mirror of
https://github.com/rembo10/headphones.git
synced 2026-01-09 14:48:07 -05:00
Revert "Migrate as much as possible to pip and requirements.txt"
This reverts commit 982594a4a5.
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -171,7 +171,3 @@ _ReSharper*/
|
||||
/logs
|
||||
.project
|
||||
.pydevproject
|
||||
.vs/
|
||||
|
||||
#Python virtual environment
|
||||
env/
|
||||
33
.travis.yml
33
.travis.yml
@@ -1,24 +1,21 @@
|
||||
# Travis CI configuration file
|
||||
# http://about.travis-ci.org/docs/
|
||||
|
||||
language: python
|
||||
|
||||
# Available Python versions:
|
||||
# http://about.travis-ci.org/docs/user/ci-environment/#Python-VM-images
|
||||
python:
|
||||
- '2.6'
|
||||
- '2.7'
|
||||
- "2.6"
|
||||
- "2.7"
|
||||
# pylint 1.4 does not run under python 2.6
|
||||
install:
|
||||
- pip install -r requirements.txt
|
||||
- pip install pylint
|
||||
- pip install pyflakes
|
||||
- pip install pep8
|
||||
- pip install pyOpenSSL
|
||||
- pip install pylint==1.3.1
|
||||
- pip install pyflakes
|
||||
- pip install pep8
|
||||
script:
|
||||
- pep8 headphones
|
||||
- pylint --rcfile=pylintrc headphones
|
||||
- pyflakes headphones
|
||||
before_deploy:
|
||||
- pip install -r requirements.txt -t lib/
|
||||
- find . -name "*.pyc" -delete
|
||||
- find ./lib/ -path "*/*\-info/*" -delete
|
||||
- zip -r -9 headphones.zip data/ headphones/ init-scripts/ lib/ Headphones.py
|
||||
deploy:
|
||||
provider: releases
|
||||
api_key:
|
||||
secure: <GITHUB OATH TOKEN HERE>
|
||||
file: headphones.zip
|
||||
on:
|
||||
tags: true
|
||||
- nosetests headphones
|
||||
|
||||
88
lib/MultipartPostHandler.py
Normal file
88
lib/MultipartPostHandler.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
####
|
||||
# 06/2010 Nic Wolfe <nic@wolfeden.ca>
|
||||
# 02/2006 Will Holcomb <wholcomb@gmail.com>
|
||||
#
|
||||
# This library is free software; you can redistribute it and/or
|
||||
# modify it under the terms of the GNU Lesser General Public
|
||||
# License as published by the Free Software Foundation; either
|
||||
# version 2.1 of the License, or (at your option) any later version.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
||||
# Lesser General Public License for more details.
|
||||
#
|
||||
|
||||
import urllib
|
||||
import urllib2
|
||||
import mimetools, mimetypes
|
||||
import os, sys
|
||||
|
||||
# Controls how sequences are uncoded. If true, elements may be given multiple values by
|
||||
# assigning a sequence.
|
||||
doseq = 1
|
||||
|
||||
class MultipartPostHandler(urllib2.BaseHandler):
|
||||
handler_order = urllib2.HTTPHandler.handler_order - 10 # needs to run first
|
||||
|
||||
def http_request(self, request):
|
||||
data = request.get_data()
|
||||
if data is not None and type(data) != str:
|
||||
v_files = []
|
||||
v_vars = []
|
||||
try:
|
||||
for(key, value) in data.items():
|
||||
if type(value) in (file, list, tuple):
|
||||
v_files.append((key, value))
|
||||
else:
|
||||
v_vars.append((key, value))
|
||||
except TypeError:
|
||||
systype, value, traceback = sys.exc_info()
|
||||
raise TypeError, "not a valid non-string sequence or mapping object", traceback
|
||||
|
||||
if len(v_files) == 0:
|
||||
data = urllib.urlencode(v_vars, doseq)
|
||||
else:
|
||||
boundary, data = MultipartPostHandler.multipart_encode(v_vars, v_files)
|
||||
contenttype = 'multipart/form-data; boundary=%s' % boundary
|
||||
if(request.has_header('Content-Type')
|
||||
and request.get_header('Content-Type').find('multipart/form-data') != 0):
|
||||
print "Replacing %s with %s" % (request.get_header('content-type'), 'multipart/form-data')
|
||||
request.add_unredirected_header('Content-Type', contenttype)
|
||||
|
||||
request.add_data(data)
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def multipart_encode(vars, files, boundary = None, buffer = None):
|
||||
if boundary is None:
|
||||
boundary = mimetools.choose_boundary()
|
||||
if buffer is None:
|
||||
buffer = ''
|
||||
for(key, value) in vars:
|
||||
buffer += '--%s\r\n' % boundary
|
||||
buffer += 'Content-Disposition: form-data; name="%s"' % key
|
||||
buffer += '\r\n\r\n' + value + '\r\n'
|
||||
for(key, fd) in files:
|
||||
|
||||
# allow them to pass in a file or a tuple with name & data
|
||||
if type(fd) == file:
|
||||
name_in = fd.name
|
||||
fd.seek(0)
|
||||
data_in = fd.read()
|
||||
elif type(fd) in (tuple, list):
|
||||
name_in, data_in = fd
|
||||
|
||||
filename = os.path.basename(name_in)
|
||||
contenttype = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
buffer += '--%s\r\n' % boundary
|
||||
buffer += 'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' % (key, filename)
|
||||
buffer += 'Content-Type: %s\r\n' % contenttype
|
||||
# buffer += 'Content-Length: %s\r\n' % file_size
|
||||
buffer += '\r\n' + data_in + '\r\n'
|
||||
buffer += '--%s--\r\n\r\n' % boundary
|
||||
return boundary, buffer
|
||||
|
||||
https_request = http_request
|
||||
5
lib/apscheduler/__init__.py
Normal file
5
lib/apscheduler/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
version_info = (3, 0, 1)
|
||||
version = '3.0.1'
|
||||
release = '3.0.1'
|
||||
|
||||
__version__ = release # PEP 396
|
||||
73
lib/apscheduler/events.py
Normal file
73
lib/apscheduler/events.py
Normal file
@@ -0,0 +1,73 @@
|
||||
__all__ = ('EVENT_SCHEDULER_START', 'EVENT_SCHEDULER_SHUTDOWN', 'EVENT_EXECUTOR_ADDED', 'EVENT_EXECUTOR_REMOVED',
|
||||
'EVENT_JOBSTORE_ADDED', 'EVENT_JOBSTORE_REMOVED', 'EVENT_ALL_JOBS_REMOVED', 'EVENT_JOB_ADDED',
|
||||
'EVENT_JOB_REMOVED', 'EVENT_JOB_MODIFIED', 'EVENT_JOB_EXECUTED', 'EVENT_JOB_ERROR', 'EVENT_JOB_MISSED',
|
||||
'SchedulerEvent', 'JobEvent', 'JobExecutionEvent')
|
||||
|
||||
|
||||
EVENT_SCHEDULER_START = 1
|
||||
EVENT_SCHEDULER_SHUTDOWN = 2
|
||||
EVENT_EXECUTOR_ADDED = 4
|
||||
EVENT_EXECUTOR_REMOVED = 8
|
||||
EVENT_JOBSTORE_ADDED = 16
|
||||
EVENT_JOBSTORE_REMOVED = 32
|
||||
EVENT_ALL_JOBS_REMOVED = 64
|
||||
EVENT_JOB_ADDED = 128
|
||||
EVENT_JOB_REMOVED = 256
|
||||
EVENT_JOB_MODIFIED = 512
|
||||
EVENT_JOB_EXECUTED = 1024
|
||||
EVENT_JOB_ERROR = 2048
|
||||
EVENT_JOB_MISSED = 4096
|
||||
EVENT_ALL = (EVENT_SCHEDULER_START | EVENT_SCHEDULER_SHUTDOWN | EVENT_JOBSTORE_ADDED | EVENT_JOBSTORE_REMOVED |
|
||||
EVENT_JOB_ADDED | EVENT_JOB_REMOVED | EVENT_JOB_MODIFIED | EVENT_JOB_EXECUTED |
|
||||
EVENT_JOB_ERROR | EVENT_JOB_MISSED)
|
||||
|
||||
|
||||
class SchedulerEvent(object):
|
||||
"""
|
||||
An event that concerns the scheduler itself.
|
||||
|
||||
:ivar code: the type code of this event
|
||||
:ivar alias: alias of the job store or executor that was added or removed (if applicable)
|
||||
"""
|
||||
|
||||
def __init__(self, code, alias=None):
|
||||
super(SchedulerEvent, self).__init__()
|
||||
self.code = code
|
||||
self.alias = alias
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (code=%d)>' % (self.__class__.__name__, self.code)
|
||||
|
||||
|
||||
class JobEvent(SchedulerEvent):
|
||||
"""
|
||||
An event that concerns a job.
|
||||
|
||||
:ivar code: the type code of this event
|
||||
:ivar job_id: identifier of the job in question
|
||||
:ivar jobstore: alias of the job store containing the job in question
|
||||
"""
|
||||
|
||||
def __init__(self, code, job_id, jobstore):
|
||||
super(JobEvent, self).__init__(code)
|
||||
self.code = code
|
||||
self.job_id = job_id
|
||||
self.jobstore = jobstore
|
||||
|
||||
|
||||
class JobExecutionEvent(JobEvent):
|
||||
"""
|
||||
An event that concerns the execution of individual jobs.
|
||||
|
||||
:ivar scheduled_run_time: the time when the job was scheduled to be run
|
||||
:ivar retval: the return value of the successfully executed job
|
||||
:ivar exception: the exception raised by the job
|
||||
:ivar traceback: a formatted traceback for the exception
|
||||
"""
|
||||
|
||||
def __init__(self, code, job_id, jobstore, scheduled_run_time, retval=None, exception=None, traceback=None):
|
||||
super(JobExecutionEvent, self).__init__(code, job_id, jobstore)
|
||||
self.scheduled_run_time = scheduled_run_time
|
||||
self.retval = retval
|
||||
self.exception = exception
|
||||
self.traceback = traceback
|
||||
0
lib/apscheduler/executors/__init__.py
Normal file
0
lib/apscheduler/executors/__init__.py
Normal file
28
lib/apscheduler/executors/asyncio.py
Normal file
28
lib/apscheduler/executors/asyncio.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class AsyncIOExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs in the default executor of the event loop.
|
||||
|
||||
Plugin alias: ``asyncio``
|
||||
"""
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
super(AsyncIOExecutor, self).start(scheduler, alias)
|
||||
self._eventloop = scheduler._eventloop
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(f):
|
||||
try:
|
||||
events = f.result()
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
|
||||
f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times, self._logger.name)
|
||||
f.add_done_callback(callback)
|
||||
119
lib/apscheduler/executors/base.py
Normal file
119
lib/apscheduler/executors/base.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from traceback import format_tb
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from pytz import utc
|
||||
import six
|
||||
|
||||
from apscheduler.events import JobExecutionEvent, EVENT_JOB_MISSED, EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
|
||||
|
||||
|
||||
class MaxInstancesReachedError(Exception):
|
||||
def __init__(self, job):
|
||||
super(MaxInstancesReachedError, self).__init__(
|
||||
'Job "%s" has already reached its maximum number of instances (%d)' % (job.id, job.max_instances))
|
||||
|
||||
|
||||
class BaseExecutor(six.with_metaclass(ABCMeta, object)):
|
||||
"""Abstract base class that defines the interface that every executor must implement."""
|
||||
|
||||
_scheduler = None
|
||||
_lock = None
|
||||
_logger = logging.getLogger('apscheduler.executors')
|
||||
|
||||
def __init__(self):
|
||||
super(BaseExecutor, self).__init__()
|
||||
self._instances = defaultdict(lambda: 0)
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
"""
|
||||
Called by the scheduler when the scheduler is being started or when the executor is being added to an already
|
||||
running scheduler.
|
||||
|
||||
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this executor
|
||||
:param str|unicode alias: alias of this executor as it was assigned to the scheduler
|
||||
"""
|
||||
|
||||
self._scheduler = scheduler
|
||||
self._lock = scheduler._create_lock()
|
||||
self._logger = logging.getLogger('apscheduler.executors.%s' % alias)
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
"""
|
||||
Shuts down this executor.
|
||||
|
||||
:param bool wait: ``True`` to wait until all submitted jobs have been executed
|
||||
"""
|
||||
|
||||
def submit_job(self, job, run_times):
|
||||
"""
|
||||
Submits job for execution.
|
||||
|
||||
:param Job job: job to execute
|
||||
:param list[datetime] run_times: list of datetimes specifying when the job should have been run
|
||||
:raises MaxInstancesReachedError: if the maximum number of allowed instances for this job has been reached
|
||||
"""
|
||||
|
||||
assert self._lock is not None, 'This executor has not been started yet'
|
||||
with self._lock:
|
||||
if self._instances[job.id] >= job.max_instances:
|
||||
raise MaxInstancesReachedError(job)
|
||||
|
||||
self._do_submit_job(job, run_times)
|
||||
self._instances[job.id] += 1
|
||||
|
||||
@abstractmethod
|
||||
def _do_submit_job(self, job, run_times):
|
||||
"""Performs the actual task of scheduling `run_job` to be called."""
|
||||
|
||||
def _run_job_success(self, job_id, events):
|
||||
"""Called by the executor with the list of generated events when `run_job` has been successfully called."""
|
||||
|
||||
with self._lock:
|
||||
self._instances[job_id] -= 1
|
||||
|
||||
for event in events:
|
||||
self._scheduler._dispatch_event(event)
|
||||
|
||||
def _run_job_error(self, job_id, exc, traceback=None):
|
||||
"""Called by the executor with the exception if there is an error calling `run_job`."""
|
||||
|
||||
with self._lock:
|
||||
self._instances[job_id] -= 1
|
||||
|
||||
exc_info = (exc.__class__, exc, traceback)
|
||||
self._logger.error('Error running job %s', job_id, exc_info=exc_info)
|
||||
|
||||
|
||||
def run_job(job, jobstore_alias, run_times, logger_name):
|
||||
"""Called by executors to run the job. Returns a list of scheduler events to be dispatched by the scheduler."""
|
||||
|
||||
events = []
|
||||
logger = logging.getLogger(logger_name)
|
||||
for run_time in run_times:
|
||||
# See if the job missed its run time window, and handle possible misfires accordingly
|
||||
if job.misfire_grace_time is not None:
|
||||
difference = datetime.now(utc) - run_time
|
||||
grace_time = timedelta(seconds=job.misfire_grace_time)
|
||||
if difference > grace_time:
|
||||
events.append(JobExecutionEvent(EVENT_JOB_MISSED, job.id, jobstore_alias, run_time))
|
||||
logger.warning('Run time of job "%s" was missed by %s', job, difference)
|
||||
continue
|
||||
|
||||
logger.info('Running job "%s" (scheduled at %s)', job, run_time)
|
||||
try:
|
||||
retval = job.func(*job.args, **job.kwargs)
|
||||
except:
|
||||
exc, tb = sys.exc_info()[1:]
|
||||
formatted_tb = ''.join(format_tb(tb))
|
||||
events.append(JobExecutionEvent(EVENT_JOB_ERROR, job.id, jobstore_alias, run_time, exception=exc,
|
||||
traceback=formatted_tb))
|
||||
logger.exception('Job "%s" raised an exception', job)
|
||||
else:
|
||||
events.append(JobExecutionEvent(EVENT_JOB_EXECUTED, job.id, jobstore_alias, run_time, retval=retval))
|
||||
logger.info('Job "%s" executed successfully', job)
|
||||
|
||||
return events
|
||||
19
lib/apscheduler/executors/debug.py
Normal file
19
lib/apscheduler/executors/debug.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class DebugExecutor(BaseExecutor):
|
||||
"""
|
||||
A special executor that executes the target callable directly instead of deferring it to a thread or process.
|
||||
|
||||
Plugin alias: ``debug``
|
||||
"""
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
try:
|
||||
events = run_job(job, job._jobstore_alias, run_times, self._logger.name)
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
29
lib/apscheduler/executors/gevent.py
Normal file
29
lib/apscheduler/executors/gevent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
try:
|
||||
import gevent
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('GeventExecutor requires gevent installed')
|
||||
|
||||
|
||||
class GeventExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs as greenlets.
|
||||
|
||||
Plugin alias: ``gevent``
|
||||
"""
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(greenlet):
|
||||
try:
|
||||
events = greenlet.get()
|
||||
except:
|
||||
self._run_job_error(job.id, *sys.exc_info()[1:])
|
||||
else:
|
||||
self._run_job_success(job.id, events)
|
||||
|
||||
gevent.spawn(run_job, job, job._jobstore_alias, run_times, self._logger.name).link(callback)
|
||||
54
lib/apscheduler/executors/pool.py
Normal file
54
lib/apscheduler/executors/pool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from abc import abstractmethod
|
||||
import concurrent.futures
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class BasePoolExecutor(BaseExecutor):
|
||||
@abstractmethod
|
||||
def __init__(self, pool):
|
||||
super(BasePoolExecutor, self).__init__()
|
||||
self._pool = pool
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(f):
|
||||
exc, tb = (f.exception_info() if hasattr(f, 'exception_info') else
|
||||
(f.exception(), getattr(f.exception(), '__traceback__', None)))
|
||||
if exc:
|
||||
self._run_job_error(job.id, exc, tb)
|
||||
else:
|
||||
self._run_job_success(job.id, f.result())
|
||||
|
||||
f = self._pool.submit(run_job, job, job._jobstore_alias, run_times, self._logger.name)
|
||||
f.add_done_callback(callback)
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
self._pool.shutdown(wait)
|
||||
|
||||
|
||||
class ThreadPoolExecutor(BasePoolExecutor):
|
||||
"""
|
||||
An executor that runs jobs in a concurrent.futures thread pool.
|
||||
|
||||
Plugin alias: ``threadpool``
|
||||
|
||||
:param max_workers: the maximum number of spawned threads.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers=10):
|
||||
pool = concurrent.futures.ThreadPoolExecutor(int(max_workers))
|
||||
super(ThreadPoolExecutor, self).__init__(pool)
|
||||
|
||||
|
||||
class ProcessPoolExecutor(BasePoolExecutor):
|
||||
"""
|
||||
An executor that runs jobs in a concurrent.futures process pool.
|
||||
|
||||
Plugin alias: ``processpool``
|
||||
|
||||
:param max_workers: the maximum number of spawned processes.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers=10):
|
||||
pool = concurrent.futures.ProcessPoolExecutor(int(max_workers))
|
||||
super(ProcessPoolExecutor, self).__init__(pool)
|
||||
25
lib/apscheduler/executors/twisted.py
Normal file
25
lib/apscheduler/executors/twisted.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.executors.base import BaseExecutor, run_job
|
||||
|
||||
|
||||
class TwistedExecutor(BaseExecutor):
|
||||
"""
|
||||
Runs jobs in the reactor's thread pool.
|
||||
|
||||
Plugin alias: ``twisted``
|
||||
"""
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
super(TwistedExecutor, self).start(scheduler, alias)
|
||||
self._reactor = scheduler._reactor
|
||||
|
||||
def _do_submit_job(self, job, run_times):
|
||||
def callback(success, result):
|
||||
if success:
|
||||
self._run_job_success(job.id, result)
|
||||
else:
|
||||
self._run_job_error(job.id, result.value, result.tb)
|
||||
|
||||
self._reactor.getThreadPool().callInThreadWithCallback(callback, run_job, job, job._jobstore_alias, run_times,
|
||||
self._logger.name)
|
||||
252
lib/apscheduler/job.py
Normal file
252
lib/apscheduler/job.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from collections import Iterable, Mapping
|
||||
from uuid import uuid4
|
||||
|
||||
import six
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import ref_to_obj, obj_to_ref, datetime_repr, repr_escape, get_callable_name, check_callable_args, \
|
||||
convert_to_datetime
|
||||
|
||||
|
||||
class Job(object):
|
||||
"""
|
||||
Contains the options given when scheduling callables and its current schedule and other state.
|
||||
This class should never be instantiated by the user.
|
||||
|
||||
:var str id: the unique identifier of this job
|
||||
:var str name: the description of this job
|
||||
:var func: the callable to execute
|
||||
:var tuple|list args: positional arguments to the callable
|
||||
:var dict kwargs: keyword arguments to the callable
|
||||
:var bool coalesce: whether to only run the job once when several run times are due
|
||||
:var trigger: the trigger object that controls the schedule of this job
|
||||
:var str executor: the name of the executor that will run this job
|
||||
:var int misfire_grace_time: the time (in seconds) how much this job's execution is allowed to be late
|
||||
:var int max_instances: the maximum number of concurrently executing instances allowed for this job
|
||||
:var datetime.datetime next_run_time: the next scheduled run time of this job
|
||||
"""
|
||||
|
||||
__slots__ = ('_scheduler', '_jobstore_alias', 'id', 'trigger', 'executor', 'func', 'func_ref', 'args', 'kwargs',
|
||||
'name', 'misfire_grace_time', 'coalesce', 'max_instances', 'next_run_time')
|
||||
|
||||
def __init__(self, scheduler, id=None, **kwargs):
|
||||
super(Job, self).__init__()
|
||||
self._scheduler = scheduler
|
||||
self._jobstore_alias = None
|
||||
self._modify(id=id or uuid4().hex, **kwargs)
|
||||
|
||||
def modify(self, **changes):
|
||||
"""
|
||||
Makes the given changes to this job and saves it in the associated job store.
|
||||
Accepted keyword arguments are the same as the variables on this class.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.modify_job`
|
||||
"""
|
||||
|
||||
self._scheduler.modify_job(self.id, self._jobstore_alias, **changes)
|
||||
|
||||
def reschedule(self, trigger, **trigger_args):
|
||||
"""
|
||||
Shortcut for switching the trigger on this job.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.reschedule_job`
|
||||
"""
|
||||
|
||||
self._scheduler.reschedule_job(self.id, self._jobstore_alias, trigger, **trigger_args)
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Temporarily suspend the execution of this job.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.pause_job`
|
||||
"""
|
||||
|
||||
self._scheduler.pause_job(self.id, self._jobstore_alias)
|
||||
|
||||
def resume(self):
|
||||
"""
|
||||
Resume the schedule of this job if previously paused.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.resume_job`
|
||||
"""
|
||||
|
||||
self._scheduler.resume_job(self.id, self._jobstore_alias)
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
Unschedules this job and removes it from its associated job store.
|
||||
|
||||
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.remove_job`
|
||||
"""
|
||||
|
||||
self._scheduler.remove_job(self.id, self._jobstore_alias)
|
||||
|
||||
@property
|
||||
def pending(self):
|
||||
"""Returns ``True`` if the referenced job is still waiting to be added to its designated job store."""
|
||||
|
||||
return self._jobstore_alias is None
|
||||
|
||||
#
|
||||
# Private API
|
||||
#
|
||||
|
||||
def _get_run_times(self, now):
|
||||
"""
|
||||
Computes the scheduled run times between ``next_run_time`` and ``now`` (inclusive).
|
||||
|
||||
:type now: datetime.datetime
|
||||
:rtype: list[datetime.datetime]
|
||||
"""
|
||||
|
||||
run_times = []
|
||||
next_run_time = self.next_run_time
|
||||
while next_run_time and next_run_time <= now:
|
||||
run_times.append(next_run_time)
|
||||
next_run_time = self.trigger.get_next_fire_time(next_run_time, now)
|
||||
|
||||
return run_times
|
||||
|
||||
def _modify(self, **changes):
|
||||
"""Validates the changes to the Job and makes the modifications if and only if all of them validate."""
|
||||
|
||||
approved = {}
|
||||
|
||||
if 'id' in changes:
|
||||
value = changes.pop('id')
|
||||
if not isinstance(value, six.string_types):
|
||||
raise TypeError("id must be a nonempty string")
|
||||
if hasattr(self, 'id'):
|
||||
raise ValueError('The job ID may not be changed')
|
||||
approved['id'] = value
|
||||
|
||||
if 'func' in changes or 'args' in changes or 'kwargs' in changes:
|
||||
func = changes.pop('func') if 'func' in changes else self.func
|
||||
args = changes.pop('args') if 'args' in changes else self.args
|
||||
kwargs = changes.pop('kwargs') if 'kwargs' in changes else self.kwargs
|
||||
|
||||
if isinstance(func, str):
|
||||
func_ref = func
|
||||
func = ref_to_obj(func)
|
||||
elif callable(func):
|
||||
try:
|
||||
func_ref = obj_to_ref(func)
|
||||
except ValueError:
|
||||
# If this happens, this Job won't be serializable
|
||||
func_ref = None
|
||||
else:
|
||||
raise TypeError('func must be a callable or a textual reference to one')
|
||||
|
||||
if not hasattr(self, 'name') and changes.get('name', None) is None:
|
||||
changes['name'] = get_callable_name(func)
|
||||
|
||||
if isinstance(args, six.string_types) or not isinstance(args, Iterable):
|
||||
raise TypeError('args must be a non-string iterable')
|
||||
if isinstance(kwargs, six.string_types) or not isinstance(kwargs, Mapping):
|
||||
raise TypeError('kwargs must be a dict-like object')
|
||||
|
||||
check_callable_args(func, args, kwargs)
|
||||
|
||||
approved['func'] = func
|
||||
approved['func_ref'] = func_ref
|
||||
approved['args'] = args
|
||||
approved['kwargs'] = kwargs
|
||||
|
||||
if 'name' in changes:
|
||||
value = changes.pop('name')
|
||||
if not value or not isinstance(value, six.string_types):
|
||||
raise TypeError("name must be a nonempty string")
|
||||
approved['name'] = value
|
||||
|
||||
if 'misfire_grace_time' in changes:
|
||||
value = changes.pop('misfire_grace_time')
|
||||
if value is not None and (not isinstance(value, six.integer_types) or value <= 0):
|
||||
raise TypeError('misfire_grace_time must be either None or a positive integer')
|
||||
approved['misfire_grace_time'] = value
|
||||
|
||||
if 'coalesce' in changes:
|
||||
value = bool(changes.pop('coalesce'))
|
||||
approved['coalesce'] = value
|
||||
|
||||
if 'max_instances' in changes:
|
||||
value = changes.pop('max_instances')
|
||||
if not isinstance(value, six.integer_types) or value <= 0:
|
||||
raise TypeError('max_instances must be a positive integer')
|
||||
approved['max_instances'] = value
|
||||
|
||||
if 'trigger' in changes:
|
||||
trigger = changes.pop('trigger')
|
||||
if not isinstance(trigger, BaseTrigger):
|
||||
raise TypeError('Expected a trigger instance, got %s instead' % trigger.__class__.__name__)
|
||||
|
||||
approved['trigger'] = trigger
|
||||
|
||||
if 'executor' in changes:
|
||||
value = changes.pop('executor')
|
||||
if not isinstance(value, six.string_types):
|
||||
raise TypeError('executor must be a string')
|
||||
approved['executor'] = value
|
||||
|
||||
if 'next_run_time' in changes:
|
||||
value = changes.pop('next_run_time')
|
||||
approved['next_run_time'] = convert_to_datetime(value, self._scheduler.timezone, 'next_run_time')
|
||||
|
||||
if changes:
|
||||
raise AttributeError('The following are not modifiable attributes of Job: %s' % ', '.join(changes))
|
||||
|
||||
for key, value in six.iteritems(approved):
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getstate__(self):
|
||||
# Don't allow this Job to be serialized if the function reference could not be determined
|
||||
if not self.func_ref:
|
||||
raise ValueError('This Job cannot be serialized since the reference to its callable (%r) could not be '
|
||||
'determined. Consider giving a textual reference (module:function name) instead.' %
|
||||
(self.func,))
|
||||
|
||||
return {
|
||||
'version': 1,
|
||||
'id': self.id,
|
||||
'func': self.func_ref,
|
||||
'trigger': self.trigger,
|
||||
'executor': self.executor,
|
||||
'args': self.args,
|
||||
'kwargs': self.kwargs,
|
||||
'name': self.name,
|
||||
'misfire_grace_time': self.misfire_grace_time,
|
||||
'coalesce': self.coalesce,
|
||||
'max_instances': self.max_instances,
|
||||
'next_run_time': self.next_run_time
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
if state.get('version', 1) > 1:
|
||||
raise ValueError('Job has version %s, but only version 1 can be handled' % state['version'])
|
||||
|
||||
self.id = state['id']
|
||||
self.func_ref = state['func']
|
||||
self.func = ref_to_obj(self.func_ref)
|
||||
self.trigger = state['trigger']
|
||||
self.executor = state['executor']
|
||||
self.args = state['args']
|
||||
self.kwargs = state['kwargs']
|
||||
self.name = state['name']
|
||||
self.misfire_grace_time = state['misfire_grace_time']
|
||||
self.coalesce = state['coalesce']
|
||||
self.max_instances = state['max_instances']
|
||||
self.next_run_time = state['next_run_time']
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Job):
|
||||
return self.id == other.id
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return '<Job (id=%s name=%s)>' % (repr_escape(self.id), repr_escape(self.name))
|
||||
|
||||
def __str__(self):
|
||||
return '%s (trigger: %s, next run at: %s)' % (repr_escape(self.name), repr_escape(str(self.trigger)),
|
||||
datetime_repr(self.next_run_time))
|
||||
|
||||
def __unicode__(self):
|
||||
return six.u('%s (trigger: %s, next run at: %s)') % (self.name, self.trigger, datetime_repr(self.next_run_time))
|
||||
0
lib/apscheduler/jobstores/__init__.py
Normal file
0
lib/apscheduler/jobstores/__init__.py
Normal file
127
lib/apscheduler/jobstores/base.py
Normal file
127
lib/apscheduler/jobstores/base.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import logging
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class JobLookupError(KeyError):
|
||||
"""Raised when the job store cannot find a job for update or removal."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(JobLookupError, self).__init__(six.u('No job by the id of %s was found') % job_id)
|
||||
|
||||
|
||||
class ConflictingIdError(KeyError):
|
||||
"""Raised when the uniqueness of job IDs is being violated."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(ConflictingIdError, self).__init__(six.u('Job identifier (%s) conflicts with an existing job') % job_id)
|
||||
|
||||
|
||||
class TransientJobError(ValueError):
|
||||
"""Raised when an attempt to add transient (with no func_ref) job to a persistent job store is detected."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
super(TransientJobError, self).__init__(
|
||||
six.u('Job (%s) cannot be added to this job store because a reference to the callable could not be '
|
||||
'determined.') % job_id)
|
||||
|
||||
|
||||
class BaseJobStore(six.with_metaclass(ABCMeta)):
|
||||
"""Abstract base class that defines the interface that every job store must implement."""
|
||||
|
||||
_scheduler = None
|
||||
_alias = None
|
||||
_logger = logging.getLogger('apscheduler.jobstores')
|
||||
|
||||
def start(self, scheduler, alias):
|
||||
"""
|
||||
Called by the scheduler when the scheduler is being started or when the job store is being added to an already
|
||||
running scheduler.
|
||||
|
||||
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this job store
|
||||
:param str|unicode alias: alias of this job store as it was assigned to the scheduler
|
||||
"""
|
||||
|
||||
self._scheduler = scheduler
|
||||
self._alias = alias
|
||||
self._logger = logging.getLogger('apscheduler.jobstores.%s' % alias)
|
||||
|
||||
def shutdown(self):
|
||||
"""Frees any resources still bound to this job store."""
|
||||
|
||||
@abstractmethod
|
||||
def lookup_job(self, job_id):
|
||||
"""
|
||||
Returns a specific job, or ``None`` if it isn't found..
|
||||
|
||||
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned job to
|
||||
point to the scheduler and itself, respectively.
|
||||
|
||||
:param str|unicode job_id: identifier of the job
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_due_jobs(self, now):
|
||||
"""
|
||||
Returns the list of jobs that have ``next_run_time`` earlier or equal to ``now``.
|
||||
The returned jobs must be sorted by next run time (ascending).
|
||||
|
||||
:param datetime.datetime now: the current (timezone aware) datetime
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_next_run_time(self):
|
||||
"""
|
||||
Returns the earliest run time of all the jobs stored in this job store, or ``None`` if there are no active jobs.
|
||||
|
||||
:rtype: datetime.datetime
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_jobs(self):
|
||||
"""
|
||||
Returns a list of all jobs in this job store. The returned jobs should be sorted by next run time (ascending).
|
||||
Paused jobs (next_run_time is None) should be sorted last.
|
||||
|
||||
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned jobs to
|
||||
point to the scheduler and itself, respectively.
|
||||
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def add_job(self, job):
|
||||
"""
|
||||
Adds the given job to this store.
|
||||
|
||||
:param Job job: the job to add
|
||||
:raises ConflictingIdError: if there is another job in this store with the same ID
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_job(self, job):
|
||||
"""
|
||||
Replaces the job in the store with the given newer version.
|
||||
|
||||
:param Job job: the job to update
|
||||
:raises JobLookupError: if the job does not exist
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_job(self, job_id):
|
||||
"""
|
||||
Removes the given job from this store.
|
||||
|
||||
:param str|unicode job_id: identifier of the job
|
||||
:raises JobLookupError: if the job does not exist
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_jobs(self):
|
||||
"""Removes all jobs from this store."""
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.__class__.__name__
|
||||
107
lib/apscheduler/jobstores/memory.py
Normal file
107
lib/apscheduler/jobstores/memory.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import datetime_to_utc_timestamp
|
||||
|
||||
|
||||
class MemoryJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in an array in RAM. Provides no persistence support.
|
||||
|
||||
Plugin alias: ``memory``
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MemoryJobStore, self).__init__()
|
||||
self._jobs = [] # list of (job, timestamp), sorted by next_run_time and job id (ascending)
|
||||
self._jobs_index = {} # id -> (job, timestamp) lookup table
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
return self._jobs_index.get(job_id, (None, None))[0]
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
now_timestamp = datetime_to_utc_timestamp(now)
|
||||
pending = []
|
||||
for job, timestamp in self._jobs:
|
||||
if timestamp is None or timestamp > now_timestamp:
|
||||
break
|
||||
pending.append(job)
|
||||
|
||||
return pending
|
||||
|
||||
def get_next_run_time(self):
|
||||
return self._jobs[0][0].next_run_time if self._jobs else None
|
||||
|
||||
def get_all_jobs(self):
|
||||
return [j[0] for j in self._jobs]
|
||||
|
||||
def add_job(self, job):
|
||||
if job.id in self._jobs_index:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
index = self._get_job_index(timestamp, job.id)
|
||||
self._jobs.insert(index, (job, timestamp))
|
||||
self._jobs_index[job.id] = (job, timestamp)
|
||||
|
||||
def update_job(self, job):
|
||||
old_job, old_timestamp = self._jobs_index.get(job.id, (None, None))
|
||||
if old_job is None:
|
||||
raise JobLookupError(job.id)
|
||||
|
||||
# If the next run time has not changed, simply replace the job in its present index.
|
||||
# Otherwise, reinsert the job to the list to preserve the ordering.
|
||||
old_index = self._get_job_index(old_timestamp, old_job.id)
|
||||
new_timestamp = datetime_to_utc_timestamp(job.next_run_time)
|
||||
if old_timestamp == new_timestamp:
|
||||
self._jobs[old_index] = (job, new_timestamp)
|
||||
else:
|
||||
del self._jobs[old_index]
|
||||
new_index = self._get_job_index(new_timestamp, job.id)
|
||||
self._jobs.insert(new_index, (job, new_timestamp))
|
||||
|
||||
self._jobs_index[old_job.id] = (job, new_timestamp)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
job, timestamp = self._jobs_index.get(job_id, (None, None))
|
||||
if job is None:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
index = self._get_job_index(timestamp, job_id)
|
||||
del self._jobs[index]
|
||||
del self._jobs_index[job.id]
|
||||
|
||||
def remove_all_jobs(self):
|
||||
self._jobs = []
|
||||
self._jobs_index = {}
|
||||
|
||||
def shutdown(self):
|
||||
self.remove_all_jobs()
|
||||
|
||||
def _get_job_index(self, timestamp, job_id):
|
||||
"""
|
||||
Returns the index of the given job, or if it's not found, the index where the job should be inserted based on
|
||||
the given timestamp.
|
||||
|
||||
:type timestamp: int
|
||||
:type job_id: str
|
||||
"""
|
||||
|
||||
lo, hi = 0, len(self._jobs)
|
||||
timestamp = float('inf') if timestamp is None else timestamp
|
||||
while lo < hi:
|
||||
mid = (lo + hi) // 2
|
||||
mid_job, mid_timestamp = self._jobs[mid]
|
||||
mid_timestamp = float('inf') if mid_timestamp is None else mid_timestamp
|
||||
if mid_timestamp > timestamp:
|
||||
hi = mid
|
||||
elif mid_timestamp < timestamp:
|
||||
lo = mid + 1
|
||||
elif mid_job.id > job_id:
|
||||
hi = mid
|
||||
elif mid_job.id < job_id:
|
||||
lo = mid + 1
|
||||
else:
|
||||
return mid
|
||||
|
||||
return lo
|
||||
124
lib/apscheduler/jobstores/mongodb.py
Normal file
124
lib/apscheduler/jobstores/mongodb.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from bson.binary import Binary
|
||||
from pymongo.errors import DuplicateKeyError
|
||||
from pymongo import MongoClient, ASCENDING
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('MongoDBJobStore requires PyMongo installed')
|
||||
|
||||
|
||||
class MongoDBJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a MongoDB database. Any leftover keyword arguments are directly passed to pymongo's `MongoClient
|
||||
<http://api.mongodb.org/python/current/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient>`_.
|
||||
|
||||
Plugin alias: ``mongodb``
|
||||
|
||||
:param str database: database to store jobs in
|
||||
:param str collection: collection to store jobs in
|
||||
:param client: a :class:`~pymongo.mongo_client.MongoClient` instance to use instead of providing connection
|
||||
arguments
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, database='apscheduler', collection='jobs', client=None,
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
|
||||
super(MongoDBJobStore, self).__init__()
|
||||
self.pickle_protocol = pickle_protocol
|
||||
|
||||
if not database:
|
||||
raise ValueError('The "database" parameter must not be empty')
|
||||
if not collection:
|
||||
raise ValueError('The "collection" parameter must not be empty')
|
||||
|
||||
if client:
|
||||
self.connection = maybe_ref(client)
|
||||
else:
|
||||
connect_args.setdefault('w', 1)
|
||||
self.connection = MongoClient(**connect_args)
|
||||
|
||||
self.collection = self.connection[database][collection]
|
||||
self.collection.ensure_index('next_run_time', sparse=True)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
document = self.collection.find_one(job_id, ['job_state'])
|
||||
return self._reconstitute_job(document['job_state']) if document else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
return self._get_jobs({'next_run_time': {'$lte': timestamp}})
|
||||
|
||||
def get_next_run_time(self):
|
||||
document = self.collection.find_one({'next_run_time': {'$ne': None}}, fields=['next_run_time'],
|
||||
sort=[('next_run_time', ASCENDING)])
|
||||
return utc_timestamp_to_datetime(document['next_run_time']) if document else None
|
||||
|
||||
def get_all_jobs(self):
|
||||
return self._get_jobs({})
|
||||
|
||||
def add_job(self, job):
|
||||
try:
|
||||
self.collection.insert({
|
||||
'_id': job.id,
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
})
|
||||
except DuplicateKeyError:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
def update_job(self, job):
|
||||
changes = {
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
}
|
||||
result = self.collection.update({'_id': job.id}, {'$set': changes})
|
||||
if result and result['n'] == 0:
|
||||
raise JobLookupError(id)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
result = self.collection.remove(job_id)
|
||||
if result and result['n'] == 0:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def remove_all_jobs(self):
|
||||
self.collection.remove()
|
||||
|
||||
def shutdown(self):
|
||||
self.connection.disconnect()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _get_jobs(self, conditions):
|
||||
jobs = []
|
||||
failed_job_ids = []
|
||||
for document in self.collection.find(conditions, ['_id', 'job_state'], sort=[('next_run_time', ASCENDING)]):
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(document['job_state']))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', document['_id'])
|
||||
failed_job_ids.append(document['_id'])
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
self.collection.remove({'_id': {'$in': failed_job_ids}})
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (client=%s)>' % (self.__class__.__name__, self.connection)
|
||||
138
lib/apscheduler/jobstores/redis.py
Normal file
138
lib/apscheduler/jobstores/redis.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import six
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from redis import StrictRedis
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('RedisJobStore requires redis installed')
|
||||
|
||||
|
||||
class RedisJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a Redis database. Any leftover keyword arguments are directly passed to redis's StrictRedis.
|
||||
|
||||
Plugin alias: ``redis``
|
||||
|
||||
:param int db: the database number to store jobs in
|
||||
:param str jobs_key: key to store jobs in
|
||||
:param str run_times_key: key to store the jobs' run times in
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, db=0, jobs_key='apscheduler.jobs', run_times_key='apscheduler.run_times',
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
|
||||
super(RedisJobStore, self).__init__()
|
||||
|
||||
if db is None:
|
||||
raise ValueError('The "db" parameter must not be empty')
|
||||
if not jobs_key:
|
||||
raise ValueError('The "jobs_key" parameter must not be empty')
|
||||
if not run_times_key:
|
||||
raise ValueError('The "run_times_key" parameter must not be empty')
|
||||
|
||||
self.pickle_protocol = pickle_protocol
|
||||
self.jobs_key = jobs_key
|
||||
self.run_times_key = run_times_key
|
||||
self.redis = StrictRedis(db=int(db), **connect_args)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
job_state = self.redis.hget(self.jobs_key, job_id)
|
||||
return self._reconstitute_job(job_state) if job_state else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
job_ids = self.redis.zrangebyscore(self.run_times_key, 0, timestamp)
|
||||
if job_ids:
|
||||
job_states = self.redis.hmget(self.jobs_key, *job_ids)
|
||||
return self._reconstitute_jobs(six.moves.zip(job_ids, job_states))
|
||||
return []
|
||||
|
||||
def get_next_run_time(self):
|
||||
next_run_time = self.redis.zrange(self.run_times_key, 0, 0, withscores=True)
|
||||
if next_run_time:
|
||||
return utc_timestamp_to_datetime(next_run_time[0][1])
|
||||
|
||||
def get_all_jobs(self):
|
||||
job_states = self.redis.hgetall(self.jobs_key)
|
||||
jobs = self._reconstitute_jobs(six.iteritems(job_states))
|
||||
return sorted(jobs, key=lambda job: job.next_run_time)
|
||||
|
||||
def add_job(self, job):
|
||||
if self.redis.hexists(self.jobs_key, job.id):
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.multi()
|
||||
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
|
||||
pipe.execute()
|
||||
|
||||
def update_job(self, job):
|
||||
if not self.redis.hexists(self.jobs_key, job.id):
|
||||
raise JobLookupError(job.id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
|
||||
if job.next_run_time:
|
||||
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
|
||||
else:
|
||||
pipe.zrem(self.run_times_key, job.id)
|
||||
pipe.execute()
|
||||
|
||||
def remove_job(self, job_id):
|
||||
if not self.redis.hexists(self.jobs_key, job_id):
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hdel(self.jobs_key, job_id)
|
||||
pipe.zrem(self.run_times_key, job_id)
|
||||
pipe.execute()
|
||||
|
||||
def remove_all_jobs(self):
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.delete(self.jobs_key)
|
||||
pipe.delete(self.run_times_key)
|
||||
pipe.execute()
|
||||
|
||||
def shutdown(self):
|
||||
self.redis.connection_pool.disconnect()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _reconstitute_jobs(self, job_states):
|
||||
jobs = []
|
||||
failed_job_ids = []
|
||||
for job_id, job_state in job_states:
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(job_state))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', job_id)
|
||||
failed_job_ids.append(job_id)
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
with self.redis.pipeline() as pipe:
|
||||
pipe.hdel(self.jobs_key, *failed_job_ids)
|
||||
pipe.zrem(self.run_times_key, *failed_job_ids)
|
||||
pipe.execute()
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.__class__.__name__
|
||||
137
lib/apscheduler/jobstores/sqlalchemy.py
Normal file
137
lib/apscheduler/jobstores/sqlalchemy.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
|
||||
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
|
||||
from apscheduler.job import Job
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError: # pragma: nocover
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, Table, Column, MetaData, Unicode, Float, LargeBinary, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('SQLAlchemyJobStore requires SQLAlchemy installed')
|
||||
|
||||
|
||||
class SQLAlchemyJobStore(BaseJobStore):
|
||||
"""
|
||||
Stores jobs in a database table using SQLAlchemy. The table will be created if it doesn't exist in the database.
|
||||
|
||||
Plugin alias: ``sqlalchemy``
|
||||
|
||||
:param str url: connection string (see `SQLAlchemy documentation
|
||||
<http://docs.sqlalchemy.org/en/latest/core/engines.html?highlight=create_engine#database-urls>`_
|
||||
on this)
|
||||
:param engine: an SQLAlchemy Engine to use instead of creating a new one based on ``url``
|
||||
:param str tablename: name of the table to store jobs in
|
||||
:param metadata: a :class:`~sqlalchemy.MetaData` instance to use instead of creating a new one
|
||||
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
|
||||
"""
|
||||
|
||||
def __init__(self, url=None, engine=None, tablename='apscheduler_jobs', metadata=None,
|
||||
pickle_protocol=pickle.HIGHEST_PROTOCOL):
|
||||
super(SQLAlchemyJobStore, self).__init__()
|
||||
self.pickle_protocol = pickle_protocol
|
||||
metadata = maybe_ref(metadata) or MetaData()
|
||||
|
||||
if engine:
|
||||
self.engine = maybe_ref(engine)
|
||||
elif url:
|
||||
self.engine = create_engine(url)
|
||||
else:
|
||||
raise ValueError('Need either "engine" or "url" defined')
|
||||
|
||||
# 191 = max key length in MySQL for InnoDB/utf8mb4 tables, 25 = precision that translates to an 8-byte float
|
||||
self.jobs_t = Table(
|
||||
tablename, metadata,
|
||||
Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True),
|
||||
Column('next_run_time', Float(25), index=True),
|
||||
Column('job_state', LargeBinary, nullable=False)
|
||||
)
|
||||
|
||||
self.jobs_t.create(self.engine, True)
|
||||
|
||||
def lookup_job(self, job_id):
|
||||
selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id)
|
||||
job_state = self.engine.execute(selectable).scalar()
|
||||
return self._reconstitute_job(job_state) if job_state else None
|
||||
|
||||
def get_due_jobs(self, now):
|
||||
timestamp = datetime_to_utc_timestamp(now)
|
||||
return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp)
|
||||
|
||||
def get_next_run_time(self):
|
||||
selectable = select([self.jobs_t.c.next_run_time]).where(self.jobs_t.c.next_run_time != None).\
|
||||
order_by(self.jobs_t.c.next_run_time).limit(1)
|
||||
next_run_time = self.engine.execute(selectable).scalar()
|
||||
return utc_timestamp_to_datetime(next_run_time)
|
||||
|
||||
def get_all_jobs(self):
|
||||
return self._get_jobs()
|
||||
|
||||
def add_job(self, job):
|
||||
insert = self.jobs_t.insert().values(**{
|
||||
'id': job.id,
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
|
||||
})
|
||||
try:
|
||||
self.engine.execute(insert)
|
||||
except IntegrityError:
|
||||
raise ConflictingIdError(job.id)
|
||||
|
||||
def update_job(self, job):
|
||||
update = self.jobs_t.update().values(**{
|
||||
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
|
||||
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
|
||||
}).where(self.jobs_t.c.id == job.id)
|
||||
result = self.engine.execute(update)
|
||||
if result.rowcount == 0:
|
||||
raise JobLookupError(id)
|
||||
|
||||
def remove_job(self, job_id):
|
||||
delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id)
|
||||
result = self.engine.execute(delete)
|
||||
if result.rowcount == 0:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def remove_all_jobs(self):
|
||||
delete = self.jobs_t.delete()
|
||||
self.engine.execute(delete)
|
||||
|
||||
def shutdown(self):
|
||||
self.engine.dispose()
|
||||
|
||||
def _reconstitute_job(self, job_state):
|
||||
job_state = pickle.loads(job_state)
|
||||
job_state['jobstore'] = self
|
||||
job = Job.__new__(Job)
|
||||
job.__setstate__(job_state)
|
||||
job._scheduler = self._scheduler
|
||||
job._jobstore_alias = self._alias
|
||||
return job
|
||||
|
||||
def _get_jobs(self, *conditions):
|
||||
jobs = []
|
||||
selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).order_by(self.jobs_t.c.next_run_time)
|
||||
selectable = selectable.where(*conditions) if conditions else selectable
|
||||
failed_job_ids = set()
|
||||
for row in self.engine.execute(selectable):
|
||||
try:
|
||||
jobs.append(self._reconstitute_job(row.job_state))
|
||||
except:
|
||||
self._logger.exception('Unable to restore job "%s" -- removing it', row.id)
|
||||
failed_job_ids.add(row.id)
|
||||
|
||||
# Remove all the jobs we failed to restore
|
||||
if failed_job_ids:
|
||||
delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids))
|
||||
self.engine.execute(delete)
|
||||
|
||||
return jobs
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s (url=%s)>' % (self.__class__.__name__, self.engine.url)
|
||||
12
lib/apscheduler/schedulers/__init__.py
Normal file
12
lib/apscheduler/schedulers/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class SchedulerAlreadyRunningError(Exception):
|
||||
"""Raised when attempting to start or configure the scheduler when it's already running."""
|
||||
|
||||
def __str__(self):
|
||||
return 'Scheduler is already running'
|
||||
|
||||
|
||||
class SchedulerNotRunningError(Exception):
|
||||
"""Raised when attempting to shutdown the scheduler when it's not running."""
|
||||
|
||||
def __str__(self):
|
||||
return 'Scheduler is not running'
|
||||
68
lib/apscheduler/schedulers/asyncio.py
Normal file
68
lib/apscheduler/schedulers/asyncio.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import absolute_import
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
import trollius as asyncio
|
||||
except ImportError:
|
||||
raise ImportError('AsyncIOScheduler requires either Python 3.4 or the asyncio package installed')
|
||||
|
||||
|
||||
def run_in_event_loop(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._eventloop.call_soon_threadsafe(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class AsyncIOScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on an asyncio (:pep:`3156`) event loop.
|
||||
|
||||
Extra options:
|
||||
|
||||
============== =============================================================
|
||||
``event_loop`` AsyncIO event loop to use (defaults to the global event loop)
|
||||
============== =============================================================
|
||||
"""
|
||||
|
||||
_eventloop = None
|
||||
_timeout = None
|
||||
|
||||
def start(self):
|
||||
super(AsyncIOScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_event_loop
|
||||
def shutdown(self, wait=True):
|
||||
super(AsyncIOScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _configure(self, config):
|
||||
self._eventloop = maybe_ref(config.pop('event_loop', None)) or asyncio.get_event_loop()
|
||||
super(AsyncIOScheduler, self)._configure(config)
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timeout = self._eventloop.call_later(wait_seconds, self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timeout:
|
||||
self._timeout.cancel()
|
||||
del self._timeout
|
||||
|
||||
@run_in_event_loop
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.asyncio import AsyncIOExecutor
|
||||
return AsyncIOExecutor()
|
||||
39
lib/apscheduler/schedulers/background.py
Normal file
39
lib/apscheduler/schedulers/background.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from __future__ import absolute_import
|
||||
from threading import Thread, Event
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.util import asbool
|
||||
|
||||
|
||||
class BackgroundScheduler(BlockingScheduler):
|
||||
"""
|
||||
A scheduler that runs in the background using a separate thread
|
||||
(:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will return immediately).
|
||||
|
||||
Extra options:
|
||||
|
||||
========== ============================================================================================
|
||||
``daemon`` Set the ``daemon`` option in the background thread (defaults to ``True``,
|
||||
see `the documentation <https://docs.python.org/3.4/library/threading.html#thread-objects>`_
|
||||
for further details)
|
||||
========== ============================================================================================
|
||||
"""
|
||||
|
||||
_thread = None
|
||||
|
||||
def _configure(self, config):
|
||||
self._daemon = asbool(config.pop('daemon', True))
|
||||
super(BackgroundScheduler, self)._configure(config)
|
||||
|
||||
def start(self):
|
||||
BaseScheduler.start(self)
|
||||
self._event = Event()
|
||||
self._thread = Thread(target=self._main_loop, name='APScheduler')
|
||||
self._thread.daemon = self._daemon
|
||||
self._thread.start()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(BackgroundScheduler, self).shutdown(wait)
|
||||
self._thread.join()
|
||||
del self._thread
|
||||
845
lib/apscheduler/schedulers/base.py
Normal file
845
lib/apscheduler/schedulers/base.py
Normal file
@@ -0,0 +1,845 @@
|
||||
from __future__ import print_function
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import MutableMapping
|
||||
from threading import RLock
|
||||
from datetime import datetime
|
||||
from logging import getLogger
|
||||
import sys
|
||||
|
||||
from pkg_resources import iter_entry_points
|
||||
from tzlocal import get_localzone
|
||||
import six
|
||||
|
||||
from apscheduler.schedulers import SchedulerAlreadyRunningError, SchedulerNotRunningError
|
||||
from apscheduler.executors.base import MaxInstancesReachedError, BaseExecutor
|
||||
from apscheduler.executors.pool import ThreadPoolExecutor
|
||||
from apscheduler.jobstores.base import ConflictingIdError, JobLookupError, BaseJobStore
|
||||
from apscheduler.jobstores.memory import MemoryJobStore
|
||||
from apscheduler.job import Job
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import asbool, asint, astimezone, maybe_ref, timedelta_seconds, undefined
|
||||
from apscheduler.events import (
|
||||
SchedulerEvent, JobEvent, EVENT_SCHEDULER_START, EVENT_SCHEDULER_SHUTDOWN, EVENT_JOBSTORE_ADDED,
|
||||
EVENT_JOBSTORE_REMOVED, EVENT_ALL, EVENT_JOB_MODIFIED, EVENT_JOB_REMOVED, EVENT_JOB_ADDED, EVENT_EXECUTOR_ADDED,
|
||||
EVENT_EXECUTOR_REMOVED, EVENT_ALL_JOBS_REMOVED)
|
||||
|
||||
|
||||
class BaseScheduler(six.with_metaclass(ABCMeta)):
|
||||
"""
|
||||
Abstract base class for all schedulers. Takes the following keyword arguments:
|
||||
|
||||
:param str|logging.Logger logger: logger to use for the scheduler's logging (defaults to apscheduler.scheduler)
|
||||
:param str|datetime.tzinfo timezone: the default time zone (defaults to the local timezone)
|
||||
:param dict job_defaults: default values for newly added jobs
|
||||
:param dict jobstores: a dictionary of job store alias -> job store instance or configuration dict
|
||||
:param dict executors: a dictionary of executor alias -> executor instance or configuration dict
|
||||
|
||||
.. seealso:: :ref:`scheduler-config`
|
||||
"""
|
||||
|
||||
_trigger_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.triggers'))
|
||||
_trigger_classes = {}
|
||||
_executor_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.executors'))
|
||||
_executor_classes = {}
|
||||
_jobstore_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.jobstores'))
|
||||
_jobstore_classes = {}
|
||||
_stopped = True
|
||||
|
||||
#
|
||||
# Public API
|
||||
#
|
||||
|
||||
def __init__(self, gconfig={}, **options):
|
||||
super(BaseScheduler, self).__init__()
|
||||
self._executors = {}
|
||||
self._executors_lock = self._create_lock()
|
||||
self._jobstores = {}
|
||||
self._jobstores_lock = self._create_lock()
|
||||
self._listeners = []
|
||||
self._listeners_lock = self._create_lock()
|
||||
self._pending_jobs = []
|
||||
self.configure(gconfig, **options)
|
||||
|
||||
def configure(self, gconfig={}, prefix='apscheduler.', **options):
|
||||
"""
|
||||
Reconfigures the scheduler with the given options. Can only be done when the scheduler isn't running.
|
||||
|
||||
:param dict gconfig: a "global" configuration dictionary whose values can be overridden by keyword arguments to
|
||||
this method
|
||||
:param str|unicode prefix: pick only those keys from ``gconfig`` that are prefixed with this string
|
||||
(pass an empty string or ``None`` to use all keys)
|
||||
:raises SchedulerAlreadyRunningError: if the scheduler is already running
|
||||
"""
|
||||
|
||||
if self.running:
|
||||
raise SchedulerAlreadyRunningError
|
||||
|
||||
# If a non-empty prefix was given, strip it from the keys in the global configuration dict
|
||||
if prefix:
|
||||
prefixlen = len(prefix)
|
||||
gconfig = dict((key[prefixlen:], value) for key, value in six.iteritems(gconfig) if key.startswith(prefix))
|
||||
|
||||
# Create a structure from the dotted options (e.g. "a.b.c = d" -> {'a': {'b': {'c': 'd'}}})
|
||||
config = {}
|
||||
for key, value in six.iteritems(gconfig):
|
||||
parts = key.split('.')
|
||||
parent = config
|
||||
key = parts.pop(0)
|
||||
while parts:
|
||||
parent = parent.setdefault(key, {})
|
||||
key = parts.pop(0)
|
||||
parent[key] = value
|
||||
|
||||
# Override any options with explicit keyword arguments
|
||||
config.update(options)
|
||||
self._configure(config)
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
"""
|
||||
Starts the scheduler. The details of this process depend on the implementation.
|
||||
|
||||
:raises SchedulerAlreadyRunningError: if the scheduler is already running
|
||||
"""
|
||||
|
||||
if self.running:
|
||||
raise SchedulerAlreadyRunningError
|
||||
|
||||
with self._executors_lock:
|
||||
# Create a default executor if nothing else is configured
|
||||
if 'default' not in self._executors:
|
||||
self.add_executor(self._create_default_executor(), 'default')
|
||||
|
||||
# Start all the executors
|
||||
for alias, executor in six.iteritems(self._executors):
|
||||
executor.start(self, alias)
|
||||
|
||||
with self._jobstores_lock:
|
||||
# Create a default job store if nothing else is configured
|
||||
if 'default' not in self._jobstores:
|
||||
self.add_jobstore(self._create_default_jobstore(), 'default')
|
||||
|
||||
# Start all the job stores
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
store.start(self, alias)
|
||||
|
||||
# Schedule all pending jobs
|
||||
for job, jobstore_alias, replace_existing in self._pending_jobs:
|
||||
self._real_add_job(job, jobstore_alias, replace_existing, False)
|
||||
del self._pending_jobs[:]
|
||||
|
||||
self._stopped = False
|
||||
self._logger.info('Scheduler started')
|
||||
|
||||
# Notify listeners that the scheduler has been started
|
||||
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_START))
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self, wait=True):
|
||||
"""
|
||||
Shuts down the scheduler. Does not interrupt any currently running jobs.
|
||||
|
||||
:param bool wait: ``True`` to wait until all currently executing jobs have finished
|
||||
:raises SchedulerNotRunningError: if the scheduler has not been started yet
|
||||
"""
|
||||
|
||||
if not self.running:
|
||||
raise SchedulerNotRunningError
|
||||
|
||||
self._stopped = True
|
||||
|
||||
# Shut down all executors
|
||||
for executor in six.itervalues(self._executors):
|
||||
executor.shutdown(wait)
|
||||
|
||||
# Shut down all job stores
|
||||
for jobstore in six.itervalues(self._jobstores):
|
||||
jobstore.shutdown()
|
||||
|
||||
self._logger.info('Scheduler has been shut down')
|
||||
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_SHUTDOWN))
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return not self._stopped
|
||||
|
||||
def add_executor(self, executor, alias='default', **executor_opts):
|
||||
"""
|
||||
Adds an executor to this scheduler. Any extra keyword arguments will be passed to the executor plugin's
|
||||
constructor, assuming that the first argument is the name of an executor plugin.
|
||||
|
||||
:param str|unicode|apscheduler.executors.base.BaseExecutor executor: either an executor instance or the name of
|
||||
an executor plugin
|
||||
:param str|unicode alias: alias for the scheduler
|
||||
:raises ValueError: if there is already an executor by the given alias
|
||||
"""
|
||||
|
||||
with self._executors_lock:
|
||||
if alias in self._executors:
|
||||
raise ValueError('This scheduler already has an executor by the alias of "%s"' % alias)
|
||||
|
||||
if isinstance(executor, BaseExecutor):
|
||||
self._executors[alias] = executor
|
||||
elif isinstance(executor, six.string_types):
|
||||
self._executors[alias] = executor = self._create_plugin_instance('executor', executor, executor_opts)
|
||||
else:
|
||||
raise TypeError('Expected an executor instance or a string, got %s instead' %
|
||||
executor.__class__.__name__)
|
||||
|
||||
# Start the executor right away if the scheduler is running
|
||||
if self.running:
|
||||
executor.start(self)
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_ADDED, alias))
|
||||
|
||||
def remove_executor(self, alias, shutdown=True):
|
||||
"""
|
||||
Removes the executor by the given alias from this scheduler.
|
||||
|
||||
:param str|unicode alias: alias of the executor
|
||||
:param bool shutdown: ``True`` to shut down the executor after removing it
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
executor = self._lookup_executor(alias)
|
||||
del self._executors[alias]
|
||||
|
||||
if shutdown:
|
||||
executor.shutdown()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_REMOVED, alias))
|
||||
|
||||
def add_jobstore(self, jobstore, alias='default', **jobstore_opts):
|
||||
"""
|
||||
Adds a job store to this scheduler. Any extra keyword arguments will be passed to the job store plugin's
|
||||
constructor, assuming that the first argument is the name of a job store plugin.
|
||||
|
||||
:param str|unicode|apscheduler.jobstores.base.BaseJobStore jobstore: job store to be added
|
||||
:param str|unicode alias: alias for the job store
|
||||
:raises ValueError: if there is already a job store by the given alias
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
if alias in self._jobstores:
|
||||
raise ValueError('This scheduler already has a job store by the alias of "%s"' % alias)
|
||||
|
||||
if isinstance(jobstore, BaseJobStore):
|
||||
self._jobstores[alias] = jobstore
|
||||
elif isinstance(jobstore, six.string_types):
|
||||
self._jobstores[alias] = jobstore = self._create_plugin_instance('jobstore', jobstore, jobstore_opts)
|
||||
else:
|
||||
raise TypeError('Expected a job store instance or a string, got %s instead' %
|
||||
jobstore.__class__.__name__)
|
||||
|
||||
# Start the job store right away if the scheduler is running
|
||||
if self.running:
|
||||
jobstore.start(self, alias)
|
||||
|
||||
# Notify listeners that a new job store has been added
|
||||
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_ADDED, alias))
|
||||
|
||||
# Notify the scheduler so it can scan the new job store for jobs
|
||||
if self.running:
|
||||
self.wakeup()
|
||||
|
||||
def remove_jobstore(self, alias, shutdown=True):
|
||||
"""
|
||||
Removes the job store by the given alias from this scheduler.
|
||||
|
||||
:param str|unicode alias: alias of the job store
|
||||
:param bool shutdown: ``True`` to shut down the job store after removing it
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
jobstore = self._lookup_jobstore(alias)
|
||||
del self._jobstores[alias]
|
||||
|
||||
if shutdown:
|
||||
jobstore.shutdown()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_REMOVED, alias))
|
||||
|
||||
def add_listener(self, callback, mask=EVENT_ALL):
|
||||
"""
|
||||
add_listener(callback, mask=EVENT_ALL)
|
||||
|
||||
Adds a listener for scheduler events. When a matching event occurs, ``callback`` is executed with the event
|
||||
object as its sole argument. If the ``mask`` parameter is not provided, the callback will receive events of all
|
||||
types.
|
||||
|
||||
:param callback: any callable that takes one argument
|
||||
:param int mask: bitmask that indicates which events should be listened to
|
||||
|
||||
.. seealso:: :mod:`apscheduler.events`
|
||||
.. seealso:: :ref:`scheduler-events`
|
||||
"""
|
||||
|
||||
with self._listeners_lock:
|
||||
self._listeners.append((callback, mask))
|
||||
|
||||
def remove_listener(self, callback):
|
||||
"""Removes a previously added event listener."""
|
||||
|
||||
with self._listeners_lock:
|
||||
for i, (cb, _) in enumerate(self._listeners):
|
||||
if callback == cb:
|
||||
del self._listeners[i]
|
||||
|
||||
def add_job(self, func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
|
||||
executor='default', replace_existing=False, **trigger_args):
|
||||
"""
|
||||
add_job(func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
|
||||
executor='default', replace_existing=False, **trigger_args)
|
||||
|
||||
Adds the given job to the job list and wakes up the scheduler if it's already running.
|
||||
|
||||
Any option that defaults to ``undefined`` will be replaced with the corresponding default value when the job is
|
||||
scheduled (which happens when the scheduler is started, or immediately if the scheduler is already running).
|
||||
|
||||
The ``func`` argument can be given either as a callable object or a textual reference in the
|
||||
``package.module:some.object`` format, where the first half (separated by ``:``) is an importable module and the
|
||||
second half is a reference to the callable object, relative to the module.
|
||||
|
||||
The ``trigger`` argument can either be:
|
||||
#. the alias name of the trigger (e.g. ``date``, ``interval`` or ``cron``), in which case any extra keyword
|
||||
arguments to this method are passed on to the trigger's constructor
|
||||
#. an instance of a trigger class
|
||||
|
||||
:param func: callable (or a textual reference to one) to run at the given time
|
||||
:param str|apscheduler.triggers.base.BaseTrigger trigger: trigger that determines when ``func`` is called
|
||||
:param list|tuple args: list of positional arguments to call func with
|
||||
:param dict kwargs: dict of keyword arguments to call func with
|
||||
:param str|unicode id: explicit identifier for the job (for modifying it later)
|
||||
:param str|unicode name: textual description of the job
|
||||
:param int misfire_grace_time: seconds after the designated run time that the job is still allowed to be run
|
||||
:param bool coalesce: run once instead of many times if the scheduler determines that the job should be run more
|
||||
than once in succession
|
||||
:param int max_instances: maximum number of concurrently running instances allowed for this job
|
||||
:param datetime next_run_time: when to first run the job, regardless of the trigger (pass ``None`` to add the
|
||||
job as paused)
|
||||
:param str|unicode jobstore: alias of the job store to store the job in
|
||||
:param str|unicode executor: alias of the executor to run the job with
|
||||
:param bool replace_existing: ``True`` to replace an existing job with the same ``id`` (but retain the
|
||||
number of runs from the existing one)
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
job_kwargs = {
|
||||
'trigger': self._create_trigger(trigger, trigger_args),
|
||||
'executor': executor,
|
||||
'func': func,
|
||||
'args': tuple(args) if args is not None else (),
|
||||
'kwargs': dict(kwargs) if kwargs is not None else {},
|
||||
'id': id,
|
||||
'name': name,
|
||||
'misfire_grace_time': misfire_grace_time,
|
||||
'coalesce': coalesce,
|
||||
'max_instances': max_instances,
|
||||
'next_run_time': next_run_time
|
||||
}
|
||||
job_kwargs = dict((key, value) for key, value in six.iteritems(job_kwargs) if value is not undefined)
|
||||
job = Job(self, **job_kwargs)
|
||||
|
||||
# Don't really add jobs to job stores before the scheduler is up and running
|
||||
with self._jobstores_lock:
|
||||
if not self.running:
|
||||
self._pending_jobs.append((job, jobstore, replace_existing))
|
||||
self._logger.info('Adding job tentatively -- it will be properly scheduled when the scheduler starts')
|
||||
else:
|
||||
self._real_add_job(job, jobstore, replace_existing, True)
|
||||
|
||||
return job
|
||||
|
||||
def scheduled_job(self, trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
|
||||
executor='default', **trigger_args):
|
||||
"""
|
||||
scheduled_job(trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
|
||||
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
|
||||
executor='default',**trigger_args)
|
||||
|
||||
A decorator version of :meth:`add_job`, except that ``replace_existing`` is always ``True``.
|
||||
|
||||
.. important:: The ``id`` argument must be given if scheduling a job in a persistent job store. The scheduler
|
||||
cannot, however, enforce this requirement.
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
self.add_job(func, trigger, args, kwargs, id, name, misfire_grace_time, coalesce, max_instances,
|
||||
next_run_time, jobstore, executor, True, **trigger_args)
|
||||
return func
|
||||
return inner
|
||||
|
||||
def modify_job(self, job_id, jobstore=None, **changes):
|
||||
"""
|
||||
Modifies the properties of a single job. Modifications are passed to this method as extra keyword arguments.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
with self._jobstores_lock:
|
||||
job, jobstore = self._lookup_job(job_id, jobstore)
|
||||
job._modify(**changes)
|
||||
if jobstore:
|
||||
self._lookup_jobstore(jobstore).update_job(job)
|
||||
|
||||
self._dispatch_event(JobEvent(EVENT_JOB_MODIFIED, job_id, jobstore))
|
||||
|
||||
# Wake up the scheduler since the job's next run time may have been changed
|
||||
self.wakeup()
|
||||
|
||||
def reschedule_job(self, job_id, jobstore=None, trigger=None, **trigger_args):
|
||||
"""
|
||||
Constructs a new trigger for a job and updates its next run time.
|
||||
Extra keyword arguments are passed directly to the trigger's constructor.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
:param trigger: alias of the trigger type or a trigger instance
|
||||
"""
|
||||
|
||||
trigger = self._create_trigger(trigger, trigger_args)
|
||||
now = datetime.now(self.timezone)
|
||||
next_run_time = trigger.get_next_fire_time(None, now)
|
||||
self.modify_job(job_id, jobstore, trigger=trigger, next_run_time=next_run_time)
|
||||
|
||||
def pause_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Causes the given job not to be executed until it is explicitly resumed.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
|
||||
self.modify_job(job_id, jobstore, next_run_time=None)
|
||||
|
||||
def resume_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Resumes the schedule of the given job, or removes the job if its schedule is finished.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
job, jobstore = self._lookup_job(job_id, jobstore)
|
||||
now = datetime.now(self.timezone)
|
||||
next_run_time = job.trigger.get_next_fire_time(None, now)
|
||||
if next_run_time:
|
||||
self.modify_job(job_id, jobstore, next_run_time=next_run_time)
|
||||
else:
|
||||
self.remove_job(job.id, jobstore)
|
||||
|
||||
def get_jobs(self, jobstore=None, pending=None):
|
||||
"""
|
||||
Returns a list of pending jobs (if the scheduler hasn't been started yet) and scheduled jobs, either from a
|
||||
specific job store or from all of them.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store
|
||||
:param bool pending: ``False`` to leave out pending jobs (jobs that are waiting for the scheduler start to be
|
||||
added to their respective job stores), ``True`` to only include pending jobs, anything else
|
||||
to return both
|
||||
:rtype: list[Job]
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
jobs = []
|
||||
|
||||
if pending is not False:
|
||||
for job, alias, replace_existing in self._pending_jobs:
|
||||
if jobstore is None or alias == jobstore:
|
||||
jobs.append(job)
|
||||
|
||||
if pending is not True:
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore is None or alias == jobstore:
|
||||
jobs.extend(store.get_all_jobs())
|
||||
|
||||
return jobs
|
||||
|
||||
def get_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Returns the Job that matches the given ``job_id``.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that most likely contains the job
|
||||
:return: the Job by the given ID, or ``None`` if it wasn't found
|
||||
:rtype: Job
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
try:
|
||||
return self._lookup_job(job_id, jobstore)[0]
|
||||
except JobLookupError:
|
||||
return
|
||||
|
||||
def remove_job(self, job_id, jobstore=None):
|
||||
"""
|
||||
Removes a job, preventing it from being run any more.
|
||||
|
||||
:param str|unicode job_id: the identifier of the job
|
||||
:param str|unicode jobstore: alias of the job store that contains the job
|
||||
:raises JobLookupError: if the job was not found
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
# Check if the job is among the pending jobs
|
||||
for i, (job, jobstore_alias, replace_existing) in enumerate(self._pending_jobs):
|
||||
if job.id == job_id:
|
||||
del self._pending_jobs[i]
|
||||
jobstore = jobstore_alias
|
||||
break
|
||||
else:
|
||||
# Otherwise, try to remove it from each store until it succeeds or we run out of stores to check
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
try:
|
||||
store.remove_job(job_id)
|
||||
except JobLookupError:
|
||||
continue
|
||||
|
||||
jobstore = alias
|
||||
break
|
||||
|
||||
if jobstore is None:
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
# Notify listeners that a job has been removed
|
||||
event = JobEvent(EVENT_JOB_REMOVED, job_id, jobstore)
|
||||
self._dispatch_event(event)
|
||||
|
||||
self._logger.info('Removed job %s', job_id)
|
||||
|
||||
def remove_all_jobs(self, jobstore=None):
|
||||
"""
|
||||
Removes all jobs from the specified job store, or all job stores if none is given.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store
|
||||
"""
|
||||
|
||||
with self._jobstores_lock:
|
||||
if jobstore:
|
||||
self._pending_jobs = [pending for pending in self._pending_jobs if pending[1] != jobstore]
|
||||
else:
|
||||
self._pending_jobs = []
|
||||
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
store.remove_all_jobs()
|
||||
|
||||
self._dispatch_event(SchedulerEvent(EVENT_ALL_JOBS_REMOVED, jobstore))
|
||||
|
||||
def print_jobs(self, jobstore=None, out=None):
|
||||
"""
|
||||
print_jobs(jobstore=None, out=sys.stdout)
|
||||
|
||||
Prints out a textual listing of all jobs currently scheduled on either all job stores or just a specific one.
|
||||
|
||||
:param str|unicode jobstore: alias of the job store, ``None`` to list jobs from all stores
|
||||
:param file out: a file-like object to print to (defaults to **sys.stdout** if nothing is given)
|
||||
"""
|
||||
|
||||
out = out or sys.stdout
|
||||
with self._jobstores_lock:
|
||||
if self._pending_jobs:
|
||||
print(six.u('Pending jobs:'), file=out)
|
||||
for job, jobstore_alias, replace_existing in self._pending_jobs:
|
||||
if jobstore in (None, jobstore_alias):
|
||||
print(six.u(' %s') % job, file=out)
|
||||
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore in (None, alias):
|
||||
print(six.u('Jobstore %s:') % alias, file=out)
|
||||
jobs = store.get_all_jobs()
|
||||
if jobs:
|
||||
for job in jobs:
|
||||
print(six.u(' %s') % job, file=out)
|
||||
else:
|
||||
print(six.u(' No scheduled jobs'), file=out)
|
||||
|
||||
@abstractmethod
|
||||
def wakeup(self):
|
||||
"""
|
||||
Notifies the scheduler that there may be jobs due for execution.
|
||||
Triggers :meth:`_process_jobs` to be run in an implementation specific manner.
|
||||
"""
|
||||
|
||||
#
|
||||
# Private API
|
||||
#
|
||||
|
||||
def _configure(self, config):
|
||||
# Set general options
|
||||
self._logger = maybe_ref(config.pop('logger', None)) or getLogger('apscheduler.scheduler')
|
||||
self.timezone = astimezone(config.pop('timezone', None)) or get_localzone()
|
||||
|
||||
# Set the job defaults
|
||||
job_defaults = config.get('job_defaults', {})
|
||||
self._job_defaults = {
|
||||
'misfire_grace_time': asint(job_defaults.get('misfire_grace_time', 1)),
|
||||
'coalesce': asbool(job_defaults.get('coalesce', True)),
|
||||
'max_instances': asint(job_defaults.get('max_instances', 1))
|
||||
}
|
||||
|
||||
# Configure executors
|
||||
self._executors.clear()
|
||||
for alias, value in six.iteritems(config.get('executors', {})):
|
||||
if isinstance(value, BaseExecutor):
|
||||
self.add_executor(value, alias)
|
||||
elif isinstance(value, MutableMapping):
|
||||
executor_class = value.pop('class', None)
|
||||
plugin = value.pop('type', None)
|
||||
if plugin:
|
||||
executor = self._create_plugin_instance('executor', plugin, value)
|
||||
elif executor_class:
|
||||
cls = maybe_ref(executor_class)
|
||||
executor = cls(**value)
|
||||
else:
|
||||
raise ValueError('Cannot create executor "%s" -- either "type" or "class" must be defined' % alias)
|
||||
|
||||
self.add_executor(executor, alias)
|
||||
else:
|
||||
raise TypeError("Expected executor instance or dict for executors['%s'], got %s instead" % (
|
||||
alias, value.__class__.__name__))
|
||||
|
||||
# Configure job stores
|
||||
self._jobstores.clear()
|
||||
for alias, value in six.iteritems(config.get('jobstores', {})):
|
||||
if isinstance(value, BaseJobStore):
|
||||
self.add_jobstore(value, alias)
|
||||
elif isinstance(value, MutableMapping):
|
||||
jobstore_class = value.pop('class', None)
|
||||
plugin = value.pop('type', None)
|
||||
if plugin:
|
||||
jobstore = self._create_plugin_instance('jobstore', plugin, value)
|
||||
elif jobstore_class:
|
||||
cls = maybe_ref(jobstore_class)
|
||||
jobstore = cls(**value)
|
||||
else:
|
||||
raise ValueError('Cannot create job store "%s" -- either "type" or "class" must be defined' % alias)
|
||||
|
||||
self.add_jobstore(jobstore, alias)
|
||||
else:
|
||||
raise TypeError("Expected job store instance or dict for jobstores['%s'], got %s instead" % (
|
||||
alias, value.__class__.__name__))
|
||||
|
||||
def _create_default_executor(self):
|
||||
"""Creates a default executor store, specific to the particular scheduler type."""
|
||||
|
||||
return ThreadPoolExecutor()
|
||||
|
||||
def _create_default_jobstore(self):
|
||||
"""Creates a default job store, specific to the particular scheduler type."""
|
||||
|
||||
return MemoryJobStore()
|
||||
|
||||
def _lookup_executor(self, alias):
|
||||
"""
|
||||
Returns the executor instance by the given name from the list of executors that were added to this scheduler.
|
||||
|
||||
:type alias: str
|
||||
:raises KeyError: if no executor by the given alias is not found
|
||||
"""
|
||||
|
||||
try:
|
||||
return self._executors[alias]
|
||||
except KeyError:
|
||||
raise KeyError('No such executor: %s' % alias)
|
||||
|
||||
def _lookup_jobstore(self, alias):
|
||||
"""
|
||||
Returns the job store instance by the given name from the list of job stores that were added to this scheduler.
|
||||
|
||||
:type alias: str
|
||||
:raises KeyError: if no job store by the given alias is not found
|
||||
"""
|
||||
|
||||
try:
|
||||
return self._jobstores[alias]
|
||||
except KeyError:
|
||||
raise KeyError('No such job store: %s' % alias)
|
||||
|
||||
def _lookup_job(self, job_id, jobstore_alias):
|
||||
"""
|
||||
Finds a job by its ID.
|
||||
|
||||
:type job_id: str
|
||||
:param str jobstore_alias: alias of a job store to look in
|
||||
:return tuple[Job, str]: a tuple of job, jobstore alias (jobstore alias is None in case of a pending job)
|
||||
:raises JobLookupError: if no job by the given ID is found.
|
||||
"""
|
||||
|
||||
# Check if the job is among the pending jobs
|
||||
for job, alias, replace_existing in self._pending_jobs:
|
||||
if job.id == job_id:
|
||||
return job, None
|
||||
|
||||
# Look in all job stores
|
||||
for alias, store in six.iteritems(self._jobstores):
|
||||
if jobstore_alias in (None, alias):
|
||||
job = store.lookup_job(job_id)
|
||||
if job is not None:
|
||||
return job, alias
|
||||
|
||||
raise JobLookupError(job_id)
|
||||
|
||||
def _dispatch_event(self, event):
|
||||
"""
|
||||
Dispatches the given event to interested listeners.
|
||||
|
||||
:param SchedulerEvent event: the event to send
|
||||
"""
|
||||
|
||||
with self._listeners_lock:
|
||||
listeners = tuple(self._listeners)
|
||||
|
||||
for cb, mask in listeners:
|
||||
if event.code & mask:
|
||||
try:
|
||||
cb(event)
|
||||
except:
|
||||
self._logger.exception('Error notifying listener')
|
||||
|
||||
def _real_add_job(self, job, jobstore_alias, replace_existing, wakeup):
|
||||
"""
|
||||
:param Job job: the job to add
|
||||
:param bool replace_existing: ``True`` to use update_job() in case the job already exists in the store
|
||||
:param bool wakeup: ``True`` to wake up the scheduler after adding the job
|
||||
"""
|
||||
|
||||
# Fill in undefined values with defaults
|
||||
replacements = {}
|
||||
for key, value in six.iteritems(self._job_defaults):
|
||||
if not hasattr(job, key):
|
||||
replacements[key] = value
|
||||
|
||||
# Calculate the next run time if there is none defined
|
||||
if not hasattr(job, 'next_run_time'):
|
||||
now = datetime.now(self.timezone)
|
||||
replacements['next_run_time'] = job.trigger.get_next_fire_time(None, now)
|
||||
|
||||
# Apply any replacements
|
||||
job._modify(**replacements)
|
||||
|
||||
# Add the job to the given job store
|
||||
store = self._lookup_jobstore(jobstore_alias)
|
||||
try:
|
||||
store.add_job(job)
|
||||
except ConflictingIdError:
|
||||
if replace_existing:
|
||||
store.update_job(job)
|
||||
else:
|
||||
raise
|
||||
|
||||
# Mark the job as no longer pending
|
||||
job._jobstore_alias = jobstore_alias
|
||||
|
||||
# Notify listeners that a new job has been added
|
||||
event = JobEvent(EVENT_JOB_ADDED, job.id, jobstore_alias)
|
||||
self._dispatch_event(event)
|
||||
|
||||
self._logger.info('Added job "%s" to job store "%s"', job.name, jobstore_alias)
|
||||
|
||||
# Notify the scheduler about the new job
|
||||
if wakeup:
|
||||
self.wakeup()
|
||||
|
||||
def _create_plugin_instance(self, type_, alias, constructor_kwargs):
|
||||
"""Creates an instance of the given plugin type, loading the plugin first if necessary."""
|
||||
|
||||
plugin_container, class_container, base_class = {
|
||||
'trigger': (self._trigger_plugins, self._trigger_classes, BaseTrigger),
|
||||
'jobstore': (self._jobstore_plugins, self._jobstore_classes, BaseJobStore),
|
||||
'executor': (self._executor_plugins, self._executor_classes, BaseExecutor)
|
||||
}[type_]
|
||||
|
||||
try:
|
||||
plugin_cls = class_container[alias]
|
||||
except KeyError:
|
||||
if alias in plugin_container:
|
||||
plugin_cls = class_container[alias] = plugin_container[alias].load()
|
||||
if not issubclass(plugin_cls, base_class):
|
||||
raise TypeError('The {0} entry point does not point to a {0} class'.format(type_))
|
||||
else:
|
||||
raise LookupError('No {0} by the name "{1}" was found'.format(type_, alias))
|
||||
|
||||
return plugin_cls(**constructor_kwargs)
|
||||
|
||||
def _create_trigger(self, trigger, trigger_args):
|
||||
if isinstance(trigger, BaseTrigger):
|
||||
return trigger
|
||||
elif trigger is None:
|
||||
trigger = 'date'
|
||||
elif not isinstance(trigger, six.string_types):
|
||||
raise TypeError('Expected a trigger instance or string, got %s instead' % trigger.__class__.__name__)
|
||||
|
||||
# Use the scheduler's time zone if nothing else is specified
|
||||
trigger_args.setdefault('timezone', self.timezone)
|
||||
|
||||
# Instantiate the trigger class
|
||||
return self._create_plugin_instance('trigger', trigger, trigger_args)
|
||||
|
||||
def _create_lock(self):
|
||||
"""Creates a reentrant lock object."""
|
||||
|
||||
return RLock()
|
||||
|
||||
def _process_jobs(self):
|
||||
"""
|
||||
Iterates through jobs in every jobstore, starts jobs that are due and figures out how long to wait for the next
|
||||
round.
|
||||
"""
|
||||
|
||||
self._logger.debug('Looking for jobs to run')
|
||||
now = datetime.now(self.timezone)
|
||||
next_wakeup_time = None
|
||||
|
||||
with self._jobstores_lock:
|
||||
for jobstore_alias, jobstore in six.iteritems(self._jobstores):
|
||||
for job in jobstore.get_due_jobs(now):
|
||||
# Look up the job's executor
|
||||
try:
|
||||
executor = self._lookup_executor(job.executor)
|
||||
except:
|
||||
self._logger.error(
|
||||
'Executor lookup ("%s") failed for job "%s" -- removing it from the job store',
|
||||
job.executor, job)
|
||||
self.remove_job(job.id, jobstore_alias)
|
||||
continue
|
||||
|
||||
run_times = job._get_run_times(now)
|
||||
run_times = run_times[-1:] if run_times and job.coalesce else run_times
|
||||
if run_times:
|
||||
try:
|
||||
executor.submit_job(job, run_times)
|
||||
except MaxInstancesReachedError:
|
||||
self._logger.warning(
|
||||
'Execution of job "%s" skipped: maximum number of running instances reached (%d)',
|
||||
job, job.max_instances)
|
||||
except:
|
||||
self._logger.exception('Error submitting job "%s" to executor "%s"', job, job.executor)
|
||||
|
||||
# Update the job if it has a next execution time. Otherwise remove it from the job store.
|
||||
job_next_run = job.trigger.get_next_fire_time(run_times[-1], now)
|
||||
if job_next_run:
|
||||
job._modify(next_run_time=job_next_run)
|
||||
jobstore.update_job(job)
|
||||
else:
|
||||
self.remove_job(job.id, jobstore_alias)
|
||||
|
||||
# Set a new next wakeup time if there isn't one yet or the jobstore has an even earlier one
|
||||
jobstore_next_run_time = jobstore.get_next_run_time()
|
||||
if jobstore_next_run_time and (next_wakeup_time is None or jobstore_next_run_time < next_wakeup_time):
|
||||
next_wakeup_time = jobstore_next_run_time
|
||||
|
||||
# Determine the delay until this method should be called again
|
||||
if next_wakeup_time is not None:
|
||||
wait_seconds = max(timedelta_seconds(next_wakeup_time - now), 0)
|
||||
self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time, wait_seconds)
|
||||
else:
|
||||
wait_seconds = None
|
||||
self._logger.debug('No jobs; waiting until a job is added')
|
||||
|
||||
return wait_seconds
|
||||
32
lib/apscheduler/schedulers/blocking.py
Normal file
32
lib/apscheduler/schedulers/blocking.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import absolute_import
|
||||
from threading import Event
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
|
||||
class BlockingScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs in the foreground (:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will block).
|
||||
"""
|
||||
|
||||
MAX_WAIT_TIME = 4294967 # Maximum value accepted by Event.wait() on Windows
|
||||
|
||||
_event = None
|
||||
|
||||
def start(self):
|
||||
super(BlockingScheduler, self).start()
|
||||
self._event = Event()
|
||||
self._main_loop()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(BlockingScheduler, self).shutdown(wait)
|
||||
self._event.set()
|
||||
|
||||
def _main_loop(self):
|
||||
while self.running:
|
||||
wait_seconds = self._process_jobs()
|
||||
self._event.wait(wait_seconds if wait_seconds is not None else self.MAX_WAIT_TIME)
|
||||
self._event.clear()
|
||||
|
||||
def wakeup(self):
|
||||
self._event.set()
|
||||
35
lib/apscheduler/schedulers/gevent.py
Normal file
35
lib/apscheduler/schedulers/gevent.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
try:
|
||||
from gevent.event import Event
|
||||
from gevent.lock import RLock
|
||||
import gevent
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('GeventScheduler requires gevent installed')
|
||||
|
||||
|
||||
class GeventScheduler(BlockingScheduler):
|
||||
"""A scheduler that runs as a Gevent greenlet."""
|
||||
|
||||
_greenlet = None
|
||||
|
||||
def start(self):
|
||||
BaseScheduler.start(self)
|
||||
self._event = Event()
|
||||
self._greenlet = gevent.spawn(self._main_loop)
|
||||
return self._greenlet
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(GeventScheduler, self).shutdown(wait)
|
||||
self._greenlet.join()
|
||||
del self._greenlet
|
||||
|
||||
def _create_lock(self):
|
||||
return RLock()
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.gevent import GeventExecutor
|
||||
return GeventExecutor()
|
||||
46
lib/apscheduler/schedulers/qt.py
Normal file
46
lib/apscheduler/schedulers/qt.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
|
||||
try:
|
||||
from PyQt5.QtCore import QObject, QTimer
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
from PyQt4.QtCore import QObject, QTimer
|
||||
except ImportError:
|
||||
try:
|
||||
from PySide.QtCore import QObject, QTimer # flake8: noqa
|
||||
except ImportError:
|
||||
raise ImportError('QtScheduler requires either PyQt5, PyQt4 or PySide installed')
|
||||
|
||||
|
||||
class QtScheduler(BaseScheduler):
|
||||
"""A scheduler that runs in a Qt event loop."""
|
||||
|
||||
_timer = None
|
||||
|
||||
def start(self):
|
||||
super(QtScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
super(QtScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timer = QTimer.singleShot(wait_seconds * 1000, self._process_jobs)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timer:
|
||||
if self._timer.isActive():
|
||||
self._timer.stop()
|
||||
del self._timer
|
||||
|
||||
def wakeup(self):
|
||||
self._start_timer(0)
|
||||
|
||||
def _process_jobs(self):
|
||||
wait_seconds = super(QtScheduler, self)._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
60
lib/apscheduler/schedulers/tornado.py
Normal file
60
lib/apscheduler/schedulers/tornado.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import absolute_import
|
||||
from datetime import timedelta
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
from tornado.ioloop import IOLoop
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('TornadoScheduler requires tornado installed')
|
||||
|
||||
|
||||
def run_in_ioloop(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._ioloop.add_callback(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TornadoScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on a Tornado IOLoop.
|
||||
|
||||
=========== ===============================================================
|
||||
``io_loop`` Tornado IOLoop instance to use (defaults to the global IO loop)
|
||||
=========== ===============================================================
|
||||
"""
|
||||
|
||||
_ioloop = None
|
||||
_timeout = None
|
||||
|
||||
def start(self):
|
||||
super(TornadoScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_ioloop
|
||||
def shutdown(self, wait=True):
|
||||
super(TornadoScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _configure(self, config):
|
||||
self._ioloop = maybe_ref(config.pop('io_loop', None)) or IOLoop.current()
|
||||
super(TornadoScheduler, self)._configure(config)
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._timeout = self._ioloop.add_timeout(timedelta(seconds=wait_seconds), self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._timeout:
|
||||
self._ioloop.remove_timeout(self._timeout)
|
||||
del self._timeout
|
||||
|
||||
@run_in_ioloop
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
65
lib/apscheduler/schedulers/twisted.py
Normal file
65
lib/apscheduler/schedulers/twisted.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import absolute_import
|
||||
from functools import wraps
|
||||
|
||||
from apscheduler.schedulers.base import BaseScheduler
|
||||
from apscheduler.util import maybe_ref
|
||||
|
||||
try:
|
||||
from twisted.internet import reactor as default_reactor
|
||||
except ImportError: # pragma: nocover
|
||||
raise ImportError('TwistedScheduler requires Twisted installed')
|
||||
|
||||
|
||||
def run_in_reactor(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._reactor.callFromThread(func, self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TwistedScheduler(BaseScheduler):
|
||||
"""
|
||||
A scheduler that runs on a Twisted reactor.
|
||||
|
||||
Extra options:
|
||||
|
||||
=========== ========================================================
|
||||
``reactor`` Reactor instance to use (defaults to the global reactor)
|
||||
=========== ========================================================
|
||||
"""
|
||||
|
||||
_reactor = None
|
||||
_delayedcall = None
|
||||
|
||||
def _configure(self, config):
|
||||
self._reactor = maybe_ref(config.pop('reactor', default_reactor))
|
||||
super(TwistedScheduler, self)._configure(config)
|
||||
|
||||
def start(self):
|
||||
super(TwistedScheduler, self).start()
|
||||
self.wakeup()
|
||||
|
||||
@run_in_reactor
|
||||
def shutdown(self, wait=True):
|
||||
super(TwistedScheduler, self).shutdown(wait)
|
||||
self._stop_timer()
|
||||
|
||||
def _start_timer(self, wait_seconds):
|
||||
self._stop_timer()
|
||||
if wait_seconds is not None:
|
||||
self._delayedcall = self._reactor.callLater(wait_seconds, self.wakeup)
|
||||
|
||||
def _stop_timer(self):
|
||||
if self._delayedcall and self._delayedcall.active():
|
||||
self._delayedcall.cancel()
|
||||
del self._delayedcall
|
||||
|
||||
@run_in_reactor
|
||||
def wakeup(self):
|
||||
self._stop_timer()
|
||||
wait_seconds = self._process_jobs()
|
||||
self._start_timer(wait_seconds)
|
||||
|
||||
def _create_default_executor(self):
|
||||
from apscheduler.executors.twisted import TwistedExecutor
|
||||
return TwistedExecutor()
|
||||
0
lib/apscheduler/triggers/__init__.py
Normal file
0
lib/apscheduler/triggers/__init__.py
Normal file
16
lib/apscheduler/triggers/base.py
Normal file
16
lib/apscheduler/triggers/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class BaseTrigger(six.with_metaclass(ABCMeta)):
|
||||
"""Abstract base class that defines the interface that every trigger must implement."""
|
||||
|
||||
@abstractmethod
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
"""
|
||||
Returns the next datetime to fire on, If no such datetime can be calculated, returns ``None``.
|
||||
|
||||
:param datetime.datetime previous_fire_time: the previous time the trigger was fired
|
||||
:param datetime.datetime now: current datetime
|
||||
"""
|
||||
176
lib/apscheduler/triggers/cron/__init__.py
Normal file
176
lib/apscheduler/triggers/cron/__init__.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tzlocal import get_localzone
|
||||
import six
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.triggers.cron.fields import BaseField, WeekField, DayOfMonthField, DayOfWeekField, DEFAULT_VALUES
|
||||
from apscheduler.util import datetime_ceil, convert_to_datetime, datetime_repr, astimezone
|
||||
|
||||
|
||||
class CronTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers when current time matches all specified time constraints, similarly to how the UNIX cron scheduler works.
|
||||
|
||||
:param int|str year: 4-digit year
|
||||
:param int|str month: month (1-12)
|
||||
:param int|str day: day of the (1-31)
|
||||
:param int|str week: ISO week (1-53)
|
||||
:param int|str day_of_week: number or name of weekday (0-6 or mon,tue,wed,thu,fri,sat,sun)
|
||||
:param int|str hour: hour (0-23)
|
||||
:param int|str minute: minute (0-59)
|
||||
:param int|str second: second (0-59)
|
||||
:param datetime|str start_date: earliest possible date/time to trigger on (inclusive)
|
||||
:param datetime|str end_date: latest possible date/time to trigger on (inclusive)
|
||||
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
|
||||
(defaults to scheduler timezone)
|
||||
|
||||
.. note:: The first weekday is always **monday**.
|
||||
"""
|
||||
|
||||
FIELD_NAMES = ('year', 'month', 'day', 'week', 'day_of_week', 'hour', 'minute', 'second')
|
||||
FIELDS_MAP = {
|
||||
'year': BaseField,
|
||||
'month': BaseField,
|
||||
'week': WeekField,
|
||||
'day': DayOfMonthField,
|
||||
'day_of_week': DayOfWeekField,
|
||||
'hour': BaseField,
|
||||
'minute': BaseField,
|
||||
'second': BaseField
|
||||
}
|
||||
|
||||
__slots__ = 'timezone', 'start_date', 'end_date', 'fields'
|
||||
|
||||
def __init__(self, year=None, month=None, day=None, week=None, day_of_week=None, hour=None, minute=None,
|
||||
second=None, start_date=None, end_date=None, timezone=None):
|
||||
if timezone:
|
||||
self.timezone = astimezone(timezone)
|
||||
elif start_date and start_date.tzinfo:
|
||||
self.timezone = start_date.tzinfo
|
||||
elif end_date and end_date.tzinfo:
|
||||
self.timezone = end_date.tzinfo
|
||||
else:
|
||||
self.timezone = get_localzone()
|
||||
|
||||
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
|
||||
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
|
||||
|
||||
values = dict((key, value) for (key, value) in six.iteritems(locals())
|
||||
if key in self.FIELD_NAMES and value is not None)
|
||||
self.fields = []
|
||||
assign_defaults = False
|
||||
for field_name in self.FIELD_NAMES:
|
||||
if field_name in values:
|
||||
exprs = values.pop(field_name)
|
||||
is_default = False
|
||||
assign_defaults = not values
|
||||
elif assign_defaults:
|
||||
exprs = DEFAULT_VALUES[field_name]
|
||||
is_default = True
|
||||
else:
|
||||
exprs = '*'
|
||||
is_default = True
|
||||
|
||||
field_class = self.FIELDS_MAP[field_name]
|
||||
field = field_class(field_name, exprs, is_default)
|
||||
self.fields.append(field)
|
||||
|
||||
def _increment_field_value(self, dateval, fieldnum):
|
||||
"""
|
||||
Increments the designated field and resets all less significant fields to their minimum values.
|
||||
|
||||
:type dateval: datetime
|
||||
:type fieldnum: int
|
||||
:return: a tuple containing the new date, and the number of the field that was actually incremented
|
||||
:rtype: tuple
|
||||
"""
|
||||
|
||||
values = {}
|
||||
i = 0
|
||||
while i < len(self.fields):
|
||||
field = self.fields[i]
|
||||
if not field.REAL:
|
||||
if i == fieldnum:
|
||||
fieldnum -= 1
|
||||
i -= 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i < fieldnum:
|
||||
values[field.name] = field.get_value(dateval)
|
||||
i += 1
|
||||
elif i > fieldnum:
|
||||
values[field.name] = field.get_min(dateval)
|
||||
i += 1
|
||||
else:
|
||||
value = field.get_value(dateval)
|
||||
maxval = field.get_max(dateval)
|
||||
if value == maxval:
|
||||
fieldnum -= 1
|
||||
i -= 1
|
||||
else:
|
||||
values[field.name] = value + 1
|
||||
i += 1
|
||||
|
||||
difference = datetime(**values) - dateval.replace(tzinfo=None)
|
||||
return self.timezone.normalize(dateval + difference), fieldnum
|
||||
|
||||
def _set_field_value(self, dateval, fieldnum, new_value):
|
||||
values = {}
|
||||
for i, field in enumerate(self.fields):
|
||||
if field.REAL:
|
||||
if i < fieldnum:
|
||||
values[field.name] = field.get_value(dateval)
|
||||
elif i > fieldnum:
|
||||
values[field.name] = field.get_min(dateval)
|
||||
else:
|
||||
values[field.name] = new_value
|
||||
|
||||
difference = datetime(**values) - dateval.replace(tzinfo=None)
|
||||
return self.timezone.normalize(dateval + difference)
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
if previous_fire_time:
|
||||
start_date = max(now, previous_fire_time + timedelta(microseconds=1))
|
||||
else:
|
||||
start_date = max(now, self.start_date) if self.start_date else now
|
||||
|
||||
fieldnum = 0
|
||||
next_date = datetime_ceil(start_date).astimezone(self.timezone)
|
||||
while 0 <= fieldnum < len(self.fields):
|
||||
field = self.fields[fieldnum]
|
||||
curr_value = field.get_value(next_date)
|
||||
next_value = field.get_next_value(next_date)
|
||||
|
||||
if next_value is None:
|
||||
# No valid value was found
|
||||
next_date, fieldnum = self._increment_field_value(next_date, fieldnum - 1)
|
||||
elif next_value > curr_value:
|
||||
# A valid, but higher than the starting value, was found
|
||||
if field.REAL:
|
||||
next_date = self._set_field_value(next_date, fieldnum, next_value)
|
||||
fieldnum += 1
|
||||
else:
|
||||
next_date, fieldnum = self._increment_field_value(next_date, fieldnum)
|
||||
else:
|
||||
# A valid value was found, no changes necessary
|
||||
fieldnum += 1
|
||||
|
||||
# Return if the date has rolled past the end date
|
||||
if self.end_date and next_date > self.end_date:
|
||||
return None
|
||||
|
||||
if fieldnum >= 0:
|
||||
return next_date
|
||||
|
||||
def __str__(self):
|
||||
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
|
||||
return 'cron[%s]' % (', '.join(options))
|
||||
|
||||
def __repr__(self):
|
||||
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
|
||||
if self.start_date:
|
||||
options.append("start_date='%s'" % datetime_repr(self.start_date))
|
||||
return '<%s (%s)>' % (self.__class__.__name__, ', '.join(options))
|
||||
188
lib/apscheduler/triggers/cron/expressions.py
Normal file
188
lib/apscheduler/triggers/cron/expressions.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
This module contains the expressions applicable for CronTrigger's fields.
|
||||
"""
|
||||
|
||||
from calendar import monthrange
|
||||
import re
|
||||
|
||||
from apscheduler.util import asint
|
||||
|
||||
__all__ = ('AllExpression', 'RangeExpression', 'WeekdayRangeExpression', 'WeekdayPositionExpression',
|
||||
'LastDayOfMonthExpression')
|
||||
|
||||
|
||||
WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
|
||||
|
||||
|
||||
class AllExpression(object):
|
||||
value_re = re.compile(r'\*(?:/(?P<step>\d+))?$')
|
||||
|
||||
def __init__(self, step=None):
|
||||
self.step = asint(step)
|
||||
if self.step == 0:
|
||||
raise ValueError('Increment must be higher than 0')
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
start = field.get_value(date)
|
||||
minval = field.get_min(date)
|
||||
maxval = field.get_max(date)
|
||||
start = max(start, minval)
|
||||
|
||||
if not self.step:
|
||||
next = start
|
||||
else:
|
||||
distance_to_next = (self.step - (start - minval)) % self.step
|
||||
next = start + distance_to_next
|
||||
|
||||
if next <= maxval:
|
||||
return next
|
||||
|
||||
def __str__(self):
|
||||
if self.step:
|
||||
return '*/%d' % self.step
|
||||
return '*'
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (self.__class__.__name__, self.step)
|
||||
|
||||
|
||||
class RangeExpression(AllExpression):
|
||||
value_re = re.compile(
|
||||
r'(?P<first>\d+)(?:-(?P<last>\d+))?(?:/(?P<step>\d+))?$')
|
||||
|
||||
def __init__(self, first, last=None, step=None):
|
||||
AllExpression.__init__(self, step)
|
||||
first = asint(first)
|
||||
last = asint(last)
|
||||
if last is None and step is None:
|
||||
last = first
|
||||
if last is not None and first > last:
|
||||
raise ValueError('The minimum value in a range must not be higher than the maximum')
|
||||
self.first = first
|
||||
self.last = last
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
start = field.get_value(date)
|
||||
minval = field.get_min(date)
|
||||
maxval = field.get_max(date)
|
||||
|
||||
# Apply range limits
|
||||
minval = max(minval, self.first)
|
||||
if self.last is not None:
|
||||
maxval = min(maxval, self.last)
|
||||
start = max(start, minval)
|
||||
|
||||
if not self.step:
|
||||
next = start
|
||||
else:
|
||||
distance_to_next = (self.step - (start - minval)) % self.step
|
||||
next = start + distance_to_next
|
||||
|
||||
if next <= maxval:
|
||||
return next
|
||||
|
||||
def __str__(self):
|
||||
if self.last != self.first and self.last is not None:
|
||||
range = '%d-%d' % (self.first, self.last)
|
||||
else:
|
||||
range = str(self.first)
|
||||
|
||||
if self.step:
|
||||
return '%s/%d' % (range, self.step)
|
||||
return range
|
||||
|
||||
def __repr__(self):
|
||||
args = [str(self.first)]
|
||||
if self.last != self.first and self.last is not None or self.step:
|
||||
args.append(str(self.last))
|
||||
if self.step:
|
||||
args.append(str(self.step))
|
||||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
|
||||
|
||||
|
||||
class WeekdayRangeExpression(RangeExpression):
|
||||
value_re = re.compile(r'(?P<first>[a-z]+)(?:-(?P<last>[a-z]+))?', re.IGNORECASE)
|
||||
|
||||
def __init__(self, first, last=None):
|
||||
try:
|
||||
first_num = WEEKDAYS.index(first.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % first)
|
||||
|
||||
if last:
|
||||
try:
|
||||
last_num = WEEKDAYS.index(last.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % last)
|
||||
else:
|
||||
last_num = None
|
||||
|
||||
RangeExpression.__init__(self, first_num, last_num)
|
||||
|
||||
def __str__(self):
|
||||
if self.last != self.first and self.last is not None:
|
||||
return '%s-%s' % (WEEKDAYS[self.first], WEEKDAYS[self.last])
|
||||
return WEEKDAYS[self.first]
|
||||
|
||||
def __repr__(self):
|
||||
args = ["'%s'" % WEEKDAYS[self.first]]
|
||||
if self.last != self.first and self.last is not None:
|
||||
args.append("'%s'" % WEEKDAYS[self.last])
|
||||
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
|
||||
|
||||
|
||||
class WeekdayPositionExpression(AllExpression):
|
||||
options = ['1st', '2nd', '3rd', '4th', '5th', 'last']
|
||||
value_re = re.compile(r'(?P<option_name>%s) +(?P<weekday_name>(?:\d+|\w+))' % '|'.join(options), re.IGNORECASE)
|
||||
|
||||
def __init__(self, option_name, weekday_name):
|
||||
try:
|
||||
self.option_num = self.options.index(option_name.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday position "%s"' % option_name)
|
||||
|
||||
try:
|
||||
self.weekday = WEEKDAYS.index(weekday_name.lower())
|
||||
except ValueError:
|
||||
raise ValueError('Invalid weekday name "%s"' % weekday_name)
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
# Figure out the weekday of the month's first day and the number
|
||||
# of days in that month
|
||||
first_day_wday, last_day = monthrange(date.year, date.month)
|
||||
|
||||
# Calculate which day of the month is the first of the target weekdays
|
||||
first_hit_day = self.weekday - first_day_wday + 1
|
||||
if first_hit_day <= 0:
|
||||
first_hit_day += 7
|
||||
|
||||
# Calculate what day of the month the target weekday would be
|
||||
if self.option_num < 5:
|
||||
target_day = first_hit_day + self.option_num * 7
|
||||
else:
|
||||
target_day = first_hit_day + ((last_day - first_hit_day) / 7) * 7
|
||||
|
||||
if target_day <= last_day and target_day >= date.day:
|
||||
return target_day
|
||||
|
||||
def __str__(self):
|
||||
return '%s %s' % (self.options[self.option_num], WEEKDAYS[self.weekday])
|
||||
|
||||
def __repr__(self):
|
||||
return "%s('%s', '%s')" % (self.__class__.__name__, self.options[self.option_num], WEEKDAYS[self.weekday])
|
||||
|
||||
|
||||
class LastDayOfMonthExpression(AllExpression):
|
||||
value_re = re.compile(r'last', re.IGNORECASE)
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_next_value(self, date, field):
|
||||
return monthrange(date.year, date.month)[1]
|
||||
|
||||
def __str__(self):
|
||||
return 'last'
|
||||
|
||||
def __repr__(self):
|
||||
return "%s()" % self.__class__.__name__
|
||||
97
lib/apscheduler/triggers/cron/fields.py
Normal file
97
lib/apscheduler/triggers/cron/fields.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Fields represent CronTrigger options which map to :class:`~datetime.datetime`
|
||||
fields.
|
||||
"""
|
||||
|
||||
from calendar import monthrange
|
||||
|
||||
from apscheduler.triggers.cron.expressions import (
|
||||
AllExpression, RangeExpression, WeekdayPositionExpression, LastDayOfMonthExpression, WeekdayRangeExpression)
|
||||
|
||||
|
||||
__all__ = ('MIN_VALUES', 'MAX_VALUES', 'DEFAULT_VALUES', 'BaseField', 'WeekField', 'DayOfMonthField', 'DayOfWeekField')
|
||||
|
||||
|
||||
MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, 'day_of_week': 0, 'hour': 0, 'minute': 0, 'second': 0}
|
||||
MAX_VALUES = {'year': 2 ** 63, 'month': 12, 'day:': 31, 'week': 53, 'day_of_week': 6, 'hour': 23, 'minute': 59,
|
||||
'second': 59}
|
||||
DEFAULT_VALUES = {'year': '*', 'month': 1, 'day': 1, 'week': '*', 'day_of_week': '*', 'hour': 0, 'minute': 0,
|
||||
'second': 0}
|
||||
|
||||
|
||||
class BaseField(object):
|
||||
REAL = True
|
||||
COMPILERS = [AllExpression, RangeExpression]
|
||||
|
||||
def __init__(self, name, exprs, is_default=False):
|
||||
self.name = name
|
||||
self.is_default = is_default
|
||||
self.compile_expressions(exprs)
|
||||
|
||||
def get_min(self, dateval):
|
||||
return MIN_VALUES[self.name]
|
||||
|
||||
def get_max(self, dateval):
|
||||
return MAX_VALUES[self.name]
|
||||
|
||||
def get_value(self, dateval):
|
||||
return getattr(dateval, self.name)
|
||||
|
||||
def get_next_value(self, dateval):
|
||||
smallest = None
|
||||
for expr in self.expressions:
|
||||
value = expr.get_next_value(dateval, self)
|
||||
if smallest is None or (value is not None and value < smallest):
|
||||
smallest = value
|
||||
|
||||
return smallest
|
||||
|
||||
def compile_expressions(self, exprs):
|
||||
self.expressions = []
|
||||
|
||||
# Split a comma-separated expression list, if any
|
||||
exprs = str(exprs).strip()
|
||||
if ',' in exprs:
|
||||
for expr in exprs.split(','):
|
||||
self.compile_expression(expr)
|
||||
else:
|
||||
self.compile_expression(exprs)
|
||||
|
||||
def compile_expression(self, expr):
|
||||
for compiler in self.COMPILERS:
|
||||
match = compiler.value_re.match(expr)
|
||||
if match:
|
||||
compiled_expr = compiler(**match.groupdict())
|
||||
self.expressions.append(compiled_expr)
|
||||
return
|
||||
|
||||
raise ValueError('Unrecognized expression "%s" for field "%s"' % (expr, self.name))
|
||||
|
||||
def __str__(self):
|
||||
expr_strings = (str(e) for e in self.expressions)
|
||||
return ','.join(expr_strings)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s('%s', '%s')" % (self.__class__.__name__, self.name, self)
|
||||
|
||||
|
||||
class WeekField(BaseField):
|
||||
REAL = False
|
||||
|
||||
def get_value(self, dateval):
|
||||
return dateval.isocalendar()[1]
|
||||
|
||||
|
||||
class DayOfMonthField(BaseField):
|
||||
COMPILERS = BaseField.COMPILERS + [WeekdayPositionExpression, LastDayOfMonthExpression]
|
||||
|
||||
def get_max(self, dateval):
|
||||
return monthrange(dateval.year, dateval.month)[1]
|
||||
|
||||
|
||||
class DayOfWeekField(BaseField):
|
||||
REAL = False
|
||||
COMPILERS = BaseField.COMPILERS + [WeekdayRangeExpression]
|
||||
|
||||
def get_value(self, dateval):
|
||||
return dateval.weekday()
|
||||
30
lib/apscheduler/triggers/date.py
Normal file
30
lib/apscheduler/triggers/date.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from datetime import datetime
|
||||
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import convert_to_datetime, datetime_repr, astimezone
|
||||
|
||||
|
||||
class DateTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers once on the given datetime. If ``run_date`` is left empty, current time is used.
|
||||
|
||||
:param datetime|str run_date: the date/time to run the job at
|
||||
:param datetime.tzinfo|str timezone: time zone for ``run_date`` if it doesn't have one already
|
||||
"""
|
||||
|
||||
__slots__ = 'timezone', 'run_date'
|
||||
|
||||
def __init__(self, run_date=None, timezone=None):
|
||||
timezone = astimezone(timezone) or get_localzone()
|
||||
self.run_date = convert_to_datetime(run_date or datetime.now(), timezone, 'run_date')
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
return self.run_date if previous_fire_time is None else None
|
||||
|
||||
def __str__(self):
|
||||
return 'date[%s]' % datetime_repr(self.run_date)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s (run_date='%s')>" % (self.__class__.__name__, datetime_repr(self.run_date))
|
||||
65
lib/apscheduler/triggers/interval.py
Normal file
65
lib/apscheduler/triggers/interval.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from datetime import timedelta, datetime
|
||||
from math import ceil
|
||||
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from apscheduler.triggers.base import BaseTrigger
|
||||
from apscheduler.util import convert_to_datetime, timedelta_seconds, datetime_repr, astimezone
|
||||
|
||||
|
||||
class IntervalTrigger(BaseTrigger):
|
||||
"""
|
||||
Triggers on specified intervals, starting on ``start_date`` if specified, ``datetime.now()`` + interval
|
||||
otherwise.
|
||||
|
||||
:param int weeks: number of weeks to wait
|
||||
:param int days: number of days to wait
|
||||
:param int hours: number of hours to wait
|
||||
:param int minutes: number of minutes to wait
|
||||
:param int seconds: number of seconds to wait
|
||||
:param datetime|str start_date: starting point for the interval calculation
|
||||
:param datetime|str end_date: latest possible date/time to trigger on
|
||||
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
|
||||
"""
|
||||
|
||||
__slots__ = 'timezone', 'start_date', 'end_date', 'interval'
|
||||
|
||||
def __init__(self, weeks=0, days=0, hours=0, minutes=0, seconds=0, start_date=None, end_date=None, timezone=None):
|
||||
self.interval = timedelta(weeks=weeks, days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||||
self.interval_length = timedelta_seconds(self.interval)
|
||||
if self.interval_length == 0:
|
||||
self.interval = timedelta(seconds=1)
|
||||
self.interval_length = 1
|
||||
|
||||
if timezone:
|
||||
self.timezone = astimezone(timezone)
|
||||
elif start_date and start_date.tzinfo:
|
||||
self.timezone = start_date.tzinfo
|
||||
elif end_date and end_date.tzinfo:
|
||||
self.timezone = end_date.tzinfo
|
||||
else:
|
||||
self.timezone = get_localzone()
|
||||
|
||||
start_date = start_date or (datetime.now(self.timezone) + self.interval)
|
||||
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
|
||||
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
|
||||
|
||||
def get_next_fire_time(self, previous_fire_time, now):
|
||||
if previous_fire_time:
|
||||
next_fire_time = previous_fire_time + self.interval
|
||||
elif self.start_date > now:
|
||||
next_fire_time = self.start_date
|
||||
else:
|
||||
timediff_seconds = timedelta_seconds(now - self.start_date)
|
||||
next_interval_num = int(ceil(timediff_seconds / self.interval_length))
|
||||
next_fire_time = self.start_date + self.interval * next_interval_num
|
||||
|
||||
if not self.end_date or next_fire_time <= self.end_date:
|
||||
return self.timezone.normalize(next_fire_time)
|
||||
|
||||
def __str__(self):
|
||||
return 'interval[%s]' % str(self.interval)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s (interval=%r, start_date='%s')>" % (self.__class__.__name__, self.interval,
|
||||
datetime_repr(self.start_date))
|
||||
385
lib/apscheduler/util.py
Normal file
385
lib/apscheduler/util.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""This module contains several handy functions primarily meant for internal use."""
|
||||
|
||||
from __future__ import division
|
||||
from datetime import date, datetime, time, timedelta, tzinfo
|
||||
from inspect import isfunction, ismethod, getargspec
|
||||
from calendar import timegm
|
||||
import re
|
||||
|
||||
from pytz import timezone, utc
|
||||
import six
|
||||
|
||||
try:
|
||||
from inspect import signature
|
||||
except ImportError: # pragma: nocover
|
||||
try:
|
||||
from funcsigs import signature
|
||||
except ImportError:
|
||||
signature = None
|
||||
|
||||
__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp',
|
||||
'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name', 'obj_to_ref',
|
||||
'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args')
|
||||
|
||||
|
||||
class _Undefined(object):
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return '<undefined>'
|
||||
|
||||
undefined = _Undefined() #: a unique object that only signifies that no value is defined
|
||||
|
||||
|
||||
def asint(text):
|
||||
"""
|
||||
Safely converts a string to an integer, returning None if the string is None.
|
||||
|
||||
:type text: str
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
if text is not None:
|
||||
return int(text)
|
||||
|
||||
|
||||
def asbool(obj):
|
||||
"""
|
||||
Interprets an object as a boolean value.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
|
||||
if isinstance(obj, str):
|
||||
obj = obj.strip().lower()
|
||||
if obj in ('true', 'yes', 'on', 'y', 't', '1'):
|
||||
return True
|
||||
if obj in ('false', 'no', 'off', 'n', 'f', '0'):
|
||||
return False
|
||||
raise ValueError('Unable to interpret value "%s" as boolean' % obj)
|
||||
return bool(obj)
|
||||
|
||||
|
||||
def astimezone(obj):
|
||||
"""
|
||||
Interprets an object as a timezone.
|
||||
|
||||
:rtype: tzinfo
|
||||
"""
|
||||
|
||||
if isinstance(obj, six.string_types):
|
||||
return timezone(obj)
|
||||
if isinstance(obj, tzinfo):
|
||||
if not hasattr(obj, 'localize') or not hasattr(obj, 'normalize'):
|
||||
raise TypeError('Only timezones from the pytz library are supported')
|
||||
if obj.zone == 'local':
|
||||
raise ValueError('Unable to determine the name of the local timezone -- use an explicit timezone instead')
|
||||
return obj
|
||||
if obj is not None:
|
||||
raise TypeError('Expected tzinfo, got %s instead' % obj.__class__.__name__)
|
||||
|
||||
|
||||
_DATE_REGEX = re.compile(
|
||||
r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
r'(?: (?P<hour>\d{1,2}):(?P<minute>\d{1,2}):(?P<second>\d{1,2})'
|
||||
r'(?:\.(?P<microsecond>\d{1,6}))?)?')
|
||||
|
||||
|
||||
def convert_to_datetime(input, tz, arg_name):
|
||||
"""
|
||||
Converts the given object to a timezone aware datetime object.
|
||||
If a timezone aware datetime object is passed, it is returned unmodified.
|
||||
If a native datetime object is passed, it is given the specified timezone.
|
||||
If the input is a string, it is parsed as a datetime with the given timezone.
|
||||
|
||||
Date strings are accepted in three different forms: date only (Y-m-d),
|
||||
date with time (Y-m-d H:M:S) or with date+time with microseconds
|
||||
(Y-m-d H:M:S.micro).
|
||||
|
||||
:param str|datetime input: the datetime or string to convert to a timezone aware datetime
|
||||
:param datetime.tzinfo tz: timezone to interpret ``input`` in
|
||||
:param str arg_name: the name of the argument (used in an error message)
|
||||
:rtype: datetime
|
||||
"""
|
||||
|
||||
if input is None:
|
||||
return
|
||||
elif isinstance(input, datetime):
|
||||
datetime_ = input
|
||||
elif isinstance(input, date):
|
||||
datetime_ = datetime.combine(input, time())
|
||||
elif isinstance(input, six.string_types):
|
||||
m = _DATE_REGEX.match(input)
|
||||
if not m:
|
||||
raise ValueError('Invalid date string')
|
||||
values = [(k, int(v or 0)) for k, v in m.groupdict().items()]
|
||||
values = dict(values)
|
||||
datetime_ = datetime(**values)
|
||||
else:
|
||||
raise TypeError('Unsupported type for %s: %s' % (arg_name, input.__class__.__name__))
|
||||
|
||||
if datetime_.tzinfo is not None:
|
||||
return datetime_
|
||||
if tz is None:
|
||||
raise ValueError('The "tz" argument must be specified if %s has no timezone information' % arg_name)
|
||||
if isinstance(tz, six.string_types):
|
||||
tz = timezone(tz)
|
||||
|
||||
try:
|
||||
return tz.localize(datetime_, is_dst=None)
|
||||
except AttributeError:
|
||||
raise TypeError('Only pytz timezones are supported (need the localize() and normalize() methods)')
|
||||
|
||||
|
||||
def datetime_to_utc_timestamp(timeval):
|
||||
"""
|
||||
Converts a datetime instance to a timestamp.
|
||||
|
||||
:type timeval: datetime
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
if timeval is not None:
|
||||
return timegm(timeval.utctimetuple()) + timeval.microsecond / 1000000
|
||||
|
||||
|
||||
def utc_timestamp_to_datetime(timestamp):
|
||||
"""
|
||||
Converts the given timestamp to a datetime instance.
|
||||
|
||||
:type timestamp: float
|
||||
:rtype: datetime
|
||||
"""
|
||||
|
||||
if timestamp is not None:
|
||||
return datetime.fromtimestamp(timestamp, utc)
|
||||
|
||||
|
||||
def timedelta_seconds(delta):
|
||||
"""
|
||||
Converts the given timedelta to seconds.
|
||||
|
||||
:type delta: timedelta
|
||||
:rtype: float
|
||||
"""
|
||||
|
||||
return delta.days * 24 * 60 * 60 + delta.seconds + \
|
||||
delta.microseconds / 1000000.0
|
||||
|
||||
|
||||
def datetime_ceil(dateval):
|
||||
"""
|
||||
Rounds the given datetime object upwards.
|
||||
|
||||
:type dateval: datetime
|
||||
"""
|
||||
|
||||
if dateval.microsecond > 0:
|
||||
return dateval + timedelta(seconds=1, microseconds=-dateval.microsecond)
|
||||
return dateval
|
||||
|
||||
|
||||
def datetime_repr(dateval):
|
||||
return dateval.strftime('%Y-%m-%d %H:%M:%S %Z') if dateval else 'None'
|
||||
|
||||
|
||||
def get_callable_name(func):
|
||||
"""
|
||||
Returns the best available display name for the given function/callable.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
# the easy case (on Python 3.3+)
|
||||
if hasattr(func, '__qualname__'):
|
||||
return func.__qualname__
|
||||
|
||||
# class methods, bound and unbound methods
|
||||
f_self = getattr(func, '__self__', None) or getattr(func, 'im_self', None)
|
||||
if f_self and hasattr(func, '__name__'):
|
||||
f_class = f_self if isinstance(f_self, type) else f_self.__class__
|
||||
else:
|
||||
f_class = getattr(func, 'im_class', None)
|
||||
|
||||
if f_class and hasattr(func, '__name__'):
|
||||
return '%s.%s' % (f_class.__name__, func.__name__)
|
||||
|
||||
# class or class instance
|
||||
if hasattr(func, '__call__'):
|
||||
# class
|
||||
if hasattr(func, '__name__'):
|
||||
return func.__name__
|
||||
|
||||
# instance of a class with a __call__ method
|
||||
return func.__class__.__name__
|
||||
|
||||
raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func)
|
||||
|
||||
|
||||
def obj_to_ref(obj):
|
||||
"""
|
||||
Returns the path to the given object.
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
try:
|
||||
ref = '%s:%s' % (obj.__module__, get_callable_name(obj))
|
||||
obj2 = ref_to_obj(ref)
|
||||
if obj != obj2:
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError('Cannot determine the reference to %r' % obj)
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
def ref_to_obj(ref):
|
||||
"""
|
||||
Returns the object pointed to by ``ref``.
|
||||
|
||||
:type ref: str
|
||||
"""
|
||||
|
||||
if not isinstance(ref, six.string_types):
|
||||
raise TypeError('References must be strings')
|
||||
if ':' not in ref:
|
||||
raise ValueError('Invalid reference')
|
||||
|
||||
modulename, rest = ref.split(':', 1)
|
||||
try:
|
||||
obj = __import__(modulename)
|
||||
except ImportError:
|
||||
raise LookupError('Error resolving reference %s: could not import module' % ref)
|
||||
|
||||
try:
|
||||
for name in modulename.split('.')[1:] + rest.split('.'):
|
||||
obj = getattr(obj, name)
|
||||
return obj
|
||||
except Exception:
|
||||
raise LookupError('Error resolving reference %s: error looking up object' % ref)
|
||||
|
||||
|
||||
def maybe_ref(ref):
|
||||
"""
|
||||
Returns the object that the given reference points to, if it is indeed a reference.
|
||||
If it is not a reference, the object is returned as-is.
|
||||
"""
|
||||
|
||||
if not isinstance(ref, str):
|
||||
return ref
|
||||
return ref_to_obj(ref)
|
||||
|
||||
|
||||
if six.PY2:
|
||||
def repr_escape(string):
|
||||
if isinstance(string, six.text_type):
|
||||
return string.encode('ascii', 'backslashreplace')
|
||||
return string
|
||||
else:
|
||||
repr_escape = lambda string: string
|
||||
|
||||
|
||||
def check_callable_args(func, args, kwargs):
|
||||
"""
|
||||
Ensures that the given callable can be called with the given arguments.
|
||||
|
||||
:type args: tuple
|
||||
:type kwargs: dict
|
||||
"""
|
||||
|
||||
pos_kwargs_conflicts = [] # parameters that have a match in both args and kwargs
|
||||
positional_only_kwargs = [] # positional-only parameters that have a match in kwargs
|
||||
unsatisfied_args = [] # parameters in signature that don't have a match in args or kwargs
|
||||
unsatisfied_kwargs = [] # keyword-only arguments that don't have a match in kwargs
|
||||
unmatched_args = list(args) # args that didn't match any of the parameters in the signature
|
||||
unmatched_kwargs = list(kwargs) # kwargs that didn't match any of the parameters in the signature
|
||||
has_varargs = has_var_kwargs = False # indicates if the signature defines *args and **kwargs respectively
|
||||
|
||||
if signature:
|
||||
try:
|
||||
sig = signature(func)
|
||||
except ValueError:
|
||||
return # signature() doesn't work against every kind of callable
|
||||
|
||||
for param in six.itervalues(sig.parameters):
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD:
|
||||
if param.name in unmatched_kwargs and unmatched_args:
|
||||
pos_kwargs_conflicts.append(param.name)
|
||||
elif unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_args.append(param.name)
|
||||
elif param.kind == param.POSITIONAL_ONLY:
|
||||
if unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
positional_only_kwargs.append(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_args.append(param.name)
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
if param.name in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(param.name)
|
||||
elif param.default is param.empty:
|
||||
unsatisfied_kwargs.append(param.name)
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
has_varargs = True
|
||||
elif param.kind == param.VAR_KEYWORD:
|
||||
has_var_kwargs = True
|
||||
else:
|
||||
if not isfunction(func) and not ismethod(func) and hasattr(func, '__call__'):
|
||||
func = func.__call__
|
||||
|
||||
try:
|
||||
argspec = getargspec(func)
|
||||
except TypeError:
|
||||
return # getargspec() doesn't work certain callables
|
||||
|
||||
argspec_args = argspec.args if not ismethod(func) else argspec.args[1:]
|
||||
has_varargs = bool(argspec.varargs)
|
||||
has_var_kwargs = bool(argspec.keywords)
|
||||
for arg, default in six.moves.zip_longest(argspec_args, argspec.defaults or (), fillvalue=undefined):
|
||||
if arg in unmatched_kwargs and unmatched_args:
|
||||
pos_kwargs_conflicts.append(arg)
|
||||
elif unmatched_args:
|
||||
del unmatched_args[0]
|
||||
elif arg in unmatched_kwargs:
|
||||
unmatched_kwargs.remove(arg)
|
||||
elif default is undefined:
|
||||
unsatisfied_args.append(arg)
|
||||
|
||||
# Make sure there are no conflicts between args and kwargs
|
||||
if pos_kwargs_conflicts:
|
||||
raise ValueError('The following arguments are supplied in both args and kwargs: %s' %
|
||||
', '.join(pos_kwargs_conflicts))
|
||||
|
||||
# Check if keyword arguments are being fed to positional-only parameters
|
||||
if positional_only_kwargs:
|
||||
raise ValueError('The following arguments cannot be given as keyword arguments: %s' %
|
||||
', '.join(positional_only_kwargs))
|
||||
|
||||
# Check that the number of positional arguments minus the number of matched kwargs matches the argspec
|
||||
if unsatisfied_args:
|
||||
raise ValueError('The following arguments have not been supplied: %s' % ', '.join(unsatisfied_args))
|
||||
|
||||
# Check that all keyword-only arguments have been supplied
|
||||
if unsatisfied_kwargs:
|
||||
raise ValueError('The following keyword-only arguments have not been supplied in kwargs: %s' %
|
||||
', '.join(unsatisfied_kwargs))
|
||||
|
||||
# Check that the callable can accept the given number of positional arguments
|
||||
if not has_varargs and unmatched_args:
|
||||
raise ValueError('The list of positional arguments is longer than the target callable can handle '
|
||||
'(allowed: %d, given in args: %d)' % (len(args) - len(unmatched_args), len(args)))
|
||||
|
||||
# Check that the callable can accept the given keyword arguments
|
||||
if not has_var_kwargs and unmatched_kwargs:
|
||||
raise ValueError('The target callable does not accept the following keyword arguments: %s' %
|
||||
', '.join(unmatched_kwargs))
|
||||
2386
lib/argparse.py
Normal file
2386
lib/argparse.py
Normal file
File diff suppressed because it is too large
Load Diff
21
lib/beets/LICENSE
Normal file
21
lib/beets/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License
|
||||
|
||||
Copyright (c) 2010-2014 Adrian Sampson
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
94
lib/beets/README.rst
Normal file
94
lib/beets/README.rst
Normal file
@@ -0,0 +1,94 @@
|
||||
.. image:: https://travis-ci.org/sampsyo/beets.svg?branch=master
|
||||
:target: https://travis-ci.org/sampsyo/beets
|
||||
|
||||
.. image:: http://img.shields.io/coveralls/sampsyo/beets.svg
|
||||
:target: https://coveralls.io/r/sampsyo/beets
|
||||
|
||||
.. image:: http://img.shields.io/pypi/v/beets.svg
|
||||
:target: https://pypi.python.org/pypi/beets
|
||||
|
||||
Beets is the media library management system for obsessive-compulsive music
|
||||
geeks.
|
||||
|
||||
The purpose of beets is to get your music collection right once and for all.
|
||||
It catalogs your collection, automatically improving its metadata as it goes.
|
||||
It then provides a bouquet of tools for manipulating and accessing your music.
|
||||
|
||||
Here's an example of beets' brainy tag corrector doing its thing::
|
||||
|
||||
$ beet import ~/music/ladytron
|
||||
Tagging:
|
||||
Ladytron - Witching Hour
|
||||
(Similarity: 98.4%)
|
||||
* Last One Standing -> The Last One Standing
|
||||
* Beauty -> Beauty*2
|
||||
* White Light Generation -> Whitelightgenerator
|
||||
* All the Way -> All the Way...
|
||||
|
||||
Because beets is designed as a library, it can do almost anything you can
|
||||
imagine for your music collection. Via `plugins`_, beets becomes a panacea:
|
||||
|
||||
- Fetch or calculate all the metadata you could possibly need: `album art`_,
|
||||
`lyrics`_, `genres`_, `tempos`_, `ReplayGain`_ levels, or `acoustic
|
||||
fingerprints`_.
|
||||
- Get metadata from `MusicBrainz`_, `Discogs`_, or `Beatport`_. Or guess
|
||||
metadata using songs' filenames or their acoustic fingerprints.
|
||||
- `Transcode audio`_ to any format you like.
|
||||
- Check your library for `duplicate tracks and albums`_ or for `albums that
|
||||
are missing tracks`_.
|
||||
- Clean up crufty tags left behind by other, less-awesome tools.
|
||||
- Embed and extract album art from files' metadata.
|
||||
- Browse your music library graphically through a Web browser and play it in any
|
||||
browser that supports `HTML5 Audio`_.
|
||||
- Analyze music files' metadata from the command line.
|
||||
- Listen to your library with a music player that speaks the `MPD`_ protocol
|
||||
and works with a staggering variety of interfaces.
|
||||
|
||||
If beets doesn't do what you want yet, `writing your own plugin`_ is
|
||||
shockingly simple if you know a little Python.
|
||||
|
||||
.. _plugins: http://beets.readthedocs.org/page/plugins/
|
||||
.. _MPD: http://www.musicpd.org/
|
||||
.. _MusicBrainz music collection: http://musicbrainz.org/doc/Collections/
|
||||
.. _writing your own plugin:
|
||||
http://beets.readthedocs.org/page/dev/plugins.html
|
||||
.. _HTML5 Audio:
|
||||
http://www.w3.org/TR/html-markup/audio.html
|
||||
.. _albums that are missing tracks:
|
||||
http://beets.readthedocs.org/page/plugins/missing.html
|
||||
.. _duplicate tracks and albums:
|
||||
http://beets.readthedocs.org/page/plugins/duplicates.html
|
||||
.. _Transcode audio:
|
||||
http://beets.readthedocs.org/page/plugins/convert.html
|
||||
.. _Beatport: http://www.beatport.com/
|
||||
.. _Discogs: http://www.discogs.com/
|
||||
.. _acoustic fingerprints:
|
||||
http://beets.readthedocs.org/page/plugins/chroma.html
|
||||
.. _ReplayGain: http://beets.readthedocs.org/page/plugins/replaygain.html
|
||||
.. _tempos: http://beets.readthedocs.org/page/plugins/echonest.html
|
||||
.. _genres: http://beets.readthedocs.org/page/plugins/lastgenre.html
|
||||
.. _album art: http://beets.readthedocs.org/page/plugins/fetchart.html
|
||||
.. _lyrics: http://beets.readthedocs.org/page/plugins/lyrics.html
|
||||
.. _MusicBrainz: http://musicbrainz.org/
|
||||
|
||||
Read More
|
||||
---------
|
||||
|
||||
Learn more about beets at `its Web site`_. Follow `@b33ts`_ on Twitter for
|
||||
news and updates.
|
||||
|
||||
You can install beets by typing ``pip install beets``. Then check out the
|
||||
`Getting Started`_ guide.
|
||||
|
||||
.. _its Web site: http://beets.radbox.org/
|
||||
.. _Getting Started: http://beets.readthedocs.org/page/guides/main.html
|
||||
.. _@b33ts: http://twitter.com/b33ts/
|
||||
|
||||
Authors
|
||||
-------
|
||||
|
||||
Beets is by `Adrian Sampson`_ with a supporting cast of thousands. For help,
|
||||
please contact the `mailing list`_.
|
||||
|
||||
.. _mailing list: https://groups.google.com/forum/#!forum/beets-users
|
||||
.. _Adrian Sampson: http://homes.cs.washington.edu/~asampson/
|
||||
28
lib/beets/__init__.py
Normal file
28
lib/beets/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
# This particular version has been slightly modified to work with Headphones
|
||||
# https://github.com/rembo10/headphones
|
||||
|
||||
__version__ = '1.3.10-headphones'
|
||||
__author__ = 'Adrian Sampson <adrian@radbox.org>'
|
||||
|
||||
import os
|
||||
|
||||
import beets.library
|
||||
from beets.util import confit
|
||||
|
||||
Library = beets.library.Library
|
||||
|
||||
config = confit.LazyConfig(os.path.dirname(__file__), __name__)
|
||||
136
lib/beets/autotag/__init__.py
Normal file
136
lib/beets/autotag/__init__.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Facilities for automatically determining files' correct metadata.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from beets import config
|
||||
|
||||
# Parts of external interface.
|
||||
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
|
||||
from .match import tag_item, tag_album # noqa
|
||||
from .match import Recommendation # noqa
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Additional utilities for the main interface.
|
||||
|
||||
def apply_item_metadata(item, track_info):
|
||||
"""Set an item's metadata from its matched TrackInfo object.
|
||||
"""
|
||||
item.artist = track_info.artist
|
||||
item.artist_sort = track_info.artist_sort
|
||||
item.artist_credit = track_info.artist_credit
|
||||
item.title = track_info.title
|
||||
item.mb_trackid = track_info.track_id
|
||||
if track_info.artist_id:
|
||||
item.mb_artistid = track_info.artist_id
|
||||
# At the moment, the other metadata is left intact (including album
|
||||
# and track number). Perhaps these should be emptied?
|
||||
|
||||
|
||||
def apply_metadata(album_info, mapping):
|
||||
"""Set the items' metadata to match an AlbumInfo object using a
|
||||
mapping from Items to TrackInfo objects.
|
||||
"""
|
||||
for item, track_info in mapping.iteritems():
|
||||
# Album, artist, track count.
|
||||
if track_info.artist:
|
||||
item.artist = track_info.artist
|
||||
else:
|
||||
item.artist = album_info.artist
|
||||
item.albumartist = album_info.artist
|
||||
item.album = album_info.album
|
||||
|
||||
# Artist sort and credit names.
|
||||
item.artist_sort = track_info.artist_sort or album_info.artist_sort
|
||||
item.artist_credit = (track_info.artist_credit or
|
||||
album_info.artist_credit)
|
||||
item.albumartist_sort = album_info.artist_sort
|
||||
item.albumartist_credit = album_info.artist_credit
|
||||
|
||||
# Release date.
|
||||
for prefix in '', 'original_':
|
||||
if config['original_date'] and not prefix:
|
||||
# Ignore specific release date.
|
||||
continue
|
||||
|
||||
for suffix in 'year', 'month', 'day':
|
||||
key = prefix + suffix
|
||||
value = getattr(album_info, key) or 0
|
||||
|
||||
# If we don't even have a year, apply nothing.
|
||||
if suffix == 'year' and not value:
|
||||
break
|
||||
|
||||
# Otherwise, set the fetched value (or 0 for the month
|
||||
# and day if not available).
|
||||
item[key] = value
|
||||
|
||||
# If we're using original release date for both fields,
|
||||
# also set item.year = info.original_year, etc.
|
||||
if config['original_date']:
|
||||
item[suffix] = value
|
||||
|
||||
# Title.
|
||||
item.title = track_info.title
|
||||
|
||||
if config['per_disc_numbering']:
|
||||
item.track = track_info.medium_index or track_info.index
|
||||
item.tracktotal = track_info.medium_total or len(album_info.tracks)
|
||||
else:
|
||||
item.track = track_info.index
|
||||
item.tracktotal = len(album_info.tracks)
|
||||
|
||||
# Disc and disc count.
|
||||
item.disc = track_info.medium
|
||||
item.disctotal = album_info.mediums
|
||||
|
||||
# MusicBrainz IDs.
|
||||
item.mb_trackid = track_info.track_id
|
||||
item.mb_albumid = album_info.album_id
|
||||
if track_info.artist_id:
|
||||
item.mb_artistid = track_info.artist_id
|
||||
else:
|
||||
item.mb_artistid = album_info.artist_id
|
||||
item.mb_albumartistid = album_info.artist_id
|
||||
item.mb_releasegroupid = album_info.releasegroup_id
|
||||
|
||||
# Compilation flag.
|
||||
item.comp = album_info.va
|
||||
|
||||
# Miscellaneous metadata.
|
||||
for field in ('albumtype',
|
||||
'label',
|
||||
'asin',
|
||||
'catalognum',
|
||||
'script',
|
||||
'language',
|
||||
'country',
|
||||
'albumstatus',
|
||||
'albumdisambig'):
|
||||
value = getattr(album_info, field)
|
||||
if value is not None:
|
||||
item[field] = value
|
||||
if track_info.disctitle is not None:
|
||||
item.disctitle = track_info.disctitle
|
||||
|
||||
if track_info.media is not None:
|
||||
item.media = track_info.media
|
||||
|
||||
# Headphones seal of approval
|
||||
item.comments = 'tagged by headphones/beets'
|
||||
579
lib/beets/autotag/hooks.py
Normal file
579
lib/beets/autotag/hooks.py
Normal file
@@ -0,0 +1,579 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Glue between metadata sources and the matching logic."""
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
import re
|
||||
|
||||
from beets import plugins
|
||||
from beets import config
|
||||
from beets.autotag import mb
|
||||
from beets.util import levenshtein
|
||||
from unidecode import unidecode
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Classes used to represent candidate options.
|
||||
|
||||
class AlbumInfo(object):
|
||||
"""Describes a canonical release that may be used to match a release
|
||||
in the library. Consists of these data members:
|
||||
|
||||
- ``album``: the release title
|
||||
- ``album_id``: MusicBrainz ID; UUID fragment only
|
||||
- ``artist``: name of the release's primary artist
|
||||
- ``artist_id``
|
||||
- ``tracks``: list of TrackInfo objects making up the release
|
||||
- ``asin``: Amazon ASIN
|
||||
- ``albumtype``: string describing the kind of release
|
||||
- ``va``: boolean: whether the release has "various artists"
|
||||
- ``year``: release year
|
||||
- ``month``: release month
|
||||
- ``day``: release day
|
||||
- ``label``: music label responsible for the release
|
||||
- ``mediums``: the number of discs in this release
|
||||
- ``artist_sort``: name of the release's artist for sorting
|
||||
- ``releasegroup_id``: MBID for the album's release group
|
||||
- ``catalognum``: the label's catalog number for the release
|
||||
- ``script``: character set used for metadata
|
||||
- ``language``: human language of the metadata
|
||||
- ``country``: the release country
|
||||
- ``albumstatus``: MusicBrainz release status (Official, etc.)
|
||||
- ``media``: delivery mechanism (Vinyl, etc.)
|
||||
- ``albumdisambig``: MusicBrainz release disambiguation comment
|
||||
- ``artist_credit``: Release-specific artist name
|
||||
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
|
||||
- ``data_url``: The data source release URL.
|
||||
|
||||
The fields up through ``tracks`` are required. The others are
|
||||
optional and may be None.
|
||||
"""
|
||||
def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
|
||||
albumtype=None, va=False, year=None, month=None, day=None,
|
||||
label=None, mediums=None, artist_sort=None,
|
||||
releasegroup_id=None, catalognum=None, script=None,
|
||||
language=None, country=None, albumstatus=None, media=None,
|
||||
albumdisambig=None, artist_credit=None, original_year=None,
|
||||
original_month=None, original_day=None, data_source=None,
|
||||
data_url=None):
|
||||
self.album = album
|
||||
self.album_id = album_id
|
||||
self.artist = artist
|
||||
self.artist_id = artist_id
|
||||
self.tracks = tracks
|
||||
self.asin = asin
|
||||
self.albumtype = albumtype
|
||||
self.va = va
|
||||
self.year = year
|
||||
self.month = month
|
||||
self.day = day
|
||||
self.label = label
|
||||
self.mediums = mediums
|
||||
self.artist_sort = artist_sort
|
||||
self.releasegroup_id = releasegroup_id
|
||||
self.catalognum = catalognum
|
||||
self.script = script
|
||||
self.language = language
|
||||
self.country = country
|
||||
self.albumstatus = albumstatus
|
||||
self.media = media
|
||||
self.albumdisambig = albumdisambig
|
||||
self.artist_credit = artist_credit
|
||||
self.original_year = original_year
|
||||
self.original_month = original_month
|
||||
self.original_day = original_day
|
||||
self.data_source = data_source
|
||||
self.data_url = data_url
|
||||
|
||||
# Work around a bug in python-musicbrainz-ngs that causes some
|
||||
# strings to be bytes rather than Unicode.
|
||||
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
|
||||
def decode(self, codec='utf8'):
|
||||
"""Ensure that all string attributes on this object, and the
|
||||
constituent `TrackInfo` objects, are decoded to Unicode.
|
||||
"""
|
||||
for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
|
||||
'catalognum', 'script', 'language', 'country',
|
||||
'albumstatus', 'albumdisambig', 'artist_credit', 'media']:
|
||||
value = getattr(self, fld)
|
||||
if isinstance(value, str):
|
||||
setattr(self, fld, value.decode(codec, 'ignore'))
|
||||
|
||||
if self.tracks:
|
||||
for track in self.tracks:
|
||||
track.decode(codec)
|
||||
|
||||
|
||||
class TrackInfo(object):
|
||||
"""Describes a canonical track present on a release. Appears as part
|
||||
of an AlbumInfo's ``tracks`` list. Consists of these data members:
|
||||
|
||||
- ``title``: name of the track
|
||||
- ``track_id``: MusicBrainz ID; UUID fragment only
|
||||
- ``artist``: individual track artist name
|
||||
- ``artist_id``
|
||||
- ``length``: float: duration of the track in seconds
|
||||
- ``index``: position on the entire release
|
||||
- ``media``: delivery mechanism (Vinyl, etc.)
|
||||
- ``medium``: the disc number this track appears on in the album
|
||||
- ``medium_index``: the track's position on the disc
|
||||
- ``medium_total``: the number of tracks on the item's disc
|
||||
- ``artist_sort``: name of the track artist for sorting
|
||||
- ``disctitle``: name of the individual medium (subtitle)
|
||||
- ``artist_credit``: Recording-specific artist name
|
||||
|
||||
Only ``title`` and ``track_id`` are required. The rest of the fields
|
||||
may be None. The indices ``index``, ``medium``, and ``medium_index``
|
||||
are all 1-based.
|
||||
"""
|
||||
def __init__(self, title, track_id, artist=None, artist_id=None,
|
||||
length=None, index=None, medium=None, medium_index=None,
|
||||
medium_total=None, artist_sort=None, disctitle=None,
|
||||
artist_credit=None, data_source=None, data_url=None,
|
||||
media=None):
|
||||
self.title = title
|
||||
self.track_id = track_id
|
||||
self.artist = artist
|
||||
self.artist_id = artist_id
|
||||
self.length = length
|
||||
self.index = index
|
||||
self.media = media
|
||||
self.medium = medium
|
||||
self.medium_index = medium_index
|
||||
self.medium_total = medium_total
|
||||
self.artist_sort = artist_sort
|
||||
self.disctitle = disctitle
|
||||
self.artist_credit = artist_credit
|
||||
self.data_source = data_source
|
||||
self.data_url = data_url
|
||||
|
||||
# As above, work around a bug in python-musicbrainz-ngs.
|
||||
def decode(self, codec='utf8'):
|
||||
"""Ensure that all string attributes on this object are decoded
|
||||
to Unicode.
|
||||
"""
|
||||
for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
|
||||
'artist_credit', 'media']:
|
||||
value = getattr(self, fld)
|
||||
if isinstance(value, str):
|
||||
setattr(self, fld, value.decode(codec, 'ignore'))
|
||||
|
||||
|
||||
# Candidate distance scoring.
|
||||
|
||||
# Parameters for string distance function.
|
||||
# Words that can be moved to the end of a string using a comma.
|
||||
SD_END_WORDS = ['the', 'a', 'an']
|
||||
# Reduced weights for certain portions of the string.
|
||||
SD_PATTERNS = [
|
||||
(r'^the ', 0.1),
|
||||
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
|
||||
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
|
||||
(r'\(.*?\)', 0.3),
|
||||
(r'\[.*?\]', 0.3),
|
||||
(r'(, )?(pt\.|part) .+', 0.2),
|
||||
]
|
||||
# Replacements to use before testing distance.
|
||||
SD_REPLACE = [
|
||||
(r'&', 'and'),
|
||||
]
|
||||
|
||||
|
||||
def _string_dist_basic(str1, str2):
|
||||
"""Basic edit distance between two strings, ignoring
|
||||
non-alphanumeric characters and case. Comparisons are based on a
|
||||
transliteration/lowering to ASCII characters. Normalized by string
|
||||
length.
|
||||
"""
|
||||
str1 = unidecode(str1)
|
||||
str2 = unidecode(str2)
|
||||
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
|
||||
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
|
||||
if not str1 and not str2:
|
||||
return 0.0
|
||||
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
|
||||
|
||||
|
||||
def string_dist(str1, str2):
|
||||
"""Gives an "intuitive" edit distance between two strings. This is
|
||||
an edit distance, normalized by the string length, with a number of
|
||||
tweaks that reflect intuition about text.
|
||||
"""
|
||||
if str1 is None and str2 is None:
|
||||
return 0.0
|
||||
if str1 is None or str2 is None:
|
||||
return 1.0
|
||||
|
||||
str1 = str1.lower()
|
||||
str2 = str2.lower()
|
||||
|
||||
# Don't penalize strings that move certain words to the end. For
|
||||
# example, "the something" should be considered equal to
|
||||
# "something, the".
|
||||
for word in SD_END_WORDS:
|
||||
if str1.endswith(', %s' % word):
|
||||
str1 = '%s %s' % (word, str1[:-len(word) - 2])
|
||||
if str2.endswith(', %s' % word):
|
||||
str2 = '%s %s' % (word, str2[:-len(word) - 2])
|
||||
|
||||
# Perform a couple of basic normalizing substitutions.
|
||||
for pat, repl in SD_REPLACE:
|
||||
str1 = re.sub(pat, repl, str1)
|
||||
str2 = re.sub(pat, repl, str2)
|
||||
|
||||
# Change the weight for certain string portions matched by a set
|
||||
# of regular expressions. We gradually change the strings and build
|
||||
# up penalties associated with parts of the string that were
|
||||
# deleted.
|
||||
base_dist = _string_dist_basic(str1, str2)
|
||||
penalty = 0.0
|
||||
for pat, weight in SD_PATTERNS:
|
||||
# Get strings that drop the pattern.
|
||||
case_str1 = re.sub(pat, '', str1)
|
||||
case_str2 = re.sub(pat, '', str2)
|
||||
|
||||
if case_str1 != str1 or case_str2 != str2:
|
||||
# If the pattern was present (i.e., it is deleted in the
|
||||
# the current case), recalculate the distances for the
|
||||
# modified strings.
|
||||
case_dist = _string_dist_basic(case_str1, case_str2)
|
||||
case_delta = max(0.0, base_dist - case_dist)
|
||||
if case_delta == 0.0:
|
||||
continue
|
||||
|
||||
# Shift our baseline strings down (to avoid rematching the
|
||||
# same part of the string) and add a scaled distance
|
||||
# amount to the penalties.
|
||||
str1 = case_str1
|
||||
str2 = case_str2
|
||||
base_dist = case_dist
|
||||
penalty += weight * case_delta
|
||||
|
||||
return base_dist + penalty
|
||||
|
||||
|
||||
class LazyClassProperty(object):
|
||||
"""A decorator implementing a read-only property that is *lazy* in
|
||||
the sense that the getter is only invoked once. Subsequent accesses
|
||||
through *any* instance use the cached result.
|
||||
"""
|
||||
def __init__(self, getter):
|
||||
self.getter = getter
|
||||
self.computed = False
|
||||
|
||||
def __get__(self, obj, owner):
|
||||
if not self.computed:
|
||||
self.value = self.getter(owner)
|
||||
self.computed = True
|
||||
return self.value
|
||||
|
||||
|
||||
class Distance(object):
|
||||
"""Keeps track of multiple distance penalties. Provides a single
|
||||
weighted distance for all penalties as well as a weighted distance
|
||||
for each individual penalty.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._penalties = {}
|
||||
|
||||
@LazyClassProperty
|
||||
def _weights(cls):
|
||||
"""A dictionary from keys to floating-point weights.
|
||||
"""
|
||||
weights_view = config['match']['distance_weights']
|
||||
weights = {}
|
||||
for key in weights_view.keys():
|
||||
weights[key] = weights_view[key].as_number()
|
||||
return weights
|
||||
|
||||
# Access the components and their aggregates.
|
||||
|
||||
@property
|
||||
def distance(self):
|
||||
"""Return a weighted and normalized distance across all
|
||||
penalties.
|
||||
"""
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return self.raw_distance / self.max_distance
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def max_distance(self):
|
||||
"""Return the maximum distance penalty (normalization factor).
|
||||
"""
|
||||
dist_max = 0.0
|
||||
for key, penalty in self._penalties.iteritems():
|
||||
dist_max += len(penalty) * self._weights[key]
|
||||
return dist_max
|
||||
|
||||
@property
|
||||
def raw_distance(self):
|
||||
"""Return the raw (denormalized) distance.
|
||||
"""
|
||||
dist_raw = 0.0
|
||||
for key, penalty in self._penalties.iteritems():
|
||||
dist_raw += sum(penalty) * self._weights[key]
|
||||
return dist_raw
|
||||
|
||||
def items(self):
|
||||
"""Return a list of (key, dist) pairs, with `dist` being the
|
||||
weighted distance, sorted from highest to lowest. Does not
|
||||
include penalties with a zero value.
|
||||
"""
|
||||
list_ = []
|
||||
for key in self._penalties:
|
||||
dist = self[key]
|
||||
if dist:
|
||||
list_.append((key, dist))
|
||||
# Convert distance into a negative float we can sort items in
|
||||
# ascending order (for keys, when the penalty is equal) and
|
||||
# still get the items with the biggest distance first.
|
||||
return sorted(list_, key=lambda (key, dist): (0 - dist, key))
|
||||
|
||||
# Behave like a float.
|
||||
|
||||
def __cmp__(self, other):
|
||||
return cmp(self.distance, other)
|
||||
|
||||
def __float__(self):
|
||||
return self.distance
|
||||
|
||||
def __sub__(self, other):
|
||||
return self.distance - other
|
||||
|
||||
def __rsub__(self, other):
|
||||
return other - self.distance
|
||||
|
||||
# Behave like a dict.
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Returns the weighted distance for a named penalty.
|
||||
"""
|
||||
dist = sum(self._penalties[key]) * self._weights[key]
|
||||
dist_max = self.max_distance
|
||||
if dist_max:
|
||||
return dist / dist_max
|
||||
return 0.0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items())
|
||||
|
||||
def keys(self):
|
||||
return [key for key, _ in self.items()]
|
||||
|
||||
def update(self, dist):
|
||||
"""Adds all the distance penalties from `dist`.
|
||||
"""
|
||||
if not isinstance(dist, Distance):
|
||||
raise ValueError(
|
||||
'`dist` must be a Distance object, not {0}'.format(type(dist))
|
||||
)
|
||||
for key, penalties in dist._penalties.iteritems():
|
||||
self._penalties.setdefault(key, []).extend(penalties)
|
||||
|
||||
# Adding components.
|
||||
|
||||
def _eq(self, value1, value2):
|
||||
"""Returns True if `value1` is equal to `value2`. `value1` may
|
||||
be a compiled regular expression, in which case it will be
|
||||
matched against `value2`.
|
||||
"""
|
||||
if isinstance(value1, re._pattern_type):
|
||||
return bool(value1.match(value2))
|
||||
return value1 == value2
|
||||
|
||||
def add(self, key, dist):
|
||||
"""Adds a distance penalty. `key` must correspond with a
|
||||
configured weight setting. `dist` must be a float between 0.0
|
||||
and 1.0, and will be added to any existing distance penalties
|
||||
for the same key.
|
||||
"""
|
||||
if not 0.0 <= dist <= 1.0:
|
||||
raise ValueError(
|
||||
'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
|
||||
)
|
||||
self._penalties.setdefault(key, []).append(dist)
|
||||
|
||||
def add_equality(self, key, value, options):
|
||||
"""Adds a distance penalty of 1.0 if `value` doesn't match any
|
||||
of the values in `options`. If an option is a compiled regular
|
||||
expression, it will be considered equal if it matches against
|
||||
`value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
for opt in options:
|
||||
if self._eq(opt, value):
|
||||
dist = 0.0
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_expr(self, key, expr):
|
||||
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
|
||||
or 0.0.
|
||||
"""
|
||||
if expr:
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_number(self, key, number1, number2):
|
||||
"""Adds a distance penalty of 1.0 for each number of difference
|
||||
between `number1` and `number2`, or 0.0 when there is no
|
||||
difference. Use this when there is no upper limit on the
|
||||
difference between the two numbers.
|
||||
"""
|
||||
diff = abs(number1 - number2)
|
||||
if diff:
|
||||
for i in range(diff):
|
||||
self.add(key, 1.0)
|
||||
else:
|
||||
self.add(key, 0.0)
|
||||
|
||||
def add_priority(self, key, value, options):
|
||||
"""Adds a distance penalty that corresponds to the position at
|
||||
which `value` appears in `options`. A distance penalty of 0.0
|
||||
for the first option, or 1.0 if there is no matching option. If
|
||||
an option is a compiled regular expression, it will be
|
||||
considered equal if it matches against `value`.
|
||||
"""
|
||||
if not isinstance(options, (list, tuple)):
|
||||
options = [options]
|
||||
unit = 1.0 / (len(options) or 1)
|
||||
for i, opt in enumerate(options):
|
||||
if self._eq(opt, value):
|
||||
dist = i * unit
|
||||
break
|
||||
else:
|
||||
dist = 1.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_ratio(self, key, number1, number2):
|
||||
"""Adds a distance penalty for `number1` as a ratio of `number2`.
|
||||
`number1` is bound at 0 and `number2`.
|
||||
"""
|
||||
number = float(max(min(number1, number2), 0))
|
||||
if number2:
|
||||
dist = number / number2
|
||||
else:
|
||||
dist = 0.0
|
||||
self.add(key, dist)
|
||||
|
||||
def add_string(self, key, str1, str2):
|
||||
"""Adds a distance penalty based on the edit distance between
|
||||
`str1` and `str2`.
|
||||
"""
|
||||
dist = string_dist(str1, str2)
|
||||
self.add(key, dist)
|
||||
|
||||
|
||||
# Structures that compose all the information for a candidate match.
|
||||
|
||||
AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
|
||||
'extra_items', 'extra_tracks'])
|
||||
|
||||
TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
|
||||
|
||||
|
||||
# Aggregation of sources.
|
||||
|
||||
def album_for_mbid(release_id):
|
||||
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
|
||||
if the ID is not found.
|
||||
"""
|
||||
try:
|
||||
return mb.album_for_id(release_id)
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
|
||||
def track_for_mbid(recording_id):
|
||||
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
|
||||
if the ID is not found.
|
||||
"""
|
||||
try:
|
||||
return mb.track_for_id(recording_id)
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
|
||||
def albums_for_id(album_id):
|
||||
"""Get a list of albums for an ID."""
|
||||
candidates = [album_for_mbid(album_id)]
|
||||
candidates.extend(plugins.album_for_id(album_id))
|
||||
return filter(None, candidates)
|
||||
|
||||
|
||||
def tracks_for_id(track_id):
|
||||
"""Get a list of tracks for an ID."""
|
||||
candidates = [track_for_mbid(track_id)]
|
||||
candidates.extend(plugins.track_for_id(track_id))
|
||||
return filter(None, candidates)
|
||||
|
||||
|
||||
def album_candidates(items, artist, album, va_likely):
|
||||
"""Search for album matches. ``items`` is a list of Item objects
|
||||
that make up the album. ``artist`` and ``album`` are the respective
|
||||
names (strings), which may be derived from the item list or may be
|
||||
entered by the user. ``va_likely`` is a boolean indicating whether
|
||||
the album is likely to be a "various artists" release.
|
||||
"""
|
||||
out = []
|
||||
|
||||
# Base candidates if we have album and artist to match.
|
||||
if artist and album:
|
||||
try:
|
||||
out.extend(mb.match_album(artist, album, len(items)))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Also add VA matches from MusicBrainz where appropriate.
|
||||
if va_likely and album:
|
||||
try:
|
||||
out.extend(mb.match_album(None, album, len(items)))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Candidates from plugins.
|
||||
out.extend(plugins.candidates(items, artist, album, va_likely))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def item_candidates(item, artist, title):
|
||||
"""Search for item matches. ``item`` is the Item to be matched.
|
||||
``artist`` and ``title`` are strings and either reflect the item or
|
||||
are specified by the user.
|
||||
"""
|
||||
out = []
|
||||
|
||||
# MusicBrainz candidates.
|
||||
if artist and title:
|
||||
try:
|
||||
out.extend(mb.match_track(artist, title))
|
||||
except mb.MusicBrainzAPIError as exc:
|
||||
exc.log(log)
|
||||
|
||||
# Plugin candidates.
|
||||
out.extend(plugins.item_candidates(item, artist, title))
|
||||
|
||||
return out
|
||||
492
lib/beets/autotag/match.py
Normal file
492
lib/beets/autotag/match.py
Normal file
@@ -0,0 +1,492 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Matches existing metadata with canonical information to identify
|
||||
releases and tracks.
|
||||
"""
|
||||
from __future__ import division
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
from munkres import Munkres
|
||||
|
||||
from beets import plugins
|
||||
from beets import config
|
||||
from beets.util import plurality
|
||||
from beets.autotag import hooks
|
||||
from beets.util.enumeration import OrderedEnum
|
||||
|
||||
# Artist signals that indicate "various artists". These are used at the
|
||||
# album level to determine whether a given release is likely a VA
|
||||
# release and also on the track level to to remove the penalty for
|
||||
# differing artists.
|
||||
VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
# Recommendation enumeration.
|
||||
|
||||
class Recommendation(OrderedEnum):
|
||||
"""Indicates a qualitative suggestion to the user about what should
|
||||
be done with a given match.
|
||||
"""
|
||||
none = 0
|
||||
low = 1
|
||||
medium = 2
|
||||
strong = 3
|
||||
|
||||
|
||||
# Primary matching functionality.
|
||||
|
||||
def current_metadata(items):
|
||||
"""Extract the likely current metadata for an album given a list of its
|
||||
items. Return two dictionaries:
|
||||
- The most common value for each field.
|
||||
- Whether each field's value was unanimous (values are booleans).
|
||||
"""
|
||||
assert items # Must be nonempty.
|
||||
|
||||
likelies = {}
|
||||
consensus = {}
|
||||
fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
|
||||
'mb_albumid', 'label', 'catalognum', 'country', 'media',
|
||||
'albumdisambig']
|
||||
for field in fields:
|
||||
values = [item[field] for item in items if item]
|
||||
likelies[field], freq = plurality(values)
|
||||
consensus[field] = (freq == len(values))
|
||||
|
||||
# If there's an album artist consensus, use this for the artist.
|
||||
if consensus['albumartist'] and likelies['albumartist']:
|
||||
likelies['artist'] = likelies['albumartist']
|
||||
|
||||
return likelies, consensus
|
||||
|
||||
|
||||
def assign_items(items, tracks):
|
||||
"""Given a list of Items and a list of TrackInfo objects, find the
|
||||
best mapping between them. Returns a mapping from Items to TrackInfo
|
||||
objects, a set of extra Items, and a set of extra TrackInfo
|
||||
objects. These "extra" objects occur when there is an unequal number
|
||||
of objects of the two types.
|
||||
"""
|
||||
# Construct the cost matrix.
|
||||
costs = []
|
||||
for item in items:
|
||||
row = []
|
||||
for i, track in enumerate(tracks):
|
||||
row.append(track_distance(item, track))
|
||||
costs.append(row)
|
||||
|
||||
# Find a minimum-cost bipartite matching.
|
||||
matching = Munkres().compute(costs)
|
||||
|
||||
# Produce the output matching.
|
||||
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
|
||||
extra_items = list(set(items) - set(mapping.keys()))
|
||||
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
|
||||
extra_tracks = list(set(tracks) - set(mapping.values()))
|
||||
extra_tracks.sort(key=lambda t: (t.index, t.title))
|
||||
return mapping, extra_items, extra_tracks
|
||||
|
||||
|
||||
def track_index_changed(item, track_info):
|
||||
"""Returns True if the item and track info index is different. Tolerates
|
||||
per disc and per release numbering.
|
||||
"""
|
||||
return item.track not in (track_info.medium_index, track_info.index)
|
||||
|
||||
|
||||
def track_distance(item, track_info, incl_artist=False):
|
||||
"""Determines the significance of a track metadata change. Returns a
|
||||
Distance object. `incl_artist` indicates that a distance component should
|
||||
be included for the track artist (i.e., for various-artist releases).
|
||||
"""
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Length.
|
||||
if track_info.length:
|
||||
diff = abs(item.length - track_info.length) - \
|
||||
config['match']['track_length_grace'].as_number()
|
||||
dist.add_ratio('track_length', diff,
|
||||
config['match']['track_length_max'].as_number())
|
||||
|
||||
# Title.
|
||||
dist.add_string('track_title', item.title, track_info.title)
|
||||
|
||||
# Artist. Only check if there is actually an artist in the track data.
|
||||
if incl_artist and track_info.artist and \
|
||||
item.artist.lower() not in VA_ARTISTS:
|
||||
dist.add_string('track_artist', item.artist, track_info.artist)
|
||||
|
||||
# Track index.
|
||||
if track_info.index and item.track:
|
||||
dist.add_expr('track_index', track_index_changed(item, track_info))
|
||||
|
||||
# Track ID.
|
||||
if item.mb_trackid:
|
||||
dist.add_expr('track_id', item.mb_trackid != track_info.track_id)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.track_distance(item, track_info))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def distance(items, album_info, mapping):
|
||||
"""Determines how "significant" an album metadata change would be.
|
||||
Returns a Distance object. `album_info` is an AlbumInfo object
|
||||
reflecting the album to be compared. `items` is a sequence of all
|
||||
Item objects that will be matched (order is not important).
|
||||
`mapping` is a dictionary mapping Items to TrackInfo objects; the
|
||||
keys are a subset of `items` and the values are a subset of
|
||||
`album_info.tracks`.
|
||||
"""
|
||||
likelies, _ = current_metadata(items)
|
||||
|
||||
dist = hooks.Distance()
|
||||
|
||||
# Artist, if not various.
|
||||
if not album_info.va:
|
||||
dist.add_string('artist', likelies['artist'], album_info.artist)
|
||||
|
||||
# Album.
|
||||
dist.add_string('album', likelies['album'], album_info.album)
|
||||
|
||||
# Current or preferred media.
|
||||
if album_info.media:
|
||||
# Preferred media options.
|
||||
patterns = config['match']['preferred']['media'].as_str_seq()
|
||||
options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
|
||||
if options:
|
||||
dist.add_priority('media', album_info.media, options)
|
||||
# Current media.
|
||||
elif likelies['media']:
|
||||
dist.add_equality('media', album_info.media, likelies['media'])
|
||||
|
||||
# Mediums.
|
||||
if likelies['disctotal'] and album_info.mediums:
|
||||
dist.add_number('mediums', likelies['disctotal'], album_info.mediums)
|
||||
|
||||
# Prefer earliest release.
|
||||
if album_info.year and config['match']['preferred']['original_year']:
|
||||
# Assume 1889 (earliest first gramophone discs) if we don't know the
|
||||
# original year.
|
||||
original = album_info.original_year or 1889
|
||||
diff = abs(album_info.year - original)
|
||||
diff_max = abs(datetime.date.today().year - original)
|
||||
dist.add_ratio('year', diff, diff_max)
|
||||
# Year.
|
||||
elif likelies['year'] and album_info.year:
|
||||
if likelies['year'] in (album_info.year, album_info.original_year):
|
||||
# No penalty for matching release or original year.
|
||||
dist.add('year', 0.0)
|
||||
elif album_info.original_year:
|
||||
# Prefer matchest closest to the release year.
|
||||
diff = abs(likelies['year'] - album_info.year)
|
||||
diff_max = abs(datetime.date.today().year -
|
||||
album_info.original_year)
|
||||
dist.add_ratio('year', diff, diff_max)
|
||||
else:
|
||||
# Full penalty when there is no original year.
|
||||
dist.add('year', 1.0)
|
||||
|
||||
# Preferred countries.
|
||||
patterns = config['match']['preferred']['countries'].as_str_seq()
|
||||
options = [re.compile(pat, re.I) for pat in patterns]
|
||||
if album_info.country and options:
|
||||
dist.add_priority('country', album_info.country, options)
|
||||
# Country.
|
||||
elif likelies['country'] and album_info.country:
|
||||
dist.add_string('country', likelies['country'], album_info.country)
|
||||
|
||||
# Label.
|
||||
if likelies['label'] and album_info.label:
|
||||
dist.add_string('label', likelies['label'], album_info.label)
|
||||
|
||||
# Catalog number.
|
||||
if likelies['catalognum'] and album_info.catalognum:
|
||||
dist.add_string('catalognum', likelies['catalognum'],
|
||||
album_info.catalognum)
|
||||
|
||||
# Disambiguation.
|
||||
if likelies['albumdisambig'] and album_info.albumdisambig:
|
||||
dist.add_string('albumdisambig', likelies['albumdisambig'],
|
||||
album_info.albumdisambig)
|
||||
|
||||
# Album ID.
|
||||
if likelies['mb_albumid']:
|
||||
dist.add_equality('album_id', likelies['mb_albumid'],
|
||||
album_info.album_id)
|
||||
|
||||
# Tracks.
|
||||
dist.tracks = {}
|
||||
for item, track in mapping.iteritems():
|
||||
dist.tracks[track] = track_distance(item, track, album_info.va)
|
||||
dist.add('tracks', dist.tracks[track].distance)
|
||||
|
||||
# Missing tracks.
|
||||
for i in range(len(album_info.tracks) - len(mapping)):
|
||||
dist.add('missing_tracks', 1.0)
|
||||
|
||||
# Unmatched tracks.
|
||||
for i in range(len(items) - len(mapping)):
|
||||
dist.add('unmatched_tracks', 1.0)
|
||||
|
||||
# Plugins.
|
||||
dist.update(plugins.album_distance(items, album_info, mapping))
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def match_by_id(items):
|
||||
"""If the items are tagged with a MusicBrainz album ID, returns an
|
||||
AlbumInfo object for the corresponding album. Otherwise, returns
|
||||
None.
|
||||
"""
|
||||
# Is there a consensus on the MB album ID?
|
||||
albumids = [item.mb_albumid for item in items if item.mb_albumid]
|
||||
if not albumids:
|
||||
log.debug(u'No album IDs found.')
|
||||
return None
|
||||
|
||||
# If all album IDs are equal, look up the album.
|
||||
if bool(reduce(lambda x, y: x if x == y else (), albumids)):
|
||||
albumid = albumids[0]
|
||||
log.debug(u'Searching for discovered album ID: {0}'.format(albumid))
|
||||
return hooks.album_for_mbid(albumid)
|
||||
else:
|
||||
log.debug(u'No album ID consensus.')
|
||||
|
||||
|
||||
def _recommendation(results):
|
||||
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
|
||||
recommendation based on the results' distances.
|
||||
|
||||
If the recommendation is higher than the configured maximum for
|
||||
an applied penalty, the recommendation will be downgraded to the
|
||||
configured maximum for that penalty.
|
||||
"""
|
||||
if not results:
|
||||
# No candidates: no recommendation.
|
||||
return Recommendation.none
|
||||
|
||||
# Basic distance thresholding.
|
||||
min_dist = results[0].distance
|
||||
if min_dist < config['match']['strong_rec_thresh'].as_number():
|
||||
# Strong recommendation level.
|
||||
rec = Recommendation.strong
|
||||
elif min_dist <= config['match']['medium_rec_thresh'].as_number():
|
||||
# Medium recommendation level.
|
||||
rec = Recommendation.medium
|
||||
elif len(results) == 1:
|
||||
# Only a single candidate.
|
||||
rec = Recommendation.low
|
||||
elif results[1].distance - min_dist >= \
|
||||
config['match']['rec_gap_thresh'].as_number():
|
||||
# Gap between first two candidates is large.
|
||||
rec = Recommendation.low
|
||||
else:
|
||||
# No conclusion. Return immediately. Can't be downgraded any further.
|
||||
return Recommendation.none
|
||||
|
||||
# Downgrade to the max rec if it is lower than the current rec for an
|
||||
# applied penalty.
|
||||
keys = set(min_dist.keys())
|
||||
if isinstance(results[0], hooks.AlbumMatch):
|
||||
for track_dist in min_dist.tracks.values():
|
||||
keys.update(track_dist.keys())
|
||||
max_rec_view = config['match']['max_rec']
|
||||
for key in keys:
|
||||
if key in max_rec_view.keys():
|
||||
max_rec = max_rec_view[key].as_choice({
|
||||
'strong': Recommendation.strong,
|
||||
'medium': Recommendation.medium,
|
||||
'low': Recommendation.low,
|
||||
'none': Recommendation.none,
|
||||
})
|
||||
rec = min(rec, max_rec)
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def _add_candidate(items, results, info):
|
||||
"""Given a candidate AlbumInfo object, attempt to add the candidate
|
||||
to the output dictionary of AlbumMatch objects. This involves
|
||||
checking the track count, ordering the items, checking for
|
||||
duplicates, and calculating the distance.
|
||||
"""
|
||||
log.debug(u'Candidate: {0} - {1}'.format(info.artist, info.album))
|
||||
|
||||
# Discard albums with zero tracks.
|
||||
if not info.tracks:
|
||||
log.debug('No tracks.')
|
||||
return
|
||||
|
||||
# Don't duplicate.
|
||||
if info.album_id in results:
|
||||
log.debug(u'Duplicate.')
|
||||
return
|
||||
|
||||
# Discard matches without required tags.
|
||||
for req_tag in config['match']['required'].as_str_seq():
|
||||
if getattr(info, req_tag) is None:
|
||||
log.debug(u'Ignored. Missing required tag: {0}'.format(req_tag))
|
||||
return
|
||||
|
||||
# Find mapping between the items and the track info.
|
||||
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
|
||||
|
||||
# Get the change distance.
|
||||
dist = distance(items, info, mapping)
|
||||
|
||||
# Skip matches with ignored penalties.
|
||||
penalties = [key for _, key in dist]
|
||||
for penalty in config['match']['ignored'].as_str_seq():
|
||||
if penalty in penalties:
|
||||
log.debug(u'Ignored. Penalty: {0}'.format(penalty))
|
||||
return
|
||||
|
||||
log.debug(u'Success. Distance: {0}'.format(dist))
|
||||
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
|
||||
extra_items, extra_tracks)
|
||||
|
||||
|
||||
def tag_album(items, search_artist=None, search_album=None,
|
||||
search_id=None):
|
||||
"""Return a tuple of a artist name, an album name, a list of
|
||||
`AlbumMatch` candidates from the metadata backend, and a
|
||||
`Recommendation`.
|
||||
|
||||
The artist and album are the most common values of these fields
|
||||
among `items`.
|
||||
|
||||
The `AlbumMatch` objects are generated by searching the metadata
|
||||
backends. By default, the metadata of the items is used for the
|
||||
search. This can be customized by setting the parameters. The
|
||||
`mapping` field of the album has the matched `items` as keys.
|
||||
|
||||
The recommendation is calculated from the match qualitiy of the
|
||||
candidates.
|
||||
"""
|
||||
# Get current metadata.
|
||||
likelies, consensus = current_metadata(items)
|
||||
cur_artist = likelies['artist']
|
||||
cur_album = likelies['album']
|
||||
log.debug(u'Tagging {0} - {1}'.format(cur_artist, cur_album))
|
||||
|
||||
# The output result (distance, AlbumInfo) tuples (keyed by MB album
|
||||
# ID).
|
||||
candidates = {}
|
||||
|
||||
# Search by explicit ID.
|
||||
if search_id is not None:
|
||||
log.debug(u'Searching for album ID: {0}'.format(search_id))
|
||||
search_cands = hooks.albums_for_id(search_id)
|
||||
|
||||
# Use existing metadata or text search.
|
||||
else:
|
||||
# Try search based on current ID.
|
||||
id_info = match_by_id(items)
|
||||
if id_info:
|
||||
_add_candidate(items, candidates, id_info)
|
||||
rec = _recommendation(candidates.values())
|
||||
log.debug(u'Album ID match recommendation is {0}'.format(str(rec)))
|
||||
if candidates and not config['import']['timid']:
|
||||
# If we have a very good MBID match, return immediately.
|
||||
# Otherwise, this match will compete against metadata-based
|
||||
# matches.
|
||||
if rec == Recommendation.strong:
|
||||
log.debug(u'ID match.')
|
||||
return cur_artist, cur_album, candidates.values(), rec
|
||||
|
||||
# Search terms.
|
||||
if not (search_artist and search_album):
|
||||
# No explicit search terms -- use current metadata.
|
||||
search_artist, search_album = cur_artist, cur_album
|
||||
log.debug(u'Search terms: {0} - {1}'.format(search_artist,
|
||||
search_album))
|
||||
|
||||
# Is this album likely to be a "various artist" release?
|
||||
va_likely = ((not consensus['artist']) or
|
||||
(search_artist.lower() in VA_ARTISTS) or
|
||||
any(item.comp for item in items))
|
||||
log.debug(u'Album might be VA: {0}'.format(str(va_likely)))
|
||||
|
||||
# Get the results from the data sources.
|
||||
search_cands = hooks.album_candidates(items, search_artist,
|
||||
search_album, va_likely)
|
||||
|
||||
log.debug(u'Evaluating {0} candidates.'.format(len(search_cands)))
|
||||
for info in search_cands:
|
||||
_add_candidate(items, candidates, info)
|
||||
|
||||
# Sort and get the recommendation.
|
||||
candidates = sorted(candidates.itervalues())
|
||||
rec = _recommendation(candidates)
|
||||
return cur_artist, cur_album, candidates, rec
|
||||
|
||||
|
||||
def tag_item(item, search_artist=None, search_title=None,
|
||||
search_id=None):
|
||||
"""Attempts to find metadata for a single track. Returns a
|
||||
`(candidates, recommendation)` pair where `candidates` is a list of
|
||||
TrackMatch objects. `search_artist` and `search_title` may be used
|
||||
to override the current metadata for the purposes of the MusicBrainz
|
||||
title; likewise `search_id`.
|
||||
"""
|
||||
# Holds candidates found so far: keys are MBIDs; values are
|
||||
# (distance, TrackInfo) pairs.
|
||||
candidates = {}
|
||||
|
||||
# First, try matching by MusicBrainz ID.
|
||||
trackid = search_id or item.mb_trackid
|
||||
if trackid:
|
||||
log.debug(u'Searching for track ID: {0}'.format(trackid))
|
||||
for track_info in hooks.tracks_for_id(trackid):
|
||||
dist = track_distance(item, track_info, incl_artist=True)
|
||||
candidates[track_info.track_id] = \
|
||||
hooks.TrackMatch(dist, track_info)
|
||||
# If this is a good match, then don't keep searching.
|
||||
rec = _recommendation(candidates.values())
|
||||
if rec == Recommendation.strong and not config['import']['timid']:
|
||||
log.debug(u'Track ID match.')
|
||||
return candidates.values(), rec
|
||||
|
||||
# If we're searching by ID, don't proceed.
|
||||
if search_id is not None:
|
||||
if candidates:
|
||||
return candidates.values(), rec
|
||||
else:
|
||||
return [], Recommendation.none
|
||||
|
||||
# Search terms.
|
||||
if not (search_artist and search_title):
|
||||
search_artist, search_title = item.artist, item.title
|
||||
log.debug(u'Item search terms: {0} - {1}'.format(search_artist,
|
||||
search_title))
|
||||
|
||||
# Get and evaluate candidate metadata.
|
||||
for track_info in hooks.item_candidates(item, search_artist, search_title):
|
||||
dist = track_distance(item, track_info, incl_artist=True)
|
||||
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
|
||||
|
||||
# Sort by distance and return with recommendation.
|
||||
log.debug(u'Found {0} candidates.'.format(len(candidates)))
|
||||
candidates = sorted(candidates.itervalues())
|
||||
rec = _recommendation(candidates)
|
||||
return candidates, rec
|
||||
407
lib/beets/autotag/mb.py
Normal file
407
lib/beets/autotag/mb.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Searches for albums in the MusicBrainz database.
|
||||
"""
|
||||
import logging
|
||||
import musicbrainzngs
|
||||
import re
|
||||
import traceback
|
||||
from urlparse import urljoin
|
||||
|
||||
import beets.autotag.hooks
|
||||
import beets
|
||||
from beets import util
|
||||
from beets import config
|
||||
|
||||
SEARCH_LIMIT = 5
|
||||
VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
|
||||
BASE_URL = 'http://musicbrainz.org/'
|
||||
|
||||
musicbrainzngs.set_useragent('beets', beets.__version__,
|
||||
'http://beets.radbox.org/')
|
||||
|
||||
|
||||
class MusicBrainzAPIError(util.HumanReadableException):
|
||||
"""An error while talking to MusicBrainz. The `query` field is the
|
||||
parameter to the action and may have any type.
|
||||
"""
|
||||
def __init__(self, reason, verb, query, tb=None):
|
||||
self.query = query
|
||||
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
|
||||
|
||||
def get_message(self):
|
||||
return u'{0} in {1} with query {2}'.format(
|
||||
self._reasonstr(), self.verb, repr(self.query)
|
||||
)
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
|
||||
'labels', 'artist-credits', 'aliases']
|
||||
TRACK_INCLUDES = ['artists', 'aliases']
|
||||
|
||||
|
||||
def track_url(trackid):
|
||||
return urljoin(BASE_URL, 'recording/' + trackid)
|
||||
|
||||
|
||||
def album_url(albumid):
|
||||
return urljoin(BASE_URL, 'release/' + albumid)
|
||||
|
||||
|
||||
def configure():
|
||||
"""Set up the python-musicbrainz-ngs module according to settings
|
||||
from the beets configuration. This should be called at startup.
|
||||
"""
|
||||
musicbrainzngs.set_hostname(config['musicbrainz']['host'].get(unicode))
|
||||
musicbrainzngs.set_rate_limit(
|
||||
config['musicbrainz']['ratelimit_interval'].as_number(),
|
||||
config['musicbrainz']['ratelimit'].get(int),
|
||||
)
|
||||
|
||||
|
||||
def _preferred_alias(aliases):
|
||||
"""Given an list of alias structures for an artist credit, select
|
||||
and return the user's preferred alias alias or None if no matching
|
||||
alias is found.
|
||||
"""
|
||||
if not aliases:
|
||||
return
|
||||
|
||||
# Only consider aliases that have locales set.
|
||||
aliases = [a for a in aliases if 'locale' in a]
|
||||
|
||||
# Search configured locales in order.
|
||||
for locale in config['import']['languages'].as_str_seq():
|
||||
# Find matching primary aliases for this locale.
|
||||
matches = [a for a in aliases
|
||||
if a['locale'] == locale and 'primary' in a]
|
||||
# Skip to the next locale if we have no matches
|
||||
if not matches:
|
||||
continue
|
||||
|
||||
return matches[0]
|
||||
|
||||
|
||||
def _flatten_artist_credit(credit):
|
||||
"""Given a list representing an ``artist-credit`` block, flatten the
|
||||
data into a triple of joined artist name strings: canonical, sort, and
|
||||
credit.
|
||||
"""
|
||||
artist_parts = []
|
||||
artist_sort_parts = []
|
||||
artist_credit_parts = []
|
||||
for el in credit:
|
||||
if isinstance(el, basestring):
|
||||
# Join phrase.
|
||||
artist_parts.append(el)
|
||||
artist_credit_parts.append(el)
|
||||
artist_sort_parts.append(el)
|
||||
|
||||
else:
|
||||
alias = _preferred_alias(el['artist'].get('alias-list', ()))
|
||||
|
||||
# An artist.
|
||||
if alias:
|
||||
cur_artist_name = alias['alias']
|
||||
else:
|
||||
cur_artist_name = el['artist']['name']
|
||||
artist_parts.append(cur_artist_name)
|
||||
|
||||
# Artist sort name.
|
||||
if alias:
|
||||
artist_sort_parts.append(alias['sort-name'])
|
||||
elif 'sort-name' in el['artist']:
|
||||
artist_sort_parts.append(el['artist']['sort-name'])
|
||||
else:
|
||||
artist_sort_parts.append(cur_artist_name)
|
||||
|
||||
# Artist credit.
|
||||
if 'name' in el:
|
||||
artist_credit_parts.append(el['name'])
|
||||
else:
|
||||
artist_credit_parts.append(cur_artist_name)
|
||||
|
||||
return (
|
||||
''.join(artist_parts),
|
||||
''.join(artist_sort_parts),
|
||||
''.join(artist_credit_parts),
|
||||
)
|
||||
|
||||
|
||||
def track_info(recording, index=None, medium=None, medium_index=None,
|
||||
medium_total=None):
|
||||
"""Translates a MusicBrainz recording result dictionary into a beets
|
||||
``TrackInfo`` object. Three parameters are optional and are used
|
||||
only for tracks that appear on releases (non-singletons): ``index``,
|
||||
the overall track number; ``medium``, the disc number;
|
||||
``medium_index``, the track's index on its medium; ``medium_total``,
|
||||
the number of tracks on the medium. Each number is a 1-based index.
|
||||
"""
|
||||
info = beets.autotag.hooks.TrackInfo(
|
||||
recording['title'],
|
||||
recording['id'],
|
||||
index=index,
|
||||
medium=medium,
|
||||
medium_index=medium_index,
|
||||
medium_total=medium_total,
|
||||
data_url=track_url(recording['id']),
|
||||
)
|
||||
|
||||
if recording.get('artist-credit'):
|
||||
# Get the artist names.
|
||||
info.artist, info.artist_sort, info.artist_credit = \
|
||||
_flatten_artist_credit(recording['artist-credit'])
|
||||
|
||||
# Get the ID and sort name of the first artist.
|
||||
artist = recording['artist-credit'][0]['artist']
|
||||
info.artist_id = artist['id']
|
||||
|
||||
if recording.get('length'):
|
||||
info.length = int(recording['length']) / (1000.0)
|
||||
|
||||
info.decode()
|
||||
return info
|
||||
|
||||
|
||||
def _set_date_str(info, date_str, original=False):
|
||||
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
|
||||
object, set the object's release date fields appropriately. If
|
||||
`original`, then set the original_year, etc., fields.
|
||||
"""
|
||||
if date_str:
|
||||
date_parts = date_str.split('-')
|
||||
for key in ('year', 'month', 'day'):
|
||||
if date_parts:
|
||||
date_part = date_parts.pop(0)
|
||||
try:
|
||||
date_num = int(date_part)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if original:
|
||||
key = 'original_' + key
|
||||
setattr(info, key, date_num)
|
||||
|
||||
|
||||
def album_info(release):
|
||||
"""Takes a MusicBrainz release result dictionary and returns a beets
|
||||
AlbumInfo object containing the interesting data about that release.
|
||||
"""
|
||||
# Get artist name using join phrases.
|
||||
artist_name, artist_sort_name, artist_credit_name = \
|
||||
_flatten_artist_credit(release['artist-credit'])
|
||||
|
||||
# Basic info.
|
||||
track_infos = []
|
||||
index = 0
|
||||
for medium in release['medium-list']:
|
||||
disctitle = medium.get('title')
|
||||
format = medium.get('format')
|
||||
for track in medium['track-list']:
|
||||
# Basic information from the recording.
|
||||
index += 1
|
||||
ti = track_info(
|
||||
track['recording'],
|
||||
index,
|
||||
int(medium['position']),
|
||||
int(track['position']),
|
||||
len(medium['track-list']),
|
||||
)
|
||||
ti.disctitle = disctitle
|
||||
ti.media = format
|
||||
|
||||
# Prefer track data, where present, over recording data.
|
||||
if track.get('title'):
|
||||
ti.title = track['title']
|
||||
if track.get('artist-credit'):
|
||||
# Get the artist names.
|
||||
ti.artist, ti.artist_sort, ti.artist_credit = \
|
||||
_flatten_artist_credit(track['artist-credit'])
|
||||
ti.artist_id = track['artist-credit'][0]['artist']['id']
|
||||
if track.get('length'):
|
||||
ti.length = int(track['length']) / (1000.0)
|
||||
|
||||
track_infos.append(ti)
|
||||
|
||||
info = beets.autotag.hooks.AlbumInfo(
|
||||
release['title'],
|
||||
release['id'],
|
||||
artist_name,
|
||||
release['artist-credit'][0]['artist']['id'],
|
||||
track_infos,
|
||||
mediums=len(release['medium-list']),
|
||||
artist_sort=artist_sort_name,
|
||||
artist_credit=artist_credit_name,
|
||||
data_source='MusicBrainz',
|
||||
data_url=album_url(release['id']),
|
||||
)
|
||||
info.va = info.artist_id == VARIOUS_ARTISTS_ID
|
||||
info.asin = release.get('asin')
|
||||
info.releasegroup_id = release['release-group']['id']
|
||||
info.country = release.get('country')
|
||||
info.albumstatus = release.get('status')
|
||||
|
||||
# Build up the disambiguation string from the release group and release.
|
||||
disambig = []
|
||||
if release['release-group'].get('disambiguation'):
|
||||
disambig.append(release['release-group'].get('disambiguation'))
|
||||
if release.get('disambiguation'):
|
||||
disambig.append(release.get('disambiguation'))
|
||||
info.albumdisambig = u', '.join(disambig)
|
||||
|
||||
# Release type not always populated.
|
||||
if 'type' in release['release-group']:
|
||||
reltype = release['release-group']['type']
|
||||
if reltype:
|
||||
info.albumtype = reltype.lower()
|
||||
|
||||
# Release dates.
|
||||
release_date = release.get('date')
|
||||
release_group_date = release['release-group'].get('first-release-date')
|
||||
if not release_date:
|
||||
# Fall back if release-specific date is not available.
|
||||
release_date = release_group_date
|
||||
_set_date_str(info, release_date, False)
|
||||
_set_date_str(info, release_group_date, True)
|
||||
|
||||
# Label name.
|
||||
if release.get('label-info-list'):
|
||||
label_info = release['label-info-list'][0]
|
||||
if label_info.get('label'):
|
||||
label = label_info['label']['name']
|
||||
if label != '[no label]':
|
||||
info.label = label
|
||||
info.catalognum = label_info.get('catalog-number')
|
||||
|
||||
# Text representation data.
|
||||
if release.get('text-representation'):
|
||||
rep = release['text-representation']
|
||||
info.script = rep.get('script')
|
||||
info.language = rep.get('language')
|
||||
|
||||
# Media (format).
|
||||
if release['medium-list']:
|
||||
first_medium = release['medium-list'][0]
|
||||
info.media = first_medium.get('format')
|
||||
|
||||
info.decode()
|
||||
return info
|
||||
|
||||
|
||||
def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
|
||||
"""Searches for a single album ("release" in MusicBrainz parlance)
|
||||
and returns an iterator over AlbumInfo objects. May raise a
|
||||
MusicBrainzAPIError.
|
||||
|
||||
The query consists of an artist name, an album name, and,
|
||||
optionally, a number of tracks on the album.
|
||||
"""
|
||||
# Build search criteria.
|
||||
criteria = {'release': album.lower().strip()}
|
||||
if artist is not None:
|
||||
criteria['artist'] = artist.lower().strip()
|
||||
else:
|
||||
# Various Artists search.
|
||||
criteria['arid'] = VARIOUS_ARTISTS_ID
|
||||
if tracks is not None:
|
||||
criteria['tracks'] = str(tracks)
|
||||
|
||||
# Abort if we have no search terms.
|
||||
if not any(criteria.itervalues()):
|
||||
return
|
||||
|
||||
try:
|
||||
res = musicbrainzngs.search_releases(limit=limit, **criteria)
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'release search', criteria,
|
||||
traceback.format_exc())
|
||||
for release in res['release-list']:
|
||||
# The search result is missing some data (namely, the tracks),
|
||||
# so we just use the ID and fetch the rest of the information.
|
||||
albuminfo = album_for_id(release['id'])
|
||||
if albuminfo is not None:
|
||||
yield albuminfo
|
||||
|
||||
|
||||
def match_track(artist, title, limit=SEARCH_LIMIT):
|
||||
"""Searches for a single track and returns an iterable of TrackInfo
|
||||
objects. May raise a MusicBrainzAPIError.
|
||||
"""
|
||||
criteria = {
|
||||
'artist': artist.lower().strip(),
|
||||
'recording': title.lower().strip(),
|
||||
}
|
||||
|
||||
if not any(criteria.itervalues()):
|
||||
return
|
||||
|
||||
try:
|
||||
res = musicbrainzngs.search_recordings(limit=limit, **criteria)
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'recording search', criteria,
|
||||
traceback.format_exc())
|
||||
for recording in res['recording-list']:
|
||||
yield track_info(recording)
|
||||
|
||||
|
||||
def _parse_id(s):
|
||||
"""Search for a MusicBrainz ID in the given string and return it. If
|
||||
no ID can be found, return None.
|
||||
"""
|
||||
# Find the first thing that looks like a UUID/MBID.
|
||||
match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
|
||||
def album_for_id(releaseid):
|
||||
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
|
||||
object or None if the album is not found. May raise a
|
||||
MusicBrainzAPIError.
|
||||
"""
|
||||
albumid = _parse_id(releaseid)
|
||||
if not albumid:
|
||||
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
|
||||
return
|
||||
try:
|
||||
res = musicbrainzngs.get_release_by_id(albumid,
|
||||
RELEASE_INCLUDES)
|
||||
except musicbrainzngs.ResponseError:
|
||||
log.debug(u'Album ID match failed.')
|
||||
return None
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
|
||||
traceback.format_exc())
|
||||
return album_info(res['release'])
|
||||
|
||||
|
||||
def track_for_id(releaseid):
|
||||
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
|
||||
or None if no track is found. May raise a MusicBrainzAPIError.
|
||||
"""
|
||||
trackid = _parse_id(releaseid)
|
||||
if not trackid:
|
||||
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
|
||||
return
|
||||
try:
|
||||
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
|
||||
except musicbrainzngs.ResponseError:
|
||||
log.debug(u'Track ID match failed.')
|
||||
return None
|
||||
except musicbrainzngs.MusicBrainzError as exc:
|
||||
raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,
|
||||
traceback.format_exc())
|
||||
return track_info(res['recording'])
|
||||
109
lib/beets/config_default.yaml
Normal file
109
lib/beets/config_default.yaml
Normal file
@@ -0,0 +1,109 @@
|
||||
library: library.db
|
||||
directory: ~/Music
|
||||
|
||||
import:
|
||||
write: yes
|
||||
copy: yes
|
||||
move: no
|
||||
link: no
|
||||
delete: no
|
||||
resume: ask
|
||||
incremental: no
|
||||
quiet_fallback: skip
|
||||
none_rec_action: ask
|
||||
timid: no
|
||||
log:
|
||||
autotag: yes
|
||||
quiet: no
|
||||
singletons: no
|
||||
default_action: apply
|
||||
languages: []
|
||||
detail: no
|
||||
flat: no
|
||||
group_albums: no
|
||||
pretend: false
|
||||
|
||||
clutter: ["Thumbs.DB", ".DS_Store"]
|
||||
ignore: [".*", "*~", "System Volume Information"]
|
||||
replace:
|
||||
'[\\/]': _
|
||||
'^\.': _
|
||||
'[\x00-\x1f]': _
|
||||
'[<>:"\?\*\|]': _
|
||||
'\.$': _
|
||||
'\s+$': ''
|
||||
'^\s+': ''
|
||||
path_sep_replace: _
|
||||
asciify_paths: false
|
||||
art_filename: cover
|
||||
max_filename_length: 0
|
||||
|
||||
plugins: []
|
||||
pluginpath: []
|
||||
threaded: yes
|
||||
color: yes
|
||||
timeout: 5.0
|
||||
per_disc_numbering: no
|
||||
verbose: no
|
||||
terminal_encoding: utf8
|
||||
original_date: no
|
||||
id3v23: no
|
||||
|
||||
ui:
|
||||
terminal_width: 80
|
||||
length_diff_thresh: 10.0
|
||||
|
||||
list_format_item: $artist - $album - $title
|
||||
list_format_album: $albumartist - $album
|
||||
time_format: '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
sort_album: albumartist+ album+
|
||||
sort_item: artist+ album+ disc+ track+
|
||||
|
||||
paths:
|
||||
default: $albumartist/$album%aunique{}/$track $title
|
||||
singleton: Non-Album/$artist/$title
|
||||
comp: Compilations/$album%aunique{}/$track $title
|
||||
|
||||
statefile: state.pickle
|
||||
|
||||
musicbrainz:
|
||||
host: musicbrainz.org
|
||||
ratelimit: 1
|
||||
ratelimit_interval: 1.0
|
||||
|
||||
match:
|
||||
strong_rec_thresh: 0.04
|
||||
medium_rec_thresh: 0.25
|
||||
rec_gap_thresh: 0.25
|
||||
max_rec:
|
||||
missing_tracks: medium
|
||||
unmatched_tracks: medium
|
||||
distance_weights:
|
||||
source: 2.0
|
||||
artist: 3.0
|
||||
album: 3.0
|
||||
media: 1.0
|
||||
mediums: 1.0
|
||||
year: 1.0
|
||||
country: 0.5
|
||||
label: 0.5
|
||||
catalognum: 0.5
|
||||
albumdisambig: 0.5
|
||||
album_id: 5.0
|
||||
tracks: 2.0
|
||||
missing_tracks: 0.9
|
||||
unmatched_tracks: 0.6
|
||||
track_title: 3.0
|
||||
track_artist: 2.0
|
||||
track_index: 1.0
|
||||
track_length: 2.0
|
||||
track_id: 5.0
|
||||
preferred:
|
||||
countries: []
|
||||
media: []
|
||||
original_year: no
|
||||
ignored: []
|
||||
required: []
|
||||
track_length_grace: 10
|
||||
track_length_max: 30
|
||||
25
lib/beets/dbcore/__init__.py
Normal file
25
lib/beets/dbcore/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""DBCore is an abstract database package that forms the basis for beets'
|
||||
Library.
|
||||
"""
|
||||
from .db import Model, Database
|
||||
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
|
||||
from .types import Type
|
||||
from .queryparse import query_from_strings
|
||||
from .queryparse import sort_from_strings
|
||||
from .queryparse import parse_sorted_query
|
||||
|
||||
# flake8: noqa
|
||||
820
lib/beets/dbcore/db.py
Normal file
820
lib/beets/dbcore/db.py
Normal file
@@ -0,0 +1,820 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""The central Model and Database constructs for DBCore.
|
||||
"""
|
||||
import time
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
import sqlite3
|
||||
import contextlib
|
||||
import collections
|
||||
|
||||
import beets
|
||||
from beets.util.functemplate import Template
|
||||
from beets.dbcore import types
|
||||
from .query import MatchQuery, NullSort, TrueQuery
|
||||
|
||||
|
||||
class FormattedMapping(collections.Mapping):
|
||||
"""A `dict`-like formatted view of a model.
|
||||
|
||||
The accessor `mapping[key]` returns the formated version of
|
||||
`model[key]` as a unicode string.
|
||||
|
||||
If `for_path` is true, all path separators in the formatted values
|
||||
are replaced.
|
||||
"""
|
||||
|
||||
def __init__(self, model, for_path=False):
|
||||
self.for_path = for_path
|
||||
self.model = model
|
||||
self.model_keys = model.keys(True)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.model_keys:
|
||||
return self._get_formatted(self.model, key)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.model_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model_keys)
|
||||
|
||||
def get(self, key, default=None):
|
||||
if default is None:
|
||||
default = self.model._type(key).format(None)
|
||||
return super(FormattedMapping, self).get(key, default)
|
||||
|
||||
def _get_formatted(self, model, key):
|
||||
value = model._type(key).format(model.get(key))
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
if self.for_path:
|
||||
sep_repl = beets.config['path_sep_replace'].get(unicode)
|
||||
for sep in (os.path.sep, os.path.altsep):
|
||||
if sep:
|
||||
value = value.replace(sep, sep_repl)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Abstract base for model classes.
|
||||
|
||||
class Model(object):
|
||||
"""An abstract object representing an object in the database. Model
|
||||
objects act like dictionaries (i.e., the allow subscript access like
|
||||
``obj['field']``). The same field set is available via attribute
|
||||
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
|
||||
available:
|
||||
|
||||
* **Fixed attributes** come from a predetermined list of field
|
||||
names. These fields correspond to SQLite table columns and are
|
||||
thus fast to read, write, and query.
|
||||
* **Flexible attributes** are free-form and do not need to be listed
|
||||
ahead of time.
|
||||
* **Computed attributes** are read-only fields computed by a getter
|
||||
function provided by a plugin.
|
||||
|
||||
Access to all three field types is uniform: ``obj.field`` works the
|
||||
same regardless of whether ``field`` is fixed, flexible, or
|
||||
computed.
|
||||
|
||||
Model objects can optionally be associated with a `Library` object,
|
||||
in which case they can be loaded and stored from the database. Dirty
|
||||
flags are used to track which fields need to be stored.
|
||||
"""
|
||||
|
||||
# Abstract components (to be provided by subclasses).
|
||||
|
||||
_table = None
|
||||
"""The main SQLite table name.
|
||||
"""
|
||||
|
||||
_flex_table = None
|
||||
"""The flex field SQLite table name.
|
||||
"""
|
||||
|
||||
_fields = {}
|
||||
"""A mapping indicating available "fixed" fields on this type. The
|
||||
keys are field names and the values are `Type` objects.
|
||||
"""
|
||||
|
||||
_search_fields = ()
|
||||
"""The fields that should be queried by default by unqualified query
|
||||
terms.
|
||||
"""
|
||||
|
||||
_types = {}
|
||||
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
|
||||
"""
|
||||
|
||||
_sorts = {}
|
||||
"""Optional named sort criteria. The keys are strings and the values
|
||||
are subclasses of `Sort`.
|
||||
"""
|
||||
|
||||
_always_dirty = False
|
||||
"""By default, fields only become "dirty" when their value actually
|
||||
changes. Enabling this flag marks fields as dirty even when the new
|
||||
value is the same as the old value (e.g., `o.f = o.f`).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _getters(cls):
|
||||
"""Return a mapping from field names to getter functions.
|
||||
"""
|
||||
# We could cache this if it becomes a performance problem to
|
||||
# gather the getter mapping every time.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _template_funcs(self):
|
||||
"""Return a mapping from function names to text-transformer
|
||||
functions.
|
||||
"""
|
||||
# As above: we could consider caching this result.
|
||||
raise NotImplementedError()
|
||||
|
||||
# Basic operation.
|
||||
|
||||
def __init__(self, db=None, **values):
|
||||
"""Create a new object with an optional Database association and
|
||||
initial field values.
|
||||
"""
|
||||
self._db = db
|
||||
self._dirty = set()
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
|
||||
# Initial contents.
|
||||
self.update(values)
|
||||
self.clear_dirty()
|
||||
|
||||
@classmethod
|
||||
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
|
||||
"""Create an object with values drawn from the database.
|
||||
|
||||
This is a performance optimization: the checks involved with
|
||||
ordinary construction are bypassed.
|
||||
"""
|
||||
obj = cls(db)
|
||||
for key, value in fixed_values.iteritems():
|
||||
obj._values_fixed[key] = cls._type(key).from_sql(value)
|
||||
for key, value in flex_values.iteritems():
|
||||
obj._values_flex[key] = cls._type(key).from_sql(value)
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return '{0}({1})'.format(
|
||||
type(self).__name__,
|
||||
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
|
||||
)
|
||||
|
||||
def clear_dirty(self):
|
||||
"""Mark all fields as *clean* (i.e., not needing to be stored to
|
||||
the database).
|
||||
"""
|
||||
self._dirty = set()
|
||||
|
||||
def _check_db(self, need_id=True):
|
||||
"""Ensure that this object is associated with a database row: it
|
||||
has a reference to a database (`_db`) and an id. A ValueError
|
||||
exception is raised otherwise.
|
||||
"""
|
||||
if not self._db:
|
||||
raise ValueError('{0} has no database'.format(type(self).__name__))
|
||||
if need_id and not self.id:
|
||||
raise ValueError('{0} has no id'.format(type(self).__name__))
|
||||
|
||||
# Essential field accessors.
|
||||
|
||||
@classmethod
|
||||
def _type(self, key):
|
||||
"""Get the type of a field, a `Type` instance.
|
||||
|
||||
If the field has no explicit type, it is given the base `Type`,
|
||||
which does no conversion.
|
||||
"""
|
||||
return self._fields.get(key) or self._types.get(key) or types.DEFAULT
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Get the value for a field. Raise a KeyError if the field is
|
||||
not available.
|
||||
"""
|
||||
getters = self._getters()
|
||||
if key in getters: # Computed.
|
||||
return getters[key](self)
|
||||
elif key in self._fields: # Fixed.
|
||||
return self._values_fixed.get(key)
|
||||
elif key in self._values_flex: # Flexible.
|
||||
return self._values_flex[key]
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
"""Assign the value for a field.
|
||||
"""
|
||||
# Choose where to place the value.
|
||||
if key in self._fields:
|
||||
source = self._values_fixed
|
||||
else:
|
||||
source = self._values_flex
|
||||
|
||||
# If the field has a type, filter the value.
|
||||
value = self._type(key).normalize(value)
|
||||
|
||||
# Assign value and possibly mark as dirty.
|
||||
old_value = source.get(key)
|
||||
source[key] = value
|
||||
if self._always_dirty or old_value != value:
|
||||
self._dirty.add(key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
"""Remove a flexible attribute from the model.
|
||||
"""
|
||||
if key in self._values_flex: # Flexible.
|
||||
del self._values_flex[key]
|
||||
self._dirty.add(key) # Mark for dropping on store.
|
||||
elif key in self._getters(): # Computed.
|
||||
raise KeyError('computed field {0} cannot be deleted'.format(key))
|
||||
elif key in self._fields: # Fixed.
|
||||
raise KeyError('fixed field {0} cannot be deleted'.format(key))
|
||||
else:
|
||||
raise KeyError('no such field {0}'.format(key))
|
||||
|
||||
def keys(self, computed=False):
|
||||
"""Get a list of available field names for this object. The
|
||||
`computed` parameter controls whether computed (plugin-provided)
|
||||
fields are included in the key list.
|
||||
"""
|
||||
base_keys = list(self._fields) + self._values_flex.keys()
|
||||
if computed:
|
||||
return base_keys + self._getters().keys()
|
||||
else:
|
||||
return base_keys
|
||||
|
||||
# Act like a dictionary.
|
||||
|
||||
def update(self, values):
|
||||
"""Assign all values in the given dict.
|
||||
"""
|
||||
for key, value in values.items():
|
||||
self[key] = value
|
||||
|
||||
def items(self):
|
||||
"""Iterate over (key, value) pairs that this object contains.
|
||||
Computed fields are not included.
|
||||
"""
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get the value for a given key or `default` if it does not
|
||||
exist.
|
||||
"""
|
||||
if key in self:
|
||||
return self[key]
|
||||
else:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
"""Determine whether `key` is an attribute on this object.
|
||||
"""
|
||||
return key in self.keys(True)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the available field names (excluding computed
|
||||
fields).
|
||||
"""
|
||||
return iter(self.keys())
|
||||
|
||||
# Convenient attribute access.
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError('model has no attribute {0!r}'.format(key))
|
||||
else:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError('no such field {0!r}'.format(key))
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__setattr__(key, value)
|
||||
else:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
super(Model, self).__delattr__(key)
|
||||
else:
|
||||
del self[key]
|
||||
|
||||
# Database interaction (CRUD methods).
|
||||
|
||||
def store(self):
|
||||
"""Save the object's metadata into the library database.
|
||||
"""
|
||||
self._check_db()
|
||||
|
||||
# Build assignments for query.
|
||||
assignments = []
|
||||
subvars = []
|
||||
for key in self._fields:
|
||||
if key != 'id' and key in self._dirty:
|
||||
self._dirty.remove(key)
|
||||
assignments.append(key + '=?')
|
||||
value = self._type(key).to_sql(self[key])
|
||||
subvars.append(value)
|
||||
assignments = ','.join(assignments)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
# Main table update.
|
||||
if assignments:
|
||||
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
|
||||
self._table, assignments
|
||||
)
|
||||
subvars.append(self.id)
|
||||
tx.mutate(query, subvars)
|
||||
|
||||
# Modified/added flexible attributes.
|
||||
for key, value in self._values_flex.items():
|
||||
if key in self._dirty:
|
||||
self._dirty.remove(key)
|
||||
tx.mutate(
|
||||
'INSERT INTO {0} '
|
||||
'(entity_id, key, value) '
|
||||
'VALUES (?, ?, ?);'.format(self._flex_table),
|
||||
(self.id, key, value),
|
||||
)
|
||||
|
||||
# Deleted flexible attributes.
|
||||
for key in self._dirty:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} '
|
||||
'WHERE entity_id=? AND key=?'.format(self._flex_table),
|
||||
(self.id, key)
|
||||
)
|
||||
|
||||
self.clear_dirty()
|
||||
|
||||
def load(self):
|
||||
"""Refresh the object's metadata from the library database.
|
||||
"""
|
||||
self._check_db()
|
||||
stored_obj = self._db._get(type(self), self.id)
|
||||
assert stored_obj is not None, "object {0} not in DB".format(self.id)
|
||||
self._values_fixed = {}
|
||||
self._values_flex = {}
|
||||
self.update(dict(stored_obj))
|
||||
self.clear_dirty()
|
||||
|
||||
def remove(self):
|
||||
"""Remove the object's associated rows from the database.
|
||||
"""
|
||||
self._check_db()
|
||||
with self._db.transaction() as tx:
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE id=?'.format(self._table),
|
||||
(self.id,)
|
||||
)
|
||||
tx.mutate(
|
||||
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
|
||||
(self.id,)
|
||||
)
|
||||
|
||||
def add(self, db=None):
|
||||
"""Add the object to the library database. This object must be
|
||||
associated with a database; you can provide one via the `db`
|
||||
parameter or use the currently associated database.
|
||||
|
||||
The object's `id` and `added` fields are set along with any
|
||||
current field values.
|
||||
"""
|
||||
if db:
|
||||
self._db = db
|
||||
self._check_db(False)
|
||||
|
||||
with self._db.transaction() as tx:
|
||||
new_id = tx.mutate(
|
||||
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
|
||||
)
|
||||
self.id = new_id
|
||||
self.added = time.time()
|
||||
|
||||
# Mark every non-null field as dirty and store.
|
||||
for key in self:
|
||||
if self[key] is not None:
|
||||
self._dirty.add(key)
|
||||
self.store()
|
||||
|
||||
# Formatting and templating.
|
||||
|
||||
_formatter = FormattedMapping
|
||||
|
||||
def formatted(self, for_path=False):
|
||||
"""Get a mapping containing all values on this object formatted
|
||||
as human-readable unicode strings.
|
||||
"""
|
||||
return self._formatter(self, for_path)
|
||||
|
||||
def evaluate_template(self, template, for_path=False):
|
||||
"""Evaluate a template (a string or a `Template` object) using
|
||||
the object's fields. If `for_path` is true, then no new path
|
||||
separators will be added to the template.
|
||||
"""
|
||||
# Perform substitution.
|
||||
if isinstance(template, basestring):
|
||||
template = Template(template)
|
||||
return template.substitute(self.formatted(for_path),
|
||||
self._template_funcs())
|
||||
|
||||
# Parsing.
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, key, string):
|
||||
"""Parse a string as a value for the given key.
|
||||
"""
|
||||
if not isinstance(string, basestring):
|
||||
raise TypeError("_parse() argument must be a string")
|
||||
|
||||
return cls._type(key).parse(string)
|
||||
|
||||
|
||||
# Database controller and supporting interfaces.
|
||||
|
||||
class Results(object):
|
||||
"""An item query result set. Iterating over the collection lazily
|
||||
constructs LibModel objects that reflect database rows.
|
||||
"""
|
||||
def __init__(self, model_class, rows, db, query=None, sort=None):
|
||||
"""Create a result set that will construct objects of type
|
||||
`model_class`.
|
||||
|
||||
`model_class` is a subclass of `LibModel` that will be
|
||||
constructed. `rows` is a query result: a list of mappings. The
|
||||
new objects will be associated with the database `db`.
|
||||
|
||||
If `query` is provided, it is used as a predicate to filter the
|
||||
results for a "slow query" that cannot be evaluated by the
|
||||
database directly. If `sort` is provided, it is used to sort the
|
||||
full list of results before returning. This means it is a "slow
|
||||
sort" and all objects must be built before returning the first
|
||||
one.
|
||||
"""
|
||||
self.model_class = model_class
|
||||
self.rows = rows
|
||||
self.db = db
|
||||
self.query = query
|
||||
self.sort = sort
|
||||
|
||||
# We keep a queue of rows we haven't yet consumed for
|
||||
# materialization. We preserve the original total number of
|
||||
# rows.
|
||||
self._rows = rows
|
||||
self._row_count = len(rows)
|
||||
|
||||
# The materialized objects corresponding to rows that have been
|
||||
# consumed.
|
||||
self._objects = []
|
||||
|
||||
def _get_objects(self):
|
||||
"""Construct and generate Model objects for they query. The
|
||||
objects are returned in the order emitted from the database; no
|
||||
slow sort is applied.
|
||||
|
||||
For performance, this generator caches materialized objects to
|
||||
avoid constructing them more than once. This way, iterating over
|
||||
a `Results` object a second time should be much faster than the
|
||||
first.
|
||||
"""
|
||||
index = 0 # Position in the materialized objects.
|
||||
while index < len(self._objects) or self._rows:
|
||||
# Are there previously-materialized objects to produce?
|
||||
if index < len(self._objects):
|
||||
yield self._objects[index]
|
||||
index += 1
|
||||
|
||||
# Otherwise, we consume another row, materialize its object
|
||||
# and produce it.
|
||||
else:
|
||||
while self._rows:
|
||||
row = self._rows.pop(0)
|
||||
obj = self._make_model(row)
|
||||
# If there is a slow-query predicate, ensurer that the
|
||||
# object passes it.
|
||||
if not self.query or self.query.match(obj):
|
||||
self._objects.append(obj)
|
||||
index += 1
|
||||
yield obj
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
"""Construct and generate Model objects for all matching
|
||||
objects, in sorted order.
|
||||
"""
|
||||
if self.sort:
|
||||
# Slow sort. Must build the full list first.
|
||||
objects = self.sort.sort(list(self._get_objects()))
|
||||
return iter(objects)
|
||||
|
||||
else:
|
||||
# Objects are pre-sorted (i.e., by the database).
|
||||
return self._get_objects()
|
||||
|
||||
def _make_model(self, row):
|
||||
# Get the flexible attributes for the object.
|
||||
with self.db.transaction() as tx:
|
||||
flex_rows = tx.query(
|
||||
'SELECT * FROM {0} WHERE entity_id=?'.format(
|
||||
self.model_class._flex_table
|
||||
),
|
||||
(row['id'],)
|
||||
)
|
||||
|
||||
cols = dict(row)
|
||||
values = dict((k, v) for (k, v) in cols.items()
|
||||
if not k[:4] == 'flex')
|
||||
flex_values = dict((row['key'], row['value']) for row in flex_rows)
|
||||
|
||||
# Construct the Python object
|
||||
obj = self.model_class._awaken(self.db, values, flex_values)
|
||||
return obj
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of matching objects.
|
||||
"""
|
||||
if not self._rows:
|
||||
# Fully materialized. Just count the objects.
|
||||
return len(self._objects)
|
||||
|
||||
elif self.query:
|
||||
# A slow query. Fall back to testing every object.
|
||||
count = 0
|
||||
for obj in self:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
else:
|
||||
# A fast query. Just count the rows.
|
||||
return self._row_count
|
||||
|
||||
def __nonzero__(self):
|
||||
"""Does this result contain any objects?
|
||||
"""
|
||||
return bool(len(self))
|
||||
|
||||
def __getitem__(self, n):
|
||||
"""Get the nth item in this result set. This is inefficient: all
|
||||
items up to n are materialized and thrown away.
|
||||
"""
|
||||
if not self._rows and not self.sort:
|
||||
# Fully materialized and already in order. Just look up the
|
||||
# object.
|
||||
return self._objects[n]
|
||||
|
||||
it = iter(self)
|
||||
try:
|
||||
for i in range(n):
|
||||
it.next()
|
||||
return it.next()
|
||||
except StopIteration:
|
||||
raise IndexError('result index {0} out of range'.format(n))
|
||||
|
||||
def get(self):
|
||||
"""Return the first matching object, or None if no objects
|
||||
match.
|
||||
"""
|
||||
it = iter(self)
|
||||
try:
|
||||
return it.next()
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class Transaction(object):
|
||||
"""A context manager for safe, concurrent access to the database.
|
||||
All SQL commands should be executed through a transaction.
|
||||
"""
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
def __enter__(self):
|
||||
"""Begin a transaction. This transaction may be created while
|
||||
another is active in a different thread.
|
||||
"""
|
||||
with self.db._tx_stack() as stack:
|
||||
first = not stack
|
||||
stack.append(self)
|
||||
if first:
|
||||
# Beginning a "root" transaction, which corresponds to an
|
||||
# SQLite transaction.
|
||||
self.db._db_lock.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
"""Complete a transaction. This must be the most recently
|
||||
entered but not yet exited transaction. If it is the last active
|
||||
transaction, the database updates are committed.
|
||||
"""
|
||||
with self.db._tx_stack() as stack:
|
||||
assert stack.pop() is self
|
||||
empty = not stack
|
||||
if empty:
|
||||
# Ending a "root" transaction. End the SQLite transaction.
|
||||
self.db._connection().commit()
|
||||
self.db._db_lock.release()
|
||||
|
||||
def query(self, statement, subvals=()):
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
a list of rows from the database.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.fetchall()
|
||||
|
||||
def mutate(self, statement, subvals=()):
|
||||
"""Execute an SQL statement with substitution values and return
|
||||
the row ID of the last affected row.
|
||||
"""
|
||||
cursor = self.db._connection().execute(statement, subvals)
|
||||
return cursor.lastrowid
|
||||
|
||||
def script(self, statements):
|
||||
"""Execute a string containing multiple SQL statements."""
|
||||
self.db._connection().executescript(statements)
|
||||
|
||||
|
||||
class Database(object):
|
||||
"""A container for Model objects that wraps an SQLite database as
|
||||
the backend.
|
||||
"""
|
||||
_models = ()
|
||||
"""The Model subclasses representing tables in this database.
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
self._connections = {}
|
||||
self._tx_stacks = defaultdict(list)
|
||||
|
||||
# A lock to protect the _connections and _tx_stacks maps, which
|
||||
# both map thread IDs to private resources.
|
||||
self._shared_map_lock = threading.Lock()
|
||||
|
||||
# A lock to protect access to the database itself. SQLite does
|
||||
# allow multiple threads to access the database at the same
|
||||
# time, but many users were experiencing crashes related to this
|
||||
# capability: where SQLite was compiled without HAVE_USLEEP, its
|
||||
# backoff algorithm in the case of contention was causing
|
||||
# whole-second sleeps (!) that would trigger its internal
|
||||
# timeout. Using this lock ensures only one SQLite transaction
|
||||
# is active at a time.
|
||||
self._db_lock = threading.Lock()
|
||||
|
||||
# Set up database schema.
|
||||
for model_cls in self._models:
|
||||
self._make_table(model_cls._table, model_cls._fields)
|
||||
self._make_attribute_table(model_cls._flex_table)
|
||||
|
||||
# Primitive access control: connections and transactions.
|
||||
|
||||
def _connection(self):
|
||||
"""Get a SQLite connection object to the underlying database.
|
||||
One connection object is created per thread.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
with self._shared_map_lock:
|
||||
if thread_id in self._connections:
|
||||
return self._connections[thread_id]
|
||||
else:
|
||||
# Make a new connection.
|
||||
conn = sqlite3.connect(
|
||||
self.path,
|
||||
timeout=beets.config['timeout'].as_number(),
|
||||
)
|
||||
|
||||
# Access SELECT results like dictionaries.
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
self._connections[thread_id] = conn
|
||||
return conn
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tx_stack(self):
|
||||
"""A context manager providing access to the current thread's
|
||||
transaction stack. The context manager synchronizes access to
|
||||
the stack map. Transactions should never migrate across threads.
|
||||
"""
|
||||
thread_id = threading.current_thread().ident
|
||||
with self._shared_map_lock:
|
||||
yield self._tx_stacks[thread_id]
|
||||
|
||||
def transaction(self):
|
||||
"""Get a :class:`Transaction` object for interacting directly
|
||||
with the underlying SQLite database.
|
||||
"""
|
||||
return Transaction(self)
|
||||
|
||||
# Schema setup and migration.
|
||||
|
||||
def _make_table(self, table, fields):
|
||||
"""Set up the schema of the database. `fields` is a mapping
|
||||
from field names to `Type`s. Columns are added if necessary.
|
||||
"""
|
||||
# Get current schema.
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query('PRAGMA table_info(%s)' % table)
|
||||
current_fields = set([row[1] for row in rows])
|
||||
|
||||
field_names = set(fields.keys())
|
||||
if current_fields.issuperset(field_names):
|
||||
# Table exists and has all the required columns.
|
||||
return
|
||||
|
||||
if not current_fields:
|
||||
# No table exists.
|
||||
columns = []
|
||||
for name, typ in fields.items():
|
||||
columns.append('{0} {1}'.format(name, typ.sql))
|
||||
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
|
||||
', '.join(columns))
|
||||
|
||||
else:
|
||||
# Table exists does not match the field set.
|
||||
setup_sql = ''
|
||||
for name, typ in fields.items():
|
||||
if name in current_fields:
|
||||
continue
|
||||
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
|
||||
table, name, typ.sql
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
tx.script(setup_sql)
|
||||
|
||||
def _make_attribute_table(self, flex_table):
|
||||
"""Create a table and associated index for flexible attributes
|
||||
for the given entity (if they don't exist).
|
||||
"""
|
||||
with self.transaction() as tx:
|
||||
tx.script("""
|
||||
CREATE TABLE IF NOT EXISTS {0} (
|
||||
id INTEGER PRIMARY KEY,
|
||||
entity_id INTEGER,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
UNIQUE(entity_id, key) ON CONFLICT REPLACE);
|
||||
CREATE INDEX IF NOT EXISTS {0}_by_entity
|
||||
ON {0} (entity_id);
|
||||
""".format(flex_table))
|
||||
|
||||
# Querying.
|
||||
|
||||
def _fetch(self, model_cls, query=None, sort=None):
|
||||
"""Fetch the objects of type `model_cls` matching the given
|
||||
query. The query may be given as a string, string sequence, a
|
||||
Query object, or None (to fetch everything). `sort` is an
|
||||
`Sort` object.
|
||||
"""
|
||||
query = query or TrueQuery() # A null query.
|
||||
sort = sort or NullSort() # Unsorted.
|
||||
where, subvals = query.clause()
|
||||
order_by = sort.order_clause()
|
||||
|
||||
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
|
||||
model_cls._table,
|
||||
where or '1',
|
||||
"ORDER BY {0}".format(order_by) if order_by else '',
|
||||
)
|
||||
|
||||
with self.transaction() as tx:
|
||||
rows = tx.query(sql, subvals)
|
||||
|
||||
return Results(
|
||||
model_cls, rows, self,
|
||||
None if where else query, # Slow query component.
|
||||
sort if sort.is_slow() else None, # Slow sort component.
|
||||
)
|
||||
|
||||
def _get(self, model_cls, id):
|
||||
"""Get a Model object by its id or None if the id does not
|
||||
exist.
|
||||
"""
|
||||
return self._fetch(model_cls, MatchQuery('id', id)).get()
|
||||
654
lib/beets/dbcore/query.py
Normal file
654
lib/beets/dbcore/query.py
Normal file
@@ -0,0 +1,654 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""The Query type hierarchy for DBCore.
|
||||
"""
|
||||
import re
|
||||
from operator import attrgetter
|
||||
from beets import util
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class Query(object):
|
||||
"""An abstract class representing a query into the item database.
|
||||
"""
|
||||
def clause(self):
|
||||
"""Generate an SQLite expression implementing the query.
|
||||
Return a clause string, a sequence of substitution values for
|
||||
the clause, and a Query object representing the "remainder"
|
||||
Returns (clause, subvals) where clause is a valid sqlite
|
||||
WHERE clause implementing the query and subvals is a list of
|
||||
items to be substituted for ?s in the clause.
|
||||
"""
|
||||
return None, ()
|
||||
|
||||
def match(self, item):
|
||||
"""Check whether this query matches a given Item. Can be used to
|
||||
perform queries on arbitrary sets of Items.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FieldQuery(Query):
|
||||
"""An abstract query that searches in a specific field for a
|
||||
pattern. Subclasses must provide a `value_match` class method, which
|
||||
determines whether a certain pattern string matches a certain value
|
||||
string. Subclasses may also provide `col_clause` to implement the
|
||||
same matching functionality in SQLite.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
self.field = field
|
||||
self.pattern = pattern
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
return None, ()
|
||||
|
||||
def clause(self):
|
||||
if self.fast:
|
||||
return self.col_clause()
|
||||
else:
|
||||
# Matching a flexattr. This is a slow query.
|
||||
return None, ()
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def match(self, item):
|
||||
return self.value_match(self.pattern, item.get(self.field))
|
||||
|
||||
|
||||
class MatchQuery(FieldQuery):
|
||||
"""A query that looks for exact matches in an item field."""
|
||||
def col_clause(self):
|
||||
return self.field + " = ?", [self.pattern]
|
||||
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
return pattern == value
|
||||
|
||||
|
||||
class NoneQuery(FieldQuery):
|
||||
|
||||
def __init__(self, field, fast=True):
|
||||
self.field = field
|
||||
self.fast = fast
|
||||
|
||||
def col_clause(self):
|
||||
return self.field + " IS NULL", ()
|
||||
|
||||
@classmethod
|
||||
def match(self, item):
|
||||
try:
|
||||
return item[self.field] is None
|
||||
except KeyError:
|
||||
return True
|
||||
|
||||
|
||||
class StringFieldQuery(FieldQuery):
|
||||
"""A FieldQuery that converts values to strings before matching
|
||||
them.
|
||||
"""
|
||||
@classmethod
|
||||
def value_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. The value
|
||||
may have any type.
|
||||
"""
|
||||
return cls.string_match(pattern, util.as_string(value))
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
"""Determine whether the value matches the pattern. Both
|
||||
arguments are strings. Subclasses implement this method.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SubstringQuery(StringFieldQuery):
|
||||
"""A query that matches a substring in a specific item field."""
|
||||
def col_clause(self):
|
||||
pattern = (self.pattern
|
||||
.replace('\\', '\\\\')
|
||||
.replace('%', '\\%')
|
||||
.replace('_', '\\_'))
|
||||
search = '%' + pattern + '%'
|
||||
clause = self.field + " like ? escape '\\'"
|
||||
subvals = [search]
|
||||
return clause, subvals
|
||||
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
return pattern.lower() in value.lower()
|
||||
|
||||
|
||||
class RegexpQuery(StringFieldQuery):
|
||||
"""A query that matches a regular expression in a specific item
|
||||
field.
|
||||
"""
|
||||
@classmethod
|
||||
def string_match(cls, pattern, value):
|
||||
try:
|
||||
res = re.search(pattern, value)
|
||||
except re.error:
|
||||
# Invalid regular expression.
|
||||
return False
|
||||
return res is not None
|
||||
|
||||
|
||||
class BooleanQuery(MatchQuery):
|
||||
"""Matches a boolean field. Pattern should either be a boolean or a
|
||||
string reflecting a boolean.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(BooleanQuery, self).__init__(field, pattern, fast)
|
||||
if isinstance(pattern, basestring):
|
||||
self.pattern = util.str2bool(pattern)
|
||||
self.pattern = int(self.pattern)
|
||||
|
||||
|
||||
class BytesQuery(MatchQuery):
|
||||
"""Match a raw bytes field (i.e., a path). This is a necessary hack
|
||||
to work around the `sqlite3` module's desire to treat `str` and
|
||||
`unicode` equivalently in Python 2. Always use this query instead of
|
||||
`MatchQuery` when matching on BLOB values.
|
||||
"""
|
||||
def __init__(self, field, pattern):
|
||||
super(BytesQuery, self).__init__(field, pattern)
|
||||
|
||||
# Use a buffer representation of the pattern for SQLite
|
||||
# matching. This instructs SQLite to treat the blob as binary
|
||||
# rather than encoded Unicode.
|
||||
if isinstance(self.pattern, basestring):
|
||||
# Implicitly coerce Unicode strings to their bytes
|
||||
# equivalents.
|
||||
if isinstance(self.pattern, unicode):
|
||||
self.pattern = self.pattern.encode('utf8')
|
||||
self.buf_pattern = buffer(self.pattern)
|
||||
elif isinstance(self.pattern, buffer):
|
||||
self.buf_pattern = self.pattern
|
||||
self.pattern = bytes(self.pattern)
|
||||
|
||||
def col_clause(self):
|
||||
return self.field + " = ?", [self.buf_pattern]
|
||||
|
||||
|
||||
class NumericQuery(FieldQuery):
|
||||
"""Matches numeric fields. A syntax using Ruby-style range ellipses
|
||||
(``..``) lets users specify one- or two-sided ranges. For example,
|
||||
``year:2001..`` finds music released since the turn of the century.
|
||||
"""
|
||||
def _convert(self, s):
|
||||
"""Convert a string to a numeric type (float or int). If the
|
||||
string cannot be converted, return None.
|
||||
"""
|
||||
# This is really just a bit of fun premature optimization.
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(NumericQuery, self).__init__(field, pattern, fast)
|
||||
|
||||
parts = pattern.split('..', 1)
|
||||
if len(parts) == 1:
|
||||
# No range.
|
||||
self.point = self._convert(parts[0])
|
||||
self.rangemin = None
|
||||
self.rangemax = None
|
||||
else:
|
||||
# One- or two-sided range.
|
||||
self.point = None
|
||||
self.rangemin = self._convert(parts[0])
|
||||
self.rangemax = self._convert(parts[1])
|
||||
|
||||
def match(self, item):
|
||||
if self.field not in item:
|
||||
return False
|
||||
value = item[self.field]
|
||||
if isinstance(value, basestring):
|
||||
value = self._convert(value)
|
||||
|
||||
if self.point is not None:
|
||||
return value == self.point
|
||||
else:
|
||||
if self.rangemin is not None and value < self.rangemin:
|
||||
return False
|
||||
if self.rangemax is not None and value > self.rangemax:
|
||||
return False
|
||||
return True
|
||||
|
||||
def col_clause(self):
|
||||
if self.point is not None:
|
||||
return self.field + '=?', (self.point,)
|
||||
else:
|
||||
if self.rangemin is not None and self.rangemax is not None:
|
||||
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
|
||||
(self.rangemin, self.rangemax))
|
||||
elif self.rangemin is not None:
|
||||
return u'{0} >= ?'.format(self.field), (self.rangemin,)
|
||||
elif self.rangemax is not None:
|
||||
return u'{0} <= ?'.format(self.field), (self.rangemax,)
|
||||
else:
|
||||
return '1', ()
|
||||
|
||||
|
||||
class CollectionQuery(Query):
|
||||
"""An abstract query class that aggregates other queries. Can be
|
||||
indexed like a list to access the sub-queries.
|
||||
"""
|
||||
def __init__(self, subqueries=()):
|
||||
self.subqueries = subqueries
|
||||
|
||||
# Act like a sequence.
|
||||
|
||||
def __len__(self):
|
||||
return len(self.subqueries)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.subqueries[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.subqueries)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.subqueries
|
||||
|
||||
def clause_with_joiner(self, joiner):
|
||||
"""Returns a clause created by joining together the clauses of
|
||||
all subqueries with the string joiner (padded by spaces).
|
||||
"""
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
for subq in self.subqueries:
|
||||
subq_clause, subq_subvals = subq.clause()
|
||||
if not subq_clause:
|
||||
# Fall back to slow query.
|
||||
return None, ()
|
||||
clause_parts.append('(' + subq_clause + ')')
|
||||
subvals += subq_subvals
|
||||
clause = (' ' + joiner + ' ').join(clause_parts)
|
||||
return clause, subvals
|
||||
|
||||
|
||||
class AnyFieldQuery(CollectionQuery):
|
||||
"""A query that matches if a given FieldQuery subclass matches in
|
||||
any field. The individual field query class is provided to the
|
||||
constructor.
|
||||
"""
|
||||
def __init__(self, pattern, fields, cls):
|
||||
self.pattern = pattern
|
||||
self.fields = fields
|
||||
self.query_class = cls
|
||||
|
||||
subqueries = []
|
||||
for field in self.fields:
|
||||
subqueries.append(cls(field, pattern, True))
|
||||
super(AnyFieldQuery, self).__init__(subqueries)
|
||||
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
for subq in self.subqueries:
|
||||
if subq.match(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MutableCollectionQuery(CollectionQuery):
|
||||
"""A collection query whose subqueries may be modified after the
|
||||
query is initialized.
|
||||
"""
|
||||
def __setitem__(self, key, value):
|
||||
self.subqueries[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.subqueries[key]
|
||||
|
||||
|
||||
class AndQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('and')
|
||||
|
||||
def match(self, item):
|
||||
return all([q.match(item) for q in self.subqueries])
|
||||
|
||||
|
||||
class OrQuery(MutableCollectionQuery):
|
||||
"""A conjunction of a list of other queries."""
|
||||
def clause(self):
|
||||
return self.clause_with_joiner('or')
|
||||
|
||||
def match(self, item):
|
||||
return any([q.match(item) for q in self.subqueries])
|
||||
|
||||
|
||||
class TrueQuery(Query):
|
||||
"""A query that always matches."""
|
||||
def clause(self):
|
||||
return '1', ()
|
||||
|
||||
def match(self, item):
|
||||
return True
|
||||
|
||||
|
||||
class FalseQuery(Query):
|
||||
"""A query that never matches."""
|
||||
def clause(self):
|
||||
return '0', ()
|
||||
|
||||
def match(self, item):
|
||||
return False
|
||||
|
||||
|
||||
# Time/date queries.
|
||||
|
||||
def _to_epoch_time(date):
|
||||
"""Convert a `datetime` object to an integer number of seconds since
|
||||
the (local) Unix epoch.
|
||||
"""
|
||||
epoch = datetime.fromtimestamp(0)
|
||||
delta = date - epoch
|
||||
try:
|
||||
return int(delta.total_seconds())
|
||||
except AttributeError:
|
||||
# datetime.timedelta.total_seconds() is not available on Python 2.6
|
||||
return delta.seconds + delta.days * 24 * 3600
|
||||
|
||||
|
||||
def _parse_periods(pattern):
|
||||
"""Parse a string containing two dates separated by two dots (..).
|
||||
Return a pair of `Period` objects.
|
||||
"""
|
||||
parts = pattern.split('..', 1)
|
||||
if len(parts) == 1:
|
||||
instant = Period.parse(parts[0])
|
||||
return (instant, instant)
|
||||
else:
|
||||
start = Period.parse(parts[0])
|
||||
end = Period.parse(parts[1])
|
||||
return (start, end)
|
||||
|
||||
|
||||
class Period(object):
|
||||
"""A period of time given by a date, time and precision.
|
||||
|
||||
Example: 2014-01-01 10:50:30 with precision 'month' represents all
|
||||
instants of time during January 2014.
|
||||
"""
|
||||
|
||||
precisions = ('year', 'month', 'day')
|
||||
date_formats = ('%Y', '%Y-%m', '%Y-%m-%d')
|
||||
|
||||
def __init__(self, date, precision):
|
||||
"""Create a period with the given date (a `datetime` object) and
|
||||
precision (a string, one of "year", "month", or "day").
|
||||
"""
|
||||
if precision not in Period.precisions:
|
||||
raise ValueError('Invalid precision ' + str(precision))
|
||||
self.date = date
|
||||
self.precision = precision
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
"""Parse a date and return a `Period` object or `None` if the
|
||||
string is empty.
|
||||
"""
|
||||
if not string:
|
||||
return None
|
||||
ordinal = string.count('-')
|
||||
if ordinal >= len(cls.date_formats):
|
||||
# Too many components.
|
||||
return None
|
||||
date_format = cls.date_formats[ordinal]
|
||||
try:
|
||||
date = datetime.strptime(string, date_format)
|
||||
except ValueError:
|
||||
# Parsing failed.
|
||||
return None
|
||||
precision = cls.precisions[ordinal]
|
||||
return cls(date, precision)
|
||||
|
||||
def open_right_endpoint(self):
|
||||
"""Based on the precision, convert the period to a precise
|
||||
`datetime` for use as a right endpoint in a right-open interval.
|
||||
"""
|
||||
precision = self.precision
|
||||
date = self.date
|
||||
if 'year' == self.precision:
|
||||
return date.replace(year=date.year + 1, month=1)
|
||||
elif 'month' == precision:
|
||||
if (date.month < 12):
|
||||
return date.replace(month=date.month + 1)
|
||||
else:
|
||||
return date.replace(year=date.year + 1, month=1)
|
||||
elif 'day' == precision:
|
||||
return date + timedelta(days=1)
|
||||
else:
|
||||
raise ValueError('unhandled precision ' + str(precision))
|
||||
|
||||
|
||||
class DateInterval(object):
|
||||
"""A closed-open interval of dates.
|
||||
|
||||
A left endpoint of None means since the beginning of time.
|
||||
A right endpoint of None means towards infinity.
|
||||
"""
|
||||
|
||||
def __init__(self, start, end):
|
||||
if start is not None and end is not None and not start < end:
|
||||
raise ValueError("start date {0} is not before end date {1}"
|
||||
.format(start, end))
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@classmethod
|
||||
def from_periods(cls, start, end):
|
||||
"""Create an interval with two Periods as the endpoints.
|
||||
"""
|
||||
end_date = end.open_right_endpoint() if end is not None else None
|
||||
start_date = start.date if start is not None else None
|
||||
return cls(start_date, end_date)
|
||||
|
||||
def contains(self, date):
|
||||
if self.start is not None and date < self.start:
|
||||
return False
|
||||
if self.end is not None and date >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __str__(self):
|
||||
return'[{0}, {1})'.format(self.start, self.end)
|
||||
|
||||
|
||||
class DateQuery(FieldQuery):
|
||||
"""Matches date fields stored as seconds since Unix epoch time.
|
||||
|
||||
Dates can be specified as ``year-month-day`` strings where only year
|
||||
is mandatory.
|
||||
|
||||
The value of a date field can be matched against a date interval by
|
||||
using an ellipsis interval syntax similar to that of NumericQuery.
|
||||
"""
|
||||
def __init__(self, field, pattern, fast=True):
|
||||
super(DateQuery, self).__init__(field, pattern, fast)
|
||||
start, end = _parse_periods(pattern)
|
||||
self.interval = DateInterval.from_periods(start, end)
|
||||
|
||||
def match(self, item):
|
||||
timestamp = float(item[self.field])
|
||||
date = datetime.utcfromtimestamp(timestamp)
|
||||
return self.interval.contains(date)
|
||||
|
||||
_clause_tmpl = "{0} {1} ?"
|
||||
|
||||
def col_clause(self):
|
||||
clause_parts = []
|
||||
subvals = []
|
||||
|
||||
if self.interval.start:
|
||||
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
|
||||
subvals.append(_to_epoch_time(self.interval.start))
|
||||
|
||||
if self.interval.end:
|
||||
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
|
||||
subvals.append(_to_epoch_time(self.interval.end))
|
||||
|
||||
if clause_parts:
|
||||
# One- or two-sided interval.
|
||||
clause = ' AND '.join(clause_parts)
|
||||
else:
|
||||
# Match any date.
|
||||
clause = '1'
|
||||
return clause, subvals
|
||||
|
||||
|
||||
# Sorting.
|
||||
|
||||
class Sort(object):
|
||||
"""An abstract class representing a sort operation for a query into
|
||||
the item database.
|
||||
"""
|
||||
|
||||
def order_clause(self):
|
||||
"""Generates a SQL fragment to be used in a ORDER BY clause, or
|
||||
None if no fragment is used (i.e., this is a slow sort).
|
||||
"""
|
||||
return None
|
||||
|
||||
def sort(self, items):
|
||||
"""Sort the list of objects and return a list.
|
||||
"""
|
||||
return sorted(items)
|
||||
|
||||
def is_slow(self):
|
||||
"""Indicate whether this query is *slow*, meaning that it cannot
|
||||
be executed in SQL and must be executed in Python.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class MultipleSort(Sort):
|
||||
"""Sort that encapsulates multiple sub-sorts.
|
||||
"""
|
||||
|
||||
def __init__(self, sorts=None):
|
||||
self.sorts = sorts or []
|
||||
|
||||
def add_sort(self, sort):
|
||||
self.sorts.append(sort)
|
||||
|
||||
def _sql_sorts(self):
|
||||
"""Return the list of sub-sorts for which we can be (at least
|
||||
partially) fast.
|
||||
|
||||
A contiguous suffix of fast (SQL-capable) sub-sorts are
|
||||
executable in SQL. The remaining, even if they are fast
|
||||
independently, must be executed slowly.
|
||||
"""
|
||||
sql_sorts = []
|
||||
for sort in reversed(self.sorts):
|
||||
if not sort.order_clause() is None:
|
||||
sql_sorts.append(sort)
|
||||
else:
|
||||
break
|
||||
sql_sorts.reverse()
|
||||
return sql_sorts
|
||||
|
||||
def order_clause(self):
|
||||
order_strings = []
|
||||
for sort in self._sql_sorts():
|
||||
order = sort.order_clause()
|
||||
order_strings.append(order)
|
||||
|
||||
return ", ".join(order_strings)
|
||||
|
||||
def is_slow(self):
|
||||
for sort in self.sorts:
|
||||
if sort.is_slow():
|
||||
return True
|
||||
return False
|
||||
|
||||
def sort(self, items):
|
||||
slow_sorts = []
|
||||
switch_slow = False
|
||||
for sort in reversed(self.sorts):
|
||||
if switch_slow:
|
||||
slow_sorts.append(sort)
|
||||
elif sort.order_clause() is None:
|
||||
switch_slow = True
|
||||
slow_sorts.append(sort)
|
||||
else:
|
||||
pass
|
||||
|
||||
for sort in slow_sorts:
|
||||
items = sort.sort(items)
|
||||
return items
|
||||
|
||||
def __repr__(self):
|
||||
return u'MultipleSort({0})'.format(repr(self.sorts))
|
||||
|
||||
|
||||
class FieldSort(Sort):
|
||||
"""An abstract sort criterion that orders by a specific field (of
|
||||
any kind).
|
||||
"""
|
||||
def __init__(self, field, ascending=True):
|
||||
self.field = field
|
||||
self.ascending = ascending
|
||||
|
||||
def sort(self, objs):
|
||||
# TODO: Conversion and null-detection here. In Python 3,
|
||||
# comparisons with None fail. We should also support flexible
|
||||
# attributes with different types without falling over.
|
||||
return sorted(objs, key=attrgetter(self.field),
|
||||
reverse=not self.ascending)
|
||||
|
||||
def __repr__(self):
|
||||
return u'<{0}: {1}{2}>'.format(
|
||||
type(self).__name__,
|
||||
self.field,
|
||||
'+' if self.ascending else '-',
|
||||
)
|
||||
|
||||
|
||||
class FixedFieldSort(FieldSort):
|
||||
"""Sort object to sort on a fixed field.
|
||||
"""
|
||||
def order_clause(self):
|
||||
order = "ASC" if self.ascending else "DESC"
|
||||
return "{0} {1}".format(self.field, order)
|
||||
|
||||
|
||||
class SlowFieldSort(FieldSort):
|
||||
"""A sort criterion by some model field other than a fixed field:
|
||||
i.e., a computed or flexible field.
|
||||
"""
|
||||
def is_slow(self):
|
||||
return True
|
||||
|
||||
|
||||
class NullSort(Sort):
|
||||
"""No sorting. Leave results unsorted."""
|
||||
def sort(items):
|
||||
return items
|
||||
180
lib/beets/dbcore/queryparse.py
Normal file
180
lib/beets/dbcore/queryparse.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Parsing of strings into DBCore queries.
|
||||
"""
|
||||
import re
|
||||
import itertools
|
||||
from . import query
|
||||
|
||||
|
||||
PARSE_QUERY_PART_REGEX = re.compile(
|
||||
# Non-capturing optional segment for the keyword.
|
||||
r'(?:'
|
||||
r'(\S+?)' # The field key.
|
||||
r'(?<!\\):' # Unescaped :
|
||||
r')?'
|
||||
|
||||
r'(.*)', # The term itself.
|
||||
|
||||
re.I # Case-insensitive.
|
||||
)
|
||||
|
||||
|
||||
def parse_query_part(part, query_classes={}, prefixes={},
|
||||
default_class=query.SubstringQuery):
|
||||
"""Take a query in the form of a key/value pair separated by a
|
||||
colon and return a tuple of `(key, value, cls)`. `key` may be None,
|
||||
indicating that any field may be matched. `cls` is a subclass of
|
||||
`FieldQuery`.
|
||||
|
||||
The optional `query_classes` parameter maps field names to default
|
||||
query types; `default_class` is the fallback. `prefixes` is a map
|
||||
from query prefix markers and query types. Prefix-indicated queries
|
||||
take precedence over type-based queries.
|
||||
|
||||
To determine the query class, two factors are used: prefixes and
|
||||
field types. For example, the colon prefix denotes a regular
|
||||
expression query and a type map might provide a special kind of
|
||||
query for numeric values. If neither a prefix nor a specific query
|
||||
class is available, `default_class` is used.
|
||||
|
||||
For instance,
|
||||
'stapler' -> (None, 'stapler', SubstringQuery)
|
||||
'color:red' -> ('color', 'red', SubstringQuery)
|
||||
':^Quiet' -> (None, '^Quiet', RegexpQuery)
|
||||
'color::b..e' -> ('color', 'b..e', RegexpQuery)
|
||||
|
||||
Prefixes may be "escaped" with a backslash to disable the keying
|
||||
behavior.
|
||||
"""
|
||||
part = part.strip()
|
||||
match = PARSE_QUERY_PART_REGEX.match(part)
|
||||
|
||||
assert match # Regex should always match.
|
||||
key = match.group(1)
|
||||
term = match.group(2).replace('\:', ':')
|
||||
|
||||
# Match the search term against the list of prefixes.
|
||||
for pre, query_class in prefixes.items():
|
||||
if term.startswith(pre):
|
||||
return key, term[len(pre):], query_class
|
||||
|
||||
# No matching prefix: use type-based or fallback/default query.
|
||||
query_class = query_classes.get(key, default_class)
|
||||
return key, term, query_class
|
||||
|
||||
|
||||
def construct_query_part(model_cls, prefixes, query_part):
|
||||
"""Create a query from a single query component, `query_part`, for
|
||||
querying instances of `model_cls`. Return a `Query` instance.
|
||||
"""
|
||||
# Shortcut for empty query parts.
|
||||
if not query_part:
|
||||
return query.TrueQuery()
|
||||
|
||||
# Get the query classes for each possible field.
|
||||
query_classes = {}
|
||||
for k, t in itertools.chain(model_cls._fields.items(),
|
||||
model_cls._types.items()):
|
||||
query_classes[k] = t.query
|
||||
|
||||
# Parse the string.
|
||||
key, pattern, query_class = \
|
||||
parse_query_part(query_part, query_classes, prefixes)
|
||||
|
||||
# No key specified.
|
||||
if key is None:
|
||||
if issubclass(query_class, query.FieldQuery):
|
||||
# The query type matches a specific field, but none was
|
||||
# specified. So we use a version of the query that matches
|
||||
# any field.
|
||||
return query.AnyFieldQuery(pattern, model_cls._search_fields,
|
||||
query_class)
|
||||
else:
|
||||
# Other query type.
|
||||
return query_class(pattern)
|
||||
|
||||
key = key.lower()
|
||||
return query_class(key.lower(), pattern, key in model_cls._fields)
|
||||
|
||||
|
||||
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
|
||||
"""Creates a collection query of type `query_cls` from a list of
|
||||
strings in the format used by parse_query_part. `model_cls`
|
||||
determines how queries are constructed from strings.
|
||||
"""
|
||||
subqueries = []
|
||||
for part in query_parts:
|
||||
subqueries.append(construct_query_part(model_cls, prefixes, part))
|
||||
if not subqueries: # No terms in query.
|
||||
subqueries = [query.TrueQuery()]
|
||||
return query_cls(subqueries)
|
||||
|
||||
|
||||
def construct_sort_part(model_cls, part):
|
||||
"""Create a `Sort` from a single string criterion.
|
||||
|
||||
`model_cls` is the `Model` being queried. `part` is a single string
|
||||
ending in ``+`` or ``-`` indicating the sort.
|
||||
"""
|
||||
assert part, "part must be a field name and + or -"
|
||||
field = part[:-1]
|
||||
assert field, "field is missing"
|
||||
direction = part[-1]
|
||||
assert direction in ('+', '-'), "part must end with + or -"
|
||||
is_ascending = direction == '+'
|
||||
|
||||
if field in model_cls._sorts:
|
||||
sort = model_cls._sorts[field](model_cls, is_ascending)
|
||||
elif field in model_cls._fields:
|
||||
sort = query.FixedFieldSort(field, is_ascending)
|
||||
else:
|
||||
# Flexible or computed.
|
||||
sort = query.SlowFieldSort(field, is_ascending)
|
||||
return sort
|
||||
|
||||
|
||||
def sort_from_strings(model_cls, sort_parts):
|
||||
"""Create a `Sort` from a list of sort criteria (strings).
|
||||
"""
|
||||
if not sort_parts:
|
||||
return query.NullSort()
|
||||
else:
|
||||
sort = query.MultipleSort()
|
||||
for part in sort_parts:
|
||||
sort.add_sort(construct_sort_part(model_cls, part))
|
||||
return sort
|
||||
|
||||
|
||||
def parse_sorted_query(model_cls, parts, prefixes={},
|
||||
query_cls=query.AndQuery):
|
||||
"""Given a list of strings, create the `Query` and `Sort` that they
|
||||
represent.
|
||||
"""
|
||||
# Separate query token and sort token.
|
||||
query_parts = []
|
||||
sort_parts = []
|
||||
for part in parts:
|
||||
if part.endswith((u'+', u'-')) and u':' not in part:
|
||||
sort_parts.append(part)
|
||||
else:
|
||||
query_parts.append(part)
|
||||
|
||||
# Parse each.
|
||||
q = query_from_strings(
|
||||
query_cls, model_cls, prefixes, query_parts
|
||||
)
|
||||
s = sort_from_strings(model_cls, sort_parts)
|
||||
return q, s
|
||||
208
lib/beets/dbcore/types.py
Normal file
208
lib/beets/dbcore/types.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Representation of type information for DBCore model fields.
|
||||
"""
|
||||
from . import query
|
||||
from beets.util import str2bool
|
||||
|
||||
|
||||
# Abstract base.
|
||||
|
||||
class Type(object):
|
||||
"""An object encapsulating the type of a model field. Includes
|
||||
information about how to store, query, format, and parse a given
|
||||
field.
|
||||
"""
|
||||
|
||||
sql = u'TEXT'
|
||||
"""The SQLite column type for the value.
|
||||
"""
|
||||
|
||||
query = query.SubstringQuery
|
||||
"""The `Query` subclass to be used when querying the field.
|
||||
"""
|
||||
|
||||
model_type = unicode
|
||||
"""The Python type that is used to represent the value in the model.
|
||||
|
||||
The model is guaranteed to return a value of this type if the field
|
||||
is accessed. To this end, the constructor is used by the `normalize`
|
||||
and `from_sql` methods and the `default` property.
|
||||
"""
|
||||
|
||||
@property
|
||||
def null(self):
|
||||
"""The value to be exposed when the underlying value is None.
|
||||
"""
|
||||
return self.model_type()
|
||||
|
||||
def format(self, value):
|
||||
"""Given a value of this type, produce a Unicode string
|
||||
representing the value. This is used in template evaluation.
|
||||
"""
|
||||
if value is None:
|
||||
value = self.null
|
||||
# `self.null` might be `None`
|
||||
if value is None:
|
||||
value = u''
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode('utf8', 'ignore')
|
||||
|
||||
return unicode(value)
|
||||
|
||||
def parse(self, string):
|
||||
"""Parse a (possibly human-written) string and return the
|
||||
indicated value of this type.
|
||||
"""
|
||||
try:
|
||||
return self.model_type(string)
|
||||
except ValueError:
|
||||
return self.null
|
||||
|
||||
def normalize(self, value):
|
||||
"""Given a value that will be assigned into a field of this
|
||||
type, normalize the value to have the appropriate type. This
|
||||
base implementation only reinterprets `None`.
|
||||
"""
|
||||
if value is None:
|
||||
return self.null
|
||||
else:
|
||||
# TODO This should eventually be replaced by
|
||||
# `self.model_type(value)`
|
||||
return value
|
||||
|
||||
def from_sql(self, sql_value):
|
||||
"""Receives the value stored in the SQL backend and return the
|
||||
value to be stored in the model.
|
||||
|
||||
For fixed fields the type of `value` is determined by the column
|
||||
type affinity given in the `sql` property and the SQL to Python
|
||||
mapping of the database adapter. For more information see:
|
||||
http://www.sqlite.org/datatype3.html
|
||||
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
|
||||
|
||||
Flexible fields have the type afinity `TEXT`. This means the
|
||||
`sql_value` is either a `buffer` or a `unicode` object` and the
|
||||
method must handle these in addition.
|
||||
"""
|
||||
if isinstance(sql_value, buffer):
|
||||
sql_value = bytes(sql_value).decode('utf8', 'ignore')
|
||||
if isinstance(sql_value, unicode):
|
||||
return self.parse(sql_value)
|
||||
else:
|
||||
return self.normalize(sql_value)
|
||||
|
||||
def to_sql(self, model_value):
|
||||
"""Convert a value as stored in the model object to a value used
|
||||
by the database adapter.
|
||||
"""
|
||||
return model_value
|
||||
|
||||
|
||||
# Reusable types.
|
||||
|
||||
class Default(Type):
|
||||
null = None
|
||||
|
||||
|
||||
class Integer(Type):
|
||||
"""A basic integer type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
query = query.NumericQuery
|
||||
model_type = int
|
||||
|
||||
|
||||
class PaddedInt(Integer):
|
||||
"""An integer field that is formatted with a given number of digits,
|
||||
padded with zeroes.
|
||||
"""
|
||||
def __init__(self, digits):
|
||||
self.digits = digits
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:0{1}d}'.format(value or 0, self.digits)
|
||||
|
||||
|
||||
class ScaledInt(Integer):
|
||||
"""An integer whose formatting operation scales the number by a
|
||||
constant and adds a suffix. Good for units with large magnitudes.
|
||||
"""
|
||||
def __init__(self, unit, suffix=u''):
|
||||
self.unit = unit
|
||||
self.suffix = suffix
|
||||
|
||||
def format(self, value):
|
||||
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
|
||||
|
||||
|
||||
class Id(Integer):
|
||||
"""An integer used as the row id or a foreign key in a SQLite table.
|
||||
This type is nullable: None values are not translated to zero.
|
||||
"""
|
||||
null = None
|
||||
|
||||
def __init__(self, primary=True):
|
||||
if primary:
|
||||
self.sql = u'INTEGER PRIMARY KEY'
|
||||
|
||||
|
||||
class Float(Type):
|
||||
"""A basic floating-point type.
|
||||
"""
|
||||
sql = u'REAL'
|
||||
query = query.NumericQuery
|
||||
model_type = float
|
||||
|
||||
def format(self, value):
|
||||
return u'{0:.1f}'.format(value or 0.0)
|
||||
|
||||
|
||||
class NullFloat(Float):
|
||||
"""Same as `Float`, but does not normalize `None` to `0.0`.
|
||||
"""
|
||||
null = None
|
||||
|
||||
|
||||
class String(Type):
|
||||
"""A Unicode string type.
|
||||
"""
|
||||
sql = u'TEXT'
|
||||
query = query.SubstringQuery
|
||||
|
||||
|
||||
class Boolean(Type):
|
||||
"""A boolean type.
|
||||
"""
|
||||
sql = u'INTEGER'
|
||||
query = query.BooleanQuery
|
||||
model_type = bool
|
||||
|
||||
def format(self, value):
|
||||
return unicode(bool(value))
|
||||
|
||||
def parse(self, string):
|
||||
return str2bool(string)
|
||||
|
||||
|
||||
# Shared instances of common types.
|
||||
DEFAULT = Default()
|
||||
INTEGER = Integer()
|
||||
PRIMARY_ID = Id(True)
|
||||
FOREIGN_ID = Id(False)
|
||||
FLOAT = Float()
|
||||
NULL_FLOAT = NullFloat()
|
||||
STRING = String()
|
||||
BOOLEAN = Boolean()
|
||||
1428
lib/beets/importer.py
Normal file
1428
lib/beets/importer.py
Normal file
File diff suppressed because it is too large
Load Diff
1299
lib/beets/library.py
Normal file
1299
lib/beets/library.py
Normal file
File diff suppressed because it is too large
Load Diff
1929
lib/beets/mediafile.py
Normal file
1929
lib/beets/mediafile.py
Normal file
File diff suppressed because it is too large
Load Diff
435
lib/beets/plugins.py
Executable file
435
lib/beets/plugins.py
Executable file
@@ -0,0 +1,435 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Support for beets plugins."""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import inspect
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
import beets
|
||||
from beets import mediafile
|
||||
|
||||
PLUGIN_NAMESPACE = 'beetsplug'
|
||||
|
||||
# Plugins using the Last.fm API can share the same API key.
|
||||
LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
|
||||
|
||||
# Global logger.
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
class PluginConflictException(Exception):
|
||||
"""Indicates that the services provided by one plugin conflict with
|
||||
those of another.
|
||||
|
||||
For example two plugins may define different types for flexible fields.
|
||||
"""
|
||||
|
||||
|
||||
# Managing the plugins themselves.
|
||||
|
||||
class BeetsPlugin(object):
|
||||
"""The base class for all beets plugins. Plugins provide
|
||||
functionality by defining a subclass of BeetsPlugin and overriding
|
||||
the abstract methods defined here.
|
||||
"""
|
||||
def __init__(self, name=None):
|
||||
"""Perform one-time plugin setup.
|
||||
"""
|
||||
self.import_stages = []
|
||||
self.name = name or self.__module__.split('.')[-1]
|
||||
self.config = beets.config[self.name]
|
||||
if not self.template_funcs:
|
||||
self.template_funcs = {}
|
||||
if not self.template_fields:
|
||||
self.template_fields = {}
|
||||
if not self.album_template_fields:
|
||||
self.album_template_fields = {}
|
||||
|
||||
def commands(self):
|
||||
"""Should return a list of beets.ui.Subcommand objects for
|
||||
commands that should be added to beets' CLI.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def queries(self):
|
||||
"""Should return a dict mapping prefixes to Query subclasses.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def track_distance(self, item, info):
|
||||
"""Should return a Distance object to be added to the
|
||||
distance for every track comparison.
|
||||
"""
|
||||
return beets.autotag.hooks.Distance()
|
||||
|
||||
def album_distance(self, items, album_info, mapping):
|
||||
"""Should return a Distance object to be added to the
|
||||
distance for every album-level comparison.
|
||||
"""
|
||||
return beets.autotag.hooks.Distance()
|
||||
|
||||
def candidates(self, items, artist, album, va_likely):
|
||||
"""Should return a sequence of AlbumInfo objects that match the
|
||||
album whose items are provided.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def item_candidates(self, item, artist, title):
|
||||
"""Should return a sequence of TrackInfo objects that match the
|
||||
item provided.
|
||||
"""
|
||||
return ()
|
||||
|
||||
def album_for_id(self, album_id):
|
||||
"""Return an AlbumInfo object or None if no matching release was
|
||||
found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def track_for_id(self, track_id):
|
||||
"""Return a TrackInfo object or None if no matching release was
|
||||
found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def add_media_field(self, name, descriptor):
|
||||
"""Add a field that is synchronized between media files and items.
|
||||
|
||||
When a media field is added ``item.write()`` will set the name
|
||||
property of the item's MediaFile to ``item[name]`` and save the
|
||||
changes. Similarly ``item.read()`` will set ``item[name]`` to
|
||||
the value of the name property of the media file.
|
||||
|
||||
``descriptor`` must be an instance of ``mediafile.MediaField``.
|
||||
"""
|
||||
# Defer impor to prevent circular dependency
|
||||
from beets import library
|
||||
mediafile.MediaFile.add_field(name, descriptor)
|
||||
library.Item._media_fields.add(name)
|
||||
|
||||
listeners = None
|
||||
|
||||
@classmethod
|
||||
def register_listener(cls, event, func):
|
||||
"""Add a function as a listener for the specified event. (An
|
||||
imperative alternative to the @listen decorator.)
|
||||
"""
|
||||
if cls.listeners is None:
|
||||
cls.listeners = defaultdict(list)
|
||||
cls.listeners[event].append(func)
|
||||
|
||||
@classmethod
|
||||
def listen(cls, event):
|
||||
"""Decorator that adds a function as an event handler for the
|
||||
specified event (as a string). The parameters passed to function
|
||||
will vary depending on what event occurred.
|
||||
|
||||
The function should respond to named parameters.
|
||||
function(**kwargs) will trap all arguments in a dictionary.
|
||||
Example:
|
||||
|
||||
>>> @MyPlugin.listen("imported")
|
||||
>>> def importListener(**kwargs):
|
||||
... pass
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.listeners is None:
|
||||
cls.listeners = defaultdict(list)
|
||||
cls.listeners[event].append(func)
|
||||
return func
|
||||
return helper
|
||||
|
||||
template_funcs = None
|
||||
template_fields = None
|
||||
album_template_fields = None
|
||||
|
||||
@classmethod
|
||||
def template_func(cls, name):
|
||||
"""Decorator that registers a path template function. The
|
||||
function will be invoked as ``%name{}`` from path format
|
||||
strings.
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.template_funcs is None:
|
||||
cls.template_funcs = {}
|
||||
cls.template_funcs[name] = func
|
||||
return func
|
||||
return helper
|
||||
|
||||
@classmethod
|
||||
def template_field(cls, name):
|
||||
"""Decorator that registers a path template field computation.
|
||||
The value will be referenced as ``$name`` from path format
|
||||
strings. The function must accept a single parameter, the Item
|
||||
being formatted.
|
||||
"""
|
||||
def helper(func):
|
||||
if cls.template_fields is None:
|
||||
cls.template_fields = {}
|
||||
cls.template_fields[name] = func
|
||||
return func
|
||||
return helper
|
||||
|
||||
|
||||
_classes = set()
|
||||
|
||||
|
||||
def load_plugins(names=()):
|
||||
"""Imports the modules for a sequence of plugin names. Each name
|
||||
must be the name of a Python module under the "beetsplug" namespace
|
||||
package in sys.path; the module indicated should contain the
|
||||
BeetsPlugin subclasses desired.
|
||||
"""
|
||||
for name in names:
|
||||
modname = '%s.%s' % (PLUGIN_NAMESPACE, name)
|
||||
try:
|
||||
try:
|
||||
namespace = __import__(modname, None, None)
|
||||
except ImportError as exc:
|
||||
# Again, this is hacky:
|
||||
if exc.args[0].endswith(' ' + name):
|
||||
log.warn(u'** plugin {0} not found'.format(name))
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
for obj in getattr(namespace, name).__dict__.values():
|
||||
if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \
|
||||
and obj != BeetsPlugin and obj not in _classes:
|
||||
_classes.add(obj)
|
||||
|
||||
except:
|
||||
log.warn(u'** error loading plugin {0}'.format(name))
|
||||
log.warn(traceback.format_exc())
|
||||
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
def find_plugins():
|
||||
"""Returns a list of BeetsPlugin subclass instances from all
|
||||
currently loaded beets plugins. Loads the default plugin set
|
||||
first.
|
||||
"""
|
||||
load_plugins()
|
||||
plugins = []
|
||||
for cls in _classes:
|
||||
# Only instantiate each plugin class once.
|
||||
if cls not in _instances:
|
||||
_instances[cls] = cls()
|
||||
plugins.append(_instances[cls])
|
||||
return plugins
|
||||
|
||||
|
||||
# Communication with plugins.
|
||||
|
||||
def commands():
|
||||
"""Returns a list of Subcommand objects from all loaded plugins.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out += plugin.commands()
|
||||
return out
|
||||
|
||||
|
||||
def queries():
|
||||
"""Returns a dict mapping prefix strings to Query subclasses all loaded
|
||||
plugins.
|
||||
"""
|
||||
out = {}
|
||||
for plugin in find_plugins():
|
||||
out.update(plugin.queries())
|
||||
return out
|
||||
|
||||
|
||||
def types(model_cls):
|
||||
# Gives us `item_types` and `album_types`
|
||||
attr_name = '{0}_types'.format(model_cls.__name__.lower())
|
||||
types = {}
|
||||
for plugin in find_plugins():
|
||||
plugin_types = getattr(plugin, attr_name, {})
|
||||
for field in plugin_types:
|
||||
if field in types and plugin_types[field] != types[field]:
|
||||
raise PluginConflictException(
|
||||
u'Plugin {0} defines flexible field {1} '
|
||||
'which has already been defined with '
|
||||
'another type.'.format(plugin.name, field)
|
||||
)
|
||||
types.update(plugin_types)
|
||||
return types
|
||||
|
||||
|
||||
def track_distance(item, info):
|
||||
"""Gets the track distance calculated by all loaded plugins.
|
||||
Returns a Distance object.
|
||||
"""
|
||||
from beets.autotag.hooks import Distance
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.track_distance(item, info))
|
||||
return dist
|
||||
|
||||
|
||||
def album_distance(items, album_info, mapping):
|
||||
"""Returns the album distance calculated by plugins."""
|
||||
from beets.autotag.hooks import Distance
|
||||
dist = Distance()
|
||||
for plugin in find_plugins():
|
||||
dist.update(plugin.album_distance(items, album_info, mapping))
|
||||
return dist
|
||||
|
||||
|
||||
def candidates(items, artist, album, va_likely):
|
||||
"""Gets MusicBrainz candidates for an album from each plugin.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out.extend(plugin.candidates(items, artist, album, va_likely))
|
||||
return out
|
||||
|
||||
|
||||
def item_candidates(item, artist, title):
|
||||
"""Gets MusicBrainz candidates for an item from the plugins.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
out.extend(plugin.item_candidates(item, artist, title))
|
||||
return out
|
||||
|
||||
|
||||
def album_for_id(album_id):
|
||||
"""Get AlbumInfo objects for a given ID string.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
res = plugin.album_for_id(album_id)
|
||||
if res:
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def track_for_id(track_id):
|
||||
"""Get TrackInfo objects for a given ID string.
|
||||
"""
|
||||
out = []
|
||||
for plugin in find_plugins():
|
||||
res = plugin.track_for_id(track_id)
|
||||
if res:
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def template_funcs():
|
||||
"""Get all the template functions declared by plugins as a
|
||||
dictionary.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.template_funcs:
|
||||
funcs.update(plugin.template_funcs)
|
||||
return funcs
|
||||
|
||||
|
||||
def import_stages():
|
||||
"""Get a list of import stage functions defined by plugins."""
|
||||
stages = []
|
||||
for plugin in find_plugins():
|
||||
if hasattr(plugin, 'import_stages'):
|
||||
stages += plugin.import_stages
|
||||
return stages
|
||||
|
||||
|
||||
# New-style (lazy) plugin-provided fields.
|
||||
|
||||
def item_field_getters():
|
||||
"""Get a dictionary mapping field names to unary functions that
|
||||
compute the field's value.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.template_fields:
|
||||
funcs.update(plugin.template_fields)
|
||||
return funcs
|
||||
|
||||
|
||||
def album_field_getters():
|
||||
"""As above, for album fields.
|
||||
"""
|
||||
funcs = {}
|
||||
for plugin in find_plugins():
|
||||
if plugin.album_template_fields:
|
||||
funcs.update(plugin.album_template_fields)
|
||||
return funcs
|
||||
|
||||
|
||||
# Event dispatch.
|
||||
|
||||
def event_handlers():
|
||||
"""Find all event handlers from plugins as a dictionary mapping
|
||||
event names to sequences of callables.
|
||||
"""
|
||||
all_handlers = defaultdict(list)
|
||||
for plugin in find_plugins():
|
||||
if plugin.listeners:
|
||||
for event, handlers in plugin.listeners.items():
|
||||
all_handlers[event] += handlers
|
||||
return all_handlers
|
||||
|
||||
|
||||
def send(event, **arguments):
|
||||
"""Sends an event to all assigned event listeners. Event is the
|
||||
name of the event to send, all other named arguments go to the
|
||||
event handler(s).
|
||||
|
||||
Returns a list of return values from the handlers.
|
||||
"""
|
||||
log.debug(u'Sending event: {0}'.format(event))
|
||||
for handler in event_handlers()[event]:
|
||||
# Don't break legacy plugins if we want to pass more arguments
|
||||
argspec = inspect.getargspec(handler).args
|
||||
args = dict((k, v) for k, v in arguments.items() if k in argspec)
|
||||
handler(**args)
|
||||
|
||||
|
||||
def feat_tokens(for_artist=True):
|
||||
"""Return a regular expression that matches phrases like "featuring"
|
||||
that separate a main artist or a song title from secondary artists.
|
||||
The `for_artist` option determines whether the regex should be
|
||||
suitable for matching artist fields (the default) or title fields.
|
||||
"""
|
||||
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
|
||||
if for_artist:
|
||||
feat_words += ['with', 'vs', 'and', 'con', '&']
|
||||
return '(?<=\s)(?:{0})(?=\s)'.format(
|
||||
'|'.join(re.escape(x) for x in feat_words)
|
||||
)
|
||||
|
||||
|
||||
def sanitize_choices(choices, choices_all):
|
||||
"""Clean up a stringlist configuration attribute: keep only choices
|
||||
elements present in choices_all, remove duplicate elements, expand '*'
|
||||
wildcard while keeping original stringlist order.
|
||||
"""
|
||||
seen = set()
|
||||
others = [x for x in choices_all if x not in choices]
|
||||
res = []
|
||||
for s in choices:
|
||||
if s in list(choices_all) + ['*']:
|
||||
if not (s in seen or seen.add(s)):
|
||||
res.extend(list(others) if s == '*' else [s])
|
||||
return res
|
||||
970
lib/beets/ui/__init__.py
Normal file
970
lib/beets/ui/__init__.py
Normal file
@@ -0,0 +1,970 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""This module contains all of the core logic for beets' command-line
|
||||
interface. To invoke the CLI, just call beets.ui.main(). The actual
|
||||
CLI commands are implemented in the ui.commands module.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import locale
|
||||
import optparse
|
||||
import textwrap
|
||||
import sys
|
||||
from difflib import SequenceMatcher
|
||||
import logging
|
||||
import sqlite3
|
||||
import errno
|
||||
import re
|
||||
import struct
|
||||
import traceback
|
||||
import os.path
|
||||
|
||||
from beets import library
|
||||
from beets import plugins
|
||||
from beets import util
|
||||
from beets.util.functemplate import Template
|
||||
from beets import config
|
||||
from beets.util import confit
|
||||
from beets.autotag import mb
|
||||
|
||||
# On Windows platforms, use colorama to support "ANSI" terminal colors.
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
import colorama
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
colorama.init()
|
||||
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
if not log.handlers:
|
||||
log.addHandler(logging.StreamHandler())
|
||||
log.propagate = False # Don't propagate to root handler.
|
||||
|
||||
|
||||
PF_KEY_QUERIES = {
|
||||
'comp': 'comp:true',
|
||||
'singleton': 'singleton:true',
|
||||
}
|
||||
|
||||
|
||||
class UserError(Exception):
|
||||
"""UI exception. Commands should throw this in order to display
|
||||
nonrecoverable errors to the user.
|
||||
"""
|
||||
|
||||
|
||||
# Utilities.
|
||||
|
||||
def _encoding():
|
||||
"""Tries to guess the encoding used by the terminal."""
|
||||
# Configured override?
|
||||
encoding = config['terminal_encoding'].get()
|
||||
if encoding:
|
||||
return encoding
|
||||
|
||||
# Determine from locale settings.
|
||||
try:
|
||||
return locale.getdefaultlocale()[1] or 'utf8'
|
||||
except ValueError:
|
||||
# Invalid locale environment variable setting. To avoid
|
||||
# failing entirely for no good reason, assume UTF-8.
|
||||
return 'utf8'
|
||||
|
||||
|
||||
def decargs(arglist):
|
||||
"""Given a list of command-line argument bytestrings, attempts to
|
||||
decode them to Unicode strings.
|
||||
"""
|
||||
return [s.decode(_encoding()) for s in arglist]
|
||||
|
||||
|
||||
def print_(*strings):
|
||||
"""Like print, but rather than raising an error when a character
|
||||
is not in the terminal's encoding's character set, just silently
|
||||
replaces it.
|
||||
"""
|
||||
if strings:
|
||||
if isinstance(strings[0], unicode):
|
||||
txt = u' '.join(strings)
|
||||
else:
|
||||
txt = ' '.join(strings)
|
||||
else:
|
||||
txt = u''
|
||||
if isinstance(txt, unicode):
|
||||
txt = txt.encode(_encoding(), 'replace')
|
||||
print(txt)
|
||||
|
||||
|
||||
def input_(prompt=None):
|
||||
"""Like `raw_input`, but decodes the result to a Unicode string.
|
||||
Raises a UserError if stdin is not available. The prompt is sent to
|
||||
stdout rather than stderr. A printed between the prompt and the
|
||||
input cursor.
|
||||
"""
|
||||
# raw_input incorrectly sends prompts to stderr, not stdout, so we
|
||||
# use print() explicitly to display prompts.
|
||||
# http://bugs.python.org/issue1927
|
||||
if prompt:
|
||||
if isinstance(prompt, unicode):
|
||||
prompt = prompt.encode(_encoding(), 'replace')
|
||||
print(prompt, end=' ')
|
||||
|
||||
try:
|
||||
resp = raw_input()
|
||||
except EOFError:
|
||||
raise UserError('stdin stream ended while input required')
|
||||
|
||||
return resp.decode(sys.stdin.encoding or 'utf8', 'ignore')
|
||||
|
||||
|
||||
def input_options(options, require=False, prompt=None, fallback_prompt=None,
|
||||
numrange=None, default=None, max_width=72):
|
||||
"""Prompts a user for input. The sequence of `options` defines the
|
||||
choices the user has. A single-letter shortcut is inferred for each
|
||||
option; the user's choice is returned as that single, lower-case
|
||||
letter. The options should be provided as lower-case strings unless
|
||||
a particular shortcut is desired; in that case, only that letter
|
||||
should be capitalized.
|
||||
|
||||
By default, the first option is the default. `default` can be provided to
|
||||
override this. If `require` is provided, then there is no default. The
|
||||
prompt and fallback prompt are also inferred but can be overridden.
|
||||
|
||||
If numrange is provided, it is a pair of `(high, low)` (both ints)
|
||||
indicating that, in addition to `options`, the user may enter an
|
||||
integer in that inclusive range.
|
||||
|
||||
`max_width` specifies the maximum number of columns in the
|
||||
automatically generated prompt string.
|
||||
"""
|
||||
# Assign single letters to each option. Also capitalize the options
|
||||
# to indicate the letter.
|
||||
letters = {}
|
||||
display_letters = []
|
||||
capitalized = []
|
||||
first = True
|
||||
for option in options:
|
||||
# Is a letter already capitalized?
|
||||
for letter in option:
|
||||
if letter.isalpha() and letter.upper() == letter:
|
||||
found_letter = letter
|
||||
break
|
||||
else:
|
||||
# Infer a letter.
|
||||
for letter in option:
|
||||
if not letter.isalpha():
|
||||
continue # Don't use punctuation.
|
||||
if letter not in letters:
|
||||
found_letter = letter
|
||||
break
|
||||
else:
|
||||
raise ValueError('no unambiguous lettering found')
|
||||
|
||||
letters[found_letter.lower()] = option
|
||||
index = option.index(found_letter)
|
||||
|
||||
# Mark the option's shortcut letter for display.
|
||||
if not require and (
|
||||
(default is None and not numrange and first) or
|
||||
(isinstance(default, basestring) and
|
||||
found_letter.lower() == default.lower())):
|
||||
# The first option is the default; mark it.
|
||||
show_letter = '[%s]' % found_letter.upper()
|
||||
is_default = True
|
||||
else:
|
||||
show_letter = found_letter.upper()
|
||||
is_default = False
|
||||
|
||||
# Colorize the letter shortcut.
|
||||
show_letter = colorize('turquoise' if is_default else 'blue',
|
||||
show_letter)
|
||||
|
||||
# Insert the highlighted letter back into the word.
|
||||
capitalized.append(
|
||||
option[:index] + show_letter + option[index + 1:]
|
||||
)
|
||||
display_letters.append(found_letter.upper())
|
||||
|
||||
first = False
|
||||
|
||||
# The default is just the first option if unspecified.
|
||||
if require:
|
||||
default = None
|
||||
elif default is None:
|
||||
if numrange:
|
||||
default = numrange[0]
|
||||
else:
|
||||
default = display_letters[0].lower()
|
||||
|
||||
# Make a prompt if one is not provided.
|
||||
if not prompt:
|
||||
prompt_parts = []
|
||||
prompt_part_lengths = []
|
||||
if numrange:
|
||||
if isinstance(default, int):
|
||||
default_name = str(default)
|
||||
default_name = colorize('turquoise', default_name)
|
||||
tmpl = '# selection (default %s)'
|
||||
prompt_parts.append(tmpl % default_name)
|
||||
prompt_part_lengths.append(len(tmpl % str(default)))
|
||||
else:
|
||||
prompt_parts.append('# selection')
|
||||
prompt_part_lengths.append(len(prompt_parts[-1]))
|
||||
prompt_parts += capitalized
|
||||
prompt_part_lengths += [len(s) for s in options]
|
||||
|
||||
# Wrap the query text.
|
||||
prompt = ''
|
||||
line_length = 0
|
||||
for i, (part, length) in enumerate(zip(prompt_parts,
|
||||
prompt_part_lengths)):
|
||||
# Add punctuation.
|
||||
if i == len(prompt_parts) - 1:
|
||||
part += '?'
|
||||
else:
|
||||
part += ','
|
||||
length += 1
|
||||
|
||||
# Choose either the current line or the beginning of the next.
|
||||
if line_length + length + 1 > max_width:
|
||||
prompt += '\n'
|
||||
line_length = 0
|
||||
|
||||
if line_length != 0:
|
||||
# Not the beginning of the line; need a space.
|
||||
part = ' ' + part
|
||||
length += 1
|
||||
|
||||
prompt += part
|
||||
line_length += length
|
||||
|
||||
# Make a fallback prompt too. This is displayed if the user enters
|
||||
# something that is not recognized.
|
||||
if not fallback_prompt:
|
||||
fallback_prompt = 'Enter one of '
|
||||
if numrange:
|
||||
fallback_prompt += '%i-%i, ' % numrange
|
||||
fallback_prompt += ', '.join(display_letters) + ':'
|
||||
|
||||
resp = input_(prompt)
|
||||
while True:
|
||||
resp = resp.strip().lower()
|
||||
|
||||
# Try default option.
|
||||
if default is not None and not resp:
|
||||
resp = default
|
||||
|
||||
# Try an integer input if available.
|
||||
if numrange:
|
||||
try:
|
||||
resp = int(resp)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
low, high = numrange
|
||||
if low <= resp <= high:
|
||||
return resp
|
||||
else:
|
||||
resp = None
|
||||
|
||||
# Try a normal letter input.
|
||||
if resp:
|
||||
resp = resp[0]
|
||||
if resp in letters:
|
||||
return resp
|
||||
|
||||
# Prompt for new input.
|
||||
resp = input_(fallback_prompt)
|
||||
|
||||
|
||||
def input_yn(prompt, require=False):
|
||||
"""Prompts the user for a "yes" or "no" response. The default is
|
||||
"yes" unless `require` is `True`, in which case there is no default.
|
||||
"""
|
||||
sel = input_options(
|
||||
('y', 'n'), require, prompt, 'Enter Y or N:'
|
||||
)
|
||||
return sel == 'y'
|
||||
|
||||
|
||||
def human_bytes(size):
|
||||
"""Formats size, a number of bytes, in a human-readable way."""
|
||||
suffices = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB', 'HB']
|
||||
for suffix in suffices:
|
||||
if size < 1024:
|
||||
return "%3.1f %s" % (size, suffix)
|
||||
size /= 1024.0
|
||||
return "big"
|
||||
|
||||
|
||||
def human_seconds(interval):
|
||||
"""Formats interval, a number of seconds, as a human-readable time
|
||||
interval using English words.
|
||||
"""
|
||||
units = [
|
||||
(1, 'second'),
|
||||
(60, 'minute'),
|
||||
(60, 'hour'),
|
||||
(24, 'day'),
|
||||
(7, 'week'),
|
||||
(52, 'year'),
|
||||
(10, 'decade'),
|
||||
]
|
||||
for i in range(len(units) - 1):
|
||||
increment, suffix = units[i]
|
||||
next_increment, _ = units[i + 1]
|
||||
interval /= float(increment)
|
||||
if interval < next_increment:
|
||||
break
|
||||
else:
|
||||
# Last unit.
|
||||
increment, suffix = units[-1]
|
||||
interval /= float(increment)
|
||||
|
||||
return "%3.1f %ss" % (interval, suffix)
|
||||
|
||||
|
||||
def human_seconds_short(interval):
|
||||
"""Formats a number of seconds as a short human-readable M:SS
|
||||
string.
|
||||
"""
|
||||
interval = int(interval)
|
||||
return u'%i:%02i' % (interval // 60, interval % 60)
|
||||
|
||||
|
||||
# ANSI terminal colorization code heavily inspired by pygments:
|
||||
# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py
|
||||
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
|
||||
COLOR_ESCAPE = "\x1b["
|
||||
DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue",
|
||||
"purple", "teal", "lightgray"]
|
||||
LIGHT_COLORS = ["darkgray", "red", "green", "yellow", "blue",
|
||||
"fuchsia", "turquoise", "white"]
|
||||
RESET_COLOR = COLOR_ESCAPE + "39;49;00m"
|
||||
|
||||
|
||||
def _colorize(color, text):
|
||||
"""Returns a string that prints the given text in the given color
|
||||
in a terminal that is ANSI color-aware. The color must be something
|
||||
in DARK_COLORS or LIGHT_COLORS.
|
||||
"""
|
||||
if color in DARK_COLORS:
|
||||
escape = COLOR_ESCAPE + "%im" % (DARK_COLORS.index(color) + 30)
|
||||
elif color in LIGHT_COLORS:
|
||||
escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS.index(color) + 30)
|
||||
else:
|
||||
raise ValueError('no such color %s', color)
|
||||
return escape + text + RESET_COLOR
|
||||
|
||||
|
||||
def colorize(color, text):
|
||||
"""Colorize text if colored output is enabled. (Like _colorize but
|
||||
conditional.)
|
||||
"""
|
||||
if config['color']:
|
||||
return _colorize(color, text)
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _colordiff(a, b, highlight='red', minor_highlight='lightgray'):
|
||||
"""Given two values, return the same pair of strings except with
|
||||
their differences highlighted in the specified color. Strings are
|
||||
highlighted intelligently to show differences; other values are
|
||||
stringified and highlighted in their entirety.
|
||||
"""
|
||||
if not isinstance(a, basestring) or not isinstance(b, basestring):
|
||||
# Non-strings: use ordinary equality.
|
||||
a = unicode(a)
|
||||
b = unicode(b)
|
||||
if a == b:
|
||||
return a, b
|
||||
else:
|
||||
return colorize(highlight, a), colorize(highlight, b)
|
||||
|
||||
if isinstance(a, bytes) or isinstance(b, bytes):
|
||||
# A path field.
|
||||
a = util.displayable_path(a)
|
||||
b = util.displayable_path(b)
|
||||
|
||||
a_out = []
|
||||
b_out = []
|
||||
|
||||
matcher = SequenceMatcher(lambda x: False, a, b)
|
||||
for op, a_start, a_end, b_start, b_end in matcher.get_opcodes():
|
||||
if op == 'equal':
|
||||
# In both strings.
|
||||
a_out.append(a[a_start:a_end])
|
||||
b_out.append(b[b_start:b_end])
|
||||
elif op == 'insert':
|
||||
# Right only.
|
||||
b_out.append(colorize(highlight, b[b_start:b_end]))
|
||||
elif op == 'delete':
|
||||
# Left only.
|
||||
a_out.append(colorize(highlight, a[a_start:a_end]))
|
||||
elif op == 'replace':
|
||||
# Right and left differ. Colorise with second highlight if
|
||||
# it's just a case change.
|
||||
if a[a_start:a_end].lower() != b[b_start:b_end].lower():
|
||||
color = highlight
|
||||
else:
|
||||
color = minor_highlight
|
||||
a_out.append(colorize(color, a[a_start:a_end]))
|
||||
b_out.append(colorize(color, b[b_start:b_end]))
|
||||
else:
|
||||
assert(False)
|
||||
|
||||
return u''.join(a_out), u''.join(b_out)
|
||||
|
||||
|
||||
def colordiff(a, b, highlight='red'):
|
||||
"""Colorize differences between two values if color is enabled.
|
||||
(Like _colordiff but conditional.)
|
||||
"""
|
||||
if config['color']:
|
||||
return _colordiff(a, b, highlight)
|
||||
else:
|
||||
return unicode(a), unicode(b)
|
||||
|
||||
|
||||
def get_path_formats(subview=None):
|
||||
"""Get the configuration's path formats as a list of query/template
|
||||
pairs.
|
||||
"""
|
||||
path_formats = []
|
||||
subview = subview or config['paths']
|
||||
for query, view in subview.items():
|
||||
query = PF_KEY_QUERIES.get(query, query) # Expand common queries.
|
||||
path_formats.append((query, Template(view.get(unicode))))
|
||||
return path_formats
|
||||
|
||||
|
||||
def get_replacements():
|
||||
"""Confit validation function that reads regex/string pairs.
|
||||
"""
|
||||
replacements = []
|
||||
for pattern, repl in config['replace'].get(dict).items():
|
||||
repl = repl or ''
|
||||
try:
|
||||
replacements.append((re.compile(pattern), repl))
|
||||
except re.error:
|
||||
raise UserError(
|
||||
u'malformed regular expression in replace: {0}'.format(
|
||||
pattern
|
||||
)
|
||||
)
|
||||
return replacements
|
||||
|
||||
|
||||
def _pick_format(album, fmt=None):
|
||||
"""Pick a format string for printing Album or Item objects,
|
||||
falling back to config options and defaults.
|
||||
"""
|
||||
if fmt:
|
||||
return fmt
|
||||
if album:
|
||||
return config['list_format_album'].get(unicode)
|
||||
else:
|
||||
return config['list_format_item'].get(unicode)
|
||||
|
||||
|
||||
def print_obj(obj, lib, fmt=None):
|
||||
"""Print an Album or Item object. If `fmt` is specified, use that
|
||||
format string. Otherwise, use the configured template.
|
||||
"""
|
||||
album = isinstance(obj, library.Album)
|
||||
fmt = _pick_format(album, fmt)
|
||||
if isinstance(fmt, Template):
|
||||
template = fmt
|
||||
else:
|
||||
template = Template(fmt)
|
||||
print_(obj.evaluate_template(template))
|
||||
|
||||
|
||||
def term_width():
|
||||
"""Get the width (columns) of the terminal."""
|
||||
fallback = config['ui']['terminal_width'].get(int)
|
||||
|
||||
# The fcntl and termios modules are not available on non-Unix
|
||||
# platforms, so we fall back to a constant.
|
||||
try:
|
||||
import fcntl
|
||||
import termios
|
||||
except ImportError:
|
||||
return fallback
|
||||
|
||||
try:
|
||||
buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4)
|
||||
except IOError:
|
||||
return fallback
|
||||
try:
|
||||
height, width = struct.unpack('hh', buf)
|
||||
except struct.error:
|
||||
return fallback
|
||||
return width
|
||||
|
||||
|
||||
FLOAT_EPSILON = 0.01
|
||||
|
||||
|
||||
def _field_diff(field, old, new):
|
||||
"""Given two Model objects, format their values for `field` and
|
||||
highlight changes among them. Return a human-readable string. If the
|
||||
value has not changed, return None instead.
|
||||
"""
|
||||
oldval = old.get(field)
|
||||
newval = new.get(field)
|
||||
|
||||
# If no change, abort.
|
||||
if isinstance(oldval, float) and isinstance(newval, float) and \
|
||||
abs(oldval - newval) < FLOAT_EPSILON:
|
||||
return None
|
||||
elif oldval == newval:
|
||||
return None
|
||||
|
||||
# Get formatted values for output.
|
||||
oldstr = old.formatted().get(field, u'')
|
||||
newstr = new.formatted().get(field, u'')
|
||||
|
||||
# For strings, highlight changes. For others, colorize the whole
|
||||
# thing.
|
||||
if isinstance(oldval, basestring):
|
||||
oldstr, newstr = colordiff(oldval, newstr)
|
||||
else:
|
||||
oldstr, newstr = colorize('red', oldstr), colorize('red', newstr)
|
||||
|
||||
return u'{0} -> {1}'.format(oldstr, newstr)
|
||||
|
||||
|
||||
def show_model_changes(new, old=None, fields=None, always=False):
|
||||
"""Given a Model object, print a list of changes from its pristine
|
||||
version stored in the database. Return a boolean indicating whether
|
||||
any changes were found.
|
||||
|
||||
`old` may be the "original" object to avoid using the pristine
|
||||
version from the database. `fields` may be a list of fields to
|
||||
restrict the detection to. `always` indicates whether the object is
|
||||
always identified, regardless of whether any changes are present.
|
||||
"""
|
||||
old = old or new._db._get(type(new), new.id)
|
||||
|
||||
# Build up lines showing changed fields.
|
||||
changes = []
|
||||
for field in old:
|
||||
# Subset of the fields. Never show mtime.
|
||||
if field == 'mtime' or (fields and field not in fields):
|
||||
continue
|
||||
|
||||
# Detect and show difference for this field.
|
||||
line = _field_diff(field, old, new)
|
||||
if line:
|
||||
changes.append(u' {0}: {1}'.format(field, line))
|
||||
|
||||
# New fields.
|
||||
for field in set(new) - set(old):
|
||||
if fields and field not in fields:
|
||||
continue
|
||||
|
||||
changes.append(u' {0}: {1}'.format(
|
||||
field,
|
||||
colorize('red', new.formatted()[field])
|
||||
))
|
||||
|
||||
# Print changes.
|
||||
if changes or always:
|
||||
print_obj(old, old._db)
|
||||
if changes:
|
||||
print_(u'\n'.join(changes))
|
||||
|
||||
return bool(changes)
|
||||
|
||||
|
||||
# Subcommand parsing infrastructure.
|
||||
#
|
||||
# This is a fairly generic subcommand parser for optparse. It is
|
||||
# maintained externally here:
|
||||
# http://gist.github.com/462717
|
||||
# There you will also find a better description of the code and a more
|
||||
# succinct example program.
|
||||
|
||||
class Subcommand(object):
|
||||
"""A subcommand of a root command-line application that may be
|
||||
invoked by a SubcommandOptionParser.
|
||||
"""
|
||||
def __init__(self, name, parser=None, help='', aliases=(), hide=False):
|
||||
"""Creates a new subcommand. name is the primary way to invoke
|
||||
the subcommand; aliases are alternate names. parser is an
|
||||
OptionParser responsible for parsing the subcommand's options.
|
||||
help is a short description of the command. If no parser is
|
||||
given, it defaults to a new, empty OptionParser.
|
||||
"""
|
||||
self.name = name
|
||||
self.parser = parser or optparse.OptionParser()
|
||||
self.aliases = aliases
|
||||
self.help = help
|
||||
self.hide = hide
|
||||
self._root_parser = None
|
||||
|
||||
def print_help(self):
|
||||
self.parser.print_help()
|
||||
|
||||
def parse_args(self, args):
|
||||
return self.parser.parse_args(args)
|
||||
|
||||
@property
|
||||
def root_parser(self):
|
||||
return self._root_parser
|
||||
|
||||
@root_parser.setter
|
||||
def root_parser(self, root_parser):
|
||||
self._root_parser = root_parser
|
||||
self.parser.prog = '{0} {1}'.format(root_parser.get_prog_name(),
|
||||
self.name)
|
||||
|
||||
|
||||
class SubcommandsOptionParser(optparse.OptionParser):
|
||||
"""A variant of OptionParser that parses subcommands and their
|
||||
arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Create a new subcommand-aware option parser. All of the
|
||||
options to OptionParser.__init__ are supported in addition
|
||||
to subcommands, a sequence of Subcommand objects.
|
||||
"""
|
||||
# A more helpful default usage.
|
||||
if 'usage' not in kwargs:
|
||||
kwargs['usage'] = """
|
||||
%prog COMMAND [ARGS...]
|
||||
%prog help COMMAND"""
|
||||
kwargs['add_help_option'] = False
|
||||
|
||||
# Super constructor.
|
||||
optparse.OptionParser.__init__(self, *args, **kwargs)
|
||||
|
||||
# Our root parser needs to stop on the first unrecognized argument.
|
||||
self.disable_interspersed_args()
|
||||
|
||||
self.subcommands = []
|
||||
|
||||
def add_subcommand(self, *cmds):
|
||||
"""Adds a Subcommand object to the parser's list of commands.
|
||||
"""
|
||||
for cmd in cmds:
|
||||
cmd.root_parser = self
|
||||
self.subcommands.append(cmd)
|
||||
|
||||
# Add the list of subcommands to the help message.
|
||||
def format_help(self, formatter=None):
|
||||
# Get the original help message, to which we will append.
|
||||
out = optparse.OptionParser.format_help(self, formatter)
|
||||
if formatter is None:
|
||||
formatter = self.formatter
|
||||
|
||||
# Subcommands header.
|
||||
result = ["\n"]
|
||||
result.append(formatter.format_heading('Commands'))
|
||||
formatter.indent()
|
||||
|
||||
# Generate the display names (including aliases).
|
||||
# Also determine the help position.
|
||||
disp_names = []
|
||||
help_position = 0
|
||||
subcommands = [c for c in self.subcommands if not c.hide]
|
||||
subcommands.sort(key=lambda c: c.name)
|
||||
for subcommand in subcommands:
|
||||
name = subcommand.name
|
||||
if subcommand.aliases:
|
||||
name += ' (%s)' % ', '.join(subcommand.aliases)
|
||||
disp_names.append(name)
|
||||
|
||||
# Set the help position based on the max width.
|
||||
proposed_help_position = len(name) + formatter.current_indent + 2
|
||||
if proposed_help_position <= formatter.max_help_position:
|
||||
help_position = max(help_position, proposed_help_position)
|
||||
|
||||
# Add each subcommand to the output.
|
||||
for subcommand, name in zip(subcommands, disp_names):
|
||||
# Lifted directly from optparse.py.
|
||||
name_width = help_position - formatter.current_indent - 2
|
||||
if len(name) > name_width:
|
||||
name = "%*s%s\n" % (formatter.current_indent, "", name)
|
||||
indent_first = help_position
|
||||
else:
|
||||
name = "%*s%-*s " % (formatter.current_indent, "",
|
||||
name_width, name)
|
||||
indent_first = 0
|
||||
result.append(name)
|
||||
help_width = formatter.width - help_position
|
||||
help_lines = textwrap.wrap(subcommand.help, help_width)
|
||||
result.append("%*s%s\n" % (indent_first, "", help_lines[0]))
|
||||
result.extend(["%*s%s\n" % (help_position, "", line)
|
||||
for line in help_lines[1:]])
|
||||
formatter.dedent()
|
||||
|
||||
# Concatenate the original help message with the subcommand
|
||||
# list.
|
||||
return out + "".join(result)
|
||||
|
||||
def _subcommand_for_name(self, name):
|
||||
"""Return the subcommand in self.subcommands matching the
|
||||
given name. The name may either be the name of a subcommand or
|
||||
an alias. If no subcommand matches, returns None.
|
||||
"""
|
||||
for subcommand in self.subcommands:
|
||||
if name == subcommand.name or \
|
||||
name in subcommand.aliases:
|
||||
return subcommand
|
||||
return None
|
||||
|
||||
def parse_global_options(self, args):
|
||||
"""Parse options up to the subcommand argument. Returns a tuple
|
||||
of the options object and the remaining arguments.
|
||||
"""
|
||||
options, subargs = self.parse_args(args)
|
||||
|
||||
# Force the help command
|
||||
if options.help:
|
||||
subargs = ['help']
|
||||
elif options.version:
|
||||
subargs = ['version']
|
||||
return options, subargs
|
||||
|
||||
def parse_subcommand(self, args):
|
||||
"""Given the `args` left unused by a `parse_global_options`,
|
||||
return the invoked subcommand, the subcommand options, and the
|
||||
subcommand arguments.
|
||||
"""
|
||||
# Help is default command
|
||||
if not args:
|
||||
args = ['help']
|
||||
|
||||
cmdname = args.pop(0)
|
||||
subcommand = self._subcommand_for_name(cmdname)
|
||||
if not subcommand:
|
||||
raise UserError("unknown command '{0}'".format(cmdname))
|
||||
|
||||
suboptions, subargs = subcommand.parse_args(args)
|
||||
return subcommand, suboptions, subargs
|
||||
|
||||
|
||||
optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
|
||||
|
||||
|
||||
def vararg_callback(option, opt_str, value, parser):
|
||||
"""Callback for an option with variable arguments.
|
||||
Manually collect arguments right of a callback-action
|
||||
option (ie. with action="callback"), and add the resulting
|
||||
list to the destination var.
|
||||
|
||||
Usage:
|
||||
parser.add_option("-c", "--callback", dest="vararg_attr",
|
||||
action="callback", callback=vararg_callback)
|
||||
|
||||
Details:
|
||||
http://docs.python.org/2/library/optparse.html#callback-example-6-variable
|
||||
-arguments
|
||||
"""
|
||||
value = [value]
|
||||
|
||||
def floatable(str):
|
||||
try:
|
||||
float(str)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
for arg in parser.rargs:
|
||||
# stop on --foo like options
|
||||
if arg[:2] == "--" and len(arg) > 2:
|
||||
break
|
||||
# stop on -a, but not on -3 or -3.0
|
||||
if arg[:1] == "-" and len(arg) > 1 and not floatable(arg):
|
||||
break
|
||||
value.append(arg)
|
||||
|
||||
del parser.rargs[:len(value) - 1]
|
||||
setattr(parser.values, option.dest, value)
|
||||
|
||||
|
||||
# The main entry point and bootstrapping.
|
||||
|
||||
def _load_plugins(config):
|
||||
"""Load the plugins specified in the configuration.
|
||||
"""
|
||||
paths = config['pluginpath'].get(confit.StrSeq(split=False))
|
||||
paths = map(util.normpath, paths)
|
||||
|
||||
import beetsplug
|
||||
beetsplug.__path__ = paths + beetsplug.__path__
|
||||
# For backwards compatibility.
|
||||
sys.path += paths
|
||||
|
||||
plugins.load_plugins(config['plugins'].as_str_seq())
|
||||
plugins.send("pluginload")
|
||||
return plugins
|
||||
|
||||
|
||||
def _setup(options, lib=None):
|
||||
"""Prepare and global state and updates it with command line options.
|
||||
|
||||
Returns a list of subcommands, a list of plugins, and a library instance.
|
||||
"""
|
||||
# Configure the MusicBrainz API.
|
||||
mb.configure()
|
||||
|
||||
config = _configure(options)
|
||||
|
||||
plugins = _load_plugins(config)
|
||||
|
||||
# Get the default subcommands.
|
||||
from beets.ui.commands import default_commands
|
||||
|
||||
subcommands = list(default_commands)
|
||||
subcommands.extend(plugins.commands())
|
||||
|
||||
if lib is None:
|
||||
lib = _open_library(config)
|
||||
plugins.send("library_opened", lib=lib)
|
||||
library.Item._types = plugins.types(library.Item)
|
||||
library.Album._types = plugins.types(library.Album)
|
||||
|
||||
return subcommands, plugins, lib
|
||||
|
||||
|
||||
def _configure(options):
|
||||
"""Amend the global configuration object with command line options.
|
||||
"""
|
||||
# Add any additional config files specified with --config. This
|
||||
# special handling lets specified plugins get loaded before we
|
||||
# finish parsing the command line.
|
||||
if getattr(options, 'config', None) is not None:
|
||||
config_path = options.config
|
||||
del options.config
|
||||
config.set_file(config_path)
|
||||
config.set_args(options)
|
||||
|
||||
# Configure the logger.
|
||||
if config['verbose'].get(bool):
|
||||
log.setLevel(logging.DEBUG)
|
||||
else:
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
config_path = config.user_config_path()
|
||||
if os.path.isfile(config_path):
|
||||
log.debug(u'user configuration: {0}'.format(
|
||||
util.displayable_path(config_path)))
|
||||
else:
|
||||
log.debug(u'no user configuration found at {0}'.format(
|
||||
util.displayable_path(config_path)))
|
||||
|
||||
log.debug(u'data directory: {0}'
|
||||
.format(util.displayable_path(config.config_dir())))
|
||||
return config
|
||||
|
||||
|
||||
def _open_library(config):
|
||||
"""Create a new library instance from the configuration.
|
||||
"""
|
||||
dbpath = config['library'].as_filename()
|
||||
try:
|
||||
lib = library.Library(
|
||||
dbpath,
|
||||
config['directory'].as_filename(),
|
||||
get_path_formats(),
|
||||
get_replacements(),
|
||||
)
|
||||
lib.get_item(0) # Test database connection.
|
||||
except (sqlite3.OperationalError, sqlite3.DatabaseError):
|
||||
log.debug(traceback.format_exc())
|
||||
raise UserError(u"database file {0} could not be opened".format(
|
||||
util.displayable_path(dbpath)
|
||||
))
|
||||
log.debug(u'library database: {0}\n'
|
||||
u'library directory: {1}'
|
||||
.format(util.displayable_path(lib.path),
|
||||
util.displayable_path(lib.directory)))
|
||||
return lib
|
||||
|
||||
|
||||
def _raw_main(args, lib=None):
|
||||
"""A helper function for `main` without top-level exception
|
||||
handling.
|
||||
"""
|
||||
parser = SubcommandsOptionParser()
|
||||
parser.add_option('-l', '--library', dest='library',
|
||||
help='library database file to use')
|
||||
parser.add_option('-d', '--directory', dest='directory',
|
||||
help="destination music directory")
|
||||
parser.add_option('-v', '--verbose', dest='verbose', action='store_true',
|
||||
help='print debugging information')
|
||||
parser.add_option('-c', '--config', dest='config',
|
||||
help='path to configuration file')
|
||||
parser.add_option('-h', '--help', dest='help', action='store_true',
|
||||
help='how this help message and exit')
|
||||
parser.add_option('--version', dest='version', action='store_true',
|
||||
help=optparse.SUPPRESS_HELP)
|
||||
|
||||
options, subargs = parser.parse_global_options(args)
|
||||
|
||||
# Special case for the `config --edit` command: bypass _setup so
|
||||
# that an invalid configuration does not prevent the editor from
|
||||
# starting.
|
||||
if subargs[0] == 'config' and ('-e' in subargs or '--edit' in subargs):
|
||||
from beets.ui.commands import config_edit
|
||||
return config_edit()
|
||||
|
||||
subcommands, plugins, lib = _setup(options, lib)
|
||||
parser.add_subcommand(*subcommands)
|
||||
|
||||
subcommand, suboptions, subargs = parser.parse_subcommand(subargs)
|
||||
subcommand.func(lib, suboptions, subargs)
|
||||
|
||||
plugins.send('cli_exit', lib=lib)
|
||||
|
||||
|
||||
def main(args=None):
|
||||
"""Run the main command-line interface for beets. Includes top-level
|
||||
exception handlers that print friendly error messages.
|
||||
"""
|
||||
try:
|
||||
_raw_main(args)
|
||||
except UserError as exc:
|
||||
message = exc.args[0] if exc.args else None
|
||||
log.error(u'error: {0}'.format(message))
|
||||
sys.exit(1)
|
||||
except util.HumanReadableException as exc:
|
||||
exc.log(log)
|
||||
sys.exit(1)
|
||||
except library.FileOperationError as exc:
|
||||
# These errors have reasonable human-readable descriptions, but
|
||||
# we still want to log their tracebacks for debugging.
|
||||
log.debug(traceback.format_exc())
|
||||
log.error(exc)
|
||||
sys.exit(1)
|
||||
except confit.ConfigError as exc:
|
||||
log.error(u'configuration error: {0}'.format(exc))
|
||||
sys.exit(1)
|
||||
except IOError as exc:
|
||||
if exc.errno == errno.EPIPE:
|
||||
# "Broken pipe". End silently.
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
# Silently ignore ^C except in verbose mode.
|
||||
log.debug(traceback.format_exc())
|
||||
1646
lib/beets/ui/commands.py
Normal file
1646
lib/beets/ui/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
162
lib/beets/ui/completion_base.sh
Normal file
162
lib/beets/ui/completion_base.sh
Normal file
@@ -0,0 +1,162 @@
|
||||
# This file is part of beets.
|
||||
# Copyright (c) 2014, Thomas Scholtes.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
|
||||
|
||||
# Completion for the `beet` command
|
||||
# =================================
|
||||
#
|
||||
# Load this script to complete beets subcommands, options, and
|
||||
# queries.
|
||||
#
|
||||
# If a beets command is found on the command line it completes filenames and
|
||||
# the subcommand's options. Otherwise it will complete global options and
|
||||
# subcommands. If the previous option on the command line expects an argument,
|
||||
# it also completes filenames or directories. Options are only
|
||||
# completed if '-' has already been typed on the command line.
|
||||
#
|
||||
# Note that completion of plugin commands only works for those plugins
|
||||
# that were enabled when running `beet completion`. It does not check
|
||||
# plugins dynamically
|
||||
#
|
||||
# Currently, only Bash 3.2 and newer is supported and the
|
||||
# `bash-completion` package is requied.
|
||||
#
|
||||
# TODO
|
||||
# ----
|
||||
#
|
||||
# * There are some issues with arguments that are quoted on the command line.
|
||||
#
|
||||
# * Complete arguments for the `--format` option by expanding field variables.
|
||||
#
|
||||
# beet ls -f "$tit[TAB]
|
||||
# beet ls -f "$title
|
||||
#
|
||||
# * Support long options with `=`, e.g. `--config=file`. Debian's bash
|
||||
# completion package can handle this.
|
||||
#
|
||||
|
||||
|
||||
# Determines the beets subcommand and dispatches the completion
|
||||
# accordingly.
|
||||
_beet_dispatch() {
|
||||
local cur prev cmd=
|
||||
|
||||
COMPREPLY=()
|
||||
_get_comp_words_by_ref -n : cur prev
|
||||
|
||||
# Look for the beets subcommand
|
||||
local arg
|
||||
for (( i=1; i < COMP_CWORD; i++ )); do
|
||||
arg="${COMP_WORDS[i]}"
|
||||
if _list_include_item "${opts___global}" $arg; then
|
||||
((i++))
|
||||
elif [[ "$arg" != -* ]]; then
|
||||
cmd="$arg"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
# Replace command shortcuts
|
||||
if [[ -n $cmd ]] && _list_include_item "$aliases" "$cmd"; then
|
||||
eval "cmd=\$alias__$cmd"
|
||||
fi
|
||||
|
||||
case $cmd in
|
||||
help)
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
;;
|
||||
list|remove|move|update|write|stats)
|
||||
_beet_complete_query
|
||||
;;
|
||||
"")
|
||||
_beet_complete_global
|
||||
;;
|
||||
*)
|
||||
_beet_complete
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
|
||||
# Adds option and file completion to COMPREPLY for the subcommand $cmd
|
||||
_beet_complete() {
|
||||
if [[ $cur == -* ]]; then
|
||||
local opts flags completions
|
||||
eval "opts=\$opts__$cmd"
|
||||
eval "flags=\$flags__$cmd"
|
||||
completions="${flags___common} ${opts} ${flags}"
|
||||
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
|
||||
else
|
||||
_filedir
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
# Add global options and subcommands to the completion
|
||||
_beet_complete_global() {
|
||||
case $prev in
|
||||
-h|--help)
|
||||
# Complete commands
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
return
|
||||
;;
|
||||
-l|--library|-c|--config)
|
||||
# Filename completion
|
||||
_filedir
|
||||
return
|
||||
;;
|
||||
-d|--directory)
|
||||
# Directory completion
|
||||
_filedir -d
|
||||
return
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ $cur == -* ]]; then
|
||||
local completions="$opts___global $flags___global"
|
||||
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
|
||||
elif [[ -n $cur ]] && _list_include_item "$aliases" "$cur"; then
|
||||
local cmd
|
||||
eval "cmd=\$alias__$cur"
|
||||
COMPREPLY+=( "$cmd" )
|
||||
else
|
||||
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
|
||||
fi
|
||||
}
|
||||
|
||||
_beet_complete_query() {
|
||||
local opts
|
||||
eval "opts=\$opts__$cmd"
|
||||
|
||||
if [[ $cur == -* ]] || _list_include_item "$opts" "$prev"; then
|
||||
_beet_complete
|
||||
elif [[ $cur != \'* && $cur != \"* &&
|
||||
$cur != *:* ]]; then
|
||||
# Do not complete quoted queries or those who already have a field
|
||||
# set.
|
||||
compopt -o nospace
|
||||
COMPREPLY+=( $(compgen -S : -W "$fields" -- $cur) )
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
# Returns true if the space separated list $1 includes $2
|
||||
_list_include_item() {
|
||||
[[ " $1 " == *[[:space:]]$2[[:space:]]* ]]
|
||||
}
|
||||
|
||||
# This is where beets dynamically adds the _beet function. This
|
||||
# function sets the variables $flags, $opts, $commands, and $aliases.
|
||||
complete -o filenames -F _beet beet
|
||||
686
lib/beets/util/__init__.py
Normal file
686
lib/beets/util/__init__.py
Normal file
@@ -0,0 +1,686 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Miscellaneous utility functions."""
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import fnmatch
|
||||
from collections import defaultdict
|
||||
import traceback
|
||||
import subprocess
|
||||
import platform
|
||||
|
||||
|
||||
MAX_FILENAME_LENGTH = 200
|
||||
WINDOWS_MAGIC_PREFIX = u'\\\\?\\'
|
||||
|
||||
|
||||
class HumanReadableException(Exception):
|
||||
"""An Exception that can include a human-readable error message to
|
||||
be logged without a traceback. Can preserve a traceback for
|
||||
debugging purposes as well.
|
||||
|
||||
Has at least two fields: `reason`, the underlying exception or a
|
||||
string describing the problem; and `verb`, the action being
|
||||
performed during the error.
|
||||
|
||||
If `tb` is provided, it is a string containing a traceback for the
|
||||
associated exception. (Note that this is not necessary in Python 3.x
|
||||
and should be removed when we make the transition.)
|
||||
"""
|
||||
error_kind = 'Error' # Human-readable description of error type.
|
||||
|
||||
def __init__(self, reason, verb, tb=None):
|
||||
self.reason = reason
|
||||
self.verb = verb
|
||||
self.tb = tb
|
||||
super(HumanReadableException, self).__init__(self.get_message())
|
||||
|
||||
def _gerund(self):
|
||||
"""Generate a (likely) gerund form of the English verb.
|
||||
"""
|
||||
if ' ' in self.verb:
|
||||
return self.verb
|
||||
gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb
|
||||
gerund += 'ing'
|
||||
return gerund
|
||||
|
||||
def _reasonstr(self):
|
||||
"""Get the reason as a string."""
|
||||
if isinstance(self.reason, unicode):
|
||||
return self.reason
|
||||
elif isinstance(self.reason, basestring): # Byte string.
|
||||
return self.reason.decode('utf8', 'ignore')
|
||||
elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError
|
||||
return self.reason.strerror
|
||||
else:
|
||||
return u'"{0}"'.format(unicode(self.reason))
|
||||
|
||||
def get_message(self):
|
||||
"""Create the human-readable description of the error, sans
|
||||
introduction.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def log(self, logger):
|
||||
"""Log to the provided `logger` a human-readable message as an
|
||||
error and a verbose traceback as a debug message.
|
||||
"""
|
||||
if self.tb:
|
||||
logger.debug(self.tb)
|
||||
logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0]))
|
||||
|
||||
|
||||
class FilesystemError(HumanReadableException):
|
||||
"""An error that occurred while performing a filesystem manipulation
|
||||
via a function in this module. The `paths` field is a sequence of
|
||||
pathnames involved in the operation.
|
||||
"""
|
||||
def __init__(self, reason, verb, paths, tb=None):
|
||||
self.paths = paths
|
||||
super(FilesystemError, self).__init__(reason, verb, tb)
|
||||
|
||||
def get_message(self):
|
||||
# Use a nicer English phrasing for some specific verbs.
|
||||
if self.verb in ('move', 'copy', 'rename'):
|
||||
clause = u'while {0} {1} to {2}'.format(
|
||||
self._gerund(),
|
||||
displayable_path(self.paths[0]),
|
||||
displayable_path(self.paths[1])
|
||||
)
|
||||
elif self.verb in ('delete', 'write', 'create', 'read'):
|
||||
clause = u'while {0} {1}'.format(
|
||||
self._gerund(),
|
||||
displayable_path(self.paths[0])
|
||||
)
|
||||
else:
|
||||
clause = u'during {0} of paths {1}'.format(
|
||||
self.verb, u', '.join(displayable_path(p) for p in self.paths)
|
||||
)
|
||||
|
||||
return u'{0} {1}'.format(self._reasonstr(), clause)
|
||||
|
||||
|
||||
def normpath(path):
|
||||
"""Provide the canonical form of the path suitable for storing in
|
||||
the database.
|
||||
"""
|
||||
path = syspath(path, prefix=False)
|
||||
path = os.path.normpath(os.path.abspath(os.path.expanduser(path)))
|
||||
return bytestring_path(path)
|
||||
|
||||
|
||||
def ancestry(path):
|
||||
"""Return a list consisting of path's parent directory, its
|
||||
grandparent, and so on. For instance:
|
||||
|
||||
>>> ancestry('/a/b/c')
|
||||
['/', '/a', '/a/b']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
"""
|
||||
out = []
|
||||
last_path = None
|
||||
while path:
|
||||
path = os.path.dirname(path)
|
||||
|
||||
if path == last_path:
|
||||
break
|
||||
last_path = path
|
||||
|
||||
if path:
|
||||
# don't yield ''
|
||||
out.insert(0, path)
|
||||
return out
|
||||
|
||||
|
||||
def sorted_walk(path, ignore=(), logger=None):
|
||||
"""Like `os.walk`, but yields things in case-insensitive sorted,
|
||||
breadth-first order. Directory and file names matching any glob
|
||||
pattern in `ignore` are skipped. If `logger` is provided, then
|
||||
warning messages are logged there when a directory cannot be listed.
|
||||
"""
|
||||
# Make sure the path isn't a Unicode string.
|
||||
path = bytestring_path(path)
|
||||
|
||||
# Get all the directories and files at this level.
|
||||
try:
|
||||
contents = os.listdir(syspath(path))
|
||||
except OSError as exc:
|
||||
if logger:
|
||||
logger.warn(u'could not list directory {0}: {1}'.format(
|
||||
displayable_path(path), exc.strerror
|
||||
))
|
||||
return
|
||||
dirs = []
|
||||
files = []
|
||||
for base in contents:
|
||||
base = bytestring_path(base)
|
||||
|
||||
# Skip ignored filenames.
|
||||
skip = False
|
||||
for pat in ignore:
|
||||
if fnmatch.fnmatch(base, pat):
|
||||
skip = True
|
||||
break
|
||||
if skip:
|
||||
continue
|
||||
|
||||
# Add to output as either a file or a directory.
|
||||
cur = os.path.join(path, base)
|
||||
if os.path.isdir(syspath(cur)):
|
||||
dirs.append(base)
|
||||
else:
|
||||
files.append(base)
|
||||
|
||||
# Sort lists (case-insensitive) and yield the current level.
|
||||
dirs.sort(key=bytes.lower)
|
||||
files.sort(key=bytes.lower)
|
||||
yield (path, dirs, files)
|
||||
|
||||
# Recurse into directories.
|
||||
for base in dirs:
|
||||
cur = os.path.join(path, base)
|
||||
# yield from sorted_walk(...)
|
||||
for res in sorted_walk(cur, ignore, logger):
|
||||
yield res
|
||||
|
||||
|
||||
def mkdirall(path):
|
||||
"""Make all the enclosing directories of path (like mkdir -p on the
|
||||
parent).
|
||||
"""
|
||||
for ancestor in ancestry(path):
|
||||
if not os.path.isdir(syspath(ancestor)):
|
||||
try:
|
||||
os.mkdir(syspath(ancestor))
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'create', (ancestor,),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def fnmatch_all(names, patterns):
|
||||
"""Determine whether all strings in `names` match at least one of
|
||||
the `patterns`, which should be shell glob expressions.
|
||||
"""
|
||||
for name in names:
|
||||
matches = False
|
||||
for pattern in patterns:
|
||||
matches = fnmatch.fnmatch(name, pattern)
|
||||
if matches:
|
||||
break
|
||||
if not matches:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
|
||||
"""If path is an empty directory, then remove it. Recursively remove
|
||||
path's ancestry up to root (which is never removed) where there are
|
||||
empty directories. If path is not contained in root, then nothing is
|
||||
removed. Glob patterns in clutter are ignored when determining
|
||||
emptiness. If root is not provided, then only path may be removed
|
||||
(i.e., no recursive removal).
|
||||
"""
|
||||
path = normpath(path)
|
||||
if root is not None:
|
||||
root = normpath(root)
|
||||
|
||||
ancestors = ancestry(path)
|
||||
if root is None:
|
||||
# Only remove the top directory.
|
||||
ancestors = []
|
||||
elif root in ancestors:
|
||||
# Only remove directories below the root.
|
||||
ancestors = ancestors[ancestors.index(root) + 1:]
|
||||
else:
|
||||
# Remove nothing.
|
||||
return
|
||||
|
||||
# Traverse upward from path.
|
||||
ancestors.append(path)
|
||||
ancestors.reverse()
|
||||
for directory in ancestors:
|
||||
directory = syspath(directory)
|
||||
if not os.path.exists(directory):
|
||||
# Directory gone already.
|
||||
continue
|
||||
if fnmatch_all(os.listdir(directory), clutter):
|
||||
# Directory contains only clutter (or nothing).
|
||||
try:
|
||||
shutil.rmtree(directory)
|
||||
except OSError:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def components(path):
|
||||
"""Return a list of the path components in path. For instance:
|
||||
|
||||
>>> components('/a/b/c')
|
||||
['a', 'b', 'c']
|
||||
|
||||
The argument should *not* be the result of a call to `syspath`.
|
||||
"""
|
||||
comps = []
|
||||
ances = ancestry(path)
|
||||
for anc in ances:
|
||||
comp = os.path.basename(anc)
|
||||
if comp:
|
||||
comps.append(comp)
|
||||
else: # root
|
||||
comps.append(anc)
|
||||
|
||||
last = os.path.basename(path)
|
||||
if last:
|
||||
comps.append(last)
|
||||
|
||||
return comps
|
||||
|
||||
|
||||
def _fsencoding():
|
||||
"""Get the system's filesystem encoding. On Windows, this is always
|
||||
UTF-8 (not MBCS).
|
||||
"""
|
||||
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
if encoding == 'mbcs':
|
||||
# On Windows, a broken encoding known to Python as "MBCS" is
|
||||
# used for the filesystem. However, we only use the Unicode API
|
||||
# for Windows paths, so the encoding is actually immaterial so
|
||||
# we can avoid dealing with this nastiness. We arbitrarily
|
||||
# choose UTF-8.
|
||||
encoding = 'utf8'
|
||||
return encoding
|
||||
|
||||
|
||||
def bytestring_path(path):
|
||||
"""Given a path, which is either a str or a unicode, returns a str
|
||||
path (ensuring that we never deal with Unicode pathnames).
|
||||
"""
|
||||
# Pass through bytestrings.
|
||||
if isinstance(path, str):
|
||||
return path
|
||||
|
||||
# On Windows, remove the magic prefix added by `syspath`. This makes
|
||||
# ``bytestring_path(syspath(X)) == X``, i.e., we can safely
|
||||
# round-trip through `syspath`.
|
||||
if os.path.__name__ == 'ntpath' and path.startswith(WINDOWS_MAGIC_PREFIX):
|
||||
path = path[len(WINDOWS_MAGIC_PREFIX):]
|
||||
|
||||
# Try to encode with default encodings, but fall back to UTF8.
|
||||
try:
|
||||
return path.encode(_fsencoding())
|
||||
except (UnicodeError, LookupError):
|
||||
return path.encode('utf8')
|
||||
|
||||
|
||||
def displayable_path(path, separator=u'; '):
|
||||
"""Attempts to decode a bytestring path to a unicode object for the
|
||||
purpose of displaying it to the user. If the `path` argument is a
|
||||
list or a tuple, the elements are joined with `separator`.
|
||||
"""
|
||||
if isinstance(path, (list, tuple)):
|
||||
return separator.join(displayable_path(p) for p in path)
|
||||
elif isinstance(path, unicode):
|
||||
return path
|
||||
elif not isinstance(path, str):
|
||||
# A non-string object: just get its unicode representation.
|
||||
return unicode(path)
|
||||
|
||||
try:
|
||||
return path.decode(_fsencoding(), 'ignore')
|
||||
except (UnicodeError, LookupError):
|
||||
return path.decode('utf8', 'ignore')
|
||||
|
||||
|
||||
def syspath(path, prefix=True):
|
||||
"""Convert a path for use by the operating system. In particular,
|
||||
paths on Windows must receive a magic prefix and must be converted
|
||||
to Unicode before they are sent to the OS. To disable the magic
|
||||
prefix on Windows, set `prefix` to False---but only do this if you
|
||||
*really* know what you're doing.
|
||||
"""
|
||||
# Don't do anything if we're not on windows
|
||||
if os.path.__name__ != 'ntpath':
|
||||
return path
|
||||
|
||||
if not isinstance(path, unicode):
|
||||
# Beets currently represents Windows paths internally with UTF-8
|
||||
# arbitrarily. But earlier versions used MBCS because it is
|
||||
# reported as the FS encoding by Windows. Try both.
|
||||
try:
|
||||
path = path.decode('utf8')
|
||||
except UnicodeError:
|
||||
# The encoding should always be MBCS, Windows' broken
|
||||
# Unicode representation.
|
||||
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
|
||||
path = path.decode(encoding, 'replace')
|
||||
|
||||
# Add the magic prefix if it isn't already there.
|
||||
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
|
||||
if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
|
||||
if path.startswith(u'\\\\'):
|
||||
# UNC path. Final path should look like \\?\UNC\...
|
||||
path = u'UNC' + path[1:]
|
||||
path = WINDOWS_MAGIC_PREFIX + path
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def samefile(p1, p2):
|
||||
"""Safer equality for paths."""
|
||||
return shutil._samefile(syspath(p1), syspath(p2))
|
||||
|
||||
|
||||
def remove(path, soft=True):
|
||||
"""Remove the file. If `soft`, then no error will be raised if the
|
||||
file does not exist.
|
||||
"""
|
||||
path = syspath(path)
|
||||
if soft and not os.path.exists(path):
|
||||
return
|
||||
try:
|
||||
os.remove(path)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
|
||||
|
||||
|
||||
def copy(path, dest, replace=False):
|
||||
"""Copy a plain file. Permissions are not copied. If `dest` already
|
||||
exists, raises a FilesystemError unless `replace` is True. Has no
|
||||
effect if `path` is the same as `dest`. Paths are translated to
|
||||
system paths before the syscall.
|
||||
"""
|
||||
if samefile(path, dest):
|
||||
return
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if not replace and os.path.exists(dest):
|
||||
raise FilesystemError('file exists', 'copy', (path, dest))
|
||||
try:
|
||||
shutil.copyfile(path, dest)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'copy', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def move(path, dest, replace=False):
|
||||
"""Rename a file. `dest` may not be a directory. If `dest` already
|
||||
exists, raises an OSError unless `replace` is True. Has no effect if
|
||||
`path` is the same as `dest`. If the paths are on different
|
||||
filesystems (or the rename otherwise fails), a copy is attempted
|
||||
instead, in which case metadata will *not* be preserved. Paths are
|
||||
translated to system paths.
|
||||
"""
|
||||
if samefile(path, dest):
|
||||
return
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if os.path.exists(dest) and not replace:
|
||||
raise FilesystemError('file exists', 'rename', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
# First, try renaming the file.
|
||||
try:
|
||||
os.rename(path, dest)
|
||||
except OSError:
|
||||
# Otherwise, copy and delete the original.
|
||||
try:
|
||||
shutil.copyfile(path, dest)
|
||||
os.remove(path)
|
||||
except (OSError, IOError) as exc:
|
||||
raise FilesystemError(exc, 'move', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def link(path, dest, replace=False):
|
||||
"""Create a symbolic link from path to `dest`. Raises an OSError if
|
||||
`dest` already exists, unless `replace` is True. Does nothing if
|
||||
`path` == `dest`."""
|
||||
if (samefile(path, dest)):
|
||||
return
|
||||
|
||||
path = syspath(path)
|
||||
dest = syspath(dest)
|
||||
if os.path.exists(dest) and not replace:
|
||||
raise FilesystemError('file exists', 'rename', (path, dest),
|
||||
traceback.format_exc())
|
||||
try:
|
||||
os.symlink(path, dest)
|
||||
except OSError:
|
||||
raise FilesystemError('Operating system does not support symbolic '
|
||||
'links.', 'link', (path, dest),
|
||||
traceback.format_exc())
|
||||
|
||||
|
||||
def unique_path(path):
|
||||
"""Returns a version of ``path`` that does not exist on the
|
||||
filesystem. Specifically, if ``path` itself already exists, then
|
||||
something unique is appended to the path.
|
||||
"""
|
||||
if not os.path.exists(syspath(path)):
|
||||
return path
|
||||
|
||||
base, ext = os.path.splitext(path)
|
||||
match = re.search(r'\.(\d)+$', base)
|
||||
if match:
|
||||
num = int(match.group(1))
|
||||
base = base[:match.start()]
|
||||
else:
|
||||
num = 0
|
||||
while True:
|
||||
num += 1
|
||||
new_path = '%s.%i%s' % (base, num, ext)
|
||||
if not os.path.exists(new_path):
|
||||
return new_path
|
||||
|
||||
# Note: The Windows "reserved characters" are, of course, allowed on
|
||||
# Unix. They are forbidden here because they cause problems on Samba
|
||||
# shares, which are sufficiently common as to cause frequent problems.
|
||||
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
|
||||
CHAR_REPLACE = [
|
||||
(re.compile(ur'[\\/]'), u'_'), # / and \ -- forbidden everywhere.
|
||||
(re.compile(ur'^\.'), u'_'), # Leading dot (hidden files on Unix).
|
||||
(re.compile(ur'[\x00-\x1f]'), u''), # Control characters.
|
||||
(re.compile(ur'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters".
|
||||
(re.compile(ur'\.$'), u'_'), # Trailing dots.
|
||||
(re.compile(ur'\s+$'), u''), # Trailing whitespace.
|
||||
]
|
||||
|
||||
|
||||
def sanitize_path(path, replacements=None):
|
||||
"""Takes a path (as a Unicode string) and makes sure that it is
|
||||
legal. Returns a new path. Only works with fragments; won't work
|
||||
reliably on Windows when a path begins with a drive letter. Path
|
||||
separators (including altsep!) should already be cleaned from the
|
||||
path components. If replacements is specified, it is used *instead*
|
||||
of the default set of replacements; it must be a list of (compiled
|
||||
regex, replacement string) pairs.
|
||||
"""
|
||||
replacements = replacements or CHAR_REPLACE
|
||||
|
||||
comps = components(path)
|
||||
if not comps:
|
||||
return ''
|
||||
for i, comp in enumerate(comps):
|
||||
for regex, repl in replacements:
|
||||
comp = regex.sub(repl, comp)
|
||||
comps[i] = comp
|
||||
return os.path.join(*comps)
|
||||
|
||||
|
||||
def truncate_path(path, length=MAX_FILENAME_LENGTH):
|
||||
"""Given a bytestring path or a Unicode path fragment, truncate the
|
||||
components to a legal length. In the last component, the extension
|
||||
is preserved.
|
||||
"""
|
||||
comps = components(path)
|
||||
|
||||
out = [c[:length] for c in comps]
|
||||
base, ext = os.path.splitext(comps[-1])
|
||||
if ext:
|
||||
# Last component has an extension.
|
||||
base = base[:length - len(ext)]
|
||||
out[-1] = base + ext
|
||||
|
||||
return os.path.join(*out)
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
"""Returns a boolean reflecting a human-entered string."""
|
||||
if value.lower() in ('yes', '1', 'true', 't', 'y'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def as_string(value):
|
||||
"""Convert a value to a Unicode object for matching with a query.
|
||||
None becomes the empty string. Bytestrings are silently decoded.
|
||||
"""
|
||||
if value is None:
|
||||
return u''
|
||||
elif isinstance(value, buffer):
|
||||
return str(value).decode('utf8', 'ignore')
|
||||
elif isinstance(value, str):
|
||||
return value.decode('utf8', 'ignore')
|
||||
else:
|
||||
return unicode(value)
|
||||
|
||||
|
||||
def levenshtein(s1, s2):
|
||||
"""A nice DP edit distance implementation from Wikibooks:
|
||||
http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/
|
||||
Levenshtein_distance#Python
|
||||
"""
|
||||
if len(s1) < len(s2):
|
||||
return levenshtein(s2, s1)
|
||||
if not s1:
|
||||
return len(s2)
|
||||
|
||||
previous_row = xrange(len(s2) + 1)
|
||||
for i, c1 in enumerate(s1):
|
||||
current_row = [i + 1]
|
||||
for j, c2 in enumerate(s2):
|
||||
insertions = previous_row[j + 1] + 1
|
||||
deletions = current_row[j] + 1
|
||||
substitutions = previous_row[j] + (c1 != c2)
|
||||
current_row.append(min(insertions, deletions, substitutions))
|
||||
previous_row = current_row
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
def plurality(objs):
|
||||
"""Given a sequence of comparable objects, returns the object that
|
||||
is most common in the set and the frequency of that object. The
|
||||
sequence must contain at least one object.
|
||||
"""
|
||||
# Calculate frequencies.
|
||||
freqs = defaultdict(int)
|
||||
for obj in objs:
|
||||
freqs[obj] += 1
|
||||
|
||||
if not freqs:
|
||||
raise ValueError('sequence must be non-empty')
|
||||
|
||||
# Find object with maximum frequency.
|
||||
max_freq = 0
|
||||
res = None
|
||||
for obj, freq in freqs.items():
|
||||
if freq > max_freq:
|
||||
max_freq = freq
|
||||
res = obj
|
||||
|
||||
return res, max_freq
|
||||
|
||||
|
||||
def cpu_count():
|
||||
"""Return the number of hardware thread contexts (cores or SMT
|
||||
threads) in the system.
|
||||
"""
|
||||
# Adapted from the soundconverter project:
|
||||
# https://github.com/kassoulet/soundconverter
|
||||
if sys.platform == 'win32':
|
||||
try:
|
||||
num = int(os.environ['NUMBER_OF_PROCESSORS'])
|
||||
except (ValueError, KeyError):
|
||||
num = 0
|
||||
elif sys.platform == 'darwin':
|
||||
try:
|
||||
num = int(command_output(['sysctl', '-n', 'hw.ncpu']))
|
||||
except ValueError:
|
||||
num = 0
|
||||
else:
|
||||
try:
|
||||
num = os.sysconf('SC_NPROCESSORS_ONLN')
|
||||
except (ValueError, OSError, AttributeError):
|
||||
num = 0
|
||||
if num >= 1:
|
||||
return num
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def command_output(cmd, shell=False):
|
||||
"""Runs the command and returns its output after it has exited.
|
||||
|
||||
``cmd`` is a list of arguments starting with the command names. If
|
||||
``shell`` is true, ``cmd`` is assumed to be a string and passed to a
|
||||
shell to execute.
|
||||
|
||||
If the process exits with a non-zero return code
|
||||
``subprocess.CalledProcessError`` is raised. May also raise
|
||||
``OSError``.
|
||||
|
||||
This replaces `subprocess.check_output`, which isn't available in
|
||||
Python 2.6 and which can have problems if lots of output is sent to
|
||||
stderr.
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=platform.system() != 'Windows',
|
||||
shell=shell
|
||||
)
|
||||
stdout, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
raise subprocess.CalledProcessError(
|
||||
returncode=proc.returncode,
|
||||
cmd=' '.join(cmd),
|
||||
)
|
||||
return stdout
|
||||
|
||||
|
||||
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
|
||||
"""Attempt to determine the maximum filename length for the
|
||||
filesystem containing `path`. If the value is greater than `limit`,
|
||||
then `limit` is used instead (to prevent errors when a filesystem
|
||||
misreports its capacity). If it cannot be determined (e.g., on
|
||||
Windows), return `limit`.
|
||||
"""
|
||||
if hasattr(os, 'statvfs'):
|
||||
try:
|
||||
res = os.statvfs(path)
|
||||
except OSError:
|
||||
return limit
|
||||
return min(res[9], limit)
|
||||
else:
|
||||
return limit
|
||||
211
lib/beets/util/artresizer.py
Normal file
211
lib/beets/util/artresizer.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Fabrice Laporte
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Abstraction layer to resize images using PIL, ImageMagick, or a
|
||||
public resizing proxy if neither is available.
|
||||
"""
|
||||
import urllib
|
||||
import subprocess
|
||||
import os
|
||||
import re
|
||||
from tempfile import NamedTemporaryFile
|
||||
import logging
|
||||
from beets import util
|
||||
|
||||
# Resizing methods
|
||||
PIL = 1
|
||||
IMAGEMAGICK = 2
|
||||
WEBPROXY = 3
|
||||
|
||||
PROXY_URL = 'http://images.weserv.nl/'
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
def resize_url(url, maxwidth):
|
||||
"""Return a proxied image URL that resizes the original image to
|
||||
maxwidth (preserving aspect ratio).
|
||||
"""
|
||||
return '{0}?{1}'.format(PROXY_URL, urllib.urlencode({
|
||||
'url': url.replace('http://', ''),
|
||||
'w': str(maxwidth),
|
||||
}))
|
||||
|
||||
|
||||
def temp_file_for(path):
|
||||
"""Return an unused filename with the same extension as the
|
||||
specified path.
|
||||
"""
|
||||
ext = os.path.splitext(path)[1]
|
||||
with NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||||
return f.name
|
||||
|
||||
|
||||
def pil_resize(maxwidth, path_in, path_out=None):
|
||||
"""Resize using Python Imaging Library (PIL). Return the output path
|
||||
of resized image.
|
||||
"""
|
||||
path_out = path_out or temp_file_for(path_in)
|
||||
from PIL import Image
|
||||
log.debug(u'artresizer: PIL resizing {0} to {1}'.format(
|
||||
util.displayable_path(path_in), util.displayable_path(path_out)
|
||||
))
|
||||
|
||||
try:
|
||||
im = Image.open(util.syspath(path_in))
|
||||
size = maxwidth, maxwidth
|
||||
im.thumbnail(size, Image.ANTIALIAS)
|
||||
im.save(path_out)
|
||||
return path_out
|
||||
except IOError:
|
||||
log.error(u"PIL cannot create thumbnail for '{0}'".format(
|
||||
util.displayable_path(path_in)
|
||||
))
|
||||
return path_in
|
||||
|
||||
|
||||
def im_resize(maxwidth, path_in, path_out=None):
|
||||
"""Resize using ImageMagick's ``convert`` tool.
|
||||
Return the output path of resized image.
|
||||
"""
|
||||
path_out = path_out or temp_file_for(path_in)
|
||||
log.debug(u'artresizer: ImageMagick resizing {0} to {1}'.format(
|
||||
util.displayable_path(path_in), util.displayable_path(path_out)
|
||||
))
|
||||
|
||||
# "-resize widthxheight>" shrinks images with dimension(s) larger
|
||||
# than the corresponding width and/or height dimension(s). The >
|
||||
# "only shrink" flag is prefixed by ^ escape char for Windows
|
||||
# compatibility.
|
||||
try:
|
||||
util.command_output([
|
||||
'convert', util.syspath(path_in),
|
||||
'-resize', '{0}x^>'.format(maxwidth), path_out
|
||||
])
|
||||
except subprocess.CalledProcessError:
|
||||
log.warn(u'artresizer: IM convert failed for {0}'.format(
|
||||
util.displayable_path(path_in)
|
||||
))
|
||||
return path_in
|
||||
return path_out
|
||||
|
||||
|
||||
BACKEND_FUNCS = {
|
||||
PIL: pil_resize,
|
||||
IMAGEMAGICK: im_resize,
|
||||
}
|
||||
|
||||
|
||||
class Shareable(type):
|
||||
"""A pseudo-singleton metaclass that allows both shared and
|
||||
non-shared instances. The ``MyClass.shared`` property holds a
|
||||
lazily-created shared instance of ``MyClass`` while calling
|
||||
``MyClass()`` to construct a new object works as usual.
|
||||
"""
|
||||
def __init__(cls, name, bases, dict):
|
||||
super(Shareable, cls).__init__(name, bases, dict)
|
||||
cls._instance = None
|
||||
|
||||
@property
|
||||
def shared(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
|
||||
class ArtResizer(object):
|
||||
"""A singleton class that performs image resizes.
|
||||
"""
|
||||
__metaclass__ = Shareable
|
||||
|
||||
def __init__(self, method=None):
|
||||
"""Create a resizer object for the given method or, if none is
|
||||
specified, with an inferred method.
|
||||
"""
|
||||
self.method = self._check_method(method)
|
||||
log.debug(u"artresizer: method is {0}".format(self.method))
|
||||
self.can_compare = self._can_compare()
|
||||
|
||||
def resize(self, maxwidth, path_in, path_out=None):
|
||||
"""Manipulate an image file according to the method, returning a
|
||||
new path. For PIL or IMAGEMAGIC methods, resizes the image to a
|
||||
temporary file. For WEBPROXY, returns `path_in` unmodified.
|
||||
"""
|
||||
if self.local:
|
||||
func = BACKEND_FUNCS[self.method[0]]
|
||||
return func(maxwidth, path_in, path_out)
|
||||
else:
|
||||
return path_in
|
||||
|
||||
def proxy_url(self, maxwidth, url):
|
||||
"""Modifies an image URL according the method, returning a new
|
||||
URL. For WEBPROXY, a URL on the proxy server is returned.
|
||||
Otherwise, the URL is returned unmodified.
|
||||
"""
|
||||
if self.local:
|
||||
return url
|
||||
else:
|
||||
return resize_url(url, maxwidth)
|
||||
|
||||
@property
|
||||
def local(self):
|
||||
"""A boolean indicating whether the resizing method is performed
|
||||
locally (i.e., PIL or ImageMagick).
|
||||
"""
|
||||
return self.method[0] in BACKEND_FUNCS
|
||||
|
||||
def _can_compare(self):
|
||||
"""A boolean indicating whether image comparison is available"""
|
||||
|
||||
return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
|
||||
|
||||
@staticmethod
|
||||
def _check_method(method=None):
|
||||
"""A tuple indicating whether current method is available and its
|
||||
version. If no method is given, it returns a supported one.
|
||||
"""
|
||||
# Guess available method
|
||||
if not method:
|
||||
for m in [IMAGEMAGICK, PIL]:
|
||||
_, version = ArtResizer._check_method(m)
|
||||
if version:
|
||||
return (m, version)
|
||||
return (WEBPROXY, (0))
|
||||
|
||||
if method == IMAGEMAGICK:
|
||||
|
||||
# Try invoking ImageMagick's "convert".
|
||||
try:
|
||||
out = util.command_output(['identify', '--version'])
|
||||
|
||||
if 'imagemagick' in out.lower():
|
||||
pattern = r".+ (\d+)\.(\d+)\.(\d+).*"
|
||||
match = re.search(pattern, out)
|
||||
if match:
|
||||
return (IMAGEMAGICK,
|
||||
(int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))))
|
||||
return (IMAGEMAGICK, (0))
|
||||
|
||||
except (subprocess.CalledProcessError, OSError):
|
||||
return (IMAGEMAGICK, None)
|
||||
|
||||
if method == PIL:
|
||||
# Try importing PIL.
|
||||
try:
|
||||
__import__('PIL', fromlist=['Image'])
|
||||
return (PIL, (0))
|
||||
except ImportError:
|
||||
return (PIL, None)
|
||||
647
lib/beets/util/bluelet.py
Normal file
647
lib/beets/util/bluelet.py
Normal file
@@ -0,0 +1,647 @@
|
||||
"""Extremely simple pure-Python implementation of coroutine-style
|
||||
asynchronous socket I/O. Inspired by, but inferior to, Eventlet.
|
||||
Bluelet can also be thought of as a less-terrible replacement for
|
||||
asyncore.
|
||||
|
||||
Bluelet: easy concurrency without all the messy parallelism.
|
||||
"""
|
||||
import socket
|
||||
import select
|
||||
import sys
|
||||
import types
|
||||
import errno
|
||||
import traceback
|
||||
import time
|
||||
import collections
|
||||
|
||||
|
||||
# A little bit of "six" (Python 2/3 compatibility): cope with PEP 3109 syntax
|
||||
# changes.
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
if PY3:
|
||||
def _reraise(typ, exc, tb):
|
||||
raise exc.with_traceback(tb)
|
||||
else:
|
||||
exec("""
|
||||
def _reraise(typ, exc, tb):
|
||||
raise typ, exc, tb
|
||||
""")
|
||||
|
||||
|
||||
# Basic events used for thread scheduling.
|
||||
|
||||
class Event(object):
|
||||
"""Just a base class identifying Bluelet events. An event is an
|
||||
object yielded from a Bluelet thread coroutine to suspend operation
|
||||
and communicate with the scheduler.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class WaitableEvent(Event):
|
||||
"""A waitable event is one encapsulating an action that can be
|
||||
waited for using a select() call. That is, it's an event with an
|
||||
associated file descriptor.
|
||||
"""
|
||||
def waitables(self):
|
||||
"""Return "waitable" objects to pass to select(). Should return
|
||||
three iterables for input readiness, output readiness, and
|
||||
exceptional conditions (i.e., the three lists passed to
|
||||
select()).
|
||||
"""
|
||||
return (), (), ()
|
||||
|
||||
def fire(self):
|
||||
"""Called when an associated file descriptor becomes ready
|
||||
(i.e., is returned from a select() call).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ValueEvent(Event):
|
||||
"""An event that does nothing but return a fixed value."""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class ExceptionEvent(Event):
|
||||
"""Raise an exception at the yield point. Used internally."""
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
|
||||
|
||||
class SpawnEvent(Event):
|
||||
"""Add a new coroutine thread to the scheduler."""
|
||||
def __init__(self, coro):
|
||||
self.spawned = coro
|
||||
|
||||
|
||||
class JoinEvent(Event):
|
||||
"""Suspend the thread until the specified child thread has
|
||||
completed.
|
||||
"""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
class KillEvent(Event):
|
||||
"""Unschedule a child thread."""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
class DelegationEvent(Event):
|
||||
"""Suspend execution of the current thread, start a new thread and,
|
||||
once the child thread finished, return control to the parent
|
||||
thread.
|
||||
"""
|
||||
def __init__(self, coro):
|
||||
self.spawned = coro
|
||||
|
||||
|
||||
class ReturnEvent(Event):
|
||||
"""Return a value the current thread's delegator at the point of
|
||||
delegation. Ends the current (delegate) thread.
|
||||
"""
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
|
||||
class SleepEvent(WaitableEvent):
|
||||
"""Suspend the thread for a given duration.
|
||||
"""
|
||||
def __init__(self, duration):
|
||||
self.wakeup_time = time.time() + duration
|
||||
|
||||
def time_left(self):
|
||||
return max(self.wakeup_time - time.time(), 0.0)
|
||||
|
||||
|
||||
class ReadEvent(WaitableEvent):
|
||||
"""Reads from a file-like object."""
|
||||
def __init__(self, fd, bufsize):
|
||||
self.fd = fd
|
||||
self.bufsize = bufsize
|
||||
|
||||
def waitables(self):
|
||||
return (self.fd,), (), ()
|
||||
|
||||
def fire(self):
|
||||
return self.fd.read(self.bufsize)
|
||||
|
||||
|
||||
class WriteEvent(WaitableEvent):
|
||||
"""Writes to a file-like object."""
|
||||
def __init__(self, fd, data):
|
||||
self.fd = fd
|
||||
self.data = data
|
||||
|
||||
def waitable(self):
|
||||
return (), (self.fd,), ()
|
||||
|
||||
def fire(self):
|
||||
self.fd.write(self.data)
|
||||
|
||||
|
||||
# Core logic for executing and scheduling threads.
|
||||
|
||||
def _event_select(events):
|
||||
"""Perform a select() over all the Events provided, returning the
|
||||
ones ready to be fired. Only WaitableEvents (including SleepEvents)
|
||||
matter here; all other events are ignored (and thus postponed).
|
||||
"""
|
||||
# Gather waitables and wakeup times.
|
||||
waitable_to_event = {}
|
||||
rlist, wlist, xlist = [], [], []
|
||||
earliest_wakeup = None
|
||||
for event in events:
|
||||
if isinstance(event, SleepEvent):
|
||||
if not earliest_wakeup:
|
||||
earliest_wakeup = event.wakeup_time
|
||||
else:
|
||||
earliest_wakeup = min(earliest_wakeup, event.wakeup_time)
|
||||
elif isinstance(event, WaitableEvent):
|
||||
r, w, x = event.waitables()
|
||||
rlist += r
|
||||
wlist += w
|
||||
xlist += x
|
||||
for waitable in r:
|
||||
waitable_to_event[('r', waitable)] = event
|
||||
for waitable in w:
|
||||
waitable_to_event[('w', waitable)] = event
|
||||
for waitable in x:
|
||||
waitable_to_event[('x', waitable)] = event
|
||||
|
||||
# If we have a any sleeping threads, determine how long to sleep.
|
||||
if earliest_wakeup:
|
||||
timeout = max(earliest_wakeup - time.time(), 0.0)
|
||||
else:
|
||||
timeout = None
|
||||
|
||||
# Perform select() if we have any waitables.
|
||||
if rlist or wlist or xlist:
|
||||
rready, wready, xready = select.select(rlist, wlist, xlist, timeout)
|
||||
else:
|
||||
rready, wready, xready = (), (), ()
|
||||
if timeout:
|
||||
time.sleep(timeout)
|
||||
|
||||
# Gather ready events corresponding to the ready waitables.
|
||||
ready_events = set()
|
||||
for ready in rready:
|
||||
ready_events.add(waitable_to_event[('r', ready)])
|
||||
for ready in wready:
|
||||
ready_events.add(waitable_to_event[('w', ready)])
|
||||
for ready in xready:
|
||||
ready_events.add(waitable_to_event[('x', ready)])
|
||||
|
||||
# Gather any finished sleeps.
|
||||
for event in events:
|
||||
if isinstance(event, SleepEvent) and event.time_left() == 0.0:
|
||||
ready_events.add(event)
|
||||
|
||||
return ready_events
|
||||
|
||||
|
||||
class ThreadException(Exception):
|
||||
def __init__(self, coro, exc_info):
|
||||
self.coro = coro
|
||||
self.exc_info = exc_info
|
||||
|
||||
def reraise(self):
|
||||
_reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
|
||||
|
||||
|
||||
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
|
||||
|
||||
|
||||
class Delegated(Event):
|
||||
"""Placeholder indicating that a thread has delegated execution to a
|
||||
different thread.
|
||||
"""
|
||||
def __init__(self, child):
|
||||
self.child = child
|
||||
|
||||
|
||||
def run(root_coro):
|
||||
"""Schedules a coroutine, running it to completion. This
|
||||
encapsulates the Bluelet scheduler, which the root coroutine can
|
||||
add to by spawning new coroutines.
|
||||
"""
|
||||
# The "threads" dictionary keeps track of all the currently-
|
||||
# executing and suspended coroutines. It maps coroutines to their
|
||||
# currently "blocking" event. The event value may be SUSPENDED if
|
||||
# the coroutine is waiting on some other condition: namely, a
|
||||
# delegated coroutine or a joined coroutine. In this case, the
|
||||
# coroutine should *also* appear as a value in one of the below
|
||||
# dictionaries `delegators` or `joiners`.
|
||||
threads = {root_coro: ValueEvent(None)}
|
||||
|
||||
# Maps child coroutines to delegating parents.
|
||||
delegators = {}
|
||||
|
||||
# Maps child coroutines to joining (exit-waiting) parents.
|
||||
joiners = collections.defaultdict(list)
|
||||
|
||||
def complete_thread(coro, return_value):
|
||||
"""Remove a coroutine from the scheduling pool, awaking
|
||||
delegators and joiners as necessary and returning the specified
|
||||
value to any delegating parent.
|
||||
"""
|
||||
del threads[coro]
|
||||
|
||||
# Resume delegator.
|
||||
if coro in delegators:
|
||||
threads[delegators[coro]] = ValueEvent(return_value)
|
||||
del delegators[coro]
|
||||
|
||||
# Resume joiners.
|
||||
if coro in joiners:
|
||||
for parent in joiners[coro]:
|
||||
threads[parent] = ValueEvent(None)
|
||||
del joiners[coro]
|
||||
|
||||
def advance_thread(coro, value, is_exc=False):
|
||||
"""After an event is fired, run a given coroutine associated with
|
||||
it in the threads dict until it yields again. If the coroutine
|
||||
exits, then the thread is removed from the pool. If the coroutine
|
||||
raises an exception, it is reraised in a ThreadException. If
|
||||
is_exc is True, then the value must be an exc_info tuple and the
|
||||
exception is thrown into the coroutine.
|
||||
"""
|
||||
try:
|
||||
if is_exc:
|
||||
next_event = coro.throw(*value)
|
||||
else:
|
||||
next_event = coro.send(value)
|
||||
except StopIteration:
|
||||
# Thread is done.
|
||||
complete_thread(coro, None)
|
||||
except:
|
||||
# Thread raised some other exception.
|
||||
del threads[coro]
|
||||
raise ThreadException(coro, sys.exc_info())
|
||||
else:
|
||||
if isinstance(next_event, types.GeneratorType):
|
||||
# Automatically invoke sub-coroutines. (Shorthand for
|
||||
# explicit bluelet.call().)
|
||||
next_event = DelegationEvent(next_event)
|
||||
threads[coro] = next_event
|
||||
|
||||
def kill_thread(coro):
|
||||
"""Unschedule this thread and its (recursive) delegates.
|
||||
"""
|
||||
# Collect all coroutines in the delegation stack.
|
||||
coros = [coro]
|
||||
while isinstance(threads[coro], Delegated):
|
||||
coro = threads[coro].child
|
||||
coros.append(coro)
|
||||
|
||||
# Complete each coroutine from the top to the bottom of the
|
||||
# stack.
|
||||
for coro in reversed(coros):
|
||||
complete_thread(coro, None)
|
||||
|
||||
# Continue advancing threads until root thread exits.
|
||||
exit_te = None
|
||||
while threads:
|
||||
try:
|
||||
# Look for events that can be run immediately. Continue
|
||||
# running immediate events until nothing is ready.
|
||||
while True:
|
||||
have_ready = False
|
||||
for coro, event in list(threads.items()):
|
||||
if isinstance(event, SpawnEvent):
|
||||
threads[event.spawned] = ValueEvent(None) # Spawn.
|
||||
advance_thread(coro, None)
|
||||
have_ready = True
|
||||
elif isinstance(event, ValueEvent):
|
||||
advance_thread(coro, event.value)
|
||||
have_ready = True
|
||||
elif isinstance(event, ExceptionEvent):
|
||||
advance_thread(coro, event.exc_info, True)
|
||||
have_ready = True
|
||||
elif isinstance(event, DelegationEvent):
|
||||
threads[coro] = Delegated(event.spawned) # Suspend.
|
||||
threads[event.spawned] = ValueEvent(None) # Spawn.
|
||||
delegators[event.spawned] = coro
|
||||
have_ready = True
|
||||
elif isinstance(event, ReturnEvent):
|
||||
# Thread is done.
|
||||
complete_thread(coro, event.value)
|
||||
have_ready = True
|
||||
elif isinstance(event, JoinEvent):
|
||||
threads[coro] = SUSPENDED # Suspend.
|
||||
joiners[event.child].append(coro)
|
||||
have_ready = True
|
||||
elif isinstance(event, KillEvent):
|
||||
threads[coro] = ValueEvent(None)
|
||||
kill_thread(event.child)
|
||||
have_ready = True
|
||||
|
||||
# Only start the select when nothing else is ready.
|
||||
if not have_ready:
|
||||
break
|
||||
|
||||
# Wait and fire.
|
||||
event2coro = dict((v, k) for k, v in threads.items())
|
||||
for event in _event_select(threads.values()):
|
||||
# Run the IO operation, but catch socket errors.
|
||||
try:
|
||||
value = event.fire()
|
||||
except socket.error as exc:
|
||||
if isinstance(exc.args, tuple) and \
|
||||
exc.args[0] == errno.EPIPE:
|
||||
# Broken pipe. Remote host disconnected.
|
||||
pass
|
||||
else:
|
||||
traceback.print_exc()
|
||||
# Abort the coroutine.
|
||||
threads[event2coro[event]] = ReturnEvent(None)
|
||||
else:
|
||||
advance_thread(event2coro[event], value)
|
||||
|
||||
except ThreadException as te:
|
||||
# Exception raised from inside a thread.
|
||||
event = ExceptionEvent(te.exc_info)
|
||||
if te.coro in delegators:
|
||||
# The thread is a delegate. Raise exception in its
|
||||
# delegator.
|
||||
threads[delegators[te.coro]] = event
|
||||
del delegators[te.coro]
|
||||
else:
|
||||
# The thread is root-level. Raise in client code.
|
||||
exit_te = te
|
||||
break
|
||||
|
||||
except:
|
||||
# For instance, KeyboardInterrupt during select(). Raise
|
||||
# into root thread and terminate others.
|
||||
threads = {root_coro: ExceptionEvent(sys.exc_info())}
|
||||
|
||||
# If any threads still remain, kill them.
|
||||
for coro in threads:
|
||||
coro.close()
|
||||
|
||||
# If we're exiting with an exception, raise it in the client.
|
||||
if exit_te:
|
||||
exit_te.reraise()
|
||||
|
||||
|
||||
# Sockets and their associated events.
|
||||
|
||||
class SocketClosedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Listener(object):
|
||||
"""A socket wrapper object for listening sockets.
|
||||
"""
|
||||
def __init__(self, host, port):
|
||||
"""Create a listening socket on the given hostname and port.
|
||||
"""
|
||||
self._closed = False
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.sock.bind((host, port))
|
||||
self.sock.listen(5)
|
||||
|
||||
def accept(self):
|
||||
"""An event that waits for a connection on the listening socket.
|
||||
When a connection is made, the event returns a Connection
|
||||
object.
|
||||
"""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return AcceptEvent(self)
|
||||
|
||||
def close(self):
|
||||
"""Immediately close the listening socket. (Not an event.)
|
||||
"""
|
||||
self._closed = True
|
||||
self.sock.close()
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""A socket wrapper object for connected sockets.
|
||||
"""
|
||||
def __init__(self, sock, addr):
|
||||
self.sock = sock
|
||||
self.addr = addr
|
||||
self._buf = b''
|
||||
self._closed = False
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
self._closed = True
|
||||
self.sock.close()
|
||||
|
||||
def recv(self, size):
|
||||
"""Read at most size bytes of data from the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
|
||||
if self._buf:
|
||||
# We already have data read previously.
|
||||
out = self._buf[:size]
|
||||
self._buf = self._buf[size:]
|
||||
return ValueEvent(out)
|
||||
else:
|
||||
return ReceiveEvent(self, size)
|
||||
|
||||
def send(self, data):
|
||||
"""Sends data on the socket, returning the number of bytes
|
||||
successfully sent.
|
||||
"""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return SendEvent(self, data)
|
||||
|
||||
def sendall(self, data):
|
||||
"""Send all of data on the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
return SendEvent(self, data, True)
|
||||
|
||||
def readline(self, terminator=b"\n", bufsize=1024):
|
||||
"""Reads a line (delimited by terminator) from the socket."""
|
||||
if self._closed:
|
||||
raise SocketClosedError()
|
||||
|
||||
while True:
|
||||
if terminator in self._buf:
|
||||
line, self._buf = self._buf.split(terminator, 1)
|
||||
line += terminator
|
||||
yield ReturnEvent(line)
|
||||
break
|
||||
data = yield ReceiveEvent(self, bufsize)
|
||||
if data:
|
||||
self._buf += data
|
||||
else:
|
||||
line = self._buf
|
||||
self._buf = b''
|
||||
yield ReturnEvent(line)
|
||||
break
|
||||
|
||||
|
||||
class AcceptEvent(WaitableEvent):
|
||||
"""An event for Listener objects (listening sockets) that suspends
|
||||
execution until the socket gets a connection.
|
||||
"""
|
||||
def __init__(self, listener):
|
||||
self.listener = listener
|
||||
|
||||
def waitables(self):
|
||||
return (self.listener.sock,), (), ()
|
||||
|
||||
def fire(self):
|
||||
sock, addr = self.listener.sock.accept()
|
||||
return Connection(sock, addr)
|
||||
|
||||
|
||||
class ReceiveEvent(WaitableEvent):
|
||||
"""An event for Connection objects (connected sockets) for
|
||||
asynchronously reading data.
|
||||
"""
|
||||
def __init__(self, conn, bufsize):
|
||||
self.conn = conn
|
||||
self.bufsize = bufsize
|
||||
|
||||
def waitables(self):
|
||||
return (self.conn.sock,), (), ()
|
||||
|
||||
def fire(self):
|
||||
return self.conn.sock.recv(self.bufsize)
|
||||
|
||||
|
||||
class SendEvent(WaitableEvent):
|
||||
"""An event for Connection objects (connected sockets) for
|
||||
asynchronously writing data.
|
||||
"""
|
||||
def __init__(self, conn, data, sendall=False):
|
||||
self.conn = conn
|
||||
self.data = data
|
||||
self.sendall = sendall
|
||||
|
||||
def waitables(self):
|
||||
return (), (self.conn.sock,), ()
|
||||
|
||||
def fire(self):
|
||||
if self.sendall:
|
||||
return self.conn.sock.sendall(self.data)
|
||||
else:
|
||||
return self.conn.sock.send(self.data)
|
||||
|
||||
|
||||
# Public interface for threads; each returns an event object that
|
||||
# can immediately be "yield"ed.
|
||||
|
||||
def null():
|
||||
"""Event: yield to the scheduler without doing anything special.
|
||||
"""
|
||||
return ValueEvent(None)
|
||||
|
||||
|
||||
def spawn(coro):
|
||||
"""Event: add another coroutine to the scheduler. Both the parent
|
||||
and child coroutines run concurrently.
|
||||
"""
|
||||
if not isinstance(coro, types.GeneratorType):
|
||||
raise ValueError('%s is not a coroutine' % str(coro))
|
||||
return SpawnEvent(coro)
|
||||
|
||||
|
||||
def call(coro):
|
||||
"""Event: delegate to another coroutine. The current coroutine
|
||||
is resumed once the sub-coroutine finishes. If the sub-coroutine
|
||||
returns a value using end(), then this event returns that value.
|
||||
"""
|
||||
if not isinstance(coro, types.GeneratorType):
|
||||
raise ValueError('%s is not a coroutine' % str(coro))
|
||||
return DelegationEvent(coro)
|
||||
|
||||
|
||||
def end(value=None):
|
||||
"""Event: ends the coroutine and returns a value to its
|
||||
delegator.
|
||||
"""
|
||||
return ReturnEvent(value)
|
||||
|
||||
|
||||
def read(fd, bufsize=None):
|
||||
"""Event: read from a file descriptor asynchronously."""
|
||||
if bufsize is None:
|
||||
# Read all.
|
||||
def reader():
|
||||
buf = []
|
||||
while True:
|
||||
data = yield read(fd, 1024)
|
||||
if not data:
|
||||
break
|
||||
buf.append(data)
|
||||
yield ReturnEvent(''.join(buf))
|
||||
return DelegationEvent(reader())
|
||||
|
||||
else:
|
||||
return ReadEvent(fd, bufsize)
|
||||
|
||||
|
||||
def write(fd, data):
|
||||
"""Event: write to a file descriptor asynchronously."""
|
||||
return WriteEvent(fd, data)
|
||||
|
||||
|
||||
def connect(host, port):
|
||||
"""Event: connect to a network address and return a Connection
|
||||
object for communicating on the socket.
|
||||
"""
|
||||
addr = (host, port)
|
||||
sock = socket.create_connection(addr)
|
||||
return ValueEvent(Connection(sock, addr))
|
||||
|
||||
|
||||
def sleep(duration):
|
||||
"""Event: suspend the thread for ``duration`` seconds.
|
||||
"""
|
||||
return SleepEvent(duration)
|
||||
|
||||
|
||||
def join(coro):
|
||||
"""Suspend the thread until another, previously `spawn`ed thread
|
||||
completes.
|
||||
"""
|
||||
return JoinEvent(coro)
|
||||
|
||||
|
||||
def kill(coro):
|
||||
"""Halt the execution of a different `spawn`ed thread.
|
||||
"""
|
||||
return KillEvent(coro)
|
||||
|
||||
|
||||
# Convenience function for running socket servers.
|
||||
|
||||
def server(host, port, func):
|
||||
"""A coroutine that runs a network server. Host and port specify the
|
||||
listening address. func should be a coroutine that takes a single
|
||||
parameter, a Connection object. The coroutine is invoked for every
|
||||
incoming connection on the listening socket.
|
||||
"""
|
||||
def handler(conn):
|
||||
try:
|
||||
yield func(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
listener = Listener(host, port)
|
||||
try:
|
||||
while True:
|
||||
conn = yield listener.accept()
|
||||
yield spawn(handler(conn))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
listener.close()
|
||||
1281
lib/beets/util/confit.py
Normal file
1281
lib/beets/util/confit.py
Normal file
File diff suppressed because it is too large
Load Diff
40
lib/beets/util/enumeration.py
Normal file
40
lib/beets/util/enumeration.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OrderedEnum(Enum):
|
||||
"""
|
||||
An Enum subclass that allows comparison of members.
|
||||
"""
|
||||
def __ge__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value >= other.value
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value > other.value
|
||||
return NotImplemented
|
||||
|
||||
def __le__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value <= other.value
|
||||
return NotImplemented
|
||||
|
||||
def __lt__(self, other):
|
||||
if self.__class__ is other.__class__:
|
||||
return self.value < other.value
|
||||
return NotImplemented
|
||||
571
lib/beets/util/functemplate.py
Normal file
571
lib/beets/util/functemplate.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""This module implements a string formatter based on the standard PEP
|
||||
292 string.Template class extended with function calls. Variables, as
|
||||
with string.Template, are indicated with $ and functions are delimited
|
||||
with %.
|
||||
|
||||
This module assumes that everything is Unicode: the template and the
|
||||
substitution values. Bytestrings are not supported. Also, the templates
|
||||
always behave like the ``safe_substitute`` method in the standard
|
||||
library: unknown symbols are left intact.
|
||||
|
||||
This is sort of like a tiny, horrible degeneration of a real templating
|
||||
engine like Jinja2 or Mustache.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import ast
|
||||
import dis
|
||||
import types
|
||||
|
||||
SYMBOL_DELIM = u'$'
|
||||
FUNC_DELIM = u'%'
|
||||
GROUP_OPEN = u'{'
|
||||
GROUP_CLOSE = u'}'
|
||||
ARG_SEP = u','
|
||||
ESCAPE_CHAR = u'$'
|
||||
|
||||
VARIABLE_PREFIX = '__var_'
|
||||
FUNCTION_PREFIX = '__func_'
|
||||
|
||||
|
||||
class Environment(object):
|
||||
"""Contains the values and functions to be substituted into a
|
||||
template.
|
||||
"""
|
||||
def __init__(self, values, functions):
|
||||
self.values = values
|
||||
self.functions = functions
|
||||
|
||||
|
||||
# Code generation helpers.
|
||||
|
||||
def ex_lvalue(name):
|
||||
"""A variable load expression."""
|
||||
return ast.Name(name, ast.Store())
|
||||
|
||||
|
||||
def ex_rvalue(name):
|
||||
"""A variable store expression."""
|
||||
return ast.Name(name, ast.Load())
|
||||
|
||||
|
||||
def ex_literal(val):
|
||||
"""An int, float, long, bool, string, or None literal with the given
|
||||
value.
|
||||
"""
|
||||
if val is None:
|
||||
return ast.Name('None', ast.Load())
|
||||
elif isinstance(val, (int, float, long)):
|
||||
return ast.Num(val)
|
||||
elif isinstance(val, bool):
|
||||
return ast.Name(str(val), ast.Load())
|
||||
elif isinstance(val, basestring):
|
||||
return ast.Str(val)
|
||||
raise TypeError('no literal for {0}'.format(type(val)))
|
||||
|
||||
|
||||
def ex_varassign(name, expr):
|
||||
"""Assign an expression into a single variable. The expression may
|
||||
either be an `ast.expr` object or a value to be used as a literal.
|
||||
"""
|
||||
if not isinstance(expr, ast.expr):
|
||||
expr = ex_literal(expr)
|
||||
return ast.Assign([ex_lvalue(name)], expr)
|
||||
|
||||
|
||||
def ex_call(func, args):
|
||||
"""A function-call expression with only positional parameters. The
|
||||
function may be an expression or the name of a function. Each
|
||||
argument may be an expression or a value to be used as a literal.
|
||||
"""
|
||||
if isinstance(func, basestring):
|
||||
func = ex_rvalue(func)
|
||||
|
||||
args = list(args)
|
||||
for i in range(len(args)):
|
||||
if not isinstance(args[i], ast.expr):
|
||||
args[i] = ex_literal(args[i])
|
||||
|
||||
return ast.Call(func, args, [], None, None)
|
||||
|
||||
|
||||
def compile_func(arg_names, statements, name='_the_func', debug=False):
|
||||
"""Compile a list of statements as the body of a function and return
|
||||
the resulting Python function. If `debug`, then print out the
|
||||
bytecode of the compiled function.
|
||||
"""
|
||||
func_def = ast.FunctionDef(
|
||||
name,
|
||||
ast.arguments(
|
||||
[ast.Name(n, ast.Param()) for n in arg_names],
|
||||
None, None,
|
||||
[ex_literal(None) for _ in arg_names],
|
||||
),
|
||||
statements,
|
||||
[],
|
||||
)
|
||||
mod = ast.Module([func_def])
|
||||
ast.fix_missing_locations(mod)
|
||||
|
||||
prog = compile(mod, '<generated>', 'exec')
|
||||
|
||||
# Debug: show bytecode.
|
||||
if debug:
|
||||
dis.dis(prog)
|
||||
for const in prog.co_consts:
|
||||
if isinstance(const, types.CodeType):
|
||||
dis.dis(const)
|
||||
|
||||
the_locals = {}
|
||||
exec prog in {}, the_locals
|
||||
return the_locals[name]
|
||||
|
||||
|
||||
# AST nodes for the template language.
|
||||
|
||||
class Symbol(object):
|
||||
"""A variable-substitution symbol in a template."""
|
||||
def __init__(self, ident, original):
|
||||
self.ident = ident
|
||||
self.original = original
|
||||
|
||||
def __repr__(self):
|
||||
return u'Symbol(%s)' % repr(self.ident)
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the symbol in the environment, returning a Unicode
|
||||
string.
|
||||
"""
|
||||
if self.ident in env.values:
|
||||
# Substitute for a value.
|
||||
return env.values[self.ident]
|
||||
else:
|
||||
# Keep original text.
|
||||
return self.original
|
||||
|
||||
def translate(self):
|
||||
"""Compile the variable lookup."""
|
||||
expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8'))
|
||||
return [expr], set([self.ident.encode('utf8')]), set()
|
||||
|
||||
|
||||
class Call(object):
|
||||
"""A function call in a template."""
|
||||
def __init__(self, ident, args, original):
|
||||
self.ident = ident
|
||||
self.args = args
|
||||
self.original = original
|
||||
|
||||
def __repr__(self):
|
||||
return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args),
|
||||
repr(self.original))
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the function call in the environment, returning a
|
||||
Unicode string.
|
||||
"""
|
||||
if self.ident in env.functions:
|
||||
arg_vals = [expr.evaluate(env) for expr in self.args]
|
||||
try:
|
||||
out = env.functions[self.ident](*arg_vals)
|
||||
except Exception as exc:
|
||||
# Function raised exception! Maybe inlining the name of
|
||||
# the exception will help debug.
|
||||
return u'<%s>' % unicode(exc)
|
||||
return unicode(out)
|
||||
else:
|
||||
return self.original
|
||||
|
||||
def translate(self):
|
||||
"""Compile the function call."""
|
||||
varnames = set()
|
||||
funcnames = set([self.ident.encode('utf8')])
|
||||
|
||||
arg_exprs = []
|
||||
for arg in self.args:
|
||||
subexprs, subvars, subfuncs = arg.translate()
|
||||
varnames.update(subvars)
|
||||
funcnames.update(subfuncs)
|
||||
|
||||
# Create a subexpression that joins the result components of
|
||||
# the arguments.
|
||||
arg_exprs.append(ex_call(
|
||||
ast.Attribute(ex_literal(u''), 'join', ast.Load()),
|
||||
[ex_call(
|
||||
'map',
|
||||
[
|
||||
ex_rvalue('unicode'),
|
||||
ast.List(subexprs, ast.Load()),
|
||||
]
|
||||
)],
|
||||
))
|
||||
|
||||
subexpr_call = ex_call(
|
||||
FUNCTION_PREFIX + self.ident.encode('utf8'),
|
||||
arg_exprs
|
||||
)
|
||||
return [subexpr_call], varnames, funcnames
|
||||
|
||||
|
||||
class Expression(object):
|
||||
"""Top-level template construct: contains a list of text blobs,
|
||||
Symbols, and Calls.
|
||||
"""
|
||||
def __init__(self, parts):
|
||||
self.parts = parts
|
||||
|
||||
def __repr__(self):
|
||||
return u'Expression(%s)' % (repr(self.parts))
|
||||
|
||||
def evaluate(self, env):
|
||||
"""Evaluate the entire expression in the environment, returning
|
||||
a Unicode string.
|
||||
"""
|
||||
out = []
|
||||
for part in self.parts:
|
||||
if isinstance(part, basestring):
|
||||
out.append(part)
|
||||
else:
|
||||
out.append(part.evaluate(env))
|
||||
return u''.join(map(unicode, out))
|
||||
|
||||
def translate(self):
|
||||
"""Compile the expression to a list of Python AST expressions, a
|
||||
set of variable names used, and a set of function names.
|
||||
"""
|
||||
expressions = []
|
||||
varnames = set()
|
||||
funcnames = set()
|
||||
for part in self.parts:
|
||||
if isinstance(part, basestring):
|
||||
expressions.append(ex_literal(part))
|
||||
else:
|
||||
e, v, f = part.translate()
|
||||
expressions.extend(e)
|
||||
varnames.update(v)
|
||||
funcnames.update(f)
|
||||
return expressions, varnames, funcnames
|
||||
|
||||
|
||||
# Parser.
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Parser(object):
|
||||
"""Parses a template expression string. Instantiate the class with
|
||||
the template source and call ``parse_expression``. The ``pos`` field
|
||||
will indicate the character after the expression finished and
|
||||
``parts`` will contain a list of Unicode strings, Symbols, and Calls
|
||||
reflecting the concatenated portions of the expression.
|
||||
|
||||
This is a terrible, ad-hoc parser implementation based on a
|
||||
left-to-right scan with no lexing step to speak of; it's probably
|
||||
both inefficient and incorrect. Maybe this should eventually be
|
||||
replaced with a real, accepted parsing technique (PEG, parser
|
||||
generator, etc.).
|
||||
"""
|
||||
def __init__(self, string):
|
||||
self.string = string
|
||||
self.pos = 0
|
||||
self.parts = []
|
||||
|
||||
# Common parsing resources.
|
||||
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
|
||||
ARG_SEP, ESCAPE_CHAR)
|
||||
special_char_re = re.compile(ur'[%s]|$' %
|
||||
u''.join(re.escape(c) for c in special_chars))
|
||||
|
||||
def parse_expression(self):
|
||||
"""Parse a template expression starting at ``pos``. Resulting
|
||||
components (Unicode strings, Symbols, and Calls) are added to
|
||||
the ``parts`` field, a list. The ``pos`` field is updated to be
|
||||
the next character after the expression.
|
||||
"""
|
||||
text_parts = []
|
||||
|
||||
while self.pos < len(self.string):
|
||||
char = self.string[self.pos]
|
||||
|
||||
if char not in self.special_chars:
|
||||
# A non-special character. Skip to the next special
|
||||
# character, treating the interstice as literal text.
|
||||
next_pos = (
|
||||
self.special_char_re.search(self.string[self.pos:]).start()
|
||||
+ self.pos
|
||||
)
|
||||
text_parts.append(self.string[self.pos:next_pos])
|
||||
self.pos = next_pos
|
||||
continue
|
||||
|
||||
if self.pos == len(self.string) - 1:
|
||||
# The last character can never begin a structure, so we
|
||||
# just interpret it as a literal character (unless it
|
||||
# terminates the expression, as with , and }).
|
||||
if char not in (GROUP_CLOSE, ARG_SEP):
|
||||
text_parts.append(char)
|
||||
self.pos += 1
|
||||
break
|
||||
|
||||
next_char = self.string[self.pos + 1]
|
||||
if char == ESCAPE_CHAR and next_char in \
|
||||
(SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP):
|
||||
# An escaped special character ($$, $}, etc.). Note that
|
||||
# ${ is not an escape sequence: this is ambiguous with
|
||||
# the start of a symbol and it's not necessary (just
|
||||
# using { suffices in all cases).
|
||||
text_parts.append(next_char)
|
||||
self.pos += 2 # Skip the next character.
|
||||
continue
|
||||
|
||||
# Shift all characters collected so far into a single string.
|
||||
if text_parts:
|
||||
self.parts.append(u''.join(text_parts))
|
||||
text_parts = []
|
||||
|
||||
if char == SYMBOL_DELIM:
|
||||
# Parse a symbol.
|
||||
self.parse_symbol()
|
||||
elif char == FUNC_DELIM:
|
||||
# Parse a function call.
|
||||
self.parse_call()
|
||||
elif char in (GROUP_CLOSE, ARG_SEP):
|
||||
# Template terminated.
|
||||
break
|
||||
elif char == GROUP_OPEN:
|
||||
# Start of a group has no meaning hear; just pass
|
||||
# through the character.
|
||||
text_parts.append(char)
|
||||
self.pos += 1
|
||||
else:
|
||||
assert False
|
||||
|
||||
# If any parsed characters remain, shift them into a string.
|
||||
if text_parts:
|
||||
self.parts.append(u''.join(text_parts))
|
||||
|
||||
def parse_symbol(self):
|
||||
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
|
||||
starting at ``pos``. Possibly appends a Symbol object (or,
|
||||
failing that, text) to the ``parts`` field and updates ``pos``.
|
||||
The character at ``pos`` must, as a precondition, be ``$``.
|
||||
"""
|
||||
assert self.pos < len(self.string)
|
||||
assert self.string[self.pos] == SYMBOL_DELIM
|
||||
|
||||
if self.pos == len(self.string) - 1:
|
||||
# Last character.
|
||||
self.parts.append(SYMBOL_DELIM)
|
||||
self.pos += 1
|
||||
return
|
||||
|
||||
next_char = self.string[self.pos + 1]
|
||||
start_pos = self.pos
|
||||
self.pos += 1
|
||||
|
||||
if next_char == GROUP_OPEN:
|
||||
# A symbol like ${this}.
|
||||
self.pos += 1 # Skip opening.
|
||||
closer = self.string.find(GROUP_CLOSE, self.pos)
|
||||
if closer == -1 or closer == self.pos:
|
||||
# No closing brace found or identifier is empty.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
else:
|
||||
# Closer found.
|
||||
ident = self.string[self.pos:closer]
|
||||
self.pos = closer + 1
|
||||
self.parts.append(Symbol(ident,
|
||||
self.string[start_pos:self.pos]))
|
||||
|
||||
else:
|
||||
# A bare-word symbol.
|
||||
ident = self._parse_ident()
|
||||
if ident:
|
||||
# Found a real symbol.
|
||||
self.parts.append(Symbol(ident,
|
||||
self.string[start_pos:self.pos]))
|
||||
else:
|
||||
# A standalone $.
|
||||
self.parts.append(SYMBOL_DELIM)
|
||||
|
||||
def parse_call(self):
|
||||
"""Parse a function call (like ``%foo{bar,baz}``) starting at
|
||||
``pos``. Possibly appends a Call object to ``parts`` and update
|
||||
``pos``. The character at ``pos`` must be ``%``.
|
||||
"""
|
||||
assert self.pos < len(self.string)
|
||||
assert self.string[self.pos] == FUNC_DELIM
|
||||
|
||||
start_pos = self.pos
|
||||
self.pos += 1
|
||||
|
||||
ident = self._parse_ident()
|
||||
if not ident:
|
||||
# No function name.
|
||||
self.parts.append(FUNC_DELIM)
|
||||
return
|
||||
|
||||
if self.pos >= len(self.string):
|
||||
# Identifier terminates string.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
if self.string[self.pos] != GROUP_OPEN:
|
||||
# Argument list not opened.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
# Skip past opening brace and try to parse an argument list.
|
||||
self.pos += 1
|
||||
args = self.parse_argument_list()
|
||||
if self.pos >= len(self.string) or \
|
||||
self.string[self.pos] != GROUP_CLOSE:
|
||||
# Arguments unclosed.
|
||||
self.parts.append(self.string[start_pos:self.pos])
|
||||
return
|
||||
|
||||
self.pos += 1 # Move past closing brace.
|
||||
self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
|
||||
|
||||
def parse_argument_list(self):
|
||||
"""Parse a list of arguments starting at ``pos``, returning a
|
||||
list of Expression objects. Does not modify ``parts``. Should
|
||||
leave ``pos`` pointing to a } character or the end of the
|
||||
string.
|
||||
"""
|
||||
# Try to parse a subexpression in a subparser.
|
||||
expressions = []
|
||||
|
||||
while self.pos < len(self.string):
|
||||
subparser = Parser(self.string[self.pos:])
|
||||
subparser.parse_expression()
|
||||
|
||||
# Extract and advance past the parsed expression.
|
||||
expressions.append(Expression(subparser.parts))
|
||||
self.pos += subparser.pos
|
||||
|
||||
if self.pos >= len(self.string) or \
|
||||
self.string[self.pos] == GROUP_CLOSE:
|
||||
# Argument list terminated by EOF or closing brace.
|
||||
break
|
||||
|
||||
# Only other way to terminate an expression is with ,.
|
||||
# Continue to the next argument.
|
||||
assert self.string[self.pos] == ARG_SEP
|
||||
self.pos += 1
|
||||
|
||||
return expressions
|
||||
|
||||
def _parse_ident(self):
|
||||
"""Parse an identifier and return it (possibly an empty string).
|
||||
Updates ``pos``.
|
||||
"""
|
||||
remainder = self.string[self.pos:]
|
||||
ident = re.match(ur'\w*', remainder).group(0)
|
||||
self.pos += len(ident)
|
||||
return ident
|
||||
|
||||
|
||||
def _parse(template):
|
||||
"""Parse a top-level template string Expression. Any extraneous text
|
||||
is considered literal text.
|
||||
"""
|
||||
parser = Parser(template)
|
||||
parser.parse_expression()
|
||||
|
||||
parts = parser.parts
|
||||
remainder = parser.string[parser.pos:]
|
||||
if remainder:
|
||||
parts.append(remainder)
|
||||
return Expression(parts)
|
||||
|
||||
|
||||
# External interface.
|
||||
|
||||
class Template(object):
|
||||
"""A string template, including text, Symbols, and Calls.
|
||||
"""
|
||||
def __init__(self, template):
|
||||
self.expr = _parse(template)
|
||||
self.original = template
|
||||
self.compiled = self.translate()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.original == other.original
|
||||
|
||||
def interpret(self, values={}, functions={}):
|
||||
"""Like `substitute`, but forces the interpreter (rather than
|
||||
the compiled version) to be used. The interpreter includes
|
||||
exception-handling code for missing variables and buggy template
|
||||
functions but is much slower.
|
||||
"""
|
||||
return self.expr.evaluate(Environment(values, functions))
|
||||
|
||||
def substitute(self, values={}, functions={}):
|
||||
"""Evaluate the template given the values and functions.
|
||||
"""
|
||||
try:
|
||||
res = self.compiled(values, functions)
|
||||
except: # Handle any exceptions thrown by compiled version.
|
||||
res = self.interpret(values, functions)
|
||||
return res
|
||||
|
||||
def translate(self):
|
||||
"""Compile the template to a Python function."""
|
||||
expressions, varnames, funcnames = self.expr.translate()
|
||||
|
||||
argnames = []
|
||||
for varname in varnames:
|
||||
argnames.append(VARIABLE_PREFIX.encode('utf8') + varname)
|
||||
for funcname in funcnames:
|
||||
argnames.append(FUNCTION_PREFIX.encode('utf8') + funcname)
|
||||
|
||||
func = compile_func(
|
||||
argnames,
|
||||
[ast.Return(ast.List(expressions, ast.Load()))],
|
||||
)
|
||||
|
||||
def wrapper_func(values={}, functions={}):
|
||||
args = {}
|
||||
for varname in varnames:
|
||||
args[VARIABLE_PREFIX + varname] = values[varname]
|
||||
for funcname in funcnames:
|
||||
args[FUNCTION_PREFIX + funcname] = functions[funcname]
|
||||
parts = func(**args)
|
||||
return u''.join(parts)
|
||||
|
||||
return wrapper_func
|
||||
|
||||
|
||||
# Performance tests.
|
||||
|
||||
if __name__ == '__main__':
|
||||
import timeit
|
||||
_tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar')
|
||||
_vars = {'bar': 'qux'}
|
||||
_funcs = {'baz': unicode.upper}
|
||||
interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
|
||||
'from __main__ import _tmpl, _vars, _funcs',
|
||||
number=10000)
|
||||
print(interp_time)
|
||||
comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)',
|
||||
'from __main__ import _tmpl, _vars, _funcs',
|
||||
number=10000)
|
||||
print(comp_time)
|
||||
print('Speedup:', interp_time / comp_time)
|
||||
517
lib/beets/util/pipeline.py
Normal file
517
lib/beets/util/pipeline.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Simple but robust implementation of generator/coroutine-based
|
||||
pipelines in Python. The pipelines may be run either sequentially
|
||||
(single-threaded) or in parallel (one thread per pipeline stage).
|
||||
|
||||
This implementation supports pipeline bubbles (indications that the
|
||||
processing for a certain item should abort). To use them, yield the
|
||||
BUBBLE constant from any stage coroutine except the last.
|
||||
|
||||
In the parallel case, the implementation transparently handles thread
|
||||
shutdown when the processing is complete and when a stage raises an
|
||||
exception. KeyboardInterrupts (^C) are also handled.
|
||||
|
||||
When running a parallel pipeline, it is also possible to use
|
||||
multiple coroutines for the same pipeline stage; this lets you speed
|
||||
up a bottleneck stage by dividing its work among multiple threads.
|
||||
To do so, pass an iterable of coroutines to the Pipeline constructor
|
||||
in place of any single coroutine.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import Queue
|
||||
from threading import Thread, Lock
|
||||
import sys
|
||||
|
||||
BUBBLE = '__PIPELINE_BUBBLE__'
|
||||
POISON = '__PIPELINE_POISON__'
|
||||
|
||||
DEFAULT_QUEUE_SIZE = 16
|
||||
|
||||
|
||||
def _invalidate_queue(q, val=None, sync=True):
|
||||
"""Breaks a Queue such that it never blocks, always has size 1,
|
||||
and has no maximum size. get()ing from the queue returns `val`,
|
||||
which defaults to None. `sync` controls whether a lock is
|
||||
required (because it's not reentrant!).
|
||||
"""
|
||||
def _qsize(len=len):
|
||||
return 1
|
||||
|
||||
def _put(item):
|
||||
pass
|
||||
|
||||
def _get():
|
||||
return val
|
||||
|
||||
if sync:
|
||||
q.mutex.acquire()
|
||||
|
||||
try:
|
||||
q.maxsize = 0
|
||||
q._qsize = _qsize
|
||||
q._put = _put
|
||||
q._get = _get
|
||||
q.not_empty.notifyAll()
|
||||
q.not_full.notifyAll()
|
||||
|
||||
finally:
|
||||
if sync:
|
||||
q.mutex.release()
|
||||
|
||||
|
||||
class CountedQueue(Queue.Queue):
|
||||
"""A queue that keeps track of the number of threads that are
|
||||
still feeding into it. The queue is poisoned when all threads are
|
||||
finished with the queue.
|
||||
"""
|
||||
def __init__(self, maxsize=0):
|
||||
Queue.Queue.__init__(self, maxsize)
|
||||
self.nthreads = 0
|
||||
self.poisoned = False
|
||||
|
||||
def acquire(self):
|
||||
"""Indicate that a thread will start putting into this queue.
|
||||
Should not be called after the queue is already poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
assert not self.poisoned
|
||||
assert self.nthreads >= 0
|
||||
self.nthreads += 1
|
||||
|
||||
def release(self):
|
||||
"""Indicate that a thread that was putting into this queue has
|
||||
exited. If this is the last thread using the queue, the queue
|
||||
is poisoned.
|
||||
"""
|
||||
with self.mutex:
|
||||
self.nthreads -= 1
|
||||
assert self.nthreads >= 0
|
||||
if self.nthreads == 0:
|
||||
# All threads are done adding to this queue. Poison it
|
||||
# when it becomes empty.
|
||||
self.poisoned = True
|
||||
|
||||
# Replacement _get invalidates when no items remain.
|
||||
_old_get = self._get
|
||||
|
||||
def _get():
|
||||
out = _old_get()
|
||||
if not self.queue:
|
||||
_invalidate_queue(self, POISON, False)
|
||||
return out
|
||||
|
||||
if self.queue:
|
||||
# Items remain.
|
||||
self._get = _get
|
||||
else:
|
||||
# No items. Invalidate immediately.
|
||||
_invalidate_queue(self, POISON, False)
|
||||
|
||||
|
||||
class MultiMessage(object):
|
||||
"""A message yielded by a pipeline stage encapsulating multiple
|
||||
values to be sent to the next stage.
|
||||
"""
|
||||
def __init__(self, messages):
|
||||
self.messages = messages
|
||||
|
||||
|
||||
def multiple(messages):
|
||||
"""Yield multiple([message, ..]) from a pipeline stage to send
|
||||
multiple values to the next pipeline stage.
|
||||
"""
|
||||
return MultiMessage(messages)
|
||||
|
||||
|
||||
def stage(func):
|
||||
"""Decorate a function to become a simple stage.
|
||||
|
||||
>>> @stage
|
||||
... def add(n, i):
|
||||
... return i + n
|
||||
>>> pipe = Pipeline([
|
||||
... iter([1, 2, 3]),
|
||||
... add(2),
|
||||
... ])
|
||||
>>> list(pipe.pull())
|
||||
[3, 4, 5]
|
||||
"""
|
||||
|
||||
def coro(*args):
|
||||
task = None
|
||||
while True:
|
||||
task = yield task
|
||||
task = func(*(args + (task,)))
|
||||
return coro
|
||||
|
||||
|
||||
def mutator_stage(func):
|
||||
"""Decorate a function that manipulates items in a coroutine to
|
||||
become a simple stage.
|
||||
|
||||
>>> @mutator_stage
|
||||
... def setkey(key, item):
|
||||
... item[key] = True
|
||||
>>> pipe = Pipeline([
|
||||
... iter([{'x': False}, {'a': False}]),
|
||||
... setkey('x'),
|
||||
... ])
|
||||
>>> list(pipe.pull())
|
||||
[{'x': True}, {'a': False, 'x': True}]
|
||||
"""
|
||||
|
||||
def coro(*args):
|
||||
task = None
|
||||
while True:
|
||||
task = yield task
|
||||
func(*(args + (task,)))
|
||||
return coro
|
||||
|
||||
|
||||
def _allmsgs(obj):
|
||||
"""Returns a list of all the messages encapsulated in obj. If obj
|
||||
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
|
||||
returns an empty list. Otherwise, returns a list containing obj.
|
||||
"""
|
||||
if isinstance(obj, MultiMessage):
|
||||
return obj.messages
|
||||
elif obj == BUBBLE:
|
||||
return []
|
||||
else:
|
||||
return [obj]
|
||||
|
||||
|
||||
class PipelineThread(Thread):
|
||||
"""Abstract base class for pipeline-stage threads."""
|
||||
def __init__(self, all_threads):
|
||||
super(PipelineThread, self).__init__()
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
self.all_threads = all_threads
|
||||
self.exc_info = None
|
||||
|
||||
def abort(self):
|
||||
"""Shut down the thread at the next chance possible.
|
||||
"""
|
||||
with self.abort_lock:
|
||||
self.abort_flag = True
|
||||
|
||||
# Ensure that we are not blocking on a queue read or write.
|
||||
if hasattr(self, 'in_queue'):
|
||||
_invalidate_queue(self.in_queue, POISON)
|
||||
if hasattr(self, 'out_queue'):
|
||||
_invalidate_queue(self.out_queue, POISON)
|
||||
|
||||
def abort_all(self, exc_info):
|
||||
"""Abort all other threads in the system for an exception.
|
||||
"""
|
||||
self.exc_info = exc_info
|
||||
for thread in self.all_threads:
|
||||
thread.abort()
|
||||
|
||||
|
||||
class FirstPipelineThread(PipelineThread):
|
||||
"""The thread running the first stage in a parallel pipeline setup.
|
||||
The coroutine should just be a generator.
|
||||
"""
|
||||
def __init__(self, coro, out_queue, all_threads):
|
||||
super(FirstPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
self.abort_lock = Lock()
|
||||
self.abort_flag = False
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the value from the generator.
|
||||
try:
|
||||
msg = self.coro.next()
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Send messages to the next stage.
|
||||
for msg in _allmsgs(msg):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Generator finished; shut down the pipeline.
|
||||
self.out_queue.release()
|
||||
|
||||
|
||||
class MiddlePipelineThread(PipelineThread):
|
||||
"""A thread running any stage in the pipeline except the first or
|
||||
last.
|
||||
"""
|
||||
def __init__(self, coro, in_queue, out_queue, all_threads):
|
||||
super(MiddlePipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
self.out_queue = out_queue
|
||||
self.out_queue.acquire()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Invoke the current stage.
|
||||
out = self.coro.send(msg)
|
||||
|
||||
# Send messages to next stage.
|
||||
for msg in _allmsgs(out):
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
self.out_queue.put(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
# Pipeline is shutting down normally.
|
||||
self.out_queue.release()
|
||||
|
||||
|
||||
class LastPipelineThread(PipelineThread):
|
||||
"""A thread running the last stage in a pipeline. The coroutine
|
||||
should yield nothing.
|
||||
"""
|
||||
def __init__(self, coro, in_queue, all_threads):
|
||||
super(LastPipelineThread, self).__init__(all_threads)
|
||||
self.coro = coro
|
||||
self.in_queue = in_queue
|
||||
|
||||
def run(self):
|
||||
# Prime the coroutine.
|
||||
self.coro.next()
|
||||
|
||||
try:
|
||||
while True:
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Get the message from the previous stage.
|
||||
msg = self.in_queue.get()
|
||||
if msg is POISON:
|
||||
break
|
||||
|
||||
with self.abort_lock:
|
||||
if self.abort_flag:
|
||||
return
|
||||
|
||||
# Send to consumer.
|
||||
self.coro.send(msg)
|
||||
|
||||
except:
|
||||
self.abort_all(sys.exc_info())
|
||||
return
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
"""Represents a staged pattern of work. Each stage in the pipeline
|
||||
is a coroutine that receives messages from the previous stage and
|
||||
yields messages to be sent to the next stage.
|
||||
"""
|
||||
def __init__(self, stages):
|
||||
"""Makes a new pipeline from a list of coroutines. There must
|
||||
be at least two stages.
|
||||
"""
|
||||
if len(stages) < 2:
|
||||
raise ValueError('pipeline must have at least two stages')
|
||||
self.stages = []
|
||||
for stage in stages:
|
||||
if isinstance(stage, (list, tuple)):
|
||||
self.stages.append(stage)
|
||||
else:
|
||||
# Default to one thread per stage.
|
||||
self.stages.append((stage,))
|
||||
|
||||
def run_sequential(self):
|
||||
"""Run the pipeline sequentially in the current thread. The
|
||||
stages are run one after the other. Only the first coroutine
|
||||
in each stage is used.
|
||||
"""
|
||||
list(self.pull())
|
||||
|
||||
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
|
||||
"""Run the pipeline in parallel using one thread per stage. The
|
||||
messages between the stages are stored in queues of the given
|
||||
size.
|
||||
"""
|
||||
queue_count = len(self.stages) - 1
|
||||
queues = [CountedQueue(queue_size) for i in range(queue_count)]
|
||||
threads = []
|
||||
|
||||
# Set up first stage.
|
||||
for coro in self.stages[0]:
|
||||
threads.append(FirstPipelineThread(coro, queues[0], threads))
|
||||
|
||||
# Middle stages.
|
||||
for i in range(1, queue_count):
|
||||
for coro in self.stages[i]:
|
||||
threads.append(MiddlePipelineThread(
|
||||
coro, queues[i - 1], queues[i], threads
|
||||
))
|
||||
|
||||
# Last stage.
|
||||
for coro in self.stages[-1]:
|
||||
threads.append(
|
||||
LastPipelineThread(coro, queues[-1], threads)
|
||||
)
|
||||
|
||||
# Start threads.
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for termination. The final thread lasts the longest.
|
||||
try:
|
||||
# Using a timeout allows us to receive KeyboardInterrupt
|
||||
# exceptions during the join().
|
||||
while threads[-1].isAlive():
|
||||
threads[-1].join(1)
|
||||
|
||||
except:
|
||||
# Stop all the threads immediately.
|
||||
for thread in threads:
|
||||
thread.abort()
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Make completely sure that all the threads have finished
|
||||
# before we return. They should already be either finished,
|
||||
# in normal operation, or aborted, in case of an exception.
|
||||
for thread in threads[:-1]:
|
||||
thread.join()
|
||||
|
||||
for thread in threads:
|
||||
exc_info = thread.exc_info
|
||||
if exc_info:
|
||||
# Make the exception appear as it was raised originally.
|
||||
raise exc_info[0], exc_info[1], exc_info[2]
|
||||
|
||||
def pull(self):
|
||||
"""Yield elements from the end of the pipeline. Runs the stages
|
||||
sequentially until the last yields some messages. Each of the messages
|
||||
is then yielded by ``pulled.next()``. If the pipeline has a consumer,
|
||||
that is the last stage does not yield any messages, then pull will not
|
||||
yield any messages. Only the first coroutine in each stage is used
|
||||
"""
|
||||
coros = [stage[0] for stage in self.stages]
|
||||
|
||||
# "Prime" the coroutines.
|
||||
for coro in coros[1:]:
|
||||
coro.next()
|
||||
|
||||
# Begin the pipeline.
|
||||
for out in coros[0]:
|
||||
msgs = _allmsgs(out)
|
||||
for coro in coros[1:]:
|
||||
next_msgs = []
|
||||
for msg in msgs:
|
||||
out = coro.send(msg)
|
||||
next_msgs.extend(_allmsgs(out))
|
||||
msgs = next_msgs
|
||||
for msg in msgs:
|
||||
yield msg
|
||||
|
||||
# Smoke test.
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
|
||||
# Test a normally-terminating pipeline both in sequence and
|
||||
# in parallel.
|
||||
def produce():
|
||||
for i in range(5):
|
||||
print('generating %i' % i)
|
||||
time.sleep(1)
|
||||
yield i
|
||||
|
||||
def work():
|
||||
num = yield
|
||||
while True:
|
||||
print('processing %i' % num)
|
||||
time.sleep(2)
|
||||
num = yield num * 2
|
||||
|
||||
def consume():
|
||||
while True:
|
||||
num = yield
|
||||
time.sleep(1)
|
||||
print('received %i' % num)
|
||||
|
||||
ts_start = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_sequential()
|
||||
ts_seq = time.time()
|
||||
Pipeline([produce(), work(), consume()]).run_parallel()
|
||||
ts_par = time.time()
|
||||
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
|
||||
ts_end = time.time()
|
||||
print('Sequential time:', ts_seq - ts_start)
|
||||
print('Parallel time:', ts_par - ts_seq)
|
||||
print('Multiply-parallel time:', ts_end - ts_par)
|
||||
print()
|
||||
|
||||
# Test a pipeline that raises an exception.
|
||||
def exc_produce():
|
||||
for i in range(10):
|
||||
print('generating %i' % i)
|
||||
time.sleep(1)
|
||||
yield i
|
||||
|
||||
def exc_work():
|
||||
num = yield
|
||||
while True:
|
||||
print('processing %i' % num)
|
||||
time.sleep(3)
|
||||
if num == 3:
|
||||
raise Exception()
|
||||
num = yield num * 2
|
||||
|
||||
def exc_consume():
|
||||
while True:
|
||||
num = yield
|
||||
print('received %i' % num)
|
||||
|
||||
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)
|
||||
50
lib/beets/vfs.py
Normal file
50
lib/beets/vfs.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""A simple utility for constructing filesystem-like trees from beets
|
||||
libraries.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
from beets import util
|
||||
|
||||
Node = namedtuple('Node', ['files', 'dirs'])
|
||||
|
||||
|
||||
def _insert(node, path, itemid):
|
||||
"""Insert an item into a virtual filesystem node."""
|
||||
if len(path) == 1:
|
||||
# Last component. Insert file.
|
||||
node.files[path[0]] = itemid
|
||||
else:
|
||||
# In a directory.
|
||||
dirname = path[0]
|
||||
rest = path[1:]
|
||||
if dirname not in node.dirs:
|
||||
node.dirs[dirname] = Node({}, {})
|
||||
_insert(node.dirs[dirname], rest, itemid)
|
||||
|
||||
|
||||
def libtree(lib):
|
||||
"""Generates a filesystem-like directory tree for the files
|
||||
contained in `lib`. Filesystem nodes are (files, dirs) named
|
||||
tuples in which both components are dictionaries. The first
|
||||
maps filenames to Item ids. The second maps directory names to
|
||||
child node tuples.
|
||||
"""
|
||||
root = Node({}, {})
|
||||
for item in lib.items():
|
||||
dest = item.destination(fragment=True)
|
||||
parts = util.components(dest)
|
||||
_insert(root, parts, item.id)
|
||||
return root
|
||||
19
lib/beetsplug/__init__.py
Normal file
19
lib/beetsplug/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2013, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""A namespace package for beets plugins."""
|
||||
|
||||
# Make this a namespace package.
|
||||
from pkgutil import extend_path
|
||||
__path__ = extend_path(__path__, __name__)
|
||||
280
lib/beetsplug/embedart.py
Normal file
280
lib/beetsplug/embedart.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Allows beets to embed album art into file metadata."""
|
||||
import os.path
|
||||
import logging
|
||||
import imghdr
|
||||
import subprocess
|
||||
import platform
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from beets.plugins import BeetsPlugin
|
||||
from beets import mediafile
|
||||
from beets import ui
|
||||
from beets.ui import decargs
|
||||
from beets.util import syspath, normpath, displayable_path
|
||||
from beets.util.artresizer import ArtResizer
|
||||
from beets import config
|
||||
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
|
||||
class EmbedCoverArtPlugin(BeetsPlugin):
|
||||
"""Allows albumart to be embedded into the actual files.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(EmbedCoverArtPlugin, self).__init__()
|
||||
self.config.add({
|
||||
'maxwidth': 0,
|
||||
'auto': True,
|
||||
'compare_threshold': 0,
|
||||
'ifempty': False,
|
||||
})
|
||||
|
||||
if self.config['maxwidth'].get(int) and not ArtResizer.shared.local:
|
||||
self.config['maxwidth'] = 0
|
||||
log.warn(u"embedart: ImageMagick or PIL not found; "
|
||||
u"'maxwidth' option ignored")
|
||||
if self.config['compare_threshold'].get(int) and not \
|
||||
ArtResizer.shared.can_compare:
|
||||
self.config['compare_threshold'] = 0
|
||||
log.warn(u"embedart: ImageMagick 6.8.7 or higher not installed; "
|
||||
u"'compare_threshold' option ignored")
|
||||
|
||||
def commands(self):
|
||||
# Embed command.
|
||||
embed_cmd = ui.Subcommand(
|
||||
'embedart', help='embed image files into file metadata'
|
||||
)
|
||||
embed_cmd.parser.add_option(
|
||||
'-f', '--file', metavar='PATH', help='the image file to embed'
|
||||
)
|
||||
maxwidth = config['embedart']['maxwidth'].get(int)
|
||||
compare_threshold = config['embedart']['compare_threshold'].get(int)
|
||||
ifempty = config['embedart']['ifempty'].get(bool)
|
||||
|
||||
def embed_func(lib, opts, args):
|
||||
if opts.file:
|
||||
imagepath = normpath(opts.file)
|
||||
for item in lib.items(decargs(args)):
|
||||
embed_item(item, imagepath, maxwidth, None,
|
||||
compare_threshold, ifempty)
|
||||
else:
|
||||
for album in lib.albums(decargs(args)):
|
||||
embed_album(album, maxwidth)
|
||||
|
||||
embed_cmd.func = embed_func
|
||||
|
||||
# Extract command.
|
||||
extract_cmd = ui.Subcommand('extractart',
|
||||
help='extract an image from file metadata')
|
||||
extract_cmd.parser.add_option('-o', dest='outpath',
|
||||
help='image output file')
|
||||
|
||||
def extract_func(lib, opts, args):
|
||||
outpath = normpath(opts.outpath or 'cover')
|
||||
item = lib.items(decargs(args)).get()
|
||||
extract(outpath, item)
|
||||
extract_cmd.func = extract_func
|
||||
|
||||
# Clear command.
|
||||
clear_cmd = ui.Subcommand('clearart',
|
||||
help='remove images from file metadata')
|
||||
|
||||
def clear_func(lib, opts, args):
|
||||
clear(lib, decargs(args))
|
||||
clear_cmd.func = clear_func
|
||||
|
||||
return [embed_cmd, extract_cmd, clear_cmd]
|
||||
|
||||
|
||||
@EmbedCoverArtPlugin.listen('album_imported')
|
||||
def album_imported(lib, album):
|
||||
"""Automatically embed art into imported albums.
|
||||
"""
|
||||
if album.artpath and config['embedart']['auto']:
|
||||
embed_album(album, config['embedart']['maxwidth'].get(int), True)
|
||||
|
||||
|
||||
def embed_item(item, imagepath, maxwidth=None, itempath=None,
|
||||
compare_threshold=0, ifempty=False, as_album=False):
|
||||
"""Embed an image into the item's media file.
|
||||
"""
|
||||
if compare_threshold:
|
||||
if not check_art_similarity(item, imagepath, compare_threshold):
|
||||
log.warn(u'Image not similar; skipping.')
|
||||
return
|
||||
if ifempty:
|
||||
art = get_art(item)
|
||||
if not art:
|
||||
pass
|
||||
else:
|
||||
log.debug(u'embedart: media file contained art already {0}'.format(
|
||||
displayable_path(imagepath)
|
||||
))
|
||||
return
|
||||
if maxwidth and not as_album:
|
||||
imagepath = resize_image(imagepath, maxwidth)
|
||||
|
||||
try:
|
||||
log.debug(u'embedart: embedding {0}'.format(
|
||||
displayable_path(imagepath)
|
||||
))
|
||||
item['images'] = [_mediafile_image(imagepath, maxwidth)]
|
||||
except IOError as exc:
|
||||
log.error(u'embedart: could not read image file: {0}'.format(exc))
|
||||
else:
|
||||
# We don't want to store the image in the database.
|
||||
item.try_write(itempath)
|
||||
del item['images']
|
||||
|
||||
|
||||
def embed_album(album, maxwidth=None, quiet=False):
|
||||
"""Embed album art into all of the album's items.
|
||||
"""
|
||||
imagepath = album.artpath
|
||||
if not imagepath:
|
||||
log.info(u'No album art present: {0} - {1}'.
|
||||
format(album.albumartist, album.album))
|
||||
return
|
||||
if not os.path.isfile(syspath(imagepath)):
|
||||
log.error(u'Album art not found at {0}'
|
||||
.format(displayable_path(imagepath)))
|
||||
return
|
||||
if maxwidth:
|
||||
imagepath = resize_image(imagepath, maxwidth)
|
||||
|
||||
log.log(
|
||||
logging.DEBUG if quiet else logging.INFO,
|
||||
u'Embedding album art into {0.albumartist} - {0.album}.'.format(album),
|
||||
)
|
||||
|
||||
for item in album.items():
|
||||
embed_item(item, imagepath, maxwidth, None,
|
||||
config['embedart']['compare_threshold'].get(int),
|
||||
config['embedart']['ifempty'].get(bool), as_album=True)
|
||||
|
||||
|
||||
def resize_image(imagepath, maxwidth):
|
||||
"""Returns path to an image resized to maxwidth.
|
||||
"""
|
||||
log.info(u'Resizing album art to {0} pixels wide'
|
||||
.format(maxwidth))
|
||||
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
|
||||
return imagepath
|
||||
|
||||
|
||||
def check_art_similarity(item, imagepath, compare_threshold):
|
||||
"""A boolean indicating if an image is similar to embedded item art.
|
||||
"""
|
||||
with NamedTemporaryFile(delete=True) as f:
|
||||
art = extract(f.name, item)
|
||||
|
||||
if art:
|
||||
# Converting images to grayscale tends to minimize the weight
|
||||
# of colors in the diff score
|
||||
cmd = 'convert {0} {1} -colorspace gray MIFF:- | ' \
|
||||
'compare -metric PHASH - null:'.format(syspath(imagepath),
|
||||
syspath(art))
|
||||
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
close_fds=platform.system() != 'Windows',
|
||||
shell=True)
|
||||
stdout, stderr = proc.communicate()
|
||||
if proc.returncode:
|
||||
if proc.returncode != 1:
|
||||
log.warn(u'embedart: IM phashes compare failed for {0}, \
|
||||
{1}'.format(displayable_path(imagepath),
|
||||
displayable_path(art)))
|
||||
return
|
||||
phashDiff = float(stderr)
|
||||
else:
|
||||
phashDiff = float(stdout)
|
||||
|
||||
log.info(u'embedart: compare PHASH score is {0}'.format(phashDiff))
|
||||
if phashDiff > compare_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _mediafile_image(image_path, maxwidth=None):
|
||||
"""Return a `mediafile.Image` object for the path.
|
||||
"""
|
||||
|
||||
with open(syspath(image_path), 'rb') as f:
|
||||
data = f.read()
|
||||
return mediafile.Image(data, type=mediafile.ImageType.front)
|
||||
|
||||
|
||||
def get_art(item):
|
||||
# Extract the art.
|
||||
try:
|
||||
mf = mediafile.MediaFile(syspath(item.path))
|
||||
except mediafile.UnreadableFileError as exc:
|
||||
log.error(u'Could not extract art from {0}: {1}'.format(
|
||||
displayable_path(item.path), exc
|
||||
))
|
||||
return
|
||||
|
||||
return mf.art
|
||||
|
||||
# 'extractart' command.
|
||||
|
||||
|
||||
def extract(outpath, item):
|
||||
if not item:
|
||||
log.error(u'No item matches query.')
|
||||
return
|
||||
|
||||
art = get_art(item)
|
||||
|
||||
if not art:
|
||||
log.error(u'No album art present in {0} - {1}.'
|
||||
.format(item.artist, item.title))
|
||||
return
|
||||
|
||||
# Add an extension to the filename.
|
||||
ext = imghdr.what(None, h=art)
|
||||
if not ext:
|
||||
log.error(u'Unknown image type.')
|
||||
return
|
||||
outpath += '.' + ext
|
||||
|
||||
log.info(u'Extracting album art from: {0.artist} - {0.title} '
|
||||
u'to: {1}'.format(item, displayable_path(outpath)))
|
||||
with open(syspath(outpath), 'wb') as f:
|
||||
f.write(art)
|
||||
return outpath
|
||||
|
||||
|
||||
# 'clearart' command.
|
||||
|
||||
def clear(lib, query):
|
||||
log.info(u'Clearing album art from items:')
|
||||
for item in lib.items(query):
|
||||
log.info(u'{0} - {1}'.format(item.artist, item.title))
|
||||
try:
|
||||
mf = mediafile.MediaFile(syspath(item.path),
|
||||
config['id3v23'].get(bool))
|
||||
except mediafile.UnreadableFileError as exc:
|
||||
log.error(u'Could not clear art from {0}: {1}'.format(
|
||||
displayable_path(item.path), exc
|
||||
))
|
||||
continue
|
||||
del mf.art
|
||||
mf.save()
|
||||
396
lib/beetsplug/fetchart.py
Normal file
396
lib/beetsplug/fetchart.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Fetches album art.
|
||||
"""
|
||||
from contextlib import closing
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import requests
|
||||
|
||||
from beets import plugins
|
||||
from beets import importer
|
||||
from beets import ui
|
||||
from beets import util
|
||||
from beets import config
|
||||
from beets.util.artresizer import ArtResizer
|
||||
|
||||
try:
|
||||
import itunes
|
||||
HAVE_ITUNES = True
|
||||
except ImportError:
|
||||
HAVE_ITUNES = False
|
||||
|
||||
IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg']
|
||||
CONTENT_TYPES = ('image/jpeg',)
|
||||
DOWNLOAD_EXTENSION = '.jpg'
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
requests_session = requests.Session()
|
||||
requests_session.headers = {'User-Agent': 'beets'}
|
||||
|
||||
|
||||
def _fetch_image(url):
|
||||
"""Downloads an image from a URL and checks whether it seems to
|
||||
actually be an image. If so, returns a path to the downloaded image.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
log.debug(u'fetchart: downloading art: {0}'.format(url))
|
||||
try:
|
||||
with closing(requests_session.get(url, stream=True)) as resp:
|
||||
if 'Content-Type' not in resp.headers \
|
||||
or resp.headers['Content-Type'] not in CONTENT_TYPES:
|
||||
log.debug(u'fetchart: not an image')
|
||||
return
|
||||
|
||||
# Generate a temporary file with the correct extension.
|
||||
with NamedTemporaryFile(suffix=DOWNLOAD_EXTENSION, delete=False) \
|
||||
as fh:
|
||||
for chunk in resp.iter_content():
|
||||
fh.write(chunk)
|
||||
log.debug(u'fetchart: downloaded art to: {0}'.format(
|
||||
util.displayable_path(fh.name)
|
||||
))
|
||||
return fh.name
|
||||
except (IOError, requests.RequestException):
|
||||
log.debug(u'fetchart: error fetching art')
|
||||
|
||||
|
||||
# ART SOURCES ################################################################
|
||||
|
||||
# Cover Art Archive.
|
||||
|
||||
CAA_URL = 'http://coverartarchive.org/release/{mbid}/front-500.jpg'
|
||||
CAA_GROUP_URL = 'http://coverartarchive.org/release-group/{mbid}/front-500.jpg'
|
||||
|
||||
|
||||
def caa_art(album):
|
||||
"""Return the Cover Art Archive and Cover Art Archive release group URLs
|
||||
using album MusicBrainz release ID and release group ID.
|
||||
"""
|
||||
if album.mb_albumid:
|
||||
yield CAA_URL.format(mbid=album.mb_albumid)
|
||||
if album.mb_releasegroupid:
|
||||
yield CAA_GROUP_URL.format(mbid=album.mb_releasegroupid)
|
||||
|
||||
|
||||
# Art from Amazon.
|
||||
|
||||
AMAZON_URL = 'http://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg'
|
||||
AMAZON_INDICES = (1, 2)
|
||||
|
||||
|
||||
def art_for_asin(album):
|
||||
"""Generate URLs using Amazon ID (ASIN) string.
|
||||
"""
|
||||
if album.asin:
|
||||
for index in AMAZON_INDICES:
|
||||
yield AMAZON_URL % (album.asin, index)
|
||||
|
||||
|
||||
# AlbumArt.org scraper.
|
||||
|
||||
AAO_URL = 'http://www.albumart.org/index_detail.php'
|
||||
AAO_PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"'
|
||||
|
||||
|
||||
def aao_art(album):
|
||||
"""Return art URL from AlbumArt.org using album ASIN.
|
||||
"""
|
||||
if not album.asin:
|
||||
return
|
||||
# Get the page from albumart.org.
|
||||
try:
|
||||
resp = requests_session.get(AAO_URL, params={'asin': album.asin})
|
||||
log.debug(u'fetchart: scraped art URL: {0}'.format(resp.url))
|
||||
except requests.RequestException:
|
||||
log.debug(u'fetchart: error scraping art page')
|
||||
return
|
||||
|
||||
# Search the page for the image URL.
|
||||
m = re.search(AAO_PAT, resp.text)
|
||||
if m:
|
||||
image_url = m.group(1)
|
||||
yield image_url
|
||||
else:
|
||||
log.debug(u'fetchart: no image found on page')
|
||||
|
||||
|
||||
# Google Images scraper.
|
||||
|
||||
GOOGLE_URL = 'https://ajax.googleapis.com/ajax/services/search/images'
|
||||
|
||||
|
||||
def google_art(album):
|
||||
"""Return art URL from google.org given an album title and
|
||||
interpreter.
|
||||
"""
|
||||
if not (album.albumartist and album.album):
|
||||
return
|
||||
search_string = (album.albumartist + ',' + album.album).encode('utf-8')
|
||||
response = requests_session.get(GOOGLE_URL, params={
|
||||
'v': '1.0',
|
||||
'q': search_string,
|
||||
'start': '0',
|
||||
})
|
||||
|
||||
# Get results using JSON.
|
||||
try:
|
||||
results = response.json()
|
||||
data = results['responseData']
|
||||
dataInfo = data['results']
|
||||
for myUrl in dataInfo:
|
||||
yield myUrl['unescapedUrl']
|
||||
except:
|
||||
log.debug(u'fetchart: error scraping art page')
|
||||
return
|
||||
|
||||
|
||||
# Art from the iTunes Store.
|
||||
|
||||
def itunes_art(album):
|
||||
"""Return art URL from iTunes Store given an album title.
|
||||
"""
|
||||
search_string = (album.albumartist + ' ' + album.album).encode('utf-8')
|
||||
try:
|
||||
# Isolate bugs in the iTunes library while searching.
|
||||
try:
|
||||
itunes_album = itunes.search_album(search_string)[0]
|
||||
except Exception as exc:
|
||||
log.debug('fetchart: iTunes search failed: {0}'.format(exc))
|
||||
return
|
||||
|
||||
if itunes_album.get_artwork()['100']:
|
||||
small_url = itunes_album.get_artwork()['100']
|
||||
big_url = small_url.replace('100x100', '1200x1200')
|
||||
yield big_url
|
||||
else:
|
||||
log.debug(u'fetchart: album has no artwork in iTunes Store')
|
||||
except IndexError:
|
||||
log.debug(u'fetchart: album not found in iTunes Store')
|
||||
|
||||
|
||||
# Art from the filesystem.
|
||||
|
||||
|
||||
def filename_priority(filename, cover_names):
|
||||
"""Sort order for image names.
|
||||
|
||||
Return indexes of cover names found in the image filename. This
|
||||
means that images with lower-numbered and more keywords will have higher
|
||||
priority.
|
||||
"""
|
||||
return [idx for (idx, x) in enumerate(cover_names) if x in filename]
|
||||
|
||||
|
||||
def art_in_path(path, cover_names, cautious):
|
||||
"""Look for album art files in a specified directory.
|
||||
"""
|
||||
if not os.path.isdir(path):
|
||||
return
|
||||
|
||||
# Find all files that look like images in the directory.
|
||||
images = []
|
||||
for fn in os.listdir(path):
|
||||
for ext in IMAGE_EXTENSIONS:
|
||||
if fn.lower().endswith('.' + ext):
|
||||
images.append(fn)
|
||||
|
||||
# Look for "preferred" filenames.
|
||||
images = sorted(images, key=lambda x: filename_priority(x, cover_names))
|
||||
cover_pat = r"(\b|_)({0})(\b|_)".format('|'.join(cover_names))
|
||||
for fn in images:
|
||||
if re.search(cover_pat, os.path.splitext(fn)[0], re.I):
|
||||
log.debug(u'fetchart: using well-named art file {0}'.format(
|
||||
util.displayable_path(fn)
|
||||
))
|
||||
return os.path.join(path, fn)
|
||||
|
||||
# Fall back to any image in the folder.
|
||||
if images and not cautious:
|
||||
log.debug(u'fetchart: using fallback art file {0}'.format(
|
||||
util.displayable_path(images[0])
|
||||
))
|
||||
return os.path.join(path, images[0])
|
||||
|
||||
|
||||
# Try each source in turn.
|
||||
|
||||
SOURCES_ALL = [u'coverart', u'itunes', u'amazon', u'albumart', u'google']
|
||||
|
||||
ART_FUNCS = {
|
||||
u'coverart': caa_art,
|
||||
u'itunes': itunes_art,
|
||||
u'albumart': aao_art,
|
||||
u'amazon': art_for_asin,
|
||||
u'google': google_art,
|
||||
}
|
||||
|
||||
|
||||
def _source_urls(album, sources=SOURCES_ALL):
|
||||
"""Generate possible source URLs for an album's art. The URLs are
|
||||
not guaranteed to work so they each need to be attempted in turn.
|
||||
This allows the main `art_for_album` function to abort iteration
|
||||
through this sequence early to avoid the cost of scraping when not
|
||||
necessary.
|
||||
"""
|
||||
for s in sources:
|
||||
urls = ART_FUNCS[s](album)
|
||||
for url in urls:
|
||||
yield url
|
||||
|
||||
|
||||
def art_for_album(album, paths, maxwidth=None, local_only=False):
|
||||
"""Given an Album object, returns a path to downloaded art for the
|
||||
album (or None if no art is found). If `maxwidth`, then images are
|
||||
resized to this maximum pixel size. If `local_only`, then only local
|
||||
image files from the filesystem are returned; no network requests
|
||||
are made.
|
||||
"""
|
||||
out = None
|
||||
|
||||
# Local art.
|
||||
cover_names = config['fetchart']['cover_names'].as_str_seq()
|
||||
cover_names = map(util.bytestring_path, cover_names)
|
||||
cautious = config['fetchart']['cautious'].get(bool)
|
||||
if paths:
|
||||
for path in paths:
|
||||
out = art_in_path(path, cover_names, cautious)
|
||||
if out:
|
||||
break
|
||||
|
||||
# Web art sources.
|
||||
remote_priority = config['fetchart']['remote_priority'].get(bool)
|
||||
if not local_only and (remote_priority or not out):
|
||||
for url in _source_urls(album,
|
||||
config['fetchart']['sources'].as_str_seq()):
|
||||
if maxwidth:
|
||||
url = ArtResizer.shared.proxy_url(maxwidth, url)
|
||||
candidate = _fetch_image(url)
|
||||
if candidate:
|
||||
out = candidate
|
||||
break
|
||||
|
||||
if maxwidth and out:
|
||||
out = ArtResizer.shared.resize(maxwidth, out)
|
||||
return out
|
||||
|
||||
|
||||
# PLUGIN LOGIC ###############################################################
|
||||
|
||||
|
||||
def batch_fetch_art(lib, albums, force, maxwidth=None):
|
||||
"""Fetch album art for each of the albums. This implements the manual
|
||||
fetchart CLI command.
|
||||
"""
|
||||
for album in albums:
|
||||
if album.artpath and not force:
|
||||
message = 'has album art'
|
||||
else:
|
||||
# In ordinary invocations, look for images on the
|
||||
# filesystem. When forcing, however, always go to the Web
|
||||
# sources.
|
||||
local_paths = None if force else [album.path]
|
||||
|
||||
path = art_for_album(album, local_paths, maxwidth)
|
||||
if path:
|
||||
album.set_art(path, False)
|
||||
album.store()
|
||||
message = ui.colorize('green', 'found album art')
|
||||
else:
|
||||
message = ui.colorize('red', 'no art found')
|
||||
|
||||
log.info(u'{0} - {1}: {2}'.format(album.albumartist, album.album,
|
||||
message))
|
||||
|
||||
|
||||
class FetchArtPlugin(plugins.BeetsPlugin):
|
||||
def __init__(self):
|
||||
super(FetchArtPlugin, self).__init__()
|
||||
|
||||
self.config.add({
|
||||
'auto': True,
|
||||
'maxwidth': 0,
|
||||
'remote_priority': False,
|
||||
'cautious': False,
|
||||
'google_search': False,
|
||||
'cover_names': ['cover', 'front', 'art', 'album', 'folder'],
|
||||
'sources': SOURCES_ALL,
|
||||
})
|
||||
|
||||
# Holds paths to downloaded images between fetching them and
|
||||
# placing them in the filesystem.
|
||||
self.art_paths = {}
|
||||
|
||||
self.maxwidth = self.config['maxwidth'].get(int)
|
||||
if self.config['auto']:
|
||||
# Enable two import hooks when fetching is enabled.
|
||||
self.import_stages = [self.fetch_art]
|
||||
self.register_listener('import_task_files', self.assign_art)
|
||||
|
||||
available_sources = list(SOURCES_ALL)
|
||||
if not HAVE_ITUNES and u'itunes' in available_sources:
|
||||
available_sources.remove(u'itunes')
|
||||
self.config['sources'] = plugins.sanitize_choices(
|
||||
self.config['sources'].as_str_seq(), available_sources)
|
||||
|
||||
# Asynchronous; after music is added to the library.
|
||||
def fetch_art(self, session, task):
|
||||
"""Find art for the album being imported."""
|
||||
if task.is_album: # Only fetch art for full albums.
|
||||
if task.choice_flag == importer.action.ASIS:
|
||||
# For as-is imports, don't search Web sources for art.
|
||||
local = True
|
||||
elif task.choice_flag == importer.action.APPLY:
|
||||
# Search everywhere for art.
|
||||
local = False
|
||||
else:
|
||||
# For any other choices (e.g., TRACKS), do nothing.
|
||||
return
|
||||
|
||||
path = art_for_album(task.album, task.paths, self.maxwidth, local)
|
||||
|
||||
if path:
|
||||
self.art_paths[task] = path
|
||||
|
||||
# Synchronous; after music files are put in place.
|
||||
def assign_art(self, session, task):
|
||||
"""Place the discovered art in the filesystem."""
|
||||
if task in self.art_paths:
|
||||
path = self.art_paths.pop(task)
|
||||
|
||||
album = task.album
|
||||
src_removed = (config['import']['delete'].get(bool) or
|
||||
config['import']['move'].get(bool))
|
||||
album.set_art(path, not src_removed)
|
||||
album.store()
|
||||
if src_removed:
|
||||
task.prune(path)
|
||||
|
||||
# Manual album art fetching.
|
||||
def commands(self):
|
||||
cmd = ui.Subcommand('fetchart', help='download album art')
|
||||
cmd.parser.add_option('-f', '--force', dest='force',
|
||||
action='store_true', default=False,
|
||||
help='re-download art when already present')
|
||||
|
||||
def func(lib, opts, args):
|
||||
batch_fetch_art(lib, lib.albums(ui.decargs(args)), opts.force,
|
||||
self.maxwidth)
|
||||
cmd.func = func
|
||||
return [cmd]
|
||||
544
lib/beetsplug/lyrics.py
Normal file
544
lib/beetsplug/lyrics.py
Normal file
@@ -0,0 +1,544 @@
|
||||
# This file is part of beets.
|
||||
# Copyright 2014, Adrian Sampson.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining
|
||||
# a copy of this software and associated documentation files (the
|
||||
# "Software"), to deal in the Software without restriction, including
|
||||
# without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to
|
||||
# permit persons to whom the Software is furnished to do so, subject to
|
||||
# the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be
|
||||
# included in all copies or substantial portions of the Software.
|
||||
|
||||
"""Fetches, embeds, and displays lyrics.
|
||||
"""
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
import logging
|
||||
import requests
|
||||
import json
|
||||
import unicodedata
|
||||
import urllib
|
||||
import difflib
|
||||
import itertools
|
||||
from HTMLParser import HTMLParseError
|
||||
|
||||
from beets import plugins
|
||||
from beets import config, ui
|
||||
|
||||
|
||||
# Global logger.
|
||||
|
||||
log = logging.getLogger('beets')
|
||||
|
||||
DIV_RE = re.compile(r'<(/?)div>?', re.I)
|
||||
COMMENT_RE = re.compile(r'<!--.*-->', re.S)
|
||||
TAG_RE = re.compile(r'<[^>]*>')
|
||||
BREAK_RE = re.compile(r'\n?\s*<br([\s|/][^>]*)*>\s*\n?', re.I)
|
||||
URL_CHARACTERS = {
|
||||
u'\u2018': u"'",
|
||||
u'\u2019': u"'",
|
||||
u'\u201c': u'"',
|
||||
u'\u201d': u'"',
|
||||
u'\u2010': u'-',
|
||||
u'\u2011': u'-',
|
||||
u'\u2012': u'-',
|
||||
u'\u2013': u'-',
|
||||
u'\u2014': u'-',
|
||||
u'\u2015': u'-',
|
||||
u'\u2016': u'-',
|
||||
u'\u2026': u'...',
|
||||
}
|
||||
|
||||
|
||||
# Utilities.
|
||||
|
||||
def fetch_url(url):
|
||||
"""Retrieve the content at a given URL, or return None if the source
|
||||
is unreachable.
|
||||
"""
|
||||
try:
|
||||
r = requests.get(url, verify=False)
|
||||
except requests.RequestException as exc:
|
||||
log.debug(u'lyrics request failed: {0}'.format(exc))
|
||||
return
|
||||
if r.status_code == requests.codes.ok:
|
||||
return r.text
|
||||
else:
|
||||
log.debug(u'failed to fetch: {0} ({1})'.format(url, r.status_code))
|
||||
|
||||
|
||||
def unescape(text):
|
||||
"""Resolves &#xxx; HTML entities (and some others)."""
|
||||
if isinstance(text, str):
|
||||
text = text.decode('utf8', 'ignore')
|
||||
out = text.replace(u' ', u' ')
|
||||
|
||||
def replchar(m):
|
||||
num = m.group(1)
|
||||
return unichr(int(num))
|
||||
out = re.sub(u"&#(\d+);", replchar, out)
|
||||
return out
|
||||
|
||||
|
||||
def extract_text_between(html, start_marker, end_marker):
|
||||
try:
|
||||
_, html = html.split(start_marker, 1)
|
||||
html, _ = html.split(end_marker, 1)
|
||||
except ValueError:
|
||||
return u''
|
||||
return html
|
||||
|
||||
|
||||
def extract_text_in(html, starttag):
|
||||
"""Extract the text from a <DIV> tag in the HTML starting with
|
||||
``starttag``. Returns None if parsing fails.
|
||||
"""
|
||||
|
||||
# Strip off the leading text before opening tag.
|
||||
try:
|
||||
_, html = html.split(starttag, 1)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# Walk through balanced DIV tags.
|
||||
level = 0
|
||||
parts = []
|
||||
pos = 0
|
||||
for match in DIV_RE.finditer(html):
|
||||
if match.group(1): # Closing tag.
|
||||
level -= 1
|
||||
if level == 0:
|
||||
pos = match.end()
|
||||
else: # Opening tag.
|
||||
if level == 0:
|
||||
parts.append(html[pos:match.start()])
|
||||
level += 1
|
||||
|
||||
if level == -1:
|
||||
parts.append(html[pos:match.start()])
|
||||
break
|
||||
else:
|
||||
print('no closing tag found!')
|
||||
return
|
||||
return u''.join(parts)
|
||||
|
||||
|
||||
def search_pairs(item):
|
||||
"""Yield a pairs of artists and titles to search for.
|
||||
|
||||
The first item in the pair is the name of the artist, the second
|
||||
item is a list of song names.
|
||||
|
||||
In addition to the artist and title obtained from the `item` the
|
||||
method tries to strip extra information like paranthesized suffixes
|
||||
and featured artists from the strings and add them as candidates.
|
||||
The method also tries to split multiple titles separated with `/`.
|
||||
"""
|
||||
|
||||
title, artist = item.title, item.artist
|
||||
titles = [title]
|
||||
artists = [artist]
|
||||
|
||||
# Remove any featuring artists from the artists name
|
||||
pattern = r"(.*?) {0}".format(plugins.feat_tokens())
|
||||
match = re.search(pattern, artist, re.IGNORECASE)
|
||||
if match:
|
||||
artists.append(match.group(1))
|
||||
|
||||
# Remove a parenthesized suffix from a title string. Common
|
||||
# examples include (live), (remix), and (acoustic).
|
||||
pattern = r"(.+?)\s+[(].*[)]$"
|
||||
match = re.search(pattern, title, re.IGNORECASE)
|
||||
if match:
|
||||
titles.append(match.group(1))
|
||||
|
||||
# Remove any featuring artists from the title
|
||||
pattern = r"(.*?) {0}".format(plugins.feat_tokens(for_artist=False))
|
||||
for title in titles[:]:
|
||||
match = re.search(pattern, title, re.IGNORECASE)
|
||||
if match:
|
||||
titles.append(match.group(1))
|
||||
|
||||
# Check for a dual song (e.g. Pink Floyd - Speak to Me / Breathe)
|
||||
# and each of them.
|
||||
multi_titles = []
|
||||
for title in titles:
|
||||
multi_titles.append([title])
|
||||
if '/' in title:
|
||||
multi_titles.append([x.strip() for x in title.split('/')])
|
||||
|
||||
return itertools.product(artists, multi_titles)
|
||||
|
||||
|
||||
def _encode(s):
|
||||
"""Encode the string for inclusion in a URL (common to both
|
||||
LyricsWiki and Lyrics.com).
|
||||
"""
|
||||
if isinstance(s, unicode):
|
||||
for char, repl in URL_CHARACTERS.items():
|
||||
s = s.replace(char, repl)
|
||||
s = s.encode('utf8', 'ignore')
|
||||
return urllib.quote(s)
|
||||
|
||||
# Musixmatch
|
||||
|
||||
MUSIXMATCH_URL_PATTERN = 'https://www.musixmatch.com/lyrics/%s/%s'
|
||||
|
||||
|
||||
def fetch_musixmatch(artist, title):
|
||||
url = MUSIXMATCH_URL_PATTERN % (_lw_encode(artist.title()),
|
||||
_lw_encode(title.title()))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
lyrics = extract_text_between(html, '"lyrics_body":', '"lyrics_language":')
|
||||
return lyrics.strip(',"').replace('\\n', '\n')
|
||||
|
||||
# LyricsWiki.
|
||||
|
||||
LYRICSWIKI_URL_PATTERN = 'http://lyrics.wikia.com/%s:%s'
|
||||
|
||||
|
||||
def _lw_encode(s):
|
||||
s = re.sub(r'\s+', '_', s)
|
||||
s = s.replace("<", "Less_Than")
|
||||
s = s.replace(">", "Greater_Than")
|
||||
s = s.replace("#", "Number_")
|
||||
s = re.sub(r'[\[\{]', '(', s)
|
||||
s = re.sub(r'[\]\}]', ')', s)
|
||||
return _encode(s)
|
||||
|
||||
|
||||
def fetch_lyricswiki(artist, title):
|
||||
"""Fetch lyrics from LyricsWiki."""
|
||||
url = LYRICSWIKI_URL_PATTERN % (_lw_encode(artist), _lw_encode(title))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
|
||||
lyrics = extract_text_in(html, u"<div class='lyricbox'>")
|
||||
if lyrics and 'Unfortunately, we are not licensed' not in lyrics:
|
||||
return lyrics
|
||||
|
||||
|
||||
# Lyrics.com.
|
||||
|
||||
LYRICSCOM_URL_PATTERN = 'http://www.lyrics.com/%s-lyrics-%s.html'
|
||||
LYRICSCOM_NOT_FOUND = (
|
||||
'Sorry, we do not have the lyric',
|
||||
'Submit Lyrics',
|
||||
)
|
||||
|
||||
|
||||
def _lc_encode(s):
|
||||
s = re.sub(r'[^\w\s-]', '', s)
|
||||
s = re.sub(r'\s+', '-', s)
|
||||
return _encode(s).lower()
|
||||
|
||||
|
||||
def fetch_lyricscom(artist, title):
|
||||
"""Fetch lyrics from Lyrics.com."""
|
||||
url = LYRICSCOM_URL_PATTERN % (_lc_encode(title), _lc_encode(artist))
|
||||
html = fetch_url(url)
|
||||
if not html:
|
||||
return
|
||||
lyrics = extract_text_between(html, '<div id="lyrics" class="SCREENONLY" '
|
||||
'itemprop="description">', '</div>')
|
||||
if not lyrics:
|
||||
return
|
||||
for not_found_str in LYRICSCOM_NOT_FOUND:
|
||||
if not_found_str in lyrics:
|
||||
return
|
||||
|
||||
parts = lyrics.split('\n---\nLyrics powered by', 1)
|
||||
if parts:
|
||||
return parts[0]
|
||||
|
||||
|
||||
# Optional Google custom search API backend.
|
||||
|
||||
def slugify(text):
|
||||
"""Normalize a string and remove non-alphanumeric characters.
|
||||
"""
|
||||
text = re.sub(r"[-'_\s]", '_', text)
|
||||
text = re.sub(r"_+", '_', text).strip('_')
|
||||
pat = "([^,\(]*)\((.*?)\)" # Remove content within parentheses
|
||||
text = re.sub(pat, '\g<1>', text).strip()
|
||||
try:
|
||||
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore')
|
||||
text = unicode(re.sub('[-\s]+', ' ', text))
|
||||
except UnicodeDecodeError:
|
||||
log.exception(u"Failing to normalize '{0}'".format(text))
|
||||
return text
|
||||
|
||||
|
||||
BY_TRANS = ['by', 'par', 'de', 'von']
|
||||
LYRICS_TRANS = ['lyrics', 'paroles', 'letras', 'liedtexte']
|
||||
|
||||
|
||||
def is_page_candidate(urlLink, urlTitle, title, artist):
|
||||
"""Return True if the URL title makes it a good candidate to be a
|
||||
page that contains lyrics of title by artist.
|
||||
"""
|
||||
title = slugify(title.lower())
|
||||
artist = slugify(artist.lower())
|
||||
sitename = re.search(u"//([^/]+)/.*", slugify(urlLink.lower())).group(1)
|
||||
urlTitle = slugify(urlTitle.lower())
|
||||
# Check if URL title contains song title (exact match)
|
||||
if urlTitle.find(title) != -1:
|
||||
return True
|
||||
# or try extracting song title from URL title and check if
|
||||
# they are close enough
|
||||
tokens = [by + '_' + artist for by in BY_TRANS] + \
|
||||
[artist, sitename, sitename.replace('www.', '')] + LYRICS_TRANS
|
||||
songTitle = re.sub(u'(%s)' % u'|'.join(tokens), u'', urlTitle)
|
||||
songTitle = songTitle.strip('_|')
|
||||
typoRatio = .9
|
||||
return difflib.SequenceMatcher(None, songTitle, title).ratio() >= typoRatio
|
||||
|
||||
|
||||
def remove_credits(text):
|
||||
"""Remove first/last line of text if it contains the word 'lyrics'
|
||||
eg 'Lyrics by songsdatabase.com'
|
||||
"""
|
||||
textlines = text.split('\n')
|
||||
credits = None
|
||||
for i in (0, -1):
|
||||
if textlines and 'lyrics' in textlines[i].lower():
|
||||
credits = textlines.pop(i)
|
||||
if credits:
|
||||
text = '\n'.join(textlines)
|
||||
return text
|
||||
|
||||
|
||||
def is_lyrics(text, artist=None):
|
||||
"""Determine whether the text seems to be valid lyrics.
|
||||
"""
|
||||
if not text:
|
||||
return False
|
||||
badTriggersOcc = []
|
||||
nbLines = text.count('\n')
|
||||
if nbLines <= 1:
|
||||
log.debug(u"Ignoring too short lyrics '{0}'".format(text))
|
||||
return False
|
||||
elif nbLines < 5:
|
||||
badTriggersOcc.append('too_short')
|
||||
else:
|
||||
# Lyrics look legit, remove credits to avoid being penalized further
|
||||
# down
|
||||
text = remove_credits(text)
|
||||
|
||||
badTriggers = ['lyrics', 'copyright', 'property', 'links']
|
||||
if artist:
|
||||
badTriggersOcc += [artist]
|
||||
|
||||
for item in badTriggers:
|
||||
badTriggersOcc += [item] * len(re.findall(r'\W%s\W' % item,
|
||||
text, re.I))
|
||||
|
||||
if badTriggersOcc:
|
||||
log.debug(u'Bad triggers detected: {0}'.format(badTriggersOcc))
|
||||
return len(badTriggersOcc) < 2
|
||||
|
||||
|
||||
def _scrape_strip_cruft(html, plain_text_out=False):
|
||||
"""Clean up HTML
|
||||
"""
|
||||
html = unescape(html)
|
||||
|
||||
html = html.replace('\r', '\n') # Normalize EOL.
|
||||
html = re.sub(r' +', ' ', html) # Whitespaces collapse.
|
||||
html = BREAK_RE.sub('\n', html) # <br> eats up surrounding '\n'.
|
||||
html = re.sub(r'<(script).*?</\1>(?s)', '', html) # Strip script tags.
|
||||
|
||||
if plain_text_out: # Strip remaining HTML tags
|
||||
html = COMMENT_RE.sub('', html)
|
||||
html = TAG_RE.sub('', html)
|
||||
|
||||
html = '\n'.join([x.strip() for x in html.strip().split('\n')])
|
||||
html = re.sub(r'\n{3,}', r'\n\n', html)
|
||||
return html
|
||||
|
||||
|
||||
def _scrape_merge_paragraphs(html):
|
||||
html = re.sub(r'</p>\s*<p(\s*[^>]*)>', '\n', html)
|
||||
return re.sub(r'<div .*>\s*</div>', '\n', html)
|
||||
|
||||
|
||||
def scrape_lyrics_from_html(html):
|
||||
"""Scrape lyrics from a URL. If no lyrics can be found, return None
|
||||
instead.
|
||||
"""
|
||||
from bs4 import SoupStrainer, BeautifulSoup
|
||||
|
||||
if not html:
|
||||
return None
|
||||
|
||||
def is_text_notcode(text):
|
||||
length = len(text)
|
||||
return (length > 20 and
|
||||
text.count(' ') > length / 25 and
|
||||
(text.find('{') == -1 or text.find(';') == -1))
|
||||
html = _scrape_strip_cruft(html)
|
||||
html = _scrape_merge_paragraphs(html)
|
||||
|
||||
# extract all long text blocks that are not code
|
||||
try:
|
||||
soup = BeautifulSoup(html, "html.parser",
|
||||
parse_only=SoupStrainer(text=is_text_notcode))
|
||||
except HTMLParseError:
|
||||
return None
|
||||
soup = sorted(soup.stripped_strings, key=len)[-1]
|
||||
return soup
|
||||
|
||||
|
||||
def fetch_google(artist, title):
|
||||
"""Fetch lyrics from Google search results.
|
||||
"""
|
||||
query = u"%s %s" % (artist, title)
|
||||
api_key = config['lyrics']['google_API_key'].get(unicode)
|
||||
engine_id = config['lyrics']['google_engine_ID'].get(unicode)
|
||||
url = u'https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s' % \
|
||||
(api_key, engine_id, urllib.quote(query.encode('utf8')))
|
||||
|
||||
data = urllib.urlopen(url)
|
||||
data = json.load(data)
|
||||
if 'error' in data:
|
||||
reason = data['error']['errors'][0]['reason']
|
||||
log.debug(u'google lyrics backend error: {0}'.format(reason))
|
||||
return
|
||||
|
||||
if 'items' in data.keys():
|
||||
for item in data['items']:
|
||||
urlLink = item['link']
|
||||
urlTitle = item.get('title', u'')
|
||||
if not is_page_candidate(urlLink, urlTitle, title, artist):
|
||||
continue
|
||||
html = fetch_url(urlLink)
|
||||
lyrics = scrape_lyrics_from_html(html)
|
||||
if not lyrics:
|
||||
continue
|
||||
|
||||
if is_lyrics(lyrics, artist):
|
||||
log.debug(u'got lyrics from {0}'.format(item['displayLink']))
|
||||
return lyrics
|
||||
|
||||
|
||||
# Plugin logic.
|
||||
|
||||
SOURCES = ['google', 'lyricwiki', 'lyrics.com', 'musixmatch']
|
||||
SOURCE_BACKENDS = {
|
||||
'google': fetch_google,
|
||||
'lyricwiki': fetch_lyricswiki,
|
||||
'lyrics.com': fetch_lyricscom,
|
||||
'musixmatch': fetch_musixmatch,
|
||||
}
|
||||
|
||||
|
||||
class LyricsPlugin(plugins.BeetsPlugin):
|
||||
def __init__(self):
|
||||
super(LyricsPlugin, self).__init__()
|
||||
self.import_stages = [self.imported]
|
||||
self.config.add({
|
||||
'auto': True,
|
||||
'google_API_key': None,
|
||||
'google_engine_ID': u'009217259823014548361:lndtuqkycfu',
|
||||
'fallback': None,
|
||||
'force': False,
|
||||
'sources': SOURCES,
|
||||
})
|
||||
|
||||
available_sources = list(SOURCES)
|
||||
if not self.config['google_API_key'].get() and \
|
||||
'google' in SOURCES:
|
||||
available_sources.remove('google')
|
||||
self.config['sources'] = plugins.sanitize_choices(
|
||||
self.config['sources'].as_str_seq(), available_sources)
|
||||
self.backends = []
|
||||
for key in self.config['sources'].as_str_seq():
|
||||
self.backends.append(SOURCE_BACKENDS[key])
|
||||
|
||||
def commands(self):
|
||||
cmd = ui.Subcommand('lyrics', help='fetch song lyrics')
|
||||
cmd.parser.add_option('-p', '--print', dest='printlyr',
|
||||
action='store_true', default=False,
|
||||
help='print lyrics to console')
|
||||
cmd.parser.add_option('-f', '--force', dest='force_refetch',
|
||||
action='store_true', default=False,
|
||||
help='always re-download lyrics')
|
||||
|
||||
def func(lib, opts, args):
|
||||
# The "write to files" option corresponds to the
|
||||
# import_write config value.
|
||||
write = config['import']['write'].get(bool)
|
||||
for item in lib.items(ui.decargs(args)):
|
||||
self.fetch_item_lyrics(
|
||||
lib, logging.INFO, item, write,
|
||||
opts.force_refetch or self.config['force'],
|
||||
)
|
||||
if opts.printlyr and item.lyrics:
|
||||
ui.print_(item.lyrics)
|
||||
|
||||
cmd.func = func
|
||||
return [cmd]
|
||||
|
||||
def imported(self, session, task):
|
||||
"""Import hook for fetching lyrics automatically.
|
||||
"""
|
||||
if self.config['auto']:
|
||||
for item in task.imported_items():
|
||||
self.fetch_item_lyrics(session.lib, logging.DEBUG, item,
|
||||
False, self.config['force'])
|
||||
|
||||
def fetch_item_lyrics(self, lib, loglevel, item, write, force):
|
||||
"""Fetch and store lyrics for a single item. If ``write``, then the
|
||||
lyrics will also be written to the file itself. The ``loglevel``
|
||||
parameter controls the visibility of the function's status log
|
||||
messages.
|
||||
"""
|
||||
# Skip if the item already has lyrics.
|
||||
if not force and item.lyrics:
|
||||
log.log(loglevel, u'lyrics already present: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
return
|
||||
|
||||
lyrics = None
|
||||
for artist, titles in search_pairs(item):
|
||||
lyrics = [self.get_lyrics(artist, title) for title in titles]
|
||||
if any(lyrics):
|
||||
break
|
||||
|
||||
lyrics = u"\n\n---\n\n".join([l for l in lyrics if l])
|
||||
|
||||
if lyrics:
|
||||
log.log(loglevel, u'fetched lyrics: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
else:
|
||||
log.log(loglevel, u'lyrics not found: {0} - {1}'
|
||||
.format(item.artist, item.title))
|
||||
fallback = self.config['fallback'].get()
|
||||
if fallback:
|
||||
lyrics = fallback
|
||||
else:
|
||||
return
|
||||
|
||||
item.lyrics = lyrics
|
||||
|
||||
if write:
|
||||
item.try_write()
|
||||
item.store()
|
||||
|
||||
def get_lyrics(self, artist, title):
|
||||
"""Fetch lyrics, trying each source in turn. Return a string or
|
||||
None if no lyrics were found.
|
||||
"""
|
||||
for backend in self.backends:
|
||||
lyrics = backend(artist, title)
|
||||
if lyrics:
|
||||
log.debug(u'got lyrics from backend: {0}'
|
||||
.format(backend.__name__))
|
||||
return _scrape_strip_cruft(lyrics, True)
|
||||
@@ -14,9 +14,6 @@
|
||||
|
||||
# Minor modifications made by Andrew Resch to replace the BTFailure errors with Exceptions
|
||||
|
||||
from types import StringType, IntType, LongType, DictType, ListType, TupleType
|
||||
|
||||
|
||||
def decode_int(x, f):
|
||||
f += 1
|
||||
newf = x.index('e', f)
|
||||
@@ -24,36 +21,32 @@ def decode_int(x, f):
|
||||
if x[f] == '-':
|
||||
if x[f + 1] == '0':
|
||||
raise ValueError
|
||||
elif x[f] == '0' and newf != f + 1:
|
||||
elif x[f] == '0' and newf != f+1:
|
||||
raise ValueError
|
||||
return (n, newf + 1)
|
||||
|
||||
return (n, newf+1)
|
||||
|
||||
def decode_string(x, f):
|
||||
colon = x.index(':', f)
|
||||
n = int(x[f:colon])
|
||||
if x[f] == '0' and colon != f + 1:
|
||||
if x[f] == '0' and colon != f+1:
|
||||
raise ValueError
|
||||
colon += 1
|
||||
return (x[colon:colon + n], colon + n)
|
||||
|
||||
return (x[colon:colon+n], colon+n)
|
||||
|
||||
def decode_list(x, f):
|
||||
r, f = [], f + 1
|
||||
r, f = [], f+1
|
||||
while x[f] != 'e':
|
||||
v, f = decode_func[x[f]](x, f)
|
||||
r.append(v)
|
||||
return (r, f + 1)
|
||||
|
||||
|
||||
def decode_dict(x, f):
|
||||
r, f = {}, f + 1
|
||||
r, f = {}, f+1
|
||||
while x[f] != 'e':
|
||||
k, f = decode_string(x, f)
|
||||
r[k], f = decode_func[x[f]](x, f)
|
||||
return (r, f + 1)
|
||||
|
||||
|
||||
decode_func = {}
|
||||
decode_func['l'] = decode_list
|
||||
decode_func['d'] = decode_dict
|
||||
@@ -69,7 +62,6 @@ decode_func['7'] = decode_string
|
||||
decode_func['8'] = decode_string
|
||||
decode_func['9'] = decode_string
|
||||
|
||||
|
||||
def bdecode(x):
|
||||
try:
|
||||
r, l = decode_func[x[0]](x, 0)
|
||||
@@ -78,41 +70,38 @@ def bdecode(x):
|
||||
|
||||
return r
|
||||
|
||||
from types import StringType, IntType, LongType, DictType, ListType, TupleType
|
||||
|
||||
|
||||
class Bencached(object):
|
||||
|
||||
__slots__ = ['bencoded']
|
||||
|
||||
def __init__(self, s):
|
||||
self.bencoded = s
|
||||
|
||||
|
||||
def encode_bencached(x, r):
|
||||
def encode_bencached(x,r):
|
||||
r.append(x.bencoded)
|
||||
|
||||
|
||||
def encode_int(x, r):
|
||||
r.extend(('i', str(x), 'e'))
|
||||
|
||||
|
||||
def encode_bool(x, r):
|
||||
if x:
|
||||
encode_int(1, r)
|
||||
else:
|
||||
encode_int(0, r)
|
||||
|
||||
|
||||
def encode_string(x, r):
|
||||
r.extend((str(len(x)), ':', x))
|
||||
|
||||
|
||||
def encode_list(x, r):
|
||||
r.append('l')
|
||||
for i in x:
|
||||
encode_func[type(i)](i, r)
|
||||
r.append('e')
|
||||
|
||||
|
||||
def encode_dict(x, r):
|
||||
def encode_dict(x,r):
|
||||
r.append('d')
|
||||
ilist = x.items()
|
||||
ilist.sort()
|
||||
@@ -121,7 +110,6 @@ def encode_dict(x, r):
|
||||
encode_func[type(v)](v, r)
|
||||
r.append('e')
|
||||
|
||||
|
||||
encode_func = {}
|
||||
encode_func[Bencached] = encode_bencached
|
||||
encode_func[IntType] = encode_int
|
||||
@@ -133,12 +121,10 @@ encode_func[DictType] = encode_dict
|
||||
|
||||
try:
|
||||
from types import BooleanType
|
||||
|
||||
encode_func[BooleanType] = encode_bool
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def bencode(x):
|
||||
r = []
|
||||
encode_func[type(x)](x, r)
|
||||
803
lib/biplist/__init__.py
Executable file
803
lib/biplist/__init__.py
Executable file
@@ -0,0 +1,803 @@
|
||||
"""biplist -- a library for reading and writing binary property list files.
|
||||
|
||||
Binary Property List (plist) files provide a faster and smaller serialization
|
||||
format for property lists on OS X. This is a library for generating binary
|
||||
plists which can be read by OS X, iOS, or other clients.
|
||||
|
||||
The API models the plistlib API, and will call through to plistlib when
|
||||
XML serialization or deserialization is required.
|
||||
|
||||
To generate plists with UID values, wrap the values with the Uid object. The
|
||||
value must be an int.
|
||||
|
||||
To generate plists with NSData/CFData values, wrap the values with the
|
||||
Data object. The value must be a string.
|
||||
|
||||
Date values can only be datetime.datetime objects.
|
||||
|
||||
The exceptions InvalidPlistException and NotBinaryPlistException may be
|
||||
thrown to indicate that the data cannot be serialized or deserialized as
|
||||
a binary plist.
|
||||
|
||||
Plist generation example:
|
||||
|
||||
from biplist import *
|
||||
from datetime import datetime
|
||||
plist = {'aKey':'aValue',
|
||||
'0':1.322,
|
||||
'now':datetime.now(),
|
||||
'list':[1,2,3],
|
||||
'tuple':('a','b','c')
|
||||
}
|
||||
try:
|
||||
writePlist(plist, "example.plist")
|
||||
except (InvalidPlistException, NotBinaryPlistException), e:
|
||||
print "Something bad happened:", e
|
||||
|
||||
Plist parsing example:
|
||||
|
||||
from biplist import *
|
||||
try:
|
||||
plist = readPlist("example.plist")
|
||||
print plist
|
||||
except (InvalidPlistException, NotBinaryPlistException), e:
|
||||
print "Not a plist:", e
|
||||
"""
|
||||
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
import datetime
|
||||
import io
|
||||
import math
|
||||
import plistlib
|
||||
from struct import pack, unpack
|
||||
from struct import error as struct_error
|
||||
import sys
|
||||
import time
|
||||
|
||||
try:
|
||||
unicode
|
||||
unicodeEmpty = r''
|
||||
except NameError:
|
||||
unicode = str
|
||||
unicodeEmpty = ''
|
||||
try:
|
||||
long
|
||||
except NameError:
|
||||
long = int
|
||||
try:
|
||||
{}.iteritems
|
||||
iteritems = lambda x: x.iteritems()
|
||||
except AttributeError:
|
||||
iteritems = lambda x: x.items()
|
||||
|
||||
__all__ = [
|
||||
'Uid', 'Data', 'readPlist', 'writePlist', 'readPlistFromString',
|
||||
'writePlistToString', 'InvalidPlistException', 'NotBinaryPlistException'
|
||||
]
|
||||
|
||||
# Apple uses Jan 1, 2001 as a base for all plist date/times.
|
||||
apple_reference_date = datetime.datetime.utcfromtimestamp(978307200)
|
||||
|
||||
class Uid(int):
|
||||
"""Wrapper around integers for representing UID values. This
|
||||
is used in keyed archiving."""
|
||||
def __repr__(self):
|
||||
return "Uid(%d)" % self
|
||||
|
||||
class Data(bytes):
|
||||
"""Wrapper around str types for representing Data values."""
|
||||
pass
|
||||
|
||||
class InvalidPlistException(Exception):
|
||||
"""Raised when the plist is incorrectly formatted."""
|
||||
pass
|
||||
|
||||
class NotBinaryPlistException(Exception):
|
||||
"""Raised when a binary plist was expected but not encountered."""
|
||||
pass
|
||||
|
||||
def readPlist(pathOrFile):
|
||||
"""Raises NotBinaryPlistException, InvalidPlistException"""
|
||||
didOpen = False
|
||||
result = None
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
pathOrFile = open(pathOrFile, 'rb')
|
||||
didOpen = True
|
||||
try:
|
||||
reader = PlistReader(pathOrFile)
|
||||
result = reader.parse()
|
||||
except NotBinaryPlistException as e:
|
||||
try:
|
||||
pathOrFile.seek(0)
|
||||
result = None
|
||||
if hasattr(plistlib, 'loads'):
|
||||
contents = None
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
with open(pathOrFile, 'rb') as f:
|
||||
contents = f.read()
|
||||
else:
|
||||
contents = pathOrFile.read()
|
||||
result = plistlib.loads(contents)
|
||||
else:
|
||||
result = plistlib.readPlist(pathOrFile)
|
||||
result = wrapDataObject(result, for_binary=True)
|
||||
except Exception as e:
|
||||
raise InvalidPlistException(e)
|
||||
finally:
|
||||
if didOpen:
|
||||
pathOrFile.close()
|
||||
return result
|
||||
|
||||
def wrapDataObject(o, for_binary=False):
|
||||
if isinstance(o, Data) and not for_binary:
|
||||
v = sys.version_info
|
||||
if not (v[0] >= 3 and v[1] >= 4):
|
||||
o = plistlib.Data(o)
|
||||
elif isinstance(o, (bytes, plistlib.Data)) and for_binary:
|
||||
if hasattr(o, 'data'):
|
||||
o = Data(o.data)
|
||||
elif isinstance(o, tuple):
|
||||
o = wrapDataObject(list(o), for_binary)
|
||||
o = tuple(o)
|
||||
elif isinstance(o, list):
|
||||
for i in range(len(o)):
|
||||
o[i] = wrapDataObject(o[i], for_binary)
|
||||
elif isinstance(o, dict):
|
||||
for k in o:
|
||||
o[k] = wrapDataObject(o[k], for_binary)
|
||||
return o
|
||||
|
||||
def writePlist(rootObject, pathOrFile, binary=True):
|
||||
if not binary:
|
||||
rootObject = wrapDataObject(rootObject, binary)
|
||||
if hasattr(plistlib, "dump"):
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
with open(pathOrFile, 'wb') as f:
|
||||
return plistlib.dump(rootObject, f)
|
||||
else:
|
||||
return plistlib.dump(rootObject, pathOrFile)
|
||||
else:
|
||||
return plistlib.writePlist(rootObject, pathOrFile)
|
||||
else:
|
||||
didOpen = False
|
||||
if isinstance(pathOrFile, (bytes, unicode)):
|
||||
pathOrFile = open(pathOrFile, 'wb')
|
||||
didOpen = True
|
||||
writer = PlistWriter(pathOrFile)
|
||||
result = writer.writeRoot(rootObject)
|
||||
if didOpen:
|
||||
pathOrFile.close()
|
||||
return result
|
||||
|
||||
def readPlistFromString(data):
|
||||
return readPlist(io.BytesIO(data))
|
||||
|
||||
def writePlistToString(rootObject, binary=True):
|
||||
if not binary:
|
||||
rootObject = wrapDataObject(rootObject, binary)
|
||||
if hasattr(plistlib, "dumps"):
|
||||
return plistlib.dumps(rootObject)
|
||||
elif hasattr(plistlib, "writePlistToBytes"):
|
||||
return plistlib.writePlistToBytes(rootObject)
|
||||
else:
|
||||
return plistlib.writePlistToString(rootObject)
|
||||
else:
|
||||
ioObject = io.BytesIO()
|
||||
writer = PlistWriter(ioObject)
|
||||
writer.writeRoot(rootObject)
|
||||
return ioObject.getvalue()
|
||||
|
||||
def is_stream_binary_plist(stream):
|
||||
stream.seek(0)
|
||||
header = stream.read(7)
|
||||
if header == b'bplist0':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
PlistTrailer = namedtuple('PlistTrailer', 'offsetSize, objectRefSize, offsetCount, topLevelObjectNumber, offsetTableOffset')
|
||||
PlistByteCounts = namedtuple('PlistByteCounts', 'nullBytes, boolBytes, intBytes, realBytes, dateBytes, dataBytes, stringBytes, uidBytes, arrayBytes, setBytes, dictBytes')
|
||||
|
||||
class PlistReader(object):
|
||||
file = None
|
||||
contents = ''
|
||||
offsets = None
|
||||
trailer = None
|
||||
currentOffset = 0
|
||||
|
||||
def __init__(self, fileOrStream):
|
||||
"""Raises NotBinaryPlistException."""
|
||||
self.reset()
|
||||
self.file = fileOrStream
|
||||
|
||||
def parse(self):
|
||||
return self.readRoot()
|
||||
|
||||
def reset(self):
|
||||
self.trailer = None
|
||||
self.contents = ''
|
||||
self.offsets = []
|
||||
self.currentOffset = 0
|
||||
|
||||
def readRoot(self):
|
||||
result = None
|
||||
self.reset()
|
||||
# Get the header, make sure it's a valid file.
|
||||
if not is_stream_binary_plist(self.file):
|
||||
raise NotBinaryPlistException()
|
||||
self.file.seek(0)
|
||||
self.contents = self.file.read()
|
||||
if len(self.contents) < 32:
|
||||
raise InvalidPlistException("File is too short.")
|
||||
trailerContents = self.contents[-32:]
|
||||
try:
|
||||
self.trailer = PlistTrailer._make(unpack("!xxxxxxBBQQQ", trailerContents))
|
||||
offset_size = self.trailer.offsetSize * self.trailer.offsetCount
|
||||
offset = self.trailer.offsetTableOffset
|
||||
offset_contents = self.contents[offset:offset+offset_size]
|
||||
offset_i = 0
|
||||
while offset_i < self.trailer.offsetCount:
|
||||
begin = self.trailer.offsetSize*offset_i
|
||||
tmp_contents = offset_contents[begin:begin+self.trailer.offsetSize]
|
||||
tmp_sized = self.getSizedInteger(tmp_contents, self.trailer.offsetSize)
|
||||
self.offsets.append(tmp_sized)
|
||||
offset_i += 1
|
||||
self.setCurrentOffsetToObjectNumber(self.trailer.topLevelObjectNumber)
|
||||
result = self.readObject()
|
||||
except TypeError as e:
|
||||
raise InvalidPlistException(e)
|
||||
return result
|
||||
|
||||
def setCurrentOffsetToObjectNumber(self, objectNumber):
|
||||
self.currentOffset = self.offsets[objectNumber]
|
||||
|
||||
def readObject(self):
|
||||
result = None
|
||||
tmp_byte = self.contents[self.currentOffset:self.currentOffset+1]
|
||||
marker_byte = unpack("!B", tmp_byte)[0]
|
||||
format = (marker_byte >> 4) & 0x0f
|
||||
extra = marker_byte & 0x0f
|
||||
self.currentOffset += 1
|
||||
|
||||
def proc_extra(extra):
|
||||
if extra == 0b1111:
|
||||
#self.currentOffset += 1
|
||||
extra = self.readObject()
|
||||
return extra
|
||||
|
||||
# bool, null, or fill byte
|
||||
if format == 0b0000:
|
||||
if extra == 0b0000:
|
||||
result = None
|
||||
elif extra == 0b1000:
|
||||
result = False
|
||||
elif extra == 0b1001:
|
||||
result = True
|
||||
elif extra == 0b1111:
|
||||
pass # fill byte
|
||||
else:
|
||||
raise InvalidPlistException("Invalid object found at offset: %d" % (self.currentOffset - 1))
|
||||
# int
|
||||
elif format == 0b0001:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readInteger(pow(2, extra))
|
||||
# real
|
||||
elif format == 0b0010:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readReal(extra)
|
||||
# date
|
||||
elif format == 0b0011 and extra == 0b0011:
|
||||
result = self.readDate()
|
||||
# data
|
||||
elif format == 0b0100:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readData(extra)
|
||||
# ascii string
|
||||
elif format == 0b0101:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readAsciiString(extra)
|
||||
# Unicode string
|
||||
elif format == 0b0110:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readUnicode(extra)
|
||||
# uid
|
||||
elif format == 0b1000:
|
||||
result = self.readUid(extra)
|
||||
# array
|
||||
elif format == 0b1010:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readArray(extra)
|
||||
# set
|
||||
elif format == 0b1100:
|
||||
extra = proc_extra(extra)
|
||||
result = set(self.readArray(extra))
|
||||
# dict
|
||||
elif format == 0b1101:
|
||||
extra = proc_extra(extra)
|
||||
result = self.readDict(extra)
|
||||
else:
|
||||
raise InvalidPlistException("Invalid object found: {format: %s, extra: %s}" % (bin(format), bin(extra)))
|
||||
return result
|
||||
|
||||
def readInteger(self, byteSize):
|
||||
result = 0
|
||||
original_offset = self.currentOffset
|
||||
data = self.contents[self.currentOffset:self.currentOffset + byteSize]
|
||||
result = self.getSizedInteger(data, byteSize, as_number=True)
|
||||
self.currentOffset = original_offset + byteSize
|
||||
return result
|
||||
|
||||
def readReal(self, length):
|
||||
result = 0.0
|
||||
to_read = pow(2, length)
|
||||
data = self.contents[self.currentOffset:self.currentOffset+to_read]
|
||||
if length == 2: # 4 bytes
|
||||
result = unpack('>f', data)[0]
|
||||
elif length == 3: # 8 bytes
|
||||
result = unpack('>d', data)[0]
|
||||
else:
|
||||
raise InvalidPlistException("Unknown real of length %d bytes" % to_read)
|
||||
return result
|
||||
|
||||
def readRefs(self, count):
|
||||
refs = []
|
||||
i = 0
|
||||
while i < count:
|
||||
fragment = self.contents[self.currentOffset:self.currentOffset+self.trailer.objectRefSize]
|
||||
ref = self.getSizedInteger(fragment, len(fragment))
|
||||
refs.append(ref)
|
||||
self.currentOffset += self.trailer.objectRefSize
|
||||
i += 1
|
||||
return refs
|
||||
|
||||
def readArray(self, count):
|
||||
result = []
|
||||
values = self.readRefs(count)
|
||||
i = 0
|
||||
while i < len(values):
|
||||
self.setCurrentOffsetToObjectNumber(values[i])
|
||||
value = self.readObject()
|
||||
result.append(value)
|
||||
i += 1
|
||||
return result
|
||||
|
||||
def readDict(self, count):
|
||||
result = {}
|
||||
keys = self.readRefs(count)
|
||||
values = self.readRefs(count)
|
||||
i = 0
|
||||
while i < len(keys):
|
||||
self.setCurrentOffsetToObjectNumber(keys[i])
|
||||
key = self.readObject()
|
||||
self.setCurrentOffsetToObjectNumber(values[i])
|
||||
value = self.readObject()
|
||||
result[key] = value
|
||||
i += 1
|
||||
return result
|
||||
|
||||
def readAsciiString(self, length):
|
||||
result = unpack("!%ds" % length, self.contents[self.currentOffset:self.currentOffset+length])[0]
|
||||
self.currentOffset += length
|
||||
return result
|
||||
|
||||
def readUnicode(self, length):
|
||||
actual_length = length*2
|
||||
data = self.contents[self.currentOffset:self.currentOffset+actual_length]
|
||||
# unpack not needed?!! data = unpack(">%ds" % (actual_length), data)[0]
|
||||
self.currentOffset += actual_length
|
||||
return data.decode('utf_16_be')
|
||||
|
||||
def readDate(self):
|
||||
result = unpack(">d", self.contents[self.currentOffset:self.currentOffset+8])[0]
|
||||
# Use timedelta to workaround time_t size limitation on 32-bit python.
|
||||
result = datetime.timedelta(seconds=result) + apple_reference_date
|
||||
self.currentOffset += 8
|
||||
return result
|
||||
|
||||
def readData(self, length):
|
||||
result = self.contents[self.currentOffset:self.currentOffset+length]
|
||||
self.currentOffset += length
|
||||
return Data(result)
|
||||
|
||||
def readUid(self, length):
|
||||
return Uid(self.readInteger(length+1))
|
||||
|
||||
def getSizedInteger(self, data, byteSize, as_number=False):
|
||||
"""Numbers of 8 bytes are signed integers when they refer to numbers, but unsigned otherwise."""
|
||||
result = 0
|
||||
# 1, 2, and 4 byte integers are unsigned
|
||||
if byteSize == 1:
|
||||
result = unpack('>B', data)[0]
|
||||
elif byteSize == 2:
|
||||
result = unpack('>H', data)[0]
|
||||
elif byteSize == 4:
|
||||
result = unpack('>L', data)[0]
|
||||
elif byteSize == 8:
|
||||
if as_number:
|
||||
result = unpack('>q', data)[0]
|
||||
else:
|
||||
result = unpack('>Q', data)[0]
|
||||
elif byteSize <= 16:
|
||||
# Handle odd-sized or integers larger than 8 bytes
|
||||
# Don't naively go over 16 bytes, in order to prevent infinite loops.
|
||||
result = 0
|
||||
if hasattr(int, 'from_bytes'):
|
||||
result = int.from_bytes(data, 'big')
|
||||
else:
|
||||
for byte in data:
|
||||
result = (result << 8) | unpack('>B', byte)[0]
|
||||
else:
|
||||
raise InvalidPlistException("Encountered integer longer than 16 bytes.")
|
||||
return result
|
||||
|
||||
class HashableWrapper(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def __repr__(self):
|
||||
return "<HashableWrapper: %s>" % [self.value]
|
||||
|
||||
class BoolWrapper(object):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def __repr__(self):
|
||||
return "<BoolWrapper: %s>" % self.value
|
||||
|
||||
class FloatWrapper(object):
|
||||
_instances = {}
|
||||
def __new__(klass, value):
|
||||
# Ensure FloatWrapper(x) for a given float x is always the same object
|
||||
wrapper = klass._instances.get(value)
|
||||
if wrapper is None:
|
||||
wrapper = object.__new__(klass)
|
||||
wrapper.value = value
|
||||
klass._instances[value] = wrapper
|
||||
return wrapper
|
||||
def __repr__(self):
|
||||
return "<FloatWrapper: %s>" % self.value
|
||||
|
||||
class PlistWriter(object):
|
||||
header = b'bplist00bybiplist1.0'
|
||||
file = None
|
||||
byteCounts = None
|
||||
trailer = None
|
||||
computedUniques = None
|
||||
writtenReferences = None
|
||||
referencePositions = None
|
||||
wrappedTrue = None
|
||||
wrappedFalse = None
|
||||
|
||||
def __init__(self, file):
|
||||
self.reset()
|
||||
self.file = file
|
||||
self.wrappedTrue = BoolWrapper(True)
|
||||
self.wrappedFalse = BoolWrapper(False)
|
||||
|
||||
def reset(self):
|
||||
self.byteCounts = PlistByteCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
self.trailer = PlistTrailer(0, 0, 0, 0, 0)
|
||||
|
||||
# A set of all the uniques which have been computed.
|
||||
self.computedUniques = set()
|
||||
# A list of all the uniques which have been written.
|
||||
self.writtenReferences = {}
|
||||
# A dict of the positions of the written uniques.
|
||||
self.referencePositions = {}
|
||||
|
||||
def positionOfObjectReference(self, obj):
|
||||
"""If the given object has been written already, return its
|
||||
position in the offset table. Otherwise, return None."""
|
||||
return self.writtenReferences.get(obj)
|
||||
|
||||
def writeRoot(self, root):
|
||||
"""
|
||||
Strategy is:
|
||||
- write header
|
||||
- wrap root object so everything is hashable
|
||||
- compute size of objects which will be written
|
||||
- need to do this in order to know how large the object refs
|
||||
will be in the list/dict/set reference lists
|
||||
- write objects
|
||||
- keep objects in writtenReferences
|
||||
- keep positions of object references in referencePositions
|
||||
- write object references with the length computed previously
|
||||
- computer object reference length
|
||||
- write object reference positions
|
||||
- write trailer
|
||||
"""
|
||||
output = self.header
|
||||
wrapped_root = self.wrapRoot(root)
|
||||
should_reference_root = True#not isinstance(wrapped_root, HashableWrapper)
|
||||
self.computeOffsets(wrapped_root, asReference=should_reference_root, isRoot=True)
|
||||
self.trailer = self.trailer._replace(**{'objectRefSize':self.intSize(len(self.computedUniques))})
|
||||
(_, output) = self.writeObjectReference(wrapped_root, output)
|
||||
output = self.writeObject(wrapped_root, output, setReferencePosition=True)
|
||||
|
||||
# output size at this point is an upper bound on how big the
|
||||
# object reference offsets need to be.
|
||||
self.trailer = self.trailer._replace(**{
|
||||
'offsetSize':self.intSize(len(output)),
|
||||
'offsetCount':len(self.computedUniques),
|
||||
'offsetTableOffset':len(output),
|
||||
'topLevelObjectNumber':0
|
||||
})
|
||||
|
||||
output = self.writeOffsetTable(output)
|
||||
output += pack('!xxxxxxBBQQQ', *self.trailer)
|
||||
self.file.write(output)
|
||||
|
||||
def wrapRoot(self, root):
|
||||
if isinstance(root, bool):
|
||||
if root is True:
|
||||
return self.wrappedTrue
|
||||
else:
|
||||
return self.wrappedFalse
|
||||
elif isinstance(root, float):
|
||||
return FloatWrapper(root)
|
||||
elif isinstance(root, set):
|
||||
n = set()
|
||||
for value in root:
|
||||
n.add(self.wrapRoot(value))
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, dict):
|
||||
n = {}
|
||||
for key, value in iteritems(root):
|
||||
n[self.wrapRoot(key)] = self.wrapRoot(value)
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, list):
|
||||
n = []
|
||||
for value in root:
|
||||
n.append(self.wrapRoot(value))
|
||||
return HashableWrapper(n)
|
||||
elif isinstance(root, tuple):
|
||||
n = tuple([self.wrapRoot(value) for value in root])
|
||||
return HashableWrapper(n)
|
||||
else:
|
||||
return root
|
||||
|
||||
def incrementByteCount(self, field, incr=1):
|
||||
self.byteCounts = self.byteCounts._replace(**{field:self.byteCounts.__getattribute__(field) + incr})
|
||||
|
||||
def computeOffsets(self, obj, asReference=False, isRoot=False):
|
||||
def check_key(key):
|
||||
if key is None:
|
||||
raise InvalidPlistException('Dictionary keys cannot be null in plists.')
|
||||
elif isinstance(key, Data):
|
||||
raise InvalidPlistException('Data cannot be dictionary keys in plists.')
|
||||
elif not isinstance(key, (bytes, unicode)):
|
||||
raise InvalidPlistException('Keys must be strings.')
|
||||
|
||||
def proc_size(size):
|
||||
if size > 0b1110:
|
||||
size += self.intSize(size)
|
||||
return size
|
||||
# If this should be a reference, then we keep a record of it in the
|
||||
# uniques table.
|
||||
if asReference:
|
||||
if obj in self.computedUniques:
|
||||
return
|
||||
else:
|
||||
self.computedUniques.add(obj)
|
||||
|
||||
if obj is None:
|
||||
self.incrementByteCount('nullBytes')
|
||||
elif isinstance(obj, BoolWrapper):
|
||||
self.incrementByteCount('boolBytes')
|
||||
elif isinstance(obj, Uid):
|
||||
size = self.intSize(obj)
|
||||
self.incrementByteCount('uidBytes', incr=1+size)
|
||||
elif isinstance(obj, (int, long)):
|
||||
size = self.intSize(obj)
|
||||
self.incrementByteCount('intBytes', incr=1+size)
|
||||
elif isinstance(obj, FloatWrapper):
|
||||
size = self.realSize(obj)
|
||||
self.incrementByteCount('realBytes', incr=1+size)
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
self.incrementByteCount('dateBytes', incr=2)
|
||||
elif isinstance(obj, Data):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('dataBytes', incr=1+size)
|
||||
elif isinstance(obj, (unicode, bytes)):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('stringBytes', incr=1+size)
|
||||
elif isinstance(obj, HashableWrapper):
|
||||
obj = obj.value
|
||||
if isinstance(obj, set):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('setBytes', incr=1+size)
|
||||
for value in obj:
|
||||
self.computeOffsets(value, asReference=True)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('arrayBytes', incr=1+size)
|
||||
for value in obj:
|
||||
asRef = True
|
||||
self.computeOffsets(value, asReference=True)
|
||||
elif isinstance(obj, dict):
|
||||
size = proc_size(len(obj))
|
||||
self.incrementByteCount('dictBytes', incr=1+size)
|
||||
for key, value in iteritems(obj):
|
||||
check_key(key)
|
||||
self.computeOffsets(key, asReference=True)
|
||||
self.computeOffsets(value, asReference=True)
|
||||
else:
|
||||
raise InvalidPlistException("Unknown object type.")
|
||||
|
||||
def writeObjectReference(self, obj, output):
|
||||
"""Tries to write an object reference, adding it to the references
|
||||
table. Does not write the actual object bytes or set the reference
|
||||
position. Returns a tuple of whether the object was a new reference
|
||||
(True if it was, False if it already was in the reference table)
|
||||
and the new output.
|
||||
"""
|
||||
position = self.positionOfObjectReference(obj)
|
||||
if position is None:
|
||||
self.writtenReferences[obj] = len(self.writtenReferences)
|
||||
output += self.binaryInt(len(self.writtenReferences) - 1, byteSize=self.trailer.objectRefSize)
|
||||
return (True, output)
|
||||
else:
|
||||
output += self.binaryInt(position, byteSize=self.trailer.objectRefSize)
|
||||
return (False, output)
|
||||
|
||||
def writeObject(self, obj, output, setReferencePosition=False):
|
||||
"""Serializes the given object to the output. Returns output.
|
||||
If setReferencePosition is True, will set the position the
|
||||
object was written.
|
||||
"""
|
||||
def proc_variable_length(format, length):
|
||||
result = b''
|
||||
if length > 0b1110:
|
||||
result += pack('!B', (format << 4) | 0b1111)
|
||||
result = self.writeObject(length, result)
|
||||
else:
|
||||
result += pack('!B', (format << 4) | length)
|
||||
return result
|
||||
|
||||
if isinstance(obj, (str, unicode)) and obj == unicodeEmpty:
|
||||
# The Apple Plist decoder can't decode a zero length Unicode string.
|
||||
obj = b''
|
||||
|
||||
if setReferencePosition:
|
||||
self.referencePositions[obj] = len(output)
|
||||
|
||||
if obj is None:
|
||||
output += pack('!B', 0b00000000)
|
||||
elif isinstance(obj, BoolWrapper):
|
||||
if obj.value is False:
|
||||
output += pack('!B', 0b00001000)
|
||||
else:
|
||||
output += pack('!B', 0b00001001)
|
||||
elif isinstance(obj, Uid):
|
||||
size = self.intSize(obj)
|
||||
output += pack('!B', (0b1000 << 4) | size - 1)
|
||||
output += self.binaryInt(obj)
|
||||
elif isinstance(obj, (int, long)):
|
||||
byteSize = self.intSize(obj)
|
||||
root = math.log(byteSize, 2)
|
||||
output += pack('!B', (0b0001 << 4) | int(root))
|
||||
output += self.binaryInt(obj, as_number=True)
|
||||
elif isinstance(obj, FloatWrapper):
|
||||
# just use doubles
|
||||
output += pack('!B', (0b0010 << 4) | 3)
|
||||
output += self.binaryReal(obj)
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
timestamp = (obj - apple_reference_date).total_seconds()
|
||||
output += pack('!B', 0b00110011)
|
||||
output += pack('!d', float(timestamp))
|
||||
elif isinstance(obj, Data):
|
||||
output += proc_variable_length(0b0100, len(obj))
|
||||
output += obj
|
||||
elif isinstance(obj, unicode):
|
||||
byteData = obj.encode('utf_16_be')
|
||||
output += proc_variable_length(0b0110, len(byteData)//2)
|
||||
output += byteData
|
||||
elif isinstance(obj, bytes):
|
||||
output += proc_variable_length(0b0101, len(obj))
|
||||
output += obj
|
||||
elif isinstance(obj, HashableWrapper):
|
||||
obj = obj.value
|
||||
if isinstance(obj, (set, list, tuple)):
|
||||
if isinstance(obj, set):
|
||||
output += proc_variable_length(0b1100, len(obj))
|
||||
else:
|
||||
output += proc_variable_length(0b1010, len(obj))
|
||||
|
||||
objectsToWrite = []
|
||||
for objRef in obj:
|
||||
(isNew, output) = self.writeObjectReference(objRef, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(objRef)
|
||||
for objRef in objectsToWrite:
|
||||
output = self.writeObject(objRef, output, setReferencePosition=True)
|
||||
elif isinstance(obj, dict):
|
||||
output += proc_variable_length(0b1101, len(obj))
|
||||
keys = []
|
||||
values = []
|
||||
objectsToWrite = []
|
||||
for key, value in iteritems(obj):
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
for key in keys:
|
||||
(isNew, output) = self.writeObjectReference(key, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(key)
|
||||
for value in values:
|
||||
(isNew, output) = self.writeObjectReference(value, output)
|
||||
if isNew:
|
||||
objectsToWrite.append(value)
|
||||
for objRef in objectsToWrite:
|
||||
output = self.writeObject(objRef, output, setReferencePosition=True)
|
||||
return output
|
||||
|
||||
def writeOffsetTable(self, output):
|
||||
"""Writes all of the object reference offsets."""
|
||||
all_positions = []
|
||||
writtenReferences = list(self.writtenReferences.items())
|
||||
writtenReferences.sort(key=lambda x: x[1])
|
||||
for obj,order in writtenReferences:
|
||||
# Porting note: Elsewhere we deliberately replace empty unicdoe strings
|
||||
# with empty binary strings, but the empty unicode string
|
||||
# goes into writtenReferences. This isn't an issue in Py2
|
||||
# because u'' and b'' have the same hash; but it is in
|
||||
# Py3, where they don't.
|
||||
if bytes != str and obj == unicodeEmpty:
|
||||
obj = b''
|
||||
position = self.referencePositions.get(obj)
|
||||
if position is None:
|
||||
raise InvalidPlistException("Error while writing offsets table. Object not found. %s" % obj)
|
||||
output += self.binaryInt(position, self.trailer.offsetSize)
|
||||
all_positions.append(position)
|
||||
return output
|
||||
|
||||
def binaryReal(self, obj):
|
||||
# just use doubles
|
||||
result = pack('>d', obj.value)
|
||||
return result
|
||||
|
||||
def binaryInt(self, obj, byteSize=None, as_number=False):
|
||||
result = b''
|
||||
if byteSize is None:
|
||||
byteSize = self.intSize(obj)
|
||||
if byteSize == 1:
|
||||
result += pack('>B', obj)
|
||||
elif byteSize == 2:
|
||||
result += pack('>H', obj)
|
||||
elif byteSize == 4:
|
||||
result += pack('>L', obj)
|
||||
elif byteSize == 8:
|
||||
if as_number:
|
||||
result += pack('>q', obj)
|
||||
else:
|
||||
result += pack('>Q', obj)
|
||||
elif byteSize <= 16:
|
||||
try:
|
||||
result = pack('>Q', 0) + pack('>Q', obj)
|
||||
except struct_error as e:
|
||||
raise InvalidPlistException("Unable to pack integer %d: %s" % (obj, e))
|
||||
else:
|
||||
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 16 bytes.")
|
||||
return result
|
||||
|
||||
def intSize(self, obj):
|
||||
"""Returns the number of bytes necessary to store the given integer."""
|
||||
# SIGNED
|
||||
if obj < 0: # Signed integer, always 8 bytes
|
||||
return 8
|
||||
# UNSIGNED
|
||||
elif obj <= 0xFF: # 1 byte
|
||||
return 1
|
||||
elif obj <= 0xFFFF: # 2 bytes
|
||||
return 2
|
||||
elif obj <= 0xFFFFFFFF: # 4 bytes
|
||||
return 4
|
||||
# SIGNED
|
||||
# 0x7FFFFFFFFFFFFFFF is the max.
|
||||
elif obj <= 0x7FFFFFFFFFFFFFFF: # 8 bytes signed
|
||||
return 8
|
||||
elif obj <= 0xffffffffffffffff: # 8 bytes unsigned
|
||||
return 16
|
||||
else:
|
||||
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 8 bytes.")
|
||||
|
||||
def realSize(self, obj):
|
||||
return 8
|
||||
468
lib/bs4/__init__.py
Normal file
468
lib/bs4/__init__.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Beautiful Soup
|
||||
Elixir and Tonic
|
||||
"The Screen-Scraper's Friend"
|
||||
http://www.crummy.com/software/BeautifulSoup/
|
||||
|
||||
Beautiful Soup uses a pluggable XML or HTML parser to parse a
|
||||
(possibly invalid) document into a tree representation. Beautiful Soup
|
||||
provides provides methods and Pythonic idioms that make it easy to
|
||||
navigate, search, and modify the parse tree.
|
||||
|
||||
Beautiful Soup works with Python 2.6 and up. It works better if lxml
|
||||
and/or html5lib is installed.
|
||||
|
||||
For more than you ever wanted to know about Beautiful Soup, see the
|
||||
documentation:
|
||||
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
|
||||
"""
|
||||
|
||||
__author__ = "Leonard Richardson (leonardr@segfault.org)"
|
||||
__version__ = "4.4.0"
|
||||
__copyright__ = "Copyright (c) 2004-2015 Leonard Richardson"
|
||||
__license__ = "MIT"
|
||||
|
||||
__all__ = ['BeautifulSoup']
|
||||
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
from .builder import builder_registry, ParserRejectedMarkup
|
||||
from .dammit import UnicodeDammit
|
||||
from .element import (
|
||||
CData,
|
||||
Comment,
|
||||
DEFAULT_OUTPUT_ENCODING,
|
||||
Declaration,
|
||||
Doctype,
|
||||
NavigableString,
|
||||
PageElement,
|
||||
ProcessingInstruction,
|
||||
ResultSet,
|
||||
SoupStrainer,
|
||||
Tag,
|
||||
)
|
||||
|
||||
# The very first thing we do is give a useful error if someone is
|
||||
# running this code under Python 3 without converting it.
|
||||
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'<>'You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
|
||||
|
||||
class BeautifulSoup(Tag):
|
||||
"""
|
||||
This class defines the basic interface called by the tree builders.
|
||||
|
||||
These methods will be called by the parser:
|
||||
reset()
|
||||
feed(markup)
|
||||
|
||||
The tree builder may call these methods from its feed() implementation:
|
||||
handle_starttag(name, attrs) # See note about return value
|
||||
handle_endtag(name)
|
||||
handle_data(data) # Appends to the current data node
|
||||
endData(containerClass=NavigableString) # Ends the current data node
|
||||
|
||||
No matter how complicated the underlying parser is, you should be
|
||||
able to build a tree using 'start tag' events, 'end tag' events,
|
||||
'data' events, and "done with data" events.
|
||||
|
||||
If you encounter an empty-element tag (aka a self-closing tag,
|
||||
like HTML's <br> tag), call handle_starttag and then
|
||||
handle_endtag.
|
||||
"""
|
||||
ROOT_TAG_NAME = u'[document]'
|
||||
|
||||
# If the end-user gives no indication which tree builder they
|
||||
# want, look for one with these features.
|
||||
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
|
||||
|
||||
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
|
||||
|
||||
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nTo get rid of this warning, change this:\n\n BeautifulSoup([your markup])\n\nto this:\n\n BeautifulSoup([your markup], \"%(parser)s\")\n"
|
||||
|
||||
def __init__(self, markup="", features=None, builder=None,
|
||||
parse_only=None, from_encoding=None, exclude_encodings=None,
|
||||
**kwargs):
|
||||
"""The Soup object is initialized as the 'root tag', and the
|
||||
provided markup (which can be a string or a file-like object)
|
||||
is fed into the underlying parser."""
|
||||
|
||||
if 'convertEntities' in kwargs:
|
||||
warnings.warn(
|
||||
"BS4 does not respect the convertEntities argument to the "
|
||||
"BeautifulSoup constructor. Entities are always converted "
|
||||
"to Unicode characters.")
|
||||
|
||||
if 'markupMassage' in kwargs:
|
||||
del kwargs['markupMassage']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the markupMassage argument to the "
|
||||
"BeautifulSoup constructor. The tree builder is responsible "
|
||||
"for any necessary markup massage.")
|
||||
|
||||
if 'smartQuotesTo' in kwargs:
|
||||
del kwargs['smartQuotesTo']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the smartQuotesTo argument to the "
|
||||
"BeautifulSoup constructor. Smart quotes are always converted "
|
||||
"to Unicode characters.")
|
||||
|
||||
if 'selfClosingTags' in kwargs:
|
||||
del kwargs['selfClosingTags']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the selfClosingTags argument to the "
|
||||
"BeautifulSoup constructor. The tree builder is responsible "
|
||||
"for understanding self-closing tags.")
|
||||
|
||||
if 'isHTML' in kwargs:
|
||||
del kwargs['isHTML']
|
||||
warnings.warn(
|
||||
"BS4 does not respect the isHTML argument to the "
|
||||
"BeautifulSoup constructor. Suggest you use "
|
||||
"features='lxml' for HTML and features='lxml-xml' for "
|
||||
"XML.")
|
||||
|
||||
def deprecated_argument(old_name, new_name):
|
||||
if old_name in kwargs:
|
||||
warnings.warn(
|
||||
'The "%s" argument to the BeautifulSoup constructor '
|
||||
'has been renamed to "%s."' % (old_name, new_name))
|
||||
value = kwargs[old_name]
|
||||
del kwargs[old_name]
|
||||
return value
|
||||
return None
|
||||
|
||||
parse_only = parse_only or deprecated_argument(
|
||||
"parseOnlyThese", "parse_only")
|
||||
|
||||
from_encoding = from_encoding or deprecated_argument(
|
||||
"fromEncoding", "from_encoding")
|
||||
|
||||
if len(kwargs) > 0:
|
||||
arg = kwargs.keys().pop()
|
||||
raise TypeError(
|
||||
"__init__() got an unexpected keyword argument '%s'" % arg)
|
||||
|
||||
if builder is None:
|
||||
original_features = features
|
||||
if isinstance(features, basestring):
|
||||
features = [features]
|
||||
if features is None or len(features) == 0:
|
||||
features = self.DEFAULT_BUILDER_FEATURES
|
||||
builder_class = builder_registry.lookup(*features)
|
||||
if builder_class is None:
|
||||
raise FeatureNotFound(
|
||||
"Couldn't find a tree builder with the features you "
|
||||
"requested: %s. Do you need to install a parser library?"
|
||||
% ",".join(features))
|
||||
builder = builder_class()
|
||||
if not (original_features == builder.NAME or
|
||||
original_features in builder.ALTERNATE_NAMES):
|
||||
if builder.is_xml:
|
||||
markup_type = "XML"
|
||||
else:
|
||||
markup_type = "HTML"
|
||||
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % dict(
|
||||
parser=builder.NAME,
|
||||
markup_type=markup_type))
|
||||
|
||||
self.builder = builder
|
||||
self.is_xml = builder.is_xml
|
||||
self.builder.soup = self
|
||||
|
||||
self.parse_only = parse_only
|
||||
|
||||
if hasattr(markup, 'read'): # It's a file-type object.
|
||||
markup = markup.read()
|
||||
elif len(markup) <= 256:
|
||||
# Print out warnings for a couple beginner problems
|
||||
# involving passing non-markup to Beautiful Soup.
|
||||
# Beautiful Soup will still parse the input as markup,
|
||||
# just in case that's what the user really wants.
|
||||
if (isinstance(markup, unicode)
|
||||
and not os.path.supports_unicode_filenames):
|
||||
possible_filename = markup.encode("utf8")
|
||||
else:
|
||||
possible_filename = markup
|
||||
is_file = False
|
||||
try:
|
||||
is_file = os.path.exists(possible_filename)
|
||||
except Exception, e:
|
||||
# This is almost certainly a problem involving
|
||||
# characters not valid in filenames on this
|
||||
# system. Just let it go.
|
||||
pass
|
||||
if is_file:
|
||||
if isinstance(markup, unicode):
|
||||
markup = markup.encode("utf8")
|
||||
warnings.warn(
|
||||
'"%s" looks like a filename, not markup. You should probably open this file and pass the filehandle into Beautiful Soup.' % markup)
|
||||
if markup[:5] == "http:" or markup[:6] == "https:":
|
||||
# TODO: This is ugly but I couldn't get it to work in
|
||||
# Python 3 otherwise.
|
||||
if ((isinstance(markup, bytes) and not b' ' in markup)
|
||||
or (isinstance(markup, unicode) and not u' ' in markup)):
|
||||
if isinstance(markup, unicode):
|
||||
markup = markup.encode("utf8")
|
||||
warnings.warn(
|
||||
'"%s" looks like a URL. Beautiful Soup is not an HTTP client. You should probably use an HTTP client to get the document behind the URL, and feed that document to Beautiful Soup.' % markup)
|
||||
|
||||
for (self.markup, self.original_encoding, self.declared_html_encoding,
|
||||
self.contains_replacement_characters) in (
|
||||
self.builder.prepare_markup(
|
||||
markup, from_encoding, exclude_encodings=exclude_encodings)):
|
||||
self.reset()
|
||||
try:
|
||||
self._feed()
|
||||
break
|
||||
except ParserRejectedMarkup:
|
||||
pass
|
||||
|
||||
# Clear out the markup and remove the builder's circular
|
||||
# reference to this object.
|
||||
self.markup = None
|
||||
self.builder.soup = None
|
||||
|
||||
def __copy__(self):
|
||||
return type(self)(self.encode(), builder=self.builder)
|
||||
|
||||
def __getstate__(self):
|
||||
# Frequently a tree builder can't be pickled.
|
||||
d = dict(self.__dict__)
|
||||
if 'builder' in d and not self.builder.picklable:
|
||||
del d['builder']
|
||||
return d
|
||||
|
||||
def _feed(self):
|
||||
# Convert the document to Unicode.
|
||||
self.builder.reset()
|
||||
|
||||
self.builder.feed(self.markup)
|
||||
# Close out any unfinished strings and close all the open tags.
|
||||
self.endData()
|
||||
while self.currentTag.name != self.ROOT_TAG_NAME:
|
||||
self.popTag()
|
||||
|
||||
def reset(self):
|
||||
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
|
||||
self.hidden = 1
|
||||
self.builder.reset()
|
||||
self.current_data = []
|
||||
self.currentTag = None
|
||||
self.tagStack = []
|
||||
self.preserve_whitespace_tag_stack = []
|
||||
self.pushTag(self)
|
||||
|
||||
def new_tag(self, name, namespace=None, nsprefix=None, **attrs):
|
||||
"""Create a new tag associated with this soup."""
|
||||
return Tag(None, self.builder, name, namespace, nsprefix, attrs)
|
||||
|
||||
def new_string(self, s, subclass=NavigableString):
|
||||
"""Create a new NavigableString associated with this soup."""
|
||||
return subclass(s)
|
||||
|
||||
def insert_before(self, successor):
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
|
||||
|
||||
def insert_after(self, successor):
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
|
||||
|
||||
def popTag(self):
|
||||
tag = self.tagStack.pop()
|
||||
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
|
||||
self.preserve_whitespace_tag_stack.pop()
|
||||
#print "Pop", tag.name
|
||||
if self.tagStack:
|
||||
self.currentTag = self.tagStack[-1]
|
||||
return self.currentTag
|
||||
|
||||
def pushTag(self, tag):
|
||||
#print "Push", tag.name
|
||||
if self.currentTag:
|
||||
self.currentTag.contents.append(tag)
|
||||
self.tagStack.append(tag)
|
||||
self.currentTag = self.tagStack[-1]
|
||||
if tag.name in self.builder.preserve_whitespace_tags:
|
||||
self.preserve_whitespace_tag_stack.append(tag)
|
||||
|
||||
def endData(self, containerClass=NavigableString):
|
||||
if self.current_data:
|
||||
current_data = u''.join(self.current_data)
|
||||
# If whitespace is not preserved, and this string contains
|
||||
# nothing but ASCII spaces, replace it with a single space
|
||||
# or newline.
|
||||
if not self.preserve_whitespace_tag_stack:
|
||||
strippable = True
|
||||
for i in current_data:
|
||||
if i not in self.ASCII_SPACES:
|
||||
strippable = False
|
||||
break
|
||||
if strippable:
|
||||
if '\n' in current_data:
|
||||
current_data = '\n'
|
||||
else:
|
||||
current_data = ' '
|
||||
|
||||
# Reset the data collector.
|
||||
self.current_data = []
|
||||
|
||||
# Should we add this string to the tree at all?
|
||||
if self.parse_only and len(self.tagStack) <= 1 and \
|
||||
(not self.parse_only.text or \
|
||||
not self.parse_only.search(current_data)):
|
||||
return
|
||||
|
||||
o = containerClass(current_data)
|
||||
self.object_was_parsed(o)
|
||||
|
||||
def object_was_parsed(self, o, parent=None, most_recent_element=None):
|
||||
"""Add an object to the parse tree."""
|
||||
parent = parent or self.currentTag
|
||||
previous_element = most_recent_element or self._most_recent_element
|
||||
|
||||
next_element = previous_sibling = next_sibling = None
|
||||
if isinstance(o, Tag):
|
||||
next_element = o.next_element
|
||||
next_sibling = o.next_sibling
|
||||
previous_sibling = o.previous_sibling
|
||||
if not previous_element:
|
||||
previous_element = o.previous_element
|
||||
|
||||
o.setup(parent, previous_element, next_element, previous_sibling, next_sibling)
|
||||
|
||||
self._most_recent_element = o
|
||||
parent.contents.append(o)
|
||||
|
||||
if parent.next_sibling:
|
||||
# This node is being inserted into an element that has
|
||||
# already been parsed. Deal with any dangling references.
|
||||
index = parent.contents.index(o)
|
||||
if index == 0:
|
||||
previous_element = parent
|
||||
previous_sibling = None
|
||||
else:
|
||||
previous_element = previous_sibling = parent.contents[index-1]
|
||||
if index == len(parent.contents)-1:
|
||||
next_element = parent.next_sibling
|
||||
next_sibling = None
|
||||
else:
|
||||
next_element = next_sibling = parent.contents[index+1]
|
||||
|
||||
o.previous_element = previous_element
|
||||
if previous_element:
|
||||
previous_element.next_element = o
|
||||
o.next_element = next_element
|
||||
if next_element:
|
||||
next_element.previous_element = o
|
||||
o.next_sibling = next_sibling
|
||||
if next_sibling:
|
||||
next_sibling.previous_sibling = o
|
||||
o.previous_sibling = previous_sibling
|
||||
if previous_sibling:
|
||||
previous_sibling.next_sibling = o
|
||||
|
||||
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
|
||||
"""Pops the tag stack up to and including the most recent
|
||||
instance of the given tag. If inclusivePop is false, pops the tag
|
||||
stack up to but *not* including the most recent instqance of
|
||||
the given tag."""
|
||||
#print "Popping to %s" % name
|
||||
if name == self.ROOT_TAG_NAME:
|
||||
# The BeautifulSoup object itself can never be popped.
|
||||
return
|
||||
|
||||
most_recently_popped = None
|
||||
|
||||
stack_size = len(self.tagStack)
|
||||
for i in range(stack_size - 1, 0, -1):
|
||||
t = self.tagStack[i]
|
||||
if (name == t.name and nsprefix == t.prefix):
|
||||
if inclusivePop:
|
||||
most_recently_popped = self.popTag()
|
||||
break
|
||||
most_recently_popped = self.popTag()
|
||||
|
||||
return most_recently_popped
|
||||
|
||||
def handle_starttag(self, name, namespace, nsprefix, attrs):
|
||||
"""Push a start tag on to the stack.
|
||||
|
||||
If this method returns None, the tag was rejected by the
|
||||
SoupStrainer. You should proceed as if the tag had not occured
|
||||
in the document. For instance, if this was a self-closing tag,
|
||||
don't call handle_endtag.
|
||||
"""
|
||||
|
||||
# print "Start tag %s: %s" % (name, attrs)
|
||||
self.endData()
|
||||
|
||||
if (self.parse_only and len(self.tagStack) <= 1
|
||||
and (self.parse_only.text
|
||||
or not self.parse_only.search_tag(name, attrs))):
|
||||
return None
|
||||
|
||||
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
|
||||
self.currentTag, self._most_recent_element)
|
||||
if tag is None:
|
||||
return tag
|
||||
if self._most_recent_element:
|
||||
self._most_recent_element.next_element = tag
|
||||
self._most_recent_element = tag
|
||||
self.pushTag(tag)
|
||||
return tag
|
||||
|
||||
def handle_endtag(self, name, nsprefix=None):
|
||||
#print "End tag: " + name
|
||||
self.endData()
|
||||
self._popToTag(name, nsprefix)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.current_data.append(data)
|
||||
|
||||
def decode(self, pretty_print=False,
|
||||
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
|
||||
formatter="minimal"):
|
||||
"""Returns a string or Unicode representation of this document.
|
||||
To get Unicode, pass None for encoding."""
|
||||
|
||||
if self.is_xml:
|
||||
# Print the XML declaration
|
||||
encoding_part = ''
|
||||
if eventual_encoding != None:
|
||||
encoding_part = ' encoding="%s"' % eventual_encoding
|
||||
prefix = u'<?xml version="1.0"%s?>\n' % encoding_part
|
||||
else:
|
||||
prefix = u''
|
||||
if not pretty_print:
|
||||
indent_level = None
|
||||
else:
|
||||
indent_level = 0
|
||||
return prefix + super(BeautifulSoup, self).decode(
|
||||
indent_level, eventual_encoding, formatter)
|
||||
|
||||
# Alias to make it easier to type import: 'from bs4 import _soup'
|
||||
_s = BeautifulSoup
|
||||
_soup = BeautifulSoup
|
||||
|
||||
class BeautifulStoneSoup(BeautifulSoup):
|
||||
"""Deprecated interface to an XML parser."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['features'] = 'xml'
|
||||
warnings.warn(
|
||||
'The BeautifulStoneSoup class is deprecated. Instead of using '
|
||||
'it, pass features="xml" into the BeautifulSoup constructor.')
|
||||
super(BeautifulStoneSoup, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StopParsing(Exception):
|
||||
pass
|
||||
|
||||
class FeatureNotFound(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
#By default, act as an HTML pretty-printer.
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
soup = BeautifulSoup(sys.stdin)
|
||||
print soup.prettify()
|
||||
324
lib/bs4/builder/__init__.py
Normal file
324
lib/bs4/builder/__init__.py
Normal file
@@ -0,0 +1,324 @@
|
||||
from collections import defaultdict
|
||||
import itertools
|
||||
import sys
|
||||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
ContentMetaAttributeValue,
|
||||
whitespace_re
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'HTMLTreeBuilder',
|
||||
'SAXTreeBuilder',
|
||||
'TreeBuilder',
|
||||
'TreeBuilderRegistry',
|
||||
]
|
||||
|
||||
# Some useful features for a TreeBuilder to have.
|
||||
FAST = 'fast'
|
||||
PERMISSIVE = 'permissive'
|
||||
STRICT = 'strict'
|
||||
XML = 'xml'
|
||||
HTML = 'html'
|
||||
HTML_5 = 'html5'
|
||||
|
||||
|
||||
class TreeBuilderRegistry(object):
|
||||
|
||||
def __init__(self):
|
||||
self.builders_for_feature = defaultdict(list)
|
||||
self.builders = []
|
||||
|
||||
def register(self, treebuilder_class):
|
||||
"""Register a treebuilder based on its advertised features."""
|
||||
for feature in treebuilder_class.features:
|
||||
self.builders_for_feature[feature].insert(0, treebuilder_class)
|
||||
self.builders.insert(0, treebuilder_class)
|
||||
|
||||
def lookup(self, *features):
|
||||
if len(self.builders) == 0:
|
||||
# There are no builders at all.
|
||||
return None
|
||||
|
||||
if len(features) == 0:
|
||||
# They didn't ask for any features. Give them the most
|
||||
# recently registered builder.
|
||||
return self.builders[0]
|
||||
|
||||
# Go down the list of features in order, and eliminate any builders
|
||||
# that don't match every feature.
|
||||
features = list(features)
|
||||
features.reverse()
|
||||
candidates = None
|
||||
candidate_set = None
|
||||
while len(features) > 0:
|
||||
feature = features.pop()
|
||||
we_have_the_feature = self.builders_for_feature.get(feature, [])
|
||||
if len(we_have_the_feature) > 0:
|
||||
if candidates is None:
|
||||
candidates = we_have_the_feature
|
||||
candidate_set = set(candidates)
|
||||
else:
|
||||
# Eliminate any candidates that don't have this feature.
|
||||
candidate_set = candidate_set.intersection(
|
||||
set(we_have_the_feature))
|
||||
|
||||
# The only valid candidates are the ones in candidate_set.
|
||||
# Go through the original list of candidates and pick the first one
|
||||
# that's in candidate_set.
|
||||
if candidate_set is None:
|
||||
return None
|
||||
for candidate in candidates:
|
||||
if candidate in candidate_set:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
# The BeautifulSoup class will take feature lists from developers and use them
|
||||
# to look up builders in this registry.
|
||||
builder_registry = TreeBuilderRegistry()
|
||||
|
||||
class TreeBuilder(object):
|
||||
"""Turn a document into a Beautiful Soup object tree."""
|
||||
|
||||
NAME = "[Unknown tree builder]"
|
||||
ALTERNATE_NAMES = []
|
||||
features = []
|
||||
|
||||
is_xml = False
|
||||
picklable = False
|
||||
preserve_whitespace_tags = set()
|
||||
empty_element_tags = None # A tag will be considered an empty-element
|
||||
# tag when and only when it has no contents.
|
||||
|
||||
# A value for these tag/attribute combinations is a space- or
|
||||
# comma-separated list of CDATA, rather than a single CDATA.
|
||||
cdata_list_attributes = {}
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.soup = None
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def can_be_empty_element(self, tag_name):
|
||||
"""Might a tag with this name be an empty-element tag?
|
||||
|
||||
The final markup may or may not actually present this tag as
|
||||
self-closing.
|
||||
|
||||
For instance: an HTMLBuilder does not consider a <p> tag to be
|
||||
an empty-element tag (it's not in
|
||||
HTMLBuilder.empty_element_tags). This means an empty <p> tag
|
||||
will be presented as "<p></p>", not "<p />".
|
||||
|
||||
The default implementation has no opinion about which tags are
|
||||
empty-element tags, so a tag will be presented as an
|
||||
empty-element tag if and only if it has no contents.
|
||||
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
|
||||
be left alone.
|
||||
"""
|
||||
if self.empty_element_tags is None:
|
||||
return True
|
||||
return tag_name in self.empty_element_tags
|
||||
|
||||
def feed(self, markup):
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None):
|
||||
return markup, None, None, False
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""Wrap an HTML fragment to make it look like a document.
|
||||
|
||||
Different parsers do this differently. For instance, lxml
|
||||
introduces an empty <head> tag, and html5lib
|
||||
doesn't. Abstracting this away lets us write simple tests
|
||||
which run HTML fragments through the parser and compare the
|
||||
results against other HTML fragments.
|
||||
|
||||
This method should not be used outside of tests.
|
||||
"""
|
||||
return fragment
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
return False
|
||||
|
||||
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
|
||||
"""Replaces class="foo bar" with class=["foo", "bar"]
|
||||
|
||||
Modifies its input in place.
|
||||
"""
|
||||
if not attrs:
|
||||
return attrs
|
||||
if self.cdata_list_attributes:
|
||||
universal = self.cdata_list_attributes.get('*', [])
|
||||
tag_specific = self.cdata_list_attributes.get(
|
||||
tag_name.lower(), None)
|
||||
for attr in attrs.keys():
|
||||
if attr in universal or (tag_specific and attr in tag_specific):
|
||||
# We have a "class"-type attribute whose string
|
||||
# value is a whitespace-separated list of
|
||||
# values. Split it into a list.
|
||||
value = attrs[attr]
|
||||
if isinstance(value, basestring):
|
||||
values = whitespace_re.split(value)
|
||||
else:
|
||||
# html5lib sometimes calls setAttributes twice
|
||||
# for the same tag when rearranging the parse
|
||||
# tree. On the second call the attribute value
|
||||
# here is already a list. If this happens,
|
||||
# leave the value alone rather than trying to
|
||||
# split it again.
|
||||
values = value
|
||||
attrs[attr] = values
|
||||
return attrs
|
||||
|
||||
class SAXTreeBuilder(TreeBuilder):
|
||||
"""A Beautiful Soup treebuilder that listens for SAX events."""
|
||||
|
||||
def feed(self, markup):
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def startElement(self, name, attrs):
|
||||
attrs = dict((key[1], value) for key, value in list(attrs.items()))
|
||||
#print "Start %s, %r" % (name, attrs)
|
||||
self.soup.handle_starttag(name, attrs)
|
||||
|
||||
def endElement(self, name):
|
||||
#print "End %s" % name
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def startElementNS(self, nsTuple, nodeName, attrs):
|
||||
# Throw away (ns, nodeName) for now.
|
||||
self.startElement(nodeName, attrs)
|
||||
|
||||
def endElementNS(self, nsTuple, nodeName):
|
||||
# Throw away (ns, nodeName) for now.
|
||||
self.endElement(nodeName)
|
||||
#handler.endElementNS((ns, node.nodeName), node.nodeName)
|
||||
|
||||
def startPrefixMapping(self, prefix, nodeValue):
|
||||
# Ignore the prefix for now.
|
||||
pass
|
||||
|
||||
def endPrefixMapping(self, prefix):
|
||||
# Ignore the prefix for now.
|
||||
# handler.endPrefixMapping(prefix)
|
||||
pass
|
||||
|
||||
def characters(self, content):
|
||||
self.soup.handle_data(content)
|
||||
|
||||
def startDocument(self):
|
||||
pass
|
||||
|
||||
def endDocument(self):
|
||||
pass
|
||||
|
||||
|
||||
class HTMLTreeBuilder(TreeBuilder):
|
||||
"""This TreeBuilder knows facts about HTML.
|
||||
|
||||
Such as which tags are empty-element tags.
|
||||
"""
|
||||
|
||||
preserve_whitespace_tags = set(['pre', 'textarea'])
|
||||
empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta',
|
||||
'spacer', 'link', 'frame', 'base'])
|
||||
|
||||
# The HTML standard defines these attributes as containing a
|
||||
# space-separated list of values, not a single value. That is,
|
||||
# class="foo bar" means that the 'class' attribute has two values,
|
||||
# 'foo' and 'bar', not the single value 'foo bar'. When we
|
||||
# encounter one of these attributes, we will parse its value into
|
||||
# a list of values if possible. Upon output, the list will be
|
||||
# converted back into a string.
|
||||
cdata_list_attributes = {
|
||||
"*" : ['class', 'accesskey', 'dropzone'],
|
||||
"a" : ['rel', 'rev'],
|
||||
"link" : ['rel', 'rev'],
|
||||
"td" : ["headers"],
|
||||
"th" : ["headers"],
|
||||
"td" : ["headers"],
|
||||
"form" : ["accept-charset"],
|
||||
"object" : ["archive"],
|
||||
|
||||
# These are HTML5 specific, as are *.accesskey and *.dropzone above.
|
||||
"area" : ["rel"],
|
||||
"icon" : ["sizes"],
|
||||
"iframe" : ["sandbox"],
|
||||
"output" : ["for"],
|
||||
}
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
# We are only interested in <meta> tags
|
||||
if tag.name != 'meta':
|
||||
return False
|
||||
|
||||
http_equiv = tag.get('http-equiv')
|
||||
content = tag.get('content')
|
||||
charset = tag.get('charset')
|
||||
|
||||
# We are interested in <meta> tags that say what encoding the
|
||||
# document was originally in. This means HTML 5-style <meta>
|
||||
# tags that provide the "charset" attribute. It also means
|
||||
# HTML 4-style <meta> tags that provide the "content"
|
||||
# attribute and have "http-equiv" set to "content-type".
|
||||
#
|
||||
# In both cases we will replace the value of the appropriate
|
||||
# attribute with a standin object that can take on any
|
||||
# encoding.
|
||||
meta_encoding = None
|
||||
if charset is not None:
|
||||
# HTML 5 style:
|
||||
# <meta charset="utf8">
|
||||
meta_encoding = charset
|
||||
tag['charset'] = CharsetMetaAttributeValue(charset)
|
||||
|
||||
elif (content is not None and http_equiv is not None
|
||||
and http_equiv.lower() == 'content-type'):
|
||||
# HTML 4 style:
|
||||
# <meta http-equiv="content-type" content="text/html; charset=utf8">
|
||||
tag['content'] = ContentMetaAttributeValue(content)
|
||||
|
||||
return (meta_encoding is not None)
|
||||
|
||||
def register_treebuilders_from(module):
|
||||
"""Copy TreeBuilders from the given module into this module."""
|
||||
# I'm fairly sure this is not the best way to do this.
|
||||
this_module = sys.modules['bs4.builder']
|
||||
for name in module.__all__:
|
||||
obj = getattr(module, name)
|
||||
|
||||
if issubclass(obj, TreeBuilder):
|
||||
setattr(this_module, name, obj)
|
||||
this_module.__all__.append(name)
|
||||
# Register the builder while we're at it.
|
||||
this_module.builder_registry.register(obj)
|
||||
|
||||
class ParserRejectedMarkup(Exception):
|
||||
pass
|
||||
|
||||
# Builders are registered in reverse order of priority, so that custom
|
||||
# builder registrations will take precedence. In general, we want lxml
|
||||
# to take precedence over html5lib, because it's faster. And we only
|
||||
# want to use HTMLParser as a last result.
|
||||
from . import _htmlparser
|
||||
register_treebuilders_from(_htmlparser)
|
||||
try:
|
||||
from . import _html5lib
|
||||
register_treebuilders_from(_html5lib)
|
||||
except ImportError:
|
||||
# They don't have html5lib installed.
|
||||
pass
|
||||
try:
|
||||
from . import _lxml
|
||||
register_treebuilders_from(_lxml)
|
||||
except ImportError:
|
||||
# They don't have lxml installed.
|
||||
pass
|
||||
329
lib/bs4/builder/_html5lib.py
Normal file
329
lib/bs4/builder/_html5lib.py
Normal file
@@ -0,0 +1,329 @@
|
||||
__all__ = [
|
||||
'HTML5TreeBuilder',
|
||||
]
|
||||
|
||||
from pdb import set_trace
|
||||
import warnings
|
||||
from bs4.builder import (
|
||||
PERMISSIVE,
|
||||
HTML,
|
||||
HTML_5,
|
||||
HTMLTreeBuilder,
|
||||
)
|
||||
from bs4.element import (
|
||||
NamespacedAttribute,
|
||||
whitespace_re,
|
||||
)
|
||||
import html5lib
|
||||
from html5lib.constants import namespaces
|
||||
from bs4.element import (
|
||||
Comment,
|
||||
Doctype,
|
||||
NavigableString,
|
||||
Tag,
|
||||
)
|
||||
|
||||
class HTML5TreeBuilder(HTMLTreeBuilder):
|
||||
"""Use html5lib to build a tree."""
|
||||
|
||||
NAME = "html5lib"
|
||||
|
||||
features = [NAME, PERMISSIVE, HTML_5, HTML]
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
# Store the user-specified encoding for use later on.
|
||||
self.user_specified_encoding = user_specified_encoding
|
||||
|
||||
# document_declared_encoding and exclude_encodings aren't used
|
||||
# ATM because the html5lib TreeBuilder doesn't use
|
||||
# UnicodeDammit.
|
||||
if exclude_encodings:
|
||||
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
|
||||
yield (markup, None, None, False)
|
||||
|
||||
# These methods are defined by Beautiful Soup.
|
||||
def feed(self, markup):
|
||||
if self.soup.parse_only is not None:
|
||||
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
|
||||
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
|
||||
doc = parser.parse(markup, encoding=self.user_specified_encoding)
|
||||
|
||||
# Set the character encoding detected by the tokenizer.
|
||||
if isinstance(markup, unicode):
|
||||
# We need to special-case this because html5lib sets
|
||||
# charEncoding to UTF-8 if it gets Unicode input.
|
||||
doc.original_encoding = None
|
||||
else:
|
||||
doc.original_encoding = parser.tokenizer.stream.charEncoding[0]
|
||||
|
||||
def create_treebuilder(self, namespaceHTMLElements):
|
||||
self.underlying_builder = TreeBuilderForHtml5lib(
|
||||
self.soup, namespaceHTMLElements)
|
||||
return self.underlying_builder
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<html><head></head><body>%s</body></html>' % fragment
|
||||
|
||||
|
||||
class TreeBuilderForHtml5lib(html5lib.treebuilders._base.TreeBuilder):
|
||||
|
||||
def __init__(self, soup, namespaceHTMLElements):
|
||||
self.soup = soup
|
||||
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
|
||||
|
||||
def documentClass(self):
|
||||
self.soup.reset()
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
||||
def insertDoctype(self, token):
|
||||
name = token["name"]
|
||||
publicId = token["publicId"]
|
||||
systemId = token["systemId"]
|
||||
|
||||
doctype = Doctype.for_name_and_ids(name, publicId, systemId)
|
||||
self.soup.object_was_parsed(doctype)
|
||||
|
||||
def elementClass(self, name, namespace):
|
||||
tag = self.soup.new_tag(name, namespace)
|
||||
return Element(tag, self.soup, namespace)
|
||||
|
||||
def commentClass(self, data):
|
||||
return TextNode(Comment(data), self.soup)
|
||||
|
||||
def fragmentClass(self):
|
||||
self.soup = BeautifulSoup("")
|
||||
self.soup.name = "[document_fragment]"
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
||||
def appendChild(self, node):
|
||||
# XXX This code is not covered by the BS4 tests.
|
||||
self.soup.append(node.element)
|
||||
|
||||
def getDocument(self):
|
||||
return self.soup
|
||||
|
||||
def getFragment(self):
|
||||
return html5lib.treebuilders._base.TreeBuilder.getFragment(self).element
|
||||
|
||||
class AttrList(object):
|
||||
def __init__(self, element):
|
||||
self.element = element
|
||||
self.attrs = dict(self.element.attrs)
|
||||
def __iter__(self):
|
||||
return list(self.attrs.items()).__iter__()
|
||||
def __setitem__(self, name, value):
|
||||
# If this attribute is a multi-valued attribute for this element,
|
||||
# turn its value into a list.
|
||||
list_attr = HTML5TreeBuilder.cdata_list_attributes
|
||||
if (name in list_attr['*']
|
||||
or (self.element.name in list_attr
|
||||
and name in list_attr[self.element.name])):
|
||||
value = whitespace_re.split(value)
|
||||
self.element[name] = value
|
||||
def items(self):
|
||||
return list(self.attrs.items())
|
||||
def keys(self):
|
||||
return list(self.attrs.keys())
|
||||
def __len__(self):
|
||||
return len(self.attrs)
|
||||
def __getitem__(self, name):
|
||||
return self.attrs[name]
|
||||
def __contains__(self, name):
|
||||
return name in list(self.attrs.keys())
|
||||
|
||||
|
||||
class Element(html5lib.treebuilders._base.Node):
|
||||
def __init__(self, element, soup, namespace):
|
||||
html5lib.treebuilders._base.Node.__init__(self, element.name)
|
||||
self.element = element
|
||||
self.soup = soup
|
||||
self.namespace = namespace
|
||||
|
||||
def appendChild(self, node):
|
||||
string_child = child = None
|
||||
if isinstance(node, basestring):
|
||||
# Some other piece of code decided to pass in a string
|
||||
# instead of creating a TextElement object to contain the
|
||||
# string.
|
||||
string_child = child = node
|
||||
elif isinstance(node, Tag):
|
||||
# Some other piece of code decided to pass in a Tag
|
||||
# instead of creating an Element object to contain the
|
||||
# Tag.
|
||||
child = node
|
||||
elif node.element.__class__ == NavigableString:
|
||||
string_child = child = node.element
|
||||
else:
|
||||
child = node.element
|
||||
|
||||
if not isinstance(child, basestring) and child.parent is not None:
|
||||
node.element.extract()
|
||||
|
||||
if (string_child and self.element.contents
|
||||
and self.element.contents[-1].__class__ == NavigableString):
|
||||
# We are appending a string onto another string.
|
||||
# TODO This has O(n^2) performance, for input like
|
||||
# "a</a>a</a>a</a>..."
|
||||
old_element = self.element.contents[-1]
|
||||
new_element = self.soup.new_string(old_element + string_child)
|
||||
old_element.replace_with(new_element)
|
||||
self.soup._most_recent_element = new_element
|
||||
else:
|
||||
if isinstance(node, basestring):
|
||||
# Create a brand new NavigableString from this string.
|
||||
child = self.soup.new_string(node)
|
||||
|
||||
# Tell Beautiful Soup to act as if it parsed this element
|
||||
# immediately after the parent's last descendant. (Or
|
||||
# immediately after the parent, if it has no children.)
|
||||
if self.element.contents:
|
||||
most_recent_element = self.element._last_descendant(False)
|
||||
elif self.element.next_element is not None:
|
||||
# Something from further ahead in the parse tree is
|
||||
# being inserted into this earlier element. This is
|
||||
# very annoying because it means an expensive search
|
||||
# for the last element in the tree.
|
||||
most_recent_element = self.soup._last_descendant()
|
||||
else:
|
||||
most_recent_element = self.element
|
||||
|
||||
self.soup.object_was_parsed(
|
||||
child, parent=self.element,
|
||||
most_recent_element=most_recent_element)
|
||||
|
||||
def getAttributes(self):
|
||||
return AttrList(self.element)
|
||||
|
||||
def setAttributes(self, attributes):
|
||||
|
||||
if attributes is not None and len(attributes) > 0:
|
||||
|
||||
converted_attributes = []
|
||||
for name, value in list(attributes.items()):
|
||||
if isinstance(name, tuple):
|
||||
new_name = NamespacedAttribute(*name)
|
||||
del attributes[name]
|
||||
attributes[new_name] = value
|
||||
|
||||
self.soup.builder._replace_cdata_list_attribute_values(
|
||||
self.name, attributes)
|
||||
for name, value in attributes.items():
|
||||
self.element[name] = value
|
||||
|
||||
# The attributes may contain variables that need substitution.
|
||||
# Call set_up_substitutions manually.
|
||||
#
|
||||
# The Tag constructor called this method when the Tag was created,
|
||||
# but we just set/changed the attributes, so call it again.
|
||||
self.soup.builder.set_up_substitutions(self.element)
|
||||
attributes = property(getAttributes, setAttributes)
|
||||
|
||||
def insertText(self, data, insertBefore=None):
|
||||
if insertBefore:
|
||||
text = TextNode(self.soup.new_string(data), self.soup)
|
||||
self.insertBefore(data, insertBefore)
|
||||
else:
|
||||
self.appendChild(data)
|
||||
|
||||
def insertBefore(self, node, refNode):
|
||||
index = self.element.index(refNode.element)
|
||||
if (node.element.__class__ == NavigableString and self.element.contents
|
||||
and self.element.contents[index-1].__class__ == NavigableString):
|
||||
# (See comments in appendChild)
|
||||
old_node = self.element.contents[index-1]
|
||||
new_str = self.soup.new_string(old_node + node.element)
|
||||
old_node.replace_with(new_str)
|
||||
else:
|
||||
self.element.insert(index, node.element)
|
||||
node.parent = self
|
||||
|
||||
def removeChild(self, node):
|
||||
node.element.extract()
|
||||
|
||||
def reparentChildren(self, new_parent):
|
||||
"""Move all of this tag's children into another tag."""
|
||||
# print "MOVE", self.element.contents
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent.element
|
||||
element = self.element
|
||||
new_parent_element = new_parent.element
|
||||
# Determine what this tag's next_element will be once all the children
|
||||
# are removed.
|
||||
final_next_element = element.next_sibling
|
||||
|
||||
new_parents_last_descendant = new_parent_element._last_descendant(False, False)
|
||||
if len(new_parent_element.contents) > 0:
|
||||
# The new parent already contains children. We will be
|
||||
# appending this tag's children to the end.
|
||||
new_parents_last_child = new_parent_element.contents[-1]
|
||||
new_parents_last_descendant_next_element = new_parents_last_descendant.next_element
|
||||
else:
|
||||
# The new parent contains no children.
|
||||
new_parents_last_child = None
|
||||
new_parents_last_descendant_next_element = new_parent_element.next_element
|
||||
|
||||
to_append = element.contents
|
||||
append_after = new_parent_element.contents
|
||||
if len(to_append) > 0:
|
||||
# Set the first child's previous_element and previous_sibling
|
||||
# to elements within the new parent
|
||||
first_child = to_append[0]
|
||||
if new_parents_last_descendant:
|
||||
first_child.previous_element = new_parents_last_descendant
|
||||
else:
|
||||
first_child.previous_element = new_parent_element
|
||||
first_child.previous_sibling = new_parents_last_child
|
||||
if new_parents_last_descendant:
|
||||
new_parents_last_descendant.next_element = first_child
|
||||
else:
|
||||
new_parent_element.next_element = first_child
|
||||
if new_parents_last_child:
|
||||
new_parents_last_child.next_sibling = first_child
|
||||
|
||||
# Fix the last child's next_element and next_sibling
|
||||
last_child = to_append[-1]
|
||||
last_child.next_element = new_parents_last_descendant_next_element
|
||||
if new_parents_last_descendant_next_element:
|
||||
new_parents_last_descendant_next_element.previous_element = last_child
|
||||
last_child.next_sibling = None
|
||||
|
||||
for child in to_append:
|
||||
child.parent = new_parent_element
|
||||
new_parent_element.contents.append(child)
|
||||
|
||||
# Now that this element has no children, change its .next_element.
|
||||
element.contents = []
|
||||
element.next_element = final_next_element
|
||||
|
||||
# print "DONE WITH MOVE"
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent_element
|
||||
|
||||
def cloneNode(self):
|
||||
tag = self.soup.new_tag(self.element.name, self.namespace)
|
||||
node = Element(tag, self.soup, self.namespace)
|
||||
for key,value in self.attributes:
|
||||
node.attributes[key] = value
|
||||
return node
|
||||
|
||||
def hasContent(self):
|
||||
return self.element.contents
|
||||
|
||||
def getNameTuple(self):
|
||||
if self.namespace == None:
|
||||
return namespaces["html"], self.name
|
||||
else:
|
||||
return self.namespace, self.name
|
||||
|
||||
nameTuple = property(getNameTuple)
|
||||
|
||||
class TextNode(Element):
|
||||
def __init__(self, element, soup):
|
||||
html5lib.treebuilders._base.Node.__init__(self, None)
|
||||
self.element = element
|
||||
self.soup = soup
|
||||
|
||||
def cloneNode(self):
|
||||
raise NotImplementedError
|
||||
262
lib/bs4/builder/_htmlparser.py
Normal file
262
lib/bs4/builder/_htmlparser.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Use the HTMLParser library to parse HTML files that aren't too bad."""
|
||||
|
||||
__all__ = [
|
||||
'HTMLParserTreeBuilder',
|
||||
]
|
||||
|
||||
from HTMLParser import HTMLParser
|
||||
|
||||
try:
|
||||
from HTMLParser import HTMLParseError
|
||||
except ImportError, e:
|
||||
# HTMLParseError is removed in Python 3.5. Since it can never be
|
||||
# thrown in 3.5, we can just define our own class as a placeholder.
|
||||
class HTMLParseError(Exception):
|
||||
pass
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
|
||||
# argument, which we'd like to set to False. Unfortunately,
|
||||
# http://bugs.python.org/issue13273 makes strict=True a better bet
|
||||
# before Python 3.2.3.
|
||||
#
|
||||
# At the end of this file, we monkeypatch HTMLParser so that
|
||||
# strict=True works well on Python 3.2.2.
|
||||
major, minor, release = sys.version_info[:3]
|
||||
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
|
||||
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
|
||||
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
|
||||
|
||||
|
||||
from bs4.element import (
|
||||
CData,
|
||||
Comment,
|
||||
Declaration,
|
||||
Doctype,
|
||||
ProcessingInstruction,
|
||||
)
|
||||
from bs4.dammit import EntitySubstitution, UnicodeDammit
|
||||
|
||||
from bs4.builder import (
|
||||
HTML,
|
||||
HTMLTreeBuilder,
|
||||
STRICT,
|
||||
)
|
||||
|
||||
|
||||
HTMLPARSER = 'html.parser'
|
||||
|
||||
class BeautifulSoupHTMLParser(HTMLParser):
|
||||
def handle_starttag(self, name, attrs):
|
||||
# XXX namespace
|
||||
attr_dict = {}
|
||||
for key, value in attrs:
|
||||
# Change None attribute values to the empty string
|
||||
# for consistency with the other tree builders.
|
||||
if value is None:
|
||||
value = ''
|
||||
attr_dict[key] = value
|
||||
attrvalue = '""'
|
||||
self.soup.handle_starttag(name, None, None, attr_dict)
|
||||
|
||||
def handle_endtag(self, name):
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.soup.handle_data(data)
|
||||
|
||||
def handle_charref(self, name):
|
||||
# XXX workaround for a bug in HTMLParser. Remove this once
|
||||
# it's fixed in all supported versions.
|
||||
# http://bugs.python.org/issue13633
|
||||
if name.startswith('x'):
|
||||
real_name = int(name.lstrip('x'), 16)
|
||||
elif name.startswith('X'):
|
||||
real_name = int(name.lstrip('X'), 16)
|
||||
else:
|
||||
real_name = int(name)
|
||||
|
||||
try:
|
||||
data = unichr(real_name)
|
||||
except (ValueError, OverflowError), e:
|
||||
data = u"\N{REPLACEMENT CHARACTER}"
|
||||
|
||||
self.handle_data(data)
|
||||
|
||||
def handle_entityref(self, name):
|
||||
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
|
||||
if character is not None:
|
||||
data = character
|
||||
else:
|
||||
data = "&%s;" % name
|
||||
self.handle_data(data)
|
||||
|
||||
def handle_comment(self, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Comment)
|
||||
|
||||
def handle_decl(self, data):
|
||||
self.soup.endData()
|
||||
if data.startswith("DOCTYPE "):
|
||||
data = data[len("DOCTYPE "):]
|
||||
elif data == 'DOCTYPE':
|
||||
# i.e. "<!DOCTYPE>"
|
||||
data = ''
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Doctype)
|
||||
|
||||
def unknown_decl(self, data):
|
||||
if data.upper().startswith('CDATA['):
|
||||
cls = CData
|
||||
data = data[len('CDATA['):]
|
||||
else:
|
||||
cls = Declaration
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(cls)
|
||||
|
||||
def handle_pi(self, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(ProcessingInstruction)
|
||||
|
||||
|
||||
class HTMLParserTreeBuilder(HTMLTreeBuilder):
|
||||
|
||||
is_xml = False
|
||||
picklable = True
|
||||
NAME = HTMLPARSER
|
||||
features = [NAME, HTML, STRICT]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
|
||||
kwargs['strict'] = False
|
||||
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
|
||||
kwargs['convert_charrefs'] = False
|
||||
self.parser_args = (args, kwargs)
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
"""
|
||||
:return: A 4-tuple (markup, original encoding, encoding
|
||||
declared within markup, whether any characters had to be
|
||||
replaced with REPLACEMENT CHARACTER).
|
||||
"""
|
||||
if isinstance(markup, unicode):
|
||||
yield (markup, None, None, False)
|
||||
return
|
||||
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
|
||||
exclude_encodings=exclude_encodings)
|
||||
yield (dammit.markup, dammit.original_encoding,
|
||||
dammit.declared_html_encoding,
|
||||
dammit.contains_replacement_characters)
|
||||
|
||||
def feed(self, markup):
|
||||
args, kwargs = self.parser_args
|
||||
parser = BeautifulSoupHTMLParser(*args, **kwargs)
|
||||
parser.soup = self.soup
|
||||
try:
|
||||
parser.feed(markup)
|
||||
except HTMLParseError, e:
|
||||
warnings.warn(RuntimeWarning(
|
||||
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
|
||||
raise e
|
||||
|
||||
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
|
||||
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
|
||||
# string.
|
||||
#
|
||||
# XXX This code can be removed once most Python 3 users are on 3.2.3.
|
||||
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
|
||||
import re
|
||||
attrfind_tolerant = re.compile(
|
||||
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
|
||||
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
|
||||
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
|
||||
|
||||
locatestarttagend = re.compile(r"""
|
||||
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
|
||||
(?:\s+ # whitespace before attribute name
|
||||
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
|
||||
(?:\s*=\s* # value indicator
|
||||
(?:'[^']*' # LITA-enclosed value
|
||||
|\"[^\"]*\" # LIT-enclosed value
|
||||
|[^'\">\s]+ # bare value
|
||||
)
|
||||
)?
|
||||
)
|
||||
)*
|
||||
\s* # trailing whitespace
|
||||
""", re.VERBOSE)
|
||||
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
|
||||
|
||||
from html.parser import tagfind, attrfind
|
||||
|
||||
def parse_starttag(self, i):
|
||||
self.__starttag_text = None
|
||||
endpos = self.check_for_whole_start_tag(i)
|
||||
if endpos < 0:
|
||||
return endpos
|
||||
rawdata = self.rawdata
|
||||
self.__starttag_text = rawdata[i:endpos]
|
||||
|
||||
# Now parse the data between i+1 and j into a tag and attrs
|
||||
attrs = []
|
||||
match = tagfind.match(rawdata, i+1)
|
||||
assert match, 'unexpected call to parse_starttag()'
|
||||
k = match.end()
|
||||
self.lasttag = tag = rawdata[i+1:k].lower()
|
||||
while k < endpos:
|
||||
if self.strict:
|
||||
m = attrfind.match(rawdata, k)
|
||||
else:
|
||||
m = attrfind_tolerant.match(rawdata, k)
|
||||
if not m:
|
||||
break
|
||||
attrname, rest, attrvalue = m.group(1, 2, 3)
|
||||
if not rest:
|
||||
attrvalue = None
|
||||
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
|
||||
attrvalue[:1] == '"' == attrvalue[-1:]:
|
||||
attrvalue = attrvalue[1:-1]
|
||||
if attrvalue:
|
||||
attrvalue = self.unescape(attrvalue)
|
||||
attrs.append((attrname.lower(), attrvalue))
|
||||
k = m.end()
|
||||
|
||||
end = rawdata[k:endpos].strip()
|
||||
if end not in (">", "/>"):
|
||||
lineno, offset = self.getpos()
|
||||
if "\n" in self.__starttag_text:
|
||||
lineno = lineno + self.__starttag_text.count("\n")
|
||||
offset = len(self.__starttag_text) \
|
||||
- self.__starttag_text.rfind("\n")
|
||||
else:
|
||||
offset = offset + len(self.__starttag_text)
|
||||
if self.strict:
|
||||
self.error("junk characters in start tag: %r"
|
||||
% (rawdata[k:endpos][:20],))
|
||||
self.handle_data(rawdata[i:endpos])
|
||||
return endpos
|
||||
if end.endswith('/>'):
|
||||
# XHTML-style empty tag: <span attr="value" />
|
||||
self.handle_startendtag(tag, attrs)
|
||||
else:
|
||||
self.handle_starttag(tag, attrs)
|
||||
if tag in self.CDATA_CONTENT_ELEMENTS:
|
||||
self.set_cdata_mode(tag)
|
||||
return endpos
|
||||
|
||||
def set_cdata_mode(self, elem):
|
||||
self.cdata_elem = elem.lower()
|
||||
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
|
||||
|
||||
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
|
||||
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
|
||||
|
||||
CONSTRUCTOR_TAKES_STRICT = True
|
||||
248
lib/bs4/builder/_lxml.py
Normal file
248
lib/bs4/builder/_lxml.py
Normal file
@@ -0,0 +1,248 @@
|
||||
__all__ = [
|
||||
'LXMLTreeBuilderForXML',
|
||||
'LXMLTreeBuilder',
|
||||
]
|
||||
|
||||
from io import BytesIO
|
||||
from StringIO import StringIO
|
||||
import collections
|
||||
from lxml import etree
|
||||
from bs4.element import (
|
||||
Comment,
|
||||
Doctype,
|
||||
NamespacedAttribute,
|
||||
ProcessingInstruction,
|
||||
)
|
||||
from bs4.builder import (
|
||||
FAST,
|
||||
HTML,
|
||||
HTMLTreeBuilder,
|
||||
PERMISSIVE,
|
||||
ParserRejectedMarkup,
|
||||
TreeBuilder,
|
||||
XML)
|
||||
from bs4.dammit import EncodingDetector
|
||||
|
||||
LXML = 'lxml'
|
||||
|
||||
class LXMLTreeBuilderForXML(TreeBuilder):
|
||||
DEFAULT_PARSER_CLASS = etree.XMLParser
|
||||
|
||||
is_xml = True
|
||||
|
||||
NAME = "lxml-xml"
|
||||
ALTERNATE_NAMES = ["xml"]
|
||||
|
||||
# Well, it's permissive by XML parser standards.
|
||||
features = [NAME, LXML, XML, FAST, PERMISSIVE]
|
||||
|
||||
CHUNK_SIZE = 512
|
||||
|
||||
# This namespace mapping is specified in the XML Namespace
|
||||
# standard.
|
||||
DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"}
|
||||
|
||||
def default_parser(self, encoding):
|
||||
# This can either return a parser object or a class, which
|
||||
# will be instantiated with default arguments.
|
||||
if self._default_parser is not None:
|
||||
return self._default_parser
|
||||
return etree.XMLParser(
|
||||
target=self, strip_cdata=False, recover=True, encoding=encoding)
|
||||
|
||||
def parser_for(self, encoding):
|
||||
# Use the default parser.
|
||||
parser = self.default_parser(encoding)
|
||||
|
||||
if isinstance(parser, collections.Callable):
|
||||
# Instantiate the parser with default arguments
|
||||
parser = parser(target=self, strip_cdata=False, encoding=encoding)
|
||||
return parser
|
||||
|
||||
def __init__(self, parser=None, empty_element_tags=None):
|
||||
# TODO: Issue a warning if parser is present but not a
|
||||
# callable, since that means there's no way to create new
|
||||
# parsers for different encodings.
|
||||
self._default_parser = parser
|
||||
if empty_element_tags is not None:
|
||||
self.empty_element_tags = set(empty_element_tags)
|
||||
self.soup = None
|
||||
self.nsmaps = [self.DEFAULT_NSMAPS]
|
||||
|
||||
def _getNsTag(self, tag):
|
||||
# Split the namespace URL out of a fully-qualified lxml tag
|
||||
# name. Copied from lxml's src/lxml/sax.py.
|
||||
if tag[0] == '{':
|
||||
return tuple(tag[1:].split('}', 1))
|
||||
else:
|
||||
return (None, tag)
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
exclude_encodings=None,
|
||||
document_declared_encoding=None):
|
||||
"""
|
||||
:yield: A series of 4-tuples.
|
||||
(markup, encoding, declared encoding,
|
||||
has undergone character replacement)
|
||||
|
||||
Each 4-tuple represents a strategy for parsing the document.
|
||||
"""
|
||||
if isinstance(markup, unicode):
|
||||
# We were given Unicode. Maybe lxml can parse Unicode on
|
||||
# this system?
|
||||
yield markup, None, document_declared_encoding, False
|
||||
|
||||
if isinstance(markup, unicode):
|
||||
# No, apparently not. Convert the Unicode to UTF-8 and
|
||||
# tell lxml to parse it as UTF-8.
|
||||
yield (markup.encode("utf8"), "utf8",
|
||||
document_declared_encoding, False)
|
||||
|
||||
# Instead of using UnicodeDammit to convert the bytestring to
|
||||
# Unicode using different encodings, use EncodingDetector to
|
||||
# iterate over the encodings, and tell lxml to try to parse
|
||||
# the document as each one in turn.
|
||||
is_html = not self.is_xml
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
detector = EncodingDetector(
|
||||
markup, try_encodings, is_html, exclude_encodings)
|
||||
for encoding in detector.encodings:
|
||||
yield (detector.markup, encoding, document_declared_encoding, False)
|
||||
|
||||
def feed(self, markup):
|
||||
if isinstance(markup, bytes):
|
||||
markup = BytesIO(markup)
|
||||
elif isinstance(markup, unicode):
|
||||
markup = StringIO(markup)
|
||||
|
||||
# Call feed() at least once, even if the markup is empty,
|
||||
# or the parser won't be initialized.
|
||||
data = markup.read(self.CHUNK_SIZE)
|
||||
try:
|
||||
self.parser = self.parser_for(self.soup.original_encoding)
|
||||
self.parser.feed(data)
|
||||
while len(data) != 0:
|
||||
# Now call feed() on the rest of the data, chunk by chunk.
|
||||
data = markup.read(self.CHUNK_SIZE)
|
||||
if len(data) != 0:
|
||||
self.parser.feed(data)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
|
||||
def close(self):
|
||||
self.nsmaps = [self.DEFAULT_NSMAPS]
|
||||
|
||||
def start(self, name, attrs, nsmap={}):
|
||||
# Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
|
||||
attrs = dict(attrs)
|
||||
nsprefix = None
|
||||
# Invert each namespace map as it comes in.
|
||||
if len(self.nsmaps) > 1:
|
||||
# There are no new namespaces for this tag, but
|
||||
# non-default namespaces are in play, so we need a
|
||||
# separate tag stack to know when they end.
|
||||
self.nsmaps.append(None)
|
||||
elif len(nsmap) > 0:
|
||||
# A new namespace mapping has come into play.
|
||||
inverted_nsmap = dict((value, key) for key, value in nsmap.items())
|
||||
self.nsmaps.append(inverted_nsmap)
|
||||
# Also treat the namespace mapping as a set of attributes on the
|
||||
# tag, so we can recreate it later.
|
||||
attrs = attrs.copy()
|
||||
for prefix, namespace in nsmap.items():
|
||||
attribute = NamespacedAttribute(
|
||||
"xmlns", prefix, "http://www.w3.org/2000/xmlns/")
|
||||
attrs[attribute] = namespace
|
||||
|
||||
# Namespaces are in play. Find any attributes that came in
|
||||
# from lxml with namespaces attached to their names, and
|
||||
# turn then into NamespacedAttribute objects.
|
||||
new_attrs = {}
|
||||
for attr, value in attrs.items():
|
||||
namespace, attr = self._getNsTag(attr)
|
||||
if namespace is None:
|
||||
new_attrs[attr] = value
|
||||
else:
|
||||
nsprefix = self._prefix_for_namespace(namespace)
|
||||
attr = NamespacedAttribute(nsprefix, attr, namespace)
|
||||
new_attrs[attr] = value
|
||||
attrs = new_attrs
|
||||
|
||||
namespace, name = self._getNsTag(name)
|
||||
nsprefix = self._prefix_for_namespace(namespace)
|
||||
self.soup.handle_starttag(name, namespace, nsprefix, attrs)
|
||||
|
||||
def _prefix_for_namespace(self, namespace):
|
||||
"""Find the currently active prefix for the given namespace."""
|
||||
if namespace is None:
|
||||
return None
|
||||
for inverted_nsmap in reversed(self.nsmaps):
|
||||
if inverted_nsmap is not None and namespace in inverted_nsmap:
|
||||
return inverted_nsmap[namespace]
|
||||
return None
|
||||
|
||||
def end(self, name):
|
||||
self.soup.endData()
|
||||
completed_tag = self.soup.tagStack[-1]
|
||||
namespace, name = self._getNsTag(name)
|
||||
nsprefix = None
|
||||
if namespace is not None:
|
||||
for inverted_nsmap in reversed(self.nsmaps):
|
||||
if inverted_nsmap is not None and namespace in inverted_nsmap:
|
||||
nsprefix = inverted_nsmap[namespace]
|
||||
break
|
||||
self.soup.handle_endtag(name, nsprefix)
|
||||
if len(self.nsmaps) > 1:
|
||||
# This tag, or one of its parents, introduced a namespace
|
||||
# mapping, so pop it off the stack.
|
||||
self.nsmaps.pop()
|
||||
|
||||
def pi(self, target, data):
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(target + ' ' + data)
|
||||
self.soup.endData(ProcessingInstruction)
|
||||
|
||||
def data(self, content):
|
||||
self.soup.handle_data(content)
|
||||
|
||||
def doctype(self, name, pubid, system):
|
||||
self.soup.endData()
|
||||
doctype = Doctype.for_name_and_ids(name, pubid, system)
|
||||
self.soup.object_was_parsed(doctype)
|
||||
|
||||
def comment(self, content):
|
||||
"Handle comments as Comment objects."
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(content)
|
||||
self.soup.endData(Comment)
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
|
||||
|
||||
|
||||
class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
|
||||
|
||||
NAME = LXML
|
||||
ALTERNATE_NAMES = ["lxml-html"]
|
||||
|
||||
features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
|
||||
is_xml = False
|
||||
|
||||
def default_parser(self, encoding):
|
||||
return etree.HTMLParser
|
||||
|
||||
def feed(self, markup):
|
||||
encoding = self.soup.original_encoding
|
||||
try:
|
||||
self.parser = self.parser_for(encoding)
|
||||
self.parser.feed(markup)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""See `TreeBuilder`."""
|
||||
return u'<html><body>%s</body></html>' % fragment
|
||||
839
lib/bs4/dammit.py
Normal file
839
lib/bs4/dammit.py
Normal file
@@ -0,0 +1,839 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Beautiful Soup bonus library: Unicode, Dammit
|
||||
|
||||
This library converts a bytestream to Unicode through any means
|
||||
necessary. It is heavily based on code from Mark Pilgrim's Universal
|
||||
Feed Parser. It works best on XML and HTML, but it does not rewrite the
|
||||
XML or HTML to reflect a new encoding; that's the tree builder's job.
|
||||
"""
|
||||
|
||||
from pdb import set_trace
|
||||
import codecs
|
||||
from htmlentitydefs import codepoint2name
|
||||
import re
|
||||
import logging
|
||||
import string
|
||||
|
||||
# Import a library to autodetect character encodings.
|
||||
chardet_type = None
|
||||
try:
|
||||
# First try the fast C implementation.
|
||||
# PyPI package: cchardet
|
||||
import cchardet
|
||||
def chardet_dammit(s):
|
||||
return cchardet.detect(s)['encoding']
|
||||
except ImportError:
|
||||
try:
|
||||
# Fall back to the pure Python implementation
|
||||
# Debian package: python-chardet
|
||||
# PyPI package: chardet
|
||||
import chardet
|
||||
def chardet_dammit(s):
|
||||
return chardet.detect(s)['encoding']
|
||||
#import chardet.constants
|
||||
#chardet.constants._debug = 1
|
||||
except ImportError:
|
||||
# No chardet available.
|
||||
def chardet_dammit(s):
|
||||
return None
|
||||
|
||||
# Available from http://cjkpython.i18n.org/.
|
||||
try:
|
||||
import iconv_codec
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
xml_encoding_re = re.compile(
|
||||
'^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I)
|
||||
html_meta_re = re.compile(
|
||||
'<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
|
||||
|
||||
class EntitySubstitution(object):
|
||||
|
||||
"""Substitute XML or HTML entities for the corresponding characters."""
|
||||
|
||||
def _populate_class_variables():
|
||||
lookup = {}
|
||||
reverse_lookup = {}
|
||||
characters_for_re = []
|
||||
for codepoint, name in list(codepoint2name.items()):
|
||||
character = unichr(codepoint)
|
||||
if codepoint != 34:
|
||||
# There's no point in turning the quotation mark into
|
||||
# ", unless it happens within an attribute value, which
|
||||
# is handled elsewhere.
|
||||
characters_for_re.append(character)
|
||||
lookup[character] = name
|
||||
# But we do want to turn " into the quotation mark.
|
||||
reverse_lookup[name] = character
|
||||
re_definition = "[%s]" % "".join(characters_for_re)
|
||||
return lookup, reverse_lookup, re.compile(re_definition)
|
||||
(CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER,
|
||||
CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables()
|
||||
|
||||
CHARACTER_TO_XML_ENTITY = {
|
||||
"'": "apos",
|
||||
'"': "quot",
|
||||
"&": "amp",
|
||||
"<": "lt",
|
||||
">": "gt",
|
||||
}
|
||||
|
||||
BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
|
||||
"&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)"
|
||||
")")
|
||||
|
||||
AMPERSAND_OR_BRACKET = re.compile("([<>&])")
|
||||
|
||||
@classmethod
|
||||
def _substitute_html_entity(cls, matchobj):
|
||||
entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0))
|
||||
return "&%s;" % entity
|
||||
|
||||
@classmethod
|
||||
def _substitute_xml_entity(cls, matchobj):
|
||||
"""Used with a regular expression to substitute the
|
||||
appropriate XML entity for an XML special character."""
|
||||
entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)]
|
||||
return "&%s;" % entity
|
||||
|
||||
@classmethod
|
||||
def quoted_attribute_value(self, value):
|
||||
"""Make a value into a quoted XML attribute, possibly escaping it.
|
||||
|
||||
Most strings will be quoted using double quotes.
|
||||
|
||||
Bob's Bar -> "Bob's Bar"
|
||||
|
||||
If a string contains double quotes, it will be quoted using
|
||||
single quotes.
|
||||
|
||||
Welcome to "my bar" -> 'Welcome to "my bar"'
|
||||
|
||||
If a string contains both single and double quotes, the
|
||||
double quotes will be escaped, and the string will be quoted
|
||||
using double quotes.
|
||||
|
||||
Welcome to "Bob's Bar" -> "Welcome to "Bob's bar"
|
||||
"""
|
||||
quote_with = '"'
|
||||
if '"' in value:
|
||||
if "'" in value:
|
||||
# The string contains both single and double
|
||||
# quotes. Turn the double quotes into
|
||||
# entities. We quote the double quotes rather than
|
||||
# the single quotes because the entity name is
|
||||
# """ whether this is HTML or XML. If we
|
||||
# quoted the single quotes, we'd have to decide
|
||||
# between ' and &squot;.
|
||||
replace_with = """
|
||||
value = value.replace('"', replace_with)
|
||||
else:
|
||||
# There are double quotes but no single quotes.
|
||||
# We can use single quotes to quote the attribute.
|
||||
quote_with = "'"
|
||||
return quote_with + value + quote_with
|
||||
|
||||
@classmethod
|
||||
def substitute_xml(cls, value, make_quoted_attribute=False):
|
||||
"""Substitute XML entities for special XML characters.
|
||||
|
||||
:param value: A string to be substituted. The less-than sign
|
||||
will become <, the greater-than sign will become >,
|
||||
and any ampersands will become &. If you want ampersands
|
||||
that appear to be part of an entity definition to be left
|
||||
alone, use substitute_xml_containing_entities() instead.
|
||||
|
||||
:param make_quoted_attribute: If True, then the string will be
|
||||
quoted, as befits an attribute value.
|
||||
"""
|
||||
# Escape angle brackets and ampersands.
|
||||
value = cls.AMPERSAND_OR_BRACKET.sub(
|
||||
cls._substitute_xml_entity, value)
|
||||
|
||||
if make_quoted_attribute:
|
||||
value = cls.quoted_attribute_value(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def substitute_xml_containing_entities(
|
||||
cls, value, make_quoted_attribute=False):
|
||||
"""Substitute XML entities for special XML characters.
|
||||
|
||||
:param value: A string to be substituted. The less-than sign will
|
||||
become <, the greater-than sign will become >, and any
|
||||
ampersands that are not part of an entity defition will
|
||||
become &.
|
||||
|
||||
:param make_quoted_attribute: If True, then the string will be
|
||||
quoted, as befits an attribute value.
|
||||
"""
|
||||
# Escape angle brackets, and ampersands that aren't part of
|
||||
# entities.
|
||||
value = cls.BARE_AMPERSAND_OR_BRACKET.sub(
|
||||
cls._substitute_xml_entity, value)
|
||||
|
||||
if make_quoted_attribute:
|
||||
value = cls.quoted_attribute_value(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def substitute_html(cls, s):
|
||||
"""Replace certain Unicode characters with named HTML entities.
|
||||
|
||||
This differs from data.encode(encoding, 'xmlcharrefreplace')
|
||||
in that the goal is to make the result more readable (to those
|
||||
with ASCII displays) rather than to recover from
|
||||
errors. There's absolutely nothing wrong with a UTF-8 string
|
||||
containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that
|
||||
character with "é" will make it more readable to some
|
||||
people.
|
||||
"""
|
||||
return cls.CHARACTER_TO_HTML_ENTITY_RE.sub(
|
||||
cls._substitute_html_entity, s)
|
||||
|
||||
|
||||
class EncodingDetector:
|
||||
"""Suggests a number of possible encodings for a bytestring.
|
||||
|
||||
Order of precedence:
|
||||
|
||||
1. Encodings you specifically tell EncodingDetector to try first
|
||||
(the override_encodings argument to the constructor).
|
||||
|
||||
2. An encoding declared within the bytestring itself, either in an
|
||||
XML declaration (if the bytestring is to be interpreted as an XML
|
||||
document), or in a <meta> tag (if the bytestring is to be
|
||||
interpreted as an HTML document.)
|
||||
|
||||
3. An encoding detected through textual analysis by chardet,
|
||||
cchardet, or a similar external library.
|
||||
|
||||
4. UTF-8.
|
||||
|
||||
5. Windows-1252.
|
||||
"""
|
||||
def __init__(self, markup, override_encodings=None, is_html=False,
|
||||
exclude_encodings=None):
|
||||
self.override_encodings = override_encodings or []
|
||||
exclude_encodings = exclude_encodings or []
|
||||
self.exclude_encodings = set([x.lower() for x in exclude_encodings])
|
||||
self.chardet_encoding = None
|
||||
self.is_html = is_html
|
||||
self.declared_encoding = None
|
||||
|
||||
# First order of business: strip a byte-order mark.
|
||||
self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup)
|
||||
|
||||
def _usable(self, encoding, tried):
|
||||
if encoding is not None:
|
||||
encoding = encoding.lower()
|
||||
if encoding in self.exclude_encodings:
|
||||
return False
|
||||
if encoding not in tried:
|
||||
tried.add(encoding)
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def encodings(self):
|
||||
"""Yield a number of encodings that might work for this markup."""
|
||||
tried = set()
|
||||
for e in self.override_encodings:
|
||||
if self._usable(e, tried):
|
||||
yield e
|
||||
|
||||
# Did the document originally start with a byte-order mark
|
||||
# that indicated its encoding?
|
||||
if self._usable(self.sniffed_encoding, tried):
|
||||
yield self.sniffed_encoding
|
||||
|
||||
# Look within the document for an XML or HTML encoding
|
||||
# declaration.
|
||||
if self.declared_encoding is None:
|
||||
self.declared_encoding = self.find_declared_encoding(
|
||||
self.markup, self.is_html)
|
||||
if self._usable(self.declared_encoding, tried):
|
||||
yield self.declared_encoding
|
||||
|
||||
# Use third-party character set detection to guess at the
|
||||
# encoding.
|
||||
if self.chardet_encoding is None:
|
||||
self.chardet_encoding = chardet_dammit(self.markup)
|
||||
if self._usable(self.chardet_encoding, tried):
|
||||
yield self.chardet_encoding
|
||||
|
||||
# As a last-ditch effort, try utf-8 and windows-1252.
|
||||
for e in ('utf-8', 'windows-1252'):
|
||||
if self._usable(e, tried):
|
||||
yield e
|
||||
|
||||
@classmethod
|
||||
def strip_byte_order_mark(cls, data):
|
||||
"""If a byte-order mark is present, strip it and return the encoding it implies."""
|
||||
encoding = None
|
||||
if isinstance(data, unicode):
|
||||
# Unicode data cannot have a byte-order mark.
|
||||
return data, encoding
|
||||
if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
|
||||
and (data[2:4] != '\x00\x00'):
|
||||
encoding = 'utf-16be'
|
||||
data = data[2:]
|
||||
elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \
|
||||
and (data[2:4] != '\x00\x00'):
|
||||
encoding = 'utf-16le'
|
||||
data = data[2:]
|
||||
elif data[:3] == b'\xef\xbb\xbf':
|
||||
encoding = 'utf-8'
|
||||
data = data[3:]
|
||||
elif data[:4] == b'\x00\x00\xfe\xff':
|
||||
encoding = 'utf-32be'
|
||||
data = data[4:]
|
||||
elif data[:4] == b'\xff\xfe\x00\x00':
|
||||
encoding = 'utf-32le'
|
||||
data = data[4:]
|
||||
return data, encoding
|
||||
|
||||
@classmethod
|
||||
def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False):
|
||||
"""Given a document, tries to find its declared encoding.
|
||||
|
||||
An XML encoding is declared at the beginning of the document.
|
||||
|
||||
An HTML encoding is declared in a <meta> tag, hopefully near the
|
||||
beginning of the document.
|
||||
"""
|
||||
if search_entire_document:
|
||||
xml_endpos = html_endpos = len(markup)
|
||||
else:
|
||||
xml_endpos = 1024
|
||||
html_endpos = max(2048, int(len(markup) * 0.05))
|
||||
|
||||
declared_encoding = None
|
||||
declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos)
|
||||
if not declared_encoding_match and is_html:
|
||||
declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos)
|
||||
if declared_encoding_match is not None:
|
||||
declared_encoding = declared_encoding_match.groups()[0].decode(
|
||||
'ascii', 'replace')
|
||||
if declared_encoding:
|
||||
return declared_encoding.lower()
|
||||
return None
|
||||
|
||||
class UnicodeDammit:
|
||||
"""A class for detecting the encoding of a *ML document and
|
||||
converting it to a Unicode string. If the source encoding is
|
||||
windows-1252, can replace MS smart quotes with their HTML or XML
|
||||
equivalents."""
|
||||
|
||||
# This dictionary maps commonly seen values for "charset" in HTML
|
||||
# meta tags to the corresponding Python codec names. It only covers
|
||||
# values that aren't in Python's aliases and can't be determined
|
||||
# by the heuristics in find_codec.
|
||||
CHARSET_ALIASES = {"macintosh": "mac-roman",
|
||||
"x-sjis": "shift-jis"}
|
||||
|
||||
ENCODINGS_WITH_SMART_QUOTES = [
|
||||
"windows-1252",
|
||||
"iso-8859-1",
|
||||
"iso-8859-2",
|
||||
]
|
||||
|
||||
def __init__(self, markup, override_encodings=[],
|
||||
smart_quotes_to=None, is_html=False, exclude_encodings=[]):
|
||||
self.smart_quotes_to = smart_quotes_to
|
||||
self.tried_encodings = []
|
||||
self.contains_replacement_characters = False
|
||||
self.is_html = is_html
|
||||
|
||||
self.detector = EncodingDetector(
|
||||
markup, override_encodings, is_html, exclude_encodings)
|
||||
|
||||
# Short-circuit if the data is in Unicode to begin with.
|
||||
if isinstance(markup, unicode) or markup == '':
|
||||
self.markup = markup
|
||||
self.unicode_markup = unicode(markup)
|
||||
self.original_encoding = None
|
||||
return
|
||||
|
||||
# The encoding detector may have stripped a byte-order mark.
|
||||
# Use the stripped markup from this point on.
|
||||
self.markup = self.detector.markup
|
||||
|
||||
u = None
|
||||
for encoding in self.detector.encodings:
|
||||
markup = self.detector.markup
|
||||
u = self._convert_from(encoding)
|
||||
if u is not None:
|
||||
break
|
||||
|
||||
if not u:
|
||||
# None of the encodings worked. As an absolute last resort,
|
||||
# try them again with character replacement.
|
||||
|
||||
for encoding in self.detector.encodings:
|
||||
if encoding != "ascii":
|
||||
u = self._convert_from(encoding, "replace")
|
||||
if u is not None:
|
||||
logging.warning(
|
||||
"Some characters could not be decoded, and were "
|
||||
"replaced with REPLACEMENT CHARACTER.")
|
||||
self.contains_replacement_characters = True
|
||||
break
|
||||
|
||||
# If none of that worked, we could at this point force it to
|
||||
# ASCII, but that would destroy so much data that I think
|
||||
# giving up is better.
|
||||
self.unicode_markup = u
|
||||
if not u:
|
||||
self.original_encoding = None
|
||||
|
||||
def _sub_ms_char(self, match):
|
||||
"""Changes a MS smart quote character to an XML or HTML
|
||||
entity, or an ASCII character."""
|
||||
orig = match.group(1)
|
||||
if self.smart_quotes_to == 'ascii':
|
||||
sub = self.MS_CHARS_TO_ASCII.get(orig).encode()
|
||||
else:
|
||||
sub = self.MS_CHARS.get(orig)
|
||||
if type(sub) == tuple:
|
||||
if self.smart_quotes_to == 'xml':
|
||||
sub = '&#x'.encode() + sub[1].encode() + ';'.encode()
|
||||
else:
|
||||
sub = '&'.encode() + sub[0].encode() + ';'.encode()
|
||||
else:
|
||||
sub = sub.encode()
|
||||
return sub
|
||||
|
||||
def _convert_from(self, proposed, errors="strict"):
|
||||
proposed = self.find_codec(proposed)
|
||||
if not proposed or (proposed, errors) in self.tried_encodings:
|
||||
return None
|
||||
self.tried_encodings.append((proposed, errors))
|
||||
markup = self.markup
|
||||
# Convert smart quotes to HTML if coming from an encoding
|
||||
# that might have them.
|
||||
if (self.smart_quotes_to is not None
|
||||
and proposed in self.ENCODINGS_WITH_SMART_QUOTES):
|
||||
smart_quotes_re = b"([\x80-\x9f])"
|
||||
smart_quotes_compiled = re.compile(smart_quotes_re)
|
||||
markup = smart_quotes_compiled.sub(self._sub_ms_char, markup)
|
||||
|
||||
try:
|
||||
#print "Trying to convert document to %s (errors=%s)" % (
|
||||
# proposed, errors)
|
||||
u = self._to_unicode(markup, proposed, errors)
|
||||
self.markup = u
|
||||
self.original_encoding = proposed
|
||||
except Exception as e:
|
||||
#print "That didn't work!"
|
||||
#print e
|
||||
return None
|
||||
#print "Correct encoding: %s" % proposed
|
||||
return self.markup
|
||||
|
||||
def _to_unicode(self, data, encoding, errors="strict"):
|
||||
'''Given a string and its encoding, decodes the string into Unicode.
|
||||
%encoding is a string recognized by encodings.aliases'''
|
||||
return unicode(data, encoding, errors)
|
||||
|
||||
@property
|
||||
def declared_html_encoding(self):
|
||||
if not self.is_html:
|
||||
return None
|
||||
return self.detector.declared_encoding
|
||||
|
||||
def find_codec(self, charset):
|
||||
value = (self._codec(self.CHARSET_ALIASES.get(charset, charset))
|
||||
or (charset and self._codec(charset.replace("-", "")))
|
||||
or (charset and self._codec(charset.replace("-", "_")))
|
||||
or (charset and charset.lower())
|
||||
or charset
|
||||
)
|
||||
if value:
|
||||
return value.lower()
|
||||
return None
|
||||
|
||||
def _codec(self, charset):
|
||||
if not charset:
|
||||
return charset
|
||||
codec = None
|
||||
try:
|
||||
codecs.lookup(charset)
|
||||
codec = charset
|
||||
except (LookupError, ValueError):
|
||||
pass
|
||||
return codec
|
||||
|
||||
|
||||
# A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities.
|
||||
MS_CHARS = {b'\x80': ('euro', '20AC'),
|
||||
b'\x81': ' ',
|
||||
b'\x82': ('sbquo', '201A'),
|
||||
b'\x83': ('fnof', '192'),
|
||||
b'\x84': ('bdquo', '201E'),
|
||||
b'\x85': ('hellip', '2026'),
|
||||
b'\x86': ('dagger', '2020'),
|
||||
b'\x87': ('Dagger', '2021'),
|
||||
b'\x88': ('circ', '2C6'),
|
||||
b'\x89': ('permil', '2030'),
|
||||
b'\x8A': ('Scaron', '160'),
|
||||
b'\x8B': ('lsaquo', '2039'),
|
||||
b'\x8C': ('OElig', '152'),
|
||||
b'\x8D': '?',
|
||||
b'\x8E': ('#x17D', '17D'),
|
||||
b'\x8F': '?',
|
||||
b'\x90': '?',
|
||||
b'\x91': ('lsquo', '2018'),
|
||||
b'\x92': ('rsquo', '2019'),
|
||||
b'\x93': ('ldquo', '201C'),
|
||||
b'\x94': ('rdquo', '201D'),
|
||||
b'\x95': ('bull', '2022'),
|
||||
b'\x96': ('ndash', '2013'),
|
||||
b'\x97': ('mdash', '2014'),
|
||||
b'\x98': ('tilde', '2DC'),
|
||||
b'\x99': ('trade', '2122'),
|
||||
b'\x9a': ('scaron', '161'),
|
||||
b'\x9b': ('rsaquo', '203A'),
|
||||
b'\x9c': ('oelig', '153'),
|
||||
b'\x9d': '?',
|
||||
b'\x9e': ('#x17E', '17E'),
|
||||
b'\x9f': ('Yuml', ''),}
|
||||
|
||||
# A parochial partial mapping of ISO-Latin-1 to ASCII. Contains
|
||||
# horrors like stripping diacritical marks to turn á into a, but also
|
||||
# contains non-horrors like turning “ into ".
|
||||
MS_CHARS_TO_ASCII = {
|
||||
b'\x80' : 'EUR',
|
||||
b'\x81' : ' ',
|
||||
b'\x82' : ',',
|
||||
b'\x83' : 'f',
|
||||
b'\x84' : ',,',
|
||||
b'\x85' : '...',
|
||||
b'\x86' : '+',
|
||||
b'\x87' : '++',
|
||||
b'\x88' : '^',
|
||||
b'\x89' : '%',
|
||||
b'\x8a' : 'S',
|
||||
b'\x8b' : '<',
|
||||
b'\x8c' : 'OE',
|
||||
b'\x8d' : '?',
|
||||
b'\x8e' : 'Z',
|
||||
b'\x8f' : '?',
|
||||
b'\x90' : '?',
|
||||
b'\x91' : "'",
|
||||
b'\x92' : "'",
|
||||
b'\x93' : '"',
|
||||
b'\x94' : '"',
|
||||
b'\x95' : '*',
|
||||
b'\x96' : '-',
|
||||
b'\x97' : '--',
|
||||
b'\x98' : '~',
|
||||
b'\x99' : '(TM)',
|
||||
b'\x9a' : 's',
|
||||
b'\x9b' : '>',
|
||||
b'\x9c' : 'oe',
|
||||
b'\x9d' : '?',
|
||||
b'\x9e' : 'z',
|
||||
b'\x9f' : 'Y',
|
||||
b'\xa0' : ' ',
|
||||
b'\xa1' : '!',
|
||||
b'\xa2' : 'c',
|
||||
b'\xa3' : 'GBP',
|
||||
b'\xa4' : '$', #This approximation is especially parochial--this is the
|
||||
#generic currency symbol.
|
||||
b'\xa5' : 'YEN',
|
||||
b'\xa6' : '|',
|
||||
b'\xa7' : 'S',
|
||||
b'\xa8' : '..',
|
||||
b'\xa9' : '',
|
||||
b'\xaa' : '(th)',
|
||||
b'\xab' : '<<',
|
||||
b'\xac' : '!',
|
||||
b'\xad' : ' ',
|
||||
b'\xae' : '(R)',
|
||||
b'\xaf' : '-',
|
||||
b'\xb0' : 'o',
|
||||
b'\xb1' : '+-',
|
||||
b'\xb2' : '2',
|
||||
b'\xb3' : '3',
|
||||
b'\xb4' : ("'", 'acute'),
|
||||
b'\xb5' : 'u',
|
||||
b'\xb6' : 'P',
|
||||
b'\xb7' : '*',
|
||||
b'\xb8' : ',',
|
||||
b'\xb9' : '1',
|
||||
b'\xba' : '(th)',
|
||||
b'\xbb' : '>>',
|
||||
b'\xbc' : '1/4',
|
||||
b'\xbd' : '1/2',
|
||||
b'\xbe' : '3/4',
|
||||
b'\xbf' : '?',
|
||||
b'\xc0' : 'A',
|
||||
b'\xc1' : 'A',
|
||||
b'\xc2' : 'A',
|
||||
b'\xc3' : 'A',
|
||||
b'\xc4' : 'A',
|
||||
b'\xc5' : 'A',
|
||||
b'\xc6' : 'AE',
|
||||
b'\xc7' : 'C',
|
||||
b'\xc8' : 'E',
|
||||
b'\xc9' : 'E',
|
||||
b'\xca' : 'E',
|
||||
b'\xcb' : 'E',
|
||||
b'\xcc' : 'I',
|
||||
b'\xcd' : 'I',
|
||||
b'\xce' : 'I',
|
||||
b'\xcf' : 'I',
|
||||
b'\xd0' : 'D',
|
||||
b'\xd1' : 'N',
|
||||
b'\xd2' : 'O',
|
||||
b'\xd3' : 'O',
|
||||
b'\xd4' : 'O',
|
||||
b'\xd5' : 'O',
|
||||
b'\xd6' : 'O',
|
||||
b'\xd7' : '*',
|
||||
b'\xd8' : 'O',
|
||||
b'\xd9' : 'U',
|
||||
b'\xda' : 'U',
|
||||
b'\xdb' : 'U',
|
||||
b'\xdc' : 'U',
|
||||
b'\xdd' : 'Y',
|
||||
b'\xde' : 'b',
|
||||
b'\xdf' : 'B',
|
||||
b'\xe0' : 'a',
|
||||
b'\xe1' : 'a',
|
||||
b'\xe2' : 'a',
|
||||
b'\xe3' : 'a',
|
||||
b'\xe4' : 'a',
|
||||
b'\xe5' : 'a',
|
||||
b'\xe6' : 'ae',
|
||||
b'\xe7' : 'c',
|
||||
b'\xe8' : 'e',
|
||||
b'\xe9' : 'e',
|
||||
b'\xea' : 'e',
|
||||
b'\xeb' : 'e',
|
||||
b'\xec' : 'i',
|
||||
b'\xed' : 'i',
|
||||
b'\xee' : 'i',
|
||||
b'\xef' : 'i',
|
||||
b'\xf0' : 'o',
|
||||
b'\xf1' : 'n',
|
||||
b'\xf2' : 'o',
|
||||
b'\xf3' : 'o',
|
||||
b'\xf4' : 'o',
|
||||
b'\xf5' : 'o',
|
||||
b'\xf6' : 'o',
|
||||
b'\xf7' : '/',
|
||||
b'\xf8' : 'o',
|
||||
b'\xf9' : 'u',
|
||||
b'\xfa' : 'u',
|
||||
b'\xfb' : 'u',
|
||||
b'\xfc' : 'u',
|
||||
b'\xfd' : 'y',
|
||||
b'\xfe' : 'b',
|
||||
b'\xff' : 'y',
|
||||
}
|
||||
|
||||
# A map used when removing rogue Windows-1252/ISO-8859-1
|
||||
# characters in otherwise UTF-8 documents.
|
||||
#
|
||||
# Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in
|
||||
# Windows-1252.
|
||||
WINDOWS_1252_TO_UTF8 = {
|
||||
0x80 : b'\xe2\x82\xac', # €
|
||||
0x82 : b'\xe2\x80\x9a', # ‚
|
||||
0x83 : b'\xc6\x92', # ƒ
|
||||
0x84 : b'\xe2\x80\x9e', # „
|
||||
0x85 : b'\xe2\x80\xa6', # …
|
||||
0x86 : b'\xe2\x80\xa0', # †
|
||||
0x87 : b'\xe2\x80\xa1', # ‡
|
||||
0x88 : b'\xcb\x86', # ˆ
|
||||
0x89 : b'\xe2\x80\xb0', # ‰
|
||||
0x8a : b'\xc5\xa0', # Š
|
||||
0x8b : b'\xe2\x80\xb9', # ‹
|
||||
0x8c : b'\xc5\x92', # Œ
|
||||
0x8e : b'\xc5\xbd', # Ž
|
||||
0x91 : b'\xe2\x80\x98', # ‘
|
||||
0x92 : b'\xe2\x80\x99', # ’
|
||||
0x93 : b'\xe2\x80\x9c', # “
|
||||
0x94 : b'\xe2\x80\x9d', # ”
|
||||
0x95 : b'\xe2\x80\xa2', # •
|
||||
0x96 : b'\xe2\x80\x93', # –
|
||||
0x97 : b'\xe2\x80\x94', # —
|
||||
0x98 : b'\xcb\x9c', # ˜
|
||||
0x99 : b'\xe2\x84\xa2', # ™
|
||||
0x9a : b'\xc5\xa1', # š
|
||||
0x9b : b'\xe2\x80\xba', # ›
|
||||
0x9c : b'\xc5\x93', # œ
|
||||
0x9e : b'\xc5\xbe', # ž
|
||||
0x9f : b'\xc5\xb8', # Ÿ
|
||||
0xa0 : b'\xc2\xa0', #
|
||||
0xa1 : b'\xc2\xa1', # ¡
|
||||
0xa2 : b'\xc2\xa2', # ¢
|
||||
0xa3 : b'\xc2\xa3', # £
|
||||
0xa4 : b'\xc2\xa4', # ¤
|
||||
0xa5 : b'\xc2\xa5', # ¥
|
||||
0xa6 : b'\xc2\xa6', # ¦
|
||||
0xa7 : b'\xc2\xa7', # §
|
||||
0xa8 : b'\xc2\xa8', # ¨
|
||||
0xa9 : b'\xc2\xa9', # ©
|
||||
0xaa : b'\xc2\xaa', # ª
|
||||
0xab : b'\xc2\xab', # «
|
||||
0xac : b'\xc2\xac', # ¬
|
||||
0xad : b'\xc2\xad', #
|
||||
0xae : b'\xc2\xae', # ®
|
||||
0xaf : b'\xc2\xaf', # ¯
|
||||
0xb0 : b'\xc2\xb0', # °
|
||||
0xb1 : b'\xc2\xb1', # ±
|
||||
0xb2 : b'\xc2\xb2', # ²
|
||||
0xb3 : b'\xc2\xb3', # ³
|
||||
0xb4 : b'\xc2\xb4', # ´
|
||||
0xb5 : b'\xc2\xb5', # µ
|
||||
0xb6 : b'\xc2\xb6', # ¶
|
||||
0xb7 : b'\xc2\xb7', # ·
|
||||
0xb8 : b'\xc2\xb8', # ¸
|
||||
0xb9 : b'\xc2\xb9', # ¹
|
||||
0xba : b'\xc2\xba', # º
|
||||
0xbb : b'\xc2\xbb', # »
|
||||
0xbc : b'\xc2\xbc', # ¼
|
||||
0xbd : b'\xc2\xbd', # ½
|
||||
0xbe : b'\xc2\xbe', # ¾
|
||||
0xbf : b'\xc2\xbf', # ¿
|
||||
0xc0 : b'\xc3\x80', # À
|
||||
0xc1 : b'\xc3\x81', # Á
|
||||
0xc2 : b'\xc3\x82', # Â
|
||||
0xc3 : b'\xc3\x83', # Ã
|
||||
0xc4 : b'\xc3\x84', # Ä
|
||||
0xc5 : b'\xc3\x85', # Å
|
||||
0xc6 : b'\xc3\x86', # Æ
|
||||
0xc7 : b'\xc3\x87', # Ç
|
||||
0xc8 : b'\xc3\x88', # È
|
||||
0xc9 : b'\xc3\x89', # É
|
||||
0xca : b'\xc3\x8a', # Ê
|
||||
0xcb : b'\xc3\x8b', # Ë
|
||||
0xcc : b'\xc3\x8c', # Ì
|
||||
0xcd : b'\xc3\x8d', # Í
|
||||
0xce : b'\xc3\x8e', # Î
|
||||
0xcf : b'\xc3\x8f', # Ï
|
||||
0xd0 : b'\xc3\x90', # Ð
|
||||
0xd1 : b'\xc3\x91', # Ñ
|
||||
0xd2 : b'\xc3\x92', # Ò
|
||||
0xd3 : b'\xc3\x93', # Ó
|
||||
0xd4 : b'\xc3\x94', # Ô
|
||||
0xd5 : b'\xc3\x95', # Õ
|
||||
0xd6 : b'\xc3\x96', # Ö
|
||||
0xd7 : b'\xc3\x97', # ×
|
||||
0xd8 : b'\xc3\x98', # Ø
|
||||
0xd9 : b'\xc3\x99', # Ù
|
||||
0xda : b'\xc3\x9a', # Ú
|
||||
0xdb : b'\xc3\x9b', # Û
|
||||
0xdc : b'\xc3\x9c', # Ü
|
||||
0xdd : b'\xc3\x9d', # Ý
|
||||
0xde : b'\xc3\x9e', # Þ
|
||||
0xdf : b'\xc3\x9f', # ß
|
||||
0xe0 : b'\xc3\xa0', # à
|
||||
0xe1 : b'\xa1', # á
|
||||
0xe2 : b'\xc3\xa2', # â
|
||||
0xe3 : b'\xc3\xa3', # ã
|
||||
0xe4 : b'\xc3\xa4', # ä
|
||||
0xe5 : b'\xc3\xa5', # å
|
||||
0xe6 : b'\xc3\xa6', # æ
|
||||
0xe7 : b'\xc3\xa7', # ç
|
||||
0xe8 : b'\xc3\xa8', # è
|
||||
0xe9 : b'\xc3\xa9', # é
|
||||
0xea : b'\xc3\xaa', # ê
|
||||
0xeb : b'\xc3\xab', # ë
|
||||
0xec : b'\xc3\xac', # ì
|
||||
0xed : b'\xc3\xad', # í
|
||||
0xee : b'\xc3\xae', # î
|
||||
0xef : b'\xc3\xaf', # ï
|
||||
0xf0 : b'\xc3\xb0', # ð
|
||||
0xf1 : b'\xc3\xb1', # ñ
|
||||
0xf2 : b'\xc3\xb2', # ò
|
||||
0xf3 : b'\xc3\xb3', # ó
|
||||
0xf4 : b'\xc3\xb4', # ô
|
||||
0xf5 : b'\xc3\xb5', # õ
|
||||
0xf6 : b'\xc3\xb6', # ö
|
||||
0xf7 : b'\xc3\xb7', # ÷
|
||||
0xf8 : b'\xc3\xb8', # ø
|
||||
0xf9 : b'\xc3\xb9', # ù
|
||||
0xfa : b'\xc3\xba', # ú
|
||||
0xfb : b'\xc3\xbb', # û
|
||||
0xfc : b'\xc3\xbc', # ü
|
||||
0xfd : b'\xc3\xbd', # ý
|
||||
0xfe : b'\xc3\xbe', # þ
|
||||
}
|
||||
|
||||
MULTIBYTE_MARKERS_AND_SIZES = [
|
||||
(0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF
|
||||
(0xe0, 0xef, 3), # 3-byte characters start with E0-EF
|
||||
(0xf0, 0xf4, 4), # 4-byte characters start with F0-F4
|
||||
]
|
||||
|
||||
FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0]
|
||||
LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1]
|
||||
|
||||
@classmethod
|
||||
def detwingle(cls, in_bytes, main_encoding="utf8",
|
||||
embedded_encoding="windows-1252"):
|
||||
"""Fix characters from one encoding embedded in some other encoding.
|
||||
|
||||
Currently the only situation supported is Windows-1252 (or its
|
||||
subset ISO-8859-1), embedded in UTF-8.
|
||||
|
||||
The input must be a bytestring. If you've already converted
|
||||
the document to Unicode, you're too late.
|
||||
|
||||
The output is a bytestring in which `embedded_encoding`
|
||||
characters have been converted to their `main_encoding`
|
||||
equivalents.
|
||||
"""
|
||||
if embedded_encoding.replace('_', '-').lower() not in (
|
||||
'windows-1252', 'windows_1252'):
|
||||
raise NotImplementedError(
|
||||
"Windows-1252 and ISO-8859-1 are the only currently supported "
|
||||
"embedded encodings.")
|
||||
|
||||
if main_encoding.lower() not in ('utf8', 'utf-8'):
|
||||
raise NotImplementedError(
|
||||
"UTF-8 is the only currently supported main encoding.")
|
||||
|
||||
byte_chunks = []
|
||||
|
||||
chunk_start = 0
|
||||
pos = 0
|
||||
while pos < len(in_bytes):
|
||||
byte = in_bytes[pos]
|
||||
if not isinstance(byte, int):
|
||||
# Python 2.x
|
||||
byte = ord(byte)
|
||||
if (byte >= cls.FIRST_MULTIBYTE_MARKER
|
||||
and byte <= cls.LAST_MULTIBYTE_MARKER):
|
||||
# This is the start of a UTF-8 multibyte character. Skip
|
||||
# to the end.
|
||||
for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES:
|
||||
if byte >= start and byte <= end:
|
||||
pos += size
|
||||
break
|
||||
elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8:
|
||||
# We found a Windows-1252 character!
|
||||
# Save the string up to this point as a chunk.
|
||||
byte_chunks.append(in_bytes[chunk_start:pos])
|
||||
|
||||
# Now translate the Windows-1252 character into UTF-8
|
||||
# and add it as another, one-byte chunk.
|
||||
byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte])
|
||||
pos += 1
|
||||
chunk_start = pos
|
||||
else:
|
||||
# Go on to the next character.
|
||||
pos += 1
|
||||
if chunk_start == 0:
|
||||
# The string is unchanged.
|
||||
return in_bytes
|
||||
else:
|
||||
# Store the final chunk.
|
||||
byte_chunks.append(in_bytes[chunk_start:])
|
||||
return b''.join(byte_chunks)
|
||||
|
||||
213
lib/bs4/diagnose.py
Normal file
213
lib/bs4/diagnose.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Diagnostic functions, mainly for use when doing tech support."""
|
||||
import cProfile
|
||||
from StringIO import StringIO
|
||||
from HTMLParser import HTMLParser
|
||||
import bs4
|
||||
from bs4 import BeautifulSoup, __version__
|
||||
from bs4.builder import builder_registry
|
||||
|
||||
import os
|
||||
import pstats
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import sys
|
||||
import cProfile
|
||||
|
||||
def diagnose(data):
|
||||
"""Diagnostic suite for isolating common problems."""
|
||||
print "Diagnostic running on Beautiful Soup %s" % __version__
|
||||
print "Python version %s" % sys.version
|
||||
|
||||
basic_parsers = ["html.parser", "html5lib", "lxml"]
|
||||
for name in basic_parsers:
|
||||
for builder in builder_registry.builders:
|
||||
if name in builder.features:
|
||||
break
|
||||
else:
|
||||
basic_parsers.remove(name)
|
||||
print (
|
||||
"I noticed that %s is not installed. Installing it may help." %
|
||||
name)
|
||||
|
||||
if 'lxml' in basic_parsers:
|
||||
basic_parsers.append(["lxml", "xml"])
|
||||
try:
|
||||
from lxml import etree
|
||||
print "Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))
|
||||
except ImportError, e:
|
||||
print (
|
||||
"lxml is not installed or couldn't be imported.")
|
||||
|
||||
|
||||
if 'html5lib' in basic_parsers:
|
||||
try:
|
||||
import html5lib
|
||||
print "Found html5lib version %s" % html5lib.__version__
|
||||
except ImportError, e:
|
||||
print (
|
||||
"html5lib is not installed or couldn't be imported.")
|
||||
|
||||
if hasattr(data, 'read'):
|
||||
data = data.read()
|
||||
elif os.path.exists(data):
|
||||
print '"%s" looks like a filename. Reading data from the file.' % data
|
||||
data = open(data).read()
|
||||
elif data.startswith("http:") or data.startswith("https:"):
|
||||
print '"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data
|
||||
print "You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup."
|
||||
return
|
||||
print
|
||||
|
||||
for parser in basic_parsers:
|
||||
print "Trying to parse your markup with %s" % parser
|
||||
success = False
|
||||
try:
|
||||
soup = BeautifulSoup(data, parser)
|
||||
success = True
|
||||
except Exception, e:
|
||||
print "%s could not parse the markup." % parser
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print "Here's what %s did with the markup:" % parser
|
||||
print soup.prettify()
|
||||
|
||||
print "-" * 80
|
||||
|
||||
def lxml_trace(data, html=True, **kwargs):
|
||||
"""Print out the lxml events that occur during parsing.
|
||||
|
||||
This lets you see how lxml parses a document when no Beautiful
|
||||
Soup code is running.
|
||||
"""
|
||||
from lxml import etree
|
||||
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
|
||||
print("%s, %4s, %s" % (event, element.tag, element.text))
|
||||
|
||||
class AnnouncingParser(HTMLParser):
|
||||
"""Announces HTMLParser parse events, without doing anything else."""
|
||||
|
||||
def _p(self, s):
|
||||
print(s)
|
||||
|
||||
def handle_starttag(self, name, attrs):
|
||||
self._p("%s START" % name)
|
||||
|
||||
def handle_endtag(self, name):
|
||||
self._p("%s END" % name)
|
||||
|
||||
def handle_data(self, data):
|
||||
self._p("%s DATA" % data)
|
||||
|
||||
def handle_charref(self, name):
|
||||
self._p("%s CHARREF" % name)
|
||||
|
||||
def handle_entityref(self, name):
|
||||
self._p("%s ENTITYREF" % name)
|
||||
|
||||
def handle_comment(self, data):
|
||||
self._p("%s COMMENT" % data)
|
||||
|
||||
def handle_decl(self, data):
|
||||
self._p("%s DECL" % data)
|
||||
|
||||
def unknown_decl(self, data):
|
||||
self._p("%s UNKNOWN-DECL" % data)
|
||||
|
||||
def handle_pi(self, data):
|
||||
self._p("%s PI" % data)
|
||||
|
||||
def htmlparser_trace(data):
|
||||
"""Print out the HTMLParser events that occur during parsing.
|
||||
|
||||
This lets you see how HTMLParser parses a document when no
|
||||
Beautiful Soup code is running.
|
||||
"""
|
||||
parser = AnnouncingParser()
|
||||
parser.feed(data)
|
||||
|
||||
_vowels = "aeiou"
|
||||
_consonants = "bcdfghjklmnpqrstvwxyz"
|
||||
|
||||
def rword(length=5):
|
||||
"Generate a random word-like string."
|
||||
s = ''
|
||||
for i in range(length):
|
||||
if i % 2 == 0:
|
||||
t = _consonants
|
||||
else:
|
||||
t = _vowels
|
||||
s += random.choice(t)
|
||||
return s
|
||||
|
||||
def rsentence(length=4):
|
||||
"Generate a random sentence-like string."
|
||||
return " ".join(rword(random.randint(4,9)) for i in range(length))
|
||||
|
||||
def rdoc(num_elements=1000):
|
||||
"""Randomly generate an invalid HTML document."""
|
||||
tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table']
|
||||
elements = []
|
||||
for i in range(num_elements):
|
||||
choice = random.randint(0,3)
|
||||
if choice == 0:
|
||||
# New tag.
|
||||
tag_name = random.choice(tag_names)
|
||||
elements.append("<%s>" % tag_name)
|
||||
elif choice == 1:
|
||||
elements.append(rsentence(random.randint(1,4)))
|
||||
elif choice == 2:
|
||||
# Close a tag.
|
||||
tag_name = random.choice(tag_names)
|
||||
elements.append("</%s>" % tag_name)
|
||||
return "<html>" + "\n".join(elements) + "</html>"
|
||||
|
||||
def benchmark_parsers(num_elements=100000):
|
||||
"""Very basic head-to-head performance benchmark."""
|
||||
print "Comparative parser benchmark on Beautiful Soup %s" % __version__
|
||||
data = rdoc(num_elements)
|
||||
print "Generated a large invalid HTML document (%d bytes)." % len(data)
|
||||
|
||||
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
|
||||
success = False
|
||||
try:
|
||||
a = time.time()
|
||||
soup = BeautifulSoup(data, parser)
|
||||
b = time.time()
|
||||
success = True
|
||||
except Exception, e:
|
||||
print "%s could not parse the markup." % parser
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print "BS4+%s parsed the markup in %.2fs." % (parser, b-a)
|
||||
|
||||
from lxml import etree
|
||||
a = time.time()
|
||||
etree.HTML(data)
|
||||
b = time.time()
|
||||
print "Raw lxml parsed the markup in %.2fs." % (b-a)
|
||||
|
||||
import html5lib
|
||||
parser = html5lib.HTMLParser()
|
||||
a = time.time()
|
||||
parser.parse(data)
|
||||
b = time.time()
|
||||
print "Raw html5lib parsed the markup in %.2fs." % (b-a)
|
||||
|
||||
def profile(num_elements=100000, parser="lxml"):
|
||||
|
||||
filehandle = tempfile.NamedTemporaryFile()
|
||||
filename = filehandle.name
|
||||
|
||||
data = rdoc(num_elements)
|
||||
vars = dict(bs4=bs4, data=data, parser=parser)
|
||||
cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename)
|
||||
|
||||
stats = pstats.Stats(filename)
|
||||
# stats.strip_dirs()
|
||||
stats.sort_stats("cumulative")
|
||||
stats.print_stats('_html5lib|bs4', 50)
|
||||
|
||||
if __name__ == '__main__':
|
||||
diagnose(sys.stdin.read())
|
||||
1713
lib/bs4/element.py
Normal file
1713
lib/bs4/element.py
Normal file
File diff suppressed because it is too large
Load Diff
680
lib/bs4/testing.py
Normal file
680
lib/bs4/testing.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""Helper classes for tests."""
|
||||
|
||||
import pickle
|
||||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
from unittest import TestCase
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
Comment,
|
||||
ContentMetaAttributeValue,
|
||||
Doctype,
|
||||
SoupStrainer,
|
||||
)
|
||||
|
||||
from bs4.builder import HTMLParserTreeBuilder
|
||||
default_builder = HTMLParserTreeBuilder
|
||||
|
||||
|
||||
class SoupTest(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def default_builder(self):
|
||||
return default_builder()
|
||||
|
||||
def soup(self, markup, **kwargs):
|
||||
"""Build a Beautiful Soup object from markup."""
|
||||
builder = kwargs.pop('builder', self.default_builder)
|
||||
return BeautifulSoup(markup, builder=builder, **kwargs)
|
||||
|
||||
def document_for(self, markup):
|
||||
"""Turn an HTML fragment into a document.
|
||||
|
||||
The details depend on the builder.
|
||||
"""
|
||||
return self.default_builder.test_fragment_to_document(markup)
|
||||
|
||||
def assertSoupEquals(self, to_parse, compare_parsed_to=None):
|
||||
builder = self.default_builder
|
||||
obj = BeautifulSoup(to_parse, builder=builder)
|
||||
if compare_parsed_to is None:
|
||||
compare_parsed_to = to_parse
|
||||
|
||||
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
|
||||
|
||||
def assertConnectedness(self, element):
|
||||
"""Ensure that next_element and previous_element are properly
|
||||
set for all descendants of the given element.
|
||||
"""
|
||||
earlier = None
|
||||
for e in element.descendants:
|
||||
if earlier:
|
||||
self.assertEqual(e, earlier.next_element)
|
||||
self.assertEqual(earlier, e.previous_element)
|
||||
earlier = e
|
||||
|
||||
class HTMLTreeBuilderSmokeTest(object):
|
||||
|
||||
"""A basic test of a treebuilder's competence.
|
||||
|
||||
Any HTML treebuilder, present or future, should be able to pass
|
||||
these tests. With invalid markup, there's room for interpretation,
|
||||
and different parsers can handle it differently. But with the
|
||||
markup in these tests, there's not much room for interpretation.
|
||||
"""
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
# to the original.
|
||||
tree = self.soup("<a><b>foo</a>")
|
||||
dumped = pickle.dumps(tree, 2)
|
||||
loaded = pickle.loads(dumped)
|
||||
self.assertEqual(loaded.__class__, BeautifulSoup)
|
||||
self.assertEqual(loaded.decode(), tree.decode())
|
||||
|
||||
def assertDoctypeHandled(self, doctype_fragment):
|
||||
"""Assert that a given doctype string is handled correctly."""
|
||||
doctype_str, soup = self._document_with_doctype(doctype_fragment)
|
||||
|
||||
# Make sure a Doctype object was created.
|
||||
doctype = soup.contents[0]
|
||||
self.assertEqual(doctype.__class__, Doctype)
|
||||
self.assertEqual(doctype, doctype_fragment)
|
||||
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
|
||||
|
||||
# Make sure that the doctype was correctly associated with the
|
||||
# parse tree and that the rest of the document parsed.
|
||||
self.assertEqual(soup.p.contents[0], 'foo')
|
||||
|
||||
def _document_with_doctype(self, doctype_fragment):
|
||||
"""Generate and parse a document with the given doctype."""
|
||||
doctype = '<!DOCTYPE %s>' % doctype_fragment
|
||||
markup = doctype + '\n<p>foo</p>'
|
||||
soup = self.soup(markup)
|
||||
return doctype, soup
|
||||
|
||||
def test_normal_doctypes(self):
|
||||
"""Make sure normal, everyday HTML doctypes are handled correctly."""
|
||||
self.assertDoctypeHandled("html")
|
||||
self.assertDoctypeHandled(
|
||||
'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"')
|
||||
|
||||
def test_empty_doctype(self):
|
||||
soup = self.soup("<!DOCTYPE>")
|
||||
doctype = soup.contents[0]
|
||||
self.assertEqual("", doctype.strip())
|
||||
|
||||
def test_public_doctype_with_url(self):
|
||||
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
|
||||
self.assertDoctypeHandled(doctype)
|
||||
|
||||
def test_system_doctype(self):
|
||||
self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"')
|
||||
|
||||
def test_namespaced_system_doctype(self):
|
||||
# We can handle a namespaced doctype with a system ID.
|
||||
self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"')
|
||||
|
||||
def test_namespaced_public_doctype(self):
|
||||
# Test a namespaced doctype with a public id.
|
||||
self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"')
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
"""A real XHTML document should come out more or less the same as it went in."""
|
||||
markup = b"""<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head><title>Hello.</title></head>
|
||||
<body>Goodbye.</body>
|
||||
</html>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
soup.encode("utf-8").replace(b"\n", b""),
|
||||
markup.replace(b"\n", b""))
|
||||
|
||||
def test_processing_instruction(self):
|
||||
markup = b"""<?PITarget PIContent?>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(markup, soup.encode("utf8"))
|
||||
|
||||
def test_deepcopy(self):
|
||||
"""Make sure you can copy the tree builder.
|
||||
|
||||
This is important because the builder is part of a
|
||||
BeautifulSoup object, and we want to be able to copy that.
|
||||
"""
|
||||
copy.deepcopy(self.default_builder)
|
||||
|
||||
def test_p_tag_is_never_empty_element(self):
|
||||
"""A <p> tag is never designated as an empty-element tag.
|
||||
|
||||
Even if the markup shows it as an empty-element tag, it
|
||||
shouldn't be presented that way.
|
||||
"""
|
||||
soup = self.soup("<p/>")
|
||||
self.assertFalse(soup.p.is_empty_element)
|
||||
self.assertEqual(str(soup.p), "<p></p>")
|
||||
|
||||
def test_unclosed_tags_get_closed(self):
|
||||
"""A tag that's not closed by the end of the document should be closed.
|
||||
|
||||
This applies to all tags except empty-element tags.
|
||||
"""
|
||||
self.assertSoupEquals("<p>", "<p></p>")
|
||||
self.assertSoupEquals("<b>", "<b></b>")
|
||||
|
||||
self.assertSoupEquals("<br>", "<br/>")
|
||||
|
||||
def test_br_is_always_empty_element_tag(self):
|
||||
"""A <br> tag is designated as an empty-element tag.
|
||||
|
||||
Some parsers treat <br></br> as one <br/> tag, some parsers as
|
||||
two tags, but it should always be an empty-element tag.
|
||||
"""
|
||||
soup = self.soup("<br></br>")
|
||||
self.assertTrue(soup.br.is_empty_element)
|
||||
self.assertEqual(str(soup.br), "<br/>")
|
||||
|
||||
def test_nested_formatting_elements(self):
|
||||
self.assertSoupEquals("<em><em></em></em>")
|
||||
|
||||
def test_double_head(self):
|
||||
html = '''<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Ordinary HEAD element test</title>
|
||||
</head>
|
||||
<script type="text/javascript">
|
||||
alert("Help!");
|
||||
</script>
|
||||
<body>
|
||||
Hello, world!
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
soup = self.soup(html)
|
||||
self.assertEqual("text/javascript", soup.find('script')['type'])
|
||||
|
||||
def test_comment(self):
|
||||
# Comments are represented as Comment objects.
|
||||
markup = "<p>foo<!--foobar-->baz</p>"
|
||||
self.assertSoupEquals(markup)
|
||||
|
||||
soup = self.soup(markup)
|
||||
comment = soup.find(text="foobar")
|
||||
self.assertEqual(comment.__class__, Comment)
|
||||
|
||||
# The comment is properly integrated into the tree.
|
||||
foo = soup.find(text="foo")
|
||||
self.assertEqual(comment, foo.next_element)
|
||||
baz = soup.find(text="baz")
|
||||
self.assertEqual(comment, baz.previous_element)
|
||||
|
||||
def test_preserved_whitespace_in_pre_and_textarea(self):
|
||||
"""Whitespace must be preserved in <pre> and <textarea> tags."""
|
||||
self.assertSoupEquals("<pre> </pre>")
|
||||
self.assertSoupEquals("<textarea> woo </textarea>")
|
||||
|
||||
def test_nested_inline_elements(self):
|
||||
"""Inline elements can be nested indefinitely."""
|
||||
b_tag = "<b>Inside a B tag</b>"
|
||||
self.assertSoupEquals(b_tag)
|
||||
|
||||
nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>"
|
||||
self.assertSoupEquals(nested_b_tag)
|
||||
|
||||
double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>"
|
||||
self.assertSoupEquals(nested_b_tag)
|
||||
|
||||
def test_nested_block_level_elements(self):
|
||||
"""Block elements can be nested."""
|
||||
soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>')
|
||||
blockquote = soup.blockquote
|
||||
self.assertEqual(blockquote.p.b.string, 'Foo')
|
||||
self.assertEqual(blockquote.b.string, 'Foo')
|
||||
|
||||
def test_correctly_nested_tables(self):
|
||||
"""One table can go inside another one."""
|
||||
markup = ('<table id="1">'
|
||||
'<tr>'
|
||||
"<td>Here's another table:"
|
||||
'<table id="2">'
|
||||
'<tr><td>foo</td></tr>'
|
||||
'</table></td>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
markup,
|
||||
'<table id="1"><tr><td>Here\'s another table:'
|
||||
'<table id="2"><tr><td>foo</td></tr></table>'
|
||||
'</td></tr></table>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
"<table><thead><tr><td>Foo</td></tr></thead>"
|
||||
"<tbody><tr><td>Bar</td></tr></tbody>"
|
||||
"<tfoot><tr><td>Baz</td></tr></tfoot></table>")
|
||||
|
||||
def test_deeply_nested_multivalued_attribute(self):
|
||||
# html5lib can set the attributes of the same tag many times
|
||||
# as it rearranges the tree. This has caused problems with
|
||||
# multivalued attributes.
|
||||
markup = '<table><div><div class="css"></div></div></table>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(["css"], soup.div.div['class'])
|
||||
|
||||
def test_multivalued_attribute_on_html(self):
|
||||
# html5lib uses a different API to set the attributes ot the
|
||||
# <html> tag. This has caused problems with multivalued
|
||||
# attributes.
|
||||
markup = '<html class="a b"></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(["a", "b"], soup.html['class'])
|
||||
|
||||
def test_angle_brackets_in_attribute_values_are_escaped(self):
|
||||
self.assertSoupEquals('<a b="<a>"></a>', '<a b="<a>"></a>')
|
||||
|
||||
def test_entities_in_attributes_converted_to_unicode(self):
|
||||
expect = u'<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
self.assertSoupEquals('<p id="piñata"></p>', expect)
|
||||
|
||||
def test_entities_in_text_converted_to_unicode(self):
|
||||
expect = u'<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
self.assertSoupEquals("<p>piñata</p>", expect)
|
||||
|
||||
def test_quot_entity_converted_to_quotation_mark(self):
|
||||
self.assertSoupEquals("<p>I said "good day!"</p>",
|
||||
'<p>I said "good day!"</p>')
|
||||
|
||||
def test_out_of_range_entity(self):
|
||||
expect = u"\N{REPLACEMENT CHARACTER}"
|
||||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
|
||||
def test_multipart_strings(self):
|
||||
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
|
||||
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
|
||||
self.assertEqual("p", soup.h2.string.next_element.name)
|
||||
self.assertEqual("p", soup.p.name)
|
||||
self.assertConnectedness(soup)
|
||||
|
||||
def test_head_tag_between_head_and_body(self):
|
||||
"Prevent recurrence of a bug in the html5lib treebuilder."
|
||||
content = """<html><head></head>
|
||||
<link></link>
|
||||
<body>foo</body>
|
||||
</html>
|
||||
"""
|
||||
soup = self.soup(content)
|
||||
self.assertNotEqual(None, soup.html.body)
|
||||
self.assertConnectedness(soup)
|
||||
|
||||
def test_multiple_copies_of_a_tag(self):
|
||||
"Prevent recurrence of a bug in the html5lib treebuilder."
|
||||
content = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<article id="a" >
|
||||
<div><a href="1"></div>
|
||||
<footer>
|
||||
<a href="2"></a>
|
||||
</footer>
|
||||
</article>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
soup = self.soup(content)
|
||||
self.assertConnectedness(soup.article)
|
||||
|
||||
def test_basic_namespaces(self):
|
||||
"""Parsers don't need to *understand* namespaces, but at the
|
||||
very least they should not choke on namespaces or lose
|
||||
data."""
|
||||
|
||||
markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(markup, soup.encode())
|
||||
html = soup.html
|
||||
self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns'])
|
||||
self.assertEqual(
|
||||
'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml'])
|
||||
self.assertEqual(
|
||||
'http://www.w3.org/2000/svg', soup.html['xmlns:svg'])
|
||||
|
||||
def test_multivalued_attribute_value_becomes_list(self):
|
||||
markup = b'<a class="foo bar">'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(['foo', 'bar'], soup.a['class'])
|
||||
|
||||
#
|
||||
# Generally speaking, tests below this point are more tests of
|
||||
# Beautiful Soup than tests of the tree builders. But parsers are
|
||||
# weird, so we run these tests separately for every tree builder
|
||||
# to detect any differences between them.
|
||||
#
|
||||
|
||||
def test_can_parse_unicode_document(self):
|
||||
# A seemingly innocuous document... but it's in Unicode! And
|
||||
# it contains characters that can't be represented in the
|
||||
# encoding found in the declaration! The horror!
|
||||
markup = u'<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(u'Sacr\xe9 bleu!', soup.body.string)
|
||||
|
||||
def test_soupstrainer(self):
|
||||
"""Parsers should be able to work with SoupStrainers."""
|
||||
strainer = SoupStrainer("b")
|
||||
soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>",
|
||||
parse_only=strainer)
|
||||
self.assertEqual(soup.decode(), "<b>bold</b>")
|
||||
|
||||
def test_single_quote_attribute_values_become_double_quotes(self):
|
||||
self.assertSoupEquals("<foo attr='bar'></foo>",
|
||||
'<foo attr="bar"></foo>')
|
||||
|
||||
def test_attribute_values_with_nested_quotes_are_left_alone(self):
|
||||
text = """<foo attr='bar "brawls" happen'>a</foo>"""
|
||||
self.assertSoupEquals(text)
|
||||
|
||||
def test_attribute_values_with_double_nested_quotes_get_quoted(self):
|
||||
text = """<foo attr='bar "brawls" happen'>a</foo>"""
|
||||
soup = self.soup(text)
|
||||
soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"'
|
||||
self.assertSoupEquals(
|
||||
soup.foo.decode(),
|
||||
"""<foo attr="Brawls happen at "Bob\'s Bar"">a</foo>""")
|
||||
|
||||
def test_ampersand_in_attribute_value_gets_escaped(self):
|
||||
self.assertSoupEquals('<this is="really messed up & stuff"></this>',
|
||||
'<this is="really messed up & stuff"></this>')
|
||||
|
||||
self.assertSoupEquals(
|
||||
'<a href="http://example.org?a=1&b=2;3">foo</a>',
|
||||
'<a href="http://example.org?a=1&b=2;3">foo</a>')
|
||||
|
||||
def test_escaped_ampersand_in_attribute_value_is_left_alone(self):
|
||||
self.assertSoupEquals('<a href="http://example.org?a=1&b=2;3"></a>')
|
||||
|
||||
def test_entities_in_strings_converted_during_parsing(self):
|
||||
# Both XML and HTML entities are converted to Unicode characters
|
||||
# during parsing.
|
||||
text = "<p><<sacré bleu!>></p>"
|
||||
expected = u"<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>"
|
||||
self.assertSoupEquals(text, expected)
|
||||
|
||||
def test_smart_quotes_converted_on_the_way_in(self):
|
||||
# Microsoft smart quotes are converted to Unicode characters during
|
||||
# parsing.
|
||||
quote = b"<p>\x91Foo\x92</p>"
|
||||
soup = self.soup(quote)
|
||||
self.assertEqual(
|
||||
soup.p.string,
|
||||
u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
|
||||
|
||||
def test_non_breaking_spaces_converted_on_the_way_in(self):
|
||||
soup = self.soup("<a> </a>")
|
||||
self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2)
|
||||
|
||||
def test_entities_converted_on_the_way_out(self):
|
||||
text = "<p><<sacré bleu!>></p>"
|
||||
expected = u"<p><<sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></p>".encode("utf-8")
|
||||
soup = self.soup(text)
|
||||
self.assertEqual(soup.p.encode("utf-8"), expected)
|
||||
|
||||
def test_real_iso_latin_document(self):
|
||||
# Smoke test of interrelated functionality, using an
|
||||
# easy-to-understand document.
|
||||
|
||||
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
|
||||
unicode_html = u'<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
|
||||
|
||||
# That's because we're going to encode it into ISO-Latin-1, and use
|
||||
# that to test.
|
||||
iso_latin_html = unicode_html.encode("iso-8859-1")
|
||||
|
||||
# Parse the ISO-Latin-1 HTML.
|
||||
soup = self.soup(iso_latin_html)
|
||||
# Encode it to UTF-8.
|
||||
result = soup.encode("utf-8")
|
||||
|
||||
# What do we expect the result to look like? Well, it would
|
||||
# look like unicode_html, except that the META tag would say
|
||||
# UTF-8 instead of ISO-Latin-1.
|
||||
expected = unicode_html.replace("ISO-Latin-1", "utf-8")
|
||||
|
||||
# And, of course, it would be in UTF-8, not Unicode.
|
||||
expected = expected.encode("utf-8")
|
||||
|
||||
# Ta-da!
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_real_shift_jis_document(self):
|
||||
# Smoke test to make sure the parser can handle a document in
|
||||
# Shift-JIS encoding, without choking.
|
||||
shift_jis_html = (
|
||||
b'<html><head></head><body><pre>'
|
||||
b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
|
||||
b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
|
||||
b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
|
||||
b'</pre></body></html>')
|
||||
unicode_html = shift_jis_html.decode("shift-jis")
|
||||
soup = self.soup(unicode_html)
|
||||
|
||||
# Make sure the parse tree is correctly encoded to various
|
||||
# encodings.
|
||||
self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8"))
|
||||
self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp"))
|
||||
|
||||
def test_real_hebrew_document(self):
|
||||
# A real-world test to make sure we can convert ISO-8859-9 (a
|
||||
# Hebrew encoding) to UTF-8.
|
||||
hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>'
|
||||
soup = self.soup(
|
||||
hebrew_document, from_encoding="iso8859-8")
|
||||
self.assertEqual(soup.original_encoding, 'iso8859-8')
|
||||
self.assertEqual(
|
||||
soup.encode('utf-8'),
|
||||
hebrew_document.decode("iso8859-8").encode("utf-8"))
|
||||
|
||||
def test_meta_tag_reflects_current_encoding(self):
|
||||
# Here's the <meta> tag saying that a document is
|
||||
# encoded in Shift-JIS.
|
||||
meta_tag = ('<meta content="text/html; charset=x-sjis" '
|
||||
'http-equiv="Content-type"/>')
|
||||
|
||||
# Here's a document incorporating that meta tag.
|
||||
shift_jis_html = (
|
||||
'<html><head>\n%s\n'
|
||||
'<meta http-equiv="Content-language" content="ja"/>'
|
||||
'</head><body>Shift-JIS markup goes here.') % meta_tag
|
||||
soup = self.soup(shift_jis_html)
|
||||
|
||||
# Parse the document, and the charset is seemingly unaffected.
|
||||
parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'})
|
||||
content = parsed_meta['content']
|
||||
self.assertEqual('text/html; charset=x-sjis', content)
|
||||
|
||||
# But that value is actually a ContentMetaAttributeValue object.
|
||||
self.assertTrue(isinstance(content, ContentMetaAttributeValue))
|
||||
|
||||
# And it will take on a value that reflects its current
|
||||
# encoding.
|
||||
self.assertEqual('text/html; charset=utf8', content.encode("utf8"))
|
||||
|
||||
# For the rest of the story, see TestSubstitutions in
|
||||
# test_tree.py.
|
||||
|
||||
def test_html5_style_meta_tag_reflects_current_encoding(self):
|
||||
# Here's the <meta> tag saying that a document is
|
||||
# encoded in Shift-JIS.
|
||||
meta_tag = ('<meta id="encoding" charset="x-sjis" />')
|
||||
|
||||
# Here's a document incorporating that meta tag.
|
||||
shift_jis_html = (
|
||||
'<html><head>\n%s\n'
|
||||
'<meta http-equiv="Content-language" content="ja"/>'
|
||||
'</head><body>Shift-JIS markup goes here.') % meta_tag
|
||||
soup = self.soup(shift_jis_html)
|
||||
|
||||
# Parse the document, and the charset is seemingly unaffected.
|
||||
parsed_meta = soup.find('meta', id="encoding")
|
||||
charset = parsed_meta['charset']
|
||||
self.assertEqual('x-sjis', charset)
|
||||
|
||||
# But that value is actually a CharsetMetaAttributeValue object.
|
||||
self.assertTrue(isinstance(charset, CharsetMetaAttributeValue))
|
||||
|
||||
# And it will take on a value that reflects its current
|
||||
# encoding.
|
||||
self.assertEqual('utf8', charset.encode("utf8"))
|
||||
|
||||
def test_tag_with_no_attributes_can_have_attributes_added(self):
|
||||
data = self.soup("<a>text</a>")
|
||||
data.a['foo'] = 'bar'
|
||||
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
|
||||
|
||||
class XMLTreeBuilderSmokeTest(object):
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
# to the original.
|
||||
tree = self.soup("<a><b>foo</a>")
|
||||
dumped = pickle.dumps(tree, 2)
|
||||
loaded = pickle.loads(dumped)
|
||||
self.assertEqual(loaded.__class__, BeautifulSoup)
|
||||
self.assertEqual(loaded.decode(), tree.decode())
|
||||
|
||||
def test_docstring_generated(self):
|
||||
soup = self.soup("<root/>")
|
||||
self.assertEqual(
|
||||
soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>')
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
"""A real XHTML document should come out *exactly* the same as it went in."""
|
||||
markup = b"""<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
|
||||
<html xmlns="http://www.w3.org/1999/xhtml">
|
||||
<head><title>Hello.</title></head>
|
||||
<body>Goodbye.</body>
|
||||
</html>"""
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
soup.encode("utf-8"), markup)
|
||||
|
||||
def test_formatter_processes_script_tag_for_xml_documents(self):
|
||||
doc = """
|
||||
<script type="text/javascript">
|
||||
</script>
|
||||
"""
|
||||
soup = BeautifulSoup(doc, "lxml-xml")
|
||||
# lxml would have stripped this while parsing, but we can add
|
||||
# it later.
|
||||
soup.script.string = 'console.log("< < hey > > ");'
|
||||
encoded = soup.encode()
|
||||
self.assertTrue(b"< < hey > >" in encoded)
|
||||
|
||||
def test_can_parse_unicode_document(self):
|
||||
markup = u'<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(u'Sacr\xe9 bleu!', soup.root.string)
|
||||
|
||||
def test_popping_namespaced_tag(self):
|
||||
markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
unicode(soup.rss), markup)
|
||||
|
||||
def test_docstring_includes_correct_encoding(self):
|
||||
soup = self.soup("<root/>")
|
||||
self.assertEqual(
|
||||
soup.encode("latin1"),
|
||||
b'<?xml version="1.0" encoding="latin1"?>\n<root/>')
|
||||
|
||||
def test_large_xml_document(self):
|
||||
"""A large XML document should come out the same as it went in."""
|
||||
markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>'
|
||||
+ b'0' * (2**12)
|
||||
+ b'</root>')
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(soup.encode("utf-8"), markup)
|
||||
|
||||
|
||||
def test_tags_are_empty_element_if_and_only_if_they_are_empty(self):
|
||||
self.assertSoupEquals("<p>", "<p/>")
|
||||
self.assertSoupEquals("<p>foo</p>")
|
||||
|
||||
def test_namespaces_are_preserved(self):
|
||||
markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>'
|
||||
soup = self.soup(markup)
|
||||
root = soup.root
|
||||
self.assertEqual("http://example.com/", root['xmlns:a'])
|
||||
self.assertEqual("http://example.net/", root['xmlns:b'])
|
||||
|
||||
def test_closing_namespaced_tag(self):
|
||||
markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.p), markup)
|
||||
|
||||
def test_namespaced_attributes(self):
|
||||
markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.foo), markup)
|
||||
|
||||
def test_namespaced_attributes_xml_namespace(self):
|
||||
markup = '<foo xml:lang="fr">bar</foo>'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(unicode(soup.foo), markup)
|
||||
|
||||
class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
|
||||
"""Smoke test for a tree builder that supports HTML5."""
|
||||
|
||||
def test_real_xhtml_document(self):
|
||||
# Since XHTML is not HTML5, HTML5 parsers are not tested to handle
|
||||
# XHTML documents in any particular way.
|
||||
pass
|
||||
|
||||
def test_html_tags_have_namespace(self):
|
||||
markup = "<a>"
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace)
|
||||
|
||||
def test_svg_tags_have_namespace(self):
|
||||
markup = '<svg><circle/></svg>'
|
||||
soup = self.soup(markup)
|
||||
namespace = "http://www.w3.org/2000/svg"
|
||||
self.assertEqual(namespace, soup.svg.namespace)
|
||||
self.assertEqual(namespace, soup.circle.namespace)
|
||||
|
||||
|
||||
def test_mathml_tags_have_namespace(self):
|
||||
markup = '<math><msqrt>5</msqrt></math>'
|
||||
soup = self.soup(markup)
|
||||
namespace = 'http://www.w3.org/1998/Math/MathML'
|
||||
self.assertEqual(namespace, soup.math.namespace)
|
||||
self.assertEqual(namespace, soup.msqrt.namespace)
|
||||
|
||||
def test_xml_declaration_becomes_comment(self):
|
||||
markup = '<?xml version="1.0" encoding="utf-8"?><html></html>'
|
||||
soup = self.soup(markup)
|
||||
self.assertTrue(isinstance(soup.contents[0], Comment))
|
||||
self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?')
|
||||
self.assertEqual("html", soup.contents[0].next_element.name)
|
||||
|
||||
def skipIf(condition, reason):
|
||||
def nothing(test, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def decorator(test_item):
|
||||
if condition:
|
||||
return nothing
|
||||
else:
|
||||
return test_item
|
||||
|
||||
return decorator
|
||||
@@ -8,9 +8,8 @@
|
||||
Certificate generation module.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from OpenSSL import crypto
|
||||
import time
|
||||
|
||||
TYPE_RSA = crypto.TYPE_RSA
|
||||
TYPE_DSA = crypto.TYPE_DSA
|
||||
@@ -30,7 +29,6 @@ def createKeyPair(type, bits):
|
||||
pkey.generate_key(type, bits)
|
||||
return pkey
|
||||
|
||||
|
||||
def createCertRequest(pkey, digest="md5", **name):
|
||||
"""
|
||||
Create a certificate request.
|
||||
@@ -51,14 +49,13 @@ def createCertRequest(pkey, digest="md5", **name):
|
||||
req = crypto.X509Req()
|
||||
subj = req.get_subject()
|
||||
|
||||
for (key, value) in name.items():
|
||||
for (key,value) in name.items():
|
||||
setattr(subj, key, value)
|
||||
|
||||
req.set_pubkey(pkey)
|
||||
req.sign(pkey, digest)
|
||||
return req
|
||||
|
||||
|
||||
def createCertificate(req, (issuerCert, issuerKey), serial, (notBefore, notAfter), digest="md5"):
|
||||
"""
|
||||
Generate a certificate given a certificate request.
|
||||
25
lib/cherrypy/LICENSE.txt
Normal file
25
lib/cherrypy/LICENSE.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
Copyright (c) 2004-2011, CherryPy Team (team@cherrypy.org)
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
* Neither the name of the CherryPy Team nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
652
lib/cherrypy/__init__.py
Normal file
652
lib/cherrypy/__init__.py
Normal file
@@ -0,0 +1,652 @@
|
||||
"""CherryPy is a pythonic, object-oriented HTTP framework.
|
||||
|
||||
|
||||
CherryPy consists of not one, but four separate API layers.
|
||||
|
||||
The APPLICATION LAYER is the simplest. CherryPy applications are written as
|
||||
a tree of classes and methods, where each branch in the tree corresponds to
|
||||
a branch in the URL path. Each method is a 'page handler', which receives
|
||||
GET and POST params as keyword arguments, and returns or yields the (HTML)
|
||||
body of the response. The special method name 'index' is used for paths
|
||||
that end in a slash, and the special method name 'default' is used to
|
||||
handle multiple paths via a single handler. This layer also includes:
|
||||
|
||||
* the 'exposed' attribute (and cherrypy.expose)
|
||||
* cherrypy.quickstart()
|
||||
* _cp_config attributes
|
||||
* cherrypy.tools (including cherrypy.session)
|
||||
* cherrypy.url()
|
||||
|
||||
The ENVIRONMENT LAYER is used by developers at all levels. It provides
|
||||
information about the current request and response, plus the application
|
||||
and server environment, via a (default) set of top-level objects:
|
||||
|
||||
* cherrypy.request
|
||||
* cherrypy.response
|
||||
* cherrypy.engine
|
||||
* cherrypy.server
|
||||
* cherrypy.tree
|
||||
* cherrypy.config
|
||||
* cherrypy.thread_data
|
||||
* cherrypy.log
|
||||
* cherrypy.HTTPError, NotFound, and HTTPRedirect
|
||||
* cherrypy.lib
|
||||
|
||||
The EXTENSION LAYER allows advanced users to construct and share their own
|
||||
plugins. It consists of:
|
||||
|
||||
* Hook API
|
||||
* Tool API
|
||||
* Toolbox API
|
||||
* Dispatch API
|
||||
* Config Namespace API
|
||||
|
||||
Finally, there is the CORE LAYER, which uses the core API's to construct
|
||||
the default components which are available at higher layers. You can think
|
||||
of the default components as the 'reference implementation' for CherryPy.
|
||||
Megaframeworks (and advanced users) may replace the default components
|
||||
with customized or extended components. The core API's are:
|
||||
|
||||
* Application API
|
||||
* Engine API
|
||||
* Request API
|
||||
* Server API
|
||||
* WSGI API
|
||||
|
||||
These API's are described in the `CherryPy specification <https://bitbucket.org/cherrypy/cherrypy/wiki/CherryPySpec>`_.
|
||||
"""
|
||||
|
||||
__version__ = "3.6.0"
|
||||
|
||||
from cherrypy._cpcompat import urljoin as _urljoin, urlencode as _urlencode
|
||||
from cherrypy._cpcompat import basestring, unicodestr, set
|
||||
|
||||
from cherrypy._cperror import HTTPError, HTTPRedirect, InternalRedirect
|
||||
from cherrypy._cperror import NotFound, CherryPyException, TimeoutError
|
||||
|
||||
from cherrypy import _cpdispatch as dispatch
|
||||
|
||||
from cherrypy import _cptools
|
||||
tools = _cptools.default_toolbox
|
||||
Tool = _cptools.Tool
|
||||
|
||||
from cherrypy import _cprequest
|
||||
from cherrypy.lib import httputil as _httputil
|
||||
|
||||
from cherrypy import _cptree
|
||||
tree = _cptree.Tree()
|
||||
from cherrypy._cptree import Application
|
||||
from cherrypy import _cpwsgi as wsgi
|
||||
|
||||
from cherrypy import process
|
||||
try:
|
||||
from cherrypy.process import win32
|
||||
engine = win32.Win32Bus()
|
||||
engine.console_control_handler = win32.ConsoleCtrlHandler(engine)
|
||||
del win32
|
||||
except ImportError:
|
||||
engine = process.bus
|
||||
|
||||
|
||||
# Timeout monitor. We add two channels to the engine
|
||||
# to which cherrypy.Application will publish.
|
||||
engine.listeners['before_request'] = set()
|
||||
engine.listeners['after_request'] = set()
|
||||
|
||||
|
||||
class _TimeoutMonitor(process.plugins.Monitor):
|
||||
|
||||
def __init__(self, bus):
|
||||
self.servings = []
|
||||
process.plugins.Monitor.__init__(self, bus, self.run)
|
||||
|
||||
def before_request(self):
|
||||
self.servings.append((serving.request, serving.response))
|
||||
|
||||
def after_request(self):
|
||||
try:
|
||||
self.servings.remove((serving.request, serving.response))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""Check timeout on all responses. (Internal)"""
|
||||
for req, resp in self.servings:
|
||||
resp.check_timeout()
|
||||
engine.timeout_monitor = _TimeoutMonitor(engine)
|
||||
engine.timeout_monitor.subscribe()
|
||||
|
||||
engine.autoreload = process.plugins.Autoreloader(engine)
|
||||
engine.autoreload.subscribe()
|
||||
|
||||
engine.thread_manager = process.plugins.ThreadManager(engine)
|
||||
engine.thread_manager.subscribe()
|
||||
|
||||
engine.signal_handler = process.plugins.SignalHandler(engine)
|
||||
|
||||
|
||||
class _HandleSignalsPlugin(object):
|
||||
|
||||
"""Handle signals from other processes based on the configured
|
||||
platform handlers above."""
|
||||
|
||||
def __init__(self, bus):
|
||||
self.bus = bus
|
||||
|
||||
def subscribe(self):
|
||||
"""Add the handlers based on the platform"""
|
||||
if hasattr(self.bus, "signal_handler"):
|
||||
self.bus.signal_handler.subscribe()
|
||||
if hasattr(self.bus, "console_control_handler"):
|
||||
self.bus.console_control_handler.subscribe()
|
||||
|
||||
engine.signals = _HandleSignalsPlugin(engine)
|
||||
|
||||
|
||||
from cherrypy import _cpserver
|
||||
server = _cpserver.Server()
|
||||
server.subscribe()
|
||||
|
||||
|
||||
def quickstart(root=None, script_name="", config=None):
|
||||
"""Mount the given root, start the builtin server (and engine), then block.
|
||||
|
||||
root: an instance of a "controller class" (a collection of page handler
|
||||
methods) which represents the root of the application.
|
||||
script_name: a string containing the "mount point" of the application.
|
||||
This should start with a slash, and be the path portion of the URL
|
||||
at which to mount the given root. For example, if root.index() will
|
||||
handle requests to "http://www.example.com:8080/dept/app1/", then
|
||||
the script_name argument would be "/dept/app1".
|
||||
|
||||
It MUST NOT end in a slash. If the script_name refers to the root
|
||||
of the URI, it MUST be an empty string (not "/").
|
||||
config: a file or dict containing application config. If this contains
|
||||
a [global] section, those entries will be used in the global
|
||||
(site-wide) config.
|
||||
"""
|
||||
if config:
|
||||
_global_conf_alias.update(config)
|
||||
|
||||
tree.mount(root, script_name, config)
|
||||
|
||||
engine.signals.subscribe()
|
||||
engine.start()
|
||||
engine.block()
|
||||
|
||||
|
||||
from cherrypy._cpcompat import threadlocal as _local
|
||||
|
||||
|
||||
class _Serving(_local):
|
||||
|
||||
"""An interface for registering request and response objects.
|
||||
|
||||
Rather than have a separate "thread local" object for the request and
|
||||
the response, this class works as a single threadlocal container for
|
||||
both objects (and any others which developers wish to define). In this
|
||||
way, we can easily dump those objects when we stop/start a new HTTP
|
||||
conversation, yet still refer to them as module-level globals in a
|
||||
thread-safe way.
|
||||
"""
|
||||
|
||||
request = _cprequest.Request(_httputil.Host("127.0.0.1", 80),
|
||||
_httputil.Host("127.0.0.1", 1111))
|
||||
"""
|
||||
The request object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
response = _cprequest.Response()
|
||||
"""
|
||||
The response object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
def load(self, request, response):
|
||||
self.request = request
|
||||
self.response = response
|
||||
|
||||
def clear(self):
|
||||
"""Remove all attributes of self."""
|
||||
self.__dict__.clear()
|
||||
|
||||
serving = _Serving()
|
||||
|
||||
|
||||
class _ThreadLocalProxy(object):
|
||||
|
||||
__slots__ = ['__attrname__', '__dict__']
|
||||
|
||||
def __init__(self, attrname):
|
||||
self.__attrname__ = attrname
|
||||
|
||||
def __getattr__(self, name):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return getattr(child, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in ("__attrname__", ):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
child = getattr(serving, self.__attrname__)
|
||||
setattr(child, name, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
delattr(child, name)
|
||||
|
||||
def _get_dict(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
d = child.__class__.__dict__.copy()
|
||||
d.update(child.__dict__)
|
||||
return d
|
||||
__dict__ = property(_get_dict)
|
||||
|
||||
def __getitem__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return child[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
child[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
del child[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return key in child
|
||||
|
||||
def __len__(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return len(child)
|
||||
|
||||
def __nonzero__(self):
|
||||
child = getattr(serving, self.__attrname__)
|
||||
return bool(child)
|
||||
# Python 3
|
||||
__bool__ = __nonzero__
|
||||
|
||||
# Create request and response object (the same objects will be used
|
||||
# throughout the entire life of the webserver, but will redirect
|
||||
# to the "serving" object)
|
||||
request = _ThreadLocalProxy('request')
|
||||
response = _ThreadLocalProxy('response')
|
||||
|
||||
# Create thread_data object as a thread-specific all-purpose storage
|
||||
|
||||
|
||||
class _ThreadData(_local):
|
||||
|
||||
"""A container for thread-specific data."""
|
||||
thread_data = _ThreadData()
|
||||
|
||||
|
||||
# Monkeypatch pydoc to allow help() to go through the threadlocal proxy.
|
||||
# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve.
|
||||
# The only other way would be to change what is returned from type(request)
|
||||
# and that's not possible in pure Python (you'd have to fake ob_type).
|
||||
def _cherrypy_pydoc_resolve(thing, forceload=0):
|
||||
"""Given an object or a path to an object, get the object and its name."""
|
||||
if isinstance(thing, _ThreadLocalProxy):
|
||||
thing = getattr(serving, thing.__attrname__)
|
||||
return _pydoc._builtin_resolve(thing, forceload)
|
||||
|
||||
try:
|
||||
import pydoc as _pydoc
|
||||
_pydoc._builtin_resolve = _pydoc.resolve
|
||||
_pydoc.resolve = _cherrypy_pydoc_resolve
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
from cherrypy import _cplogging
|
||||
|
||||
|
||||
class _GlobalLogManager(_cplogging.LogManager):
|
||||
|
||||
"""A site-wide LogManager; routes to app.log or global log as appropriate.
|
||||
|
||||
This :class:`LogManager<cherrypy._cplogging.LogManager>` implements
|
||||
cherrypy.log() and cherrypy.log.access(). If either
|
||||
function is called during a request, the message will be sent to the
|
||||
logger for the current Application. If they are called outside of a
|
||||
request, the message will be sent to the site-wide logger.
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Log the given message to the app.log or global log as appropriate.
|
||||
"""
|
||||
# Do NOT use try/except here. See
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/945
|
||||
if hasattr(request, 'app') and hasattr(request.app, 'log'):
|
||||
log = request.app.log
|
||||
else:
|
||||
log = self
|
||||
return log.error(*args, **kwargs)
|
||||
|
||||
def access(self):
|
||||
"""Log an access message to the app.log or global log as appropriate.
|
||||
"""
|
||||
try:
|
||||
return request.app.log.access()
|
||||
except AttributeError:
|
||||
return _cplogging.LogManager.access(self)
|
||||
|
||||
|
||||
log = _GlobalLogManager()
|
||||
# Set a default screen handler on the global log.
|
||||
log.screen = True
|
||||
log.error_file = ''
|
||||
# Using an access file makes CP about 10% slower. Leave off by default.
|
||||
log.access_file = ''
|
||||
|
||||
|
||||
def _buslog(msg, level):
|
||||
log.error(msg, 'ENGINE', severity=level)
|
||||
engine.subscribe('log', _buslog)
|
||||
|
||||
# Helper functions for CP apps #
|
||||
|
||||
|
||||
def expose(func=None, alias=None):
|
||||
"""Expose the function, optionally providing an alias or set of aliases."""
|
||||
def expose_(func):
|
||||
func.exposed = True
|
||||
if alias is not None:
|
||||
if isinstance(alias, basestring):
|
||||
parents[alias.replace(".", "_")] = func
|
||||
else:
|
||||
for a in alias:
|
||||
parents[a.replace(".", "_")] = func
|
||||
return func
|
||||
|
||||
import sys
|
||||
import types
|
||||
if isinstance(func, (types.FunctionType, types.MethodType)):
|
||||
if alias is None:
|
||||
# @expose
|
||||
func.exposed = True
|
||||
return func
|
||||
else:
|
||||
# func = expose(func, alias)
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_(func)
|
||||
elif func is None:
|
||||
if alias is None:
|
||||
# @expose()
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_
|
||||
else:
|
||||
# @expose(alias="alias") or
|
||||
# @expose(alias=["alias1", "alias2"])
|
||||
parents = sys._getframe(1).f_locals
|
||||
return expose_
|
||||
else:
|
||||
# @expose("alias") or
|
||||
# @expose(["alias1", "alias2"])
|
||||
parents = sys._getframe(1).f_locals
|
||||
alias = func
|
||||
return expose_
|
||||
|
||||
|
||||
def popargs(*args, **kwargs):
|
||||
"""A decorator for _cp_dispatch
|
||||
(cherrypy.dispatch.Dispatcher.dispatch_method_name).
|
||||
|
||||
Optional keyword argument: handler=(Object or Function)
|
||||
|
||||
Provides a _cp_dispatch function that pops off path segments into
|
||||
cherrypy.request.params under the names specified. The dispatch
|
||||
is then forwarded on to the next vpath element.
|
||||
|
||||
Note that any existing (and exposed) member function of the class that
|
||||
popargs is applied to will override that value of the argument. For
|
||||
instance, if you have a method named "list" on the class decorated with
|
||||
popargs, then accessing "/list" will call that function instead of popping
|
||||
it off as the requested parameter. This restriction applies to all
|
||||
_cp_dispatch functions. The only way around this restriction is to create
|
||||
a "blank class" whose only function is to provide _cp_dispatch.
|
||||
|
||||
If there are path elements after the arguments, or more arguments
|
||||
are requested than are available in the vpath, then the 'handler'
|
||||
keyword argument specifies the next object to handle the parameterized
|
||||
request. If handler is not specified or is None, then self is used.
|
||||
If handler is a function rather than an instance, then that function
|
||||
will be called with the args specified and the return value from that
|
||||
function used as the next object INSTEAD of adding the parameters to
|
||||
cherrypy.request.args.
|
||||
|
||||
This decorator may be used in one of two ways:
|
||||
|
||||
As a class decorator:
|
||||
@cherrypy.popargs('year', 'month', 'day')
|
||||
class Blog:
|
||||
def index(self, year=None, month=None, day=None):
|
||||
#Process the parameters here; any url like
|
||||
#/, /2009, /2009/12, or /2009/12/31
|
||||
#will fill in the appropriate parameters.
|
||||
|
||||
def create(self):
|
||||
#This link will still be available at /create. Defined functions
|
||||
#take precedence over arguments.
|
||||
|
||||
Or as a member of a class:
|
||||
class Blog:
|
||||
_cp_dispatch = cherrypy.popargs('year', 'month', 'day')
|
||||
#...
|
||||
|
||||
The handler argument may be used to mix arguments with built in functions.
|
||||
For instance, the following setup allows different activities at the
|
||||
day, month, and year level:
|
||||
|
||||
class DayHandler:
|
||||
def index(self, year, month, day):
|
||||
#Do something with this day; probably list entries
|
||||
|
||||
def delete(self, year, month, day):
|
||||
#Delete all entries for this day
|
||||
|
||||
@cherrypy.popargs('day', handler=DayHandler())
|
||||
class MonthHandler:
|
||||
def index(self, year, month):
|
||||
#Do something with this month; probably list entries
|
||||
|
||||
def delete(self, year, month):
|
||||
#Delete all entries for this month
|
||||
|
||||
@cherrypy.popargs('month', handler=MonthHandler())
|
||||
class YearHandler:
|
||||
def index(self, year):
|
||||
#Do something with this year
|
||||
|
||||
#...
|
||||
|
||||
@cherrypy.popargs('year', handler=YearHandler())
|
||||
class Root:
|
||||
def index(self):
|
||||
#...
|
||||
|
||||
"""
|
||||
|
||||
# Since keyword arg comes after *args, we have to process it ourselves
|
||||
# for lower versions of python.
|
||||
|
||||
handler = None
|
||||
handler_call = False
|
||||
for k, v in kwargs.items():
|
||||
if k == 'handler':
|
||||
handler = v
|
||||
else:
|
||||
raise TypeError(
|
||||
"cherrypy.popargs() got an unexpected keyword argument '{0}'"
|
||||
.format(k)
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
if handler is not None \
|
||||
and (hasattr(handler, '__call__') or inspect.isclass(handler)):
|
||||
handler_call = True
|
||||
|
||||
def decorated(cls_or_self=None, vpath=None):
|
||||
if inspect.isclass(cls_or_self):
|
||||
# cherrypy.popargs is a class decorator
|
||||
cls = cls_or_self
|
||||
setattr(cls, dispatch.Dispatcher.dispatch_method_name, decorated)
|
||||
return cls
|
||||
|
||||
# We're in the actual function
|
||||
self = cls_or_self
|
||||
parms = {}
|
||||
for arg in args:
|
||||
if not vpath:
|
||||
break
|
||||
parms[arg] = vpath.pop(0)
|
||||
|
||||
if handler is not None:
|
||||
if handler_call:
|
||||
return handler(**parms)
|
||||
else:
|
||||
request.params.update(parms)
|
||||
return handler
|
||||
|
||||
request.params.update(parms)
|
||||
|
||||
# If we are the ultimate handler, then to prevent our _cp_dispatch
|
||||
# from being called again, we will resolve remaining elements through
|
||||
# getattr() directly.
|
||||
if vpath:
|
||||
return getattr(self, vpath.pop(0), None)
|
||||
else:
|
||||
return self
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def url(path="", qs="", script_name=None, base=None, relative=None):
|
||||
"""Create an absolute URL for the given path.
|
||||
|
||||
If 'path' starts with a slash ('/'), this will return
|
||||
(base + script_name + path + qs).
|
||||
If it does not start with a slash, this returns
|
||||
(base + script_name [+ request.path_info] + path + qs).
|
||||
|
||||
If script_name is None, cherrypy.request will be used
|
||||
to find a script_name, if available.
|
||||
|
||||
If base is None, cherrypy.request.base will be used (if available).
|
||||
Note that you can use cherrypy.tools.proxy to change this.
|
||||
|
||||
Finally, note that this function can be used to obtain an absolute URL
|
||||
for the current request path (minus the querystring) by passing no args.
|
||||
If you call url(qs=cherrypy.request.query_string), you should get the
|
||||
original browser URL (assuming no internal redirections).
|
||||
|
||||
If relative is None or not provided, request.app.relative_urls will
|
||||
be used (if available, else False). If False, the output will be an
|
||||
absolute URL (including the scheme, host, vhost, and script_name).
|
||||
If True, the output will instead be a URL that is relative to the
|
||||
current request path, perhaps including '..' atoms. If relative is
|
||||
the string 'server', the output will instead be a URL that is
|
||||
relative to the server root; i.e., it will start with a slash.
|
||||
"""
|
||||
if isinstance(qs, (tuple, list, dict)):
|
||||
qs = _urlencode(qs)
|
||||
if qs:
|
||||
qs = '?' + qs
|
||||
|
||||
if request.app:
|
||||
if not path.startswith("/"):
|
||||
# Append/remove trailing slash from path_info as needed
|
||||
# (this is to support mistyped URL's without redirecting;
|
||||
# if you want to redirect, use tools.trailing_slash).
|
||||
pi = request.path_info
|
||||
if request.is_index is True:
|
||||
if not pi.endswith('/'):
|
||||
pi = pi + '/'
|
||||
elif request.is_index is False:
|
||||
if pi.endswith('/') and pi != '/':
|
||||
pi = pi[:-1]
|
||||
|
||||
if path == "":
|
||||
path = pi
|
||||
else:
|
||||
path = _urljoin(pi, path)
|
||||
|
||||
if script_name is None:
|
||||
script_name = request.script_name
|
||||
if base is None:
|
||||
base = request.base
|
||||
|
||||
newurl = base + script_name + path + qs
|
||||
else:
|
||||
# No request.app (we're being called outside a request).
|
||||
# We'll have to guess the base from server.* attributes.
|
||||
# This will produce very different results from the above
|
||||
# if you're using vhosts or tools.proxy.
|
||||
if base is None:
|
||||
base = server.base()
|
||||
|
||||
path = (script_name or "") + path
|
||||
newurl = base + path + qs
|
||||
|
||||
if './' in newurl:
|
||||
# Normalize the URL by removing ./ and ../
|
||||
atoms = []
|
||||
for atom in newurl.split('/'):
|
||||
if atom == '.':
|
||||
pass
|
||||
elif atom == '..':
|
||||
atoms.pop()
|
||||
else:
|
||||
atoms.append(atom)
|
||||
newurl = '/'.join(atoms)
|
||||
|
||||
# At this point, we should have a fully-qualified absolute URL.
|
||||
|
||||
if relative is None:
|
||||
relative = getattr(request.app, "relative_urls", False)
|
||||
|
||||
# See http://www.ietf.org/rfc/rfc2396.txt
|
||||
if relative == 'server':
|
||||
# "A relative reference beginning with a single slash character is
|
||||
# termed an absolute-path reference, as defined by <abs_path>..."
|
||||
# This is also sometimes called "server-relative".
|
||||
newurl = '/' + '/'.join(newurl.split('/', 3)[3:])
|
||||
elif relative:
|
||||
# "A relative reference that does not begin with a scheme name
|
||||
# or a slash character is termed a relative-path reference."
|
||||
old = url(relative=False).split('/')[:-1]
|
||||
new = newurl.split('/')
|
||||
while old and new:
|
||||
a, b = old[0], new[0]
|
||||
if a != b:
|
||||
break
|
||||
old.pop(0)
|
||||
new.pop(0)
|
||||
new = (['..'] * len(old)) + new
|
||||
newurl = '/'.join(new)
|
||||
|
||||
return newurl
|
||||
|
||||
|
||||
# import _cpconfig last so it can reference other top-level objects
|
||||
from cherrypy import _cpconfig
|
||||
# Use _global_conf_alias so quickstart can use 'config' as an arg
|
||||
# without shadowing cherrypy.config.
|
||||
config = _global_conf_alias = _cpconfig.Config()
|
||||
config.defaults = {
|
||||
'tools.log_tracebacks.on': True,
|
||||
'tools.log_headers.on': True,
|
||||
'tools.trailing_slash.on': True,
|
||||
'tools.encode.on': True
|
||||
}
|
||||
config.namespaces["log"] = lambda k, v: setattr(log, k, v)
|
||||
config.namespaces["checker"] = lambda k, v: setattr(checker, k, v)
|
||||
# Must reset to get our defaults applied.
|
||||
config.reset()
|
||||
|
||||
from cherrypy import _cpchecker
|
||||
checker = _cpchecker.Checker()
|
||||
engine.subscribe('start', checker)
|
||||
332
lib/cherrypy/_cpchecker.py
Normal file
332
lib/cherrypy/_cpchecker.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import iteritems, copykeys, builtins
|
||||
|
||||
|
||||
class Checker(object):
|
||||
|
||||
"""A checker for CherryPy sites and their mounted applications.
|
||||
|
||||
When this object is called at engine startup, it executes each
|
||||
of its own methods whose names start with ``check_``. If you wish
|
||||
to disable selected checks, simply add a line in your global
|
||||
config which sets the appropriate method to False::
|
||||
|
||||
[global]
|
||||
checker.check_skipped_app_config = False
|
||||
|
||||
You may also dynamically add or replace ``check_*`` methods in this way.
|
||||
"""
|
||||
|
||||
on = True
|
||||
"""If True (the default), run all checks; if False, turn off all checks."""
|
||||
|
||||
def __init__(self):
|
||||
self._populate_known_types()
|
||||
|
||||
def __call__(self):
|
||||
"""Run all check_* methods."""
|
||||
if self.on:
|
||||
oldformatwarning = warnings.formatwarning
|
||||
warnings.formatwarning = self.formatwarning
|
||||
try:
|
||||
for name in dir(self):
|
||||
if name.startswith("check_"):
|
||||
method = getattr(self, name)
|
||||
if method and hasattr(method, '__call__'):
|
||||
method()
|
||||
finally:
|
||||
warnings.formatwarning = oldformatwarning
|
||||
|
||||
def formatwarning(self, message, category, filename, lineno, line=None):
|
||||
"""Function to format a warning."""
|
||||
return "CherryPy Checker:\n%s\n\n" % message
|
||||
|
||||
# This value should be set inside _cpconfig.
|
||||
global_config_contained_paths = False
|
||||
|
||||
def check_app_config_entries_dont_start_with_script_name(self):
|
||||
"""Check for Application config with sections that repeat script_name.
|
||||
"""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
continue
|
||||
if sn == '':
|
||||
continue
|
||||
sn_atoms = sn.strip("/").split("/")
|
||||
for key in app.config.keys():
|
||||
key_atoms = key.strip("/").split("/")
|
||||
if key_atoms[:len(sn_atoms)] == sn_atoms:
|
||||
warnings.warn(
|
||||
"The application mounted at %r has config "
|
||||
"entries that start with its script name: %r" % (sn,
|
||||
key))
|
||||
|
||||
def check_site_config_entries_in_app_config(self):
|
||||
"""Check for mounted Applications that have site-scoped config."""
|
||||
for sn, app in iteritems(cherrypy.tree.apps):
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
|
||||
msg = []
|
||||
for section, entries in iteritems(app.config):
|
||||
if section.startswith('/'):
|
||||
for key, value in iteritems(entries):
|
||||
for n in ("engine.", "server.", "tree.", "checker."):
|
||||
if key.startswith(n):
|
||||
msg.append("[%s] %s = %s" %
|
||||
(section, key, value))
|
||||
if msg:
|
||||
msg.insert(0,
|
||||
"The application mounted at %r contains the "
|
||||
"following config entries, which are only allowed "
|
||||
"in site-wide config. Move them to a [global] "
|
||||
"section and pass them to cherrypy.config.update() "
|
||||
"instead of tree.mount()." % sn)
|
||||
warnings.warn(os.linesep.join(msg))
|
||||
|
||||
def check_skipped_app_config(self):
|
||||
"""Check for mounted Applications that have no config."""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
msg = "The Application mounted at %r has an empty config." % sn
|
||||
if self.global_config_contained_paths:
|
||||
msg += (" It looks like the config you passed to "
|
||||
"cherrypy.config.update() contains application-"
|
||||
"specific sections. You must explicitly pass "
|
||||
"application config via "
|
||||
"cherrypy.tree.mount(..., config=app_config)")
|
||||
warnings.warn(msg)
|
||||
return
|
||||
|
||||
def check_app_config_brackets(self):
|
||||
"""Check for Application config with extraneous brackets in section
|
||||
names.
|
||||
"""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
if not app.config:
|
||||
continue
|
||||
for key in app.config.keys():
|
||||
if key.startswith("[") or key.endswith("]"):
|
||||
warnings.warn(
|
||||
"The application mounted at %r has config "
|
||||
"section names with extraneous brackets: %r. "
|
||||
"Config *files* need brackets; config *dicts* "
|
||||
"(e.g. passed to tree.mount) do not." % (sn, key))
|
||||
|
||||
def check_static_paths(self):
|
||||
"""Check Application config for incorrect static paths."""
|
||||
# Use the dummy Request object in the main thread.
|
||||
request = cherrypy.request
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
request.app = app
|
||||
for section in app.config:
|
||||
# get_resource will populate request.config
|
||||
request.get_resource(section + "/dummy.html")
|
||||
conf = request.config.get
|
||||
|
||||
if conf("tools.staticdir.on", False):
|
||||
msg = ""
|
||||
root = conf("tools.staticdir.root")
|
||||
dir = conf("tools.staticdir.dir")
|
||||
if dir is None:
|
||||
msg = "tools.staticdir.dir is not set."
|
||||
else:
|
||||
fulldir = ""
|
||||
if os.path.isabs(dir):
|
||||
fulldir = dir
|
||||
if root:
|
||||
msg = ("dir is an absolute path, even "
|
||||
"though a root is provided.")
|
||||
testdir = os.path.join(root, dir[1:])
|
||||
if os.path.exists(testdir):
|
||||
msg += (
|
||||
"\nIf you meant to serve the "
|
||||
"filesystem folder at %r, remove the "
|
||||
"leading slash from dir." % (testdir,))
|
||||
else:
|
||||
if not root:
|
||||
msg = (
|
||||
"dir is a relative path and "
|
||||
"no root provided.")
|
||||
else:
|
||||
fulldir = os.path.join(root, dir)
|
||||
if not os.path.isabs(fulldir):
|
||||
msg = ("%r is not an absolute path." % (
|
||||
fulldir,))
|
||||
|
||||
if fulldir and not os.path.exists(fulldir):
|
||||
if msg:
|
||||
msg += "\n"
|
||||
msg += ("%r (root + dir) is not an existing "
|
||||
"filesystem path." % fulldir)
|
||||
|
||||
if msg:
|
||||
warnings.warn("%s\nsection: [%s]\nroot: %r\ndir: %r"
|
||||
% (msg, section, root, dir))
|
||||
|
||||
# -------------------------- Compatibility -------------------------- #
|
||||
obsolete = {
|
||||
'server.default_content_type': 'tools.response_headers.headers',
|
||||
'log_access_file': 'log.access_file',
|
||||
'log_config_options': None,
|
||||
'log_file': 'log.error_file',
|
||||
'log_file_not_found': None,
|
||||
'log_request_headers': 'tools.log_headers.on',
|
||||
'log_to_screen': 'log.screen',
|
||||
'show_tracebacks': 'request.show_tracebacks',
|
||||
'throw_errors': 'request.throw_errors',
|
||||
'profiler.on': ('cherrypy.tree.mount(profiler.make_app('
|
||||
'cherrypy.Application(Root())))'),
|
||||
}
|
||||
|
||||
deprecated = {}
|
||||
|
||||
def _compat(self, config):
|
||||
"""Process config and warn on each obsolete or deprecated entry."""
|
||||
for section, conf in config.items():
|
||||
if isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
if k in self.obsolete:
|
||||
warnings.warn("%r is obsolete. Use %r instead.\n"
|
||||
"section: [%s]" %
|
||||
(k, self.obsolete[k], section))
|
||||
elif k in self.deprecated:
|
||||
warnings.warn("%r is deprecated. Use %r instead.\n"
|
||||
"section: [%s]" %
|
||||
(k, self.deprecated[k], section))
|
||||
else:
|
||||
if section in self.obsolete:
|
||||
warnings.warn("%r is obsolete. Use %r instead."
|
||||
% (section, self.obsolete[section]))
|
||||
elif section in self.deprecated:
|
||||
warnings.warn("%r is deprecated. Use %r instead."
|
||||
% (section, self.deprecated[section]))
|
||||
|
||||
def check_compatibility(self):
|
||||
"""Process config and warn on each obsolete or deprecated entry."""
|
||||
self._compat(cherrypy.config)
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._compat(app.config)
|
||||
|
||||
# ------------------------ Known Namespaces ------------------------ #
|
||||
extra_config_namespaces = []
|
||||
|
||||
def _known_ns(self, app):
|
||||
ns = ["wsgi"]
|
||||
ns.extend(copykeys(app.toolboxes))
|
||||
ns.extend(copykeys(app.namespaces))
|
||||
ns.extend(copykeys(app.request_class.namespaces))
|
||||
ns.extend(copykeys(cherrypy.config.namespaces))
|
||||
ns += self.extra_config_namespaces
|
||||
|
||||
for section, conf in app.config.items():
|
||||
is_path_section = section.startswith("/")
|
||||
if is_path_section and isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
atoms = k.split(".")
|
||||
if len(atoms) > 1:
|
||||
if atoms[0] not in ns:
|
||||
# Spit out a special warning if a known
|
||||
# namespace is preceded by "cherrypy."
|
||||
if atoms[0] == "cherrypy" and atoms[1] in ns:
|
||||
msg = (
|
||||
"The config entry %r is invalid; "
|
||||
"try %r instead.\nsection: [%s]"
|
||||
% (k, ".".join(atoms[1:]), section))
|
||||
else:
|
||||
msg = (
|
||||
"The config entry %r is invalid, "
|
||||
"because the %r config namespace "
|
||||
"is unknown.\n"
|
||||
"section: [%s]" % (k, atoms[0], section))
|
||||
warnings.warn(msg)
|
||||
elif atoms[0] == "tools":
|
||||
if atoms[1] not in dir(cherrypy.tools):
|
||||
msg = (
|
||||
"The config entry %r may be invalid, "
|
||||
"because the %r tool was not found.\n"
|
||||
"section: [%s]" % (k, atoms[1], section))
|
||||
warnings.warn(msg)
|
||||
|
||||
def check_config_namespaces(self):
|
||||
"""Process config and warn on each unknown config namespace."""
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._known_ns(app)
|
||||
|
||||
# -------------------------- Config Types -------------------------- #
|
||||
known_config_types = {}
|
||||
|
||||
def _populate_known_types(self):
|
||||
b = [x for x in vars(builtins).values()
|
||||
if type(x) is type(str)]
|
||||
|
||||
def traverse(obj, namespace):
|
||||
for name in dir(obj):
|
||||
# Hack for 3.2's warning about body_params
|
||||
if name == 'body_params':
|
||||
continue
|
||||
vtype = type(getattr(obj, name, None))
|
||||
if vtype in b:
|
||||
self.known_config_types[namespace + "." + name] = vtype
|
||||
|
||||
traverse(cherrypy.request, "request")
|
||||
traverse(cherrypy.response, "response")
|
||||
traverse(cherrypy.server, "server")
|
||||
traverse(cherrypy.engine, "engine")
|
||||
traverse(cherrypy.log, "log")
|
||||
|
||||
def _known_types(self, config):
|
||||
msg = ("The config entry %r in section %r is of type %r, "
|
||||
"which does not match the expected type %r.")
|
||||
|
||||
for section, conf in config.items():
|
||||
if isinstance(conf, dict):
|
||||
for k, v in conf.items():
|
||||
if v is not None:
|
||||
expected_type = self.known_config_types.get(k, None)
|
||||
vtype = type(v)
|
||||
if expected_type and vtype != expected_type:
|
||||
warnings.warn(msg % (k, section, vtype.__name__,
|
||||
expected_type.__name__))
|
||||
else:
|
||||
k, v = section, conf
|
||||
if v is not None:
|
||||
expected_type = self.known_config_types.get(k, None)
|
||||
vtype = type(v)
|
||||
if expected_type and vtype != expected_type:
|
||||
warnings.warn(msg % (k, section, vtype.__name__,
|
||||
expected_type.__name__))
|
||||
|
||||
def check_config_types(self):
|
||||
"""Assert that config values are of the same type as default values."""
|
||||
self._known_types(cherrypy.config)
|
||||
for sn, app in cherrypy.tree.apps.items():
|
||||
if not isinstance(app, cherrypy.Application):
|
||||
continue
|
||||
self._known_types(app.config)
|
||||
|
||||
# -------------------- Specific config warnings -------------------- #
|
||||
def check_localhost(self):
|
||||
"""Warn if any socket_host is 'localhost'. See #711."""
|
||||
for k, v in cherrypy.config.items():
|
||||
if k == 'server.socket_host' and v == 'localhost':
|
||||
warnings.warn("The use of 'localhost' as a socket host can "
|
||||
"cause problems on newer systems, since "
|
||||
"'localhost' can map to either an IPv4 or an "
|
||||
"IPv6 address. You should use '127.0.0.1' "
|
||||
"or '[::1]' instead.")
|
||||
383
lib/cherrypy/_cpcompat.py
Normal file
383
lib/cherrypy/_cpcompat.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Compatibility code for using CherryPy with various versions of Python.
|
||||
|
||||
CherryPy 3.2 is compatible with Python versions 2.3+. This module provides a
|
||||
useful abstraction over the differences between Python versions, sometimes by
|
||||
preferring a newer idiom, sometimes an older one, and sometimes a custom one.
|
||||
|
||||
In particular, Python 2 uses str and '' for byte strings, while Python 3
|
||||
uses str and '' for unicode strings. We will call each of these the 'native
|
||||
string' type for each version. Because of this major difference, this module
|
||||
provides new 'bytestr', 'unicodestr', and 'nativestr' attributes, as well as
|
||||
two functions: 'ntob', which translates native strings (of type 'str') into
|
||||
byte strings regardless of Python version, and 'ntou', which translates native
|
||||
strings to unicode strings. This also provides a 'BytesIO' name for dealing
|
||||
specifically with bytes, and a 'StringIO' name for dealing with native strings.
|
||||
It also provides a 'base64_decode' function with native strings as input and
|
||||
output.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
py3k = True
|
||||
bytestr = bytes
|
||||
unicodestr = str
|
||||
nativestr = unicodestr
|
||||
basestring = (bytes, str)
|
||||
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a byte string in the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n.encode(encoding)
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a unicode string with the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 3, the native string type is unicode
|
||||
return n
|
||||
|
||||
def tonative(n, encoding='ISO-8859-1'):
|
||||
"""Return the given string as a native string in the given encoding."""
|
||||
# In Python 3, the native string type is unicode
|
||||
if isinstance(n, bytes):
|
||||
return n.decode(encoding)
|
||||
return n
|
||||
# type("")
|
||||
from io import StringIO
|
||||
# bytes:
|
||||
from io import BytesIO as BytesIO
|
||||
else:
|
||||
# Python 2
|
||||
py3k = False
|
||||
bytestr = str
|
||||
unicodestr = unicode
|
||||
nativestr = bytestr
|
||||
basestring = basestring
|
||||
|
||||
def ntob(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a byte string in the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes. Assume it's already
|
||||
# in the given encoding, which for ISO-8859-1 is almost always what
|
||||
# was intended.
|
||||
return n
|
||||
|
||||
def ntou(n, encoding='ISO-8859-1'):
|
||||
"""Return the given native string as a unicode string with the given
|
||||
encoding.
|
||||
"""
|
||||
assert_native(n)
|
||||
# In Python 2, the native string type is bytes.
|
||||
# First, check for the special encoding 'escape'. The test suite uses
|
||||
# this to signal that it wants to pass a string with embedded \uXXXX
|
||||
# escapes, but without having to prefix it with u'' for Python 2,
|
||||
# but no prefix for Python 3.
|
||||
if encoding == 'escape':
|
||||
return unicode(
|
||||
re.sub(r'\\u([0-9a-zA-Z]{4})',
|
||||
lambda m: unichr(int(m.group(1), 16)),
|
||||
n.decode('ISO-8859-1')))
|
||||
# Assume it's already in the given encoding, which for ISO-8859-1
|
||||
# is almost always what was intended.
|
||||
return n.decode(encoding)
|
||||
|
||||
def tonative(n, encoding='ISO-8859-1'):
|
||||
"""Return the given string as a native string in the given encoding."""
|
||||
# In Python 2, the native string type is bytes.
|
||||
if isinstance(n, unicode):
|
||||
return n.encode(encoding)
|
||||
return n
|
||||
try:
|
||||
# type("")
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
# type("")
|
||||
from StringIO import StringIO
|
||||
# bytes:
|
||||
BytesIO = StringIO
|
||||
|
||||
|
||||
def assert_native(n):
|
||||
if not isinstance(n, nativestr):
|
||||
raise TypeError("n must be a native str (got %s)" % type(n).__name__)
|
||||
|
||||
try:
|
||||
set = set
|
||||
except NameError:
|
||||
from sets import Set as set
|
||||
|
||||
try:
|
||||
# Python 3.1+
|
||||
from base64 import decodebytes as _base64_decodebytes
|
||||
except ImportError:
|
||||
# Python 3.0-
|
||||
# since CherryPy claims compability with Python 2.3, we must use
|
||||
# the legacy API of base64
|
||||
from base64 import decodestring as _base64_decodebytes
|
||||
|
||||
|
||||
def base64_decode(n, encoding='ISO-8859-1'):
|
||||
"""Return the native string base64-decoded (as a native string)."""
|
||||
if isinstance(n, unicodestr):
|
||||
b = n.encode(encoding)
|
||||
else:
|
||||
b = n
|
||||
b = _base64_decodebytes(b)
|
||||
if nativestr is unicodestr:
|
||||
return b.decode(encoding)
|
||||
else:
|
||||
return b
|
||||
|
||||
try:
|
||||
# Python 2.5+
|
||||
from hashlib import md5
|
||||
except ImportError:
|
||||
from md5 import new as md5
|
||||
|
||||
try:
|
||||
# Python 2.5+
|
||||
from hashlib import sha1 as sha
|
||||
except ImportError:
|
||||
from sha import new as sha
|
||||
|
||||
try:
|
||||
sorted = sorted
|
||||
except NameError:
|
||||
def sorted(i):
|
||||
i = i[:]
|
||||
i.sort()
|
||||
return i
|
||||
|
||||
try:
|
||||
reversed = reversed
|
||||
except NameError:
|
||||
def reversed(x):
|
||||
i = len(x)
|
||||
while i > 0:
|
||||
i -= 1
|
||||
yield x[i]
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
from urllib.parse import urljoin, urlencode
|
||||
from urllib.parse import quote, quote_plus
|
||||
from urllib.request import unquote, urlopen
|
||||
from urllib.request import parse_http_list, parse_keqv_list
|
||||
except ImportError:
|
||||
# Python 2
|
||||
from urlparse import urljoin
|
||||
from urllib import urlencode, urlopen
|
||||
from urllib import quote, quote_plus
|
||||
from urllib import unquote
|
||||
from urllib2 import parse_http_list, parse_keqv_list
|
||||
|
||||
try:
|
||||
from threading import local as threadlocal
|
||||
except ImportError:
|
||||
from cherrypy._cpthreadinglocal import local as threadlocal
|
||||
|
||||
try:
|
||||
dict.iteritems
|
||||
# Python 2
|
||||
iteritems = lambda d: d.iteritems()
|
||||
copyitems = lambda d: d.items()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
iteritems = lambda d: d.items()
|
||||
copyitems = lambda d: list(d.items())
|
||||
|
||||
try:
|
||||
dict.iterkeys
|
||||
# Python 2
|
||||
iterkeys = lambda d: d.iterkeys()
|
||||
copykeys = lambda d: d.keys()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
iterkeys = lambda d: d.keys()
|
||||
copykeys = lambda d: list(d.keys())
|
||||
|
||||
try:
|
||||
dict.itervalues
|
||||
# Python 2
|
||||
itervalues = lambda d: d.itervalues()
|
||||
copyvalues = lambda d: d.values()
|
||||
except AttributeError:
|
||||
# Python 3
|
||||
itervalues = lambda d: d.values()
|
||||
copyvalues = lambda d: list(d.values())
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
import builtins
|
||||
except ImportError:
|
||||
# Python 2
|
||||
import __builtin__ as builtins
|
||||
|
||||
try:
|
||||
# Python 2. We try Python 2 first clients on Python 2
|
||||
# don't try to import the 'http' module from cherrypy.lib
|
||||
from Cookie import SimpleCookie, CookieError
|
||||
from httplib import BadStatusLine, HTTPConnection, IncompleteRead
|
||||
from httplib import NotConnected
|
||||
from BaseHTTPServer import BaseHTTPRequestHandler
|
||||
except ImportError:
|
||||
# Python 3
|
||||
from http.cookies import SimpleCookie, CookieError
|
||||
from http.client import BadStatusLine, HTTPConnection, IncompleteRead
|
||||
from http.client import NotConnected
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
|
||||
# Some platforms don't expose HTTPSConnection, so handle it separately
|
||||
if py3k:
|
||||
try:
|
||||
from http.client import HTTPSConnection
|
||||
except ImportError:
|
||||
# Some platforms which don't have SSL don't expose HTTPSConnection
|
||||
HTTPSConnection = None
|
||||
else:
|
||||
try:
|
||||
from httplib import HTTPSConnection
|
||||
except ImportError:
|
||||
HTTPSConnection = None
|
||||
|
||||
try:
|
||||
# Python 2
|
||||
xrange = xrange
|
||||
except NameError:
|
||||
# Python 3
|
||||
xrange = range
|
||||
|
||||
import threading
|
||||
if hasattr(threading.Thread, "daemon"):
|
||||
# Python 2.6+
|
||||
def get_daemon(t):
|
||||
return t.daemon
|
||||
|
||||
def set_daemon(t, val):
|
||||
t.daemon = val
|
||||
else:
|
||||
def get_daemon(t):
|
||||
return t.isDaemon()
|
||||
|
||||
def set_daemon(t, val):
|
||||
t.setDaemon(val)
|
||||
|
||||
try:
|
||||
from email.utils import formatdate
|
||||
|
||||
def HTTPDate(timeval=None):
|
||||
return formatdate(timeval, usegmt=True)
|
||||
except ImportError:
|
||||
from rfc822 import formatdate as HTTPDate
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
from urllib.parse import unquote as parse_unquote
|
||||
|
||||
def unquote_qs(atom, encoding, errors='strict'):
|
||||
return parse_unquote(
|
||||
atom.replace('+', ' '),
|
||||
encoding=encoding,
|
||||
errors=errors)
|
||||
except ImportError:
|
||||
# Python 2
|
||||
from urllib import unquote as parse_unquote
|
||||
|
||||
def unquote_qs(atom, encoding, errors='strict'):
|
||||
return parse_unquote(atom.replace('+', ' ')).decode(encoding, errors)
|
||||
|
||||
try:
|
||||
# Prefer simplejson, which is usually more advanced than the builtin
|
||||
# module.
|
||||
import simplejson as json
|
||||
json_decode = json.JSONDecoder().decode
|
||||
_json_encode = json.JSONEncoder().iterencode
|
||||
except ImportError:
|
||||
if sys.version_info >= (2, 6):
|
||||
# Python >=2.6 : json is part of the standard library
|
||||
import json
|
||||
json_decode = json.JSONDecoder().decode
|
||||
_json_encode = json.JSONEncoder().iterencode
|
||||
else:
|
||||
json = None
|
||||
|
||||
def json_decode(s):
|
||||
raise ValueError('No JSON library is available')
|
||||
|
||||
def _json_encode(s):
|
||||
raise ValueError('No JSON library is available')
|
||||
finally:
|
||||
if json and py3k:
|
||||
# The two Python 3 implementations (simplejson/json)
|
||||
# outputs str. We need bytes.
|
||||
def json_encode(value):
|
||||
for chunk in _json_encode(value):
|
||||
yield chunk.encode('utf8')
|
||||
else:
|
||||
json_encode = _json_encode
|
||||
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
# In Python 2, pickle is a Python version.
|
||||
# In Python 3, pickle is the sped-up C version.
|
||||
import pickle
|
||||
|
||||
try:
|
||||
os.urandom(20)
|
||||
import binascii
|
||||
|
||||
def random20():
|
||||
return binascii.hexlify(os.urandom(20)).decode('ascii')
|
||||
except (AttributeError, NotImplementedError):
|
||||
import random
|
||||
# os.urandom not available until Python 2.4. Fall back to random.random.
|
||||
|
||||
def random20():
|
||||
return sha('%s' % random.random()).hexdigest()
|
||||
|
||||
try:
|
||||
from _thread import get_ident as get_thread_ident
|
||||
except ImportError:
|
||||
from thread import get_ident as get_thread_ident
|
||||
|
||||
try:
|
||||
# Python 3
|
||||
next = next
|
||||
except NameError:
|
||||
# Python 2
|
||||
def next(i):
|
||||
return i.next()
|
||||
|
||||
if sys.version_info >= (3, 3):
|
||||
Timer = threading.Timer
|
||||
Event = threading.Event
|
||||
else:
|
||||
# Python 3.2 and earlier
|
||||
Timer = threading._Timer
|
||||
Event = threading._Event
|
||||
|
||||
# Prior to Python 2.6, the Thread class did not have a .daemon property.
|
||||
# This mix-in adds that property.
|
||||
|
||||
|
||||
class SetDaemonProperty:
|
||||
|
||||
def __get_daemon(self):
|
||||
return self.isDaemon()
|
||||
|
||||
def __set_daemon(self, daemon):
|
||||
self.setDaemon(daemon)
|
||||
|
||||
if sys.version_info < (2, 6):
|
||||
daemon = property(__get_daemon, __set_daemon)
|
||||
1544
lib/cherrypy/_cpcompat_subprocess.py
Normal file
1544
lib/cherrypy/_cpcompat_subprocess.py
Normal file
File diff suppressed because it is too large
Load Diff
317
lib/cherrypy/_cpconfig.py
Normal file
317
lib/cherrypy/_cpconfig.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Configuration system for CherryPy.
|
||||
|
||||
Configuration in CherryPy is implemented via dictionaries. Keys are strings
|
||||
which name the mapped value, which may be of any type.
|
||||
|
||||
|
||||
Architecture
|
||||
------------
|
||||
|
||||
CherryPy Requests are part of an Application, which runs in a global context,
|
||||
and configuration data may apply to any of those three scopes:
|
||||
|
||||
Global
|
||||
Configuration entries which apply everywhere are stored in
|
||||
cherrypy.config.
|
||||
|
||||
Application
|
||||
Entries which apply to each mounted application are stored
|
||||
on the Application object itself, as 'app.config'. This is a two-level
|
||||
dict where each key is a path, or "relative URL" (for example, "/" or
|
||||
"/path/to/my/page"), and each value is a config dict. Usually, this
|
||||
data is provided in the call to tree.mount(root(), config=conf),
|
||||
although you may also use app.merge(conf).
|
||||
|
||||
Request
|
||||
Each Request object possesses a single 'Request.config' dict.
|
||||
Early in the request process, this dict is populated by merging global
|
||||
config entries, Application entries (whose path equals or is a parent
|
||||
of Request.path_info), and any config acquired while looking up the
|
||||
page handler (see next).
|
||||
|
||||
|
||||
Declaration
|
||||
-----------
|
||||
|
||||
Configuration data may be supplied as a Python dictionary, as a filename,
|
||||
or as an open file object. When you supply a filename or file, CherryPy
|
||||
uses Python's builtin ConfigParser; you declare Application config by
|
||||
writing each path as a section header::
|
||||
|
||||
[/path/to/my/page]
|
||||
request.stream = True
|
||||
|
||||
To declare global configuration entries, place them in a [global] section.
|
||||
|
||||
You may also declare config entries directly on the classes and methods
|
||||
(page handlers) that make up your CherryPy application via the ``_cp_config``
|
||||
attribute. For example::
|
||||
|
||||
class Demo:
|
||||
_cp_config = {'tools.gzip.on': True}
|
||||
|
||||
def index(self):
|
||||
return "Hello world"
|
||||
index.exposed = True
|
||||
index._cp_config = {'request.show_tracebacks': False}
|
||||
|
||||
.. note::
|
||||
|
||||
This behavior is only guaranteed for the default dispatcher.
|
||||
Other dispatchers may have different restrictions on where
|
||||
you can attach _cp_config attributes.
|
||||
|
||||
|
||||
Namespaces
|
||||
----------
|
||||
|
||||
Configuration keys are separated into namespaces by the first "." in the key.
|
||||
Current namespaces:
|
||||
|
||||
engine
|
||||
Controls the 'application engine', including autoreload.
|
||||
These can only be declared in the global config.
|
||||
|
||||
tree
|
||||
Grafts cherrypy.Application objects onto cherrypy.tree.
|
||||
These can only be declared in the global config.
|
||||
|
||||
hooks
|
||||
Declares additional request-processing functions.
|
||||
|
||||
log
|
||||
Configures the logging for each application.
|
||||
These can only be declared in the global or / config.
|
||||
|
||||
request
|
||||
Adds attributes to each Request.
|
||||
|
||||
response
|
||||
Adds attributes to each Response.
|
||||
|
||||
server
|
||||
Controls the default HTTP server via cherrypy.server.
|
||||
These can only be declared in the global config.
|
||||
|
||||
tools
|
||||
Runs and configures additional request-processing packages.
|
||||
|
||||
wsgi
|
||||
Adds WSGI middleware to an Application's "pipeline".
|
||||
These can only be declared in the app's root config ("/").
|
||||
|
||||
checker
|
||||
Controls the 'checker', which looks for common errors in
|
||||
app state (including config) when the engine starts.
|
||||
Global config only.
|
||||
|
||||
The only key that does not exist in a namespace is the "environment" entry.
|
||||
This special entry 'imports' other config entries from a template stored in
|
||||
cherrypy._cpconfig.environments[environment]. It only applies to the global
|
||||
config, and only when you use cherrypy.config.update.
|
||||
|
||||
You can define your own namespaces to be called at the Global, Application,
|
||||
or Request level, by adding a named handler to cherrypy.config.namespaces,
|
||||
app.namespaces, or app.request_class.namespaces. The name can
|
||||
be any string, and the handler must be either a callable or a (Python 2.5
|
||||
style) context manager.
|
||||
"""
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import set, basestring
|
||||
from cherrypy.lib import reprconf
|
||||
|
||||
# Deprecated in CherryPy 3.2--remove in 3.3
|
||||
NamespaceSet = reprconf.NamespaceSet
|
||||
|
||||
|
||||
def merge(base, other):
|
||||
"""Merge one app config (from a dict, file, or filename) into another.
|
||||
|
||||
If the given config is a filename, it will be appended to
|
||||
the list of files to monitor for "autoreload" changes.
|
||||
"""
|
||||
if isinstance(other, basestring):
|
||||
cherrypy.engine.autoreload.files.add(other)
|
||||
|
||||
# Load other into base
|
||||
for section, value_map in reprconf.as_dict(other).items():
|
||||
if not isinstance(value_map, dict):
|
||||
raise ValueError(
|
||||
"Application config must include section headers, but the "
|
||||
"config you tried to merge doesn't have any sections. "
|
||||
"Wrap your config in another dict with paths as section "
|
||||
"headers, for example: {'/': config}.")
|
||||
base.setdefault(section, {}).update(value_map)
|
||||
|
||||
|
||||
class Config(reprconf.Config):
|
||||
|
||||
"""The 'global' configuration data for the entire CherryPy process."""
|
||||
|
||||
def update(self, config):
|
||||
"""Update self from a dict, file or filename."""
|
||||
if isinstance(config, basestring):
|
||||
# Filename
|
||||
cherrypy.engine.autoreload.files.add(config)
|
||||
reprconf.Config.update(self, config)
|
||||
|
||||
def _apply(self, config):
|
||||
"""Update self from a dict."""
|
||||
if isinstance(config.get("global"), dict):
|
||||
if len(config) > 1:
|
||||
cherrypy.checker.global_config_contained_paths = True
|
||||
config = config["global"]
|
||||
if 'tools.staticdir.dir' in config:
|
||||
config['tools.staticdir.section'] = "global"
|
||||
reprconf.Config._apply(self, config)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Decorator for page handlers to set _cp_config."""
|
||||
if args:
|
||||
raise TypeError(
|
||||
"The cherrypy.config decorator does not accept positional "
|
||||
"arguments; you must use keyword arguments.")
|
||||
|
||||
def tool_decorator(f):
|
||||
if not hasattr(f, "_cp_config"):
|
||||
f._cp_config = {}
|
||||
for k, v in kwargs.items():
|
||||
f._cp_config[k] = v
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
|
||||
# Sphinx begin config.environments
|
||||
Config.environments = environments = {
|
||||
"staging": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
},
|
||||
"production": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
'log.screen': False,
|
||||
},
|
||||
"embedded": {
|
||||
# For use with CherryPy embedded in another deployment stack.
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': False,
|
||||
'request.show_mismatched_params': False,
|
||||
'log.screen': False,
|
||||
'engine.SIGHUP': None,
|
||||
'engine.SIGTERM': None,
|
||||
},
|
||||
"test_suite": {
|
||||
'engine.autoreload.on': False,
|
||||
'checker.on': False,
|
||||
'tools.log_headers.on': False,
|
||||
'request.show_tracebacks': True,
|
||||
'request.show_mismatched_params': True,
|
||||
'log.screen': False,
|
||||
},
|
||||
}
|
||||
# Sphinx end config.environments
|
||||
|
||||
|
||||
def _server_namespace_handler(k, v):
|
||||
"""Config handler for the "server" namespace."""
|
||||
atoms = k.split(".", 1)
|
||||
if len(atoms) > 1:
|
||||
# Special-case config keys of the form 'server.servername.socket_port'
|
||||
# to configure additional HTTP servers.
|
||||
if not hasattr(cherrypy, "servers"):
|
||||
cherrypy.servers = {}
|
||||
|
||||
servername, k = atoms
|
||||
if servername not in cherrypy.servers:
|
||||
from cherrypy import _cpserver
|
||||
cherrypy.servers[servername] = _cpserver.Server()
|
||||
# On by default, but 'on = False' can unsubscribe it (see below).
|
||||
cherrypy.servers[servername].subscribe()
|
||||
|
||||
if k == 'on':
|
||||
if v:
|
||||
cherrypy.servers[servername].subscribe()
|
||||
else:
|
||||
cherrypy.servers[servername].unsubscribe()
|
||||
else:
|
||||
setattr(cherrypy.servers[servername], k, v)
|
||||
else:
|
||||
setattr(cherrypy.server, k, v)
|
||||
Config.namespaces["server"] = _server_namespace_handler
|
||||
|
||||
|
||||
def _engine_namespace_handler(k, v):
|
||||
"""Backward compatibility handler for the "engine" namespace."""
|
||||
engine = cherrypy.engine
|
||||
|
||||
deprecated = {
|
||||
'autoreload_on': 'autoreload.on',
|
||||
'autoreload_frequency': 'autoreload.frequency',
|
||||
'autoreload_match': 'autoreload.match',
|
||||
'reload_files': 'autoreload.files',
|
||||
'deadlock_poll_freq': 'timeout_monitor.frequency'
|
||||
}
|
||||
|
||||
if k in deprecated:
|
||||
engine.log(
|
||||
'WARNING: Use of engine.%s is deprecated and will be removed in a '
|
||||
'future version. Use engine.%s instead.' % (k, deprecated[k]))
|
||||
|
||||
if k == 'autoreload_on':
|
||||
if v:
|
||||
engine.autoreload.subscribe()
|
||||
else:
|
||||
engine.autoreload.unsubscribe()
|
||||
elif k == 'autoreload_frequency':
|
||||
engine.autoreload.frequency = v
|
||||
elif k == 'autoreload_match':
|
||||
engine.autoreload.match = v
|
||||
elif k == 'reload_files':
|
||||
engine.autoreload.files = set(v)
|
||||
elif k == 'deadlock_poll_freq':
|
||||
engine.timeout_monitor.frequency = v
|
||||
elif k == 'SIGHUP':
|
||||
engine.listeners['SIGHUP'] = set([v])
|
||||
elif k == 'SIGTERM':
|
||||
engine.listeners['SIGTERM'] = set([v])
|
||||
elif "." in k:
|
||||
plugin, attrname = k.split(".", 1)
|
||||
plugin = getattr(engine, plugin)
|
||||
if attrname == 'on':
|
||||
if v and hasattr(getattr(plugin, 'subscribe', None), '__call__'):
|
||||
plugin.subscribe()
|
||||
return
|
||||
elif (
|
||||
(not v) and
|
||||
hasattr(getattr(plugin, 'unsubscribe', None), '__call__')
|
||||
):
|
||||
plugin.unsubscribe()
|
||||
return
|
||||
setattr(plugin, attrname, v)
|
||||
else:
|
||||
setattr(engine, k, v)
|
||||
Config.namespaces["engine"] = _engine_namespace_handler
|
||||
|
||||
|
||||
def _tree_namespace_handler(k, v):
|
||||
"""Namespace handler for the 'tree' config namespace."""
|
||||
if isinstance(v, dict):
|
||||
for script_name, app in v.items():
|
||||
cherrypy.tree.graft(app, script_name)
|
||||
cherrypy.engine.log("Mounted: %s on %s" %
|
||||
(app, script_name or "/"))
|
||||
else:
|
||||
cherrypy.tree.graft(v, v.script_name)
|
||||
cherrypy.engine.log("Mounted: %s on %s" % (v, v.script_name or "/"))
|
||||
Config.namespaces["tree"] = _tree_namespace_handler
|
||||
686
lib/cherrypy/_cpdispatch.py
Normal file
686
lib/cherrypy/_cpdispatch.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""CherryPy dispatchers.
|
||||
|
||||
A 'dispatcher' is the object which looks up the 'page handler' callable
|
||||
and collects config for the current request based on the path_info, other
|
||||
request attributes, and the application architecture. The core calls the
|
||||
dispatcher as early as possible, passing it a 'path_info' argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
"""
|
||||
|
||||
import string
|
||||
import sys
|
||||
import types
|
||||
try:
|
||||
classtype = (type, types.ClassType)
|
||||
except AttributeError:
|
||||
classtype = type
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import set
|
||||
|
||||
|
||||
class PageHandler(object):
|
||||
|
||||
"""Callable which sets response.body."""
|
||||
|
||||
def __init__(self, callable, *args, **kwargs):
|
||||
self.callable = callable
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_args(self):
|
||||
return cherrypy.serving.request.args
|
||||
|
||||
def set_args(self, args):
|
||||
cherrypy.serving.request.args = args
|
||||
return cherrypy.serving.request.args
|
||||
|
||||
args = property(
|
||||
get_args,
|
||||
set_args,
|
||||
doc="The ordered args should be accessible from post dispatch hooks"
|
||||
)
|
||||
|
||||
def get_kwargs(self):
|
||||
return cherrypy.serving.request.kwargs
|
||||
|
||||
def set_kwargs(self, kwargs):
|
||||
cherrypy.serving.request.kwargs = kwargs
|
||||
return cherrypy.serving.request.kwargs
|
||||
|
||||
kwargs = property(
|
||||
get_kwargs,
|
||||
set_kwargs,
|
||||
doc="The named kwargs should be accessible from post dispatch hooks"
|
||||
)
|
||||
|
||||
def __call__(self):
|
||||
try:
|
||||
return self.callable(*self.args, **self.kwargs)
|
||||
except TypeError:
|
||||
x = sys.exc_info()[1]
|
||||
try:
|
||||
test_callable_spec(self.callable, self.args, self.kwargs)
|
||||
except cherrypy.HTTPError:
|
||||
raise sys.exc_info()[1]
|
||||
except:
|
||||
raise x
|
||||
raise
|
||||
|
||||
|
||||
def test_callable_spec(callable, callable_args, callable_kwargs):
|
||||
"""
|
||||
Inspect callable and test to see if the given args are suitable for it.
|
||||
|
||||
When an error occurs during the handler's invoking stage there are 2
|
||||
erroneous cases:
|
||||
1. Too many parameters passed to a function which doesn't define
|
||||
one of *args or **kwargs.
|
||||
2. Too little parameters are passed to the function.
|
||||
|
||||
There are 3 sources of parameters to a cherrypy handler.
|
||||
1. query string parameters are passed as keyword parameters to the
|
||||
handler.
|
||||
2. body parameters are also passed as keyword parameters.
|
||||
3. when partial matching occurs, the final path atoms are passed as
|
||||
positional args.
|
||||
Both the query string and path atoms are part of the URI. If they are
|
||||
incorrect, then a 404 Not Found should be raised. Conversely the body
|
||||
parameters are part of the request; if they are invalid a 400 Bad Request.
|
||||
"""
|
||||
show_mismatched_params = getattr(
|
||||
cherrypy.serving.request, 'show_mismatched_params', False)
|
||||
try:
|
||||
(args, varargs, varkw, defaults) = getargspec(callable)
|
||||
except TypeError:
|
||||
if isinstance(callable, object) and hasattr(callable, '__call__'):
|
||||
(args, varargs, varkw,
|
||||
defaults) = getargspec(callable.__call__)
|
||||
else:
|
||||
# If it wasn't one of our own types, re-raise
|
||||
# the original error
|
||||
raise
|
||||
|
||||
if args and args[0] == 'self':
|
||||
args = args[1:]
|
||||
|
||||
arg_usage = dict([(arg, 0,) for arg in args])
|
||||
vararg_usage = 0
|
||||
varkw_usage = 0
|
||||
extra_kwargs = set()
|
||||
|
||||
for i, value in enumerate(callable_args):
|
||||
try:
|
||||
arg_usage[args[i]] += 1
|
||||
except IndexError:
|
||||
vararg_usage += 1
|
||||
|
||||
for key in callable_kwargs.keys():
|
||||
try:
|
||||
arg_usage[key] += 1
|
||||
except KeyError:
|
||||
varkw_usage += 1
|
||||
extra_kwargs.add(key)
|
||||
|
||||
# figure out which args have defaults.
|
||||
args_with_defaults = args[-len(defaults or []):]
|
||||
for i, val in enumerate(defaults or []):
|
||||
# Defaults take effect only when the arg hasn't been used yet.
|
||||
if arg_usage[args_with_defaults[i]] == 0:
|
||||
arg_usage[args_with_defaults[i]] += 1
|
||||
|
||||
missing_args = []
|
||||
multiple_args = []
|
||||
for key, usage in arg_usage.items():
|
||||
if usage == 0:
|
||||
missing_args.append(key)
|
||||
elif usage > 1:
|
||||
multiple_args.append(key)
|
||||
|
||||
if missing_args:
|
||||
# In the case where the method allows body arguments
|
||||
# there are 3 potential errors:
|
||||
# 1. not enough query string parameters -> 404
|
||||
# 2. not enough body parameters -> 400
|
||||
# 3. not enough path parts (partial matches) -> 404
|
||||
#
|
||||
# We can't actually tell which case it is,
|
||||
# so I'm raising a 404 because that covers 2/3 of the
|
||||
# possibilities
|
||||
#
|
||||
# In the case where the method does not allow body
|
||||
# arguments it's definitely a 404.
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Missing parameters: %s" % ",".join(missing_args)
|
||||
raise cherrypy.HTTPError(404, message=message)
|
||||
|
||||
# the extra positional arguments come from the path - 404 Not Found
|
||||
if not varargs and vararg_usage > 0:
|
||||
raise cherrypy.HTTPError(404)
|
||||
|
||||
body_params = cherrypy.serving.request.body.params or {}
|
||||
body_params = set(body_params.keys())
|
||||
qs_params = set(callable_kwargs.keys()) - body_params
|
||||
|
||||
if multiple_args:
|
||||
if qs_params.intersection(set(multiple_args)):
|
||||
# If any of the multiple parameters came from the query string then
|
||||
# it's a 404 Not Found
|
||||
error = 404
|
||||
else:
|
||||
# Otherwise it's a 400 Bad Request
|
||||
error = 400
|
||||
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Multiple values for parameters: "\
|
||||
"%s" % ",".join(multiple_args)
|
||||
raise cherrypy.HTTPError(error, message=message)
|
||||
|
||||
if not varkw and varkw_usage > 0:
|
||||
|
||||
# If there were extra query string parameters, it's a 404 Not Found
|
||||
extra_qs_params = set(qs_params).intersection(extra_kwargs)
|
||||
if extra_qs_params:
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Unexpected query string "\
|
||||
"parameters: %s" % ", ".join(extra_qs_params)
|
||||
raise cherrypy.HTTPError(404, message=message)
|
||||
|
||||
# If there were any extra body parameters, it's a 400 Not Found
|
||||
extra_body_params = set(body_params).intersection(extra_kwargs)
|
||||
if extra_body_params:
|
||||
message = None
|
||||
if show_mismatched_params:
|
||||
message = "Unexpected body parameters: "\
|
||||
"%s" % ", ".join(extra_body_params)
|
||||
raise cherrypy.HTTPError(400, message=message)
|
||||
|
||||
|
||||
try:
|
||||
import inspect
|
||||
except ImportError:
|
||||
test_callable_spec = lambda callable, args, kwargs: None
|
||||
else:
|
||||
getargspec = inspect.getargspec
|
||||
# Python 3 requires using getfullargspec if keyword-only arguments are present
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
def getargspec(callable):
|
||||
return inspect.getfullargspec(callable)[:4]
|
||||
|
||||
|
||||
class LateParamPageHandler(PageHandler):
|
||||
|
||||
"""When passing cherrypy.request.params to the page handler, we do not
|
||||
want to capture that dict too early; we want to give tools like the
|
||||
decoding tool a chance to modify the params dict in-between the lookup
|
||||
of the handler and the actual calling of the handler. This subclass
|
||||
takes that into account, and allows request.params to be 'bound late'
|
||||
(it's more complicated than that, but that's the effect).
|
||||
"""
|
||||
|
||||
def _get_kwargs(self):
|
||||
kwargs = cherrypy.serving.request.params.copy()
|
||||
if self._kwargs:
|
||||
kwargs.update(self._kwargs)
|
||||
return kwargs
|
||||
|
||||
def _set_kwargs(self, kwargs):
|
||||
cherrypy.serving.request.kwargs = kwargs
|
||||
self._kwargs = kwargs
|
||||
|
||||
kwargs = property(_get_kwargs, _set_kwargs,
|
||||
doc='page handler kwargs (with '
|
||||
'cherrypy.request.params copied in)')
|
||||
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
punctuation_to_underscores = string.maketrans(
|
||||
string.punctuation, '_' * len(string.punctuation))
|
||||
|
||||
def validate_translator(t):
|
||||
if not isinstance(t, str) or len(t) != 256:
|
||||
raise ValueError(
|
||||
"The translate argument must be a str of len 256.")
|
||||
else:
|
||||
punctuation_to_underscores = str.maketrans(
|
||||
string.punctuation, '_' * len(string.punctuation))
|
||||
|
||||
def validate_translator(t):
|
||||
if not isinstance(t, dict):
|
||||
raise ValueError("The translate argument must be a dict.")
|
||||
|
||||
|
||||
class Dispatcher(object):
|
||||
|
||||
"""CherryPy Dispatcher which walks a tree of objects to find a handler.
|
||||
|
||||
The tree is rooted at cherrypy.request.app.root, and each hierarchical
|
||||
component in the path_info argument is matched to a corresponding nested
|
||||
attribute of the root object. Matching handlers must have an 'exposed'
|
||||
attribute which evaluates to True. The special method name "index"
|
||||
matches a URI which ends in a slash ("/"). The special method name
|
||||
"default" may match a portion of the path_info (but only when no longer
|
||||
substring of the path_info matches some other object).
|
||||
|
||||
This is the default, built-in dispatcher for CherryPy.
|
||||
"""
|
||||
|
||||
dispatch_method_name = '_cp_dispatch'
|
||||
"""
|
||||
The name of the dispatch method that nodes may optionally implement
|
||||
to provide their own dynamic dispatch algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, dispatch_method_name=None,
|
||||
translate=punctuation_to_underscores):
|
||||
validate_translator(translate)
|
||||
self.translate = translate
|
||||
if dispatch_method_name:
|
||||
self.dispatch_method_name = dispatch_method_name
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
request = cherrypy.serving.request
|
||||
func, vpath = self.find_handler(path_info)
|
||||
|
||||
if func:
|
||||
# Decode any leftover %2F in the virtual_path atoms.
|
||||
vpath = [x.replace("%2F", "/") for x in vpath]
|
||||
request.handler = LateParamPageHandler(func, *vpath)
|
||||
else:
|
||||
request.handler = cherrypy.NotFound()
|
||||
|
||||
def find_handler(self, path):
|
||||
"""Return the appropriate page handler, plus any virtual path.
|
||||
|
||||
This will return two objects. The first will be a callable,
|
||||
which can be used to generate page output. Any parameters from
|
||||
the query string or request body will be sent to that callable
|
||||
as keyword arguments.
|
||||
|
||||
The callable is found by traversing the application's tree,
|
||||
starting from cherrypy.request.app.root, and matching path
|
||||
components to successive objects in the tree. For example, the
|
||||
URL "/path/to/handler" might return root.path.to.handler.
|
||||
|
||||
The second object returned will be a list of names which are
|
||||
'virtual path' components: parts of the URL which are dynamic,
|
||||
and were not used when looking up the handler.
|
||||
These virtual path components are passed to the handler as
|
||||
positional arguments.
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
app = request.app
|
||||
root = app.root
|
||||
dispatch_name = self.dispatch_method_name
|
||||
|
||||
# Get config for the root object/path.
|
||||
fullpath = [x for x in path.strip('/').split('/') if x] + ['index']
|
||||
fullpath_len = len(fullpath)
|
||||
segleft = fullpath_len
|
||||
nodeconf = {}
|
||||
if hasattr(root, "_cp_config"):
|
||||
nodeconf.update(root._cp_config)
|
||||
if "/" in app.config:
|
||||
nodeconf.update(app.config["/"])
|
||||
object_trail = [['root', root, nodeconf, segleft]]
|
||||
|
||||
node = root
|
||||
iternames = fullpath[:]
|
||||
while iternames:
|
||||
name = iternames[0]
|
||||
# map to legal Python identifiers (e.g. replace '.' with '_')
|
||||
objname = name.translate(self.translate)
|
||||
|
||||
nodeconf = {}
|
||||
subnode = getattr(node, objname, None)
|
||||
pre_len = len(iternames)
|
||||
if subnode is None:
|
||||
dispatch = getattr(node, dispatch_name, None)
|
||||
if dispatch and hasattr(dispatch, '__call__') and not \
|
||||
getattr(dispatch, 'exposed', False) and \
|
||||
pre_len > 1:
|
||||
# Don't expose the hidden 'index' token to _cp_dispatch
|
||||
# We skip this if pre_len == 1 since it makes no sense
|
||||
# to call a dispatcher when we have no tokens left.
|
||||
index_name = iternames.pop()
|
||||
subnode = dispatch(vpath=iternames)
|
||||
iternames.append(index_name)
|
||||
else:
|
||||
# We didn't find a path, but keep processing in case there
|
||||
# is a default() handler.
|
||||
iternames.pop(0)
|
||||
else:
|
||||
# We found the path, remove the vpath entry
|
||||
iternames.pop(0)
|
||||
segleft = len(iternames)
|
||||
if segleft > pre_len:
|
||||
# No path segment was removed. Raise an error.
|
||||
raise cherrypy.CherryPyException(
|
||||
"A vpath segment was added. Custom dispatchers may only "
|
||||
+ "remove elements. While trying to process "
|
||||
+ "{0} in {1}".format(name, fullpath)
|
||||
)
|
||||
elif segleft == pre_len:
|
||||
# Assume that the handler used the current path segment, but
|
||||
# did not pop it. This allows things like
|
||||
# return getattr(self, vpath[0], None)
|
||||
iternames.pop(0)
|
||||
segleft -= 1
|
||||
node = subnode
|
||||
|
||||
if node is not None:
|
||||
# Get _cp_config attached to this node.
|
||||
if hasattr(node, "_cp_config"):
|
||||
nodeconf.update(node._cp_config)
|
||||
|
||||
# Mix in values from app.config for this path.
|
||||
existing_len = fullpath_len - pre_len
|
||||
if existing_len != 0:
|
||||
curpath = '/' + '/'.join(fullpath[0:existing_len])
|
||||
else:
|
||||
curpath = ''
|
||||
new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft]
|
||||
for seg in new_segs:
|
||||
curpath += '/' + seg
|
||||
if curpath in app.config:
|
||||
nodeconf.update(app.config[curpath])
|
||||
|
||||
object_trail.append([name, node, nodeconf, segleft])
|
||||
|
||||
def set_conf():
|
||||
"""Collapse all object_trail config into cherrypy.request.config.
|
||||
"""
|
||||
base = cherrypy.config.copy()
|
||||
# Note that we merge the config from each node
|
||||
# even if that node was None.
|
||||
for name, obj, conf, segleft in object_trail:
|
||||
base.update(conf)
|
||||
if 'tools.staticdir.dir' in conf:
|
||||
base['tools.staticdir.section'] = '/' + \
|
||||
'/'.join(fullpath[0:fullpath_len - segleft])
|
||||
return base
|
||||
|
||||
# Try successive objects (reverse order)
|
||||
num_candidates = len(object_trail) - 1
|
||||
for i in range(num_candidates, -1, -1):
|
||||
|
||||
name, candidate, nodeconf, segleft = object_trail[i]
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
# Try a "default" method on the current leaf.
|
||||
if hasattr(candidate, "default"):
|
||||
defhandler = candidate.default
|
||||
if getattr(defhandler, 'exposed', False):
|
||||
# Insert any extra _cp_config from the default handler.
|
||||
conf = getattr(defhandler, "_cp_config", {})
|
||||
object_trail.insert(
|
||||
i + 1, ["default", defhandler, conf, segleft])
|
||||
request.config = set_conf()
|
||||
# See https://bitbucket.org/cherrypy/cherrypy/issue/613
|
||||
request.is_index = path.endswith("/")
|
||||
return defhandler, fullpath[fullpath_len - segleft:-1]
|
||||
|
||||
# Uncomment the next line to restrict positional params to
|
||||
# "default".
|
||||
# if i < num_candidates - 2: continue
|
||||
|
||||
# Try the current leaf.
|
||||
if getattr(candidate, 'exposed', False):
|
||||
request.config = set_conf()
|
||||
if i == num_candidates:
|
||||
# We found the extra ".index". Mark request so tools
|
||||
# can redirect if path_info has no trailing slash.
|
||||
request.is_index = True
|
||||
else:
|
||||
# We're not at an 'index' handler. Mark request so tools
|
||||
# can redirect if path_info has NO trailing slash.
|
||||
# Note that this also includes handlers which take
|
||||
# positional parameters (virtual paths).
|
||||
request.is_index = False
|
||||
return candidate, fullpath[fullpath_len - segleft:-1]
|
||||
|
||||
# We didn't find anything
|
||||
request.config = set_conf()
|
||||
return None, []
|
||||
|
||||
|
||||
class MethodDispatcher(Dispatcher):
|
||||
|
||||
"""Additional dispatch based on cherrypy.request.method.upper().
|
||||
|
||||
Methods named GET, POST, etc will be called on an exposed class.
|
||||
The method names must be all caps; the appropriate Allow header
|
||||
will be output showing all capitalized method names as allowable
|
||||
HTTP verbs.
|
||||
|
||||
Note that the containing class must be exposed, not the methods.
|
||||
"""
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
request = cherrypy.serving.request
|
||||
resource, vpath = self.find_handler(path_info)
|
||||
|
||||
if resource:
|
||||
# Set Allow header
|
||||
avail = [m for m in dir(resource) if m.isupper()]
|
||||
if "GET" in avail and "HEAD" not in avail:
|
||||
avail.append("HEAD")
|
||||
avail.sort()
|
||||
cherrypy.serving.response.headers['Allow'] = ", ".join(avail)
|
||||
|
||||
# Find the subhandler
|
||||
meth = request.method.upper()
|
||||
func = getattr(resource, meth, None)
|
||||
if func is None and meth == "HEAD":
|
||||
func = getattr(resource, "GET", None)
|
||||
if func:
|
||||
# Grab any _cp_config on the subhandler.
|
||||
if hasattr(func, "_cp_config"):
|
||||
request.config.update(func._cp_config)
|
||||
|
||||
# Decode any leftover %2F in the virtual_path atoms.
|
||||
vpath = [x.replace("%2F", "/") for x in vpath]
|
||||
request.handler = LateParamPageHandler(func, *vpath)
|
||||
else:
|
||||
request.handler = cherrypy.HTTPError(405)
|
||||
else:
|
||||
request.handler = cherrypy.NotFound()
|
||||
|
||||
|
||||
class RoutesDispatcher(object):
|
||||
|
||||
"""A Routes based dispatcher for CherryPy."""
|
||||
|
||||
def __init__(self, full_result=False, **mapper_options):
|
||||
"""
|
||||
Routes dispatcher
|
||||
|
||||
Set full_result to True if you wish the controller
|
||||
and the action to be passed on to the page handler
|
||||
parameters. By default they won't be.
|
||||
"""
|
||||
import routes
|
||||
self.full_result = full_result
|
||||
self.controllers = {}
|
||||
self.mapper = routes.Mapper(**mapper_options)
|
||||
self.mapper.controller_scan = self.controllers.keys
|
||||
|
||||
def connect(self, name, route, controller, **kwargs):
|
||||
self.controllers[name] = controller
|
||||
self.mapper.connect(name, route, controller=name, **kwargs)
|
||||
|
||||
def redirect(self, url):
|
||||
raise cherrypy.HTTPRedirect(url)
|
||||
|
||||
def __call__(self, path_info):
|
||||
"""Set handler and config for the current request."""
|
||||
func = self.find_handler(path_info)
|
||||
if func:
|
||||
cherrypy.serving.request.handler = LateParamPageHandler(func)
|
||||
else:
|
||||
cherrypy.serving.request.handler = cherrypy.NotFound()
|
||||
|
||||
def find_handler(self, path_info):
|
||||
"""Find the right page handler, and set request.config."""
|
||||
import routes
|
||||
|
||||
request = cherrypy.serving.request
|
||||
|
||||
config = routes.request_config()
|
||||
config.mapper = self.mapper
|
||||
if hasattr(request, 'wsgi_environ'):
|
||||
config.environ = request.wsgi_environ
|
||||
config.host = request.headers.get('Host', None)
|
||||
config.protocol = request.scheme
|
||||
config.redirect = self.redirect
|
||||
|
||||
result = self.mapper.match(path_info)
|
||||
|
||||
config.mapper_dict = result
|
||||
params = {}
|
||||
if result:
|
||||
params = result.copy()
|
||||
if not self.full_result:
|
||||
params.pop('controller', None)
|
||||
params.pop('action', None)
|
||||
request.params.update(params)
|
||||
|
||||
# Get config for the root object/path.
|
||||
request.config = base = cherrypy.config.copy()
|
||||
curpath = ""
|
||||
|
||||
def merge(nodeconf):
|
||||
if 'tools.staticdir.dir' in nodeconf:
|
||||
nodeconf['tools.staticdir.section'] = curpath or "/"
|
||||
base.update(nodeconf)
|
||||
|
||||
app = request.app
|
||||
root = app.root
|
||||
if hasattr(root, "_cp_config"):
|
||||
merge(root._cp_config)
|
||||
if "/" in app.config:
|
||||
merge(app.config["/"])
|
||||
|
||||
# Mix in values from app.config.
|
||||
atoms = [x for x in path_info.split("/") if x]
|
||||
if atoms:
|
||||
last = atoms.pop()
|
||||
else:
|
||||
last = None
|
||||
for atom in atoms:
|
||||
curpath = "/".join((curpath, atom))
|
||||
if curpath in app.config:
|
||||
merge(app.config[curpath])
|
||||
|
||||
handler = None
|
||||
if result:
|
||||
controller = result.get('controller')
|
||||
controller = self.controllers.get(controller, controller)
|
||||
if controller:
|
||||
if isinstance(controller, classtype):
|
||||
controller = controller()
|
||||
# Get config from the controller.
|
||||
if hasattr(controller, "_cp_config"):
|
||||
merge(controller._cp_config)
|
||||
|
||||
action = result.get('action')
|
||||
if action is not None:
|
||||
handler = getattr(controller, action, None)
|
||||
# Get config from the handler
|
||||
if hasattr(handler, "_cp_config"):
|
||||
merge(handler._cp_config)
|
||||
else:
|
||||
handler = controller
|
||||
|
||||
# Do the last path atom here so it can
|
||||
# override the controller's _cp_config.
|
||||
if last:
|
||||
curpath = "/".join((curpath, last))
|
||||
if curpath in app.config:
|
||||
merge(app.config[curpath])
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def XMLRPCDispatcher(next_dispatcher=Dispatcher()):
|
||||
from cherrypy.lib import xmlrpcutil
|
||||
|
||||
def xmlrpc_dispatch(path_info):
|
||||
path_info = xmlrpcutil.patched_path(path_info)
|
||||
return next_dispatcher(path_info)
|
||||
return xmlrpc_dispatch
|
||||
|
||||
|
||||
def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True,
|
||||
**domains):
|
||||
"""
|
||||
Select a different handler based on the Host header.
|
||||
|
||||
This can be useful when running multiple sites within one CP server.
|
||||
It allows several domains to point to different parts of a single
|
||||
website structure. For example::
|
||||
|
||||
http://www.domain.example -> root
|
||||
http://www.domain2.example -> root/domain2/
|
||||
http://www.domain2.example:443 -> root/secure
|
||||
|
||||
can be accomplished via the following config::
|
||||
|
||||
[/]
|
||||
request.dispatch = cherrypy.dispatch.VirtualHost(
|
||||
**{'www.domain2.example': '/domain2',
|
||||
'www.domain2.example:443': '/secure',
|
||||
})
|
||||
|
||||
next_dispatcher
|
||||
The next dispatcher object in the dispatch chain.
|
||||
The VirtualHost dispatcher adds a prefix to the URL and calls
|
||||
another dispatcher. Defaults to cherrypy.dispatch.Dispatcher().
|
||||
|
||||
use_x_forwarded_host
|
||||
If True (the default), any "X-Forwarded-Host"
|
||||
request header will be used instead of the "Host" header. This
|
||||
is commonly added by HTTP servers (such as Apache) when proxying.
|
||||
|
||||
``**domains``
|
||||
A dict of {host header value: virtual prefix} pairs.
|
||||
The incoming "Host" request header is looked up in this dict,
|
||||
and, if a match is found, the corresponding "virtual prefix"
|
||||
value will be prepended to the URL path before calling the
|
||||
next dispatcher. Note that you often need separate entries
|
||||
for "example.com" and "www.example.com". In addition, "Host"
|
||||
headers may contain the port number.
|
||||
"""
|
||||
from cherrypy.lib import httputil
|
||||
|
||||
def vhost_dispatch(path_info):
|
||||
request = cherrypy.serving.request
|
||||
header = request.headers.get
|
||||
|
||||
domain = header('Host', '')
|
||||
if use_x_forwarded_host:
|
||||
domain = header("X-Forwarded-Host", domain)
|
||||
|
||||
prefix = domains.get(domain, "")
|
||||
if prefix:
|
||||
path_info = httputil.urljoin(prefix, path_info)
|
||||
|
||||
result = next_dispatcher(path_info)
|
||||
|
||||
# Touch up staticdir config. See
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/614.
|
||||
section = request.config.get('tools.staticdir.section')
|
||||
if section:
|
||||
section = section[len(prefix):]
|
||||
request.config['tools.staticdir.section'] = section
|
||||
|
||||
return result
|
||||
return vhost_dispatch
|
||||
609
lib/cherrypy/_cperror.py
Normal file
609
lib/cherrypy/_cperror.py
Normal file
@@ -0,0 +1,609 @@
|
||||
"""Exception classes for CherryPy.
|
||||
|
||||
CherryPy provides (and uses) exceptions for declaring that the HTTP response
|
||||
should be a status other than the default "200 OK". You can ``raise`` them like
|
||||
normal Python exceptions. You can also call them and they will raise
|
||||
themselves; this means you can set an
|
||||
:class:`HTTPError<cherrypy._cperror.HTTPError>`
|
||||
or :class:`HTTPRedirect<cherrypy._cperror.HTTPRedirect>` as the
|
||||
:attr:`request.handler<cherrypy._cprequest.Request.handler>`.
|
||||
|
||||
.. _redirectingpost:
|
||||
|
||||
Redirecting POST
|
||||
================
|
||||
|
||||
When you GET a resource and are redirected by the server to another Location,
|
||||
there's generally no problem since GET is both a "safe method" (there should
|
||||
be no side-effects) and an "idempotent method" (multiple calls are no different
|
||||
than a single call).
|
||||
|
||||
POST, however, is neither safe nor idempotent--if you
|
||||
charge a credit card, you don't want to be charged twice by a redirect!
|
||||
|
||||
For this reason, *none* of the 3xx responses permit a user-agent (browser) to
|
||||
resubmit a POST on redirection without first confirming the action with the
|
||||
user:
|
||||
|
||||
===== ================================= ===========
|
||||
300 Multiple Choices Confirm with the user
|
||||
301 Moved Permanently Confirm with the user
|
||||
302 Found (Object moved temporarily) Confirm with the user
|
||||
303 See Other GET the new URI--no confirmation
|
||||
304 Not modified (for conditional GET only--POST should not raise this error)
|
||||
305 Use Proxy Confirm with the user
|
||||
307 Temporary Redirect Confirm with the user
|
||||
===== ================================= ===========
|
||||
|
||||
However, browsers have historically implemented these restrictions poorly;
|
||||
in particular, many browsers do not force the user to confirm 301, 302
|
||||
or 307 when redirecting POST. For this reason, CherryPy defaults to 303,
|
||||
which most user-agents appear to have implemented correctly. Therefore, if
|
||||
you raise HTTPRedirect for a POST request, the user-agent will most likely
|
||||
attempt to GET the new URI (without asking for confirmation from the user).
|
||||
We realize this is confusing for developers, but it's the safest thing we
|
||||
could do. You are of course free to raise ``HTTPRedirect(uri, status=302)``
|
||||
or any other 3xx status if you know what you're doing, but given the
|
||||
environment, we couldn't let any of those be the default.
|
||||
|
||||
Custom Error Handling
|
||||
=====================
|
||||
|
||||
.. image:: /refman/cperrors.gif
|
||||
|
||||
Anticipated HTTP responses
|
||||
--------------------------
|
||||
|
||||
The 'error_page' config namespace can be used to provide custom HTML output for
|
||||
expected responses (like 404 Not Found). Supply a filename from which the
|
||||
output will be read. The contents will be interpolated with the values
|
||||
%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python
|
||||
`string formatting <http://docs.python.org/2/library/stdtypes.html#string-formatting-operations>`_.
|
||||
|
||||
::
|
||||
|
||||
_cp_config = {
|
||||
'error_page.404': os.path.join(localDir, "static/index.html")
|
||||
}
|
||||
|
||||
|
||||
Beginning in version 3.1, you may also provide a function or other callable as
|
||||
an error_page entry. It will be passed the same status, message, traceback and
|
||||
version arguments that are interpolated into templates::
|
||||
|
||||
def error_page_402(status, message, traceback, version):
|
||||
return "Error %s - Well, I'm very sorry but you haven't paid!" % status
|
||||
cherrypy.config.update({'error_page.402': error_page_402})
|
||||
|
||||
Also in 3.1, in addition to the numbered error codes, you may also supply
|
||||
"error_page.default" to handle all codes which do not have their own error_page
|
||||
entry.
|
||||
|
||||
|
||||
|
||||
Unanticipated errors
|
||||
--------------------
|
||||
|
||||
CherryPy also has a generic error handling mechanism: whenever an unanticipated
|
||||
error occurs in your code, it will call
|
||||
:func:`Request.error_response<cherrypy._cprequest.Request.error_response>` to
|
||||
set the response status, headers, and body. By default, this is the same
|
||||
output as
|
||||
:class:`HTTPError(500) <cherrypy._cperror.HTTPError>`. If you want to provide
|
||||
some other behavior, you generally replace "request.error_response".
|
||||
|
||||
Here is some sample code that shows how to display a custom error message and
|
||||
send an e-mail containing the error::
|
||||
|
||||
from cherrypy import _cperror
|
||||
|
||||
def handle_error():
|
||||
cherrypy.response.status = 500
|
||||
cherrypy.response.body = [
|
||||
"<html><body>Sorry, an error occured</body></html>"
|
||||
]
|
||||
sendMail('error@domain.com',
|
||||
'Error in your web app',
|
||||
_cperror.format_exc())
|
||||
|
||||
class Root:
|
||||
_cp_config = {'request.error_response': handle_error}
|
||||
|
||||
|
||||
Note that you have to explicitly set
|
||||
:attr:`response.body <cherrypy._cprequest.Response.body>`
|
||||
and not simply return an error message as a result.
|
||||
"""
|
||||
|
||||
from cgi import escape as _escape
|
||||
from sys import exc_info as _exc_info
|
||||
from traceback import format_exception as _format_exception
|
||||
from cherrypy._cpcompat import basestring, bytestr, iteritems, ntob
|
||||
from cherrypy._cpcompat import tonative, urljoin as _urljoin
|
||||
from cherrypy.lib import httputil as _httputil
|
||||
|
||||
|
||||
class CherryPyException(Exception):
|
||||
|
||||
"""A base class for CherryPy exceptions."""
|
||||
pass
|
||||
|
||||
|
||||
class TimeoutError(CherryPyException):
|
||||
|
||||
"""Exception raised when Response.timed_out is detected."""
|
||||
pass
|
||||
|
||||
|
||||
class InternalRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised to switch to the handler for a different URL.
|
||||
|
||||
This exception will redirect processing to another path within the site
|
||||
(without informing the client). Provide the new path as an argument when
|
||||
raising the exception. Provide any params in the querystring for the new
|
||||
URL.
|
||||
"""
|
||||
|
||||
def __init__(self, path, query_string=""):
|
||||
import cherrypy
|
||||
self.request = cherrypy.serving.request
|
||||
|
||||
self.query_string = query_string
|
||||
if "?" in path:
|
||||
# Separate any params included in the path
|
||||
path, self.query_string = path.split("?", 1)
|
||||
|
||||
# Note that urljoin will "do the right thing" whether url is:
|
||||
# 1. a URL relative to root (e.g. "/dummy")
|
||||
# 2. a URL relative to the current path
|
||||
# Note that any query string will be discarded.
|
||||
path = _urljoin(self.request.path_info, path)
|
||||
|
||||
# Set a 'path' member attribute so that code which traps this
|
||||
# error can have access to it.
|
||||
self.path = path
|
||||
|
||||
CherryPyException.__init__(self, path, self.query_string)
|
||||
|
||||
|
||||
class HTTPRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised when the request should be redirected.
|
||||
|
||||
This exception will force a HTTP redirect to the URL or URL's you give it.
|
||||
The new URL must be passed as the first argument to the Exception,
|
||||
e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list.
|
||||
If a URL is absolute, it will be used as-is. If it is relative, it is
|
||||
assumed to be relative to the current cherrypy.request.path_info.
|
||||
|
||||
If one of the provided URL is a unicode object, it will be encoded
|
||||
using the default encoding or the one passed in parameter.
|
||||
|
||||
There are multiple types of redirect, from which you can select via the
|
||||
``status`` argument. If you do not provide a ``status`` arg, it defaults to
|
||||
303 (or 302 if responding with HTTP/1.0).
|
||||
|
||||
Examples::
|
||||
|
||||
raise cherrypy.HTTPRedirect("")
|
||||
raise cherrypy.HTTPRedirect("/abs/path", 307)
|
||||
raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301)
|
||||
|
||||
See :ref:`redirectingpost` for additional caveats.
|
||||
"""
|
||||
|
||||
status = None
|
||||
"""The integer HTTP status code to emit."""
|
||||
|
||||
urls = None
|
||||
"""The list of URL's to emit."""
|
||||
|
||||
encoding = 'utf-8'
|
||||
"""The encoding when passed urls are not native strings"""
|
||||
|
||||
def __init__(self, urls, status=None, encoding=None):
|
||||
import cherrypy
|
||||
request = cherrypy.serving.request
|
||||
|
||||
if isinstance(urls, basestring):
|
||||
urls = [urls]
|
||||
|
||||
abs_urls = []
|
||||
for url in urls:
|
||||
url = tonative(url, encoding or self.encoding)
|
||||
|
||||
# Note that urljoin will "do the right thing" whether url is:
|
||||
# 1. a complete URL with host (e.g. "http://www.example.com/test")
|
||||
# 2. a URL relative to root (e.g. "/dummy")
|
||||
# 3. a URL relative to the current path
|
||||
# Note that any query string in cherrypy.request is discarded.
|
||||
url = _urljoin(cherrypy.url(), url)
|
||||
abs_urls.append(url)
|
||||
self.urls = abs_urls
|
||||
|
||||
# RFC 2616 indicates a 301 response code fits our goal; however,
|
||||
# browser support for 301 is quite messy. Do 302/303 instead. See
|
||||
# http://www.alanflavell.org.uk/www/post-redirect.html
|
||||
if status is None:
|
||||
if request.protocol >= (1, 1):
|
||||
status = 303
|
||||
else:
|
||||
status = 302
|
||||
else:
|
||||
status = int(status)
|
||||
if status < 300 or status > 399:
|
||||
raise ValueError("status must be between 300 and 399.")
|
||||
|
||||
self.status = status
|
||||
CherryPyException.__init__(self, abs_urls, status)
|
||||
|
||||
def set_response(self):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPRedirect object and set its output without *raising* the exception.
|
||||
"""
|
||||
import cherrypy
|
||||
response = cherrypy.serving.response
|
||||
response.status = status = self.status
|
||||
|
||||
if status in (300, 301, 302, 303, 307):
|
||||
response.headers['Content-Type'] = "text/html;charset=utf-8"
|
||||
# "The ... URI SHOULD be given by the Location field
|
||||
# in the response."
|
||||
response.headers['Location'] = self.urls[0]
|
||||
|
||||
# "Unless the request method was HEAD, the entity of the response
|
||||
# SHOULD contain a short hypertext note with a hyperlink to the
|
||||
# new URI(s)."
|
||||
msg = {
|
||||
300: "This resource can be found at ",
|
||||
301: "This resource has permanently moved to ",
|
||||
302: "This resource resides temporarily at ",
|
||||
303: "This resource can be found at ",
|
||||
307: "This resource has moved temporarily to ",
|
||||
}[status]
|
||||
msg += '<a href=%s>%s</a>.'
|
||||
from xml.sax import saxutils
|
||||
msgs = [msg % (saxutils.quoteattr(u), u) for u in self.urls]
|
||||
response.body = ntob("<br />\n".join(msgs), 'utf-8')
|
||||
# Previous code may have set C-L, so we have to reset it
|
||||
# (allow finalize to set it).
|
||||
response.headers.pop('Content-Length', None)
|
||||
elif status == 304:
|
||||
# Not Modified.
|
||||
# "The response MUST include the following header fields:
|
||||
# Date, unless its omission is required by section 14.18.1"
|
||||
# The "Date" header should have been set in Response.__init__
|
||||
|
||||
# "...the response SHOULD NOT include other entity-headers."
|
||||
for key in ('Allow', 'Content-Encoding', 'Content-Language',
|
||||
'Content-Length', 'Content-Location', 'Content-MD5',
|
||||
'Content-Range', 'Content-Type', 'Expires',
|
||||
'Last-Modified'):
|
||||
if key in response.headers:
|
||||
del response.headers[key]
|
||||
|
||||
# "The 304 response MUST NOT contain a message-body."
|
||||
response.body = None
|
||||
# Previous code may have set C-L, so we have to reset it.
|
||||
response.headers.pop('Content-Length', None)
|
||||
elif status == 305:
|
||||
# Use Proxy.
|
||||
# self.urls[0] should be the URI of the proxy.
|
||||
response.headers['Location'] = self.urls[0]
|
||||
response.body = None
|
||||
# Previous code may have set C-L, so we have to reset it.
|
||||
response.headers.pop('Content-Length', None)
|
||||
else:
|
||||
raise ValueError("The %s status code is unknown." % status)
|
||||
|
||||
def __call__(self):
|
||||
"""Use this exception as a request.handler (raise self)."""
|
||||
raise self
|
||||
|
||||
|
||||
def clean_headers(status):
|
||||
"""Remove any headers which should not apply to an error response."""
|
||||
import cherrypy
|
||||
|
||||
response = cherrypy.serving.response
|
||||
|
||||
# Remove headers which applied to the original content,
|
||||
# but do not apply to the error page.
|
||||
respheaders = response.headers
|
||||
for key in ["Accept-Ranges", "Age", "ETag", "Location", "Retry-After",
|
||||
"Vary", "Content-Encoding", "Content-Length", "Expires",
|
||||
"Content-Location", "Content-MD5", "Last-Modified"]:
|
||||
if key in respheaders:
|
||||
del respheaders[key]
|
||||
|
||||
if status != 416:
|
||||
# A server sending a response with status code 416 (Requested
|
||||
# range not satisfiable) SHOULD include a Content-Range field
|
||||
# with a byte-range-resp-spec of "*". The instance-length
|
||||
# specifies the current length of the selected resource.
|
||||
# A response with status code 206 (Partial Content) MUST NOT
|
||||
# include a Content-Range field with a byte-range- resp-spec of "*".
|
||||
if "Content-Range" in respheaders:
|
||||
del respheaders["Content-Range"]
|
||||
|
||||
|
||||
class HTTPError(CherryPyException):
|
||||
|
||||
"""Exception used to return an HTTP error code (4xx-5xx) to the client.
|
||||
|
||||
This exception can be used to automatically send a response using a
|
||||
http status code, with an appropriate error page. It takes an optional
|
||||
``status`` argument (which must be between 400 and 599); it defaults to 500
|
||||
("Internal Server Error"). It also takes an optional ``message`` argument,
|
||||
which will be returned in the response body. See
|
||||
`RFC2616 <http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4>`_
|
||||
for a complete list of available error codes and when to use them.
|
||||
|
||||
Examples::
|
||||
|
||||
raise cherrypy.HTTPError(403)
|
||||
raise cherrypy.HTTPError(
|
||||
"403 Forbidden", "You are not allowed to access this resource.")
|
||||
"""
|
||||
|
||||
status = None
|
||||
"""The HTTP status code. May be of type int or str (with a Reason-Phrase).
|
||||
"""
|
||||
|
||||
code = None
|
||||
"""The integer HTTP status code."""
|
||||
|
||||
reason = None
|
||||
"""The HTTP Reason-Phrase string."""
|
||||
|
||||
def __init__(self, status=500, message=None):
|
||||
self.status = status
|
||||
try:
|
||||
self.code, self.reason, defaultmsg = _httputil.valid_status(status)
|
||||
except ValueError:
|
||||
raise self.__class__(500, _exc_info()[1].args[0])
|
||||
|
||||
if self.code < 400 or self.code > 599:
|
||||
raise ValueError("status must be between 400 and 599.")
|
||||
|
||||
# See http://www.python.org/dev/peps/pep-0352/
|
||||
# self.message = message
|
||||
self._message = message or defaultmsg
|
||||
CherryPyException.__init__(self, status, message)
|
||||
|
||||
def set_response(self):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPError object and set its output without *raising* the exception.
|
||||
"""
|
||||
import cherrypy
|
||||
|
||||
response = cherrypy.serving.response
|
||||
|
||||
clean_headers(self.code)
|
||||
|
||||
# In all cases, finalize will be called after this method,
|
||||
# so don't bother cleaning up response values here.
|
||||
response.status = self.status
|
||||
tb = None
|
||||
if cherrypy.serving.request.show_tracebacks:
|
||||
tb = format_exc()
|
||||
|
||||
response.headers.pop('Content-Length', None)
|
||||
|
||||
content = self.get_error_page(self.status, traceback=tb,
|
||||
message=self._message)
|
||||
response.body = content
|
||||
|
||||
_be_ie_unfriendly(self.code)
|
||||
|
||||
def get_error_page(self, *args, **kwargs):
|
||||
return get_error_page(*args, **kwargs)
|
||||
|
||||
def __call__(self):
|
||||
"""Use this exception as a request.handler (raise self)."""
|
||||
raise self
|
||||
|
||||
|
||||
class NotFound(HTTPError):
|
||||
|
||||
"""Exception raised when a URL could not be mapped to any handler (404).
|
||||
|
||||
This is equivalent to raising
|
||||
:class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None):
|
||||
if path is None:
|
||||
import cherrypy
|
||||
request = cherrypy.serving.request
|
||||
path = request.script_name + request.path_info
|
||||
self.args = (path,)
|
||||
HTTPError.__init__(self, 404, "The path '%s' was not found." % path)
|
||||
|
||||
|
||||
_HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
|
||||
"-//W3C//DTD XHTML 1.0 Transitional//EN"
|
||||
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"></meta>
|
||||
<title>%(status)s</title>
|
||||
<style type="text/css">
|
||||
#powered_by {
|
||||
margin-top: 20px;
|
||||
border-top: 2px solid black;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
#traceback {
|
||||
color: red;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h2>%(status)s</h2>
|
||||
<p>%(message)s</p>
|
||||
<pre id="traceback">%(traceback)s</pre>
|
||||
<div id="powered_by">
|
||||
<span>
|
||||
Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a>
|
||||
</span>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
'''
|
||||
|
||||
|
||||
def get_error_page(status, **kwargs):
|
||||
"""Return an HTML page, containing a pretty error response.
|
||||
|
||||
status should be an int or a str.
|
||||
kwargs will be interpolated into the page template.
|
||||
"""
|
||||
import cherrypy
|
||||
|
||||
try:
|
||||
code, reason, message = _httputil.valid_status(status)
|
||||
except ValueError:
|
||||
raise cherrypy.HTTPError(500, _exc_info()[1].args[0])
|
||||
|
||||
# We can't use setdefault here, because some
|
||||
# callers send None for kwarg values.
|
||||
if kwargs.get('status') is None:
|
||||
kwargs['status'] = "%s %s" % (code, reason)
|
||||
if kwargs.get('message') is None:
|
||||
kwargs['message'] = message
|
||||
if kwargs.get('traceback') is None:
|
||||
kwargs['traceback'] = ''
|
||||
if kwargs.get('version') is None:
|
||||
kwargs['version'] = cherrypy.__version__
|
||||
|
||||
for k, v in iteritems(kwargs):
|
||||
if v is None:
|
||||
kwargs[k] = ""
|
||||
else:
|
||||
kwargs[k] = _escape(kwargs[k])
|
||||
|
||||
# Use a custom template or callable for the error page?
|
||||
pages = cherrypy.serving.request.error_page
|
||||
error_page = pages.get(code) or pages.get('default')
|
||||
|
||||
# Default template, can be overridden below.
|
||||
template = _HTTPErrorTemplate
|
||||
if error_page:
|
||||
try:
|
||||
if hasattr(error_page, '__call__'):
|
||||
# The caller function may be setting headers manually,
|
||||
# so we delegate to it completely. We may be returning
|
||||
# an iterator as well as a string here.
|
||||
#
|
||||
# We *must* make sure any content is not unicode.
|
||||
result = error_page(**kwargs)
|
||||
if cherrypy.lib.is_iterator(result):
|
||||
from cherrypy.lib.encoding import UTF8StreamEncoder
|
||||
return UTF8StreamEncoder(result)
|
||||
elif isinstance(result, cherrypy._cpcompat.unicodestr):
|
||||
return result.encode('utf-8')
|
||||
else:
|
||||
if not isinstance(result, cherrypy._cpcompat.bytestr):
|
||||
raise ValueError('error page function did not '
|
||||
'return a bytestring, unicodestring or an '
|
||||
'iterator - returned object of type %s.'
|
||||
% (type(result).__name__))
|
||||
return result
|
||||
else:
|
||||
# Load the template from this path.
|
||||
template = tonative(open(error_page, 'rb').read())
|
||||
except:
|
||||
e = _format_exception(*_exc_info())[-1]
|
||||
m = kwargs['message']
|
||||
if m:
|
||||
m += "<br />"
|
||||
m += "In addition, the custom error page failed:\n<br />%s" % e
|
||||
kwargs['message'] = m
|
||||
|
||||
response = cherrypy.serving.response
|
||||
response.headers['Content-Type'] = "text/html;charset=utf-8"
|
||||
result = template % kwargs
|
||||
return result.encode('utf-8')
|
||||
|
||||
|
||||
|
||||
_ie_friendly_error_sizes = {
|
||||
400: 512, 403: 256, 404: 512, 405: 256,
|
||||
406: 512, 408: 512, 409: 512, 410: 256,
|
||||
500: 512, 501: 512, 505: 512,
|
||||
}
|
||||
|
||||
|
||||
def _be_ie_unfriendly(status):
|
||||
import cherrypy
|
||||
response = cherrypy.serving.response
|
||||
|
||||
# For some statuses, Internet Explorer 5+ shows "friendly error
|
||||
# messages" instead of our response.body if the body is smaller
|
||||
# than a given size. Fix this by returning a body over that size
|
||||
# (by adding whitespace).
|
||||
# See http://support.microsoft.com/kb/q218155/
|
||||
s = _ie_friendly_error_sizes.get(status, 0)
|
||||
if s:
|
||||
s += 1
|
||||
# Since we are issuing an HTTP error status, we assume that
|
||||
# the entity is short, and we should just collapse it.
|
||||
content = response.collapse_body()
|
||||
l = len(content)
|
||||
if l and l < s:
|
||||
# IN ADDITION: the response must be written to IE
|
||||
# in one chunk or it will still get replaced! Bah.
|
||||
content = content + (ntob(" ") * (s - l))
|
||||
response.body = content
|
||||
response.headers['Content-Length'] = str(len(content))
|
||||
|
||||
|
||||
def format_exc(exc=None):
|
||||
"""Return exc (or sys.exc_info if None), formatted."""
|
||||
try:
|
||||
if exc is None:
|
||||
exc = _exc_info()
|
||||
if exc == (None, None, None):
|
||||
return ""
|
||||
import traceback
|
||||
return "".join(traceback.format_exception(*exc))
|
||||
finally:
|
||||
del exc
|
||||
|
||||
|
||||
def bare_error(extrabody=None):
|
||||
"""Produce status, headers, body for a critical error.
|
||||
|
||||
Returns a triple without calling any other questionable functions,
|
||||
so it should be as error-free as possible. Call it from an HTTP server
|
||||
if you get errors outside of the request.
|
||||
|
||||
If extrabody is None, a friendly but rather unhelpful error message
|
||||
is set in the body. If extrabody is a string, it will be appended
|
||||
as-is to the body.
|
||||
"""
|
||||
|
||||
# The whole point of this function is to be a last line-of-defense
|
||||
# in handling errors. That is, it must not raise any errors itself;
|
||||
# it cannot be allowed to fail. Therefore, don't add to it!
|
||||
# In particular, don't call any other CP functions.
|
||||
|
||||
body = ntob("Unrecoverable error in the server.")
|
||||
if extrabody is not None:
|
||||
if not isinstance(extrabody, bytestr):
|
||||
extrabody = extrabody.encode('utf-8')
|
||||
body += ntob("\n") + extrabody
|
||||
|
||||
return (ntob("500 Internal Server Error"),
|
||||
[(ntob('Content-Type'), ntob('text/plain')),
|
||||
(ntob('Content-Length'), ntob(str(len(body)), 'ISO-8859-1'))],
|
||||
[body])
|
||||
459
lib/cherrypy/_cplogging.py
Normal file
459
lib/cherrypy/_cplogging.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
Simple config
|
||||
=============
|
||||
|
||||
Although CherryPy uses the :mod:`Python logging module <logging>`, it does so
|
||||
behind the scenes so that simple logging is simple, but complicated logging
|
||||
is still possible. "Simple" logging means that you can log to the screen
|
||||
(i.e. console/stdout) or to a file, and that you can easily have separate
|
||||
error and access log files.
|
||||
|
||||
Here are the simplified logging settings. You use these by adding lines to
|
||||
your config file or dict. You should set these at either the global level or
|
||||
per application (see next), but generally not both.
|
||||
|
||||
* ``log.screen``: Set this to True to have both "error" and "access" messages
|
||||
printed to stdout.
|
||||
* ``log.access_file``: Set this to an absolute filename where you want
|
||||
"access" messages written.
|
||||
* ``log.error_file``: Set this to an absolute filename where you want "error"
|
||||
messages written.
|
||||
|
||||
Many events are automatically logged; to log your own application events, call
|
||||
:func:`cherrypy.log`.
|
||||
|
||||
Architecture
|
||||
============
|
||||
|
||||
Separate scopes
|
||||
---------------
|
||||
|
||||
CherryPy provides log managers at both the global and application layers.
|
||||
This means you can have one set of logging rules for your entire site,
|
||||
and another set of rules specific to each application. The global log
|
||||
manager is found at :func:`cherrypy.log`, and the log manager for each
|
||||
application is found at :attr:`app.log<cherrypy._cptree.Application.log>`.
|
||||
If you're inside a request, the latter is reachable from
|
||||
``cherrypy.request.app.log``; if you're outside a request, you'll have to
|
||||
obtain a reference to the ``app``: either the return value of
|
||||
:func:`tree.mount()<cherrypy._cptree.Tree.mount>` or, if you used
|
||||
:func:`quickstart()<cherrypy.quickstart>` instead, via
|
||||
``cherrypy.tree.apps['/']``.
|
||||
|
||||
By default, the global logs are named "cherrypy.error" and "cherrypy.access",
|
||||
and the application logs are named "cherrypy.error.2378745" and
|
||||
"cherrypy.access.2378745" (the number is the id of the Application object).
|
||||
This means that the application logs "bubble up" to the site logs, so if your
|
||||
application has no log handlers, the site-level handlers will still log the
|
||||
messages.
|
||||
|
||||
Errors vs. Access
|
||||
-----------------
|
||||
|
||||
Each log manager handles both "access" messages (one per HTTP request) and
|
||||
"error" messages (everything else). Note that the "error" log is not just for
|
||||
errors! The format of access messages is highly formalized, but the error log
|
||||
isn't--it receives messages from a variety of sources (including full error
|
||||
tracebacks, if enabled).
|
||||
|
||||
If you are logging the access log and error log to the same source, then there
|
||||
is a possibility that a specially crafted error message may replicate an access
|
||||
log message as described in CWE-117. In this case it is the application
|
||||
developer's responsibility to manually escape data before using CherryPy's log()
|
||||
functionality, or they may create an application that is vulnerable to CWE-117.
|
||||
This would be achieved by using a custom handler escape any special characters,
|
||||
and attached as described below.
|
||||
|
||||
Custom Handlers
|
||||
===============
|
||||
|
||||
The simple settings above work by manipulating Python's standard :mod:`logging`
|
||||
module. So when you need something more complex, the full power of the standard
|
||||
module is yours to exploit. You can borrow or create custom handlers, formats,
|
||||
filters, and much more. Here's an example that skips the standard FileHandler
|
||||
and uses a RotatingFileHandler instead:
|
||||
|
||||
::
|
||||
|
||||
#python
|
||||
log = app.log
|
||||
|
||||
# Remove the default FileHandlers if present.
|
||||
log.error_file = ""
|
||||
log.access_file = ""
|
||||
|
||||
maxBytes = getattr(log, "rot_maxBytes", 10000000)
|
||||
backupCount = getattr(log, "rot_backupCount", 1000)
|
||||
|
||||
# Make a new RotatingFileHandler for the error log.
|
||||
fname = getattr(log, "rot_error_file", "error.log")
|
||||
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
|
||||
h.setLevel(DEBUG)
|
||||
h.setFormatter(_cplogging.logfmt)
|
||||
log.error_log.addHandler(h)
|
||||
|
||||
# Make a new RotatingFileHandler for the access log.
|
||||
fname = getattr(log, "rot_access_file", "access.log")
|
||||
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
|
||||
h.setLevel(DEBUG)
|
||||
h.setFormatter(_cplogging.logfmt)
|
||||
log.access_log.addHandler(h)
|
||||
|
||||
|
||||
The ``rot_*`` attributes are pulled straight from the application log object.
|
||||
Since "log.*" config entries simply set attributes on the log object, you can
|
||||
add custom attributes to your heart's content. Note that these handlers are
|
||||
used ''instead'' of the default, simple handlers outlined above (so don't set
|
||||
the "log.error_file" config entry, for example).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
|
||||
logging.Logger.manager.emittedNoHandlerWarning = 1
|
||||
logfmt = logging.Formatter("%(message)s")
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy import _cperror
|
||||
from cherrypy._cpcompat import ntob, py3k
|
||||
|
||||
|
||||
class NullHandler(logging.Handler):
|
||||
|
||||
"""A no-op logging handler to silence the logging.lastResort handler."""
|
||||
|
||||
def handle(self, record):
|
||||
pass
|
||||
|
||||
def emit(self, record):
|
||||
pass
|
||||
|
||||
def createLock(self):
|
||||
self.lock = None
|
||||
|
||||
|
||||
class LogManager(object):
|
||||
|
||||
"""An object to assist both simple and advanced logging.
|
||||
|
||||
``cherrypy.log`` is an instance of this class.
|
||||
"""
|
||||
|
||||
appid = None
|
||||
"""The id() of the Application object which owns this log manager. If this
|
||||
is a global log manager, appid is None."""
|
||||
|
||||
error_log = None
|
||||
"""The actual :class:`logging.Logger` instance for error messages."""
|
||||
|
||||
access_log = None
|
||||
"""The actual :class:`logging.Logger` instance for access messages."""
|
||||
|
||||
if py3k:
|
||||
access_log_format = \
|
||||
'{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"'
|
||||
else:
|
||||
access_log_format = \
|
||||
'%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
|
||||
|
||||
logger_root = None
|
||||
"""The "top-level" logger name.
|
||||
|
||||
This string will be used as the first segment in the Logger names.
|
||||
The default is "cherrypy", for example, in which case the Logger names
|
||||
will be of the form::
|
||||
|
||||
cherrypy.error.<appid>
|
||||
cherrypy.access.<appid>
|
||||
"""
|
||||
|
||||
def __init__(self, appid=None, logger_root="cherrypy"):
|
||||
self.logger_root = logger_root
|
||||
self.appid = appid
|
||||
if appid is None:
|
||||
self.error_log = logging.getLogger("%s.error" % logger_root)
|
||||
self.access_log = logging.getLogger("%s.access" % logger_root)
|
||||
else:
|
||||
self.error_log = logging.getLogger(
|
||||
"%s.error.%s" % (logger_root, appid))
|
||||
self.access_log = logging.getLogger(
|
||||
"%s.access.%s" % (logger_root, appid))
|
||||
self.error_log.setLevel(logging.INFO)
|
||||
self.access_log.setLevel(logging.INFO)
|
||||
|
||||
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
|
||||
self.error_log.addHandler(NullHandler())
|
||||
self.access_log.addHandler(NullHandler())
|
||||
|
||||
cherrypy.engine.subscribe('graceful', self.reopen_files)
|
||||
|
||||
def reopen_files(self):
|
||||
"""Close and reopen all file handlers."""
|
||||
for log in (self.error_log, self.access_log):
|
||||
for h in log.handlers:
|
||||
if isinstance(h, logging.FileHandler):
|
||||
h.acquire()
|
||||
h.stream.close()
|
||||
h.stream = open(h.baseFilename, h.mode)
|
||||
h.release()
|
||||
|
||||
def error(self, msg='', context='', severity=logging.INFO,
|
||||
traceback=False):
|
||||
"""Write the given ``msg`` to the error log.
|
||||
|
||||
This is not just for errors! Applications may call this at any time
|
||||
to log application-specific information.
|
||||
|
||||
If ``traceback`` is True, the traceback of the current exception
|
||||
(if any) will be appended to ``msg``.
|
||||
"""
|
||||
if traceback:
|
||||
msg += _cperror.format_exc()
|
||||
self.error_log.log(severity, ' '.join((self.time(), context, msg)))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""An alias for ``error``."""
|
||||
return self.error(*args, **kwargs)
|
||||
|
||||
def access(self):
|
||||
"""Write to the access log (in Apache/NCSA Combined Log format).
|
||||
|
||||
See the
|
||||
`apache documentation <http://httpd.apache.org/docs/current/logs.html#combined>`_
|
||||
for format details.
|
||||
|
||||
CherryPy calls this automatically for you. Note there are no arguments;
|
||||
it collects the data itself from
|
||||
:class:`cherrypy.request<cherrypy._cprequest.Request>`.
|
||||
|
||||
Like Apache started doing in 2.0.46, non-printable and other special
|
||||
characters in %r (and we expand that to all parts) are escaped using
|
||||
\\xhh sequences, where hh stands for the hexadecimal representation
|
||||
of the raw byte. Exceptions from this rule are " and \\, which are
|
||||
escaped by prepending a backslash, and all whitespace characters,
|
||||
which are written in their C-style notation (\\n, \\t, etc).
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
remote = request.remote
|
||||
response = cherrypy.serving.response
|
||||
outheaders = response.headers
|
||||
inheaders = request.headers
|
||||
if response.output_status is None:
|
||||
status = "-"
|
||||
else:
|
||||
status = response.output_status.split(ntob(" "), 1)[0]
|
||||
if py3k:
|
||||
status = status.decode('ISO-8859-1')
|
||||
|
||||
atoms = {'h': remote.name or remote.ip,
|
||||
'l': '-',
|
||||
'u': getattr(request, "login", None) or "-",
|
||||
't': self.time(),
|
||||
'r': request.request_line,
|
||||
's': status,
|
||||
'b': dict.get(outheaders, 'Content-Length', '') or "-",
|
||||
'f': dict.get(inheaders, 'Referer', ''),
|
||||
'a': dict.get(inheaders, 'User-Agent', ''),
|
||||
'o': dict.get(inheaders, 'Host', '-'),
|
||||
}
|
||||
if py3k:
|
||||
for k, v in atoms.items():
|
||||
if not isinstance(v, str):
|
||||
v = str(v)
|
||||
v = v.replace('"', '\\"').encode('utf8')
|
||||
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
|
||||
# and backslash for us. All we have to do is strip the quotes.
|
||||
v = repr(v)[2:-1]
|
||||
|
||||
# in python 3.0 the repr of bytes (as returned by encode)
|
||||
# uses double \'s. But then the logger escapes them yet, again
|
||||
# resulting in quadruple slashes. Remove the extra one here.
|
||||
v = v.replace('\\\\', '\\')
|
||||
|
||||
# Escape double-quote.
|
||||
atoms[k] = v
|
||||
|
||||
try:
|
||||
self.access_log.log(
|
||||
logging.INFO, self.access_log_format.format(**atoms))
|
||||
except:
|
||||
self(traceback=True)
|
||||
else:
|
||||
for k, v in atoms.items():
|
||||
if isinstance(v, unicode):
|
||||
v = v.encode('utf8')
|
||||
elif not isinstance(v, str):
|
||||
v = str(v)
|
||||
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
|
||||
# and backslash for us. All we have to do is strip the quotes.
|
||||
v = repr(v)[1:-1]
|
||||
# Escape double-quote.
|
||||
atoms[k] = v.replace('"', '\\"')
|
||||
|
||||
try:
|
||||
self.access_log.log(
|
||||
logging.INFO, self.access_log_format % atoms)
|
||||
except:
|
||||
self(traceback=True)
|
||||
|
||||
def time(self):
|
||||
"""Return now() in Apache Common Log Format (no timezone)."""
|
||||
now = datetime.datetime.now()
|
||||
monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun',
|
||||
'jul', 'aug', 'sep', 'oct', 'nov', 'dec']
|
||||
month = monthnames[now.month - 1].capitalize()
|
||||
return ('[%02d/%s/%04d:%02d:%02d:%02d]' %
|
||||
(now.day, month, now.year, now.hour, now.minute, now.second))
|
||||
|
||||
def _get_builtin_handler(self, log, key):
|
||||
for h in log.handlers:
|
||||
if getattr(h, "_cpbuiltin", None) == key:
|
||||
return h
|
||||
|
||||
# ------------------------- Screen handlers ------------------------- #
|
||||
def _set_screen_handler(self, log, enable, stream=None):
|
||||
h = self._get_builtin_handler(log, "screen")
|
||||
if enable:
|
||||
if not h:
|
||||
if stream is None:
|
||||
stream = sys.stderr
|
||||
h = logging.StreamHandler(stream)
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "screen"
|
||||
log.addHandler(h)
|
||||
elif h:
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_screen(self):
|
||||
h = self._get_builtin_handler
|
||||
has_h = h(self.error_log, "screen") or h(self.access_log, "screen")
|
||||
return bool(has_h)
|
||||
|
||||
def _set_screen(self, newvalue):
|
||||
self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr)
|
||||
self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout)
|
||||
screen = property(_get_screen, _set_screen,
|
||||
doc="""Turn stderr/stdout logging on or off.
|
||||
|
||||
If you set this to True, it'll add the appropriate StreamHandler for
|
||||
you. If you set it to False, it will remove the handler.
|
||||
""")
|
||||
|
||||
# -------------------------- File handlers -------------------------- #
|
||||
|
||||
def _add_builtin_file_handler(self, log, fname):
|
||||
h = logging.FileHandler(fname)
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "file"
|
||||
log.addHandler(h)
|
||||
|
||||
def _set_file_handler(self, log, filename):
|
||||
h = self._get_builtin_handler(log, "file")
|
||||
if filename:
|
||||
if h:
|
||||
if h.baseFilename != os.path.abspath(filename):
|
||||
h.close()
|
||||
log.handlers.remove(h)
|
||||
self._add_builtin_file_handler(log, filename)
|
||||
else:
|
||||
self._add_builtin_file_handler(log, filename)
|
||||
else:
|
||||
if h:
|
||||
h.close()
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_error_file(self):
|
||||
h = self._get_builtin_handler(self.error_log, "file")
|
||||
if h:
|
||||
return h.baseFilename
|
||||
return ''
|
||||
|
||||
def _set_error_file(self, newvalue):
|
||||
self._set_file_handler(self.error_log, newvalue)
|
||||
error_file = property(_get_error_file, _set_error_file,
|
||||
doc="""The filename for self.error_log.
|
||||
|
||||
If you set this to a string, it'll add the appropriate FileHandler for
|
||||
you. If you set it to ``None`` or ``''``, it will remove the handler.
|
||||
""")
|
||||
|
||||
def _get_access_file(self):
|
||||
h = self._get_builtin_handler(self.access_log, "file")
|
||||
if h:
|
||||
return h.baseFilename
|
||||
return ''
|
||||
|
||||
def _set_access_file(self, newvalue):
|
||||
self._set_file_handler(self.access_log, newvalue)
|
||||
access_file = property(_get_access_file, _set_access_file,
|
||||
doc="""The filename for self.access_log.
|
||||
|
||||
If you set this to a string, it'll add the appropriate FileHandler for
|
||||
you. If you set it to ``None`` or ``''``, it will remove the handler.
|
||||
""")
|
||||
|
||||
# ------------------------- WSGI handlers ------------------------- #
|
||||
|
||||
def _set_wsgi_handler(self, log, enable):
|
||||
h = self._get_builtin_handler(log, "wsgi")
|
||||
if enable:
|
||||
if not h:
|
||||
h = WSGIErrorHandler()
|
||||
h.setFormatter(logfmt)
|
||||
h._cpbuiltin = "wsgi"
|
||||
log.addHandler(h)
|
||||
elif h:
|
||||
log.handlers.remove(h)
|
||||
|
||||
def _get_wsgi(self):
|
||||
return bool(self._get_builtin_handler(self.error_log, "wsgi"))
|
||||
|
||||
def _set_wsgi(self, newvalue):
|
||||
self._set_wsgi_handler(self.error_log, newvalue)
|
||||
wsgi = property(_get_wsgi, _set_wsgi,
|
||||
doc="""Write errors to wsgi.errors.
|
||||
|
||||
If you set this to True, it'll add the appropriate
|
||||
:class:`WSGIErrorHandler<cherrypy._cplogging.WSGIErrorHandler>` for you
|
||||
(which writes errors to ``wsgi.errors``).
|
||||
If you set it to False, it will remove the handler.
|
||||
""")
|
||||
|
||||
|
||||
class WSGIErrorHandler(logging.Handler):
|
||||
|
||||
"A handler class which writes logging records to environ['wsgi.errors']."
|
||||
|
||||
def flush(self):
|
||||
"""Flushes the stream."""
|
||||
try:
|
||||
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
stream.flush()
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit a record."""
|
||||
try:
|
||||
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
msg = self.format(record)
|
||||
fs = "%s\n"
|
||||
import types
|
||||
# if no unicode support...
|
||||
if not hasattr(types, "UnicodeType"):
|
||||
stream.write(fs % msg)
|
||||
else:
|
||||
try:
|
||||
stream.write(fs % msg)
|
||||
except UnicodeError:
|
||||
stream.write(fs % msg.encode("UTF-8"))
|
||||
self.flush()
|
||||
except:
|
||||
self.handleError(record)
|
||||
353
lib/cherrypy/_cpmodpy.py
Normal file
353
lib/cherrypy/_cpmodpy.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Native adapter for serving CherryPy via mod_python
|
||||
|
||||
Basic usage:
|
||||
|
||||
##########################################
|
||||
# Application in a module called myapp.py
|
||||
##########################################
|
||||
|
||||
import cherrypy
|
||||
|
||||
class Root:
|
||||
@cherrypy.expose
|
||||
def index(self):
|
||||
return 'Hi there, Ho there, Hey there'
|
||||
|
||||
|
||||
# We will use this method from the mod_python configuration
|
||||
# as the entry point to our application
|
||||
def setup_server():
|
||||
cherrypy.tree.mount(Root())
|
||||
cherrypy.config.update({'environment': 'production',
|
||||
'log.screen': False,
|
||||
'show_tracebacks': False})
|
||||
|
||||
##########################################
|
||||
# mod_python settings for apache2
|
||||
# This should reside in your httpd.conf
|
||||
# or a file that will be loaded at
|
||||
# apache startup
|
||||
##########################################
|
||||
|
||||
# Start
|
||||
DocumentRoot "/"
|
||||
Listen 8080
|
||||
LoadModule python_module /usr/lib/apache2/modules/mod_python.so
|
||||
|
||||
<Location "/">
|
||||
PythonPath "sys.path+['/path/to/my/application']"
|
||||
SetHandler python-program
|
||||
PythonHandler cherrypy._cpmodpy::handler
|
||||
PythonOption cherrypy.setup myapp::setup_server
|
||||
PythonDebug On
|
||||
</Location>
|
||||
# End
|
||||
|
||||
The actual path to your mod_python.so is dependent on your
|
||||
environment. In this case we suppose a global mod_python
|
||||
installation on a Linux distribution such as Ubuntu.
|
||||
|
||||
We do set the PythonPath configuration setting so that
|
||||
your application can be found by from the user running
|
||||
the apache2 instance. Of course if your application
|
||||
resides in the global site-package this won't be needed.
|
||||
|
||||
Then restart apache2 and access http://127.0.0.1:8080
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import BytesIO, copyitems, ntob
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil
|
||||
|
||||
|
||||
# ------------------------------ Request-handling
|
||||
|
||||
|
||||
def setup(req):
|
||||
from mod_python import apache
|
||||
|
||||
# Run any setup functions defined by a "PythonOption cherrypy.setup"
|
||||
# directive.
|
||||
options = req.get_options()
|
||||
if 'cherrypy.setup' in options:
|
||||
for function in options['cherrypy.setup'].split():
|
||||
atoms = function.split('::', 1)
|
||||
if len(atoms) == 1:
|
||||
mod = __import__(atoms[0], globals(), locals())
|
||||
else:
|
||||
modname, fname = atoms
|
||||
mod = __import__(modname, globals(), locals(), [fname])
|
||||
func = getattr(mod, fname)
|
||||
func()
|
||||
|
||||
cherrypy.config.update({'log.screen': False,
|
||||
"tools.ignore_headers.on": True,
|
||||
"tools.ignore_headers.headers": ['Range'],
|
||||
})
|
||||
|
||||
engine = cherrypy.engine
|
||||
if hasattr(engine, "signal_handler"):
|
||||
engine.signal_handler.unsubscribe()
|
||||
if hasattr(engine, "console_control_handler"):
|
||||
engine.console_control_handler.unsubscribe()
|
||||
engine.autoreload.unsubscribe()
|
||||
cherrypy.server.unsubscribe()
|
||||
|
||||
def _log(msg, level):
|
||||
newlevel = apache.APLOG_ERR
|
||||
if logging.DEBUG >= level:
|
||||
newlevel = apache.APLOG_DEBUG
|
||||
elif logging.INFO >= level:
|
||||
newlevel = apache.APLOG_INFO
|
||||
elif logging.WARNING >= level:
|
||||
newlevel = apache.APLOG_WARNING
|
||||
# On Windows, req.server is required or the msg will vanish. See
|
||||
# http://www.modpython.org/pipermail/mod_python/2003-October/014291.html
|
||||
# Also, "When server is not specified...LogLevel does not apply..."
|
||||
apache.log_error(msg, newlevel, req.server)
|
||||
engine.subscribe('log', _log)
|
||||
|
||||
engine.start()
|
||||
|
||||
def cherrypy_cleanup(data):
|
||||
engine.exit()
|
||||
try:
|
||||
# apache.register_cleanup wasn't available until 3.1.4.
|
||||
apache.register_cleanup(cherrypy_cleanup)
|
||||
except AttributeError:
|
||||
req.server.register_cleanup(req, cherrypy_cleanup)
|
||||
|
||||
|
||||
class _ReadOnlyRequest:
|
||||
expose = ('read', 'readline', 'readlines')
|
||||
|
||||
def __init__(self, req):
|
||||
for method in self.expose:
|
||||
self.__dict__[method] = getattr(req, method)
|
||||
|
||||
|
||||
recursive = False
|
||||
|
||||
_isSetUp = False
|
||||
|
||||
|
||||
def handler(req):
|
||||
from mod_python import apache
|
||||
try:
|
||||
global _isSetUp
|
||||
if not _isSetUp:
|
||||
setup(req)
|
||||
_isSetUp = True
|
||||
|
||||
# Obtain a Request object from CherryPy
|
||||
local = req.connection.local_addr
|
||||
local = httputil.Host(
|
||||
local[0], local[1], req.connection.local_host or "")
|
||||
remote = req.connection.remote_addr
|
||||
remote = httputil.Host(
|
||||
remote[0], remote[1], req.connection.remote_host or "")
|
||||
|
||||
scheme = req.parsed_uri[0] or 'http'
|
||||
req.get_basic_auth_pw()
|
||||
|
||||
try:
|
||||
# apache.mpm_query only became available in mod_python 3.1
|
||||
q = apache.mpm_query
|
||||
threaded = q(apache.AP_MPMQ_IS_THREADED)
|
||||
forked = q(apache.AP_MPMQ_IS_FORKED)
|
||||
except AttributeError:
|
||||
bad_value = ("You must provide a PythonOption '%s', "
|
||||
"either 'on' or 'off', when running a version "
|
||||
"of mod_python < 3.1")
|
||||
|
||||
threaded = options.get('multithread', '').lower()
|
||||
if threaded == 'on':
|
||||
threaded = True
|
||||
elif threaded == 'off':
|
||||
threaded = False
|
||||
else:
|
||||
raise ValueError(bad_value % "multithread")
|
||||
|
||||
forked = options.get('multiprocess', '').lower()
|
||||
if forked == 'on':
|
||||
forked = True
|
||||
elif forked == 'off':
|
||||
forked = False
|
||||
else:
|
||||
raise ValueError(bad_value % "multiprocess")
|
||||
|
||||
sn = cherrypy.tree.script_name(req.uri or "/")
|
||||
if sn is None:
|
||||
send_response(req, '404 Not Found', [], '')
|
||||
else:
|
||||
app = cherrypy.tree.apps[sn]
|
||||
method = req.method
|
||||
path = req.uri
|
||||
qs = req.args or ""
|
||||
reqproto = req.protocol
|
||||
headers = copyitems(req.headers_in)
|
||||
rfile = _ReadOnlyRequest(req)
|
||||
prev = None
|
||||
|
||||
try:
|
||||
redirections = []
|
||||
while True:
|
||||
request, response = app.get_serving(local, remote, scheme,
|
||||
"HTTP/1.1")
|
||||
request.login = req.user
|
||||
request.multithread = bool(threaded)
|
||||
request.multiprocess = bool(forked)
|
||||
request.app = app
|
||||
request.prev = prev
|
||||
|
||||
# Run the CherryPy Request object and obtain the response
|
||||
try:
|
||||
request.run(method, path, qs, reqproto, headers, rfile)
|
||||
break
|
||||
except cherrypy.InternalRedirect:
|
||||
ir = sys.exc_info()[1]
|
||||
app.release_serving()
|
||||
prev = request
|
||||
|
||||
if not recursive:
|
||||
if ir.path in redirections:
|
||||
raise RuntimeError(
|
||||
"InternalRedirector visited the same URL "
|
||||
"twice: %r" % ir.path)
|
||||
else:
|
||||
# Add the *previous* path_info + qs to
|
||||
# redirections.
|
||||
if qs:
|
||||
qs = "?" + qs
|
||||
redirections.append(sn + path + qs)
|
||||
|
||||
# Munge environment and try again.
|
||||
method = "GET"
|
||||
path = ir.path
|
||||
qs = ir.query_string
|
||||
rfile = BytesIO()
|
||||
|
||||
send_response(
|
||||
req, response.output_status, response.header_list,
|
||||
response.body, response.stream)
|
||||
finally:
|
||||
app.release_serving()
|
||||
except:
|
||||
tb = format_exc()
|
||||
cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR)
|
||||
s, h, b = bare_error()
|
||||
send_response(req, s, h, b)
|
||||
return apache.OK
|
||||
|
||||
|
||||
def send_response(req, status, headers, body, stream=False):
|
||||
# Set response status
|
||||
req.status = int(status[:3])
|
||||
|
||||
# Set response headers
|
||||
req.content_type = "text/plain"
|
||||
for header, value in headers:
|
||||
if header.lower() == 'content-type':
|
||||
req.content_type = value
|
||||
continue
|
||||
req.headers_out.add(header, value)
|
||||
|
||||
if stream:
|
||||
# Flush now so the status and headers are sent immediately.
|
||||
req.flush()
|
||||
|
||||
# Set response body
|
||||
if isinstance(body, basestring):
|
||||
req.write(body)
|
||||
else:
|
||||
for seg in body:
|
||||
req.write(seg)
|
||||
|
||||
|
||||
# --------------- Startup tools for CherryPy + mod_python --------------- #
|
||||
import os
|
||||
import re
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
def popen(fullcmd):
|
||||
p = subprocess.Popen(fullcmd, shell=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
close_fds=True)
|
||||
return p.stdout
|
||||
except ImportError:
|
||||
def popen(fullcmd):
|
||||
pipein, pipeout = os.popen4(fullcmd)
|
||||
return pipeout
|
||||
|
||||
|
||||
def read_process(cmd, args=""):
|
||||
fullcmd = "%s %s" % (cmd, args)
|
||||
pipeout = popen(fullcmd)
|
||||
try:
|
||||
firstline = pipeout.readline()
|
||||
cmd_not_found = re.search(
|
||||
ntob("(not recognized|No such file|not found)"),
|
||||
firstline,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if cmd_not_found:
|
||||
raise IOError('%s must be on your system path.' % cmd)
|
||||
output = firstline + pipeout.read()
|
||||
finally:
|
||||
pipeout.close()
|
||||
return output
|
||||
|
||||
|
||||
class ModPythonServer(object):
|
||||
|
||||
template = """
|
||||
# Apache2 server configuration file for running CherryPy with mod_python.
|
||||
|
||||
DocumentRoot "/"
|
||||
Listen %(port)s
|
||||
LoadModule python_module modules/mod_python.so
|
||||
|
||||
<Location %(loc)s>
|
||||
SetHandler python-program
|
||||
PythonHandler %(handler)s
|
||||
PythonDebug On
|
||||
%(opts)s
|
||||
</Location>
|
||||
"""
|
||||
|
||||
def __init__(self, loc="/", port=80, opts=None, apache_path="apache",
|
||||
handler="cherrypy._cpmodpy::handler"):
|
||||
self.loc = loc
|
||||
self.port = port
|
||||
self.opts = opts
|
||||
self.apache_path = apache_path
|
||||
self.handler = handler
|
||||
|
||||
def start(self):
|
||||
opts = "".join([" PythonOption %s %s\n" % (k, v)
|
||||
for k, v in self.opts])
|
||||
conf_data = self.template % {"port": self.port,
|
||||
"loc": self.loc,
|
||||
"opts": opts,
|
||||
"handler": self.handler,
|
||||
}
|
||||
|
||||
mpconf = os.path.join(os.path.dirname(__file__), "cpmodpy.conf")
|
||||
f = open(mpconf, 'wb')
|
||||
try:
|
||||
f.write(conf_data)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
response = read_process(self.apache_path, "-k start -f %s" % mpconf)
|
||||
self.ready = True
|
||||
return response
|
||||
|
||||
def stop(self):
|
||||
os.popen("apache -k stop")
|
||||
self.ready = False
|
||||
154
lib/cherrypy/_cpnative_server.py
Normal file
154
lib/cherrypy/_cpnative_server.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Native adapter for serving CherryPy via its builtin server."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import BytesIO
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil
|
||||
from cherrypy import wsgiserver
|
||||
|
||||
|
||||
class NativeGateway(wsgiserver.Gateway):
|
||||
|
||||
recursive = False
|
||||
|
||||
def respond(self):
|
||||
req = self.req
|
||||
try:
|
||||
# Obtain a Request object from CherryPy
|
||||
local = req.server.bind_addr
|
||||
local = httputil.Host(local[0], local[1], "")
|
||||
remote = req.conn.remote_addr, req.conn.remote_port
|
||||
remote = httputil.Host(remote[0], remote[1], "")
|
||||
|
||||
scheme = req.scheme
|
||||
sn = cherrypy.tree.script_name(req.uri or "/")
|
||||
if sn is None:
|
||||
self.send_response('404 Not Found', [], [''])
|
||||
else:
|
||||
app = cherrypy.tree.apps[sn]
|
||||
method = req.method
|
||||
path = req.path
|
||||
qs = req.qs or ""
|
||||
headers = req.inheaders.items()
|
||||
rfile = req.rfile
|
||||
prev = None
|
||||
|
||||
try:
|
||||
redirections = []
|
||||
while True:
|
||||
request, response = app.get_serving(
|
||||
local, remote, scheme, "HTTP/1.1")
|
||||
request.multithread = True
|
||||
request.multiprocess = False
|
||||
request.app = app
|
||||
request.prev = prev
|
||||
|
||||
# Run the CherryPy Request object and obtain the
|
||||
# response
|
||||
try:
|
||||
request.run(method, path, qs,
|
||||
req.request_protocol, headers, rfile)
|
||||
break
|
||||
except cherrypy.InternalRedirect:
|
||||
ir = sys.exc_info()[1]
|
||||
app.release_serving()
|
||||
prev = request
|
||||
|
||||
if not self.recursive:
|
||||
if ir.path in redirections:
|
||||
raise RuntimeError(
|
||||
"InternalRedirector visited the same "
|
||||
"URL twice: %r" % ir.path)
|
||||
else:
|
||||
# Add the *previous* path_info + qs to
|
||||
# redirections.
|
||||
if qs:
|
||||
qs = "?" + qs
|
||||
redirections.append(sn + path + qs)
|
||||
|
||||
# Munge environment and try again.
|
||||
method = "GET"
|
||||
path = ir.path
|
||||
qs = ir.query_string
|
||||
rfile = BytesIO()
|
||||
|
||||
self.send_response(
|
||||
response.output_status, response.header_list,
|
||||
response.body)
|
||||
finally:
|
||||
app.release_serving()
|
||||
except:
|
||||
tb = format_exc()
|
||||
# print tb
|
||||
cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR)
|
||||
s, h, b = bare_error()
|
||||
self.send_response(s, h, b)
|
||||
|
||||
def send_response(self, status, headers, body):
|
||||
req = self.req
|
||||
|
||||
# Set response status
|
||||
req.status = str(status or "500 Server Error")
|
||||
|
||||
# Set response headers
|
||||
for header, value in headers:
|
||||
req.outheaders.append((header, value))
|
||||
if (req.ready and not req.sent_headers):
|
||||
req.sent_headers = True
|
||||
req.send_headers()
|
||||
|
||||
# Set response body
|
||||
for seg in body:
|
||||
req.write(seg)
|
||||
|
||||
|
||||
class CPHTTPServer(wsgiserver.HTTPServer):
|
||||
|
||||
"""Wrapper for wsgiserver.HTTPServer.
|
||||
|
||||
wsgiserver has been designed to not reference CherryPy in any way,
|
||||
so that it can be used in other frameworks and applications.
|
||||
Therefore, we wrap it here, so we can apply some attributes
|
||||
from config -> cherrypy.server -> HTTPServer.
|
||||
"""
|
||||
|
||||
def __init__(self, server_adapter=cherrypy.server):
|
||||
self.server_adapter = server_adapter
|
||||
|
||||
server_name = (self.server_adapter.socket_host or
|
||||
self.server_adapter.socket_file or
|
||||
None)
|
||||
|
||||
wsgiserver.HTTPServer.__init__(
|
||||
self, server_adapter.bind_addr, NativeGateway,
|
||||
minthreads=server_adapter.thread_pool,
|
||||
maxthreads=server_adapter.thread_pool_max,
|
||||
server_name=server_name)
|
||||
|
||||
self.max_request_header_size = (
|
||||
self.server_adapter.max_request_header_size or 0)
|
||||
self.max_request_body_size = (
|
||||
self.server_adapter.max_request_body_size or 0)
|
||||
self.request_queue_size = self.server_adapter.socket_queue_size
|
||||
self.timeout = self.server_adapter.socket_timeout
|
||||
self.shutdown_timeout = self.server_adapter.shutdown_timeout
|
||||
self.protocol = self.server_adapter.protocol_version
|
||||
self.nodelay = self.server_adapter.nodelay
|
||||
|
||||
ssl_module = self.server_adapter.ssl_module or 'pyopenssl'
|
||||
if self.server_adapter.ssl_context:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
self.ssl_adapter.context = self.server_adapter.ssl_context
|
||||
elif self.server_adapter.ssl_certificate:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
1013
lib/cherrypy/_cpreqbody.py
Normal file
1013
lib/cherrypy/_cpreqbody.py
Normal file
File diff suppressed because it is too large
Load Diff
973
lib/cherrypy/_cprequest.py
Normal file
973
lib/cherrypy/_cprequest.py
Normal file
@@ -0,0 +1,973 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import basestring, copykeys, ntob, unicodestr
|
||||
from cherrypy._cpcompat import SimpleCookie, CookieError, py3k
|
||||
from cherrypy import _cpreqbody, _cpconfig
|
||||
from cherrypy._cperror import format_exc, bare_error
|
||||
from cherrypy.lib import httputil, file_generator
|
||||
|
||||
|
||||
class Hook(object):
|
||||
|
||||
"""A callback and its metadata: failsafe, priority, and kwargs."""
|
||||
|
||||
callback = None
|
||||
"""
|
||||
The bare callable that this Hook object is wrapping, which will
|
||||
be called when the Hook is called."""
|
||||
|
||||
failsafe = False
|
||||
"""
|
||||
If True, the callback is guaranteed to run even if other callbacks
|
||||
from the same call point raise exceptions."""
|
||||
|
||||
priority = 50
|
||||
"""
|
||||
Defines the order of execution for a list of Hooks. Priority numbers
|
||||
should be limited to the closed interval [0, 100], but values outside
|
||||
this range are acceptable, as are fractional values."""
|
||||
|
||||
kwargs = {}
|
||||
"""
|
||||
A set of keyword arguments that will be passed to the
|
||||
callable on each call."""
|
||||
|
||||
def __init__(self, callback, failsafe=None, priority=None, **kwargs):
|
||||
self.callback = callback
|
||||
|
||||
if failsafe is None:
|
||||
failsafe = getattr(callback, "failsafe", False)
|
||||
self.failsafe = failsafe
|
||||
|
||||
if priority is None:
|
||||
priority = getattr(callback, "priority", 50)
|
||||
self.priority = priority
|
||||
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __lt__(self, other):
|
||||
# Python 3
|
||||
return self.priority < other.priority
|
||||
|
||||
def __cmp__(self, other):
|
||||
# Python 2
|
||||
return cmp(self.priority, other.priority)
|
||||
|
||||
def __call__(self):
|
||||
"""Run self.callback(**self.kwargs)."""
|
||||
return self.callback(**self.kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
cls = self.__class__
|
||||
return ("%s.%s(callback=%r, failsafe=%r, priority=%r, %s)"
|
||||
% (cls.__module__, cls.__name__, self.callback,
|
||||
self.failsafe, self.priority,
|
||||
", ".join(['%s=%r' % (k, v)
|
||||
for k, v in self.kwargs.items()])))
|
||||
|
||||
|
||||
class HookMap(dict):
|
||||
|
||||
"""A map of call points to lists of callbacks (Hook objects)."""
|
||||
|
||||
def __new__(cls, points=None):
|
||||
d = dict.__new__(cls)
|
||||
for p in points or []:
|
||||
d[p] = []
|
||||
return d
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def attach(self, point, callback, failsafe=None, priority=None, **kwargs):
|
||||
"""Append a new Hook made from the supplied arguments."""
|
||||
self[point].append(Hook(callback, failsafe, priority, **kwargs))
|
||||
|
||||
def run(self, point):
|
||||
"""Execute all registered Hooks (callbacks) for the given point."""
|
||||
exc = None
|
||||
hooks = self[point]
|
||||
hooks.sort()
|
||||
for hook in hooks:
|
||||
# Some hooks are guaranteed to run even if others at
|
||||
# the same hookpoint fail. We will still log the failure,
|
||||
# but proceed on to the next hook. The only way
|
||||
# to stop all processing from one of these hooks is
|
||||
# to raise SystemExit and stop the whole server.
|
||||
if exc is None or hook.failsafe:
|
||||
try:
|
||||
hook()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
raise
|
||||
except (cherrypy.HTTPError, cherrypy.HTTPRedirect,
|
||||
cherrypy.InternalRedirect):
|
||||
exc = sys.exc_info()[1]
|
||||
except:
|
||||
exc = sys.exc_info()[1]
|
||||
cherrypy.log(traceback=True, severity=40)
|
||||
if exc:
|
||||
raise exc
|
||||
|
||||
def __copy__(self):
|
||||
newmap = self.__class__()
|
||||
# We can't just use 'update' because we want copies of the
|
||||
# mutable values (each is a list) as well.
|
||||
for k, v in self.items():
|
||||
newmap[k] = v[:]
|
||||
return newmap
|
||||
copy = __copy__
|
||||
|
||||
def __repr__(self):
|
||||
cls = self.__class__
|
||||
return "%s.%s(points=%r)" % (
|
||||
cls.__module__,
|
||||
cls.__name__,
|
||||
copykeys(self)
|
||||
)
|
||||
|
||||
|
||||
# Config namespace handlers
|
||||
|
||||
def hooks_namespace(k, v):
|
||||
"""Attach bare hooks declared in config."""
|
||||
# Use split again to allow multiple hooks for a single
|
||||
# hookpoint per path (e.g. "hooks.before_handler.1").
|
||||
# Little-known fact you only get from reading source ;)
|
||||
hookpoint = k.split(".", 1)[0]
|
||||
if isinstance(v, basestring):
|
||||
v = cherrypy.lib.attributes(v)
|
||||
if not isinstance(v, Hook):
|
||||
v = Hook(v)
|
||||
cherrypy.serving.request.hooks[hookpoint].append(v)
|
||||
|
||||
|
||||
def request_namespace(k, v):
|
||||
"""Attach request attributes declared in config."""
|
||||
# Provides config entries to set request.body attrs (like
|
||||
# attempt_charsets).
|
||||
if k[:5] == 'body.':
|
||||
setattr(cherrypy.serving.request.body, k[5:], v)
|
||||
else:
|
||||
setattr(cherrypy.serving.request, k, v)
|
||||
|
||||
|
||||
def response_namespace(k, v):
|
||||
"""Attach response attributes declared in config."""
|
||||
# Provides config entries to set default response headers
|
||||
# http://cherrypy.org/ticket/889
|
||||
if k[:8] == 'headers.':
|
||||
cherrypy.serving.response.headers[k.split('.', 1)[1]] = v
|
||||
else:
|
||||
setattr(cherrypy.serving.response, k, v)
|
||||
|
||||
|
||||
def error_page_namespace(k, v):
|
||||
"""Attach error pages declared in config."""
|
||||
if k != 'default':
|
||||
k = int(k)
|
||||
cherrypy.serving.request.error_page[k] = v
|
||||
|
||||
|
||||
hookpoints = ['on_start_resource', 'before_request_body',
|
||||
'before_handler', 'before_finalize',
|
||||
'on_end_resource', 'on_end_request',
|
||||
'before_error_response', 'after_error_response']
|
||||
|
||||
|
||||
class Request(object):
|
||||
|
||||
"""An HTTP request.
|
||||
|
||||
This object represents the metadata of an HTTP request message;
|
||||
that is, it contains attributes which describe the environment
|
||||
in which the request URL, headers, and body were sent (if you
|
||||
want tools to interpret the headers and body, those are elsewhere,
|
||||
mostly in Tools). This 'metadata' consists of socket data,
|
||||
transport characteristics, and the Request-Line. This object
|
||||
also contains data regarding the configuration in effect for
|
||||
the given URL, and the execution plan for generating a response.
|
||||
"""
|
||||
|
||||
prev = None
|
||||
"""
|
||||
The previous Request object (if any). This should be None
|
||||
unless we are processing an InternalRedirect."""
|
||||
|
||||
# Conversation/connection attributes
|
||||
local = httputil.Host("127.0.0.1", 80)
|
||||
"An httputil.Host(ip, port, hostname) object for the server socket."
|
||||
|
||||
remote = httputil.Host("127.0.0.1", 1111)
|
||||
"An httputil.Host(ip, port, hostname) object for the client socket."
|
||||
|
||||
scheme = "http"
|
||||
"""
|
||||
The protocol used between client and server. In most cases,
|
||||
this will be either 'http' or 'https'."""
|
||||
|
||||
server_protocol = "HTTP/1.1"
|
||||
"""
|
||||
The HTTP version for which the HTTP server is at least
|
||||
conditionally compliant."""
|
||||
|
||||
base = ""
|
||||
"""The (scheme://host) portion of the requested URL.
|
||||
In some cases (e.g. when proxying via mod_rewrite), this may contain
|
||||
path segments which cherrypy.url uses when constructing url's, but
|
||||
which otherwise are ignored by CherryPy. Regardless, this value
|
||||
MUST NOT end in a slash."""
|
||||
|
||||
# Request-Line attributes
|
||||
request_line = ""
|
||||
"""
|
||||
The complete Request-Line received from the client. This is a
|
||||
single string consisting of the request method, URI, and protocol
|
||||
version (joined by spaces). Any final CRLF is removed."""
|
||||
|
||||
method = "GET"
|
||||
"""
|
||||
Indicates the HTTP method to be performed on the resource identified
|
||||
by the Request-URI. Common methods include GET, HEAD, POST, PUT, and
|
||||
DELETE. CherryPy allows any extension method; however, various HTTP
|
||||
servers and gateways may restrict the set of allowable methods.
|
||||
CherryPy applications SHOULD restrict the set (on a per-URI basis)."""
|
||||
|
||||
query_string = ""
|
||||
"""
|
||||
The query component of the Request-URI, a string of information to be
|
||||
interpreted by the resource. The query portion of a URI follows the
|
||||
path component, and is separated by a '?'. For example, the URI
|
||||
'http://www.cherrypy.org/wiki?a=3&b=4' has the query component,
|
||||
'a=3&b=4'."""
|
||||
|
||||
query_string_encoding = 'utf8'
|
||||
"""
|
||||
The encoding expected for query string arguments after % HEX HEX decoding).
|
||||
If a query string is provided that cannot be decoded with this encoding,
|
||||
404 is raised (since technically it's a different URI). If you want
|
||||
arbitrary encodings to not error, set this to 'Latin-1'; you can then
|
||||
encode back to bytes and re-decode to whatever encoding you like later.
|
||||
"""
|
||||
|
||||
protocol = (1, 1)
|
||||
"""The HTTP protocol version corresponding to the set
|
||||
of features which should be allowed in the response. If BOTH
|
||||
the client's request message AND the server's level of HTTP
|
||||
compliance is HTTP/1.1, this attribute will be the tuple (1, 1).
|
||||
If either is 1.0, this attribute will be the tuple (1, 0).
|
||||
Lower HTTP protocol versions are not explicitly supported."""
|
||||
|
||||
params = {}
|
||||
"""
|
||||
A dict which combines query string (GET) and request entity (POST)
|
||||
variables. This is populated in two stages: GET params are added
|
||||
before the 'on_start_resource' hook, and POST params are added
|
||||
between the 'before_request_body' and 'before_handler' hooks."""
|
||||
|
||||
# Message attributes
|
||||
header_list = []
|
||||
"""
|
||||
A list of the HTTP request headers as (name, value) tuples.
|
||||
In general, you should use request.headers (a dict) instead."""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
A dict-like object containing the request headers. Keys are header
|
||||
names (in Title-Case format); however, you may get and set them in
|
||||
a case-insensitive manner. That is, headers['Content-Type'] and
|
||||
headers['content-type'] refer to the same value. Values are header
|
||||
values (decoded according to :rfc:`2047` if necessary). See also:
|
||||
httputil.HeaderMap, httputil.HeaderElement."""
|
||||
|
||||
cookie = SimpleCookie()
|
||||
"""See help(Cookie)."""
|
||||
|
||||
rfile = None
|
||||
"""
|
||||
If the request included an entity (body), it will be available
|
||||
as a stream in this attribute. However, the rfile will normally
|
||||
be read for you between the 'before_request_body' hook and the
|
||||
'before_handler' hook, and the resulting string is placed into
|
||||
either request.params or the request.body attribute.
|
||||
|
||||
You may disable the automatic consumption of the rfile by setting
|
||||
request.process_request_body to False, either in config for the desired
|
||||
path, or in an 'on_start_resource' or 'before_request_body' hook.
|
||||
|
||||
WARNING: In almost every case, you should not attempt to read from the
|
||||
rfile stream after CherryPy's automatic mechanism has read it. If you
|
||||
turn off the automatic parsing of rfile, you should read exactly the
|
||||
number of bytes specified in request.headers['Content-Length'].
|
||||
Ignoring either of these warnings may result in a hung request thread
|
||||
or in corruption of the next (pipelined) request.
|
||||
"""
|
||||
|
||||
process_request_body = True
|
||||
"""
|
||||
If True, the rfile (if any) is automatically read and parsed,
|
||||
and the result placed into request.params or request.body."""
|
||||
|
||||
methods_with_bodies = ("POST", "PUT")
|
||||
"""
|
||||
A sequence of HTTP methods for which CherryPy will automatically
|
||||
attempt to read a body from the rfile. If you are going to change
|
||||
this property, modify it on the configuration (recommended)
|
||||
or on the "hook point" `on_start_resource`.
|
||||
"""
|
||||
|
||||
body = None
|
||||
"""
|
||||
If the request Content-Type is 'application/x-www-form-urlencoded'
|
||||
or multipart, this will be None. Otherwise, this will be an instance
|
||||
of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you
|
||||
can .read()); this value is set between the 'before_request_body' and
|
||||
'before_handler' hooks (assuming that process_request_body is True)."""
|
||||
|
||||
# Dispatch attributes
|
||||
dispatch = cherrypy.dispatch.Dispatcher()
|
||||
"""
|
||||
The object which looks up the 'page handler' callable and collects
|
||||
config for the current request based on the path_info, other
|
||||
request attributes, and the application architecture. The core
|
||||
calls the dispatcher as early as possible, passing it a 'path_info'
|
||||
argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
See help(cherrypy.dispatch) for more information."""
|
||||
|
||||
script_name = ""
|
||||
"""
|
||||
The 'mount point' of the application which is handling this request.
|
||||
|
||||
This attribute MUST NOT end in a slash. If the script_name refers to
|
||||
the root of the URI, it MUST be an empty string (not "/").
|
||||
"""
|
||||
|
||||
path_info = "/"
|
||||
"""
|
||||
The 'relative path' portion of the Request-URI. This is relative
|
||||
to the script_name ('mount point') of the application which is
|
||||
handling this request."""
|
||||
|
||||
login = None
|
||||
"""
|
||||
When authentication is used during the request processing this is
|
||||
set to 'False' if it failed and to the 'username' value if it succeeded.
|
||||
The default 'None' implies that no authentication happened."""
|
||||
|
||||
# Note that cherrypy.url uses "if request.app:" to determine whether
|
||||
# the call is during a real HTTP request or not. So leave this None.
|
||||
app = None
|
||||
"""The cherrypy.Application object which is handling this request."""
|
||||
|
||||
handler = None
|
||||
"""
|
||||
The function, method, or other callable which CherryPy will call to
|
||||
produce the response. The discovery of the handler and the arguments
|
||||
it will receive are determined by the request.dispatch object.
|
||||
By default, the handler is discovered by walking a tree of objects
|
||||
starting at request.app.root, and is then passed all HTTP params
|
||||
(from the query string and POST body) as keyword arguments."""
|
||||
|
||||
toolmaps = {}
|
||||
"""
|
||||
A nested dict of all Toolboxes and Tools in effect for this request,
|
||||
of the form: {Toolbox.namespace: {Tool.name: config dict}}."""
|
||||
|
||||
config = None
|
||||
"""
|
||||
A flat dict of all configuration entries which apply to the
|
||||
current request. These entries are collected from global config,
|
||||
application config (based on request.path_info), and from handler
|
||||
config (exactly how is governed by the request.dispatch object in
|
||||
effect for this request; by default, handler config can be attached
|
||||
anywhere in the tree between request.app.root and the final handler,
|
||||
and inherits downward)."""
|
||||
|
||||
is_index = None
|
||||
"""
|
||||
This will be True if the current request is mapped to an 'index'
|
||||
resource handler (also, a 'default' handler if path_info ends with
|
||||
a slash). The value may be used to automatically redirect the
|
||||
user-agent to a 'more canonical' URL which either adds or removes
|
||||
the trailing slash. See cherrypy.tools.trailing_slash."""
|
||||
|
||||
hooks = HookMap(hookpoints)
|
||||
"""
|
||||
A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
|
||||
Each key is a str naming the hook point, and each value is a list
|
||||
of hooks which will be called at that hook point during this request.
|
||||
The list of hooks is generally populated as early as possible (mostly
|
||||
from Tools specified in config), but may be extended at any time.
|
||||
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools."""
|
||||
|
||||
error_response = cherrypy.HTTPError(500).set_response
|
||||
"""
|
||||
The no-arg callable which will handle unexpected, untrapped errors
|
||||
during request processing. This is not used for expected exceptions
|
||||
(like NotFound, HTTPError, or HTTPRedirect) which are raised in
|
||||
response to expected conditions (those should be customized either
|
||||
via request.error_page or by overriding HTTPError.set_response).
|
||||
By default, error_response uses HTTPError(500) to return a generic
|
||||
error response to the user-agent."""
|
||||
|
||||
error_page = {}
|
||||
"""
|
||||
A dict of {error code: response filename or callable} pairs.
|
||||
|
||||
The error code must be an int representing a given HTTP error code,
|
||||
or the string 'default', which will be used if no matching entry
|
||||
is found for a given numeric code.
|
||||
|
||||
If a filename is provided, the file should contain a Python string-
|
||||
formatting template, and can expect by default to receive format
|
||||
values with the mapping keys %(status)s, %(message)s, %(traceback)s,
|
||||
and %(version)s. The set of format mappings can be extended by
|
||||
overriding HTTPError.set_response.
|
||||
|
||||
If a callable is provided, it will be called by default with keyword
|
||||
arguments 'status', 'message', 'traceback', and 'version', as for a
|
||||
string-formatting template. The callable must return a string or
|
||||
iterable of strings which will be set to response.body. It may also
|
||||
override headers or perform any other processing.
|
||||
|
||||
If no entry is given for an error code, and no 'default' entry exists,
|
||||
a default template will be used.
|
||||
"""
|
||||
|
||||
show_tracebacks = True
|
||||
"""
|
||||
If True, unexpected errors encountered during request processing will
|
||||
include a traceback in the response body."""
|
||||
|
||||
show_mismatched_params = True
|
||||
"""
|
||||
If True, mismatched parameters encountered during PageHandler invocation
|
||||
processing will be included in the response body."""
|
||||
|
||||
throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect)
|
||||
"""The sequence of exceptions which Request.run does not trap."""
|
||||
|
||||
throw_errors = False
|
||||
"""
|
||||
If True, Request.run will not trap any errors (except HTTPRedirect and
|
||||
HTTPError, which are more properly called 'exceptions', not errors)."""
|
||||
|
||||
closed = False
|
||||
"""True once the close method has been called, False otherwise."""
|
||||
|
||||
stage = None
|
||||
"""
|
||||
A string containing the stage reached in the request-handling process.
|
||||
This is useful when debugging a live server with hung requests."""
|
||||
|
||||
namespaces = _cpconfig.NamespaceSet(
|
||||
**{"hooks": hooks_namespace,
|
||||
"request": request_namespace,
|
||||
"response": response_namespace,
|
||||
"error_page": error_page_namespace,
|
||||
"tools": cherrypy.tools,
|
||||
})
|
||||
|
||||
def __init__(self, local_host, remote_host, scheme="http",
|
||||
server_protocol="HTTP/1.1"):
|
||||
"""Populate a new Request object.
|
||||
|
||||
local_host should be an httputil.Host object with the server info.
|
||||
remote_host should be an httputil.Host object with the client info.
|
||||
scheme should be a string, either "http" or "https".
|
||||
"""
|
||||
self.local = local_host
|
||||
self.remote = remote_host
|
||||
self.scheme = scheme
|
||||
self.server_protocol = server_protocol
|
||||
|
||||
self.closed = False
|
||||
|
||||
# Put a *copy* of the class error_page into self.
|
||||
self.error_page = self.error_page.copy()
|
||||
|
||||
# Put a *copy* of the class namespaces into self.
|
||||
self.namespaces = self.namespaces.copy()
|
||||
|
||||
self.stage = None
|
||||
|
||||
def close(self):
|
||||
"""Run cleanup code. (Core)"""
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.stage = 'on_end_request'
|
||||
self.hooks.run('on_end_request')
|
||||
self.stage = 'close'
|
||||
|
||||
def run(self, method, path, query_string, req_protocol, headers, rfile):
|
||||
r"""Process the Request. (Core)
|
||||
|
||||
method, path, query_string, and req_protocol should be pulled directly
|
||||
from the Request-Line (e.g. "GET /path?key=val HTTP/1.0").
|
||||
|
||||
path
|
||||
This should be %XX-unquoted, but query_string should not be.
|
||||
|
||||
When using Python 2, they both MUST be byte strings,
|
||||
not unicode strings.
|
||||
|
||||
When using Python 3, they both MUST be unicode strings,
|
||||
not byte strings, and preferably not bytes \x00-\xFF
|
||||
disguised as unicode.
|
||||
|
||||
headers
|
||||
A list of (name, value) tuples.
|
||||
|
||||
rfile
|
||||
A file-like object containing the HTTP request entity.
|
||||
|
||||
When run() is done, the returned object should have 3 attributes:
|
||||
|
||||
* status, e.g. "200 OK"
|
||||
* header_list, a list of (name, value) tuples
|
||||
* body, an iterable yielding strings
|
||||
|
||||
Consumer code (HTTP servers) should then access these response
|
||||
attributes to build the outbound stream.
|
||||
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
self.stage = 'run'
|
||||
try:
|
||||
self.error_response = cherrypy.HTTPError(500).set_response
|
||||
|
||||
self.method = method
|
||||
path = path or "/"
|
||||
self.query_string = query_string or ''
|
||||
self.params = {}
|
||||
|
||||
# Compare request and server HTTP protocol versions, in case our
|
||||
# server does not support the requested protocol. Limit our output
|
||||
# to min(req, server). We want the following output:
|
||||
# request server actual written supported response
|
||||
# protocol protocol response protocol feature set
|
||||
# a 1.0 1.0 1.0 1.0
|
||||
# b 1.0 1.1 1.1 1.0
|
||||
# c 1.1 1.0 1.0 1.0
|
||||
# d 1.1 1.1 1.1 1.1
|
||||
# Notice that, in (b), the response will be "HTTP/1.1" even though
|
||||
# the client only understands 1.0. RFC 2616 10.5.6 says we should
|
||||
# only return 505 if the _major_ version is different.
|
||||
rp = int(req_protocol[5]), int(req_protocol[7])
|
||||
sp = int(self.server_protocol[5]), int(self.server_protocol[7])
|
||||
self.protocol = min(rp, sp)
|
||||
response.headers.protocol = self.protocol
|
||||
|
||||
# Rebuild first line of the request (e.g. "GET /path HTTP/1.0").
|
||||
url = path
|
||||
if query_string:
|
||||
url += '?' + query_string
|
||||
self.request_line = '%s %s %s' % (method, url, req_protocol)
|
||||
|
||||
self.header_list = list(headers)
|
||||
self.headers = httputil.HeaderMap()
|
||||
|
||||
self.rfile = rfile
|
||||
self.body = None
|
||||
|
||||
self.cookie = SimpleCookie()
|
||||
self.handler = None
|
||||
|
||||
# path_info should be the path from the
|
||||
# app root (script_name) to the handler.
|
||||
self.script_name = self.app.script_name
|
||||
self.path_info = pi = path[len(self.script_name):]
|
||||
|
||||
self.stage = 'respond'
|
||||
self.respond(pi)
|
||||
|
||||
except self.throws:
|
||||
raise
|
||||
except:
|
||||
if self.throw_errors:
|
||||
raise
|
||||
else:
|
||||
# Failure in setup, error handler or finalize. Bypass them.
|
||||
# Can't use handle_error because we may not have hooks yet.
|
||||
cherrypy.log(traceback=True, severity=40)
|
||||
if self.show_tracebacks:
|
||||
body = format_exc()
|
||||
else:
|
||||
body = ""
|
||||
r = bare_error(body)
|
||||
response.output_status, response.header_list, response.body = r
|
||||
|
||||
if self.method == "HEAD":
|
||||
# HEAD requests MUST NOT return a message-body in the response.
|
||||
response.body = []
|
||||
|
||||
try:
|
||||
cherrypy.log.access()
|
||||
except:
|
||||
cherrypy.log.error(traceback=True)
|
||||
|
||||
if response.timed_out:
|
||||
raise cherrypy.TimeoutError()
|
||||
|
||||
return response
|
||||
|
||||
# Uncomment for stage debugging
|
||||
# stage = property(lambda self: self._stage, lambda self, v: print(v))
|
||||
|
||||
def respond(self, path_info):
|
||||
"""Generate a response for the resource at self.path_info. (Core)"""
|
||||
response = cherrypy.serving.response
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
if self.app is None:
|
||||
raise cherrypy.NotFound()
|
||||
|
||||
# Get the 'Host' header, so we can HTTPRedirect properly.
|
||||
self.stage = 'process_headers'
|
||||
self.process_headers()
|
||||
|
||||
# Make a copy of the class hooks
|
||||
self.hooks = self.__class__.hooks.copy()
|
||||
self.toolmaps = {}
|
||||
|
||||
self.stage = 'get_resource'
|
||||
self.get_resource(path_info)
|
||||
|
||||
self.body = _cpreqbody.RequestBody(
|
||||
self.rfile, self.headers, request_params=self.params)
|
||||
|
||||
self.namespaces(self.config)
|
||||
|
||||
self.stage = 'on_start_resource'
|
||||
self.hooks.run('on_start_resource')
|
||||
|
||||
# Parse the querystring
|
||||
self.stage = 'process_query_string'
|
||||
self.process_query_string()
|
||||
|
||||
# Process the body
|
||||
if self.process_request_body:
|
||||
if self.method not in self.methods_with_bodies:
|
||||
self.process_request_body = False
|
||||
self.stage = 'before_request_body'
|
||||
self.hooks.run('before_request_body')
|
||||
if self.process_request_body:
|
||||
self.body.process()
|
||||
|
||||
# Run the handler
|
||||
self.stage = 'before_handler'
|
||||
self.hooks.run('before_handler')
|
||||
if self.handler:
|
||||
self.stage = 'handler'
|
||||
response.body = self.handler()
|
||||
|
||||
# Finalize
|
||||
self.stage = 'before_finalize'
|
||||
self.hooks.run('before_finalize')
|
||||
response.finalize()
|
||||
except (cherrypy.HTTPRedirect, cherrypy.HTTPError):
|
||||
inst = sys.exc_info()[1]
|
||||
inst.set_response()
|
||||
self.stage = 'before_finalize (HTTPError)'
|
||||
self.hooks.run('before_finalize')
|
||||
response.finalize()
|
||||
finally:
|
||||
self.stage = 'on_end_resource'
|
||||
self.hooks.run('on_end_resource')
|
||||
except self.throws:
|
||||
raise
|
||||
except:
|
||||
if self.throw_errors:
|
||||
raise
|
||||
self.handle_error()
|
||||
|
||||
def process_query_string(self):
|
||||
"""Parse the query string into Python structures. (Core)"""
|
||||
try:
|
||||
p = httputil.parse_query_string(
|
||||
self.query_string, encoding=self.query_string_encoding)
|
||||
except UnicodeDecodeError:
|
||||
raise cherrypy.HTTPError(
|
||||
404, "The given query string could not be processed. Query "
|
||||
"strings for this resource must be encoded with %r." %
|
||||
self.query_string_encoding)
|
||||
|
||||
# Python 2 only: keyword arguments must be byte strings (type 'str').
|
||||
if not py3k:
|
||||
for key, value in p.items():
|
||||
if isinstance(key, unicode):
|
||||
del p[key]
|
||||
p[key.encode(self.query_string_encoding)] = value
|
||||
self.params.update(p)
|
||||
|
||||
def process_headers(self):
|
||||
"""Parse HTTP header data into Python structures. (Core)"""
|
||||
# Process the headers into self.headers
|
||||
headers = self.headers
|
||||
for name, value in self.header_list:
|
||||
# Call title() now (and use dict.__method__(headers))
|
||||
# so title doesn't have to be called twice.
|
||||
name = name.title()
|
||||
value = value.strip()
|
||||
|
||||
# Warning: if there is more than one header entry for cookies
|
||||
# (AFAIK, only Konqueror does that), only the last one will
|
||||
# remain in headers (but they will be correctly stored in
|
||||
# request.cookie).
|
||||
if "=?" in value:
|
||||
dict.__setitem__(headers, name, httputil.decode_TEXT(value))
|
||||
else:
|
||||
dict.__setitem__(headers, name, value)
|
||||
|
||||
# Handle cookies differently because on Konqueror, multiple
|
||||
# cookies come on different lines with the same key
|
||||
if name == 'Cookie':
|
||||
try:
|
||||
self.cookie.load(value)
|
||||
except CookieError:
|
||||
msg = "Illegal cookie name %s" % value.split('=')[0]
|
||||
raise cherrypy.HTTPError(400, msg)
|
||||
|
||||
if not dict.__contains__(headers, 'Host'):
|
||||
# All Internet-based HTTP/1.1 servers MUST respond with a 400
|
||||
# (Bad Request) status code to any HTTP/1.1 request message
|
||||
# which lacks a Host header field.
|
||||
if self.protocol >= (1, 1):
|
||||
msg = "HTTP/1.1 requires a 'Host' request header."
|
||||
raise cherrypy.HTTPError(400, msg)
|
||||
host = dict.get(headers, 'Host')
|
||||
if not host:
|
||||
host = self.local.name or self.local.ip
|
||||
self.base = "%s://%s" % (self.scheme, host)
|
||||
|
||||
def get_resource(self, path):
|
||||
"""Call a dispatcher (which sets self.handler and .config). (Core)"""
|
||||
# First, see if there is a custom dispatch at this URI. Custom
|
||||
# dispatchers can only be specified in app.config, not in _cp_config
|
||||
# (since custom dispatchers may not even have an app.root).
|
||||
dispatch = self.app.find_config(
|
||||
path, "request.dispatch", self.dispatch)
|
||||
|
||||
# dispatch() should set self.handler and self.config
|
||||
dispatch(path)
|
||||
|
||||
def handle_error(self):
|
||||
"""Handle the last unanticipated exception. (Core)"""
|
||||
try:
|
||||
self.hooks.run("before_error_response")
|
||||
if self.error_response:
|
||||
self.error_response()
|
||||
self.hooks.run("after_error_response")
|
||||
cherrypy.serving.response.finalize()
|
||||
except cherrypy.HTTPRedirect:
|
||||
inst = sys.exc_info()[1]
|
||||
inst.set_response()
|
||||
cherrypy.serving.response.finalize()
|
||||
|
||||
# ------------------------- Properties ------------------------- #
|
||||
|
||||
def _get_body_params(self):
|
||||
warnings.warn(
|
||||
"body_params is deprecated in CherryPy 3.2, will be removed in "
|
||||
"CherryPy 3.3.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self.body.params
|
||||
body_params = property(_get_body_params,
|
||||
doc="""
|
||||
If the request Content-Type is 'application/x-www-form-urlencoded' or
|
||||
multipart, this will be a dict of the params pulled from the entity
|
||||
body; that is, it will be the portion of request.params that come
|
||||
from the message body (sometimes called "POST params", although they
|
||||
can be sent with various HTTP method verbs). This value is set between
|
||||
the 'before_request_body' and 'before_handler' hooks (assuming that
|
||||
process_request_body is True).
|
||||
|
||||
Deprecated in 3.2, will be removed for 3.3 in favor of
|
||||
:attr:`request.body.params<cherrypy._cprequest.RequestBody.params>`.""")
|
||||
|
||||
|
||||
class ResponseBody(object):
|
||||
|
||||
"""The body of the HTTP response (the response entity)."""
|
||||
|
||||
if py3k:
|
||||
unicode_err = ("Page handlers MUST return bytes. Use tools.encode "
|
||||
"if you wish to return unicode.")
|
||||
|
||||
def __get__(self, obj, objclass=None):
|
||||
if obj is None:
|
||||
# When calling on the class instead of an instance...
|
||||
return self
|
||||
else:
|
||||
return obj._body
|
||||
|
||||
def __set__(self, obj, value):
|
||||
# Convert the given value to an iterable object.
|
||||
if py3k and isinstance(value, str):
|
||||
raise ValueError(self.unicode_err)
|
||||
|
||||
if isinstance(value, basestring):
|
||||
# strings get wrapped in a list because iterating over a single
|
||||
# item list is much faster than iterating over every character
|
||||
# in a long string.
|
||||
if value:
|
||||
value = [value]
|
||||
else:
|
||||
# [''] doesn't evaluate to False, so replace it with [].
|
||||
value = []
|
||||
elif py3k and isinstance(value, list):
|
||||
# every item in a list must be bytes...
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, str):
|
||||
raise ValueError(self.unicode_err)
|
||||
# Don't use isinstance here; io.IOBase which has an ABC takes
|
||||
# 1000 times as long as, say, isinstance(value, str)
|
||||
elif hasattr(value, 'read'):
|
||||
value = file_generator(value)
|
||||
elif value is None:
|
||||
value = []
|
||||
obj._body = value
|
||||
|
||||
|
||||
class Response(object):
|
||||
|
||||
"""An HTTP Response, including status, headers, and body."""
|
||||
|
||||
status = ""
|
||||
"""The HTTP Status-Code and Reason-Phrase."""
|
||||
|
||||
header_list = []
|
||||
"""
|
||||
A list of the HTTP response headers as (name, value) tuples.
|
||||
In general, you should use response.headers (a dict) instead. This
|
||||
attribute is generated from response.headers and is not valid until
|
||||
after the finalize phase."""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
A dict-like object containing the response headers. Keys are header
|
||||
names (in Title-Case format); however, you may get and set them in
|
||||
a case-insensitive manner. That is, headers['Content-Type'] and
|
||||
headers['content-type'] refer to the same value. Values are header
|
||||
values (decoded according to :rfc:`2047` if necessary).
|
||||
|
||||
.. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement`
|
||||
"""
|
||||
|
||||
cookie = SimpleCookie()
|
||||
"""See help(Cookie)."""
|
||||
|
||||
body = ResponseBody()
|
||||
"""The body (entity) of the HTTP response."""
|
||||
|
||||
time = None
|
||||
"""The value of time.time() when created. Use in HTTP dates."""
|
||||
|
||||
timeout = 300
|
||||
"""Seconds after which the response will be aborted."""
|
||||
|
||||
timed_out = False
|
||||
"""
|
||||
Flag to indicate the response should be aborted, because it has
|
||||
exceeded its timeout."""
|
||||
|
||||
stream = False
|
||||
"""If False, buffer the response body."""
|
||||
|
||||
def __init__(self):
|
||||
self.status = None
|
||||
self.header_list = None
|
||||
self._body = []
|
||||
self.time = time.time()
|
||||
|
||||
self.headers = httputil.HeaderMap()
|
||||
# Since we know all our keys are titled strings, we can
|
||||
# bypass HeaderMap.update and get a big speed boost.
|
||||
dict.update(self.headers, {
|
||||
"Content-Type": 'text/html',
|
||||
"Server": "CherryPy/" + cherrypy.__version__,
|
||||
"Date": httputil.HTTPDate(self.time),
|
||||
})
|
||||
self.cookie = SimpleCookie()
|
||||
|
||||
def collapse_body(self):
|
||||
"""Collapse self.body to a single string; replace it and return it."""
|
||||
if isinstance(self.body, basestring):
|
||||
return self.body
|
||||
|
||||
newbody = []
|
||||
for chunk in self.body:
|
||||
if py3k and not isinstance(chunk, bytes):
|
||||
raise TypeError("Chunk %s is not of type 'bytes'." %
|
||||
repr(chunk))
|
||||
newbody.append(chunk)
|
||||
newbody = ntob('').join(newbody)
|
||||
|
||||
self.body = newbody
|
||||
return newbody
|
||||
|
||||
def finalize(self):
|
||||
"""Transform headers (and cookies) into self.header_list. (Core)"""
|
||||
try:
|
||||
code, reason, _ = httputil.valid_status(self.status)
|
||||
except ValueError:
|
||||
raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0])
|
||||
|
||||
headers = self.headers
|
||||
|
||||
self.status = "%s %s" % (code, reason)
|
||||
self.output_status = ntob(str(code), 'ascii') + \
|
||||
ntob(" ") + headers.encode(reason)
|
||||
|
||||
if self.stream:
|
||||
# The upshot: wsgiserver will chunk the response if
|
||||
# you pop Content-Length (or set it explicitly to None).
|
||||
# Note that lib.static sets C-L to the file's st_size.
|
||||
if dict.get(headers, 'Content-Length') is None:
|
||||
dict.pop(headers, 'Content-Length', None)
|
||||
elif code < 200 or code in (204, 205, 304):
|
||||
# "All 1xx (informational), 204 (no content),
|
||||
# and 304 (not modified) responses MUST NOT
|
||||
# include a message-body."
|
||||
dict.pop(headers, 'Content-Length', None)
|
||||
self.body = ntob("")
|
||||
else:
|
||||
# Responses which are not streamed should have a Content-Length,
|
||||
# but allow user code to set Content-Length if desired.
|
||||
if dict.get(headers, 'Content-Length') is None:
|
||||
content = self.collapse_body()
|
||||
dict.__setitem__(headers, 'Content-Length', len(content))
|
||||
|
||||
# Transform our header dict into a list of tuples.
|
||||
self.header_list = h = headers.output()
|
||||
|
||||
cookie = self.cookie.output()
|
||||
if cookie:
|
||||
for line in cookie.split("\n"):
|
||||
if line.endswith("\r"):
|
||||
# Python 2.4 emits cookies joined by LF but 2.5+ by CRLF.
|
||||
line = line[:-1]
|
||||
name, value = line.split(": ", 1)
|
||||
if isinstance(name, unicodestr):
|
||||
name = name.encode("ISO-8859-1")
|
||||
if isinstance(value, unicodestr):
|
||||
value = headers.encode(value)
|
||||
h.append((name, value))
|
||||
|
||||
def check_timeout(self):
|
||||
"""If now > self.time + self.timeout, set self.timed_out.
|
||||
|
||||
This purposefully sets a flag, rather than raising an error,
|
||||
so that a monitor thread can interrupt the Response thread.
|
||||
"""
|
||||
if time.time() > self.time + self.timeout:
|
||||
self.timed_out = True
|
||||
226
lib/cherrypy/_cpserver.py
Normal file
226
lib/cherrypy/_cpserver.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Manage HTTP servers with CherryPy."""
|
||||
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
from cherrypy.lib import attributes
|
||||
from cherrypy._cpcompat import basestring, py3k
|
||||
|
||||
# We import * because we want to export check_port
|
||||
# et al as attributes of this module.
|
||||
from cherrypy.process.servers import *
|
||||
|
||||
|
||||
class Server(ServerAdapter):
|
||||
|
||||
"""An adapter for an HTTP server.
|
||||
|
||||
You can set attributes (like socket_host and socket_port)
|
||||
on *this* object (which is probably cherrypy.server), and call
|
||||
quickstart. For example::
|
||||
|
||||
cherrypy.server.socket_port = 80
|
||||
cherrypy.quickstart()
|
||||
"""
|
||||
|
||||
socket_port = 8080
|
||||
"""The TCP port on which to listen for connections."""
|
||||
|
||||
_socket_host = '127.0.0.1'
|
||||
|
||||
def _get_socket_host(self):
|
||||
return self._socket_host
|
||||
|
||||
def _set_socket_host(self, value):
|
||||
if value == '':
|
||||
raise ValueError("The empty string ('') is not an allowed value. "
|
||||
"Use '0.0.0.0' instead to listen on all active "
|
||||
"interfaces (INADDR_ANY).")
|
||||
self._socket_host = value
|
||||
socket_host = property(
|
||||
_get_socket_host,
|
||||
_set_socket_host,
|
||||
doc="""The hostname or IP address on which to listen for connections.
|
||||
|
||||
Host values may be any IPv4 or IPv6 address, or any valid hostname.
|
||||
The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if
|
||||
your hosts file prefers IPv6). The string '0.0.0.0' is a special
|
||||
IPv4 entry meaning "any active interface" (INADDR_ANY), and '::'
|
||||
is the similar IN6ADDR_ANY for IPv6. The empty string or None are
|
||||
not allowed.""")
|
||||
|
||||
socket_file = None
|
||||
"""If given, the name of the UNIX socket to use instead of TCP/IP.
|
||||
|
||||
When this option is not None, the `socket_host` and `socket_port` options
|
||||
are ignored."""
|
||||
|
||||
socket_queue_size = 5
|
||||
"""The 'backlog' argument to socket.listen(); specifies the maximum number
|
||||
of queued connections (default 5)."""
|
||||
|
||||
socket_timeout = 10
|
||||
"""The timeout in seconds for accepted connections (default 10)."""
|
||||
|
||||
accepted_queue_size = -1
|
||||
"""The maximum number of requests which will be queued up before
|
||||
the server refuses to accept it (default -1, meaning no limit)."""
|
||||
|
||||
accepted_queue_timeout = 10
|
||||
"""The timeout in seconds for attempting to add a request to the
|
||||
queue when the queue is full (default 10)."""
|
||||
|
||||
shutdown_timeout = 5
|
||||
"""The time to wait for HTTP worker threads to clean up."""
|
||||
|
||||
protocol_version = 'HTTP/1.1'
|
||||
"""The version string to write in the Status-Line of all HTTP responses,
|
||||
for example, "HTTP/1.1" (the default). Depending on the HTTP server used,
|
||||
this should also limit the supported features used in the response."""
|
||||
|
||||
thread_pool = 10
|
||||
"""The number of worker threads to start up in the pool."""
|
||||
|
||||
thread_pool_max = -1
|
||||
"""The maximum size of the worker-thread pool. Use -1 to indicate no limit.
|
||||
"""
|
||||
|
||||
max_request_header_size = 500 * 1024
|
||||
"""The maximum number of bytes allowable in the request headers.
|
||||
If exceeded, the HTTP server should return "413 Request Entity Too Large".
|
||||
"""
|
||||
|
||||
max_request_body_size = 100 * 1024 * 1024
|
||||
"""The maximum number of bytes allowable in the request body. If exceeded,
|
||||
the HTTP server should return "413 Request Entity Too Large"."""
|
||||
|
||||
instance = None
|
||||
"""If not None, this should be an HTTP server instance (such as
|
||||
CPWSGIServer) which cherrypy.server will control. Use this when you need
|
||||
more control over object instantiation than is available in the various
|
||||
configuration options."""
|
||||
|
||||
ssl_context = None
|
||||
"""When using PyOpenSSL, an instance of SSL.Context."""
|
||||
|
||||
ssl_certificate = None
|
||||
"""The filename of the SSL certificate to use."""
|
||||
|
||||
ssl_certificate_chain = None
|
||||
"""When using PyOpenSSL, the certificate chain to pass to
|
||||
Context.load_verify_locations."""
|
||||
|
||||
ssl_private_key = None
|
||||
"""The filename of the private key to use with SSL."""
|
||||
|
||||
if py3k:
|
||||
ssl_module = 'builtin'
|
||||
"""The name of a registered SSL adaptation module to use with
|
||||
the builtin WSGI server. Builtin options are: 'builtin' (to
|
||||
use the SSL library built into recent versions of Python).
|
||||
You may also register your own classes in the
|
||||
wsgiserver.ssl_adapters dict."""
|
||||
else:
|
||||
ssl_module = 'pyopenssl'
|
||||
"""The name of a registered SSL adaptation module to use with the
|
||||
builtin WSGI server. Builtin options are 'builtin' (to use the SSL
|
||||
library built into recent versions of Python) and 'pyopenssl' (to
|
||||
use the PyOpenSSL project, which you must install separately). You
|
||||
may also register your own classes in the wsgiserver.ssl_adapters
|
||||
dict."""
|
||||
|
||||
statistics = False
|
||||
"""Turns statistics-gathering on or off for aware HTTP servers."""
|
||||
|
||||
nodelay = True
|
||||
"""If True (the default since 3.1), sets the TCP_NODELAY socket option."""
|
||||
|
||||
wsgi_version = (1, 0)
|
||||
"""The WSGI version tuple to use with the builtin WSGI server.
|
||||
The provided options are (1, 0) [which includes support for PEP 3333,
|
||||
which declares it covers WSGI version 1.0.1 but still mandates the
|
||||
wsgi.version (1, 0)] and ('u', 0), an experimental unicode version.
|
||||
You may create and register your own experimental versions of the WSGI
|
||||
protocol by adding custom classes to the wsgiserver.wsgi_gateways dict."""
|
||||
|
||||
def __init__(self):
|
||||
self.bus = cherrypy.engine
|
||||
self.httpserver = None
|
||||
self.interrupt = None
|
||||
self.running = False
|
||||
|
||||
def httpserver_from_self(self, httpserver=None):
|
||||
"""Return a (httpserver, bind_addr) pair based on self attributes."""
|
||||
if httpserver is None:
|
||||
httpserver = self.instance
|
||||
if httpserver is None:
|
||||
from cherrypy import _cpwsgi_server
|
||||
httpserver = _cpwsgi_server.CPWSGIServer(self)
|
||||
if isinstance(httpserver, basestring):
|
||||
# Is anyone using this? Can I add an arg?
|
||||
httpserver = attributes(httpserver)(self)
|
||||
return httpserver, self.bind_addr
|
||||
|
||||
def start(self):
|
||||
"""Start the HTTP server."""
|
||||
if not self.httpserver:
|
||||
self.httpserver, self.bind_addr = self.httpserver_from_self()
|
||||
ServerAdapter.start(self)
|
||||
start.priority = 75
|
||||
|
||||
def _get_bind_addr(self):
|
||||
if self.socket_file:
|
||||
return self.socket_file
|
||||
if self.socket_host is None and self.socket_port is None:
|
||||
return None
|
||||
return (self.socket_host, self.socket_port)
|
||||
|
||||
def _set_bind_addr(self, value):
|
||||
if value is None:
|
||||
self.socket_file = None
|
||||
self.socket_host = None
|
||||
self.socket_port = None
|
||||
elif isinstance(value, basestring):
|
||||
self.socket_file = value
|
||||
self.socket_host = None
|
||||
self.socket_port = None
|
||||
else:
|
||||
try:
|
||||
self.socket_host, self.socket_port = value
|
||||
self.socket_file = None
|
||||
except ValueError:
|
||||
raise ValueError("bind_addr must be a (host, port) tuple "
|
||||
"(for TCP sockets) or a string (for Unix "
|
||||
"domain sockets), not %r" % value)
|
||||
bind_addr = property(
|
||||
_get_bind_addr,
|
||||
_set_bind_addr,
|
||||
doc='A (host, port) tuple for TCP sockets or '
|
||||
'a str for Unix domain sockets.')
|
||||
|
||||
def base(self):
|
||||
"""Return the base (scheme://host[:port] or sock file) for this server.
|
||||
"""
|
||||
if self.socket_file:
|
||||
return self.socket_file
|
||||
|
||||
host = self.socket_host
|
||||
if host in ('0.0.0.0', '::'):
|
||||
# 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY.
|
||||
# Look up the host name, which should be the
|
||||
# safest thing to spit out in a URL.
|
||||
import socket
|
||||
host = socket.gethostname()
|
||||
|
||||
port = self.socket_port
|
||||
|
||||
if self.ssl_certificate:
|
||||
scheme = "https"
|
||||
if port != 443:
|
||||
host += ":%s" % port
|
||||
else:
|
||||
scheme = "http"
|
||||
if port != 80:
|
||||
host += ":%s" % port
|
||||
|
||||
return "%s://%s" % (scheme, host)
|
||||
241
lib/cherrypy/_cpthreadinglocal.py
Normal file
241
lib/cherrypy/_cpthreadinglocal.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# This is a backport of Python-2.4's threading.local() implementation
|
||||
|
||||
"""Thread-local objects
|
||||
|
||||
(Note that this module provides a Python version of thread
|
||||
threading.local class. Depending on the version of Python you're
|
||||
using, there may be a faster one available. You should always import
|
||||
the local class from threading.)
|
||||
|
||||
Thread-local objects support the management of thread-local data.
|
||||
If you have data that you want to be local to a thread, simply create
|
||||
a thread-local object and use its attributes:
|
||||
|
||||
>>> mydata = local()
|
||||
>>> mydata.number = 42
|
||||
>>> mydata.number
|
||||
42
|
||||
|
||||
You can also access the local-object's dictionary:
|
||||
|
||||
>>> mydata.__dict__
|
||||
{'number': 42}
|
||||
>>> mydata.__dict__.setdefault('widgets', [])
|
||||
[]
|
||||
>>> mydata.widgets
|
||||
[]
|
||||
|
||||
What's important about thread-local objects is that their data are
|
||||
local to a thread. If we access the data in a different thread:
|
||||
|
||||
>>> log = []
|
||||
>>> def f():
|
||||
... items = mydata.__dict__.items()
|
||||
... items.sort()
|
||||
... log.append(items)
|
||||
... mydata.number = 11
|
||||
... log.append(mydata.number)
|
||||
|
||||
>>> import threading
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
>>> log
|
||||
[[], 11]
|
||||
|
||||
we get different data. Furthermore, changes made in the other thread
|
||||
don't affect data seen in this thread:
|
||||
|
||||
>>> mydata.number
|
||||
42
|
||||
|
||||
Of course, values you get from a local object, including a __dict__
|
||||
attribute, are for whatever thread was current at the time the
|
||||
attribute was read. For that reason, you generally don't want to save
|
||||
these values across threads, as they apply only to the thread they
|
||||
came from.
|
||||
|
||||
You can create custom local objects by subclassing the local class:
|
||||
|
||||
>>> class MyLocal(local):
|
||||
... number = 2
|
||||
... initialized = False
|
||||
... def __init__(self, **kw):
|
||||
... if self.initialized:
|
||||
... raise SystemError('__init__ called too many times')
|
||||
... self.initialized = True
|
||||
... self.__dict__.update(kw)
|
||||
... def squared(self):
|
||||
... return self.number ** 2
|
||||
|
||||
This can be useful to support default values, methods and
|
||||
initialization. Note that if you define an __init__ method, it will be
|
||||
called each time the local object is used in a separate thread. This
|
||||
is necessary to initialize each thread's dictionary.
|
||||
|
||||
Now if we create a local object:
|
||||
|
||||
>>> mydata = MyLocal(color='red')
|
||||
|
||||
Now we have a default number:
|
||||
|
||||
>>> mydata.number
|
||||
2
|
||||
|
||||
an initial color:
|
||||
|
||||
>>> mydata.color
|
||||
'red'
|
||||
>>> del mydata.color
|
||||
|
||||
And a method that operates on the data:
|
||||
|
||||
>>> mydata.squared()
|
||||
4
|
||||
|
||||
As before, we can access the data in a separate thread:
|
||||
|
||||
>>> log = []
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
>>> log
|
||||
[[('color', 'red'), ('initialized', True)], 11]
|
||||
|
||||
without affecting this thread's data:
|
||||
|
||||
>>> mydata.number
|
||||
2
|
||||
>>> mydata.color
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: 'MyLocal' object has no attribute 'color'
|
||||
|
||||
Note that subclasses can define slots, but they are not thread
|
||||
local. They are shared across threads:
|
||||
|
||||
>>> class MyLocal(local):
|
||||
... __slots__ = 'number'
|
||||
|
||||
>>> mydata = MyLocal()
|
||||
>>> mydata.number = 42
|
||||
>>> mydata.color = 'red'
|
||||
|
||||
So, the separate thread:
|
||||
|
||||
>>> thread = threading.Thread(target=f)
|
||||
>>> thread.start()
|
||||
>>> thread.join()
|
||||
|
||||
affects what we see:
|
||||
|
||||
>>> mydata.number
|
||||
11
|
||||
|
||||
>>> del mydata
|
||||
"""
|
||||
|
||||
# Threading import is at end
|
||||
|
||||
|
||||
class _localbase(object):
|
||||
__slots__ = '_local__key', '_local__args', '_local__lock'
|
||||
|
||||
def __new__(cls, *args, **kw):
|
||||
self = object.__new__(cls)
|
||||
key = 'thread.local.' + str(id(self))
|
||||
object.__setattr__(self, '_local__key', key)
|
||||
object.__setattr__(self, '_local__args', (args, kw))
|
||||
object.__setattr__(self, '_local__lock', RLock())
|
||||
|
||||
if args or kw and (cls.__init__ is object.__init__):
|
||||
raise TypeError("Initialization arguments are not supported")
|
||||
|
||||
# We need to create the thread dict in anticipation of
|
||||
# __init__ being called, to make sure we don't call it
|
||||
# again ourselves.
|
||||
dict = object.__getattribute__(self, '__dict__')
|
||||
currentThread().__dict__[key] = dict
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def _patch(self):
|
||||
key = object.__getattribute__(self, '_local__key')
|
||||
d = currentThread().__dict__.get(key)
|
||||
if d is None:
|
||||
d = {}
|
||||
currentThread().__dict__[key] = d
|
||||
object.__setattr__(self, '__dict__', d)
|
||||
|
||||
# we have a new instance dict, so call out __init__ if we have
|
||||
# one
|
||||
cls = type(self)
|
||||
if cls.__init__ is not object.__init__:
|
||||
args, kw = object.__getattribute__(self, '_local__args')
|
||||
cls.__init__(self, *args, **kw)
|
||||
else:
|
||||
object.__setattr__(self, '__dict__', d)
|
||||
|
||||
|
||||
class local(_localbase):
|
||||
|
||||
def __getattribute__(self, name):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__getattribute__(self, name)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__setattr__(self, name, value)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __delattr__(self, name):
|
||||
lock = object.__getattribute__(self, '_local__lock')
|
||||
lock.acquire()
|
||||
try:
|
||||
_patch(self)
|
||||
return object.__delattr__(self, name)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
def __del__():
|
||||
threading_enumerate = enumerate
|
||||
__getattribute__ = object.__getattribute__
|
||||
|
||||
def __del__(self):
|
||||
key = __getattribute__(self, '_local__key')
|
||||
|
||||
try:
|
||||
threads = list(threading_enumerate())
|
||||
except:
|
||||
# if enumerate fails, as it seems to do during
|
||||
# shutdown, we'll skip cleanup under the assumption
|
||||
# that there is nothing to clean up
|
||||
return
|
||||
|
||||
for thread in threads:
|
||||
try:
|
||||
__dict__ = thread.__dict__
|
||||
except AttributeError:
|
||||
# Thread is dying, rest in peace
|
||||
continue
|
||||
|
||||
if key in __dict__:
|
||||
try:
|
||||
del __dict__[key]
|
||||
except KeyError:
|
||||
pass # didn't have anything in this thread
|
||||
|
||||
return __del__
|
||||
__del__ = __del__()
|
||||
|
||||
from threading import currentThread, enumerate, RLock
|
||||
529
lib/cherrypy/_cptools.py
Normal file
529
lib/cherrypy/_cptools.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""CherryPy tools. A "tool" is any helper, adapted to CP.
|
||||
|
||||
Tools are usually designed to be used in a variety of ways (although some
|
||||
may only offer one if they choose):
|
||||
|
||||
Library calls
|
||||
All tools are callables that can be used wherever needed.
|
||||
The arguments are straightforward and should be detailed within the
|
||||
docstring.
|
||||
|
||||
Function decorators
|
||||
All tools, when called, may be used as decorators which configure
|
||||
individual CherryPy page handlers (methods on the CherryPy tree).
|
||||
That is, "@tools.anytool()" should "turn on" the tool via the
|
||||
decorated function's _cp_config attribute.
|
||||
|
||||
CherryPy config
|
||||
If a tool exposes a "_setup" callable, it will be called
|
||||
once per Request (if the feature is "turned on" via config).
|
||||
|
||||
Tools may be implemented as any object with a namespace. The builtins
|
||||
are generally either modules or instances of the tools.Tool class.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import cherrypy
|
||||
|
||||
|
||||
def _getargs(func):
|
||||
"""Return the names of all static arguments to the given function."""
|
||||
# Use this instead of importing inspect for less mem overhead.
|
||||
import types
|
||||
if sys.version_info >= (3, 0):
|
||||
if isinstance(func, types.MethodType):
|
||||
func = func.__func__
|
||||
co = func.__code__
|
||||
else:
|
||||
if isinstance(func, types.MethodType):
|
||||
func = func.im_func
|
||||
co = func.func_code
|
||||
return co.co_varnames[:co.co_argcount]
|
||||
|
||||
|
||||
_attr_error = (
|
||||
"CherryPy Tools cannot be turned on directly. Instead, turn them "
|
||||
"on via config, or use them as decorators on your page handlers."
|
||||
)
|
||||
|
||||
|
||||
class Tool(object):
|
||||
|
||||
"""A registered function for use with CherryPy request-processing hooks.
|
||||
|
||||
help(tool.callable) should give you more information about this Tool.
|
||||
"""
|
||||
|
||||
namespace = "tools"
|
||||
|
||||
def __init__(self, point, callable, name=None, priority=50):
|
||||
self._point = point
|
||||
self.callable = callable
|
||||
self._name = name
|
||||
self._priority = priority
|
||||
self.__doc__ = self.callable.__doc__
|
||||
self._setargs()
|
||||
|
||||
def _get_on(self):
|
||||
raise AttributeError(_attr_error)
|
||||
|
||||
def _set_on(self, value):
|
||||
raise AttributeError(_attr_error)
|
||||
on = property(_get_on, _set_on)
|
||||
|
||||
def _setargs(self):
|
||||
"""Copy func parameter names to obj attributes."""
|
||||
try:
|
||||
for arg in _getargs(self.callable):
|
||||
setattr(self, arg, None)
|
||||
except (TypeError, AttributeError):
|
||||
if hasattr(self.callable, "__call__"):
|
||||
for arg in _getargs(self.callable.__call__):
|
||||
setattr(self, arg, None)
|
||||
# IronPython 1.0 raises NotImplementedError because
|
||||
# inspect.getargspec tries to access Python bytecode
|
||||
# in co_code attribute.
|
||||
except NotImplementedError:
|
||||
pass
|
||||
# IronPython 1B1 may raise IndexError in some cases,
|
||||
# but if we trap it here it doesn't prevent CP from working.
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def _merged_args(self, d=None):
|
||||
"""Return a dict of configuration entries for this Tool."""
|
||||
if d:
|
||||
conf = d.copy()
|
||||
else:
|
||||
conf = {}
|
||||
|
||||
tm = cherrypy.serving.request.toolmaps[self.namespace]
|
||||
if self._name in tm:
|
||||
conf.update(tm[self._name])
|
||||
|
||||
if "on" in conf:
|
||||
del conf["on"]
|
||||
|
||||
return conf
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Compile-time decorator (turn on the tool in config).
|
||||
|
||||
For example::
|
||||
|
||||
@tools.proxy()
|
||||
def whats_my_base(self):
|
||||
return cherrypy.request.base
|
||||
whats_my_base.exposed = True
|
||||
"""
|
||||
if args:
|
||||
raise TypeError("The %r Tool does not accept positional "
|
||||
"arguments; you must use keyword arguments."
|
||||
% self._name)
|
||||
|
||||
def tool_decorator(f):
|
||||
if not hasattr(f, "_cp_config"):
|
||||
f._cp_config = {}
|
||||
subspace = self.namespace + "." + self._name + "."
|
||||
f._cp_config[subspace + "on"] = True
|
||||
for k, v in kwargs.items():
|
||||
f._cp_config[subspace + k] = v
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
cherrypy.serving.request.hooks.attach(self._point, self.callable,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class HandlerTool(Tool):
|
||||
|
||||
"""Tool which is called 'before main', that may skip normal handlers.
|
||||
|
||||
If the tool successfully handles the request (by setting response.body),
|
||||
if should return True. This will cause CherryPy to skip any 'normal' page
|
||||
handler. If the tool did not handle the request, it should return False
|
||||
to tell CherryPy to continue on and call the normal page handler. If the
|
||||
tool is declared AS a page handler (see the 'handler' method), returning
|
||||
False will raise NotFound.
|
||||
"""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
Tool.__init__(self, 'before_handler', callable, name)
|
||||
|
||||
def handler(self, *args, **kwargs):
|
||||
"""Use this tool as a CherryPy page handler.
|
||||
|
||||
For example::
|
||||
|
||||
class Root:
|
||||
nav = tools.staticdir.handler(section="/nav", dir="nav",
|
||||
root=absDir)
|
||||
"""
|
||||
def handle_func(*a, **kw):
|
||||
handled = self.callable(*args, **self._merged_args(kwargs))
|
||||
if not handled:
|
||||
raise cherrypy.NotFound()
|
||||
return cherrypy.serving.response.body
|
||||
handle_func.exposed = True
|
||||
return handle_func
|
||||
|
||||
def _wrapper(self, **kwargs):
|
||||
if self.callable(**kwargs):
|
||||
cherrypy.serving.request.handler = None
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
cherrypy.serving.request.hooks.attach(self._point, self._wrapper,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class HandlerWrapperTool(Tool):
|
||||
|
||||
"""Tool which wraps request.handler in a provided wrapper function.
|
||||
|
||||
The 'newhandler' arg must be a handler wrapper function that takes a
|
||||
'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all
|
||||
page handler
|
||||
functions, it must return an iterable for use as cherrypy.response.body.
|
||||
|
||||
For example, to allow your 'inner' page handlers to return dicts
|
||||
which then get interpolated into a template::
|
||||
|
||||
def interpolator(next_handler, *args, **kwargs):
|
||||
filename = cherrypy.request.config.get('template')
|
||||
cherrypy.response.template = env.get_template(filename)
|
||||
response_dict = next_handler(*args, **kwargs)
|
||||
return cherrypy.response.template.render(**response_dict)
|
||||
cherrypy.tools.jinja = HandlerWrapperTool(interpolator)
|
||||
"""
|
||||
|
||||
def __init__(self, newhandler, point='before_handler', name=None,
|
||||
priority=50):
|
||||
self.newhandler = newhandler
|
||||
self._point = point
|
||||
self._name = name
|
||||
self._priority = priority
|
||||
|
||||
def callable(self, *args, **kwargs):
|
||||
innerfunc = cherrypy.serving.request.handler
|
||||
|
||||
def wrap(*args, **kwargs):
|
||||
return self.newhandler(innerfunc, *args, **kwargs)
|
||||
cherrypy.serving.request.handler = wrap
|
||||
|
||||
|
||||
class ErrorTool(Tool):
|
||||
|
||||
"""Tool which is used to replace the default request.error_response."""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
Tool.__init__(self, None, callable, name)
|
||||
|
||||
def _wrapper(self):
|
||||
self.callable(**self._merged_args())
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
cherrypy.serving.request.error_response = self._wrapper
|
||||
|
||||
|
||||
# Builtin tools #
|
||||
|
||||
from cherrypy.lib import cptools, encoding, auth, static, jsontools
|
||||
from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc
|
||||
from cherrypy.lib import caching as _caching
|
||||
from cherrypy.lib import auth_basic, auth_digest
|
||||
|
||||
|
||||
class SessionTool(Tool):
|
||||
|
||||
"""Session Tool for CherryPy.
|
||||
|
||||
sessions.locking
|
||||
When 'implicit' (the default), the session will be locked for you,
|
||||
just before running the page handler.
|
||||
|
||||
When 'early', the session will be locked before reading the request
|
||||
body. This is off by default for safety reasons; for example,
|
||||
a large upload would block the session, denying an AJAX
|
||||
progress meter
|
||||
(`issue <https://bitbucket.org/cherrypy/cherrypy/issue/630>`_).
|
||||
|
||||
When 'explicit' (or any other value), you need to call
|
||||
cherrypy.session.acquire_lock() yourself before using
|
||||
session data.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# _sessions.init must be bound after headers are read
|
||||
Tool.__init__(self, 'before_request_body', _sessions.init)
|
||||
|
||||
def _lock_session(self):
|
||||
cherrypy.serving.session.acquire_lock()
|
||||
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
"""
|
||||
hooks = cherrypy.serving.request.hooks
|
||||
|
||||
conf = self._merged_args()
|
||||
|
||||
p = conf.pop("priority", None)
|
||||
if p is None:
|
||||
p = getattr(self.callable, "priority", self._priority)
|
||||
|
||||
hooks.attach(self._point, self.callable, priority=p, **conf)
|
||||
|
||||
locking = conf.pop('locking', 'implicit')
|
||||
if locking == 'implicit':
|
||||
hooks.attach('before_handler', self._lock_session)
|
||||
elif locking == 'early':
|
||||
# Lock before the request body (but after _sessions.init runs!)
|
||||
hooks.attach('before_request_body', self._lock_session,
|
||||
priority=60)
|
||||
else:
|
||||
# Don't lock
|
||||
pass
|
||||
|
||||
hooks.attach('before_finalize', _sessions.save)
|
||||
hooks.attach('on_end_request', _sessions.close)
|
||||
|
||||
def regenerate(self):
|
||||
"""Drop the current session and make a new one (with a new id)."""
|
||||
sess = cherrypy.serving.session
|
||||
sess.regenerate()
|
||||
|
||||
# Grab cookie-relevant tool args
|
||||
conf = dict([(k, v) for k, v in self._merged_args().items()
|
||||
if k in ('path', 'path_header', 'name', 'timeout',
|
||||
'domain', 'secure')])
|
||||
_sessions.set_response_cookie(**conf)
|
||||
|
||||
|
||||
class XMLRPCController(object):
|
||||
|
||||
"""A Controller (page handler collection) for XML-RPC.
|
||||
|
||||
To use it, have your controllers subclass this base class (it will
|
||||
turn on the tool for you).
|
||||
|
||||
You can also supply the following optional config entries::
|
||||
|
||||
tools.xmlrpc.encoding: 'utf-8'
|
||||
tools.xmlrpc.allow_none: 0
|
||||
|
||||
XML-RPC is a rather discontinuous layer over HTTP; dispatching to the
|
||||
appropriate handler must first be performed according to the URL, and
|
||||
then a second dispatch step must take place according to the RPC method
|
||||
specified in the request body. It also allows a superfluous "/RPC2"
|
||||
prefix in the URL, supplies its own handler args in the body, and
|
||||
requires a 200 OK "Fault" response instead of 404 when the desired
|
||||
method is not found.
|
||||
|
||||
Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone.
|
||||
This Controller acts as the dispatch target for the first half (based
|
||||
on the URL); it then reads the RPC method from the request body and
|
||||
does its own second dispatch step based on that method. It also reads
|
||||
body params, and returns a Fault on error.
|
||||
|
||||
The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2
|
||||
in your URL's, you can safely skip turning on the XMLRPCDispatcher.
|
||||
Otherwise, you need to use declare it in config::
|
||||
|
||||
request.dispatch: cherrypy.dispatch.XMLRPCDispatcher()
|
||||
"""
|
||||
|
||||
# Note we're hard-coding this into the 'tools' namespace. We could do
|
||||
# a huge amount of work to make it relocatable, but the only reason why
|
||||
# would be if someone actually disabled the default_toolbox. Meh.
|
||||
_cp_config = {'tools.xmlrpc.on': True}
|
||||
|
||||
def default(self, *vpath, **params):
|
||||
rpcparams, rpcmethod = _xmlrpc.process_body()
|
||||
|
||||
subhandler = self
|
||||
for attr in str(rpcmethod).split('.'):
|
||||
subhandler = getattr(subhandler, attr, None)
|
||||
|
||||
if subhandler and getattr(subhandler, "exposed", False):
|
||||
body = subhandler(*(vpath + rpcparams), **params)
|
||||
|
||||
else:
|
||||
# https://bitbucket.org/cherrypy/cherrypy/issue/533
|
||||
# if a method is not found, an xmlrpclib.Fault should be returned
|
||||
# raising an exception here will do that; see
|
||||
# cherrypy.lib.xmlrpcutil.on_error
|
||||
raise Exception('method "%s" is not supported' % attr)
|
||||
|
||||
conf = cherrypy.serving.request.toolmaps['tools'].get("xmlrpc", {})
|
||||
_xmlrpc.respond(body,
|
||||
conf.get('encoding', 'utf-8'),
|
||||
conf.get('allow_none', 0))
|
||||
return cherrypy.serving.response.body
|
||||
default.exposed = True
|
||||
|
||||
|
||||
class SessionAuthTool(HandlerTool):
|
||||
|
||||
def _setargs(self):
|
||||
for name in dir(cptools.SessionAuth):
|
||||
if not name.startswith("__"):
|
||||
setattr(self, name, None)
|
||||
|
||||
|
||||
class CachingTool(Tool):
|
||||
|
||||
"""Caching Tool for CherryPy."""
|
||||
|
||||
def _wrapper(self, **kwargs):
|
||||
request = cherrypy.serving.request
|
||||
if _caching.get(**kwargs):
|
||||
request.handler = None
|
||||
else:
|
||||
if request.cacheable:
|
||||
# Note the devious technique here of adding hooks on the fly
|
||||
request.hooks.attach('before_finalize', _caching.tee_output,
|
||||
priority=90)
|
||||
_wrapper.priority = 20
|
||||
|
||||
def _setup(self):
|
||||
"""Hook caching into cherrypy.request."""
|
||||
conf = self._merged_args()
|
||||
|
||||
p = conf.pop("priority", None)
|
||||
cherrypy.serving.request.hooks.attach('before_handler', self._wrapper,
|
||||
priority=p, **conf)
|
||||
|
||||
|
||||
class Toolbox(object):
|
||||
|
||||
"""A collection of Tools.
|
||||
|
||||
This object also functions as a config namespace handler for itself.
|
||||
Custom toolboxes should be added to each Application's toolboxes dict.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace):
|
||||
self.namespace = namespace
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# If the Tool._name is None, supply it from the attribute name.
|
||||
if isinstance(value, Tool):
|
||||
if value._name is None:
|
||||
value._name = name
|
||||
value.namespace = self.namespace
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def __enter__(self):
|
||||
"""Populate request.toolmaps from tools specified in config."""
|
||||
cherrypy.serving.request.toolmaps[self.namespace] = map = {}
|
||||
|
||||
def populate(k, v):
|
||||
toolname, arg = k.split(".", 1)
|
||||
bucket = map.setdefault(toolname, {})
|
||||
bucket[arg] = v
|
||||
return populate
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Run tool._setup() for each tool in our toolmap."""
|
||||
map = cherrypy.serving.request.toolmaps.get(self.namespace)
|
||||
if map:
|
||||
for name, settings in map.items():
|
||||
if settings.get("on", False):
|
||||
tool = getattr(self, name)
|
||||
tool._setup()
|
||||
|
||||
|
||||
class DeprecatedTool(Tool):
|
||||
|
||||
_name = None
|
||||
warnmsg = "This Tool is deprecated."
|
||||
|
||||
def __init__(self, point, warnmsg=None):
|
||||
self.point = point
|
||||
if warnmsg is not None:
|
||||
self.warnmsg = warnmsg
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
warnings.warn(self.warnmsg)
|
||||
|
||||
def tool_decorator(f):
|
||||
return f
|
||||
return tool_decorator
|
||||
|
||||
def _setup(self):
|
||||
warnings.warn(self.warnmsg)
|
||||
|
||||
|
||||
default_toolbox = _d = Toolbox("tools")
|
||||
_d.session_auth = SessionAuthTool(cptools.session_auth)
|
||||
_d.allow = Tool('on_start_resource', cptools.allow)
|
||||
_d.proxy = Tool('before_request_body', cptools.proxy, priority=30)
|
||||
_d.response_headers = Tool('on_start_resource', cptools.response_headers)
|
||||
_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback)
|
||||
_d.log_headers = Tool('before_error_response', cptools.log_request_headers)
|
||||
_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100)
|
||||
_d.err_redirect = ErrorTool(cptools.redirect)
|
||||
_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75)
|
||||
_d.decode = Tool('before_request_body', encoding.decode)
|
||||
# the order of encoding, gzip, caching is important
|
||||
_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70)
|
||||
_d.gzip = Tool('before_finalize', encoding.gzip, priority=80)
|
||||
_d.staticdir = HandlerTool(static.staticdir)
|
||||
_d.staticfile = HandlerTool(static.staticfile)
|
||||
_d.sessions = SessionTool()
|
||||
_d.xmlrpc = ErrorTool(_xmlrpc.on_error)
|
||||
_d.caching = CachingTool('before_handler', _caching.get, 'caching')
|
||||
_d.expires = Tool('before_finalize', _caching.expires)
|
||||
_d.tidy = DeprecatedTool(
|
||||
'before_finalize',
|
||||
"The tidy tool has been removed from the standard distribution of "
|
||||
"CherryPy. The most recent version can be found at "
|
||||
"http://tools.cherrypy.org/browser.")
|
||||
_d.nsgmls = DeprecatedTool(
|
||||
'before_finalize',
|
||||
"The nsgmls tool has been removed from the standard distribution of "
|
||||
"CherryPy. The most recent version can be found at "
|
||||
"http://tools.cherrypy.org/browser.")
|
||||
_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers)
|
||||
_d.referer = Tool('before_request_body', cptools.referer)
|
||||
_d.basic_auth = Tool('on_start_resource', auth.basic_auth)
|
||||
_d.digest_auth = Tool('on_start_resource', auth.digest_auth)
|
||||
_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60)
|
||||
_d.flatten = Tool('before_finalize', cptools.flatten)
|
||||
_d.accept = Tool('on_start_resource', cptools.accept)
|
||||
_d.redirect = Tool('on_start_resource', cptools.redirect)
|
||||
_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0)
|
||||
_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30)
|
||||
_d.json_out = Tool('before_handler', jsontools.json_out, priority=30)
|
||||
_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1)
|
||||
_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1)
|
||||
|
||||
del _d, cptools, encoding, auth, static
|
||||
299
lib/cherrypy/_cptree.py
Normal file
299
lib/cherrypy/_cptree.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""CherryPy Application and Tree objects."""
|
||||
|
||||
import os
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import ntou, py3k
|
||||
from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools
|
||||
from cherrypy.lib import httputil
|
||||
|
||||
|
||||
class Application(object):
|
||||
|
||||
"""A CherryPy Application.
|
||||
|
||||
Servers and gateways should not instantiate Request objects directly.
|
||||
Instead, they should ask an Application object for a request object.
|
||||
|
||||
An instance of this class may also be used as a WSGI callable
|
||||
(WSGI application object) for itself.
|
||||
"""
|
||||
|
||||
root = None
|
||||
"""The top-most container of page handlers for this app. Handlers should
|
||||
be arranged in a hierarchy of attributes, matching the expected URI
|
||||
hierarchy; the default dispatcher then searches this hierarchy for a
|
||||
matching handler. When using a dispatcher other than the default,
|
||||
this value may be None."""
|
||||
|
||||
config = {}
|
||||
"""A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict
|
||||
of {key: value} pairs."""
|
||||
|
||||
namespaces = _cpconfig.NamespaceSet()
|
||||
toolboxes = {'tools': cherrypy.tools}
|
||||
|
||||
log = None
|
||||
"""A LogManager instance. See _cplogging."""
|
||||
|
||||
wsgiapp = None
|
||||
"""A CPWSGIApp instance. See _cpwsgi."""
|
||||
|
||||
request_class = _cprequest.Request
|
||||
response_class = _cprequest.Response
|
||||
|
||||
relative_urls = False
|
||||
|
||||
def __init__(self, root, script_name="", config=None):
|
||||
self.log = _cplogging.LogManager(id(self), cherrypy.log.logger_root)
|
||||
self.root = root
|
||||
self.script_name = script_name
|
||||
self.wsgiapp = _cpwsgi.CPWSGIApp(self)
|
||||
|
||||
self.namespaces = self.namespaces.copy()
|
||||
self.namespaces["log"] = lambda k, v: setattr(self.log, k, v)
|
||||
self.namespaces["wsgi"] = self.wsgiapp.namespace_handler
|
||||
|
||||
self.config = self.__class__.config.copy()
|
||||
if config:
|
||||
self.merge(config)
|
||||
|
||||
def __repr__(self):
|
||||
return "%s.%s(%r, %r)" % (self.__module__, self.__class__.__name__,
|
||||
self.root, self.script_name)
|
||||
|
||||
script_name_doc = """The URI "mount point" for this app. A mount point
|
||||
is that portion of the URI which is constant for all URIs that are
|
||||
serviced by this application; it does not include scheme, host, or proxy
|
||||
("virtual host") portions of the URI.
|
||||
|
||||
For example, if script_name is "/my/cool/app", then the URL
|
||||
"http://www.example.com/my/cool/app/page1" might be handled by a
|
||||
"page1" method on the root object.
|
||||
|
||||
The value of script_name MUST NOT end in a slash. If the script_name
|
||||
refers to the root of the URI, it MUST be an empty string (not "/").
|
||||
|
||||
If script_name is explicitly set to None, then the script_name will be
|
||||
provided for each call from request.wsgi_environ['SCRIPT_NAME'].
|
||||
"""
|
||||
|
||||
def _get_script_name(self):
|
||||
if self._script_name is not None:
|
||||
return self._script_name
|
||||
|
||||
# A `_script_name` with a value of None signals that the script name
|
||||
# should be pulled from WSGI environ.
|
||||
return cherrypy.serving.request.wsgi_environ['SCRIPT_NAME'].rstrip("/")
|
||||
|
||||
def _set_script_name(self, value):
|
||||
if value:
|
||||
value = value.rstrip("/")
|
||||
self._script_name = value
|
||||
script_name = property(fget=_get_script_name, fset=_set_script_name,
|
||||
doc=script_name_doc)
|
||||
|
||||
def merge(self, config):
|
||||
"""Merge the given config into self.config."""
|
||||
_cpconfig.merge(self.config, config)
|
||||
|
||||
# Handle namespaces specified in config.
|
||||
self.namespaces(self.config.get("/", {}))
|
||||
|
||||
def find_config(self, path, key, default=None):
|
||||
"""Return the most-specific value for key along path, or default."""
|
||||
trail = path or "/"
|
||||
while trail:
|
||||
nodeconf = self.config.get(trail, {})
|
||||
|
||||
if key in nodeconf:
|
||||
return nodeconf[key]
|
||||
|
||||
lastslash = trail.rfind("/")
|
||||
if lastslash == -1:
|
||||
break
|
||||
elif lastslash == 0 and trail != "/":
|
||||
trail = "/"
|
||||
else:
|
||||
trail = trail[:lastslash]
|
||||
|
||||
return default
|
||||
|
||||
def get_serving(self, local, remote, scheme, sproto):
|
||||
"""Create and return a Request and Response object."""
|
||||
req = self.request_class(local, remote, scheme, sproto)
|
||||
req.app = self
|
||||
|
||||
for name, toolbox in self.toolboxes.items():
|
||||
req.namespaces[name] = toolbox
|
||||
|
||||
resp = self.response_class()
|
||||
cherrypy.serving.load(req, resp)
|
||||
cherrypy.engine.publish('acquire_thread')
|
||||
cherrypy.engine.publish('before_request')
|
||||
|
||||
return req, resp
|
||||
|
||||
def release_serving(self):
|
||||
"""Release the current serving (request and response)."""
|
||||
req = cherrypy.serving.request
|
||||
|
||||
cherrypy.engine.publish('after_request')
|
||||
|
||||
try:
|
||||
req.close()
|
||||
except:
|
||||
cherrypy.log(traceback=True, severity=40)
|
||||
|
||||
cherrypy.serving.clear()
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
return self.wsgiapp(environ, start_response)
|
||||
|
||||
|
||||
class Tree(object):
|
||||
|
||||
"""A registry of CherryPy applications, mounted at diverse points.
|
||||
|
||||
An instance of this class may also be used as a WSGI callable
|
||||
(WSGI application object), in which case it dispatches to all
|
||||
mounted apps.
|
||||
"""
|
||||
|
||||
apps = {}
|
||||
"""
|
||||
A dict of the form {script name: application}, where "script name"
|
||||
is a string declaring the URI mount point (no trailing slash), and
|
||||
"application" is an instance of cherrypy.Application (or an arbitrary
|
||||
WSGI callable if you happen to be using a WSGI server)."""
|
||||
|
||||
def __init__(self):
|
||||
self.apps = {}
|
||||
|
||||
def mount(self, root, script_name="", config=None):
|
||||
"""Mount a new app from a root object, script_name, and config.
|
||||
|
||||
root
|
||||
An instance of a "controller class" (a collection of page
|
||||
handler methods) which represents the root of the application.
|
||||
This may also be an Application instance, or None if using
|
||||
a dispatcher other than the default.
|
||||
|
||||
script_name
|
||||
A string containing the "mount point" of the application.
|
||||
This should start with a slash, and be the path portion of the
|
||||
URL at which to mount the given root. For example, if root.index()
|
||||
will handle requests to "http://www.example.com:8080/dept/app1/",
|
||||
then the script_name argument would be "/dept/app1".
|
||||
|
||||
It MUST NOT end in a slash. If the script_name refers to the
|
||||
root of the URI, it MUST be an empty string (not "/").
|
||||
|
||||
config
|
||||
A file or dict containing application config.
|
||||
"""
|
||||
if script_name is None:
|
||||
raise TypeError(
|
||||
"The 'script_name' argument may not be None. Application "
|
||||
"objects may, however, possess a script_name of None (in "
|
||||
"order to inpect the WSGI environ for SCRIPT_NAME upon each "
|
||||
"request). You cannot mount such Applications on this Tree; "
|
||||
"you must pass them to a WSGI server interface directly.")
|
||||
|
||||
# Next line both 1) strips trailing slash and 2) maps "/" -> "".
|
||||
script_name = script_name.rstrip("/")
|
||||
|
||||
if isinstance(root, Application):
|
||||
app = root
|
||||
if script_name != "" and script_name != app.script_name:
|
||||
raise ValueError(
|
||||
"Cannot specify a different script name and pass an "
|
||||
"Application instance to cherrypy.mount")
|
||||
script_name = app.script_name
|
||||
else:
|
||||
app = Application(root, script_name)
|
||||
|
||||
# If mounted at "", add favicon.ico
|
||||
if (script_name == "" and root is not None
|
||||
and not hasattr(root, "favicon_ico")):
|
||||
favicon = os.path.join(os.getcwd(), os.path.dirname(__file__),
|
||||
"favicon.ico")
|
||||
root.favicon_ico = tools.staticfile.handler(favicon)
|
||||
|
||||
if config:
|
||||
app.merge(config)
|
||||
|
||||
self.apps[script_name] = app
|
||||
|
||||
return app
|
||||
|
||||
def graft(self, wsgi_callable, script_name=""):
|
||||
"""Mount a wsgi callable at the given script_name."""
|
||||
# Next line both 1) strips trailing slash and 2) maps "/" -> "".
|
||||
script_name = script_name.rstrip("/")
|
||||
self.apps[script_name] = wsgi_callable
|
||||
|
||||
def script_name(self, path=None):
|
||||
"""The script_name of the app at the given path, or None.
|
||||
|
||||
If path is None, cherrypy.request is used.
|
||||
"""
|
||||
if path is None:
|
||||
try:
|
||||
request = cherrypy.serving.request
|
||||
path = httputil.urljoin(request.script_name,
|
||||
request.path_info)
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
while True:
|
||||
if path in self.apps:
|
||||
return path
|
||||
|
||||
if path == "":
|
||||
return None
|
||||
|
||||
# Move one node up the tree and try again.
|
||||
path = path[:path.rfind("/")]
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
# If you're calling this, then you're probably setting SCRIPT_NAME
|
||||
# to '' (some WSGI servers always set SCRIPT_NAME to '').
|
||||
# Try to look up the app using the full path.
|
||||
env1x = environ
|
||||
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
|
||||
env1x = _cpwsgi.downgrade_wsgi_ux_to_1x(environ)
|
||||
path = httputil.urljoin(env1x.get('SCRIPT_NAME', ''),
|
||||
env1x.get('PATH_INFO', ''))
|
||||
sn = self.script_name(path or "/")
|
||||
if sn is None:
|
||||
start_response('404 Not Found', [])
|
||||
return []
|
||||
|
||||
app = self.apps[sn]
|
||||
|
||||
# Correct the SCRIPT_NAME and PATH_INFO environ entries.
|
||||
environ = environ.copy()
|
||||
if not py3k:
|
||||
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
|
||||
# Python 2/WSGI u.0: all strings MUST be of type unicode
|
||||
enc = environ[ntou('wsgi.url_encoding')]
|
||||
environ[ntou('SCRIPT_NAME')] = sn.decode(enc)
|
||||
environ[ntou('PATH_INFO')] = path[
|
||||
len(sn.rstrip("/")):].decode(enc)
|
||||
else:
|
||||
# Python 2/WSGI 1.x: all strings MUST be of type str
|
||||
environ['SCRIPT_NAME'] = sn
|
||||
environ['PATH_INFO'] = path[len(sn.rstrip("/")):]
|
||||
else:
|
||||
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
|
||||
# Python 3/WSGI u.0: all strings MUST be full unicode
|
||||
environ['SCRIPT_NAME'] = sn
|
||||
environ['PATH_INFO'] = path[len(sn.rstrip("/")):]
|
||||
else:
|
||||
# Python 3/WSGI 1.x: all strings MUST be ISO-8859-1 str
|
||||
environ['SCRIPT_NAME'] = sn.encode(
|
||||
'utf-8').decode('ISO-8859-1')
|
||||
environ['PATH_INFO'] = path[
|
||||
len(sn.rstrip("/")):].encode('utf-8').decode('ISO-8859-1')
|
||||
return app(environ, start_response)
|
||||
438
lib/cherrypy/_cpwsgi.py
Normal file
438
lib/cherrypy/_cpwsgi.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""WSGI interface (see PEP 333 and 3333).
|
||||
|
||||
Note that WSGI environ keys and values are 'native strings'; that is,
|
||||
whatever the type of "" is. For Python 2, that's a byte string; for Python 3,
|
||||
it's a unicode string. But PEP 3333 says: "even if Python's str type is
|
||||
actually Unicode "under the hood", the content of native strings must
|
||||
still be translatable to bytes via the Latin-1 encoding!"
|
||||
"""
|
||||
|
||||
import sys as _sys
|
||||
|
||||
import cherrypy as _cherrypy
|
||||
from cherrypy._cpcompat import BytesIO, bytestr, ntob, ntou, py3k, unicodestr
|
||||
from cherrypy import _cperror
|
||||
from cherrypy.lib import httputil
|
||||
from cherrypy.lib import is_closable_iterator
|
||||
|
||||
def downgrade_wsgi_ux_to_1x(environ):
|
||||
"""Return a new environ dict for WSGI 1.x from the given WSGI u.x environ.
|
||||
"""
|
||||
env1x = {}
|
||||
|
||||
url_encoding = environ[ntou('wsgi.url_encoding')]
|
||||
for k, v in list(environ.items()):
|
||||
if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]:
|
||||
v = v.encode(url_encoding)
|
||||
elif isinstance(v, unicodestr):
|
||||
v = v.encode('ISO-8859-1')
|
||||
env1x[k.encode('ISO-8859-1')] = v
|
||||
|
||||
return env1x
|
||||
|
||||
|
||||
class VirtualHost(object):
|
||||
|
||||
"""Select a different WSGI application based on the Host header.
|
||||
|
||||
This can be useful when running multiple sites within one CP server.
|
||||
It allows several domains to point to different applications. For example::
|
||||
|
||||
root = Root()
|
||||
RootApp = cherrypy.Application(root)
|
||||
Domain2App = cherrypy.Application(root)
|
||||
SecureApp = cherrypy.Application(Secure())
|
||||
|
||||
vhost = cherrypy._cpwsgi.VirtualHost(RootApp,
|
||||
domains={'www.domain2.example': Domain2App,
|
||||
'www.domain2.example:443': SecureApp,
|
||||
})
|
||||
|
||||
cherrypy.tree.graft(vhost)
|
||||
"""
|
||||
default = None
|
||||
"""Required. The default WSGI application."""
|
||||
|
||||
use_x_forwarded_host = True
|
||||
"""If True (the default), any "X-Forwarded-Host"
|
||||
request header will be used instead of the "Host" header. This
|
||||
is commonly added by HTTP servers (such as Apache) when proxying."""
|
||||
|
||||
domains = {}
|
||||
"""A dict of {host header value: application} pairs.
|
||||
The incoming "Host" request header is looked up in this dict,
|
||||
and, if a match is found, the corresponding WSGI application
|
||||
will be called instead of the default. Note that you often need
|
||||
separate entries for "example.com" and "www.example.com".
|
||||
In addition, "Host" headers may contain the port number.
|
||||
"""
|
||||
|
||||
def __init__(self, default, domains=None, use_x_forwarded_host=True):
|
||||
self.default = default
|
||||
self.domains = domains or {}
|
||||
self.use_x_forwarded_host = use_x_forwarded_host
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
domain = environ.get('HTTP_HOST', '')
|
||||
if self.use_x_forwarded_host:
|
||||
domain = environ.get("HTTP_X_FORWARDED_HOST", domain)
|
||||
|
||||
nextapp = self.domains.get(domain)
|
||||
if nextapp is None:
|
||||
nextapp = self.default
|
||||
return nextapp(environ, start_response)
|
||||
|
||||
|
||||
class InternalRedirector(object):
|
||||
|
||||
"""WSGI middleware that handles raised cherrypy.InternalRedirect."""
|
||||
|
||||
def __init__(self, nextapp, recursive=False):
|
||||
self.nextapp = nextapp
|
||||
self.recursive = recursive
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
redirections = []
|
||||
while True:
|
||||
environ = environ.copy()
|
||||
try:
|
||||
return self.nextapp(environ, start_response)
|
||||
except _cherrypy.InternalRedirect:
|
||||
ir = _sys.exc_info()[1]
|
||||
sn = environ.get('SCRIPT_NAME', '')
|
||||
path = environ.get('PATH_INFO', '')
|
||||
qs = environ.get('QUERY_STRING', '')
|
||||
|
||||
# Add the *previous* path_info + qs to redirections.
|
||||
old_uri = sn + path
|
||||
if qs:
|
||||
old_uri += "?" + qs
|
||||
redirections.append(old_uri)
|
||||
|
||||
if not self.recursive:
|
||||
# Check to see if the new URI has been redirected to
|
||||
# already
|
||||
new_uri = sn + ir.path
|
||||
if ir.query_string:
|
||||
new_uri += "?" + ir.query_string
|
||||
if new_uri in redirections:
|
||||
ir.request.close()
|
||||
raise RuntimeError("InternalRedirector visited the "
|
||||
"same URL twice: %r" % new_uri)
|
||||
|
||||
# Munge the environment and try again.
|
||||
environ['REQUEST_METHOD'] = "GET"
|
||||
environ['PATH_INFO'] = ir.path
|
||||
environ['QUERY_STRING'] = ir.query_string
|
||||
environ['wsgi.input'] = BytesIO()
|
||||
environ['CONTENT_LENGTH'] = "0"
|
||||
environ['cherrypy.previous_request'] = ir.request
|
||||
|
||||
|
||||
class ExceptionTrapper(object):
|
||||
|
||||
"""WSGI middleware that traps exceptions."""
|
||||
|
||||
def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)):
|
||||
self.nextapp = nextapp
|
||||
self.throws = throws
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
return _TrappedResponse(
|
||||
self.nextapp,
|
||||
environ,
|
||||
start_response,
|
||||
self.throws
|
||||
)
|
||||
|
||||
|
||||
class _TrappedResponse(object):
|
||||
|
||||
response = iter([])
|
||||
|
||||
def __init__(self, nextapp, environ, start_response, throws):
|
||||
self.nextapp = nextapp
|
||||
self.environ = environ
|
||||
self.start_response = start_response
|
||||
self.throws = throws
|
||||
self.started_response = False
|
||||
self.response = self.trap(
|
||||
self.nextapp, self.environ, self.start_response)
|
||||
self.iter_response = iter(self.response)
|
||||
|
||||
def __iter__(self):
|
||||
self.started_response = True
|
||||
return self
|
||||
|
||||
if py3k:
|
||||
def __next__(self):
|
||||
return self.trap(next, self.iter_response)
|
||||
else:
|
||||
def next(self):
|
||||
return self.trap(self.iter_response.next)
|
||||
|
||||
def close(self):
|
||||
if hasattr(self.response, 'close'):
|
||||
self.response.close()
|
||||
|
||||
def trap(self, func, *args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except self.throws:
|
||||
raise
|
||||
except StopIteration:
|
||||
raise
|
||||
except:
|
||||
tb = _cperror.format_exc()
|
||||
#print('trapped (started %s):' % self.started_response, tb)
|
||||
_cherrypy.log(tb, severity=40)
|
||||
if not _cherrypy.request.show_tracebacks:
|
||||
tb = ""
|
||||
s, h, b = _cperror.bare_error(tb)
|
||||
if py3k:
|
||||
# What fun.
|
||||
s = s.decode('ISO-8859-1')
|
||||
h = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1'))
|
||||
for k, v in h]
|
||||
if self.started_response:
|
||||
# Empty our iterable (so future calls raise StopIteration)
|
||||
self.iter_response = iter([])
|
||||
else:
|
||||
self.iter_response = iter(b)
|
||||
|
||||
try:
|
||||
self.start_response(s, h, _sys.exc_info())
|
||||
except:
|
||||
# "The application must not trap any exceptions raised by
|
||||
# start_response, if it called start_response with exc_info.
|
||||
# Instead, it should allow such exceptions to propagate
|
||||
# back to the server or gateway."
|
||||
# But we still log and call close() to clean up ourselves.
|
||||
_cherrypy.log(traceback=True, severity=40)
|
||||
raise
|
||||
|
||||
if self.started_response:
|
||||
return ntob("").join(b)
|
||||
else:
|
||||
return b
|
||||
|
||||
|
||||
# WSGI-to-CP Adapter #
|
||||
|
||||
|
||||
class AppResponse(object):
|
||||
|
||||
"""WSGI response iterable for CherryPy applications."""
|
||||
|
||||
def __init__(self, environ, start_response, cpapp):
|
||||
self.cpapp = cpapp
|
||||
try:
|
||||
if not py3k:
|
||||
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
|
||||
environ = downgrade_wsgi_ux_to_1x(environ)
|
||||
self.environ = environ
|
||||
self.run()
|
||||
|
||||
r = _cherrypy.serving.response
|
||||
|
||||
outstatus = r.output_status
|
||||
if not isinstance(outstatus, bytestr):
|
||||
raise TypeError("response.output_status is not a byte string.")
|
||||
|
||||
outheaders = []
|
||||
for k, v in r.header_list:
|
||||
if not isinstance(k, bytestr):
|
||||
raise TypeError(
|
||||
"response.header_list key %r is not a byte string." %
|
||||
k)
|
||||
if not isinstance(v, bytestr):
|
||||
raise TypeError(
|
||||
"response.header_list value %r is not a byte string." %
|
||||
v)
|
||||
outheaders.append((k, v))
|
||||
|
||||
if py3k:
|
||||
# According to PEP 3333, when using Python 3, the response
|
||||
# status and headers must be bytes masquerading as unicode;
|
||||
# that is, they must be of type "str" but are restricted to
|
||||
# code points in the "latin-1" set.
|
||||
outstatus = outstatus.decode('ISO-8859-1')
|
||||
outheaders = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1'))
|
||||
for k, v in outheaders]
|
||||
|
||||
self.iter_response = iter(r.body)
|
||||
self.write = start_response(outstatus, outheaders)
|
||||
except:
|
||||
self.close()
|
||||
raise
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
if py3k:
|
||||
def __next__(self):
|
||||
return next(self.iter_response)
|
||||
else:
|
||||
def next(self):
|
||||
return self.iter_response.next()
|
||||
|
||||
def close(self):
|
||||
"""Close and de-reference the current request and response. (Core)"""
|
||||
streaming = _cherrypy.serving.response.stream
|
||||
self.cpapp.release_serving()
|
||||
|
||||
# We avoid the expense of examining the iterator to see if it's
|
||||
# closable unless we are streaming the response, as that's the
|
||||
# only situation where we are going to have an iterator which
|
||||
# may not have been exhausted yet.
|
||||
if streaming and is_closable_iterator(self.iter_response):
|
||||
iter_close = self.iter_response.close
|
||||
try:
|
||||
iter_close()
|
||||
except Exception:
|
||||
_cherrypy.log(traceback=True, severity=40)
|
||||
|
||||
def run(self):
|
||||
"""Create a Request object using environ."""
|
||||
env = self.environ.get
|
||||
|
||||
local = httputil.Host('', int(env('SERVER_PORT', 80)),
|
||||
env('SERVER_NAME', ''))
|
||||
remote = httputil.Host(env('REMOTE_ADDR', ''),
|
||||
int(env('REMOTE_PORT', -1) or -1),
|
||||
env('REMOTE_HOST', ''))
|
||||
scheme = env('wsgi.url_scheme')
|
||||
sproto = env('ACTUAL_SERVER_PROTOCOL', "HTTP/1.1")
|
||||
request, resp = self.cpapp.get_serving(local, remote, scheme, sproto)
|
||||
|
||||
# LOGON_USER is served by IIS, and is the name of the
|
||||
# user after having been mapped to a local account.
|
||||
# Both IIS and Apache set REMOTE_USER, when possible.
|
||||
request.login = env('LOGON_USER') or env('REMOTE_USER') or None
|
||||
request.multithread = self.environ['wsgi.multithread']
|
||||
request.multiprocess = self.environ['wsgi.multiprocess']
|
||||
request.wsgi_environ = self.environ
|
||||
request.prev = env('cherrypy.previous_request', None)
|
||||
|
||||
meth = self.environ['REQUEST_METHOD']
|
||||
|
||||
path = httputil.urljoin(self.environ.get('SCRIPT_NAME', ''),
|
||||
self.environ.get('PATH_INFO', ''))
|
||||
qs = self.environ.get('QUERY_STRING', '')
|
||||
|
||||
if py3k:
|
||||
# This isn't perfect; if the given PATH_INFO is in the
|
||||
# wrong encoding, it may fail to match the appropriate config
|
||||
# section URI. But meh.
|
||||
old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1')
|
||||
new_enc = self.cpapp.find_config(self.environ.get('PATH_INFO', ''),
|
||||
"request.uri_encoding", 'utf-8')
|
||||
if new_enc.lower() != old_enc.lower():
|
||||
# Even though the path and qs are unicode, the WSGI server
|
||||
# is required by PEP 3333 to coerce them to ISO-8859-1
|
||||
# masquerading as unicode. So we have to encode back to
|
||||
# bytes and then decode again using the "correct" encoding.
|
||||
try:
|
||||
u_path = path.encode(old_enc).decode(new_enc)
|
||||
u_qs = qs.encode(old_enc).decode(new_enc)
|
||||
except (UnicodeEncodeError, UnicodeDecodeError):
|
||||
# Just pass them through without transcoding and hope.
|
||||
pass
|
||||
else:
|
||||
# Only set transcoded values if they both succeed.
|
||||
path = u_path
|
||||
qs = u_qs
|
||||
|
||||
rproto = self.environ.get('SERVER_PROTOCOL')
|
||||
headers = self.translate_headers(self.environ)
|
||||
rfile = self.environ['wsgi.input']
|
||||
request.run(meth, path, qs, rproto, headers, rfile)
|
||||
|
||||
headerNames = {'HTTP_CGI_AUTHORIZATION': 'Authorization',
|
||||
'CONTENT_LENGTH': 'Content-Length',
|
||||
'CONTENT_TYPE': 'Content-Type',
|
||||
'REMOTE_HOST': 'Remote-Host',
|
||||
'REMOTE_ADDR': 'Remote-Addr',
|
||||
}
|
||||
|
||||
def translate_headers(self, environ):
|
||||
"""Translate CGI-environ header names to HTTP header names."""
|
||||
for cgiName in environ:
|
||||
# We assume all incoming header keys are uppercase already.
|
||||
if cgiName in self.headerNames:
|
||||
yield self.headerNames[cgiName], environ[cgiName]
|
||||
elif cgiName[:5] == "HTTP_":
|
||||
# Hackish attempt at recovering original header names.
|
||||
translatedHeader = cgiName[5:].replace("_", "-")
|
||||
yield translatedHeader, environ[cgiName]
|
||||
|
||||
|
||||
class CPWSGIApp(object):
|
||||
|
||||
"""A WSGI application object for a CherryPy Application."""
|
||||
|
||||
pipeline = [('ExceptionTrapper', ExceptionTrapper),
|
||||
('InternalRedirector', InternalRedirector),
|
||||
]
|
||||
"""A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a
|
||||
constructor that takes an initial, positional 'nextapp' argument,
|
||||
plus optional keyword arguments, and returns a WSGI application
|
||||
(that takes environ and start_response arguments). The 'name' can
|
||||
be any you choose, and will correspond to keys in self.config."""
|
||||
|
||||
head = None
|
||||
"""Rather than nest all apps in the pipeline on each call, it's only
|
||||
done the first time, and the result is memoized into self.head. Set
|
||||
this to None again if you change self.pipeline after calling self."""
|
||||
|
||||
config = {}
|
||||
"""A dict whose keys match names listed in the pipeline. Each
|
||||
value is a further dict which will be passed to the corresponding
|
||||
named WSGI callable (from the pipeline) as keyword arguments."""
|
||||
|
||||
response_class = AppResponse
|
||||
"""The class to instantiate and return as the next app in the WSGI chain.
|
||||
"""
|
||||
|
||||
def __init__(self, cpapp, pipeline=None):
|
||||
self.cpapp = cpapp
|
||||
self.pipeline = self.pipeline[:]
|
||||
if pipeline:
|
||||
self.pipeline.extend(pipeline)
|
||||
self.config = self.config.copy()
|
||||
|
||||
def tail(self, environ, start_response):
|
||||
"""WSGI application callable for the actual CherryPy application.
|
||||
|
||||
You probably shouldn't call this; call self.__call__ instead,
|
||||
so that any WSGI middleware in self.pipeline can run first.
|
||||
"""
|
||||
return self.response_class(environ, start_response, self.cpapp)
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
head = self.head
|
||||
if head is None:
|
||||
# Create and nest the WSGI apps in our pipeline (in reverse order).
|
||||
# Then memoize the result in self.head.
|
||||
head = self.tail
|
||||
for name, callable in self.pipeline[::-1]:
|
||||
conf = self.config.get(name, {})
|
||||
head = callable(head, **conf)
|
||||
self.head = head
|
||||
return head(environ, start_response)
|
||||
|
||||
def namespace_handler(self, k, v):
|
||||
"""Config handler for the 'wsgi' namespace."""
|
||||
if k == "pipeline":
|
||||
# Note this allows multiple 'wsgi.pipeline' config entries
|
||||
# (but each entry will be processed in a 'random' order).
|
||||
# It should also allow developers to set default middleware
|
||||
# in code (passed to self.__init__) that deployers can add to
|
||||
# (but not remove) via config.
|
||||
self.pipeline.extend(v)
|
||||
elif k == "response_class":
|
||||
self.response_class = v
|
||||
else:
|
||||
name, arg = k.split(".", 1)
|
||||
bucket = self.config.setdefault(name, {})
|
||||
bucket[arg] = v
|
||||
70
lib/cherrypy/_cpwsgi_server.py
Normal file
70
lib/cherrypy/_cpwsgi_server.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""WSGI server interface (see PEP 333). This adds some CP-specific bits to
|
||||
the framework-agnostic wsgiserver package.
|
||||
"""
|
||||
import sys
|
||||
|
||||
import cherrypy
|
||||
from cherrypy import wsgiserver
|
||||
|
||||
|
||||
class CPWSGIServer(wsgiserver.CherryPyWSGIServer):
|
||||
|
||||
"""Wrapper for wsgiserver.CherryPyWSGIServer.
|
||||
|
||||
wsgiserver has been designed to not reference CherryPy in any way,
|
||||
so that it can be used in other frameworks and applications. Therefore,
|
||||
we wrap it here, so we can set our own mount points from cherrypy.tree
|
||||
and apply some attributes from config -> cherrypy.server -> wsgiserver.
|
||||
"""
|
||||
|
||||
def __init__(self, server_adapter=cherrypy.server):
|
||||
self.server_adapter = server_adapter
|
||||
self.max_request_header_size = (
|
||||
self.server_adapter.max_request_header_size or 0
|
||||
)
|
||||
self.max_request_body_size = (
|
||||
self.server_adapter.max_request_body_size or 0
|
||||
)
|
||||
|
||||
server_name = (self.server_adapter.socket_host or
|
||||
self.server_adapter.socket_file or
|
||||
None)
|
||||
|
||||
self.wsgi_version = self.server_adapter.wsgi_version
|
||||
s = wsgiserver.CherryPyWSGIServer
|
||||
s.__init__(self, server_adapter.bind_addr, cherrypy.tree,
|
||||
self.server_adapter.thread_pool,
|
||||
server_name,
|
||||
max=self.server_adapter.thread_pool_max,
|
||||
request_queue_size=self.server_adapter.socket_queue_size,
|
||||
timeout=self.server_adapter.socket_timeout,
|
||||
shutdown_timeout=self.server_adapter.shutdown_timeout,
|
||||
accepted_queue_size=self.server_adapter.accepted_queue_size,
|
||||
accepted_queue_timeout=self.server_adapter.accepted_queue_timeout,
|
||||
)
|
||||
self.protocol = self.server_adapter.protocol_version
|
||||
self.nodelay = self.server_adapter.nodelay
|
||||
|
||||
if sys.version_info >= (3, 0):
|
||||
ssl_module = self.server_adapter.ssl_module or 'builtin'
|
||||
else:
|
||||
ssl_module = self.server_adapter.ssl_module or 'pyopenssl'
|
||||
if self.server_adapter.ssl_context:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
self.ssl_adapter.context = self.server_adapter.ssl_context
|
||||
elif self.server_adapter.ssl_certificate:
|
||||
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
|
||||
self.ssl_adapter = adapter_class(
|
||||
self.server_adapter.ssl_certificate,
|
||||
self.server_adapter.ssl_private_key,
|
||||
self.server_adapter.ssl_certificate_chain)
|
||||
|
||||
self.stats['Enabled'] = getattr(
|
||||
self.server_adapter, 'statistics', False)
|
||||
|
||||
def error_log(self, msg="", level=20, traceback=False):
|
||||
cherrypy.engine.log(msg, level, traceback)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user