Revert "Migrate as much as possible to pip and requirements.txt"

This reverts commit 982594a4a5.
This commit is contained in:
rembo10
2016-01-29 15:38:12 +00:00
parent 3c015990c5
commit 2110eb9855
1195 changed files with 196887 additions and 81 deletions

4
.gitignore vendored
View File

@@ -171,7 +171,3 @@ _ReSharper*/
/logs
.project
.pydevproject
.vs/
#Python virtual environment
env/

View File

@@ -1,24 +1,21 @@
# Travis CI configuration file
# http://about.travis-ci.org/docs/
language: python
# Available Python versions:
# http://about.travis-ci.org/docs/user/ci-environment/#Python-VM-images
python:
- '2.6'
- '2.7'
- "2.6"
- "2.7"
# pylint 1.4 does not run under python 2.6
install:
- pip install -r requirements.txt
- pip install pylint
- pip install pyflakes
- pip install pep8
- pip install pyOpenSSL
- pip install pylint==1.3.1
- pip install pyflakes
- pip install pep8
script:
- pep8 headphones
- pylint --rcfile=pylintrc headphones
- pyflakes headphones
before_deploy:
- pip install -r requirements.txt -t lib/
- find . -name "*.pyc" -delete
- find ./lib/ -path "*/*\-info/*" -delete
- zip -r -9 headphones.zip data/ headphones/ init-scripts/ lib/ Headphones.py
deploy:
provider: releases
api_key:
secure: <GITHUB OATH TOKEN HERE>
file: headphones.zip
on:
tags: true
- nosetests headphones

View File

@@ -0,0 +1,88 @@
#!/usr/bin/python
####
# 06/2010 Nic Wolfe <nic@wolfeden.ca>
# 02/2006 Will Holcomb <wholcomb@gmail.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
import urllib
import urllib2
import mimetools, mimetypes
import os, sys
# Controls how sequences are uncoded. If true, elements may be given multiple values by
# assigning a sequence.
doseq = 1
class MultipartPostHandler(urllib2.BaseHandler):
handler_order = urllib2.HTTPHandler.handler_order - 10 # needs to run first
def http_request(self, request):
data = request.get_data()
if data is not None and type(data) != str:
v_files = []
v_vars = []
try:
for(key, value) in data.items():
if type(value) in (file, list, tuple):
v_files.append((key, value))
else:
v_vars.append((key, value))
except TypeError:
systype, value, traceback = sys.exc_info()
raise TypeError, "not a valid non-string sequence or mapping object", traceback
if len(v_files) == 0:
data = urllib.urlencode(v_vars, doseq)
else:
boundary, data = MultipartPostHandler.multipart_encode(v_vars, v_files)
contenttype = 'multipart/form-data; boundary=%s' % boundary
if(request.has_header('Content-Type')
and request.get_header('Content-Type').find('multipart/form-data') != 0):
print "Replacing %s with %s" % (request.get_header('content-type'), 'multipart/form-data')
request.add_unredirected_header('Content-Type', contenttype)
request.add_data(data)
return request
@staticmethod
def multipart_encode(vars, files, boundary = None, buffer = None):
if boundary is None:
boundary = mimetools.choose_boundary()
if buffer is None:
buffer = ''
for(key, value) in vars:
buffer += '--%s\r\n' % boundary
buffer += 'Content-Disposition: form-data; name="%s"' % key
buffer += '\r\n\r\n' + value + '\r\n'
for(key, fd) in files:
# allow them to pass in a file or a tuple with name & data
if type(fd) == file:
name_in = fd.name
fd.seek(0)
data_in = fd.read()
elif type(fd) in (tuple, list):
name_in, data_in = fd
filename = os.path.basename(name_in)
contenttype = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
buffer += '--%s\r\n' % boundary
buffer += 'Content-Disposition: form-data; name="%s"; filename="%s"\r\n' % (key, filename)
buffer += 'Content-Type: %s\r\n' % contenttype
# buffer += 'Content-Length: %s\r\n' % file_size
buffer += '\r\n' + data_in + '\r\n'
buffer += '--%s--\r\n\r\n' % boundary
return boundary, buffer
https_request = http_request

View File

@@ -0,0 +1,5 @@
version_info = (3, 0, 1)
version = '3.0.1'
release = '3.0.1'
__version__ = release # PEP 396

73
lib/apscheduler/events.py Normal file
View File

@@ -0,0 +1,73 @@
__all__ = ('EVENT_SCHEDULER_START', 'EVENT_SCHEDULER_SHUTDOWN', 'EVENT_EXECUTOR_ADDED', 'EVENT_EXECUTOR_REMOVED',
'EVENT_JOBSTORE_ADDED', 'EVENT_JOBSTORE_REMOVED', 'EVENT_ALL_JOBS_REMOVED', 'EVENT_JOB_ADDED',
'EVENT_JOB_REMOVED', 'EVENT_JOB_MODIFIED', 'EVENT_JOB_EXECUTED', 'EVENT_JOB_ERROR', 'EVENT_JOB_MISSED',
'SchedulerEvent', 'JobEvent', 'JobExecutionEvent')
EVENT_SCHEDULER_START = 1
EVENT_SCHEDULER_SHUTDOWN = 2
EVENT_EXECUTOR_ADDED = 4
EVENT_EXECUTOR_REMOVED = 8
EVENT_JOBSTORE_ADDED = 16
EVENT_JOBSTORE_REMOVED = 32
EVENT_ALL_JOBS_REMOVED = 64
EVENT_JOB_ADDED = 128
EVENT_JOB_REMOVED = 256
EVENT_JOB_MODIFIED = 512
EVENT_JOB_EXECUTED = 1024
EVENT_JOB_ERROR = 2048
EVENT_JOB_MISSED = 4096
EVENT_ALL = (EVENT_SCHEDULER_START | EVENT_SCHEDULER_SHUTDOWN | EVENT_JOBSTORE_ADDED | EVENT_JOBSTORE_REMOVED |
EVENT_JOB_ADDED | EVENT_JOB_REMOVED | EVENT_JOB_MODIFIED | EVENT_JOB_EXECUTED |
EVENT_JOB_ERROR | EVENT_JOB_MISSED)
class SchedulerEvent(object):
"""
An event that concerns the scheduler itself.
:ivar code: the type code of this event
:ivar alias: alias of the job store or executor that was added or removed (if applicable)
"""
def __init__(self, code, alias=None):
super(SchedulerEvent, self).__init__()
self.code = code
self.alias = alias
def __repr__(self):
return '<%s (code=%d)>' % (self.__class__.__name__, self.code)
class JobEvent(SchedulerEvent):
"""
An event that concerns a job.
:ivar code: the type code of this event
:ivar job_id: identifier of the job in question
:ivar jobstore: alias of the job store containing the job in question
"""
def __init__(self, code, job_id, jobstore):
super(JobEvent, self).__init__(code)
self.code = code
self.job_id = job_id
self.jobstore = jobstore
class JobExecutionEvent(JobEvent):
"""
An event that concerns the execution of individual jobs.
:ivar scheduled_run_time: the time when the job was scheduled to be run
:ivar retval: the return value of the successfully executed job
:ivar exception: the exception raised by the job
:ivar traceback: a formatted traceback for the exception
"""
def __init__(self, code, job_id, jobstore, scheduled_run_time, retval=None, exception=None, traceback=None):
super(JobExecutionEvent, self).__init__(code, job_id, jobstore)
self.scheduled_run_time = scheduled_run_time
self.retval = retval
self.exception = exception
self.traceback = traceback

View File

View File

@@ -0,0 +1,28 @@
from __future__ import absolute_import
import sys
from apscheduler.executors.base import BaseExecutor, run_job
class AsyncIOExecutor(BaseExecutor):
"""
Runs jobs in the default executor of the event loop.
Plugin alias: ``asyncio``
"""
def start(self, scheduler, alias):
super(AsyncIOExecutor, self).start(scheduler, alias)
self._eventloop = scheduler._eventloop
def _do_submit_job(self, job, run_times):
def callback(f):
try:
events = f.result()
except:
self._run_job_error(job.id, *sys.exc_info()[1:])
else:
self._run_job_success(job.id, events)
f = self._eventloop.run_in_executor(None, run_job, job, job._jobstore_alias, run_times, self._logger.name)
f.add_done_callback(callback)

View File

@@ -0,0 +1,119 @@
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from datetime import datetime, timedelta
from traceback import format_tb
import logging
import sys
from pytz import utc
import six
from apscheduler.events import JobExecutionEvent, EVENT_JOB_MISSED, EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
class MaxInstancesReachedError(Exception):
def __init__(self, job):
super(MaxInstancesReachedError, self).__init__(
'Job "%s" has already reached its maximum number of instances (%d)' % (job.id, job.max_instances))
class BaseExecutor(six.with_metaclass(ABCMeta, object)):
"""Abstract base class that defines the interface that every executor must implement."""
_scheduler = None
_lock = None
_logger = logging.getLogger('apscheduler.executors')
def __init__(self):
super(BaseExecutor, self).__init__()
self._instances = defaultdict(lambda: 0)
def start(self, scheduler, alias):
"""
Called by the scheduler when the scheduler is being started or when the executor is being added to an already
running scheduler.
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this executor
:param str|unicode alias: alias of this executor as it was assigned to the scheduler
"""
self._scheduler = scheduler
self._lock = scheduler._create_lock()
self._logger = logging.getLogger('apscheduler.executors.%s' % alias)
def shutdown(self, wait=True):
"""
Shuts down this executor.
:param bool wait: ``True`` to wait until all submitted jobs have been executed
"""
def submit_job(self, job, run_times):
"""
Submits job for execution.
:param Job job: job to execute
:param list[datetime] run_times: list of datetimes specifying when the job should have been run
:raises MaxInstancesReachedError: if the maximum number of allowed instances for this job has been reached
"""
assert self._lock is not None, 'This executor has not been started yet'
with self._lock:
if self._instances[job.id] >= job.max_instances:
raise MaxInstancesReachedError(job)
self._do_submit_job(job, run_times)
self._instances[job.id] += 1
@abstractmethod
def _do_submit_job(self, job, run_times):
"""Performs the actual task of scheduling `run_job` to be called."""
def _run_job_success(self, job_id, events):
"""Called by the executor with the list of generated events when `run_job` has been successfully called."""
with self._lock:
self._instances[job_id] -= 1
for event in events:
self._scheduler._dispatch_event(event)
def _run_job_error(self, job_id, exc, traceback=None):
"""Called by the executor with the exception if there is an error calling `run_job`."""
with self._lock:
self._instances[job_id] -= 1
exc_info = (exc.__class__, exc, traceback)
self._logger.error('Error running job %s', job_id, exc_info=exc_info)
def run_job(job, jobstore_alias, run_times, logger_name):
"""Called by executors to run the job. Returns a list of scheduler events to be dispatched by the scheduler."""
events = []
logger = logging.getLogger(logger_name)
for run_time in run_times:
# See if the job missed its run time window, and handle possible misfires accordingly
if job.misfire_grace_time is not None:
difference = datetime.now(utc) - run_time
grace_time = timedelta(seconds=job.misfire_grace_time)
if difference > grace_time:
events.append(JobExecutionEvent(EVENT_JOB_MISSED, job.id, jobstore_alias, run_time))
logger.warning('Run time of job "%s" was missed by %s', job, difference)
continue
logger.info('Running job "%s" (scheduled at %s)', job, run_time)
try:
retval = job.func(*job.args, **job.kwargs)
except:
exc, tb = sys.exc_info()[1:]
formatted_tb = ''.join(format_tb(tb))
events.append(JobExecutionEvent(EVENT_JOB_ERROR, job.id, jobstore_alias, run_time, exception=exc,
traceback=formatted_tb))
logger.exception('Job "%s" raised an exception', job)
else:
events.append(JobExecutionEvent(EVENT_JOB_EXECUTED, job.id, jobstore_alias, run_time, retval=retval))
logger.info('Job "%s" executed successfully', job)
return events

View File

@@ -0,0 +1,19 @@
import sys
from apscheduler.executors.base import BaseExecutor, run_job
class DebugExecutor(BaseExecutor):
"""
A special executor that executes the target callable directly instead of deferring it to a thread or process.
Plugin alias: ``debug``
"""
def _do_submit_job(self, job, run_times):
try:
events = run_job(job, job._jobstore_alias, run_times, self._logger.name)
except:
self._run_job_error(job.id, *sys.exc_info()[1:])
else:
self._run_job_success(job.id, events)

View File

@@ -0,0 +1,29 @@
from __future__ import absolute_import
import sys
from apscheduler.executors.base import BaseExecutor, run_job
try:
import gevent
except ImportError: # pragma: nocover
raise ImportError('GeventExecutor requires gevent installed')
class GeventExecutor(BaseExecutor):
"""
Runs jobs as greenlets.
Plugin alias: ``gevent``
"""
def _do_submit_job(self, job, run_times):
def callback(greenlet):
try:
events = greenlet.get()
except:
self._run_job_error(job.id, *sys.exc_info()[1:])
else:
self._run_job_success(job.id, events)
gevent.spawn(run_job, job, job._jobstore_alias, run_times, self._logger.name).link(callback)

View File

@@ -0,0 +1,54 @@
from abc import abstractmethod
import concurrent.futures
from apscheduler.executors.base import BaseExecutor, run_job
class BasePoolExecutor(BaseExecutor):
@abstractmethod
def __init__(self, pool):
super(BasePoolExecutor, self).__init__()
self._pool = pool
def _do_submit_job(self, job, run_times):
def callback(f):
exc, tb = (f.exception_info() if hasattr(f, 'exception_info') else
(f.exception(), getattr(f.exception(), '__traceback__', None)))
if exc:
self._run_job_error(job.id, exc, tb)
else:
self._run_job_success(job.id, f.result())
f = self._pool.submit(run_job, job, job._jobstore_alias, run_times, self._logger.name)
f.add_done_callback(callback)
def shutdown(self, wait=True):
self._pool.shutdown(wait)
class ThreadPoolExecutor(BasePoolExecutor):
"""
An executor that runs jobs in a concurrent.futures thread pool.
Plugin alias: ``threadpool``
:param max_workers: the maximum number of spawned threads.
"""
def __init__(self, max_workers=10):
pool = concurrent.futures.ThreadPoolExecutor(int(max_workers))
super(ThreadPoolExecutor, self).__init__(pool)
class ProcessPoolExecutor(BasePoolExecutor):
"""
An executor that runs jobs in a concurrent.futures process pool.
Plugin alias: ``processpool``
:param max_workers: the maximum number of spawned processes.
"""
def __init__(self, max_workers=10):
pool = concurrent.futures.ProcessPoolExecutor(int(max_workers))
super(ProcessPoolExecutor, self).__init__(pool)

View File

@@ -0,0 +1,25 @@
from __future__ import absolute_import
from apscheduler.executors.base import BaseExecutor, run_job
class TwistedExecutor(BaseExecutor):
"""
Runs jobs in the reactor's thread pool.
Plugin alias: ``twisted``
"""
def start(self, scheduler, alias):
super(TwistedExecutor, self).start(scheduler, alias)
self._reactor = scheduler._reactor
def _do_submit_job(self, job, run_times):
def callback(success, result):
if success:
self._run_job_success(job.id, result)
else:
self._run_job_error(job.id, result.value, result.tb)
self._reactor.getThreadPool().callInThreadWithCallback(callback, run_job, job, job._jobstore_alias, run_times,
self._logger.name)

252
lib/apscheduler/job.py Normal file
View File

@@ -0,0 +1,252 @@
from collections import Iterable, Mapping
from uuid import uuid4
import six
from apscheduler.triggers.base import BaseTrigger
from apscheduler.util import ref_to_obj, obj_to_ref, datetime_repr, repr_escape, get_callable_name, check_callable_args, \
convert_to_datetime
class Job(object):
"""
Contains the options given when scheduling callables and its current schedule and other state.
This class should never be instantiated by the user.
:var str id: the unique identifier of this job
:var str name: the description of this job
:var func: the callable to execute
:var tuple|list args: positional arguments to the callable
:var dict kwargs: keyword arguments to the callable
:var bool coalesce: whether to only run the job once when several run times are due
:var trigger: the trigger object that controls the schedule of this job
:var str executor: the name of the executor that will run this job
:var int misfire_grace_time: the time (in seconds) how much this job's execution is allowed to be late
:var int max_instances: the maximum number of concurrently executing instances allowed for this job
:var datetime.datetime next_run_time: the next scheduled run time of this job
"""
__slots__ = ('_scheduler', '_jobstore_alias', 'id', 'trigger', 'executor', 'func', 'func_ref', 'args', 'kwargs',
'name', 'misfire_grace_time', 'coalesce', 'max_instances', 'next_run_time')
def __init__(self, scheduler, id=None, **kwargs):
super(Job, self).__init__()
self._scheduler = scheduler
self._jobstore_alias = None
self._modify(id=id or uuid4().hex, **kwargs)
def modify(self, **changes):
"""
Makes the given changes to this job and saves it in the associated job store.
Accepted keyword arguments are the same as the variables on this class.
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.modify_job`
"""
self._scheduler.modify_job(self.id, self._jobstore_alias, **changes)
def reschedule(self, trigger, **trigger_args):
"""
Shortcut for switching the trigger on this job.
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.reschedule_job`
"""
self._scheduler.reschedule_job(self.id, self._jobstore_alias, trigger, **trigger_args)
def pause(self):
"""
Temporarily suspend the execution of this job.
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.pause_job`
"""
self._scheduler.pause_job(self.id, self._jobstore_alias)
def resume(self):
"""
Resume the schedule of this job if previously paused.
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.resume_job`
"""
self._scheduler.resume_job(self.id, self._jobstore_alias)
def remove(self):
"""
Unschedules this job and removes it from its associated job store.
.. seealso:: :meth:`~apscheduler.schedulers.base.BaseScheduler.remove_job`
"""
self._scheduler.remove_job(self.id, self._jobstore_alias)
@property
def pending(self):
"""Returns ``True`` if the referenced job is still waiting to be added to its designated job store."""
return self._jobstore_alias is None
#
# Private API
#
def _get_run_times(self, now):
"""
Computes the scheduled run times between ``next_run_time`` and ``now`` (inclusive).
:type now: datetime.datetime
:rtype: list[datetime.datetime]
"""
run_times = []
next_run_time = self.next_run_time
while next_run_time and next_run_time <= now:
run_times.append(next_run_time)
next_run_time = self.trigger.get_next_fire_time(next_run_time, now)
return run_times
def _modify(self, **changes):
"""Validates the changes to the Job and makes the modifications if and only if all of them validate."""
approved = {}
if 'id' in changes:
value = changes.pop('id')
if not isinstance(value, six.string_types):
raise TypeError("id must be a nonempty string")
if hasattr(self, 'id'):
raise ValueError('The job ID may not be changed')
approved['id'] = value
if 'func' in changes or 'args' in changes or 'kwargs' in changes:
func = changes.pop('func') if 'func' in changes else self.func
args = changes.pop('args') if 'args' in changes else self.args
kwargs = changes.pop('kwargs') if 'kwargs' in changes else self.kwargs
if isinstance(func, str):
func_ref = func
func = ref_to_obj(func)
elif callable(func):
try:
func_ref = obj_to_ref(func)
except ValueError:
# If this happens, this Job won't be serializable
func_ref = None
else:
raise TypeError('func must be a callable or a textual reference to one')
if not hasattr(self, 'name') and changes.get('name', None) is None:
changes['name'] = get_callable_name(func)
if isinstance(args, six.string_types) or not isinstance(args, Iterable):
raise TypeError('args must be a non-string iterable')
if isinstance(kwargs, six.string_types) or not isinstance(kwargs, Mapping):
raise TypeError('kwargs must be a dict-like object')
check_callable_args(func, args, kwargs)
approved['func'] = func
approved['func_ref'] = func_ref
approved['args'] = args
approved['kwargs'] = kwargs
if 'name' in changes:
value = changes.pop('name')
if not value or not isinstance(value, six.string_types):
raise TypeError("name must be a nonempty string")
approved['name'] = value
if 'misfire_grace_time' in changes:
value = changes.pop('misfire_grace_time')
if value is not None and (not isinstance(value, six.integer_types) or value <= 0):
raise TypeError('misfire_grace_time must be either None or a positive integer')
approved['misfire_grace_time'] = value
if 'coalesce' in changes:
value = bool(changes.pop('coalesce'))
approved['coalesce'] = value
if 'max_instances' in changes:
value = changes.pop('max_instances')
if not isinstance(value, six.integer_types) or value <= 0:
raise TypeError('max_instances must be a positive integer')
approved['max_instances'] = value
if 'trigger' in changes:
trigger = changes.pop('trigger')
if not isinstance(trigger, BaseTrigger):
raise TypeError('Expected a trigger instance, got %s instead' % trigger.__class__.__name__)
approved['trigger'] = trigger
if 'executor' in changes:
value = changes.pop('executor')
if not isinstance(value, six.string_types):
raise TypeError('executor must be a string')
approved['executor'] = value
if 'next_run_time' in changes:
value = changes.pop('next_run_time')
approved['next_run_time'] = convert_to_datetime(value, self._scheduler.timezone, 'next_run_time')
if changes:
raise AttributeError('The following are not modifiable attributes of Job: %s' % ', '.join(changes))
for key, value in six.iteritems(approved):
setattr(self, key, value)
def __getstate__(self):
# Don't allow this Job to be serialized if the function reference could not be determined
if not self.func_ref:
raise ValueError('This Job cannot be serialized since the reference to its callable (%r) could not be '
'determined. Consider giving a textual reference (module:function name) instead.' %
(self.func,))
return {
'version': 1,
'id': self.id,
'func': self.func_ref,
'trigger': self.trigger,
'executor': self.executor,
'args': self.args,
'kwargs': self.kwargs,
'name': self.name,
'misfire_grace_time': self.misfire_grace_time,
'coalesce': self.coalesce,
'max_instances': self.max_instances,
'next_run_time': self.next_run_time
}
def __setstate__(self, state):
if state.get('version', 1) > 1:
raise ValueError('Job has version %s, but only version 1 can be handled' % state['version'])
self.id = state['id']
self.func_ref = state['func']
self.func = ref_to_obj(self.func_ref)
self.trigger = state['trigger']
self.executor = state['executor']
self.args = state['args']
self.kwargs = state['kwargs']
self.name = state['name']
self.misfire_grace_time = state['misfire_grace_time']
self.coalesce = state['coalesce']
self.max_instances = state['max_instances']
self.next_run_time = state['next_run_time']
def __eq__(self, other):
if isinstance(other, Job):
return self.id == other.id
return NotImplemented
def __repr__(self):
return '<Job (id=%s name=%s)>' % (repr_escape(self.id), repr_escape(self.name))
def __str__(self):
return '%s (trigger: %s, next run at: %s)' % (repr_escape(self.name), repr_escape(str(self.trigger)),
datetime_repr(self.next_run_time))
def __unicode__(self):
return six.u('%s (trigger: %s, next run at: %s)') % (self.name, self.trigger, datetime_repr(self.next_run_time))

View File

View File

@@ -0,0 +1,127 @@
from abc import ABCMeta, abstractmethod
import logging
import six
class JobLookupError(KeyError):
"""Raised when the job store cannot find a job for update or removal."""
def __init__(self, job_id):
super(JobLookupError, self).__init__(six.u('No job by the id of %s was found') % job_id)
class ConflictingIdError(KeyError):
"""Raised when the uniqueness of job IDs is being violated."""
def __init__(self, job_id):
super(ConflictingIdError, self).__init__(six.u('Job identifier (%s) conflicts with an existing job') % job_id)
class TransientJobError(ValueError):
"""Raised when an attempt to add transient (with no func_ref) job to a persistent job store is detected."""
def __init__(self, job_id):
super(TransientJobError, self).__init__(
six.u('Job (%s) cannot be added to this job store because a reference to the callable could not be '
'determined.') % job_id)
class BaseJobStore(six.with_metaclass(ABCMeta)):
"""Abstract base class that defines the interface that every job store must implement."""
_scheduler = None
_alias = None
_logger = logging.getLogger('apscheduler.jobstores')
def start(self, scheduler, alias):
"""
Called by the scheduler when the scheduler is being started or when the job store is being added to an already
running scheduler.
:param apscheduler.schedulers.base.BaseScheduler scheduler: the scheduler that is starting this job store
:param str|unicode alias: alias of this job store as it was assigned to the scheduler
"""
self._scheduler = scheduler
self._alias = alias
self._logger = logging.getLogger('apscheduler.jobstores.%s' % alias)
def shutdown(self):
"""Frees any resources still bound to this job store."""
@abstractmethod
def lookup_job(self, job_id):
"""
Returns a specific job, or ``None`` if it isn't found..
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned job to
point to the scheduler and itself, respectively.
:param str|unicode job_id: identifier of the job
:rtype: Job
"""
@abstractmethod
def get_due_jobs(self, now):
"""
Returns the list of jobs that have ``next_run_time`` earlier or equal to ``now``.
The returned jobs must be sorted by next run time (ascending).
:param datetime.datetime now: the current (timezone aware) datetime
:rtype: list[Job]
"""
@abstractmethod
def get_next_run_time(self):
"""
Returns the earliest run time of all the jobs stored in this job store, or ``None`` if there are no active jobs.
:rtype: datetime.datetime
"""
@abstractmethod
def get_all_jobs(self):
"""
Returns a list of all jobs in this job store. The returned jobs should be sorted by next run time (ascending).
Paused jobs (next_run_time is None) should be sorted last.
The job store is responsible for setting the ``scheduler`` and ``jobstore`` attributes of the returned jobs to
point to the scheduler and itself, respectively.
:rtype: list[Job]
"""
@abstractmethod
def add_job(self, job):
"""
Adds the given job to this store.
:param Job job: the job to add
:raises ConflictingIdError: if there is another job in this store with the same ID
"""
@abstractmethod
def update_job(self, job):
"""
Replaces the job in the store with the given newer version.
:param Job job: the job to update
:raises JobLookupError: if the job does not exist
"""
@abstractmethod
def remove_job(self, job_id):
"""
Removes the given job from this store.
:param str|unicode job_id: identifier of the job
:raises JobLookupError: if the job does not exist
"""
@abstractmethod
def remove_all_jobs(self):
"""Removes all jobs from this store."""
def __repr__(self):
return '<%s>' % self.__class__.__name__

View File

@@ -0,0 +1,107 @@
from __future__ import absolute_import
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
from apscheduler.util import datetime_to_utc_timestamp
class MemoryJobStore(BaseJobStore):
"""
Stores jobs in an array in RAM. Provides no persistence support.
Plugin alias: ``memory``
"""
def __init__(self):
super(MemoryJobStore, self).__init__()
self._jobs = [] # list of (job, timestamp), sorted by next_run_time and job id (ascending)
self._jobs_index = {} # id -> (job, timestamp) lookup table
def lookup_job(self, job_id):
return self._jobs_index.get(job_id, (None, None))[0]
def get_due_jobs(self, now):
now_timestamp = datetime_to_utc_timestamp(now)
pending = []
for job, timestamp in self._jobs:
if timestamp is None or timestamp > now_timestamp:
break
pending.append(job)
return pending
def get_next_run_time(self):
return self._jobs[0][0].next_run_time if self._jobs else None
def get_all_jobs(self):
return [j[0] for j in self._jobs]
def add_job(self, job):
if job.id in self._jobs_index:
raise ConflictingIdError(job.id)
timestamp = datetime_to_utc_timestamp(job.next_run_time)
index = self._get_job_index(timestamp, job.id)
self._jobs.insert(index, (job, timestamp))
self._jobs_index[job.id] = (job, timestamp)
def update_job(self, job):
old_job, old_timestamp = self._jobs_index.get(job.id, (None, None))
if old_job is None:
raise JobLookupError(job.id)
# If the next run time has not changed, simply replace the job in its present index.
# Otherwise, reinsert the job to the list to preserve the ordering.
old_index = self._get_job_index(old_timestamp, old_job.id)
new_timestamp = datetime_to_utc_timestamp(job.next_run_time)
if old_timestamp == new_timestamp:
self._jobs[old_index] = (job, new_timestamp)
else:
del self._jobs[old_index]
new_index = self._get_job_index(new_timestamp, job.id)
self._jobs.insert(new_index, (job, new_timestamp))
self._jobs_index[old_job.id] = (job, new_timestamp)
def remove_job(self, job_id):
job, timestamp = self._jobs_index.get(job_id, (None, None))
if job is None:
raise JobLookupError(job_id)
index = self._get_job_index(timestamp, job_id)
del self._jobs[index]
del self._jobs_index[job.id]
def remove_all_jobs(self):
self._jobs = []
self._jobs_index = {}
def shutdown(self):
self.remove_all_jobs()
def _get_job_index(self, timestamp, job_id):
"""
Returns the index of the given job, or if it's not found, the index where the job should be inserted based on
the given timestamp.
:type timestamp: int
:type job_id: str
"""
lo, hi = 0, len(self._jobs)
timestamp = float('inf') if timestamp is None else timestamp
while lo < hi:
mid = (lo + hi) // 2
mid_job, mid_timestamp = self._jobs[mid]
mid_timestamp = float('inf') if mid_timestamp is None else mid_timestamp
if mid_timestamp > timestamp:
hi = mid
elif mid_timestamp < timestamp:
lo = mid + 1
elif mid_job.id > job_id:
hi = mid
elif mid_job.id < job_id:
lo = mid + 1
else:
return mid
return lo

View File

@@ -0,0 +1,124 @@
from __future__ import absolute_import
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
from apscheduler.job import Job
try:
import cPickle as pickle
except ImportError: # pragma: nocover
import pickle
try:
from bson.binary import Binary
from pymongo.errors import DuplicateKeyError
from pymongo import MongoClient, ASCENDING
except ImportError: # pragma: nocover
raise ImportError('MongoDBJobStore requires PyMongo installed')
class MongoDBJobStore(BaseJobStore):
"""
Stores jobs in a MongoDB database. Any leftover keyword arguments are directly passed to pymongo's `MongoClient
<http://api.mongodb.org/python/current/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient>`_.
Plugin alias: ``mongodb``
:param str database: database to store jobs in
:param str collection: collection to store jobs in
:param client: a :class:`~pymongo.mongo_client.MongoClient` instance to use instead of providing connection
arguments
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
"""
def __init__(self, database='apscheduler', collection='jobs', client=None,
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
super(MongoDBJobStore, self).__init__()
self.pickle_protocol = pickle_protocol
if not database:
raise ValueError('The "database" parameter must not be empty')
if not collection:
raise ValueError('The "collection" parameter must not be empty')
if client:
self.connection = maybe_ref(client)
else:
connect_args.setdefault('w', 1)
self.connection = MongoClient(**connect_args)
self.collection = self.connection[database][collection]
self.collection.ensure_index('next_run_time', sparse=True)
def lookup_job(self, job_id):
document = self.collection.find_one(job_id, ['job_state'])
return self._reconstitute_job(document['job_state']) if document else None
def get_due_jobs(self, now):
timestamp = datetime_to_utc_timestamp(now)
return self._get_jobs({'next_run_time': {'$lte': timestamp}})
def get_next_run_time(self):
document = self.collection.find_one({'next_run_time': {'$ne': None}}, fields=['next_run_time'],
sort=[('next_run_time', ASCENDING)])
return utc_timestamp_to_datetime(document['next_run_time']) if document else None
def get_all_jobs(self):
return self._get_jobs({})
def add_job(self, job):
try:
self.collection.insert({
'_id': job.id,
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
})
except DuplicateKeyError:
raise ConflictingIdError(job.id)
def update_job(self, job):
changes = {
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': Binary(pickle.dumps(job.__getstate__(), self.pickle_protocol))
}
result = self.collection.update({'_id': job.id}, {'$set': changes})
if result and result['n'] == 0:
raise JobLookupError(id)
def remove_job(self, job_id):
result = self.collection.remove(job_id)
if result and result['n'] == 0:
raise JobLookupError(job_id)
def remove_all_jobs(self):
self.collection.remove()
def shutdown(self):
self.connection.disconnect()
def _reconstitute_job(self, job_state):
job_state = pickle.loads(job_state)
job = Job.__new__(Job)
job.__setstate__(job_state)
job._scheduler = self._scheduler
job._jobstore_alias = self._alias
return job
def _get_jobs(self, conditions):
jobs = []
failed_job_ids = []
for document in self.collection.find(conditions, ['_id', 'job_state'], sort=[('next_run_time', ASCENDING)]):
try:
jobs.append(self._reconstitute_job(document['job_state']))
except:
self._logger.exception('Unable to restore job "%s" -- removing it', document['_id'])
failed_job_ids.append(document['_id'])
# Remove all the jobs we failed to restore
if failed_job_ids:
self.collection.remove({'_id': {'$in': failed_job_ids}})
return jobs
def __repr__(self):
return '<%s (client=%s)>' % (self.__class__.__name__, self.connection)

View File

@@ -0,0 +1,138 @@
from __future__ import absolute_import
import six
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
from apscheduler.util import datetime_to_utc_timestamp, utc_timestamp_to_datetime
from apscheduler.job import Job
try:
import cPickle as pickle
except ImportError: # pragma: nocover
import pickle
try:
from redis import StrictRedis
except ImportError: # pragma: nocover
raise ImportError('RedisJobStore requires redis installed')
class RedisJobStore(BaseJobStore):
"""
Stores jobs in a Redis database. Any leftover keyword arguments are directly passed to redis's StrictRedis.
Plugin alias: ``redis``
:param int db: the database number to store jobs in
:param str jobs_key: key to store jobs in
:param str run_times_key: key to store the jobs' run times in
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
"""
def __init__(self, db=0, jobs_key='apscheduler.jobs', run_times_key='apscheduler.run_times',
pickle_protocol=pickle.HIGHEST_PROTOCOL, **connect_args):
super(RedisJobStore, self).__init__()
if db is None:
raise ValueError('The "db" parameter must not be empty')
if not jobs_key:
raise ValueError('The "jobs_key" parameter must not be empty')
if not run_times_key:
raise ValueError('The "run_times_key" parameter must not be empty')
self.pickle_protocol = pickle_protocol
self.jobs_key = jobs_key
self.run_times_key = run_times_key
self.redis = StrictRedis(db=int(db), **connect_args)
def lookup_job(self, job_id):
job_state = self.redis.hget(self.jobs_key, job_id)
return self._reconstitute_job(job_state) if job_state else None
def get_due_jobs(self, now):
timestamp = datetime_to_utc_timestamp(now)
job_ids = self.redis.zrangebyscore(self.run_times_key, 0, timestamp)
if job_ids:
job_states = self.redis.hmget(self.jobs_key, *job_ids)
return self._reconstitute_jobs(six.moves.zip(job_ids, job_states))
return []
def get_next_run_time(self):
next_run_time = self.redis.zrange(self.run_times_key, 0, 0, withscores=True)
if next_run_time:
return utc_timestamp_to_datetime(next_run_time[0][1])
def get_all_jobs(self):
job_states = self.redis.hgetall(self.jobs_key)
jobs = self._reconstitute_jobs(six.iteritems(job_states))
return sorted(jobs, key=lambda job: job.next_run_time)
def add_job(self, job):
if self.redis.hexists(self.jobs_key, job.id):
raise ConflictingIdError(job.id)
with self.redis.pipeline() as pipe:
pipe.multi()
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
pipe.execute()
def update_job(self, job):
if not self.redis.hexists(self.jobs_key, job.id):
raise JobLookupError(job.id)
with self.redis.pipeline() as pipe:
pipe.hset(self.jobs_key, job.id, pickle.dumps(job.__getstate__(), self.pickle_protocol))
if job.next_run_time:
pipe.zadd(self.run_times_key, datetime_to_utc_timestamp(job.next_run_time), job.id)
else:
pipe.zrem(self.run_times_key, job.id)
pipe.execute()
def remove_job(self, job_id):
if not self.redis.hexists(self.jobs_key, job_id):
raise JobLookupError(job_id)
with self.redis.pipeline() as pipe:
pipe.hdel(self.jobs_key, job_id)
pipe.zrem(self.run_times_key, job_id)
pipe.execute()
def remove_all_jobs(self):
with self.redis.pipeline() as pipe:
pipe.delete(self.jobs_key)
pipe.delete(self.run_times_key)
pipe.execute()
def shutdown(self):
self.redis.connection_pool.disconnect()
def _reconstitute_job(self, job_state):
job_state = pickle.loads(job_state)
job = Job.__new__(Job)
job.__setstate__(job_state)
job._scheduler = self._scheduler
job._jobstore_alias = self._alias
return job
def _reconstitute_jobs(self, job_states):
jobs = []
failed_job_ids = []
for job_id, job_state in job_states:
try:
jobs.append(self._reconstitute_job(job_state))
except:
self._logger.exception('Unable to restore job "%s" -- removing it', job_id)
failed_job_ids.append(job_id)
# Remove all the jobs we failed to restore
if failed_job_ids:
with self.redis.pipeline() as pipe:
pipe.hdel(self.jobs_key, *failed_job_ids)
pipe.zrem(self.run_times_key, *failed_job_ids)
pipe.execute()
return jobs
def __repr__(self):
return '<%s>' % self.__class__.__name__

View File

@@ -0,0 +1,137 @@
from __future__ import absolute_import
from apscheduler.jobstores.base import BaseJobStore, JobLookupError, ConflictingIdError
from apscheduler.util import maybe_ref, datetime_to_utc_timestamp, utc_timestamp_to_datetime
from apscheduler.job import Job
try:
import cPickle as pickle
except ImportError: # pragma: nocover
import pickle
try:
from sqlalchemy import create_engine, Table, Column, MetaData, Unicode, Float, LargeBinary, select
from sqlalchemy.exc import IntegrityError
except ImportError: # pragma: nocover
raise ImportError('SQLAlchemyJobStore requires SQLAlchemy installed')
class SQLAlchemyJobStore(BaseJobStore):
"""
Stores jobs in a database table using SQLAlchemy. The table will be created if it doesn't exist in the database.
Plugin alias: ``sqlalchemy``
:param str url: connection string (see `SQLAlchemy documentation
<http://docs.sqlalchemy.org/en/latest/core/engines.html?highlight=create_engine#database-urls>`_
on this)
:param engine: an SQLAlchemy Engine to use instead of creating a new one based on ``url``
:param str tablename: name of the table to store jobs in
:param metadata: a :class:`~sqlalchemy.MetaData` instance to use instead of creating a new one
:param int pickle_protocol: pickle protocol level to use (for serialization), defaults to the highest available
"""
def __init__(self, url=None, engine=None, tablename='apscheduler_jobs', metadata=None,
pickle_protocol=pickle.HIGHEST_PROTOCOL):
super(SQLAlchemyJobStore, self).__init__()
self.pickle_protocol = pickle_protocol
metadata = maybe_ref(metadata) or MetaData()
if engine:
self.engine = maybe_ref(engine)
elif url:
self.engine = create_engine(url)
else:
raise ValueError('Need either "engine" or "url" defined')
# 191 = max key length in MySQL for InnoDB/utf8mb4 tables, 25 = precision that translates to an 8-byte float
self.jobs_t = Table(
tablename, metadata,
Column('id', Unicode(191, _warn_on_bytestring=False), primary_key=True),
Column('next_run_time', Float(25), index=True),
Column('job_state', LargeBinary, nullable=False)
)
self.jobs_t.create(self.engine, True)
def lookup_job(self, job_id):
selectable = select([self.jobs_t.c.job_state]).where(self.jobs_t.c.id == job_id)
job_state = self.engine.execute(selectable).scalar()
return self._reconstitute_job(job_state) if job_state else None
def get_due_jobs(self, now):
timestamp = datetime_to_utc_timestamp(now)
return self._get_jobs(self.jobs_t.c.next_run_time <= timestamp)
def get_next_run_time(self):
selectable = select([self.jobs_t.c.next_run_time]).where(self.jobs_t.c.next_run_time != None).\
order_by(self.jobs_t.c.next_run_time).limit(1)
next_run_time = self.engine.execute(selectable).scalar()
return utc_timestamp_to_datetime(next_run_time)
def get_all_jobs(self):
return self._get_jobs()
def add_job(self, job):
insert = self.jobs_t.insert().values(**{
'id': job.id,
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
})
try:
self.engine.execute(insert)
except IntegrityError:
raise ConflictingIdError(job.id)
def update_job(self, job):
update = self.jobs_t.update().values(**{
'next_run_time': datetime_to_utc_timestamp(job.next_run_time),
'job_state': pickle.dumps(job.__getstate__(), self.pickle_protocol)
}).where(self.jobs_t.c.id == job.id)
result = self.engine.execute(update)
if result.rowcount == 0:
raise JobLookupError(id)
def remove_job(self, job_id):
delete = self.jobs_t.delete().where(self.jobs_t.c.id == job_id)
result = self.engine.execute(delete)
if result.rowcount == 0:
raise JobLookupError(job_id)
def remove_all_jobs(self):
delete = self.jobs_t.delete()
self.engine.execute(delete)
def shutdown(self):
self.engine.dispose()
def _reconstitute_job(self, job_state):
job_state = pickle.loads(job_state)
job_state['jobstore'] = self
job = Job.__new__(Job)
job.__setstate__(job_state)
job._scheduler = self._scheduler
job._jobstore_alias = self._alias
return job
def _get_jobs(self, *conditions):
jobs = []
selectable = select([self.jobs_t.c.id, self.jobs_t.c.job_state]).order_by(self.jobs_t.c.next_run_time)
selectable = selectable.where(*conditions) if conditions else selectable
failed_job_ids = set()
for row in self.engine.execute(selectable):
try:
jobs.append(self._reconstitute_job(row.job_state))
except:
self._logger.exception('Unable to restore job "%s" -- removing it', row.id)
failed_job_ids.add(row.id)
# Remove all the jobs we failed to restore
if failed_job_ids:
delete = self.jobs_t.delete().where(self.jobs_t.c.id.in_(failed_job_ids))
self.engine.execute(delete)
return jobs
def __repr__(self):
return '<%s (url=%s)>' % (self.__class__.__name__, self.engine.url)

View File

@@ -0,0 +1,12 @@
class SchedulerAlreadyRunningError(Exception):
"""Raised when attempting to start or configure the scheduler when it's already running."""
def __str__(self):
return 'Scheduler is already running'
class SchedulerNotRunningError(Exception):
"""Raised when attempting to shutdown the scheduler when it's not running."""
def __str__(self):
return 'Scheduler is not running'

View File

@@ -0,0 +1,68 @@
from __future__ import absolute_import
from functools import wraps
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.util import maybe_ref
try:
import asyncio
except ImportError: # pragma: nocover
try:
import trollius as asyncio
except ImportError:
raise ImportError('AsyncIOScheduler requires either Python 3.4 or the asyncio package installed')
def run_in_event_loop(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
self._eventloop.call_soon_threadsafe(func, self, *args, **kwargs)
return wrapper
class AsyncIOScheduler(BaseScheduler):
"""
A scheduler that runs on an asyncio (:pep:`3156`) event loop.
Extra options:
============== =============================================================
``event_loop`` AsyncIO event loop to use (defaults to the global event loop)
============== =============================================================
"""
_eventloop = None
_timeout = None
def start(self):
super(AsyncIOScheduler, self).start()
self.wakeup()
@run_in_event_loop
def shutdown(self, wait=True):
super(AsyncIOScheduler, self).shutdown(wait)
self._stop_timer()
def _configure(self, config):
self._eventloop = maybe_ref(config.pop('event_loop', None)) or asyncio.get_event_loop()
super(AsyncIOScheduler, self)._configure(config)
def _start_timer(self, wait_seconds):
self._stop_timer()
if wait_seconds is not None:
self._timeout = self._eventloop.call_later(wait_seconds, self.wakeup)
def _stop_timer(self):
if self._timeout:
self._timeout.cancel()
del self._timeout
@run_in_event_loop
def wakeup(self):
self._stop_timer()
wait_seconds = self._process_jobs()
self._start_timer(wait_seconds)
def _create_default_executor(self):
from apscheduler.executors.asyncio import AsyncIOExecutor
return AsyncIOExecutor()

View File

@@ -0,0 +1,39 @@
from __future__ import absolute_import
from threading import Thread, Event
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.util import asbool
class BackgroundScheduler(BlockingScheduler):
"""
A scheduler that runs in the background using a separate thread
(:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will return immediately).
Extra options:
========== ============================================================================================
``daemon`` Set the ``daemon`` option in the background thread (defaults to ``True``,
see `the documentation <https://docs.python.org/3.4/library/threading.html#thread-objects>`_
for further details)
========== ============================================================================================
"""
_thread = None
def _configure(self, config):
self._daemon = asbool(config.pop('daemon', True))
super(BackgroundScheduler, self)._configure(config)
def start(self):
BaseScheduler.start(self)
self._event = Event()
self._thread = Thread(target=self._main_loop, name='APScheduler')
self._thread.daemon = self._daemon
self._thread.start()
def shutdown(self, wait=True):
super(BackgroundScheduler, self).shutdown(wait)
self._thread.join()
del self._thread

View File

@@ -0,0 +1,845 @@
from __future__ import print_function
from abc import ABCMeta, abstractmethod
from collections import MutableMapping
from threading import RLock
from datetime import datetime
from logging import getLogger
import sys
from pkg_resources import iter_entry_points
from tzlocal import get_localzone
import six
from apscheduler.schedulers import SchedulerAlreadyRunningError, SchedulerNotRunningError
from apscheduler.executors.base import MaxInstancesReachedError, BaseExecutor
from apscheduler.executors.pool import ThreadPoolExecutor
from apscheduler.jobstores.base import ConflictingIdError, JobLookupError, BaseJobStore
from apscheduler.jobstores.memory import MemoryJobStore
from apscheduler.job import Job
from apscheduler.triggers.base import BaseTrigger
from apscheduler.util import asbool, asint, astimezone, maybe_ref, timedelta_seconds, undefined
from apscheduler.events import (
SchedulerEvent, JobEvent, EVENT_SCHEDULER_START, EVENT_SCHEDULER_SHUTDOWN, EVENT_JOBSTORE_ADDED,
EVENT_JOBSTORE_REMOVED, EVENT_ALL, EVENT_JOB_MODIFIED, EVENT_JOB_REMOVED, EVENT_JOB_ADDED, EVENT_EXECUTOR_ADDED,
EVENT_EXECUTOR_REMOVED, EVENT_ALL_JOBS_REMOVED)
class BaseScheduler(six.with_metaclass(ABCMeta)):
"""
Abstract base class for all schedulers. Takes the following keyword arguments:
:param str|logging.Logger logger: logger to use for the scheduler's logging (defaults to apscheduler.scheduler)
:param str|datetime.tzinfo timezone: the default time zone (defaults to the local timezone)
:param dict job_defaults: default values for newly added jobs
:param dict jobstores: a dictionary of job store alias -> job store instance or configuration dict
:param dict executors: a dictionary of executor alias -> executor instance or configuration dict
.. seealso:: :ref:`scheduler-config`
"""
_trigger_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.triggers'))
_trigger_classes = {}
_executor_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.executors'))
_executor_classes = {}
_jobstore_plugins = dict((ep.name, ep) for ep in iter_entry_points('apscheduler.jobstores'))
_jobstore_classes = {}
_stopped = True
#
# Public API
#
def __init__(self, gconfig={}, **options):
super(BaseScheduler, self).__init__()
self._executors = {}
self._executors_lock = self._create_lock()
self._jobstores = {}
self._jobstores_lock = self._create_lock()
self._listeners = []
self._listeners_lock = self._create_lock()
self._pending_jobs = []
self.configure(gconfig, **options)
def configure(self, gconfig={}, prefix='apscheduler.', **options):
"""
Reconfigures the scheduler with the given options. Can only be done when the scheduler isn't running.
:param dict gconfig: a "global" configuration dictionary whose values can be overridden by keyword arguments to
this method
:param str|unicode prefix: pick only those keys from ``gconfig`` that are prefixed with this string
(pass an empty string or ``None`` to use all keys)
:raises SchedulerAlreadyRunningError: if the scheduler is already running
"""
if self.running:
raise SchedulerAlreadyRunningError
# If a non-empty prefix was given, strip it from the keys in the global configuration dict
if prefix:
prefixlen = len(prefix)
gconfig = dict((key[prefixlen:], value) for key, value in six.iteritems(gconfig) if key.startswith(prefix))
# Create a structure from the dotted options (e.g. "a.b.c = d" -> {'a': {'b': {'c': 'd'}}})
config = {}
for key, value in six.iteritems(gconfig):
parts = key.split('.')
parent = config
key = parts.pop(0)
while parts:
parent = parent.setdefault(key, {})
key = parts.pop(0)
parent[key] = value
# Override any options with explicit keyword arguments
config.update(options)
self._configure(config)
@abstractmethod
def start(self):
"""
Starts the scheduler. The details of this process depend on the implementation.
:raises SchedulerAlreadyRunningError: if the scheduler is already running
"""
if self.running:
raise SchedulerAlreadyRunningError
with self._executors_lock:
# Create a default executor if nothing else is configured
if 'default' not in self._executors:
self.add_executor(self._create_default_executor(), 'default')
# Start all the executors
for alias, executor in six.iteritems(self._executors):
executor.start(self, alias)
with self._jobstores_lock:
# Create a default job store if nothing else is configured
if 'default' not in self._jobstores:
self.add_jobstore(self._create_default_jobstore(), 'default')
# Start all the job stores
for alias, store in six.iteritems(self._jobstores):
store.start(self, alias)
# Schedule all pending jobs
for job, jobstore_alias, replace_existing in self._pending_jobs:
self._real_add_job(job, jobstore_alias, replace_existing, False)
del self._pending_jobs[:]
self._stopped = False
self._logger.info('Scheduler started')
# Notify listeners that the scheduler has been started
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_START))
@abstractmethod
def shutdown(self, wait=True):
"""
Shuts down the scheduler. Does not interrupt any currently running jobs.
:param bool wait: ``True`` to wait until all currently executing jobs have finished
:raises SchedulerNotRunningError: if the scheduler has not been started yet
"""
if not self.running:
raise SchedulerNotRunningError
self._stopped = True
# Shut down all executors
for executor in six.itervalues(self._executors):
executor.shutdown(wait)
# Shut down all job stores
for jobstore in six.itervalues(self._jobstores):
jobstore.shutdown()
self._logger.info('Scheduler has been shut down')
self._dispatch_event(SchedulerEvent(EVENT_SCHEDULER_SHUTDOWN))
@property
def running(self):
return not self._stopped
def add_executor(self, executor, alias='default', **executor_opts):
"""
Adds an executor to this scheduler. Any extra keyword arguments will be passed to the executor plugin's
constructor, assuming that the first argument is the name of an executor plugin.
:param str|unicode|apscheduler.executors.base.BaseExecutor executor: either an executor instance or the name of
an executor plugin
:param str|unicode alias: alias for the scheduler
:raises ValueError: if there is already an executor by the given alias
"""
with self._executors_lock:
if alias in self._executors:
raise ValueError('This scheduler already has an executor by the alias of "%s"' % alias)
if isinstance(executor, BaseExecutor):
self._executors[alias] = executor
elif isinstance(executor, six.string_types):
self._executors[alias] = executor = self._create_plugin_instance('executor', executor, executor_opts)
else:
raise TypeError('Expected an executor instance or a string, got %s instead' %
executor.__class__.__name__)
# Start the executor right away if the scheduler is running
if self.running:
executor.start(self)
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_ADDED, alias))
def remove_executor(self, alias, shutdown=True):
"""
Removes the executor by the given alias from this scheduler.
:param str|unicode alias: alias of the executor
:param bool shutdown: ``True`` to shut down the executor after removing it
"""
with self._jobstores_lock:
executor = self._lookup_executor(alias)
del self._executors[alias]
if shutdown:
executor.shutdown()
self._dispatch_event(SchedulerEvent(EVENT_EXECUTOR_REMOVED, alias))
def add_jobstore(self, jobstore, alias='default', **jobstore_opts):
"""
Adds a job store to this scheduler. Any extra keyword arguments will be passed to the job store plugin's
constructor, assuming that the first argument is the name of a job store plugin.
:param str|unicode|apscheduler.jobstores.base.BaseJobStore jobstore: job store to be added
:param str|unicode alias: alias for the job store
:raises ValueError: if there is already a job store by the given alias
"""
with self._jobstores_lock:
if alias in self._jobstores:
raise ValueError('This scheduler already has a job store by the alias of "%s"' % alias)
if isinstance(jobstore, BaseJobStore):
self._jobstores[alias] = jobstore
elif isinstance(jobstore, six.string_types):
self._jobstores[alias] = jobstore = self._create_plugin_instance('jobstore', jobstore, jobstore_opts)
else:
raise TypeError('Expected a job store instance or a string, got %s instead' %
jobstore.__class__.__name__)
# Start the job store right away if the scheduler is running
if self.running:
jobstore.start(self, alias)
# Notify listeners that a new job store has been added
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_ADDED, alias))
# Notify the scheduler so it can scan the new job store for jobs
if self.running:
self.wakeup()
def remove_jobstore(self, alias, shutdown=True):
"""
Removes the job store by the given alias from this scheduler.
:param str|unicode alias: alias of the job store
:param bool shutdown: ``True`` to shut down the job store after removing it
"""
with self._jobstores_lock:
jobstore = self._lookup_jobstore(alias)
del self._jobstores[alias]
if shutdown:
jobstore.shutdown()
self._dispatch_event(SchedulerEvent(EVENT_JOBSTORE_REMOVED, alias))
def add_listener(self, callback, mask=EVENT_ALL):
"""
add_listener(callback, mask=EVENT_ALL)
Adds a listener for scheduler events. When a matching event occurs, ``callback`` is executed with the event
object as its sole argument. If the ``mask`` parameter is not provided, the callback will receive events of all
types.
:param callback: any callable that takes one argument
:param int mask: bitmask that indicates which events should be listened to
.. seealso:: :mod:`apscheduler.events`
.. seealso:: :ref:`scheduler-events`
"""
with self._listeners_lock:
self._listeners.append((callback, mask))
def remove_listener(self, callback):
"""Removes a previously added event listener."""
with self._listeners_lock:
for i, (cb, _) in enumerate(self._listeners):
if callback == cb:
del self._listeners[i]
def add_job(self, func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
executor='default', replace_existing=False, **trigger_args):
"""
add_job(func, trigger=None, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
executor='default', replace_existing=False, **trigger_args)
Adds the given job to the job list and wakes up the scheduler if it's already running.
Any option that defaults to ``undefined`` will be replaced with the corresponding default value when the job is
scheduled (which happens when the scheduler is started, or immediately if the scheduler is already running).
The ``func`` argument can be given either as a callable object or a textual reference in the
``package.module:some.object`` format, where the first half (separated by ``:``) is an importable module and the
second half is a reference to the callable object, relative to the module.
The ``trigger`` argument can either be:
#. the alias name of the trigger (e.g. ``date``, ``interval`` or ``cron``), in which case any extra keyword
arguments to this method are passed on to the trigger's constructor
#. an instance of a trigger class
:param func: callable (or a textual reference to one) to run at the given time
:param str|apscheduler.triggers.base.BaseTrigger trigger: trigger that determines when ``func`` is called
:param list|tuple args: list of positional arguments to call func with
:param dict kwargs: dict of keyword arguments to call func with
:param str|unicode id: explicit identifier for the job (for modifying it later)
:param str|unicode name: textual description of the job
:param int misfire_grace_time: seconds after the designated run time that the job is still allowed to be run
:param bool coalesce: run once instead of many times if the scheduler determines that the job should be run more
than once in succession
:param int max_instances: maximum number of concurrently running instances allowed for this job
:param datetime next_run_time: when to first run the job, regardless of the trigger (pass ``None`` to add the
job as paused)
:param str|unicode jobstore: alias of the job store to store the job in
:param str|unicode executor: alias of the executor to run the job with
:param bool replace_existing: ``True`` to replace an existing job with the same ``id`` (but retain the
number of runs from the existing one)
:rtype: Job
"""
job_kwargs = {
'trigger': self._create_trigger(trigger, trigger_args),
'executor': executor,
'func': func,
'args': tuple(args) if args is not None else (),
'kwargs': dict(kwargs) if kwargs is not None else {},
'id': id,
'name': name,
'misfire_grace_time': misfire_grace_time,
'coalesce': coalesce,
'max_instances': max_instances,
'next_run_time': next_run_time
}
job_kwargs = dict((key, value) for key, value in six.iteritems(job_kwargs) if value is not undefined)
job = Job(self, **job_kwargs)
# Don't really add jobs to job stores before the scheduler is up and running
with self._jobstores_lock:
if not self.running:
self._pending_jobs.append((job, jobstore, replace_existing))
self._logger.info('Adding job tentatively -- it will be properly scheduled when the scheduler starts')
else:
self._real_add_job(job, jobstore, replace_existing, True)
return job
def scheduled_job(self, trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined,
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default',
executor='default', **trigger_args):
"""
scheduled_job(trigger, args=None, kwargs=None, id=None, name=None, misfire_grace_time=undefined, \
coalesce=undefined, max_instances=undefined, next_run_time=undefined, jobstore='default', \
executor='default',**trigger_args)
A decorator version of :meth:`add_job`, except that ``replace_existing`` is always ``True``.
.. important:: The ``id`` argument must be given if scheduling a job in a persistent job store. The scheduler
cannot, however, enforce this requirement.
"""
def inner(func):
self.add_job(func, trigger, args, kwargs, id, name, misfire_grace_time, coalesce, max_instances,
next_run_time, jobstore, executor, True, **trigger_args)
return func
return inner
def modify_job(self, job_id, jobstore=None, **changes):
"""
Modifies the properties of a single job. Modifications are passed to this method as extra keyword arguments.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that contains the job
"""
with self._jobstores_lock:
job, jobstore = self._lookup_job(job_id, jobstore)
job._modify(**changes)
if jobstore:
self._lookup_jobstore(jobstore).update_job(job)
self._dispatch_event(JobEvent(EVENT_JOB_MODIFIED, job_id, jobstore))
# Wake up the scheduler since the job's next run time may have been changed
self.wakeup()
def reschedule_job(self, job_id, jobstore=None, trigger=None, **trigger_args):
"""
Constructs a new trigger for a job and updates its next run time.
Extra keyword arguments are passed directly to the trigger's constructor.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that contains the job
:param trigger: alias of the trigger type or a trigger instance
"""
trigger = self._create_trigger(trigger, trigger_args)
now = datetime.now(self.timezone)
next_run_time = trigger.get_next_fire_time(None, now)
self.modify_job(job_id, jobstore, trigger=trigger, next_run_time=next_run_time)
def pause_job(self, job_id, jobstore=None):
"""
Causes the given job not to be executed until it is explicitly resumed.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that contains the job
"""
self.modify_job(job_id, jobstore, next_run_time=None)
def resume_job(self, job_id, jobstore=None):
"""
Resumes the schedule of the given job, or removes the job if its schedule is finished.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that contains the job
"""
with self._jobstores_lock:
job, jobstore = self._lookup_job(job_id, jobstore)
now = datetime.now(self.timezone)
next_run_time = job.trigger.get_next_fire_time(None, now)
if next_run_time:
self.modify_job(job_id, jobstore, next_run_time=next_run_time)
else:
self.remove_job(job.id, jobstore)
def get_jobs(self, jobstore=None, pending=None):
"""
Returns a list of pending jobs (if the scheduler hasn't been started yet) and scheduled jobs, either from a
specific job store or from all of them.
:param str|unicode jobstore: alias of the job store
:param bool pending: ``False`` to leave out pending jobs (jobs that are waiting for the scheduler start to be
added to their respective job stores), ``True`` to only include pending jobs, anything else
to return both
:rtype: list[Job]
"""
with self._jobstores_lock:
jobs = []
if pending is not False:
for job, alias, replace_existing in self._pending_jobs:
if jobstore is None or alias == jobstore:
jobs.append(job)
if pending is not True:
for alias, store in six.iteritems(self._jobstores):
if jobstore is None or alias == jobstore:
jobs.extend(store.get_all_jobs())
return jobs
def get_job(self, job_id, jobstore=None):
"""
Returns the Job that matches the given ``job_id``.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that most likely contains the job
:return: the Job by the given ID, or ``None`` if it wasn't found
:rtype: Job
"""
with self._jobstores_lock:
try:
return self._lookup_job(job_id, jobstore)[0]
except JobLookupError:
return
def remove_job(self, job_id, jobstore=None):
"""
Removes a job, preventing it from being run any more.
:param str|unicode job_id: the identifier of the job
:param str|unicode jobstore: alias of the job store that contains the job
:raises JobLookupError: if the job was not found
"""
with self._jobstores_lock:
# Check if the job is among the pending jobs
for i, (job, jobstore_alias, replace_existing) in enumerate(self._pending_jobs):
if job.id == job_id:
del self._pending_jobs[i]
jobstore = jobstore_alias
break
else:
# Otherwise, try to remove it from each store until it succeeds or we run out of stores to check
for alias, store in six.iteritems(self._jobstores):
if jobstore in (None, alias):
try:
store.remove_job(job_id)
except JobLookupError:
continue
jobstore = alias
break
if jobstore is None:
raise JobLookupError(job_id)
# Notify listeners that a job has been removed
event = JobEvent(EVENT_JOB_REMOVED, job_id, jobstore)
self._dispatch_event(event)
self._logger.info('Removed job %s', job_id)
def remove_all_jobs(self, jobstore=None):
"""
Removes all jobs from the specified job store, or all job stores if none is given.
:param str|unicode jobstore: alias of the job store
"""
with self._jobstores_lock:
if jobstore:
self._pending_jobs = [pending for pending in self._pending_jobs if pending[1] != jobstore]
else:
self._pending_jobs = []
for alias, store in six.iteritems(self._jobstores):
if jobstore in (None, alias):
store.remove_all_jobs()
self._dispatch_event(SchedulerEvent(EVENT_ALL_JOBS_REMOVED, jobstore))
def print_jobs(self, jobstore=None, out=None):
"""
print_jobs(jobstore=None, out=sys.stdout)
Prints out a textual listing of all jobs currently scheduled on either all job stores or just a specific one.
:param str|unicode jobstore: alias of the job store, ``None`` to list jobs from all stores
:param file out: a file-like object to print to (defaults to **sys.stdout** if nothing is given)
"""
out = out or sys.stdout
with self._jobstores_lock:
if self._pending_jobs:
print(six.u('Pending jobs:'), file=out)
for job, jobstore_alias, replace_existing in self._pending_jobs:
if jobstore in (None, jobstore_alias):
print(six.u(' %s') % job, file=out)
for alias, store in six.iteritems(self._jobstores):
if jobstore in (None, alias):
print(six.u('Jobstore %s:') % alias, file=out)
jobs = store.get_all_jobs()
if jobs:
for job in jobs:
print(six.u(' %s') % job, file=out)
else:
print(six.u(' No scheduled jobs'), file=out)
@abstractmethod
def wakeup(self):
"""
Notifies the scheduler that there may be jobs due for execution.
Triggers :meth:`_process_jobs` to be run in an implementation specific manner.
"""
#
# Private API
#
def _configure(self, config):
# Set general options
self._logger = maybe_ref(config.pop('logger', None)) or getLogger('apscheduler.scheduler')
self.timezone = astimezone(config.pop('timezone', None)) or get_localzone()
# Set the job defaults
job_defaults = config.get('job_defaults', {})
self._job_defaults = {
'misfire_grace_time': asint(job_defaults.get('misfire_grace_time', 1)),
'coalesce': asbool(job_defaults.get('coalesce', True)),
'max_instances': asint(job_defaults.get('max_instances', 1))
}
# Configure executors
self._executors.clear()
for alias, value in six.iteritems(config.get('executors', {})):
if isinstance(value, BaseExecutor):
self.add_executor(value, alias)
elif isinstance(value, MutableMapping):
executor_class = value.pop('class', None)
plugin = value.pop('type', None)
if plugin:
executor = self._create_plugin_instance('executor', plugin, value)
elif executor_class:
cls = maybe_ref(executor_class)
executor = cls(**value)
else:
raise ValueError('Cannot create executor "%s" -- either "type" or "class" must be defined' % alias)
self.add_executor(executor, alias)
else:
raise TypeError("Expected executor instance or dict for executors['%s'], got %s instead" % (
alias, value.__class__.__name__))
# Configure job stores
self._jobstores.clear()
for alias, value in six.iteritems(config.get('jobstores', {})):
if isinstance(value, BaseJobStore):
self.add_jobstore(value, alias)
elif isinstance(value, MutableMapping):
jobstore_class = value.pop('class', None)
plugin = value.pop('type', None)
if plugin:
jobstore = self._create_plugin_instance('jobstore', plugin, value)
elif jobstore_class:
cls = maybe_ref(jobstore_class)
jobstore = cls(**value)
else:
raise ValueError('Cannot create job store "%s" -- either "type" or "class" must be defined' % alias)
self.add_jobstore(jobstore, alias)
else:
raise TypeError("Expected job store instance or dict for jobstores['%s'], got %s instead" % (
alias, value.__class__.__name__))
def _create_default_executor(self):
"""Creates a default executor store, specific to the particular scheduler type."""
return ThreadPoolExecutor()
def _create_default_jobstore(self):
"""Creates a default job store, specific to the particular scheduler type."""
return MemoryJobStore()
def _lookup_executor(self, alias):
"""
Returns the executor instance by the given name from the list of executors that were added to this scheduler.
:type alias: str
:raises KeyError: if no executor by the given alias is not found
"""
try:
return self._executors[alias]
except KeyError:
raise KeyError('No such executor: %s' % alias)
def _lookup_jobstore(self, alias):
"""
Returns the job store instance by the given name from the list of job stores that were added to this scheduler.
:type alias: str
:raises KeyError: if no job store by the given alias is not found
"""
try:
return self._jobstores[alias]
except KeyError:
raise KeyError('No such job store: %s' % alias)
def _lookup_job(self, job_id, jobstore_alias):
"""
Finds a job by its ID.
:type job_id: str
:param str jobstore_alias: alias of a job store to look in
:return tuple[Job, str]: a tuple of job, jobstore alias (jobstore alias is None in case of a pending job)
:raises JobLookupError: if no job by the given ID is found.
"""
# Check if the job is among the pending jobs
for job, alias, replace_existing in self._pending_jobs:
if job.id == job_id:
return job, None
# Look in all job stores
for alias, store in six.iteritems(self._jobstores):
if jobstore_alias in (None, alias):
job = store.lookup_job(job_id)
if job is not None:
return job, alias
raise JobLookupError(job_id)
def _dispatch_event(self, event):
"""
Dispatches the given event to interested listeners.
:param SchedulerEvent event: the event to send
"""
with self._listeners_lock:
listeners = tuple(self._listeners)
for cb, mask in listeners:
if event.code & mask:
try:
cb(event)
except:
self._logger.exception('Error notifying listener')
def _real_add_job(self, job, jobstore_alias, replace_existing, wakeup):
"""
:param Job job: the job to add
:param bool replace_existing: ``True`` to use update_job() in case the job already exists in the store
:param bool wakeup: ``True`` to wake up the scheduler after adding the job
"""
# Fill in undefined values with defaults
replacements = {}
for key, value in six.iteritems(self._job_defaults):
if not hasattr(job, key):
replacements[key] = value
# Calculate the next run time if there is none defined
if not hasattr(job, 'next_run_time'):
now = datetime.now(self.timezone)
replacements['next_run_time'] = job.trigger.get_next_fire_time(None, now)
# Apply any replacements
job._modify(**replacements)
# Add the job to the given job store
store = self._lookup_jobstore(jobstore_alias)
try:
store.add_job(job)
except ConflictingIdError:
if replace_existing:
store.update_job(job)
else:
raise
# Mark the job as no longer pending
job._jobstore_alias = jobstore_alias
# Notify listeners that a new job has been added
event = JobEvent(EVENT_JOB_ADDED, job.id, jobstore_alias)
self._dispatch_event(event)
self._logger.info('Added job "%s" to job store "%s"', job.name, jobstore_alias)
# Notify the scheduler about the new job
if wakeup:
self.wakeup()
def _create_plugin_instance(self, type_, alias, constructor_kwargs):
"""Creates an instance of the given plugin type, loading the plugin first if necessary."""
plugin_container, class_container, base_class = {
'trigger': (self._trigger_plugins, self._trigger_classes, BaseTrigger),
'jobstore': (self._jobstore_plugins, self._jobstore_classes, BaseJobStore),
'executor': (self._executor_plugins, self._executor_classes, BaseExecutor)
}[type_]
try:
plugin_cls = class_container[alias]
except KeyError:
if alias in plugin_container:
plugin_cls = class_container[alias] = plugin_container[alias].load()
if not issubclass(plugin_cls, base_class):
raise TypeError('The {0} entry point does not point to a {0} class'.format(type_))
else:
raise LookupError('No {0} by the name "{1}" was found'.format(type_, alias))
return plugin_cls(**constructor_kwargs)
def _create_trigger(self, trigger, trigger_args):
if isinstance(trigger, BaseTrigger):
return trigger
elif trigger is None:
trigger = 'date'
elif not isinstance(trigger, six.string_types):
raise TypeError('Expected a trigger instance or string, got %s instead' % trigger.__class__.__name__)
# Use the scheduler's time zone if nothing else is specified
trigger_args.setdefault('timezone', self.timezone)
# Instantiate the trigger class
return self._create_plugin_instance('trigger', trigger, trigger_args)
def _create_lock(self):
"""Creates a reentrant lock object."""
return RLock()
def _process_jobs(self):
"""
Iterates through jobs in every jobstore, starts jobs that are due and figures out how long to wait for the next
round.
"""
self._logger.debug('Looking for jobs to run')
now = datetime.now(self.timezone)
next_wakeup_time = None
with self._jobstores_lock:
for jobstore_alias, jobstore in six.iteritems(self._jobstores):
for job in jobstore.get_due_jobs(now):
# Look up the job's executor
try:
executor = self._lookup_executor(job.executor)
except:
self._logger.error(
'Executor lookup ("%s") failed for job "%s" -- removing it from the job store',
job.executor, job)
self.remove_job(job.id, jobstore_alias)
continue
run_times = job._get_run_times(now)
run_times = run_times[-1:] if run_times and job.coalesce else run_times
if run_times:
try:
executor.submit_job(job, run_times)
except MaxInstancesReachedError:
self._logger.warning(
'Execution of job "%s" skipped: maximum number of running instances reached (%d)',
job, job.max_instances)
except:
self._logger.exception('Error submitting job "%s" to executor "%s"', job, job.executor)
# Update the job if it has a next execution time. Otherwise remove it from the job store.
job_next_run = job.trigger.get_next_fire_time(run_times[-1], now)
if job_next_run:
job._modify(next_run_time=job_next_run)
jobstore.update_job(job)
else:
self.remove_job(job.id, jobstore_alias)
# Set a new next wakeup time if there isn't one yet or the jobstore has an even earlier one
jobstore_next_run_time = jobstore.get_next_run_time()
if jobstore_next_run_time and (next_wakeup_time is None or jobstore_next_run_time < next_wakeup_time):
next_wakeup_time = jobstore_next_run_time
# Determine the delay until this method should be called again
if next_wakeup_time is not None:
wait_seconds = max(timedelta_seconds(next_wakeup_time - now), 0)
self._logger.debug('Next wakeup is due at %s (in %f seconds)', next_wakeup_time, wait_seconds)
else:
wait_seconds = None
self._logger.debug('No jobs; waiting until a job is added')
return wait_seconds

View File

@@ -0,0 +1,32 @@
from __future__ import absolute_import
from threading import Event
from apscheduler.schedulers.base import BaseScheduler
class BlockingScheduler(BaseScheduler):
"""
A scheduler that runs in the foreground (:meth:`~apscheduler.schedulers.base.BaseScheduler.start` will block).
"""
MAX_WAIT_TIME = 4294967 # Maximum value accepted by Event.wait() on Windows
_event = None
def start(self):
super(BlockingScheduler, self).start()
self._event = Event()
self._main_loop()
def shutdown(self, wait=True):
super(BlockingScheduler, self).shutdown(wait)
self._event.set()
def _main_loop(self):
while self.running:
wait_seconds = self._process_jobs()
self._event.wait(wait_seconds if wait_seconds is not None else self.MAX_WAIT_TIME)
self._event.clear()
def wakeup(self):
self._event.set()

View File

@@ -0,0 +1,35 @@
from __future__ import absolute_import
from apscheduler.schedulers.blocking import BlockingScheduler
from apscheduler.schedulers.base import BaseScheduler
try:
from gevent.event import Event
from gevent.lock import RLock
import gevent
except ImportError: # pragma: nocover
raise ImportError('GeventScheduler requires gevent installed')
class GeventScheduler(BlockingScheduler):
"""A scheduler that runs as a Gevent greenlet."""
_greenlet = None
def start(self):
BaseScheduler.start(self)
self._event = Event()
self._greenlet = gevent.spawn(self._main_loop)
return self._greenlet
def shutdown(self, wait=True):
super(GeventScheduler, self).shutdown(wait)
self._greenlet.join()
del self._greenlet
def _create_lock(self):
return RLock()
def _create_default_executor(self):
from apscheduler.executors.gevent import GeventExecutor
return GeventExecutor()

View File

@@ -0,0 +1,46 @@
from __future__ import absolute_import
from apscheduler.schedulers.base import BaseScheduler
try:
from PyQt5.QtCore import QObject, QTimer
except ImportError: # pragma: nocover
try:
from PyQt4.QtCore import QObject, QTimer
except ImportError:
try:
from PySide.QtCore import QObject, QTimer # flake8: noqa
except ImportError:
raise ImportError('QtScheduler requires either PyQt5, PyQt4 or PySide installed')
class QtScheduler(BaseScheduler):
"""A scheduler that runs in a Qt event loop."""
_timer = None
def start(self):
super(QtScheduler, self).start()
self.wakeup()
def shutdown(self, wait=True):
super(QtScheduler, self).shutdown(wait)
self._stop_timer()
def _start_timer(self, wait_seconds):
self._stop_timer()
if wait_seconds is not None:
self._timer = QTimer.singleShot(wait_seconds * 1000, self._process_jobs)
def _stop_timer(self):
if self._timer:
if self._timer.isActive():
self._timer.stop()
del self._timer
def wakeup(self):
self._start_timer(0)
def _process_jobs(self):
wait_seconds = super(QtScheduler, self)._process_jobs()
self._start_timer(wait_seconds)

View File

@@ -0,0 +1,60 @@
from __future__ import absolute_import
from datetime import timedelta
from functools import wraps
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.util import maybe_ref
try:
from tornado.ioloop import IOLoop
except ImportError: # pragma: nocover
raise ImportError('TornadoScheduler requires tornado installed')
def run_in_ioloop(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
self._ioloop.add_callback(func, self, *args, **kwargs)
return wrapper
class TornadoScheduler(BaseScheduler):
"""
A scheduler that runs on a Tornado IOLoop.
=========== ===============================================================
``io_loop`` Tornado IOLoop instance to use (defaults to the global IO loop)
=========== ===============================================================
"""
_ioloop = None
_timeout = None
def start(self):
super(TornadoScheduler, self).start()
self.wakeup()
@run_in_ioloop
def shutdown(self, wait=True):
super(TornadoScheduler, self).shutdown(wait)
self._stop_timer()
def _configure(self, config):
self._ioloop = maybe_ref(config.pop('io_loop', None)) or IOLoop.current()
super(TornadoScheduler, self)._configure(config)
def _start_timer(self, wait_seconds):
self._stop_timer()
if wait_seconds is not None:
self._timeout = self._ioloop.add_timeout(timedelta(seconds=wait_seconds), self.wakeup)
def _stop_timer(self):
if self._timeout:
self._ioloop.remove_timeout(self._timeout)
del self._timeout
@run_in_ioloop
def wakeup(self):
self._stop_timer()
wait_seconds = self._process_jobs()
self._start_timer(wait_seconds)

View File

@@ -0,0 +1,65 @@
from __future__ import absolute_import
from functools import wraps
from apscheduler.schedulers.base import BaseScheduler
from apscheduler.util import maybe_ref
try:
from twisted.internet import reactor as default_reactor
except ImportError: # pragma: nocover
raise ImportError('TwistedScheduler requires Twisted installed')
def run_in_reactor(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
self._reactor.callFromThread(func, self, *args, **kwargs)
return wrapper
class TwistedScheduler(BaseScheduler):
"""
A scheduler that runs on a Twisted reactor.
Extra options:
=========== ========================================================
``reactor`` Reactor instance to use (defaults to the global reactor)
=========== ========================================================
"""
_reactor = None
_delayedcall = None
def _configure(self, config):
self._reactor = maybe_ref(config.pop('reactor', default_reactor))
super(TwistedScheduler, self)._configure(config)
def start(self):
super(TwistedScheduler, self).start()
self.wakeup()
@run_in_reactor
def shutdown(self, wait=True):
super(TwistedScheduler, self).shutdown(wait)
self._stop_timer()
def _start_timer(self, wait_seconds):
self._stop_timer()
if wait_seconds is not None:
self._delayedcall = self._reactor.callLater(wait_seconds, self.wakeup)
def _stop_timer(self):
if self._delayedcall and self._delayedcall.active():
self._delayedcall.cancel()
del self._delayedcall
@run_in_reactor
def wakeup(self):
self._stop_timer()
wait_seconds = self._process_jobs()
self._start_timer(wait_seconds)
def _create_default_executor(self):
from apscheduler.executors.twisted import TwistedExecutor
return TwistedExecutor()

View File

View File

@@ -0,0 +1,16 @@
from abc import ABCMeta, abstractmethod
import six
class BaseTrigger(six.with_metaclass(ABCMeta)):
"""Abstract base class that defines the interface that every trigger must implement."""
@abstractmethod
def get_next_fire_time(self, previous_fire_time, now):
"""
Returns the next datetime to fire on, If no such datetime can be calculated, returns ``None``.
:param datetime.datetime previous_fire_time: the previous time the trigger was fired
:param datetime.datetime now: current datetime
"""

View File

@@ -0,0 +1,176 @@
from datetime import datetime, timedelta
from tzlocal import get_localzone
import six
from apscheduler.triggers.base import BaseTrigger
from apscheduler.triggers.cron.fields import BaseField, WeekField, DayOfMonthField, DayOfWeekField, DEFAULT_VALUES
from apscheduler.util import datetime_ceil, convert_to_datetime, datetime_repr, astimezone
class CronTrigger(BaseTrigger):
"""
Triggers when current time matches all specified time constraints, similarly to how the UNIX cron scheduler works.
:param int|str year: 4-digit year
:param int|str month: month (1-12)
:param int|str day: day of the (1-31)
:param int|str week: ISO week (1-53)
:param int|str day_of_week: number or name of weekday (0-6 or mon,tue,wed,thu,fri,sat,sun)
:param int|str hour: hour (0-23)
:param int|str minute: minute (0-59)
:param int|str second: second (0-59)
:param datetime|str start_date: earliest possible date/time to trigger on (inclusive)
:param datetime|str end_date: latest possible date/time to trigger on (inclusive)
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
(defaults to scheduler timezone)
.. note:: The first weekday is always **monday**.
"""
FIELD_NAMES = ('year', 'month', 'day', 'week', 'day_of_week', 'hour', 'minute', 'second')
FIELDS_MAP = {
'year': BaseField,
'month': BaseField,
'week': WeekField,
'day': DayOfMonthField,
'day_of_week': DayOfWeekField,
'hour': BaseField,
'minute': BaseField,
'second': BaseField
}
__slots__ = 'timezone', 'start_date', 'end_date', 'fields'
def __init__(self, year=None, month=None, day=None, week=None, day_of_week=None, hour=None, minute=None,
second=None, start_date=None, end_date=None, timezone=None):
if timezone:
self.timezone = astimezone(timezone)
elif start_date and start_date.tzinfo:
self.timezone = start_date.tzinfo
elif end_date and end_date.tzinfo:
self.timezone = end_date.tzinfo
else:
self.timezone = get_localzone()
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
values = dict((key, value) for (key, value) in six.iteritems(locals())
if key in self.FIELD_NAMES and value is not None)
self.fields = []
assign_defaults = False
for field_name in self.FIELD_NAMES:
if field_name in values:
exprs = values.pop(field_name)
is_default = False
assign_defaults = not values
elif assign_defaults:
exprs = DEFAULT_VALUES[field_name]
is_default = True
else:
exprs = '*'
is_default = True
field_class = self.FIELDS_MAP[field_name]
field = field_class(field_name, exprs, is_default)
self.fields.append(field)
def _increment_field_value(self, dateval, fieldnum):
"""
Increments the designated field and resets all less significant fields to their minimum values.
:type dateval: datetime
:type fieldnum: int
:return: a tuple containing the new date, and the number of the field that was actually incremented
:rtype: tuple
"""
values = {}
i = 0
while i < len(self.fields):
field = self.fields[i]
if not field.REAL:
if i == fieldnum:
fieldnum -= 1
i -= 1
else:
i += 1
continue
if i < fieldnum:
values[field.name] = field.get_value(dateval)
i += 1
elif i > fieldnum:
values[field.name] = field.get_min(dateval)
i += 1
else:
value = field.get_value(dateval)
maxval = field.get_max(dateval)
if value == maxval:
fieldnum -= 1
i -= 1
else:
values[field.name] = value + 1
i += 1
difference = datetime(**values) - dateval.replace(tzinfo=None)
return self.timezone.normalize(dateval + difference), fieldnum
def _set_field_value(self, dateval, fieldnum, new_value):
values = {}
for i, field in enumerate(self.fields):
if field.REAL:
if i < fieldnum:
values[field.name] = field.get_value(dateval)
elif i > fieldnum:
values[field.name] = field.get_min(dateval)
else:
values[field.name] = new_value
difference = datetime(**values) - dateval.replace(tzinfo=None)
return self.timezone.normalize(dateval + difference)
def get_next_fire_time(self, previous_fire_time, now):
if previous_fire_time:
start_date = max(now, previous_fire_time + timedelta(microseconds=1))
else:
start_date = max(now, self.start_date) if self.start_date else now
fieldnum = 0
next_date = datetime_ceil(start_date).astimezone(self.timezone)
while 0 <= fieldnum < len(self.fields):
field = self.fields[fieldnum]
curr_value = field.get_value(next_date)
next_value = field.get_next_value(next_date)
if next_value is None:
# No valid value was found
next_date, fieldnum = self._increment_field_value(next_date, fieldnum - 1)
elif next_value > curr_value:
# A valid, but higher than the starting value, was found
if field.REAL:
next_date = self._set_field_value(next_date, fieldnum, next_value)
fieldnum += 1
else:
next_date, fieldnum = self._increment_field_value(next_date, fieldnum)
else:
# A valid value was found, no changes necessary
fieldnum += 1
# Return if the date has rolled past the end date
if self.end_date and next_date > self.end_date:
return None
if fieldnum >= 0:
return next_date
def __str__(self):
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
return 'cron[%s]' % (', '.join(options))
def __repr__(self):
options = ["%s='%s'" % (f.name, f) for f in self.fields if not f.is_default]
if self.start_date:
options.append("start_date='%s'" % datetime_repr(self.start_date))
return '<%s (%s)>' % (self.__class__.__name__, ', '.join(options))

View File

@@ -0,0 +1,188 @@
"""
This module contains the expressions applicable for CronTrigger's fields.
"""
from calendar import monthrange
import re
from apscheduler.util import asint
__all__ = ('AllExpression', 'RangeExpression', 'WeekdayRangeExpression', 'WeekdayPositionExpression',
'LastDayOfMonthExpression')
WEEKDAYS = ['mon', 'tue', 'wed', 'thu', 'fri', 'sat', 'sun']
class AllExpression(object):
value_re = re.compile(r'\*(?:/(?P<step>\d+))?$')
def __init__(self, step=None):
self.step = asint(step)
if self.step == 0:
raise ValueError('Increment must be higher than 0')
def get_next_value(self, date, field):
start = field.get_value(date)
minval = field.get_min(date)
maxval = field.get_max(date)
start = max(start, minval)
if not self.step:
next = start
else:
distance_to_next = (self.step - (start - minval)) % self.step
next = start + distance_to_next
if next <= maxval:
return next
def __str__(self):
if self.step:
return '*/%d' % self.step
return '*'
def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, self.step)
class RangeExpression(AllExpression):
value_re = re.compile(
r'(?P<first>\d+)(?:-(?P<last>\d+))?(?:/(?P<step>\d+))?$')
def __init__(self, first, last=None, step=None):
AllExpression.__init__(self, step)
first = asint(first)
last = asint(last)
if last is None and step is None:
last = first
if last is not None and first > last:
raise ValueError('The minimum value in a range must not be higher than the maximum')
self.first = first
self.last = last
def get_next_value(self, date, field):
start = field.get_value(date)
minval = field.get_min(date)
maxval = field.get_max(date)
# Apply range limits
minval = max(minval, self.first)
if self.last is not None:
maxval = min(maxval, self.last)
start = max(start, minval)
if not self.step:
next = start
else:
distance_to_next = (self.step - (start - minval)) % self.step
next = start + distance_to_next
if next <= maxval:
return next
def __str__(self):
if self.last != self.first and self.last is not None:
range = '%d-%d' % (self.first, self.last)
else:
range = str(self.first)
if self.step:
return '%s/%d' % (range, self.step)
return range
def __repr__(self):
args = [str(self.first)]
if self.last != self.first and self.last is not None or self.step:
args.append(str(self.last))
if self.step:
args.append(str(self.step))
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
class WeekdayRangeExpression(RangeExpression):
value_re = re.compile(r'(?P<first>[a-z]+)(?:-(?P<last>[a-z]+))?', re.IGNORECASE)
def __init__(self, first, last=None):
try:
first_num = WEEKDAYS.index(first.lower())
except ValueError:
raise ValueError('Invalid weekday name "%s"' % first)
if last:
try:
last_num = WEEKDAYS.index(last.lower())
except ValueError:
raise ValueError('Invalid weekday name "%s"' % last)
else:
last_num = None
RangeExpression.__init__(self, first_num, last_num)
def __str__(self):
if self.last != self.first and self.last is not None:
return '%s-%s' % (WEEKDAYS[self.first], WEEKDAYS[self.last])
return WEEKDAYS[self.first]
def __repr__(self):
args = ["'%s'" % WEEKDAYS[self.first]]
if self.last != self.first and self.last is not None:
args.append("'%s'" % WEEKDAYS[self.last])
return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
class WeekdayPositionExpression(AllExpression):
options = ['1st', '2nd', '3rd', '4th', '5th', 'last']
value_re = re.compile(r'(?P<option_name>%s) +(?P<weekday_name>(?:\d+|\w+))' % '|'.join(options), re.IGNORECASE)
def __init__(self, option_name, weekday_name):
try:
self.option_num = self.options.index(option_name.lower())
except ValueError:
raise ValueError('Invalid weekday position "%s"' % option_name)
try:
self.weekday = WEEKDAYS.index(weekday_name.lower())
except ValueError:
raise ValueError('Invalid weekday name "%s"' % weekday_name)
def get_next_value(self, date, field):
# Figure out the weekday of the month's first day and the number
# of days in that month
first_day_wday, last_day = monthrange(date.year, date.month)
# Calculate which day of the month is the first of the target weekdays
first_hit_day = self.weekday - first_day_wday + 1
if first_hit_day <= 0:
first_hit_day += 7
# Calculate what day of the month the target weekday would be
if self.option_num < 5:
target_day = first_hit_day + self.option_num * 7
else:
target_day = first_hit_day + ((last_day - first_hit_day) / 7) * 7
if target_day <= last_day and target_day >= date.day:
return target_day
def __str__(self):
return '%s %s' % (self.options[self.option_num], WEEKDAYS[self.weekday])
def __repr__(self):
return "%s('%s', '%s')" % (self.__class__.__name__, self.options[self.option_num], WEEKDAYS[self.weekday])
class LastDayOfMonthExpression(AllExpression):
value_re = re.compile(r'last', re.IGNORECASE)
def __init__(self):
pass
def get_next_value(self, date, field):
return monthrange(date.year, date.month)[1]
def __str__(self):
return 'last'
def __repr__(self):
return "%s()" % self.__class__.__name__

View File

@@ -0,0 +1,97 @@
"""
Fields represent CronTrigger options which map to :class:`~datetime.datetime`
fields.
"""
from calendar import monthrange
from apscheduler.triggers.cron.expressions import (
AllExpression, RangeExpression, WeekdayPositionExpression, LastDayOfMonthExpression, WeekdayRangeExpression)
__all__ = ('MIN_VALUES', 'MAX_VALUES', 'DEFAULT_VALUES', 'BaseField', 'WeekField', 'DayOfMonthField', 'DayOfWeekField')
MIN_VALUES = {'year': 1970, 'month': 1, 'day': 1, 'week': 1, 'day_of_week': 0, 'hour': 0, 'minute': 0, 'second': 0}
MAX_VALUES = {'year': 2 ** 63, 'month': 12, 'day:': 31, 'week': 53, 'day_of_week': 6, 'hour': 23, 'minute': 59,
'second': 59}
DEFAULT_VALUES = {'year': '*', 'month': 1, 'day': 1, 'week': '*', 'day_of_week': '*', 'hour': 0, 'minute': 0,
'second': 0}
class BaseField(object):
REAL = True
COMPILERS = [AllExpression, RangeExpression]
def __init__(self, name, exprs, is_default=False):
self.name = name
self.is_default = is_default
self.compile_expressions(exprs)
def get_min(self, dateval):
return MIN_VALUES[self.name]
def get_max(self, dateval):
return MAX_VALUES[self.name]
def get_value(self, dateval):
return getattr(dateval, self.name)
def get_next_value(self, dateval):
smallest = None
for expr in self.expressions:
value = expr.get_next_value(dateval, self)
if smallest is None or (value is not None and value < smallest):
smallest = value
return smallest
def compile_expressions(self, exprs):
self.expressions = []
# Split a comma-separated expression list, if any
exprs = str(exprs).strip()
if ',' in exprs:
for expr in exprs.split(','):
self.compile_expression(expr)
else:
self.compile_expression(exprs)
def compile_expression(self, expr):
for compiler in self.COMPILERS:
match = compiler.value_re.match(expr)
if match:
compiled_expr = compiler(**match.groupdict())
self.expressions.append(compiled_expr)
return
raise ValueError('Unrecognized expression "%s" for field "%s"' % (expr, self.name))
def __str__(self):
expr_strings = (str(e) for e in self.expressions)
return ','.join(expr_strings)
def __repr__(self):
return "%s('%s', '%s')" % (self.__class__.__name__, self.name, self)
class WeekField(BaseField):
REAL = False
def get_value(self, dateval):
return dateval.isocalendar()[1]
class DayOfMonthField(BaseField):
COMPILERS = BaseField.COMPILERS + [WeekdayPositionExpression, LastDayOfMonthExpression]
def get_max(self, dateval):
return monthrange(dateval.year, dateval.month)[1]
class DayOfWeekField(BaseField):
REAL = False
COMPILERS = BaseField.COMPILERS + [WeekdayRangeExpression]
def get_value(self, dateval):
return dateval.weekday()

View File

@@ -0,0 +1,30 @@
from datetime import datetime
from tzlocal import get_localzone
from apscheduler.triggers.base import BaseTrigger
from apscheduler.util import convert_to_datetime, datetime_repr, astimezone
class DateTrigger(BaseTrigger):
"""
Triggers once on the given datetime. If ``run_date`` is left empty, current time is used.
:param datetime|str run_date: the date/time to run the job at
:param datetime.tzinfo|str timezone: time zone for ``run_date`` if it doesn't have one already
"""
__slots__ = 'timezone', 'run_date'
def __init__(self, run_date=None, timezone=None):
timezone = astimezone(timezone) or get_localzone()
self.run_date = convert_to_datetime(run_date or datetime.now(), timezone, 'run_date')
def get_next_fire_time(self, previous_fire_time, now):
return self.run_date if previous_fire_time is None else None
def __str__(self):
return 'date[%s]' % datetime_repr(self.run_date)
def __repr__(self):
return "<%s (run_date='%s')>" % (self.__class__.__name__, datetime_repr(self.run_date))

View File

@@ -0,0 +1,65 @@
from datetime import timedelta, datetime
from math import ceil
from tzlocal import get_localzone
from apscheduler.triggers.base import BaseTrigger
from apscheduler.util import convert_to_datetime, timedelta_seconds, datetime_repr, astimezone
class IntervalTrigger(BaseTrigger):
"""
Triggers on specified intervals, starting on ``start_date`` if specified, ``datetime.now()`` + interval
otherwise.
:param int weeks: number of weeks to wait
:param int days: number of days to wait
:param int hours: number of hours to wait
:param int minutes: number of minutes to wait
:param int seconds: number of seconds to wait
:param datetime|str start_date: starting point for the interval calculation
:param datetime|str end_date: latest possible date/time to trigger on
:param datetime.tzinfo|str timezone: time zone to use for the date/time calculations
"""
__slots__ = 'timezone', 'start_date', 'end_date', 'interval'
def __init__(self, weeks=0, days=0, hours=0, minutes=0, seconds=0, start_date=None, end_date=None, timezone=None):
self.interval = timedelta(weeks=weeks, days=days, hours=hours, minutes=minutes, seconds=seconds)
self.interval_length = timedelta_seconds(self.interval)
if self.interval_length == 0:
self.interval = timedelta(seconds=1)
self.interval_length = 1
if timezone:
self.timezone = astimezone(timezone)
elif start_date and start_date.tzinfo:
self.timezone = start_date.tzinfo
elif end_date and end_date.tzinfo:
self.timezone = end_date.tzinfo
else:
self.timezone = get_localzone()
start_date = start_date or (datetime.now(self.timezone) + self.interval)
self.start_date = convert_to_datetime(start_date, self.timezone, 'start_date')
self.end_date = convert_to_datetime(end_date, self.timezone, 'end_date')
def get_next_fire_time(self, previous_fire_time, now):
if previous_fire_time:
next_fire_time = previous_fire_time + self.interval
elif self.start_date > now:
next_fire_time = self.start_date
else:
timediff_seconds = timedelta_seconds(now - self.start_date)
next_interval_num = int(ceil(timediff_seconds / self.interval_length))
next_fire_time = self.start_date + self.interval * next_interval_num
if not self.end_date or next_fire_time <= self.end_date:
return self.timezone.normalize(next_fire_time)
def __str__(self):
return 'interval[%s]' % str(self.interval)
def __repr__(self):
return "<%s (interval=%r, start_date='%s')>" % (self.__class__.__name__, self.interval,
datetime_repr(self.start_date))

385
lib/apscheduler/util.py Normal file
View File

@@ -0,0 +1,385 @@
"""This module contains several handy functions primarily meant for internal use."""
from __future__ import division
from datetime import date, datetime, time, timedelta, tzinfo
from inspect import isfunction, ismethod, getargspec
from calendar import timegm
import re
from pytz import timezone, utc
import six
try:
from inspect import signature
except ImportError: # pragma: nocover
try:
from funcsigs import signature
except ImportError:
signature = None
__all__ = ('asint', 'asbool', 'astimezone', 'convert_to_datetime', 'datetime_to_utc_timestamp',
'utc_timestamp_to_datetime', 'timedelta_seconds', 'datetime_ceil', 'get_callable_name', 'obj_to_ref',
'ref_to_obj', 'maybe_ref', 'repr_escape', 'check_callable_args')
class _Undefined(object):
def __nonzero__(self):
return False
def __bool__(self):
return False
def __repr__(self):
return '<undefined>'
undefined = _Undefined() #: a unique object that only signifies that no value is defined
def asint(text):
"""
Safely converts a string to an integer, returning None if the string is None.
:type text: str
:rtype: int
"""
if text is not None:
return int(text)
def asbool(obj):
"""
Interprets an object as a boolean value.
:rtype: bool
"""
if isinstance(obj, str):
obj = obj.strip().lower()
if obj in ('true', 'yes', 'on', 'y', 't', '1'):
return True
if obj in ('false', 'no', 'off', 'n', 'f', '0'):
return False
raise ValueError('Unable to interpret value "%s" as boolean' % obj)
return bool(obj)
def astimezone(obj):
"""
Interprets an object as a timezone.
:rtype: tzinfo
"""
if isinstance(obj, six.string_types):
return timezone(obj)
if isinstance(obj, tzinfo):
if not hasattr(obj, 'localize') or not hasattr(obj, 'normalize'):
raise TypeError('Only timezones from the pytz library are supported')
if obj.zone == 'local':
raise ValueError('Unable to determine the name of the local timezone -- use an explicit timezone instead')
return obj
if obj is not None:
raise TypeError('Expected tzinfo, got %s instead' % obj.__class__.__name__)
_DATE_REGEX = re.compile(
r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
r'(?: (?P<hour>\d{1,2}):(?P<minute>\d{1,2}):(?P<second>\d{1,2})'
r'(?:\.(?P<microsecond>\d{1,6}))?)?')
def convert_to_datetime(input, tz, arg_name):
"""
Converts the given object to a timezone aware datetime object.
If a timezone aware datetime object is passed, it is returned unmodified.
If a native datetime object is passed, it is given the specified timezone.
If the input is a string, it is parsed as a datetime with the given timezone.
Date strings are accepted in three different forms: date only (Y-m-d),
date with time (Y-m-d H:M:S) or with date+time with microseconds
(Y-m-d H:M:S.micro).
:param str|datetime input: the datetime or string to convert to a timezone aware datetime
:param datetime.tzinfo tz: timezone to interpret ``input`` in
:param str arg_name: the name of the argument (used in an error message)
:rtype: datetime
"""
if input is None:
return
elif isinstance(input, datetime):
datetime_ = input
elif isinstance(input, date):
datetime_ = datetime.combine(input, time())
elif isinstance(input, six.string_types):
m = _DATE_REGEX.match(input)
if not m:
raise ValueError('Invalid date string')
values = [(k, int(v or 0)) for k, v in m.groupdict().items()]
values = dict(values)
datetime_ = datetime(**values)
else:
raise TypeError('Unsupported type for %s: %s' % (arg_name, input.__class__.__name__))
if datetime_.tzinfo is not None:
return datetime_
if tz is None:
raise ValueError('The "tz" argument must be specified if %s has no timezone information' % arg_name)
if isinstance(tz, six.string_types):
tz = timezone(tz)
try:
return tz.localize(datetime_, is_dst=None)
except AttributeError:
raise TypeError('Only pytz timezones are supported (need the localize() and normalize() methods)')
def datetime_to_utc_timestamp(timeval):
"""
Converts a datetime instance to a timestamp.
:type timeval: datetime
:rtype: float
"""
if timeval is not None:
return timegm(timeval.utctimetuple()) + timeval.microsecond / 1000000
def utc_timestamp_to_datetime(timestamp):
"""
Converts the given timestamp to a datetime instance.
:type timestamp: float
:rtype: datetime
"""
if timestamp is not None:
return datetime.fromtimestamp(timestamp, utc)
def timedelta_seconds(delta):
"""
Converts the given timedelta to seconds.
:type delta: timedelta
:rtype: float
"""
return delta.days * 24 * 60 * 60 + delta.seconds + \
delta.microseconds / 1000000.0
def datetime_ceil(dateval):
"""
Rounds the given datetime object upwards.
:type dateval: datetime
"""
if dateval.microsecond > 0:
return dateval + timedelta(seconds=1, microseconds=-dateval.microsecond)
return dateval
def datetime_repr(dateval):
return dateval.strftime('%Y-%m-%d %H:%M:%S %Z') if dateval else 'None'
def get_callable_name(func):
"""
Returns the best available display name for the given function/callable.
:rtype: str
"""
# the easy case (on Python 3.3+)
if hasattr(func, '__qualname__'):
return func.__qualname__
# class methods, bound and unbound methods
f_self = getattr(func, '__self__', None) or getattr(func, 'im_self', None)
if f_self and hasattr(func, '__name__'):
f_class = f_self if isinstance(f_self, type) else f_self.__class__
else:
f_class = getattr(func, 'im_class', None)
if f_class and hasattr(func, '__name__'):
return '%s.%s' % (f_class.__name__, func.__name__)
# class or class instance
if hasattr(func, '__call__'):
# class
if hasattr(func, '__name__'):
return func.__name__
# instance of a class with a __call__ method
return func.__class__.__name__
raise TypeError('Unable to determine a name for %r -- maybe it is not a callable?' % func)
def obj_to_ref(obj):
"""
Returns the path to the given object.
:rtype: str
"""
try:
ref = '%s:%s' % (obj.__module__, get_callable_name(obj))
obj2 = ref_to_obj(ref)
if obj != obj2:
raise ValueError
except Exception:
raise ValueError('Cannot determine the reference to %r' % obj)
return ref
def ref_to_obj(ref):
"""
Returns the object pointed to by ``ref``.
:type ref: str
"""
if not isinstance(ref, six.string_types):
raise TypeError('References must be strings')
if ':' not in ref:
raise ValueError('Invalid reference')
modulename, rest = ref.split(':', 1)
try:
obj = __import__(modulename)
except ImportError:
raise LookupError('Error resolving reference %s: could not import module' % ref)
try:
for name in modulename.split('.')[1:] + rest.split('.'):
obj = getattr(obj, name)
return obj
except Exception:
raise LookupError('Error resolving reference %s: error looking up object' % ref)
def maybe_ref(ref):
"""
Returns the object that the given reference points to, if it is indeed a reference.
If it is not a reference, the object is returned as-is.
"""
if not isinstance(ref, str):
return ref
return ref_to_obj(ref)
if six.PY2:
def repr_escape(string):
if isinstance(string, six.text_type):
return string.encode('ascii', 'backslashreplace')
return string
else:
repr_escape = lambda string: string
def check_callable_args(func, args, kwargs):
"""
Ensures that the given callable can be called with the given arguments.
:type args: tuple
:type kwargs: dict
"""
pos_kwargs_conflicts = [] # parameters that have a match in both args and kwargs
positional_only_kwargs = [] # positional-only parameters that have a match in kwargs
unsatisfied_args = [] # parameters in signature that don't have a match in args or kwargs
unsatisfied_kwargs = [] # keyword-only arguments that don't have a match in kwargs
unmatched_args = list(args) # args that didn't match any of the parameters in the signature
unmatched_kwargs = list(kwargs) # kwargs that didn't match any of the parameters in the signature
has_varargs = has_var_kwargs = False # indicates if the signature defines *args and **kwargs respectively
if signature:
try:
sig = signature(func)
except ValueError:
return # signature() doesn't work against every kind of callable
for param in six.itervalues(sig.parameters):
if param.kind == param.POSITIONAL_OR_KEYWORD:
if param.name in unmatched_kwargs and unmatched_args:
pos_kwargs_conflicts.append(param.name)
elif unmatched_args:
del unmatched_args[0]
elif param.name in unmatched_kwargs:
unmatched_kwargs.remove(param.name)
elif param.default is param.empty:
unsatisfied_args.append(param.name)
elif param.kind == param.POSITIONAL_ONLY:
if unmatched_args:
del unmatched_args[0]
elif param.name in unmatched_kwargs:
unmatched_kwargs.remove(param.name)
positional_only_kwargs.append(param.name)
elif param.default is param.empty:
unsatisfied_args.append(param.name)
elif param.kind == param.KEYWORD_ONLY:
if param.name in unmatched_kwargs:
unmatched_kwargs.remove(param.name)
elif param.default is param.empty:
unsatisfied_kwargs.append(param.name)
elif param.kind == param.VAR_POSITIONAL:
has_varargs = True
elif param.kind == param.VAR_KEYWORD:
has_var_kwargs = True
else:
if not isfunction(func) and not ismethod(func) and hasattr(func, '__call__'):
func = func.__call__
try:
argspec = getargspec(func)
except TypeError:
return # getargspec() doesn't work certain callables
argspec_args = argspec.args if not ismethod(func) else argspec.args[1:]
has_varargs = bool(argspec.varargs)
has_var_kwargs = bool(argspec.keywords)
for arg, default in six.moves.zip_longest(argspec_args, argspec.defaults or (), fillvalue=undefined):
if arg in unmatched_kwargs and unmatched_args:
pos_kwargs_conflicts.append(arg)
elif unmatched_args:
del unmatched_args[0]
elif arg in unmatched_kwargs:
unmatched_kwargs.remove(arg)
elif default is undefined:
unsatisfied_args.append(arg)
# Make sure there are no conflicts between args and kwargs
if pos_kwargs_conflicts:
raise ValueError('The following arguments are supplied in both args and kwargs: %s' %
', '.join(pos_kwargs_conflicts))
# Check if keyword arguments are being fed to positional-only parameters
if positional_only_kwargs:
raise ValueError('The following arguments cannot be given as keyword arguments: %s' %
', '.join(positional_only_kwargs))
# Check that the number of positional arguments minus the number of matched kwargs matches the argspec
if unsatisfied_args:
raise ValueError('The following arguments have not been supplied: %s' % ', '.join(unsatisfied_args))
# Check that all keyword-only arguments have been supplied
if unsatisfied_kwargs:
raise ValueError('The following keyword-only arguments have not been supplied in kwargs: %s' %
', '.join(unsatisfied_kwargs))
# Check that the callable can accept the given number of positional arguments
if not has_varargs and unmatched_args:
raise ValueError('The list of positional arguments is longer than the target callable can handle '
'(allowed: %d, given in args: %d)' % (len(args) - len(unmatched_args), len(args)))
# Check that the callable can accept the given keyword arguments
if not has_var_kwargs and unmatched_kwargs:
raise ValueError('The target callable does not accept the following keyword arguments: %s' %
', '.join(unmatched_kwargs))

2386
lib/argparse.py Normal file

File diff suppressed because it is too large Load Diff

21
lib/beets/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
The MIT License
Copyright (c) 2010-2014 Adrian Sampson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

94
lib/beets/README.rst Normal file
View File

@@ -0,0 +1,94 @@
.. image:: https://travis-ci.org/sampsyo/beets.svg?branch=master
:target: https://travis-ci.org/sampsyo/beets
.. image:: http://img.shields.io/coveralls/sampsyo/beets.svg
:target: https://coveralls.io/r/sampsyo/beets
.. image:: http://img.shields.io/pypi/v/beets.svg
:target: https://pypi.python.org/pypi/beets
Beets is the media library management system for obsessive-compulsive music
geeks.
The purpose of beets is to get your music collection right once and for all.
It catalogs your collection, automatically improving its metadata as it goes.
It then provides a bouquet of tools for manipulating and accessing your music.
Here's an example of beets' brainy tag corrector doing its thing::
$ beet import ~/music/ladytron
Tagging:
Ladytron - Witching Hour
(Similarity: 98.4%)
* Last One Standing -> The Last One Standing
* Beauty -> Beauty*2
* White Light Generation -> Whitelightgenerator
* All the Way -> All the Way...
Because beets is designed as a library, it can do almost anything you can
imagine for your music collection. Via `plugins`_, beets becomes a panacea:
- Fetch or calculate all the metadata you could possibly need: `album art`_,
`lyrics`_, `genres`_, `tempos`_, `ReplayGain`_ levels, or `acoustic
fingerprints`_.
- Get metadata from `MusicBrainz`_, `Discogs`_, or `Beatport`_. Or guess
metadata using songs' filenames or their acoustic fingerprints.
- `Transcode audio`_ to any format you like.
- Check your library for `duplicate tracks and albums`_ or for `albums that
are missing tracks`_.
- Clean up crufty tags left behind by other, less-awesome tools.
- Embed and extract album art from files' metadata.
- Browse your music library graphically through a Web browser and play it in any
browser that supports `HTML5 Audio`_.
- Analyze music files' metadata from the command line.
- Listen to your library with a music player that speaks the `MPD`_ protocol
and works with a staggering variety of interfaces.
If beets doesn't do what you want yet, `writing your own plugin`_ is
shockingly simple if you know a little Python.
.. _plugins: http://beets.readthedocs.org/page/plugins/
.. _MPD: http://www.musicpd.org/
.. _MusicBrainz music collection: http://musicbrainz.org/doc/Collections/
.. _writing your own plugin:
http://beets.readthedocs.org/page/dev/plugins.html
.. _HTML5 Audio:
http://www.w3.org/TR/html-markup/audio.html
.. _albums that are missing tracks:
http://beets.readthedocs.org/page/plugins/missing.html
.. _duplicate tracks and albums:
http://beets.readthedocs.org/page/plugins/duplicates.html
.. _Transcode audio:
http://beets.readthedocs.org/page/plugins/convert.html
.. _Beatport: http://www.beatport.com/
.. _Discogs: http://www.discogs.com/
.. _acoustic fingerprints:
http://beets.readthedocs.org/page/plugins/chroma.html
.. _ReplayGain: http://beets.readthedocs.org/page/plugins/replaygain.html
.. _tempos: http://beets.readthedocs.org/page/plugins/echonest.html
.. _genres: http://beets.readthedocs.org/page/plugins/lastgenre.html
.. _album art: http://beets.readthedocs.org/page/plugins/fetchart.html
.. _lyrics: http://beets.readthedocs.org/page/plugins/lyrics.html
.. _MusicBrainz: http://musicbrainz.org/
Read More
---------
Learn more about beets at `its Web site`_. Follow `@b33ts`_ on Twitter for
news and updates.
You can install beets by typing ``pip install beets``. Then check out the
`Getting Started`_ guide.
.. _its Web site: http://beets.radbox.org/
.. _Getting Started: http://beets.readthedocs.org/page/guides/main.html
.. _@b33ts: http://twitter.com/b33ts/
Authors
-------
Beets is by `Adrian Sampson`_ with a supporting cast of thousands. For help,
please contact the `mailing list`_.
.. _mailing list: https://groups.google.com/forum/#!forum/beets-users
.. _Adrian Sampson: http://homes.cs.washington.edu/~asampson/

28
lib/beets/__init__.py Normal file
View File

@@ -0,0 +1,28 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# This particular version has been slightly modified to work with Headphones
# https://github.com/rembo10/headphones
__version__ = '1.3.10-headphones'
__author__ = 'Adrian Sampson <adrian@radbox.org>'
import os
import beets.library
from beets.util import confit
Library = beets.library.Library
config = confit.LazyConfig(os.path.dirname(__file__), __name__)

View File

@@ -0,0 +1,136 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Facilities for automatically determining files' correct metadata.
"""
import logging
from beets import config
# Parts of external interface.
from .hooks import AlbumInfo, TrackInfo, AlbumMatch, TrackMatch # noqa
from .match import tag_item, tag_album # noqa
from .match import Recommendation # noqa
# Global logger.
log = logging.getLogger('beets')
# Additional utilities for the main interface.
def apply_item_metadata(item, track_info):
"""Set an item's metadata from its matched TrackInfo object.
"""
item.artist = track_info.artist
item.artist_sort = track_info.artist_sort
item.artist_credit = track_info.artist_credit
item.title = track_info.title
item.mb_trackid = track_info.track_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
# At the moment, the other metadata is left intact (including album
# and track number). Perhaps these should be emptied?
def apply_metadata(album_info, mapping):
"""Set the items' metadata to match an AlbumInfo object using a
mapping from Items to TrackInfo objects.
"""
for item, track_info in mapping.iteritems():
# Album, artist, track count.
if track_info.artist:
item.artist = track_info.artist
else:
item.artist = album_info.artist
item.albumartist = album_info.artist
item.album = album_info.album
# Artist sort and credit names.
item.artist_sort = track_info.artist_sort or album_info.artist_sort
item.artist_credit = (track_info.artist_credit or
album_info.artist_credit)
item.albumartist_sort = album_info.artist_sort
item.albumartist_credit = album_info.artist_credit
# Release date.
for prefix in '', 'original_':
if config['original_date'] and not prefix:
# Ignore specific release date.
continue
for suffix in 'year', 'month', 'day':
key = prefix + suffix
value = getattr(album_info, key) or 0
# If we don't even have a year, apply nothing.
if suffix == 'year' and not value:
break
# Otherwise, set the fetched value (or 0 for the month
# and day if not available).
item[key] = value
# If we're using original release date for both fields,
# also set item.year = info.original_year, etc.
if config['original_date']:
item[suffix] = value
# Title.
item.title = track_info.title
if config['per_disc_numbering']:
item.track = track_info.medium_index or track_info.index
item.tracktotal = track_info.medium_total or len(album_info.tracks)
else:
item.track = track_info.index
item.tracktotal = len(album_info.tracks)
# Disc and disc count.
item.disc = track_info.medium
item.disctotal = album_info.mediums
# MusicBrainz IDs.
item.mb_trackid = track_info.track_id
item.mb_albumid = album_info.album_id
if track_info.artist_id:
item.mb_artistid = track_info.artist_id
else:
item.mb_artistid = album_info.artist_id
item.mb_albumartistid = album_info.artist_id
item.mb_releasegroupid = album_info.releasegroup_id
# Compilation flag.
item.comp = album_info.va
# Miscellaneous metadata.
for field in ('albumtype',
'label',
'asin',
'catalognum',
'script',
'language',
'country',
'albumstatus',
'albumdisambig'):
value = getattr(album_info, field)
if value is not None:
item[field] = value
if track_info.disctitle is not None:
item.disctitle = track_info.disctitle
if track_info.media is not None:
item.media = track_info.media
# Headphones seal of approval
item.comments = 'tagged by headphones/beets'

579
lib/beets/autotag/hooks.py Normal file
View File

@@ -0,0 +1,579 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Glue between metadata sources and the matching logic."""
import logging
from collections import namedtuple
import re
from beets import plugins
from beets import config
from beets.autotag import mb
from beets.util import levenshtein
from unidecode import unidecode
log = logging.getLogger('beets')
# Classes used to represent candidate options.
class AlbumInfo(object):
"""Describes a canonical release that may be used to match a release
in the library. Consists of these data members:
- ``album``: the release title
- ``album_id``: MusicBrainz ID; UUID fragment only
- ``artist``: name of the release's primary artist
- ``artist_id``
- ``tracks``: list of TrackInfo objects making up the release
- ``asin``: Amazon ASIN
- ``albumtype``: string describing the kind of release
- ``va``: boolean: whether the release has "various artists"
- ``year``: release year
- ``month``: release month
- ``day``: release day
- ``label``: music label responsible for the release
- ``mediums``: the number of discs in this release
- ``artist_sort``: name of the release's artist for sorting
- ``releasegroup_id``: MBID for the album's release group
- ``catalognum``: the label's catalog number for the release
- ``script``: character set used for metadata
- ``language``: human language of the metadata
- ``country``: the release country
- ``albumstatus``: MusicBrainz release status (Official, etc.)
- ``media``: delivery mechanism (Vinyl, etc.)
- ``albumdisambig``: MusicBrainz release disambiguation comment
- ``artist_credit``: Release-specific artist name
- ``data_source``: The original data source (MusicBrainz, Discogs, etc.)
- ``data_url``: The data source release URL.
The fields up through ``tracks`` are required. The others are
optional and may be None.
"""
def __init__(self, album, album_id, artist, artist_id, tracks, asin=None,
albumtype=None, va=False, year=None, month=None, day=None,
label=None, mediums=None, artist_sort=None,
releasegroup_id=None, catalognum=None, script=None,
language=None, country=None, albumstatus=None, media=None,
albumdisambig=None, artist_credit=None, original_year=None,
original_month=None, original_day=None, data_source=None,
data_url=None):
self.album = album
self.album_id = album_id
self.artist = artist
self.artist_id = artist_id
self.tracks = tracks
self.asin = asin
self.albumtype = albumtype
self.va = va
self.year = year
self.month = month
self.day = day
self.label = label
self.mediums = mediums
self.artist_sort = artist_sort
self.releasegroup_id = releasegroup_id
self.catalognum = catalognum
self.script = script
self.language = language
self.country = country
self.albumstatus = albumstatus
self.media = media
self.albumdisambig = albumdisambig
self.artist_credit = artist_credit
self.original_year = original_year
self.original_month = original_month
self.original_day = original_day
self.data_source = data_source
self.data_url = data_url
# Work around a bug in python-musicbrainz-ngs that causes some
# strings to be bytes rather than Unicode.
# https://github.com/alastair/python-musicbrainz-ngs/issues/85
def decode(self, codec='utf8'):
"""Ensure that all string attributes on this object, and the
constituent `TrackInfo` objects, are decoded to Unicode.
"""
for fld in ['album', 'artist', 'albumtype', 'label', 'artist_sort',
'catalognum', 'script', 'language', 'country',
'albumstatus', 'albumdisambig', 'artist_credit', 'media']:
value = getattr(self, fld)
if isinstance(value, str):
setattr(self, fld, value.decode(codec, 'ignore'))
if self.tracks:
for track in self.tracks:
track.decode(codec)
class TrackInfo(object):
"""Describes a canonical track present on a release. Appears as part
of an AlbumInfo's ``tracks`` list. Consists of these data members:
- ``title``: name of the track
- ``track_id``: MusicBrainz ID; UUID fragment only
- ``artist``: individual track artist name
- ``artist_id``
- ``length``: float: duration of the track in seconds
- ``index``: position on the entire release
- ``media``: delivery mechanism (Vinyl, etc.)
- ``medium``: the disc number this track appears on in the album
- ``medium_index``: the track's position on the disc
- ``medium_total``: the number of tracks on the item's disc
- ``artist_sort``: name of the track artist for sorting
- ``disctitle``: name of the individual medium (subtitle)
- ``artist_credit``: Recording-specific artist name
Only ``title`` and ``track_id`` are required. The rest of the fields
may be None. The indices ``index``, ``medium``, and ``medium_index``
are all 1-based.
"""
def __init__(self, title, track_id, artist=None, artist_id=None,
length=None, index=None, medium=None, medium_index=None,
medium_total=None, artist_sort=None, disctitle=None,
artist_credit=None, data_source=None, data_url=None,
media=None):
self.title = title
self.track_id = track_id
self.artist = artist
self.artist_id = artist_id
self.length = length
self.index = index
self.media = media
self.medium = medium
self.medium_index = medium_index
self.medium_total = medium_total
self.artist_sort = artist_sort
self.disctitle = disctitle
self.artist_credit = artist_credit
self.data_source = data_source
self.data_url = data_url
# As above, work around a bug in python-musicbrainz-ngs.
def decode(self, codec='utf8'):
"""Ensure that all string attributes on this object are decoded
to Unicode.
"""
for fld in ['title', 'artist', 'medium', 'artist_sort', 'disctitle',
'artist_credit', 'media']:
value = getattr(self, fld)
if isinstance(value, str):
setattr(self, fld, value.decode(codec, 'ignore'))
# Candidate distance scoring.
# Parameters for string distance function.
# Words that can be moved to the end of a string using a comma.
SD_END_WORDS = ['the', 'a', 'an']
# Reduced weights for certain portions of the string.
SD_PATTERNS = [
(r'^the ', 0.1),
(r'[\[\(]?(ep|single)[\]\)]?', 0.0),
(r'[\[\(]?(featuring|feat|ft)[\. :].+', 0.1),
(r'\(.*?\)', 0.3),
(r'\[.*?\]', 0.3),
(r'(, )?(pt\.|part) .+', 0.2),
]
# Replacements to use before testing distance.
SD_REPLACE = [
(r'&', 'and'),
]
def _string_dist_basic(str1, str2):
"""Basic edit distance between two strings, ignoring
non-alphanumeric characters and case. Comparisons are based on a
transliteration/lowering to ASCII characters. Normalized by string
length.
"""
str1 = unidecode(str1)
str2 = unidecode(str2)
str1 = re.sub(r'[^a-z0-9]', '', str1.lower())
str2 = re.sub(r'[^a-z0-9]', '', str2.lower())
if not str1 and not str2:
return 0.0
return levenshtein(str1, str2) / float(max(len(str1), len(str2)))
def string_dist(str1, str2):
"""Gives an "intuitive" edit distance between two strings. This is
an edit distance, normalized by the string length, with a number of
tweaks that reflect intuition about text.
"""
if str1 is None and str2 is None:
return 0.0
if str1 is None or str2 is None:
return 1.0
str1 = str1.lower()
str2 = str2.lower()
# Don't penalize strings that move certain words to the end. For
# example, "the something" should be considered equal to
# "something, the".
for word in SD_END_WORDS:
if str1.endswith(', %s' % word):
str1 = '%s %s' % (word, str1[:-len(word) - 2])
if str2.endswith(', %s' % word):
str2 = '%s %s' % (word, str2[:-len(word) - 2])
# Perform a couple of basic normalizing substitutions.
for pat, repl in SD_REPLACE:
str1 = re.sub(pat, repl, str1)
str2 = re.sub(pat, repl, str2)
# Change the weight for certain string portions matched by a set
# of regular expressions. We gradually change the strings and build
# up penalties associated with parts of the string that were
# deleted.
base_dist = _string_dist_basic(str1, str2)
penalty = 0.0
for pat, weight in SD_PATTERNS:
# Get strings that drop the pattern.
case_str1 = re.sub(pat, '', str1)
case_str2 = re.sub(pat, '', str2)
if case_str1 != str1 or case_str2 != str2:
# If the pattern was present (i.e., it is deleted in the
# the current case), recalculate the distances for the
# modified strings.
case_dist = _string_dist_basic(case_str1, case_str2)
case_delta = max(0.0, base_dist - case_dist)
if case_delta == 0.0:
continue
# Shift our baseline strings down (to avoid rematching the
# same part of the string) and add a scaled distance
# amount to the penalties.
str1 = case_str1
str2 = case_str2
base_dist = case_dist
penalty += weight * case_delta
return base_dist + penalty
class LazyClassProperty(object):
"""A decorator implementing a read-only property that is *lazy* in
the sense that the getter is only invoked once. Subsequent accesses
through *any* instance use the cached result.
"""
def __init__(self, getter):
self.getter = getter
self.computed = False
def __get__(self, obj, owner):
if not self.computed:
self.value = self.getter(owner)
self.computed = True
return self.value
class Distance(object):
"""Keeps track of multiple distance penalties. Provides a single
weighted distance for all penalties as well as a weighted distance
for each individual penalty.
"""
def __init__(self):
self._penalties = {}
@LazyClassProperty
def _weights(cls):
"""A dictionary from keys to floating-point weights.
"""
weights_view = config['match']['distance_weights']
weights = {}
for key in weights_view.keys():
weights[key] = weights_view[key].as_number()
return weights
# Access the components and their aggregates.
@property
def distance(self):
"""Return a weighted and normalized distance across all
penalties.
"""
dist_max = self.max_distance
if dist_max:
return self.raw_distance / self.max_distance
return 0.0
@property
def max_distance(self):
"""Return the maximum distance penalty (normalization factor).
"""
dist_max = 0.0
for key, penalty in self._penalties.iteritems():
dist_max += len(penalty) * self._weights[key]
return dist_max
@property
def raw_distance(self):
"""Return the raw (denormalized) distance.
"""
dist_raw = 0.0
for key, penalty in self._penalties.iteritems():
dist_raw += sum(penalty) * self._weights[key]
return dist_raw
def items(self):
"""Return a list of (key, dist) pairs, with `dist` being the
weighted distance, sorted from highest to lowest. Does not
include penalties with a zero value.
"""
list_ = []
for key in self._penalties:
dist = self[key]
if dist:
list_.append((key, dist))
# Convert distance into a negative float we can sort items in
# ascending order (for keys, when the penalty is equal) and
# still get the items with the biggest distance first.
return sorted(list_, key=lambda (key, dist): (0 - dist, key))
# Behave like a float.
def __cmp__(self, other):
return cmp(self.distance, other)
def __float__(self):
return self.distance
def __sub__(self, other):
return self.distance - other
def __rsub__(self, other):
return other - self.distance
# Behave like a dict.
def __getitem__(self, key):
"""Returns the weighted distance for a named penalty.
"""
dist = sum(self._penalties[key]) * self._weights[key]
dist_max = self.max_distance
if dist_max:
return dist / dist_max
return 0.0
def __iter__(self):
return iter(self.items())
def __len__(self):
return len(self.items())
def keys(self):
return [key for key, _ in self.items()]
def update(self, dist):
"""Adds all the distance penalties from `dist`.
"""
if not isinstance(dist, Distance):
raise ValueError(
'`dist` must be a Distance object, not {0}'.format(type(dist))
)
for key, penalties in dist._penalties.iteritems():
self._penalties.setdefault(key, []).extend(penalties)
# Adding components.
def _eq(self, value1, value2):
"""Returns True if `value1` is equal to `value2`. `value1` may
be a compiled regular expression, in which case it will be
matched against `value2`.
"""
if isinstance(value1, re._pattern_type):
return bool(value1.match(value2))
return value1 == value2
def add(self, key, dist):
"""Adds a distance penalty. `key` must correspond with a
configured weight setting. `dist` must be a float between 0.0
and 1.0, and will be added to any existing distance penalties
for the same key.
"""
if not 0.0 <= dist <= 1.0:
raise ValueError(
'`dist` must be between 0.0 and 1.0, not {0}'.format(dist)
)
self._penalties.setdefault(key, []).append(dist)
def add_equality(self, key, value, options):
"""Adds a distance penalty of 1.0 if `value` doesn't match any
of the values in `options`. If an option is a compiled regular
expression, it will be considered equal if it matches against
`value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
for opt in options:
if self._eq(opt, value):
dist = 0.0
break
else:
dist = 1.0
self.add(key, dist)
def add_expr(self, key, expr):
"""Adds a distance penalty of 1.0 if `expr` evaluates to True,
or 0.0.
"""
if expr:
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_number(self, key, number1, number2):
"""Adds a distance penalty of 1.0 for each number of difference
between `number1` and `number2`, or 0.0 when there is no
difference. Use this when there is no upper limit on the
difference between the two numbers.
"""
diff = abs(number1 - number2)
if diff:
for i in range(diff):
self.add(key, 1.0)
else:
self.add(key, 0.0)
def add_priority(self, key, value, options):
"""Adds a distance penalty that corresponds to the position at
which `value` appears in `options`. A distance penalty of 0.0
for the first option, or 1.0 if there is no matching option. If
an option is a compiled regular expression, it will be
considered equal if it matches against `value`.
"""
if not isinstance(options, (list, tuple)):
options = [options]
unit = 1.0 / (len(options) or 1)
for i, opt in enumerate(options):
if self._eq(opt, value):
dist = i * unit
break
else:
dist = 1.0
self.add(key, dist)
def add_ratio(self, key, number1, number2):
"""Adds a distance penalty for `number1` as a ratio of `number2`.
`number1` is bound at 0 and `number2`.
"""
number = float(max(min(number1, number2), 0))
if number2:
dist = number / number2
else:
dist = 0.0
self.add(key, dist)
def add_string(self, key, str1, str2):
"""Adds a distance penalty based on the edit distance between
`str1` and `str2`.
"""
dist = string_dist(str1, str2)
self.add(key, dist)
# Structures that compose all the information for a candidate match.
AlbumMatch = namedtuple('AlbumMatch', ['distance', 'info', 'mapping',
'extra_items', 'extra_tracks'])
TrackMatch = namedtuple('TrackMatch', ['distance', 'info'])
# Aggregation of sources.
def album_for_mbid(release_id):
"""Get an AlbumInfo object for a MusicBrainz release ID. Return None
if the ID is not found.
"""
try:
return mb.album_for_id(release_id)
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def track_for_mbid(recording_id):
"""Get a TrackInfo object for a MusicBrainz recording ID. Return None
if the ID is not found.
"""
try:
return mb.track_for_id(recording_id)
except mb.MusicBrainzAPIError as exc:
exc.log(log)
def albums_for_id(album_id):
"""Get a list of albums for an ID."""
candidates = [album_for_mbid(album_id)]
candidates.extend(plugins.album_for_id(album_id))
return filter(None, candidates)
def tracks_for_id(track_id):
"""Get a list of tracks for an ID."""
candidates = [track_for_mbid(track_id)]
candidates.extend(plugins.track_for_id(track_id))
return filter(None, candidates)
def album_candidates(items, artist, album, va_likely):
"""Search for album matches. ``items`` is a list of Item objects
that make up the album. ``artist`` and ``album`` are the respective
names (strings), which may be derived from the item list or may be
entered by the user. ``va_likely`` is a boolean indicating whether
the album is likely to be a "various artists" release.
"""
out = []
# Base candidates if we have album and artist to match.
if artist and album:
try:
out.extend(mb.match_album(artist, album, len(items)))
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Also add VA matches from MusicBrainz where appropriate.
if va_likely and album:
try:
out.extend(mb.match_album(None, album, len(items)))
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Candidates from plugins.
out.extend(plugins.candidates(items, artist, album, va_likely))
return out
def item_candidates(item, artist, title):
"""Search for item matches. ``item`` is the Item to be matched.
``artist`` and ``title`` are strings and either reflect the item or
are specified by the user.
"""
out = []
# MusicBrainz candidates.
if artist and title:
try:
out.extend(mb.match_track(artist, title))
except mb.MusicBrainzAPIError as exc:
exc.log(log)
# Plugin candidates.
out.extend(plugins.item_candidates(item, artist, title))
return out

492
lib/beets/autotag/match.py Normal file
View File

@@ -0,0 +1,492 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Matches existing metadata with canonical information to identify
releases and tracks.
"""
from __future__ import division
import datetime
import logging
import re
from munkres import Munkres
from beets import plugins
from beets import config
from beets.util import plurality
from beets.autotag import hooks
from beets.util.enumeration import OrderedEnum
# Artist signals that indicate "various artists". These are used at the
# album level to determine whether a given release is likely a VA
# release and also on the track level to to remove the penalty for
# differing artists.
VA_ARTISTS = (u'', u'various artists', u'various', u'va', u'unknown')
# Global logger.
log = logging.getLogger('beets')
# Recommendation enumeration.
class Recommendation(OrderedEnum):
"""Indicates a qualitative suggestion to the user about what should
be done with a given match.
"""
none = 0
low = 1
medium = 2
strong = 3
# Primary matching functionality.
def current_metadata(items):
"""Extract the likely current metadata for an album given a list of its
items. Return two dictionaries:
- The most common value for each field.
- Whether each field's value was unanimous (values are booleans).
"""
assert items # Must be nonempty.
likelies = {}
consensus = {}
fields = ['artist', 'album', 'albumartist', 'year', 'disctotal',
'mb_albumid', 'label', 'catalognum', 'country', 'media',
'albumdisambig']
for field in fields:
values = [item[field] for item in items if item]
likelies[field], freq = plurality(values)
consensus[field] = (freq == len(values))
# If there's an album artist consensus, use this for the artist.
if consensus['albumartist'] and likelies['albumartist']:
likelies['artist'] = likelies['albumartist']
return likelies, consensus
def assign_items(items, tracks):
"""Given a list of Items and a list of TrackInfo objects, find the
best mapping between them. Returns a mapping from Items to TrackInfo
objects, a set of extra Items, and a set of extra TrackInfo
objects. These "extra" objects occur when there is an unequal number
of objects of the two types.
"""
# Construct the cost matrix.
costs = []
for item in items:
row = []
for i, track in enumerate(tracks):
row.append(track_distance(item, track))
costs.append(row)
# Find a minimum-cost bipartite matching.
matching = Munkres().compute(costs)
# Produce the output matching.
mapping = dict((items[i], tracks[j]) for (i, j) in matching)
extra_items = list(set(items) - set(mapping.keys()))
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
extra_tracks.sort(key=lambda t: (t.index, t.title))
return mapping, extra_items, extra_tracks
def track_index_changed(item, track_info):
"""Returns True if the item and track info index is different. Tolerates
per disc and per release numbering.
"""
return item.track not in (track_info.medium_index, track_info.index)
def track_distance(item, track_info, incl_artist=False):
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
"""
dist = hooks.Distance()
# Length.
if track_info.length:
diff = abs(item.length - track_info.length) - \
config['match']['track_length_grace'].as_number()
dist.add_ratio('track_length', diff,
config['match']['track_length_max'].as_number())
# Title.
dist.add_string('track_title', item.title, track_info.title)
# Artist. Only check if there is actually an artist in the track data.
if incl_artist and track_info.artist and \
item.artist.lower() not in VA_ARTISTS:
dist.add_string('track_artist', item.artist, track_info.artist)
# Track index.
if track_info.index and item.track:
dist.add_expr('track_index', track_index_changed(item, track_info))
# Track ID.
if item.mb_trackid:
dist.add_expr('track_id', item.mb_trackid != track_info.track_id)
# Plugins.
dist.update(plugins.track_distance(item, track_info))
return dist
def distance(items, album_info, mapping):
"""Determines how "significant" an album metadata change would be.
Returns a Distance object. `album_info` is an AlbumInfo object
reflecting the album to be compared. `items` is a sequence of all
Item objects that will be matched (order is not important).
`mapping` is a dictionary mapping Items to TrackInfo objects; the
keys are a subset of `items` and the values are a subset of
`album_info.tracks`.
"""
likelies, _ = current_metadata(items)
dist = hooks.Distance()
# Artist, if not various.
if not album_info.va:
dist.add_string('artist', likelies['artist'], album_info.artist)
# Album.
dist.add_string('album', likelies['album'], album_info.album)
# Current or preferred media.
if album_info.media:
# Preferred media options.
patterns = config['match']['preferred']['media'].as_str_seq()
options = [re.compile(r'(\d+x)?(%s)' % pat, re.I) for pat in patterns]
if options:
dist.add_priority('media', album_info.media, options)
# Current media.
elif likelies['media']:
dist.add_equality('media', album_info.media, likelies['media'])
# Mediums.
if likelies['disctotal'] and album_info.mediums:
dist.add_number('mediums', likelies['disctotal'], album_info.mediums)
# Prefer earliest release.
if album_info.year and config['match']['preferred']['original_year']:
# Assume 1889 (earliest first gramophone discs) if we don't know the
# original year.
original = album_info.original_year or 1889
diff = abs(album_info.year - original)
diff_max = abs(datetime.date.today().year - original)
dist.add_ratio('year', diff, diff_max)
# Year.
elif likelies['year'] and album_info.year:
if likelies['year'] in (album_info.year, album_info.original_year):
# No penalty for matching release or original year.
dist.add('year', 0.0)
elif album_info.original_year:
# Prefer matchest closest to the release year.
diff = abs(likelies['year'] - album_info.year)
diff_max = abs(datetime.date.today().year -
album_info.original_year)
dist.add_ratio('year', diff, diff_max)
else:
# Full penalty when there is no original year.
dist.add('year', 1.0)
# Preferred countries.
patterns = config['match']['preferred']['countries'].as_str_seq()
options = [re.compile(pat, re.I) for pat in patterns]
if album_info.country and options:
dist.add_priority('country', album_info.country, options)
# Country.
elif likelies['country'] and album_info.country:
dist.add_string('country', likelies['country'], album_info.country)
# Label.
if likelies['label'] and album_info.label:
dist.add_string('label', likelies['label'], album_info.label)
# Catalog number.
if likelies['catalognum'] and album_info.catalognum:
dist.add_string('catalognum', likelies['catalognum'],
album_info.catalognum)
# Disambiguation.
if likelies['albumdisambig'] and album_info.albumdisambig:
dist.add_string('albumdisambig', likelies['albumdisambig'],
album_info.albumdisambig)
# Album ID.
if likelies['mb_albumid']:
dist.add_equality('album_id', likelies['mb_albumid'],
album_info.album_id)
# Tracks.
dist.tracks = {}
for item, track in mapping.iteritems():
dist.tracks[track] = track_distance(item, track, album_info.va)
dist.add('tracks', dist.tracks[track].distance)
# Missing tracks.
for i in range(len(album_info.tracks) - len(mapping)):
dist.add('missing_tracks', 1.0)
# Unmatched tracks.
for i in range(len(items) - len(mapping)):
dist.add('unmatched_tracks', 1.0)
# Plugins.
dist.update(plugins.album_distance(items, album_info, mapping))
return dist
def match_by_id(items):
"""If the items are tagged with a MusicBrainz album ID, returns an
AlbumInfo object for the corresponding album. Otherwise, returns
None.
"""
# Is there a consensus on the MB album ID?
albumids = [item.mb_albumid for item in items if item.mb_albumid]
if not albumids:
log.debug(u'No album IDs found.')
return None
# If all album IDs are equal, look up the album.
if bool(reduce(lambda x, y: x if x == y else (), albumids)):
albumid = albumids[0]
log.debug(u'Searching for discovered album ID: {0}'.format(albumid))
return hooks.album_for_mbid(albumid)
else:
log.debug(u'No album ID consensus.')
def _recommendation(results):
"""Given a sorted list of AlbumMatch or TrackMatch objects, return a
recommendation based on the results' distances.
If the recommendation is higher than the configured maximum for
an applied penalty, the recommendation will be downgraded to the
configured maximum for that penalty.
"""
if not results:
# No candidates: no recommendation.
return Recommendation.none
# Basic distance thresholding.
min_dist = results[0].distance
if min_dist < config['match']['strong_rec_thresh'].as_number():
# Strong recommendation level.
rec = Recommendation.strong
elif min_dist <= config['match']['medium_rec_thresh'].as_number():
# Medium recommendation level.
rec = Recommendation.medium
elif len(results) == 1:
# Only a single candidate.
rec = Recommendation.low
elif results[1].distance - min_dist >= \
config['match']['rec_gap_thresh'].as_number():
# Gap between first two candidates is large.
rec = Recommendation.low
else:
# No conclusion. Return immediately. Can't be downgraded any further.
return Recommendation.none
# Downgrade to the max rec if it is lower than the current rec for an
# applied penalty.
keys = set(min_dist.keys())
if isinstance(results[0], hooks.AlbumMatch):
for track_dist in min_dist.tracks.values():
keys.update(track_dist.keys())
max_rec_view = config['match']['max_rec']
for key in keys:
if key in max_rec_view.keys():
max_rec = max_rec_view[key].as_choice({
'strong': Recommendation.strong,
'medium': Recommendation.medium,
'low': Recommendation.low,
'none': Recommendation.none,
})
rec = min(rec, max_rec)
return rec
def _add_candidate(items, results, info):
"""Given a candidate AlbumInfo object, attempt to add the candidate
to the output dictionary of AlbumMatch objects. This involves
checking the track count, ordering the items, checking for
duplicates, and calculating the distance.
"""
log.debug(u'Candidate: {0} - {1}'.format(info.artist, info.album))
# Discard albums with zero tracks.
if not info.tracks:
log.debug('No tracks.')
return
# Don't duplicate.
if info.album_id in results:
log.debug(u'Duplicate.')
return
# Discard matches without required tags.
for req_tag in config['match']['required'].as_str_seq():
if getattr(info, req_tag) is None:
log.debug(u'Ignored. Missing required tag: {0}'.format(req_tag))
return
# Find mapping between the items and the track info.
mapping, extra_items, extra_tracks = assign_items(items, info.tracks)
# Get the change distance.
dist = distance(items, info, mapping)
# Skip matches with ignored penalties.
penalties = [key for _, key in dist]
for penalty in config['match']['ignored'].as_str_seq():
if penalty in penalties:
log.debug(u'Ignored. Penalty: {0}'.format(penalty))
return
log.debug(u'Success. Distance: {0}'.format(dist))
results[info.album_id] = hooks.AlbumMatch(dist, info, mapping,
extra_items, extra_tracks)
def tag_album(items, search_artist=None, search_album=None,
search_id=None):
"""Return a tuple of a artist name, an album name, a list of
`AlbumMatch` candidates from the metadata backend, and a
`Recommendation`.
The artist and album are the most common values of these fields
among `items`.
The `AlbumMatch` objects are generated by searching the metadata
backends. By default, the metadata of the items is used for the
search. This can be customized by setting the parameters. The
`mapping` field of the album has the matched `items` as keys.
The recommendation is calculated from the match qualitiy of the
candidates.
"""
# Get current metadata.
likelies, consensus = current_metadata(items)
cur_artist = likelies['artist']
cur_album = likelies['album']
log.debug(u'Tagging {0} - {1}'.format(cur_artist, cur_album))
# The output result (distance, AlbumInfo) tuples (keyed by MB album
# ID).
candidates = {}
# Search by explicit ID.
if search_id is not None:
log.debug(u'Searching for album ID: {0}'.format(search_id))
search_cands = hooks.albums_for_id(search_id)
# Use existing metadata or text search.
else:
# Try search based on current ID.
id_info = match_by_id(items)
if id_info:
_add_candidate(items, candidates, id_info)
rec = _recommendation(candidates.values())
log.debug(u'Album ID match recommendation is {0}'.format(str(rec)))
if candidates and not config['import']['timid']:
# If we have a very good MBID match, return immediately.
# Otherwise, this match will compete against metadata-based
# matches.
if rec == Recommendation.strong:
log.debug(u'ID match.')
return cur_artist, cur_album, candidates.values(), rec
# Search terms.
if not (search_artist and search_album):
# No explicit search terms -- use current metadata.
search_artist, search_album = cur_artist, cur_album
log.debug(u'Search terms: {0} - {1}'.format(search_artist,
search_album))
# Is this album likely to be a "various artist" release?
va_likely = ((not consensus['artist']) or
(search_artist.lower() in VA_ARTISTS) or
any(item.comp for item in items))
log.debug(u'Album might be VA: {0}'.format(str(va_likely)))
# Get the results from the data sources.
search_cands = hooks.album_candidates(items, search_artist,
search_album, va_likely)
log.debug(u'Evaluating {0} candidates.'.format(len(search_cands)))
for info in search_cands:
_add_candidate(items, candidates, info)
# Sort and get the recommendation.
candidates = sorted(candidates.itervalues())
rec = _recommendation(candidates)
return cur_artist, cur_album, candidates, rec
def tag_item(item, search_artist=None, search_title=None,
search_id=None):
"""Attempts to find metadata for a single track. Returns a
`(candidates, recommendation)` pair where `candidates` is a list of
TrackMatch objects. `search_artist` and `search_title` may be used
to override the current metadata for the purposes of the MusicBrainz
title; likewise `search_id`.
"""
# Holds candidates found so far: keys are MBIDs; values are
# (distance, TrackInfo) pairs.
candidates = {}
# First, try matching by MusicBrainz ID.
trackid = search_id or item.mb_trackid
if trackid:
log.debug(u'Searching for track ID: {0}'.format(trackid))
for track_info in hooks.tracks_for_id(trackid):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = \
hooks.TrackMatch(dist, track_info)
# If this is a good match, then don't keep searching.
rec = _recommendation(candidates.values())
if rec == Recommendation.strong and not config['import']['timid']:
log.debug(u'Track ID match.')
return candidates.values(), rec
# If we're searching by ID, don't proceed.
if search_id is not None:
if candidates:
return candidates.values(), rec
else:
return [], Recommendation.none
# Search terms.
if not (search_artist and search_title):
search_artist, search_title = item.artist, item.title
log.debug(u'Item search terms: {0} - {1}'.format(search_artist,
search_title))
# Get and evaluate candidate metadata.
for track_info in hooks.item_candidates(item, search_artist, search_title):
dist = track_distance(item, track_info, incl_artist=True)
candidates[track_info.track_id] = hooks.TrackMatch(dist, track_info)
# Sort by distance and return with recommendation.
log.debug(u'Found {0} candidates.'.format(len(candidates)))
candidates = sorted(candidates.itervalues())
rec = _recommendation(candidates)
return candidates, rec

407
lib/beets/autotag/mb.py Normal file
View File

@@ -0,0 +1,407 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Searches for albums in the MusicBrainz database.
"""
import logging
import musicbrainzngs
import re
import traceback
from urlparse import urljoin
import beets.autotag.hooks
import beets
from beets import util
from beets import config
SEARCH_LIMIT = 5
VARIOUS_ARTISTS_ID = '89ad4ac3-39f7-470e-963a-56509c546377'
BASE_URL = 'http://musicbrainz.org/'
musicbrainzngs.set_useragent('beets', beets.__version__,
'http://beets.radbox.org/')
class MusicBrainzAPIError(util.HumanReadableException):
"""An error while talking to MusicBrainz. The `query` field is the
parameter to the action and may have any type.
"""
def __init__(self, reason, verb, query, tb=None):
self.query = query
super(MusicBrainzAPIError, self).__init__(reason, verb, tb)
def get_message(self):
return u'{0} in {1} with query {2}'.format(
self._reasonstr(), self.verb, repr(self.query)
)
log = logging.getLogger('beets')
RELEASE_INCLUDES = ['artists', 'media', 'recordings', 'release-groups',
'labels', 'artist-credits', 'aliases']
TRACK_INCLUDES = ['artists', 'aliases']
def track_url(trackid):
return urljoin(BASE_URL, 'recording/' + trackid)
def album_url(albumid):
return urljoin(BASE_URL, 'release/' + albumid)
def configure():
"""Set up the python-musicbrainz-ngs module according to settings
from the beets configuration. This should be called at startup.
"""
musicbrainzngs.set_hostname(config['musicbrainz']['host'].get(unicode))
musicbrainzngs.set_rate_limit(
config['musicbrainz']['ratelimit_interval'].as_number(),
config['musicbrainz']['ratelimit'].get(int),
)
def _preferred_alias(aliases):
"""Given an list of alias structures for an artist credit, select
and return the user's preferred alias alias or None if no matching
alias is found.
"""
if not aliases:
return
# Only consider aliases that have locales set.
aliases = [a for a in aliases if 'locale' in a]
# Search configured locales in order.
for locale in config['import']['languages'].as_str_seq():
# Find matching primary aliases for this locale.
matches = [a for a in aliases
if a['locale'] == locale and 'primary' in a]
# Skip to the next locale if we have no matches
if not matches:
continue
return matches[0]
def _flatten_artist_credit(credit):
"""Given a list representing an ``artist-credit`` block, flatten the
data into a triple of joined artist name strings: canonical, sort, and
credit.
"""
artist_parts = []
artist_sort_parts = []
artist_credit_parts = []
for el in credit:
if isinstance(el, basestring):
# Join phrase.
artist_parts.append(el)
artist_credit_parts.append(el)
artist_sort_parts.append(el)
else:
alias = _preferred_alias(el['artist'].get('alias-list', ()))
# An artist.
if alias:
cur_artist_name = alias['alias']
else:
cur_artist_name = el['artist']['name']
artist_parts.append(cur_artist_name)
# Artist sort name.
if alias:
artist_sort_parts.append(alias['sort-name'])
elif 'sort-name' in el['artist']:
artist_sort_parts.append(el['artist']['sort-name'])
else:
artist_sort_parts.append(cur_artist_name)
# Artist credit.
if 'name' in el:
artist_credit_parts.append(el['name'])
else:
artist_credit_parts.append(cur_artist_name)
return (
''.join(artist_parts),
''.join(artist_sort_parts),
''.join(artist_credit_parts),
)
def track_info(recording, index=None, medium=None, medium_index=None,
medium_total=None):
"""Translates a MusicBrainz recording result dictionary into a beets
``TrackInfo`` object. Three parameters are optional and are used
only for tracks that appear on releases (non-singletons): ``index``,
the overall track number; ``medium``, the disc number;
``medium_index``, the track's index on its medium; ``medium_total``,
the number of tracks on the medium. Each number is a 1-based index.
"""
info = beets.autotag.hooks.TrackInfo(
recording['title'],
recording['id'],
index=index,
medium=medium,
medium_index=medium_index,
medium_total=medium_total,
data_url=track_url(recording['id']),
)
if recording.get('artist-credit'):
# Get the artist names.
info.artist, info.artist_sort, info.artist_credit = \
_flatten_artist_credit(recording['artist-credit'])
# Get the ID and sort name of the first artist.
artist = recording['artist-credit'][0]['artist']
info.artist_id = artist['id']
if recording.get('length'):
info.length = int(recording['length']) / (1000.0)
info.decode()
return info
def _set_date_str(info, date_str, original=False):
"""Given a (possibly partial) YYYY-MM-DD string and an AlbumInfo
object, set the object's release date fields appropriately. If
`original`, then set the original_year, etc., fields.
"""
if date_str:
date_parts = date_str.split('-')
for key in ('year', 'month', 'day'):
if date_parts:
date_part = date_parts.pop(0)
try:
date_num = int(date_part)
except ValueError:
continue
if original:
key = 'original_' + key
setattr(info, key, date_num)
def album_info(release):
"""Takes a MusicBrainz release result dictionary and returns a beets
AlbumInfo object containing the interesting data about that release.
"""
# Get artist name using join phrases.
artist_name, artist_sort_name, artist_credit_name = \
_flatten_artist_credit(release['artist-credit'])
# Basic info.
track_infos = []
index = 0
for medium in release['medium-list']:
disctitle = medium.get('title')
format = medium.get('format')
for track in medium['track-list']:
# Basic information from the recording.
index += 1
ti = track_info(
track['recording'],
index,
int(medium['position']),
int(track['position']),
len(medium['track-list']),
)
ti.disctitle = disctitle
ti.media = format
# Prefer track data, where present, over recording data.
if track.get('title'):
ti.title = track['title']
if track.get('artist-credit'):
# Get the artist names.
ti.artist, ti.artist_sort, ti.artist_credit = \
_flatten_artist_credit(track['artist-credit'])
ti.artist_id = track['artist-credit'][0]['artist']['id']
if track.get('length'):
ti.length = int(track['length']) / (1000.0)
track_infos.append(ti)
info = beets.autotag.hooks.AlbumInfo(
release['title'],
release['id'],
artist_name,
release['artist-credit'][0]['artist']['id'],
track_infos,
mediums=len(release['medium-list']),
artist_sort=artist_sort_name,
artist_credit=artist_credit_name,
data_source='MusicBrainz',
data_url=album_url(release['id']),
)
info.va = info.artist_id == VARIOUS_ARTISTS_ID
info.asin = release.get('asin')
info.releasegroup_id = release['release-group']['id']
info.country = release.get('country')
info.albumstatus = release.get('status')
# Build up the disambiguation string from the release group and release.
disambig = []
if release['release-group'].get('disambiguation'):
disambig.append(release['release-group'].get('disambiguation'))
if release.get('disambiguation'):
disambig.append(release.get('disambiguation'))
info.albumdisambig = u', '.join(disambig)
# Release type not always populated.
if 'type' in release['release-group']:
reltype = release['release-group']['type']
if reltype:
info.albumtype = reltype.lower()
# Release dates.
release_date = release.get('date')
release_group_date = release['release-group'].get('first-release-date')
if not release_date:
# Fall back if release-specific date is not available.
release_date = release_group_date
_set_date_str(info, release_date, False)
_set_date_str(info, release_group_date, True)
# Label name.
if release.get('label-info-list'):
label_info = release['label-info-list'][0]
if label_info.get('label'):
label = label_info['label']['name']
if label != '[no label]':
info.label = label
info.catalognum = label_info.get('catalog-number')
# Text representation data.
if release.get('text-representation'):
rep = release['text-representation']
info.script = rep.get('script')
info.language = rep.get('language')
# Media (format).
if release['medium-list']:
first_medium = release['medium-list'][0]
info.media = first_medium.get('format')
info.decode()
return info
def match_album(artist, album, tracks=None, limit=SEARCH_LIMIT):
"""Searches for a single album ("release" in MusicBrainz parlance)
and returns an iterator over AlbumInfo objects. May raise a
MusicBrainzAPIError.
The query consists of an artist name, an album name, and,
optionally, a number of tracks on the album.
"""
# Build search criteria.
criteria = {'release': album.lower().strip()}
if artist is not None:
criteria['artist'] = artist.lower().strip()
else:
# Various Artists search.
criteria['arid'] = VARIOUS_ARTISTS_ID
if tracks is not None:
criteria['tracks'] = str(tracks)
# Abort if we have no search terms.
if not any(criteria.itervalues()):
return
try:
res = musicbrainzngs.search_releases(limit=limit, **criteria)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'release search', criteria,
traceback.format_exc())
for release in res['release-list']:
# The search result is missing some data (namely, the tracks),
# so we just use the ID and fetch the rest of the information.
albuminfo = album_for_id(release['id'])
if albuminfo is not None:
yield albuminfo
def match_track(artist, title, limit=SEARCH_LIMIT):
"""Searches for a single track and returns an iterable of TrackInfo
objects. May raise a MusicBrainzAPIError.
"""
criteria = {
'artist': artist.lower().strip(),
'recording': title.lower().strip(),
}
if not any(criteria.itervalues()):
return
try:
res = musicbrainzngs.search_recordings(limit=limit, **criteria)
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'recording search', criteria,
traceback.format_exc())
for recording in res['recording-list']:
yield track_info(recording)
def _parse_id(s):
"""Search for a MusicBrainz ID in the given string and return it. If
no ID can be found, return None.
"""
# Find the first thing that looks like a UUID/MBID.
match = re.search('[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}', s)
if match:
return match.group()
def album_for_id(releaseid):
"""Fetches an album by its MusicBrainz ID and returns an AlbumInfo
object or None if the album is not found. May raise a
MusicBrainzAPIError.
"""
albumid = _parse_id(releaseid)
if not albumid:
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
return
try:
res = musicbrainzngs.get_release_by_id(albumid,
RELEASE_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Album ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'get release by ID', albumid,
traceback.format_exc())
return album_info(res['release'])
def track_for_id(releaseid):
"""Fetches a track by its MusicBrainz ID. Returns a TrackInfo object
or None if no track is found. May raise a MusicBrainzAPIError.
"""
trackid = _parse_id(releaseid)
if not trackid:
log.debug(u'Invalid MBID ({0}).'.format(releaseid))
return
try:
res = musicbrainzngs.get_recording_by_id(trackid, TRACK_INCLUDES)
except musicbrainzngs.ResponseError:
log.debug(u'Track ID match failed.')
return None
except musicbrainzngs.MusicBrainzError as exc:
raise MusicBrainzAPIError(exc, 'get recording by ID', trackid,
traceback.format_exc())
return track_info(res['recording'])

View File

@@ -0,0 +1,109 @@
library: library.db
directory: ~/Music
import:
write: yes
copy: yes
move: no
link: no
delete: no
resume: ask
incremental: no
quiet_fallback: skip
none_rec_action: ask
timid: no
log:
autotag: yes
quiet: no
singletons: no
default_action: apply
languages: []
detail: no
flat: no
group_albums: no
pretend: false
clutter: ["Thumbs.DB", ".DS_Store"]
ignore: [".*", "*~", "System Volume Information"]
replace:
'[\\/]': _
'^\.': _
'[\x00-\x1f]': _
'[<>:"\?\*\|]': _
'\.$': _
'\s+$': ''
'^\s+': ''
path_sep_replace: _
asciify_paths: false
art_filename: cover
max_filename_length: 0
plugins: []
pluginpath: []
threaded: yes
color: yes
timeout: 5.0
per_disc_numbering: no
verbose: no
terminal_encoding: utf8
original_date: no
id3v23: no
ui:
terminal_width: 80
length_diff_thresh: 10.0
list_format_item: $artist - $album - $title
list_format_album: $albumartist - $album
time_format: '%Y-%m-%d %H:%M:%S'
sort_album: albumartist+ album+
sort_item: artist+ album+ disc+ track+
paths:
default: $albumartist/$album%aunique{}/$track $title
singleton: Non-Album/$artist/$title
comp: Compilations/$album%aunique{}/$track $title
statefile: state.pickle
musicbrainz:
host: musicbrainz.org
ratelimit: 1
ratelimit_interval: 1.0
match:
strong_rec_thresh: 0.04
medium_rec_thresh: 0.25
rec_gap_thresh: 0.25
max_rec:
missing_tracks: medium
unmatched_tracks: medium
distance_weights:
source: 2.0
artist: 3.0
album: 3.0
media: 1.0
mediums: 1.0
year: 1.0
country: 0.5
label: 0.5
catalognum: 0.5
albumdisambig: 0.5
album_id: 5.0
tracks: 2.0
missing_tracks: 0.9
unmatched_tracks: 0.6
track_title: 3.0
track_artist: 2.0
track_index: 1.0
track_length: 2.0
track_id: 5.0
preferred:
countries: []
media: []
original_year: no
ignored: []
required: []
track_length_grace: 10
track_length_max: 30

View File

@@ -0,0 +1,25 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""DBCore is an abstract database package that forms the basis for beets'
Library.
"""
from .db import Model, Database
from .query import Query, FieldQuery, MatchQuery, AndQuery, OrQuery
from .types import Type
from .queryparse import query_from_strings
from .queryparse import sort_from_strings
from .queryparse import parse_sorted_query
# flake8: noqa

820
lib/beets/dbcore/db.py Normal file
View File

@@ -0,0 +1,820 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The central Model and Database constructs for DBCore.
"""
import time
import os
from collections import defaultdict
import threading
import sqlite3
import contextlib
import collections
import beets
from beets.util.functemplate import Template
from beets.dbcore import types
from .query import MatchQuery, NullSort, TrueQuery
class FormattedMapping(collections.Mapping):
"""A `dict`-like formatted view of a model.
The accessor `mapping[key]` returns the formated version of
`model[key]` as a unicode string.
If `for_path` is true, all path separators in the formatted values
are replaced.
"""
def __init__(self, model, for_path=False):
self.for_path = for_path
self.model = model
self.model_keys = model.keys(True)
def __getitem__(self, key):
if key in self.model_keys:
return self._get_formatted(self.model, key)
else:
raise KeyError(key)
def __iter__(self):
return iter(self.model_keys)
def __len__(self):
return len(self.model_keys)
def get(self, key, default=None):
if default is None:
default = self.model._type(key).format(None)
return super(FormattedMapping, self).get(key, default)
def _get_formatted(self, model, key):
value = model._type(key).format(model.get(key))
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
if self.for_path:
sep_repl = beets.config['path_sep_replace'].get(unicode)
for sep in (os.path.sep, os.path.altsep):
if sep:
value = value.replace(sep, sep_repl)
return value
# Abstract base for model classes.
class Model(object):
"""An abstract object representing an object in the database. Model
objects act like dictionaries (i.e., the allow subscript access like
``obj['field']``). The same field set is available via attribute
access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
available:
* **Fixed attributes** come from a predetermined list of field
names. These fields correspond to SQLite table columns and are
thus fast to read, write, and query.
* **Flexible attributes** are free-form and do not need to be listed
ahead of time.
* **Computed attributes** are read-only fields computed by a getter
function provided by a plugin.
Access to all three field types is uniform: ``obj.field`` works the
same regardless of whether ``field`` is fixed, flexible, or
computed.
Model objects can optionally be associated with a `Library` object,
in which case they can be loaded and stored from the database. Dirty
flags are used to track which fields need to be stored.
"""
# Abstract components (to be provided by subclasses).
_table = None
"""The main SQLite table name.
"""
_flex_table = None
"""The flex field SQLite table name.
"""
_fields = {}
"""A mapping indicating available "fixed" fields on this type. The
keys are field names and the values are `Type` objects.
"""
_search_fields = ()
"""The fields that should be queried by default by unqualified query
terms.
"""
_types = {}
"""Optional Types for non-fixed (i.e., flexible and computed) fields.
"""
_sorts = {}
"""Optional named sort criteria. The keys are strings and the values
are subclasses of `Sort`.
"""
_always_dirty = False
"""By default, fields only become "dirty" when their value actually
changes. Enabling this flag marks fields as dirty even when the new
value is the same as the old value (e.g., `o.f = o.f`).
"""
@classmethod
def _getters(cls):
"""Return a mapping from field names to getter functions.
"""
# We could cache this if it becomes a performance problem to
# gather the getter mapping every time.
raise NotImplementedError()
def _template_funcs(self):
"""Return a mapping from function names to text-transformer
functions.
"""
# As above: we could consider caching this result.
raise NotImplementedError()
# Basic operation.
def __init__(self, db=None, **values):
"""Create a new object with an optional Database association and
initial field values.
"""
self._db = db
self._dirty = set()
self._values_fixed = {}
self._values_flex = {}
# Initial contents.
self.update(values)
self.clear_dirty()
@classmethod
def _awaken(cls, db=None, fixed_values={}, flex_values={}):
"""Create an object with values drawn from the database.
This is a performance optimization: the checks involved with
ordinary construction are bypassed.
"""
obj = cls(db)
for key, value in fixed_values.iteritems():
obj._values_fixed[key] = cls._type(key).from_sql(value)
for key, value in flex_values.iteritems():
obj._values_flex[key] = cls._type(key).from_sql(value)
return obj
def __repr__(self):
return '{0}({1})'.format(
type(self).__name__,
', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
)
def clear_dirty(self):
"""Mark all fields as *clean* (i.e., not needing to be stored to
the database).
"""
self._dirty = set()
def _check_db(self, need_id=True):
"""Ensure that this object is associated with a database row: it
has a reference to a database (`_db`) and an id. A ValueError
exception is raised otherwise.
"""
if not self._db:
raise ValueError('{0} has no database'.format(type(self).__name__))
if need_id and not self.id:
raise ValueError('{0} has no id'.format(type(self).__name__))
# Essential field accessors.
@classmethod
def _type(self, key):
"""Get the type of a field, a `Type` instance.
If the field has no explicit type, it is given the base `Type`,
which does no conversion.
"""
return self._fields.get(key) or self._types.get(key) or types.DEFAULT
def __getitem__(self, key):
"""Get the value for a field. Raise a KeyError if the field is
not available.
"""
getters = self._getters()
if key in getters: # Computed.
return getters[key](self)
elif key in self._fields: # Fixed.
return self._values_fixed.get(key)
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
else:
raise KeyError(key)
def __setitem__(self, key, value):
"""Assign the value for a field.
"""
# Choose where to place the value.
if key in self._fields:
source = self._values_fixed
else:
source = self._values_flex
# If the field has a type, filter the value.
value = self._type(key).normalize(value)
# Assign value and possibly mark as dirty.
old_value = source.get(key)
source[key] = value
if self._always_dirty or old_value != value:
self._dirty.add(key)
def __delitem__(self, key):
"""Remove a flexible attribute from the model.
"""
if key in self._values_flex: # Flexible.
del self._values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._getters(): # Computed.
raise KeyError('computed field {0} cannot be deleted'.format(key))
elif key in self._fields: # Fixed.
raise KeyError('fixed field {0} cannot be deleted'.format(key))
else:
raise KeyError('no such field {0}'.format(key))
def keys(self, computed=False):
"""Get a list of available field names for this object. The
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
"""
base_keys = list(self._fields) + self._values_flex.keys()
if computed:
return base_keys + self._getters().keys()
else:
return base_keys
# Act like a dictionary.
def update(self, values):
"""Assign all values in the given dict.
"""
for key, value in values.items():
self[key] = value
def items(self):
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]
def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default
def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys(True)
def __iter__(self):
"""Iterate over the available field names (excluding computed
fields).
"""
return iter(self.keys())
# Convenient attribute access.
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError('model has no attribute {0!r}'.format(key))
else:
try:
return self[key]
except KeyError:
raise AttributeError('no such field {0!r}'.format(key))
def __setattr__(self, key, value):
if key.startswith('_'):
super(Model, self).__setattr__(key, value)
else:
self[key] = value
def __delattr__(self, key):
if key.startswith('_'):
super(Model, self).__delattr__(key)
else:
del self[key]
# Database interaction (CRUD methods).
def store(self):
"""Save the object's metadata into the library database.
"""
self._check_db()
# Build assignments for query.
assignments = []
subvars = []
for key in self._fields:
if key != 'id' and key in self._dirty:
self._dirty.remove(key)
assignments.append(key + '=?')
value = self._type(key).to_sql(self[key])
subvars.append(value)
assignments = ','.join(assignments)
with self._db.transaction() as tx:
# Main table update.
if assignments:
query = 'UPDATE {0} SET {1} WHERE id=?'.format(
self._table, assignments
)
subvars.append(self.id)
tx.mutate(query, subvars)
# Modified/added flexible attributes.
for key, value in self._values_flex.items():
if key in self._dirty:
self._dirty.remove(key)
tx.mutate(
'INSERT INTO {0} '
'(entity_id, key, value) '
'VALUES (?, ?, ?);'.format(self._flex_table),
(self.id, key, value),
)
# Deleted flexible attributes.
for key in self._dirty:
tx.mutate(
'DELETE FROM {0} '
'WHERE entity_id=? AND key=?'.format(self._flex_table),
(self.id, key)
)
self.clear_dirty()
def load(self):
"""Refresh the object's metadata from the library database.
"""
self._check_db()
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, "object {0} not in DB".format(self.id)
self._values_fixed = {}
self._values_flex = {}
self.update(dict(stored_obj))
self.clear_dirty()
def remove(self):
"""Remove the object's associated rows from the database.
"""
self._check_db()
with self._db.transaction() as tx:
tx.mutate(
'DELETE FROM {0} WHERE id=?'.format(self._table),
(self.id,)
)
tx.mutate(
'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
(self.id,)
)
def add(self, db=None):
"""Add the object to the library database. This object must be
associated with a database; you can provide one via the `db`
parameter or use the currently associated database.
The object's `id` and `added` fields are set along with any
current field values.
"""
if db:
self._db = db
self._check_db(False)
with self._db.transaction() as tx:
new_id = tx.mutate(
'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
)
self.id = new_id
self.added = time.time()
# Mark every non-null field as dirty and store.
for key in self:
if self[key] is not None:
self._dirty.add(key)
self.store()
# Formatting and templating.
_formatter = FormattedMapping
def formatted(self, for_path=False):
"""Get a mapping containing all values on this object formatted
as human-readable unicode strings.
"""
return self._formatter(self, for_path)
def evaluate_template(self, template, for_path=False):
"""Evaluate a template (a string or a `Template` object) using
the object's fields. If `for_path` is true, then no new path
separators will be added to the template.
"""
# Perform substitution.
if isinstance(template, basestring):
template = Template(template)
return template.substitute(self.formatted(for_path),
self._template_funcs())
# Parsing.
@classmethod
def _parse(cls, key, string):
"""Parse a string as a value for the given key.
"""
if not isinstance(string, basestring):
raise TypeError("_parse() argument must be a string")
return cls._type(key).parse(string)
# Database controller and supporting interfaces.
class Results(object):
"""An item query result set. Iterating over the collection lazily
constructs LibModel objects that reflect database rows.
"""
def __init__(self, model_class, rows, db, query=None, sort=None):
"""Create a result set that will construct objects of type
`model_class`.
`model_class` is a subclass of `LibModel` that will be
constructed. `rows` is a query result: a list of mappings. The
new objects will be associated with the database `db`.
If `query` is provided, it is used as a predicate to filter the
results for a "slow query" that cannot be evaluated by the
database directly. If `sort` is provided, it is used to sort the
full list of results before returning. This means it is a "slow
sort" and all objects must be built before returning the first
one.
"""
self.model_class = model_class
self.rows = rows
self.db = db
self.query = query
self.sort = sort
# We keep a queue of rows we haven't yet consumed for
# materialization. We preserve the original total number of
# rows.
self._rows = rows
self._row_count = len(rows)
# The materialized objects corresponding to rows that have been
# consumed.
self._objects = []
def _get_objects(self):
"""Construct and generate Model objects for they query. The
objects are returned in the order emitted from the database; no
slow sort is applied.
For performance, this generator caches materialized objects to
avoid constructing them more than once. This way, iterating over
a `Results` object a second time should be much faster than the
first.
"""
index = 0 # Position in the materialized objects.
while index < len(self._objects) or self._rows:
# Are there previously-materialized objects to produce?
if index < len(self._objects):
yield self._objects[index]
index += 1
# Otherwise, we consume another row, materialize its object
# and produce it.
else:
while self._rows:
row = self._rows.pop(0)
obj = self._make_model(row)
# If there is a slow-query predicate, ensurer that the
# object passes it.
if not self.query or self.query.match(obj):
self._objects.append(obj)
index += 1
yield obj
break
def __iter__(self):
"""Construct and generate Model objects for all matching
objects, in sorted order.
"""
if self.sort:
# Slow sort. Must build the full list first.
objects = self.sort.sort(list(self._get_objects()))
return iter(objects)
else:
# Objects are pre-sorted (i.e., by the database).
return self._get_objects()
def _make_model(self, row):
# Get the flexible attributes for the object.
with self.db.transaction() as tx:
flex_rows = tx.query(
'SELECT * FROM {0} WHERE entity_id=?'.format(
self.model_class._flex_table
),
(row['id'],)
)
cols = dict(row)
values = dict((k, v) for (k, v) in cols.items()
if not k[:4] == 'flex')
flex_values = dict((row['key'], row['value']) for row in flex_rows)
# Construct the Python object
obj = self.model_class._awaken(self.db, values, flex_values)
return obj
def __len__(self):
"""Get the number of matching objects.
"""
if not self._rows:
# Fully materialized. Just count the objects.
return len(self._objects)
elif self.query:
# A slow query. Fall back to testing every object.
count = 0
for obj in self:
count += 1
return count
else:
# A fast query. Just count the rows.
return self._row_count
def __nonzero__(self):
"""Does this result contain any objects?
"""
return bool(len(self))
def __getitem__(self, n):
"""Get the nth item in this result set. This is inefficient: all
items up to n are materialized and thrown away.
"""
if not self._rows and not self.sort:
# Fully materialized and already in order. Just look up the
# object.
return self._objects[n]
it = iter(self)
try:
for i in range(n):
it.next()
return it.next()
except StopIteration:
raise IndexError('result index {0} out of range'.format(n))
def get(self):
"""Return the first matching object, or None if no objects
match.
"""
it = iter(self)
try:
return it.next()
except StopIteration:
return None
class Transaction(object):
"""A context manager for safe, concurrent access to the database.
All SQL commands should be executed through a transaction.
"""
def __init__(self, db):
self.db = db
def __enter__(self):
"""Begin a transaction. This transaction may be created while
another is active in a different thread.
"""
with self.db._tx_stack() as stack:
first = not stack
stack.append(self)
if first:
# Beginning a "root" transaction, which corresponds to an
# SQLite transaction.
self.db._db_lock.acquire()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Complete a transaction. This must be the most recently
entered but not yet exited transaction. If it is the last active
transaction, the database updates are committed.
"""
with self.db._tx_stack() as stack:
assert stack.pop() is self
empty = not stack
if empty:
# Ending a "root" transaction. End the SQLite transaction.
self.db._connection().commit()
self.db._db_lock.release()
def query(self, statement, subvals=()):
"""Execute an SQL statement with substitution values and return
a list of rows from the database.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.fetchall()
def mutate(self, statement, subvals=()):
"""Execute an SQL statement with substitution values and return
the row ID of the last affected row.
"""
cursor = self.db._connection().execute(statement, subvals)
return cursor.lastrowid
def script(self, statements):
"""Execute a string containing multiple SQL statements."""
self.db._connection().executescript(statements)
class Database(object):
"""A container for Model objects that wraps an SQLite database as
the backend.
"""
_models = ()
"""The Model subclasses representing tables in this database.
"""
def __init__(self, path):
self.path = path
self._connections = {}
self._tx_stacks = defaultdict(list)
# A lock to protect the _connections and _tx_stacks maps, which
# both map thread IDs to private resources.
self._shared_map_lock = threading.Lock()
# A lock to protect access to the database itself. SQLite does
# allow multiple threads to access the database at the same
# time, but many users were experiencing crashes related to this
# capability: where SQLite was compiled without HAVE_USLEEP, its
# backoff algorithm in the case of contention was causing
# whole-second sleeps (!) that would trigger its internal
# timeout. Using this lock ensures only one SQLite transaction
# is active at a time.
self._db_lock = threading.Lock()
# Set up database schema.
for model_cls in self._models:
self._make_table(model_cls._table, model_cls._fields)
self._make_attribute_table(model_cls._flex_table)
# Primitive access control: connections and transactions.
def _connection(self):
"""Get a SQLite connection object to the underlying database.
One connection object is created per thread.
"""
thread_id = threading.current_thread().ident
with self._shared_map_lock:
if thread_id in self._connections:
return self._connections[thread_id]
else:
# Make a new connection.
conn = sqlite3.connect(
self.path,
timeout=beets.config['timeout'].as_number(),
)
# Access SELECT results like dictionaries.
conn.row_factory = sqlite3.Row
self._connections[thread_id] = conn
return conn
@contextlib.contextmanager
def _tx_stack(self):
"""A context manager providing access to the current thread's
transaction stack. The context manager synchronizes access to
the stack map. Transactions should never migrate across threads.
"""
thread_id = threading.current_thread().ident
with self._shared_map_lock:
yield self._tx_stacks[thread_id]
def transaction(self):
"""Get a :class:`Transaction` object for interacting directly
with the underlying SQLite database.
"""
return Transaction(self)
# Schema setup and migration.
def _make_table(self, table, fields):
"""Set up the schema of the database. `fields` is a mapping
from field names to `Type`s. Columns are added if necessary.
"""
# Get current schema.
with self.transaction() as tx:
rows = tx.query('PRAGMA table_info(%s)' % table)
current_fields = set([row[1] for row in rows])
field_names = set(fields.keys())
if current_fields.issuperset(field_names):
# Table exists and has all the required columns.
return
if not current_fields:
# No table exists.
columns = []
for name, typ in fields.items():
columns.append('{0} {1}'.format(name, typ.sql))
setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
', '.join(columns))
else:
# Table exists does not match the field set.
setup_sql = ''
for name, typ in fields.items():
if name in current_fields:
continue
setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
table, name, typ.sql
)
with self.transaction() as tx:
tx.script(setup_sql)
def _make_attribute_table(self, flex_table):
"""Create a table and associated index for flexible attributes
for the given entity (if they don't exist).
"""
with self.transaction() as tx:
tx.script("""
CREATE TABLE IF NOT EXISTS {0} (
id INTEGER PRIMARY KEY,
entity_id INTEGER,
key TEXT,
value TEXT,
UNIQUE(entity_id, key) ON CONFLICT REPLACE);
CREATE INDEX IF NOT EXISTS {0}_by_entity
ON {0} (entity_id);
""".format(flex_table))
# Querying.
def _fetch(self, model_cls, query=None, sort=None):
"""Fetch the objects of type `model_cls` matching the given
query. The query may be given as a string, string sequence, a
Query object, or None (to fetch everything). `sort` is an
`Sort` object.
"""
query = query or TrueQuery() # A null query.
sort = sort or NullSort() # Unsorted.
where, subvals = query.clause()
order_by = sort.order_clause()
sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
model_cls._table,
where or '1',
"ORDER BY {0}".format(order_by) if order_by else '',
)
with self.transaction() as tx:
rows = tx.query(sql, subvals)
return Results(
model_cls, rows, self,
None if where else query, # Slow query component.
sort if sort.is_slow() else None, # Slow sort component.
)
def _get(self, model_cls, id):
"""Get a Model object by its id or None if the id does not
exist.
"""
return self._fetch(model_cls, MatchQuery('id', id)).get()

654
lib/beets/dbcore/query.py Normal file
View File

@@ -0,0 +1,654 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""The Query type hierarchy for DBCore.
"""
import re
from operator import attrgetter
from beets import util
from datetime import datetime, timedelta
class Query(object):
"""An abstract class representing a query into the item database.
"""
def clause(self):
"""Generate an SQLite expression implementing the query.
Return a clause string, a sequence of substitution values for
the clause, and a Query object representing the "remainder"
Returns (clause, subvals) where clause is a valid sqlite
WHERE clause implementing the query and subvals is a list of
items to be substituted for ?s in the clause.
"""
return None, ()
def match(self, item):
"""Check whether this query matches a given Item. Can be used to
perform queries on arbitrary sets of Items.
"""
raise NotImplementedError
class FieldQuery(Query):
"""An abstract query that searches in a specific field for a
pattern. Subclasses must provide a `value_match` class method, which
determines whether a certain pattern string matches a certain value
string. Subclasses may also provide `col_clause` to implement the
same matching functionality in SQLite.
"""
def __init__(self, field, pattern, fast=True):
self.field = field
self.pattern = pattern
self.fast = fast
def col_clause(self):
return None, ()
def clause(self):
if self.fast:
return self.col_clause()
else:
# Matching a flexattr. This is a slow query.
return None, ()
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings.
"""
raise NotImplementedError()
def match(self, item):
return self.value_match(self.pattern, item.get(self.field))
class MatchQuery(FieldQuery):
"""A query that looks for exact matches in an item field."""
def col_clause(self):
return self.field + " = ?", [self.pattern]
@classmethod
def value_match(cls, pattern, value):
return pattern == value
class NoneQuery(FieldQuery):
def __init__(self, field, fast=True):
self.field = field
self.fast = fast
def col_clause(self):
return self.field + " IS NULL", ()
@classmethod
def match(self, item):
try:
return item[self.field] is None
except KeyError:
return True
class StringFieldQuery(FieldQuery):
"""A FieldQuery that converts values to strings before matching
them.
"""
@classmethod
def value_match(cls, pattern, value):
"""Determine whether the value matches the pattern. The value
may have any type.
"""
return cls.string_match(pattern, util.as_string(value))
@classmethod
def string_match(cls, pattern, value):
"""Determine whether the value matches the pattern. Both
arguments are strings. Subclasses implement this method.
"""
raise NotImplementedError()
class SubstringQuery(StringFieldQuery):
"""A query that matches a substring in a specific item field."""
def col_clause(self):
pattern = (self.pattern
.replace('\\', '\\\\')
.replace('%', '\\%')
.replace('_', '\\_'))
search = '%' + pattern + '%'
clause = self.field + " like ? escape '\\'"
subvals = [search]
return clause, subvals
@classmethod
def string_match(cls, pattern, value):
return pattern.lower() in value.lower()
class RegexpQuery(StringFieldQuery):
"""A query that matches a regular expression in a specific item
field.
"""
@classmethod
def string_match(cls, pattern, value):
try:
res = re.search(pattern, value)
except re.error:
# Invalid regular expression.
return False
return res is not None
class BooleanQuery(MatchQuery):
"""Matches a boolean field. Pattern should either be a boolean or a
string reflecting a boolean.
"""
def __init__(self, field, pattern, fast=True):
super(BooleanQuery, self).__init__(field, pattern, fast)
if isinstance(pattern, basestring):
self.pattern = util.str2bool(pattern)
self.pattern = int(self.pattern)
class BytesQuery(MatchQuery):
"""Match a raw bytes field (i.e., a path). This is a necessary hack
to work around the `sqlite3` module's desire to treat `str` and
`unicode` equivalently in Python 2. Always use this query instead of
`MatchQuery` when matching on BLOB values.
"""
def __init__(self, field, pattern):
super(BytesQuery, self).__init__(field, pattern)
# Use a buffer representation of the pattern for SQLite
# matching. This instructs SQLite to treat the blob as binary
# rather than encoded Unicode.
if isinstance(self.pattern, basestring):
# Implicitly coerce Unicode strings to their bytes
# equivalents.
if isinstance(self.pattern, unicode):
self.pattern = self.pattern.encode('utf8')
self.buf_pattern = buffer(self.pattern)
elif isinstance(self.pattern, buffer):
self.buf_pattern = self.pattern
self.pattern = bytes(self.pattern)
def col_clause(self):
return self.field + " = ?", [self.buf_pattern]
class NumericQuery(FieldQuery):
"""Matches numeric fields. A syntax using Ruby-style range ellipses
(``..``) lets users specify one- or two-sided ranges. For example,
``year:2001..`` finds music released since the turn of the century.
"""
def _convert(self, s):
"""Convert a string to a numeric type (float or int). If the
string cannot be converted, return None.
"""
# This is really just a bit of fun premature optimization.
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return None
def __init__(self, field, pattern, fast=True):
super(NumericQuery, self).__init__(field, pattern, fast)
parts = pattern.split('..', 1)
if len(parts) == 1:
# No range.
self.point = self._convert(parts[0])
self.rangemin = None
self.rangemax = None
else:
# One- or two-sided range.
self.point = None
self.rangemin = self._convert(parts[0])
self.rangemax = self._convert(parts[1])
def match(self, item):
if self.field not in item:
return False
value = item[self.field]
if isinstance(value, basestring):
value = self._convert(value)
if self.point is not None:
return value == self.point
else:
if self.rangemin is not None and value < self.rangemin:
return False
if self.rangemax is not None and value > self.rangemax:
return False
return True
def col_clause(self):
if self.point is not None:
return self.field + '=?', (self.point,)
else:
if self.rangemin is not None and self.rangemax is not None:
return (u'{0} >= ? AND {0} <= ?'.format(self.field),
(self.rangemin, self.rangemax))
elif self.rangemin is not None:
return u'{0} >= ?'.format(self.field), (self.rangemin,)
elif self.rangemax is not None:
return u'{0} <= ?'.format(self.field), (self.rangemax,)
else:
return '1', ()
class CollectionQuery(Query):
"""An abstract query class that aggregates other queries. Can be
indexed like a list to access the sub-queries.
"""
def __init__(self, subqueries=()):
self.subqueries = subqueries
# Act like a sequence.
def __len__(self):
return len(self.subqueries)
def __getitem__(self, key):
return self.subqueries[key]
def __iter__(self):
return iter(self.subqueries)
def __contains__(self, item):
return item in self.subqueries
def clause_with_joiner(self, joiner):
"""Returns a clause created by joining together the clauses of
all subqueries with the string joiner (padded by spaces).
"""
clause_parts = []
subvals = []
for subq in self.subqueries:
subq_clause, subq_subvals = subq.clause()
if not subq_clause:
# Fall back to slow query.
return None, ()
clause_parts.append('(' + subq_clause + ')')
subvals += subq_subvals
clause = (' ' + joiner + ' ').join(clause_parts)
return clause, subvals
class AnyFieldQuery(CollectionQuery):
"""A query that matches if a given FieldQuery subclass matches in
any field. The individual field query class is provided to the
constructor.
"""
def __init__(self, pattern, fields, cls):
self.pattern = pattern
self.fields = fields
self.query_class = cls
subqueries = []
for field in self.fields:
subqueries.append(cls(field, pattern, True))
super(AnyFieldQuery, self).__init__(subqueries)
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
for subq in self.subqueries:
if subq.match(item):
return True
return False
class MutableCollectionQuery(CollectionQuery):
"""A collection query whose subqueries may be modified after the
query is initialized.
"""
def __setitem__(self, key, value):
self.subqueries[key] = value
def __delitem__(self, key):
del self.subqueries[key]
class AndQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
return self.clause_with_joiner('and')
def match(self, item):
return all([q.match(item) for q in self.subqueries])
class OrQuery(MutableCollectionQuery):
"""A conjunction of a list of other queries."""
def clause(self):
return self.clause_with_joiner('or')
def match(self, item):
return any([q.match(item) for q in self.subqueries])
class TrueQuery(Query):
"""A query that always matches."""
def clause(self):
return '1', ()
def match(self, item):
return True
class FalseQuery(Query):
"""A query that never matches."""
def clause(self):
return '0', ()
def match(self, item):
return False
# Time/date queries.
def _to_epoch_time(date):
"""Convert a `datetime` object to an integer number of seconds since
the (local) Unix epoch.
"""
epoch = datetime.fromtimestamp(0)
delta = date - epoch
try:
return int(delta.total_seconds())
except AttributeError:
# datetime.timedelta.total_seconds() is not available on Python 2.6
return delta.seconds + delta.days * 24 * 3600
def _parse_periods(pattern):
"""Parse a string containing two dates separated by two dots (..).
Return a pair of `Period` objects.
"""
parts = pattern.split('..', 1)
if len(parts) == 1:
instant = Period.parse(parts[0])
return (instant, instant)
else:
start = Period.parse(parts[0])
end = Period.parse(parts[1])
return (start, end)
class Period(object):
"""A period of time given by a date, time and precision.
Example: 2014-01-01 10:50:30 with precision 'month' represents all
instants of time during January 2014.
"""
precisions = ('year', 'month', 'day')
date_formats = ('%Y', '%Y-%m', '%Y-%m-%d')
def __init__(self, date, precision):
"""Create a period with the given date (a `datetime` object) and
precision (a string, one of "year", "month", or "day").
"""
if precision not in Period.precisions:
raise ValueError('Invalid precision ' + str(precision))
self.date = date
self.precision = precision
@classmethod
def parse(cls, string):
"""Parse a date and return a `Period` object or `None` if the
string is empty.
"""
if not string:
return None
ordinal = string.count('-')
if ordinal >= len(cls.date_formats):
# Too many components.
return None
date_format = cls.date_formats[ordinal]
try:
date = datetime.strptime(string, date_format)
except ValueError:
# Parsing failed.
return None
precision = cls.precisions[ordinal]
return cls(date, precision)
def open_right_endpoint(self):
"""Based on the precision, convert the period to a precise
`datetime` for use as a right endpoint in a right-open interval.
"""
precision = self.precision
date = self.date
if 'year' == self.precision:
return date.replace(year=date.year + 1, month=1)
elif 'month' == precision:
if (date.month < 12):
return date.replace(month=date.month + 1)
else:
return date.replace(year=date.year + 1, month=1)
elif 'day' == precision:
return date + timedelta(days=1)
else:
raise ValueError('unhandled precision ' + str(precision))
class DateInterval(object):
"""A closed-open interval of dates.
A left endpoint of None means since the beginning of time.
A right endpoint of None means towards infinity.
"""
def __init__(self, start, end):
if start is not None and end is not None and not start < end:
raise ValueError("start date {0} is not before end date {1}"
.format(start, end))
self.start = start
self.end = end
@classmethod
def from_periods(cls, start, end):
"""Create an interval with two Periods as the endpoints.
"""
end_date = end.open_right_endpoint() if end is not None else None
start_date = start.date if start is not None else None
return cls(start_date, end_date)
def contains(self, date):
if self.start is not None and date < self.start:
return False
if self.end is not None and date >= self.end:
return False
return True
def __str__(self):
return'[{0}, {1})'.format(self.start, self.end)
class DateQuery(FieldQuery):
"""Matches date fields stored as seconds since Unix epoch time.
Dates can be specified as ``year-month-day`` strings where only year
is mandatory.
The value of a date field can be matched against a date interval by
using an ellipsis interval syntax similar to that of NumericQuery.
"""
def __init__(self, field, pattern, fast=True):
super(DateQuery, self).__init__(field, pattern, fast)
start, end = _parse_periods(pattern)
self.interval = DateInterval.from_periods(start, end)
def match(self, item):
timestamp = float(item[self.field])
date = datetime.utcfromtimestamp(timestamp)
return self.interval.contains(date)
_clause_tmpl = "{0} {1} ?"
def col_clause(self):
clause_parts = []
subvals = []
if self.interval.start:
clause_parts.append(self._clause_tmpl.format(self.field, ">="))
subvals.append(_to_epoch_time(self.interval.start))
if self.interval.end:
clause_parts.append(self._clause_tmpl.format(self.field, "<"))
subvals.append(_to_epoch_time(self.interval.end))
if clause_parts:
# One- or two-sided interval.
clause = ' AND '.join(clause_parts)
else:
# Match any date.
clause = '1'
return clause, subvals
# Sorting.
class Sort(object):
"""An abstract class representing a sort operation for a query into
the item database.
"""
def order_clause(self):
"""Generates a SQL fragment to be used in a ORDER BY clause, or
None if no fragment is used (i.e., this is a slow sort).
"""
return None
def sort(self, items):
"""Sort the list of objects and return a list.
"""
return sorted(items)
def is_slow(self):
"""Indicate whether this query is *slow*, meaning that it cannot
be executed in SQL and must be executed in Python.
"""
return False
class MultipleSort(Sort):
"""Sort that encapsulates multiple sub-sorts.
"""
def __init__(self, sorts=None):
self.sorts = sorts or []
def add_sort(self, sort):
self.sorts.append(sort)
def _sql_sorts(self):
"""Return the list of sub-sorts for which we can be (at least
partially) fast.
A contiguous suffix of fast (SQL-capable) sub-sorts are
executable in SQL. The remaining, even if they are fast
independently, must be executed slowly.
"""
sql_sorts = []
for sort in reversed(self.sorts):
if not sort.order_clause() is None:
sql_sorts.append(sort)
else:
break
sql_sorts.reverse()
return sql_sorts
def order_clause(self):
order_strings = []
for sort in self._sql_sorts():
order = sort.order_clause()
order_strings.append(order)
return ", ".join(order_strings)
def is_slow(self):
for sort in self.sorts:
if sort.is_slow():
return True
return False
def sort(self, items):
slow_sorts = []
switch_slow = False
for sort in reversed(self.sorts):
if switch_slow:
slow_sorts.append(sort)
elif sort.order_clause() is None:
switch_slow = True
slow_sorts.append(sort)
else:
pass
for sort in slow_sorts:
items = sort.sort(items)
return items
def __repr__(self):
return u'MultipleSort({0})'.format(repr(self.sorts))
class FieldSort(Sort):
"""An abstract sort criterion that orders by a specific field (of
any kind).
"""
def __init__(self, field, ascending=True):
self.field = field
self.ascending = ascending
def sort(self, objs):
# TODO: Conversion and null-detection here. In Python 3,
# comparisons with None fail. We should also support flexible
# attributes with different types without falling over.
return sorted(objs, key=attrgetter(self.field),
reverse=not self.ascending)
def __repr__(self):
return u'<{0}: {1}{2}>'.format(
type(self).__name__,
self.field,
'+' if self.ascending else '-',
)
class FixedFieldSort(FieldSort):
"""Sort object to sort on a fixed field.
"""
def order_clause(self):
order = "ASC" if self.ascending else "DESC"
return "{0} {1}".format(self.field, order)
class SlowFieldSort(FieldSort):
"""A sort criterion by some model field other than a fixed field:
i.e., a computed or flexible field.
"""
def is_slow(self):
return True
class NullSort(Sort):
"""No sorting. Leave results unsorted."""
def sort(items):
return items

View File

@@ -0,0 +1,180 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Parsing of strings into DBCore queries.
"""
import re
import itertools
from . import query
PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r'(?:'
r'(\S+?)' # The field key.
r'(?<!\\):' # Unescaped :
r')?'
r'(.*)', # The term itself.
re.I # Case-insensitive.
)
def parse_query_part(part, query_classes={}, prefixes={},
default_class=query.SubstringQuery):
"""Take a query in the form of a key/value pair separated by a
colon and return a tuple of `(key, value, cls)`. `key` may be None,
indicating that any field may be matched. `cls` is a subclass of
`FieldQuery`.
The optional `query_classes` parameter maps field names to default
query types; `default_class` is the fallback. `prefixes` is a map
from query prefix markers and query types. Prefix-indicated queries
take precedence over type-based queries.
To determine the query class, two factors are used: prefixes and
field types. For example, the colon prefix denotes a regular
expression query and a type map might provide a special kind of
query for numeric values. If neither a prefix nor a specific query
class is available, `default_class` is used.
For instance,
'stapler' -> (None, 'stapler', SubstringQuery)
'color:red' -> ('color', 'red', SubstringQuery)
':^Quiet' -> (None, '^Quiet', RegexpQuery)
'color::b..e' -> ('color', 'b..e', RegexpQuery)
Prefixes may be "escaped" with a backslash to disable the keying
behavior.
"""
part = part.strip()
match = PARSE_QUERY_PART_REGEX.match(part)
assert match # Regex should always match.
key = match.group(1)
term = match.group(2).replace('\:', ':')
# Match the search term against the list of prefixes.
for pre, query_class in prefixes.items():
if term.startswith(pre):
return key, term[len(pre):], query_class
# No matching prefix: use type-based or fallback/default query.
query_class = query_classes.get(key, default_class)
return key, term, query_class
def construct_query_part(model_cls, prefixes, query_part):
"""Create a query from a single query component, `query_part`, for
querying instances of `model_cls`. Return a `Query` instance.
"""
# Shortcut for empty query parts.
if not query_part:
return query.TrueQuery()
# Get the query classes for each possible field.
query_classes = {}
for k, t in itertools.chain(model_cls._fields.items(),
model_cls._types.items()):
query_classes[k] = t.query
# Parse the string.
key, pattern, query_class = \
parse_query_part(query_part, query_classes, prefixes)
# No key specified.
if key is None:
if issubclass(query_class, query.FieldQuery):
# The query type matches a specific field, but none was
# specified. So we use a version of the query that matches
# any field.
return query.AnyFieldQuery(pattern, model_cls._search_fields,
query_class)
else:
# Other query type.
return query_class(pattern)
key = key.lower()
return query_class(key.lower(), pattern, key in model_cls._fields)
def query_from_strings(query_cls, model_cls, prefixes, query_parts):
"""Creates a collection query of type `query_cls` from a list of
strings in the format used by parse_query_part. `model_cls`
determines how queries are constructed from strings.
"""
subqueries = []
for part in query_parts:
subqueries.append(construct_query_part(model_cls, prefixes, part))
if not subqueries: # No terms in query.
subqueries = [query.TrueQuery()]
return query_cls(subqueries)
def construct_sort_part(model_cls, part):
"""Create a `Sort` from a single string criterion.
`model_cls` is the `Model` being queried. `part` is a single string
ending in ``+`` or ``-`` indicating the sort.
"""
assert part, "part must be a field name and + or -"
field = part[:-1]
assert field, "field is missing"
direction = part[-1]
assert direction in ('+', '-'), "part must end with + or -"
is_ascending = direction == '+'
if field in model_cls._sorts:
sort = model_cls._sorts[field](model_cls, is_ascending)
elif field in model_cls._fields:
sort = query.FixedFieldSort(field, is_ascending)
else:
# Flexible or computed.
sort = query.SlowFieldSort(field, is_ascending)
return sort
def sort_from_strings(model_cls, sort_parts):
"""Create a `Sort` from a list of sort criteria (strings).
"""
if not sort_parts:
return query.NullSort()
else:
sort = query.MultipleSort()
for part in sort_parts:
sort.add_sort(construct_sort_part(model_cls, part))
return sort
def parse_sorted_query(model_cls, parts, prefixes={},
query_cls=query.AndQuery):
"""Given a list of strings, create the `Query` and `Sort` that they
represent.
"""
# Separate query token and sort token.
query_parts = []
sort_parts = []
for part in parts:
if part.endswith((u'+', u'-')) and u':' not in part:
sort_parts.append(part)
else:
query_parts.append(part)
# Parse each.
q = query_from_strings(
query_cls, model_cls, prefixes, query_parts
)
s = sort_from_strings(model_cls, sort_parts)
return q, s

208
lib/beets/dbcore/types.py Normal file
View File

@@ -0,0 +1,208 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Representation of type information for DBCore model fields.
"""
from . import query
from beets.util import str2bool
# Abstract base.
class Type(object):
"""An object encapsulating the type of a model field. Includes
information about how to store, query, format, and parse a given
field.
"""
sql = u'TEXT'
"""The SQLite column type for the value.
"""
query = query.SubstringQuery
"""The `Query` subclass to be used when querying the field.
"""
model_type = unicode
"""The Python type that is used to represent the value in the model.
The model is guaranteed to return a value of this type if the field
is accessed. To this end, the constructor is used by the `normalize`
and `from_sql` methods and the `default` property.
"""
@property
def null(self):
"""The value to be exposed when the underlying value is None.
"""
return self.model_type()
def format(self, value):
"""Given a value of this type, produce a Unicode string
representing the value. This is used in template evaluation.
"""
if value is None:
value = self.null
# `self.null` might be `None`
if value is None:
value = u''
if isinstance(value, bytes):
value = value.decode('utf8', 'ignore')
return unicode(value)
def parse(self, string):
"""Parse a (possibly human-written) string and return the
indicated value of this type.
"""
try:
return self.model_type(string)
except ValueError:
return self.null
def normalize(self, value):
"""Given a value that will be assigned into a field of this
type, normalize the value to have the appropriate type. This
base implementation only reinterprets `None`.
"""
if value is None:
return self.null
else:
# TODO This should eventually be replaced by
# `self.model_type(value)`
return value
def from_sql(self, sql_value):
"""Receives the value stored in the SQL backend and return the
value to be stored in the model.
For fixed fields the type of `value` is determined by the column
type affinity given in the `sql` property and the SQL to Python
mapping of the database adapter. For more information see:
http://www.sqlite.org/datatype3.html
https://docs.python.org/2/library/sqlite3.html#sqlite-and-python-types
Flexible fields have the type afinity `TEXT`. This means the
`sql_value` is either a `buffer` or a `unicode` object` and the
method must handle these in addition.
"""
if isinstance(sql_value, buffer):
sql_value = bytes(sql_value).decode('utf8', 'ignore')
if isinstance(sql_value, unicode):
return self.parse(sql_value)
else:
return self.normalize(sql_value)
def to_sql(self, model_value):
"""Convert a value as stored in the model object to a value used
by the database adapter.
"""
return model_value
# Reusable types.
class Default(Type):
null = None
class Integer(Type):
"""A basic integer type.
"""
sql = u'INTEGER'
query = query.NumericQuery
model_type = int
class PaddedInt(Integer):
"""An integer field that is formatted with a given number of digits,
padded with zeroes.
"""
def __init__(self, digits):
self.digits = digits
def format(self, value):
return u'{0:0{1}d}'.format(value or 0, self.digits)
class ScaledInt(Integer):
"""An integer whose formatting operation scales the number by a
constant and adds a suffix. Good for units with large magnitudes.
"""
def __init__(self, unit, suffix=u''):
self.unit = unit
self.suffix = suffix
def format(self, value):
return u'{0}{1}'.format((value or 0) // self.unit, self.suffix)
class Id(Integer):
"""An integer used as the row id or a foreign key in a SQLite table.
This type is nullable: None values are not translated to zero.
"""
null = None
def __init__(self, primary=True):
if primary:
self.sql = u'INTEGER PRIMARY KEY'
class Float(Type):
"""A basic floating-point type.
"""
sql = u'REAL'
query = query.NumericQuery
model_type = float
def format(self, value):
return u'{0:.1f}'.format(value or 0.0)
class NullFloat(Float):
"""Same as `Float`, but does not normalize `None` to `0.0`.
"""
null = None
class String(Type):
"""A Unicode string type.
"""
sql = u'TEXT'
query = query.SubstringQuery
class Boolean(Type):
"""A boolean type.
"""
sql = u'INTEGER'
query = query.BooleanQuery
model_type = bool
def format(self, value):
return unicode(bool(value))
def parse(self, string):
return str2bool(string)
# Shared instances of common types.
DEFAULT = Default()
INTEGER = Integer()
PRIMARY_ID = Id(True)
FOREIGN_ID = Id(False)
FLOAT = Float()
NULL_FLOAT = NullFloat()
STRING = String()
BOOLEAN = Boolean()

1428
lib/beets/importer.py Normal file

File diff suppressed because it is too large Load Diff

1299
lib/beets/library.py Normal file

File diff suppressed because it is too large Load Diff

1929
lib/beets/mediafile.py Normal file

File diff suppressed because it is too large Load Diff

435
lib/beets/plugins.py Executable file
View File

@@ -0,0 +1,435 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Support for beets plugins."""
import logging
import traceback
import inspect
import re
from collections import defaultdict
import beets
from beets import mediafile
PLUGIN_NAMESPACE = 'beetsplug'
# Plugins using the Last.fm API can share the same API key.
LASTFM_KEY = '2dc3914abf35f0d9c92d97d8f8e42b43'
# Global logger.
log = logging.getLogger('beets')
class PluginConflictException(Exception):
"""Indicates that the services provided by one plugin conflict with
those of another.
For example two plugins may define different types for flexible fields.
"""
# Managing the plugins themselves.
class BeetsPlugin(object):
"""The base class for all beets plugins. Plugins provide
functionality by defining a subclass of BeetsPlugin and overriding
the abstract methods defined here.
"""
def __init__(self, name=None):
"""Perform one-time plugin setup.
"""
self.import_stages = []
self.name = name or self.__module__.split('.')[-1]
self.config = beets.config[self.name]
if not self.template_funcs:
self.template_funcs = {}
if not self.template_fields:
self.template_fields = {}
if not self.album_template_fields:
self.album_template_fields = {}
def commands(self):
"""Should return a list of beets.ui.Subcommand objects for
commands that should be added to beets' CLI.
"""
return ()
def queries(self):
"""Should return a dict mapping prefixes to Query subclasses.
"""
return {}
def track_distance(self, item, info):
"""Should return a Distance object to be added to the
distance for every track comparison.
"""
return beets.autotag.hooks.Distance()
def album_distance(self, items, album_info, mapping):
"""Should return a Distance object to be added to the
distance for every album-level comparison.
"""
return beets.autotag.hooks.Distance()
def candidates(self, items, artist, album, va_likely):
"""Should return a sequence of AlbumInfo objects that match the
album whose items are provided.
"""
return ()
def item_candidates(self, item, artist, title):
"""Should return a sequence of TrackInfo objects that match the
item provided.
"""
return ()
def album_for_id(self, album_id):
"""Return an AlbumInfo object or None if no matching release was
found.
"""
return None
def track_for_id(self, track_id):
"""Return a TrackInfo object or None if no matching release was
found.
"""
return None
def add_media_field(self, name, descriptor):
"""Add a field that is synchronized between media files and items.
When a media field is added ``item.write()`` will set the name
property of the item's MediaFile to ``item[name]`` and save the
changes. Similarly ``item.read()`` will set ``item[name]`` to
the value of the name property of the media file.
``descriptor`` must be an instance of ``mediafile.MediaField``.
"""
# Defer impor to prevent circular dependency
from beets import library
mediafile.MediaFile.add_field(name, descriptor)
library.Item._media_fields.add(name)
listeners = None
@classmethod
def register_listener(cls, event, func):
"""Add a function as a listener for the specified event. (An
imperative alternative to the @listen decorator.)
"""
if cls.listeners is None:
cls.listeners = defaultdict(list)
cls.listeners[event].append(func)
@classmethod
def listen(cls, event):
"""Decorator that adds a function as an event handler for the
specified event (as a string). The parameters passed to function
will vary depending on what event occurred.
The function should respond to named parameters.
function(**kwargs) will trap all arguments in a dictionary.
Example:
>>> @MyPlugin.listen("imported")
>>> def importListener(**kwargs):
... pass
"""
def helper(func):
if cls.listeners is None:
cls.listeners = defaultdict(list)
cls.listeners[event].append(func)
return func
return helper
template_funcs = None
template_fields = None
album_template_fields = None
@classmethod
def template_func(cls, name):
"""Decorator that registers a path template function. The
function will be invoked as ``%name{}`` from path format
strings.
"""
def helper(func):
if cls.template_funcs is None:
cls.template_funcs = {}
cls.template_funcs[name] = func
return func
return helper
@classmethod
def template_field(cls, name):
"""Decorator that registers a path template field computation.
The value will be referenced as ``$name`` from path format
strings. The function must accept a single parameter, the Item
being formatted.
"""
def helper(func):
if cls.template_fields is None:
cls.template_fields = {}
cls.template_fields[name] = func
return func
return helper
_classes = set()
def load_plugins(names=()):
"""Imports the modules for a sequence of plugin names. Each name
must be the name of a Python module under the "beetsplug" namespace
package in sys.path; the module indicated should contain the
BeetsPlugin subclasses desired.
"""
for name in names:
modname = '%s.%s' % (PLUGIN_NAMESPACE, name)
try:
try:
namespace = __import__(modname, None, None)
except ImportError as exc:
# Again, this is hacky:
if exc.args[0].endswith(' ' + name):
log.warn(u'** plugin {0} not found'.format(name))
else:
raise
else:
for obj in getattr(namespace, name).__dict__.values():
if isinstance(obj, type) and issubclass(obj, BeetsPlugin) \
and obj != BeetsPlugin and obj not in _classes:
_classes.add(obj)
except:
log.warn(u'** error loading plugin {0}'.format(name))
log.warn(traceback.format_exc())
_instances = {}
def find_plugins():
"""Returns a list of BeetsPlugin subclass instances from all
currently loaded beets plugins. Loads the default plugin set
first.
"""
load_plugins()
plugins = []
for cls in _classes:
# Only instantiate each plugin class once.
if cls not in _instances:
_instances[cls] = cls()
plugins.append(_instances[cls])
return plugins
# Communication with plugins.
def commands():
"""Returns a list of Subcommand objects from all loaded plugins.
"""
out = []
for plugin in find_plugins():
out += plugin.commands()
return out
def queries():
"""Returns a dict mapping prefix strings to Query subclasses all loaded
plugins.
"""
out = {}
for plugin in find_plugins():
out.update(plugin.queries())
return out
def types(model_cls):
# Gives us `item_types` and `album_types`
attr_name = '{0}_types'.format(model_cls.__name__.lower())
types = {}
for plugin in find_plugins():
plugin_types = getattr(plugin, attr_name, {})
for field in plugin_types:
if field in types and plugin_types[field] != types[field]:
raise PluginConflictException(
u'Plugin {0} defines flexible field {1} '
'which has already been defined with '
'another type.'.format(plugin.name, field)
)
types.update(plugin_types)
return types
def track_distance(item, info):
"""Gets the track distance calculated by all loaded plugins.
Returns a Distance object.
"""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.track_distance(item, info))
return dist
def album_distance(items, album_info, mapping):
"""Returns the album distance calculated by plugins."""
from beets.autotag.hooks import Distance
dist = Distance()
for plugin in find_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
return dist
def candidates(items, artist, album, va_likely):
"""Gets MusicBrainz candidates for an album from each plugin.
"""
out = []
for plugin in find_plugins():
out.extend(plugin.candidates(items, artist, album, va_likely))
return out
def item_candidates(item, artist, title):
"""Gets MusicBrainz candidates for an item from the plugins.
"""
out = []
for plugin in find_plugins():
out.extend(plugin.item_candidates(item, artist, title))
return out
def album_for_id(album_id):
"""Get AlbumInfo objects for a given ID string.
"""
out = []
for plugin in find_plugins():
res = plugin.album_for_id(album_id)
if res:
out.append(res)
return out
def track_for_id(track_id):
"""Get TrackInfo objects for a given ID string.
"""
out = []
for plugin in find_plugins():
res = plugin.track_for_id(track_id)
if res:
out.append(res)
return out
def template_funcs():
"""Get all the template functions declared by plugins as a
dictionary.
"""
funcs = {}
for plugin in find_plugins():
if plugin.template_funcs:
funcs.update(plugin.template_funcs)
return funcs
def import_stages():
"""Get a list of import stage functions defined by plugins."""
stages = []
for plugin in find_plugins():
if hasattr(plugin, 'import_stages'):
stages += plugin.import_stages
return stages
# New-style (lazy) plugin-provided fields.
def item_field_getters():
"""Get a dictionary mapping field names to unary functions that
compute the field's value.
"""
funcs = {}
for plugin in find_plugins():
if plugin.template_fields:
funcs.update(plugin.template_fields)
return funcs
def album_field_getters():
"""As above, for album fields.
"""
funcs = {}
for plugin in find_plugins():
if plugin.album_template_fields:
funcs.update(plugin.album_template_fields)
return funcs
# Event dispatch.
def event_handlers():
"""Find all event handlers from plugins as a dictionary mapping
event names to sequences of callables.
"""
all_handlers = defaultdict(list)
for plugin in find_plugins():
if plugin.listeners:
for event, handlers in plugin.listeners.items():
all_handlers[event] += handlers
return all_handlers
def send(event, **arguments):
"""Sends an event to all assigned event listeners. Event is the
name of the event to send, all other named arguments go to the
event handler(s).
Returns a list of return values from the handlers.
"""
log.debug(u'Sending event: {0}'.format(event))
for handler in event_handlers()[event]:
# Don't break legacy plugins if we want to pass more arguments
argspec = inspect.getargspec(handler).args
args = dict((k, v) for k, v in arguments.items() if k in argspec)
handler(**args)
def feat_tokens(for_artist=True):
"""Return a regular expression that matches phrases like "featuring"
that separate a main artist or a song title from secondary artists.
The `for_artist` option determines whether the regex should be
suitable for matching artist fields (the default) or title fields.
"""
feat_words = ['ft', 'featuring', 'feat', 'feat.', 'ft.']
if for_artist:
feat_words += ['with', 'vs', 'and', 'con', '&']
return '(?<=\s)(?:{0})(?=\s)'.format(
'|'.join(re.escape(x) for x in feat_words)
)
def sanitize_choices(choices, choices_all):
"""Clean up a stringlist configuration attribute: keep only choices
elements present in choices_all, remove duplicate elements, expand '*'
wildcard while keeping original stringlist order.
"""
seen = set()
others = [x for x in choices_all if x not in choices]
res = []
for s in choices:
if s in list(choices_all) + ['*']:
if not (s in seen or seen.add(s)):
res.extend(list(others) if s == '*' else [s])
return res

970
lib/beets/ui/__init__.py Normal file
View File

@@ -0,0 +1,970 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This module contains all of the core logic for beets' command-line
interface. To invoke the CLI, just call beets.ui.main(). The actual
CLI commands are implemented in the ui.commands module.
"""
from __future__ import print_function
import locale
import optparse
import textwrap
import sys
from difflib import SequenceMatcher
import logging
import sqlite3
import errno
import re
import struct
import traceback
import os.path
from beets import library
from beets import plugins
from beets import util
from beets.util.functemplate import Template
from beets import config
from beets.util import confit
from beets.autotag import mb
# On Windows platforms, use colorama to support "ANSI" terminal colors.
if sys.platform == 'win32':
try:
import colorama
except ImportError:
pass
else:
colorama.init()
log = logging.getLogger('beets')
if not log.handlers:
log.addHandler(logging.StreamHandler())
log.propagate = False # Don't propagate to root handler.
PF_KEY_QUERIES = {
'comp': 'comp:true',
'singleton': 'singleton:true',
}
class UserError(Exception):
"""UI exception. Commands should throw this in order to display
nonrecoverable errors to the user.
"""
# Utilities.
def _encoding():
"""Tries to guess the encoding used by the terminal."""
# Configured override?
encoding = config['terminal_encoding'].get()
if encoding:
return encoding
# Determine from locale settings.
try:
return locale.getdefaultlocale()[1] or 'utf8'
except ValueError:
# Invalid locale environment variable setting. To avoid
# failing entirely for no good reason, assume UTF-8.
return 'utf8'
def decargs(arglist):
"""Given a list of command-line argument bytestrings, attempts to
decode them to Unicode strings.
"""
return [s.decode(_encoding()) for s in arglist]
def print_(*strings):
"""Like print, but rather than raising an error when a character
is not in the terminal's encoding's character set, just silently
replaces it.
"""
if strings:
if isinstance(strings[0], unicode):
txt = u' '.join(strings)
else:
txt = ' '.join(strings)
else:
txt = u''
if isinstance(txt, unicode):
txt = txt.encode(_encoding(), 'replace')
print(txt)
def input_(prompt=None):
"""Like `raw_input`, but decodes the result to a Unicode string.
Raises a UserError if stdin is not available. The prompt is sent to
stdout rather than stderr. A printed between the prompt and the
input cursor.
"""
# raw_input incorrectly sends prompts to stderr, not stdout, so we
# use print() explicitly to display prompts.
# http://bugs.python.org/issue1927
if prompt:
if isinstance(prompt, unicode):
prompt = prompt.encode(_encoding(), 'replace')
print(prompt, end=' ')
try:
resp = raw_input()
except EOFError:
raise UserError('stdin stream ended while input required')
return resp.decode(sys.stdin.encoding or 'utf8', 'ignore')
def input_options(options, require=False, prompt=None, fallback_prompt=None,
numrange=None, default=None, max_width=72):
"""Prompts a user for input. The sequence of `options` defines the
choices the user has. A single-letter shortcut is inferred for each
option; the user's choice is returned as that single, lower-case
letter. The options should be provided as lower-case strings unless
a particular shortcut is desired; in that case, only that letter
should be capitalized.
By default, the first option is the default. `default` can be provided to
override this. If `require` is provided, then there is no default. The
prompt and fallback prompt are also inferred but can be overridden.
If numrange is provided, it is a pair of `(high, low)` (both ints)
indicating that, in addition to `options`, the user may enter an
integer in that inclusive range.
`max_width` specifies the maximum number of columns in the
automatically generated prompt string.
"""
# Assign single letters to each option. Also capitalize the options
# to indicate the letter.
letters = {}
display_letters = []
capitalized = []
first = True
for option in options:
# Is a letter already capitalized?
for letter in option:
if letter.isalpha() and letter.upper() == letter:
found_letter = letter
break
else:
# Infer a letter.
for letter in option:
if not letter.isalpha():
continue # Don't use punctuation.
if letter not in letters:
found_letter = letter
break
else:
raise ValueError('no unambiguous lettering found')
letters[found_letter.lower()] = option
index = option.index(found_letter)
# Mark the option's shortcut letter for display.
if not require and (
(default is None and not numrange and first) or
(isinstance(default, basestring) and
found_letter.lower() == default.lower())):
# The first option is the default; mark it.
show_letter = '[%s]' % found_letter.upper()
is_default = True
else:
show_letter = found_letter.upper()
is_default = False
# Colorize the letter shortcut.
show_letter = colorize('turquoise' if is_default else 'blue',
show_letter)
# Insert the highlighted letter back into the word.
capitalized.append(
option[:index] + show_letter + option[index + 1:]
)
display_letters.append(found_letter.upper())
first = False
# The default is just the first option if unspecified.
if require:
default = None
elif default is None:
if numrange:
default = numrange[0]
else:
default = display_letters[0].lower()
# Make a prompt if one is not provided.
if not prompt:
prompt_parts = []
prompt_part_lengths = []
if numrange:
if isinstance(default, int):
default_name = str(default)
default_name = colorize('turquoise', default_name)
tmpl = '# selection (default %s)'
prompt_parts.append(tmpl % default_name)
prompt_part_lengths.append(len(tmpl % str(default)))
else:
prompt_parts.append('# selection')
prompt_part_lengths.append(len(prompt_parts[-1]))
prompt_parts += capitalized
prompt_part_lengths += [len(s) for s in options]
# Wrap the query text.
prompt = ''
line_length = 0
for i, (part, length) in enumerate(zip(prompt_parts,
prompt_part_lengths)):
# Add punctuation.
if i == len(prompt_parts) - 1:
part += '?'
else:
part += ','
length += 1
# Choose either the current line or the beginning of the next.
if line_length + length + 1 > max_width:
prompt += '\n'
line_length = 0
if line_length != 0:
# Not the beginning of the line; need a space.
part = ' ' + part
length += 1
prompt += part
line_length += length
# Make a fallback prompt too. This is displayed if the user enters
# something that is not recognized.
if not fallback_prompt:
fallback_prompt = 'Enter one of '
if numrange:
fallback_prompt += '%i-%i, ' % numrange
fallback_prompt += ', '.join(display_letters) + ':'
resp = input_(prompt)
while True:
resp = resp.strip().lower()
# Try default option.
if default is not None and not resp:
resp = default
# Try an integer input if available.
if numrange:
try:
resp = int(resp)
except ValueError:
pass
else:
low, high = numrange
if low <= resp <= high:
return resp
else:
resp = None
# Try a normal letter input.
if resp:
resp = resp[0]
if resp in letters:
return resp
# Prompt for new input.
resp = input_(fallback_prompt)
def input_yn(prompt, require=False):
"""Prompts the user for a "yes" or "no" response. The default is
"yes" unless `require` is `True`, in which case there is no default.
"""
sel = input_options(
('y', 'n'), require, prompt, 'Enter Y or N:'
)
return sel == 'y'
def human_bytes(size):
"""Formats size, a number of bytes, in a human-readable way."""
suffices = ['B', 'KB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB', 'HB']
for suffix in suffices:
if size < 1024:
return "%3.1f %s" % (size, suffix)
size /= 1024.0
return "big"
def human_seconds(interval):
"""Formats interval, a number of seconds, as a human-readable time
interval using English words.
"""
units = [
(1, 'second'),
(60, 'minute'),
(60, 'hour'),
(24, 'day'),
(7, 'week'),
(52, 'year'),
(10, 'decade'),
]
for i in range(len(units) - 1):
increment, suffix = units[i]
next_increment, _ = units[i + 1]
interval /= float(increment)
if interval < next_increment:
break
else:
# Last unit.
increment, suffix = units[-1]
interval /= float(increment)
return "%3.1f %ss" % (interval, suffix)
def human_seconds_short(interval):
"""Formats a number of seconds as a short human-readable M:SS
string.
"""
interval = int(interval)
return u'%i:%02i' % (interval // 60, interval % 60)
# ANSI terminal colorization code heavily inspired by pygments:
# http://dev.pocoo.org/hg/pygments-main/file/b2deea5b5030/pygments/console.py
# (pygments is by Tim Hatch, Armin Ronacher, et al.)
COLOR_ESCAPE = "\x1b["
DARK_COLORS = ["black", "darkred", "darkgreen", "brown", "darkblue",
"purple", "teal", "lightgray"]
LIGHT_COLORS = ["darkgray", "red", "green", "yellow", "blue",
"fuchsia", "turquoise", "white"]
RESET_COLOR = COLOR_ESCAPE + "39;49;00m"
def _colorize(color, text):
"""Returns a string that prints the given text in the given color
in a terminal that is ANSI color-aware. The color must be something
in DARK_COLORS or LIGHT_COLORS.
"""
if color in DARK_COLORS:
escape = COLOR_ESCAPE + "%im" % (DARK_COLORS.index(color) + 30)
elif color in LIGHT_COLORS:
escape = COLOR_ESCAPE + "%i;01m" % (LIGHT_COLORS.index(color) + 30)
else:
raise ValueError('no such color %s', color)
return escape + text + RESET_COLOR
def colorize(color, text):
"""Colorize text if colored output is enabled. (Like _colorize but
conditional.)
"""
if config['color']:
return _colorize(color, text)
else:
return text
def _colordiff(a, b, highlight='red', minor_highlight='lightgray'):
"""Given two values, return the same pair of strings except with
their differences highlighted in the specified color. Strings are
highlighted intelligently to show differences; other values are
stringified and highlighted in their entirety.
"""
if not isinstance(a, basestring) or not isinstance(b, basestring):
# Non-strings: use ordinary equality.
a = unicode(a)
b = unicode(b)
if a == b:
return a, b
else:
return colorize(highlight, a), colorize(highlight, b)
if isinstance(a, bytes) or isinstance(b, bytes):
# A path field.
a = util.displayable_path(a)
b = util.displayable_path(b)
a_out = []
b_out = []
matcher = SequenceMatcher(lambda x: False, a, b)
for op, a_start, a_end, b_start, b_end in matcher.get_opcodes():
if op == 'equal':
# In both strings.
a_out.append(a[a_start:a_end])
b_out.append(b[b_start:b_end])
elif op == 'insert':
# Right only.
b_out.append(colorize(highlight, b[b_start:b_end]))
elif op == 'delete':
# Left only.
a_out.append(colorize(highlight, a[a_start:a_end]))
elif op == 'replace':
# Right and left differ. Colorise with second highlight if
# it's just a case change.
if a[a_start:a_end].lower() != b[b_start:b_end].lower():
color = highlight
else:
color = minor_highlight
a_out.append(colorize(color, a[a_start:a_end]))
b_out.append(colorize(color, b[b_start:b_end]))
else:
assert(False)
return u''.join(a_out), u''.join(b_out)
def colordiff(a, b, highlight='red'):
"""Colorize differences between two values if color is enabled.
(Like _colordiff but conditional.)
"""
if config['color']:
return _colordiff(a, b, highlight)
else:
return unicode(a), unicode(b)
def get_path_formats(subview=None):
"""Get the configuration's path formats as a list of query/template
pairs.
"""
path_formats = []
subview = subview or config['paths']
for query, view in subview.items():
query = PF_KEY_QUERIES.get(query, query) # Expand common queries.
path_formats.append((query, Template(view.get(unicode))))
return path_formats
def get_replacements():
"""Confit validation function that reads regex/string pairs.
"""
replacements = []
for pattern, repl in config['replace'].get(dict).items():
repl = repl or ''
try:
replacements.append((re.compile(pattern), repl))
except re.error:
raise UserError(
u'malformed regular expression in replace: {0}'.format(
pattern
)
)
return replacements
def _pick_format(album, fmt=None):
"""Pick a format string for printing Album or Item objects,
falling back to config options and defaults.
"""
if fmt:
return fmt
if album:
return config['list_format_album'].get(unicode)
else:
return config['list_format_item'].get(unicode)
def print_obj(obj, lib, fmt=None):
"""Print an Album or Item object. If `fmt` is specified, use that
format string. Otherwise, use the configured template.
"""
album = isinstance(obj, library.Album)
fmt = _pick_format(album, fmt)
if isinstance(fmt, Template):
template = fmt
else:
template = Template(fmt)
print_(obj.evaluate_template(template))
def term_width():
"""Get the width (columns) of the terminal."""
fallback = config['ui']['terminal_width'].get(int)
# The fcntl and termios modules are not available on non-Unix
# platforms, so we fall back to a constant.
try:
import fcntl
import termios
except ImportError:
return fallback
try:
buf = fcntl.ioctl(0, termios.TIOCGWINSZ, ' ' * 4)
except IOError:
return fallback
try:
height, width = struct.unpack('hh', buf)
except struct.error:
return fallback
return width
FLOAT_EPSILON = 0.01
def _field_diff(field, old, new):
"""Given two Model objects, format their values for `field` and
highlight changes among them. Return a human-readable string. If the
value has not changed, return None instead.
"""
oldval = old.get(field)
newval = new.get(field)
# If no change, abort.
if isinstance(oldval, float) and isinstance(newval, float) and \
abs(oldval - newval) < FLOAT_EPSILON:
return None
elif oldval == newval:
return None
# Get formatted values for output.
oldstr = old.formatted().get(field, u'')
newstr = new.formatted().get(field, u'')
# For strings, highlight changes. For others, colorize the whole
# thing.
if isinstance(oldval, basestring):
oldstr, newstr = colordiff(oldval, newstr)
else:
oldstr, newstr = colorize('red', oldstr), colorize('red', newstr)
return u'{0} -> {1}'.format(oldstr, newstr)
def show_model_changes(new, old=None, fields=None, always=False):
"""Given a Model object, print a list of changes from its pristine
version stored in the database. Return a boolean indicating whether
any changes were found.
`old` may be the "original" object to avoid using the pristine
version from the database. `fields` may be a list of fields to
restrict the detection to. `always` indicates whether the object is
always identified, regardless of whether any changes are present.
"""
old = old or new._db._get(type(new), new.id)
# Build up lines showing changed fields.
changes = []
for field in old:
# Subset of the fields. Never show mtime.
if field == 'mtime' or (fields and field not in fields):
continue
# Detect and show difference for this field.
line = _field_diff(field, old, new)
if line:
changes.append(u' {0}: {1}'.format(field, line))
# New fields.
for field in set(new) - set(old):
if fields and field not in fields:
continue
changes.append(u' {0}: {1}'.format(
field,
colorize('red', new.formatted()[field])
))
# Print changes.
if changes or always:
print_obj(old, old._db)
if changes:
print_(u'\n'.join(changes))
return bool(changes)
# Subcommand parsing infrastructure.
#
# This is a fairly generic subcommand parser for optparse. It is
# maintained externally here:
# http://gist.github.com/462717
# There you will also find a better description of the code and a more
# succinct example program.
class Subcommand(object):
"""A subcommand of a root command-line application that may be
invoked by a SubcommandOptionParser.
"""
def __init__(self, name, parser=None, help='', aliases=(), hide=False):
"""Creates a new subcommand. name is the primary way to invoke
the subcommand; aliases are alternate names. parser is an
OptionParser responsible for parsing the subcommand's options.
help is a short description of the command. If no parser is
given, it defaults to a new, empty OptionParser.
"""
self.name = name
self.parser = parser or optparse.OptionParser()
self.aliases = aliases
self.help = help
self.hide = hide
self._root_parser = None
def print_help(self):
self.parser.print_help()
def parse_args(self, args):
return self.parser.parse_args(args)
@property
def root_parser(self):
return self._root_parser
@root_parser.setter
def root_parser(self, root_parser):
self._root_parser = root_parser
self.parser.prog = '{0} {1}'.format(root_parser.get_prog_name(),
self.name)
class SubcommandsOptionParser(optparse.OptionParser):
"""A variant of OptionParser that parses subcommands and their
arguments.
"""
def __init__(self, *args, **kwargs):
"""Create a new subcommand-aware option parser. All of the
options to OptionParser.__init__ are supported in addition
to subcommands, a sequence of Subcommand objects.
"""
# A more helpful default usage.
if 'usage' not in kwargs:
kwargs['usage'] = """
%prog COMMAND [ARGS...]
%prog help COMMAND"""
kwargs['add_help_option'] = False
# Super constructor.
optparse.OptionParser.__init__(self, *args, **kwargs)
# Our root parser needs to stop on the first unrecognized argument.
self.disable_interspersed_args()
self.subcommands = []
def add_subcommand(self, *cmds):
"""Adds a Subcommand object to the parser's list of commands.
"""
for cmd in cmds:
cmd.root_parser = self
self.subcommands.append(cmd)
# Add the list of subcommands to the help message.
def format_help(self, formatter=None):
# Get the original help message, to which we will append.
out = optparse.OptionParser.format_help(self, formatter)
if formatter is None:
formatter = self.formatter
# Subcommands header.
result = ["\n"]
result.append(formatter.format_heading('Commands'))
formatter.indent()
# Generate the display names (including aliases).
# Also determine the help position.
disp_names = []
help_position = 0
subcommands = [c for c in self.subcommands if not c.hide]
subcommands.sort(key=lambda c: c.name)
for subcommand in subcommands:
name = subcommand.name
if subcommand.aliases:
name += ' (%s)' % ', '.join(subcommand.aliases)
disp_names.append(name)
# Set the help position based on the max width.
proposed_help_position = len(name) + formatter.current_indent + 2
if proposed_help_position <= formatter.max_help_position:
help_position = max(help_position, proposed_help_position)
# Add each subcommand to the output.
for subcommand, name in zip(subcommands, disp_names):
# Lifted directly from optparse.py.
name_width = help_position - formatter.current_indent - 2
if len(name) > name_width:
name = "%*s%s\n" % (formatter.current_indent, "", name)
indent_first = help_position
else:
name = "%*s%-*s " % (formatter.current_indent, "",
name_width, name)
indent_first = 0
result.append(name)
help_width = formatter.width - help_position
help_lines = textwrap.wrap(subcommand.help, help_width)
result.append("%*s%s\n" % (indent_first, "", help_lines[0]))
result.extend(["%*s%s\n" % (help_position, "", line)
for line in help_lines[1:]])
formatter.dedent()
# Concatenate the original help message with the subcommand
# list.
return out + "".join(result)
def _subcommand_for_name(self, name):
"""Return the subcommand in self.subcommands matching the
given name. The name may either be the name of a subcommand or
an alias. If no subcommand matches, returns None.
"""
for subcommand in self.subcommands:
if name == subcommand.name or \
name in subcommand.aliases:
return subcommand
return None
def parse_global_options(self, args):
"""Parse options up to the subcommand argument. Returns a tuple
of the options object and the remaining arguments.
"""
options, subargs = self.parse_args(args)
# Force the help command
if options.help:
subargs = ['help']
elif options.version:
subargs = ['version']
return options, subargs
def parse_subcommand(self, args):
"""Given the `args` left unused by a `parse_global_options`,
return the invoked subcommand, the subcommand options, and the
subcommand arguments.
"""
# Help is default command
if not args:
args = ['help']
cmdname = args.pop(0)
subcommand = self._subcommand_for_name(cmdname)
if not subcommand:
raise UserError("unknown command '{0}'".format(cmdname))
suboptions, subargs = subcommand.parse_args(args)
return subcommand, suboptions, subargs
optparse.Option.ALWAYS_TYPED_ACTIONS += ('callback',)
def vararg_callback(option, opt_str, value, parser):
"""Callback for an option with variable arguments.
Manually collect arguments right of a callback-action
option (ie. with action="callback"), and add the resulting
list to the destination var.
Usage:
parser.add_option("-c", "--callback", dest="vararg_attr",
action="callback", callback=vararg_callback)
Details:
http://docs.python.org/2/library/optparse.html#callback-example-6-variable
-arguments
"""
value = [value]
def floatable(str):
try:
float(str)
return True
except ValueError:
return False
for arg in parser.rargs:
# stop on --foo like options
if arg[:2] == "--" and len(arg) > 2:
break
# stop on -a, but not on -3 or -3.0
if arg[:1] == "-" and len(arg) > 1 and not floatable(arg):
break
value.append(arg)
del parser.rargs[:len(value) - 1]
setattr(parser.values, option.dest, value)
# The main entry point and bootstrapping.
def _load_plugins(config):
"""Load the plugins specified in the configuration.
"""
paths = config['pluginpath'].get(confit.StrSeq(split=False))
paths = map(util.normpath, paths)
import beetsplug
beetsplug.__path__ = paths + beetsplug.__path__
# For backwards compatibility.
sys.path += paths
plugins.load_plugins(config['plugins'].as_str_seq())
plugins.send("pluginload")
return plugins
def _setup(options, lib=None):
"""Prepare and global state and updates it with command line options.
Returns a list of subcommands, a list of plugins, and a library instance.
"""
# Configure the MusicBrainz API.
mb.configure()
config = _configure(options)
plugins = _load_plugins(config)
# Get the default subcommands.
from beets.ui.commands import default_commands
subcommands = list(default_commands)
subcommands.extend(plugins.commands())
if lib is None:
lib = _open_library(config)
plugins.send("library_opened", lib=lib)
library.Item._types = plugins.types(library.Item)
library.Album._types = plugins.types(library.Album)
return subcommands, plugins, lib
def _configure(options):
"""Amend the global configuration object with command line options.
"""
# Add any additional config files specified with --config. This
# special handling lets specified plugins get loaded before we
# finish parsing the command line.
if getattr(options, 'config', None) is not None:
config_path = options.config
del options.config
config.set_file(config_path)
config.set_args(options)
# Configure the logger.
if config['verbose'].get(bool):
log.setLevel(logging.DEBUG)
else:
log.setLevel(logging.INFO)
config_path = config.user_config_path()
if os.path.isfile(config_path):
log.debug(u'user configuration: {0}'.format(
util.displayable_path(config_path)))
else:
log.debug(u'no user configuration found at {0}'.format(
util.displayable_path(config_path)))
log.debug(u'data directory: {0}'
.format(util.displayable_path(config.config_dir())))
return config
def _open_library(config):
"""Create a new library instance from the configuration.
"""
dbpath = config['library'].as_filename()
try:
lib = library.Library(
dbpath,
config['directory'].as_filename(),
get_path_formats(),
get_replacements(),
)
lib.get_item(0) # Test database connection.
except (sqlite3.OperationalError, sqlite3.DatabaseError):
log.debug(traceback.format_exc())
raise UserError(u"database file {0} could not be opened".format(
util.displayable_path(dbpath)
))
log.debug(u'library database: {0}\n'
u'library directory: {1}'
.format(util.displayable_path(lib.path),
util.displayable_path(lib.directory)))
return lib
def _raw_main(args, lib=None):
"""A helper function for `main` without top-level exception
handling.
"""
parser = SubcommandsOptionParser()
parser.add_option('-l', '--library', dest='library',
help='library database file to use')
parser.add_option('-d', '--directory', dest='directory',
help="destination music directory")
parser.add_option('-v', '--verbose', dest='verbose', action='store_true',
help='print debugging information')
parser.add_option('-c', '--config', dest='config',
help='path to configuration file')
parser.add_option('-h', '--help', dest='help', action='store_true',
help='how this help message and exit')
parser.add_option('--version', dest='version', action='store_true',
help=optparse.SUPPRESS_HELP)
options, subargs = parser.parse_global_options(args)
# Special case for the `config --edit` command: bypass _setup so
# that an invalid configuration does not prevent the editor from
# starting.
if subargs[0] == 'config' and ('-e' in subargs or '--edit' in subargs):
from beets.ui.commands import config_edit
return config_edit()
subcommands, plugins, lib = _setup(options, lib)
parser.add_subcommand(*subcommands)
subcommand, suboptions, subargs = parser.parse_subcommand(subargs)
subcommand.func(lib, suboptions, subargs)
plugins.send('cli_exit', lib=lib)
def main(args=None):
"""Run the main command-line interface for beets. Includes top-level
exception handlers that print friendly error messages.
"""
try:
_raw_main(args)
except UserError as exc:
message = exc.args[0] if exc.args else None
log.error(u'error: {0}'.format(message))
sys.exit(1)
except util.HumanReadableException as exc:
exc.log(log)
sys.exit(1)
except library.FileOperationError as exc:
# These errors have reasonable human-readable descriptions, but
# we still want to log their tracebacks for debugging.
log.debug(traceback.format_exc())
log.error(exc)
sys.exit(1)
except confit.ConfigError as exc:
log.error(u'configuration error: {0}'.format(exc))
sys.exit(1)
except IOError as exc:
if exc.errno == errno.EPIPE:
# "Broken pipe". End silently.
pass
else:
raise
except KeyboardInterrupt:
# Silently ignore ^C except in verbose mode.
log.debug(traceback.format_exc())

1646
lib/beets/ui/commands.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,162 @@
# This file is part of beets.
# Copyright (c) 2014, Thomas Scholtes.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
# Completion for the `beet` command
# =================================
#
# Load this script to complete beets subcommands, options, and
# queries.
#
# If a beets command is found on the command line it completes filenames and
# the subcommand's options. Otherwise it will complete global options and
# subcommands. If the previous option on the command line expects an argument,
# it also completes filenames or directories. Options are only
# completed if '-' has already been typed on the command line.
#
# Note that completion of plugin commands only works for those plugins
# that were enabled when running `beet completion`. It does not check
# plugins dynamically
#
# Currently, only Bash 3.2 and newer is supported and the
# `bash-completion` package is requied.
#
# TODO
# ----
#
# * There are some issues with arguments that are quoted on the command line.
#
# * Complete arguments for the `--format` option by expanding field variables.
#
# beet ls -f "$tit[TAB]
# beet ls -f "$title
#
# * Support long options with `=`, e.g. `--config=file`. Debian's bash
# completion package can handle this.
#
# Determines the beets subcommand and dispatches the completion
# accordingly.
_beet_dispatch() {
local cur prev cmd=
COMPREPLY=()
_get_comp_words_by_ref -n : cur prev
# Look for the beets subcommand
local arg
for (( i=1; i < COMP_CWORD; i++ )); do
arg="${COMP_WORDS[i]}"
if _list_include_item "${opts___global}" $arg; then
((i++))
elif [[ "$arg" != -* ]]; then
cmd="$arg"
break
fi
done
# Replace command shortcuts
if [[ -n $cmd ]] && _list_include_item "$aliases" "$cmd"; then
eval "cmd=\$alias__$cmd"
fi
case $cmd in
help)
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
;;
list|remove|move|update|write|stats)
_beet_complete_query
;;
"")
_beet_complete_global
;;
*)
_beet_complete
;;
esac
}
# Adds option and file completion to COMPREPLY for the subcommand $cmd
_beet_complete() {
if [[ $cur == -* ]]; then
local opts flags completions
eval "opts=\$opts__$cmd"
eval "flags=\$flags__$cmd"
completions="${flags___common} ${opts} ${flags}"
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
else
_filedir
fi
}
# Add global options and subcommands to the completion
_beet_complete_global() {
case $prev in
-h|--help)
# Complete commands
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
return
;;
-l|--library|-c|--config)
# Filename completion
_filedir
return
;;
-d|--directory)
# Directory completion
_filedir -d
return
;;
esac
if [[ $cur == -* ]]; then
local completions="$opts___global $flags___global"
COMPREPLY+=( $(compgen -W "$completions" -- $cur) )
elif [[ -n $cur ]] && _list_include_item "$aliases" "$cur"; then
local cmd
eval "cmd=\$alias__$cur"
COMPREPLY+=( "$cmd" )
else
COMPREPLY+=( $(compgen -W "$commands" -- $cur) )
fi
}
_beet_complete_query() {
local opts
eval "opts=\$opts__$cmd"
if [[ $cur == -* ]] || _list_include_item "$opts" "$prev"; then
_beet_complete
elif [[ $cur != \'* && $cur != \"* &&
$cur != *:* ]]; then
# Do not complete quoted queries or those who already have a field
# set.
compopt -o nospace
COMPREPLY+=( $(compgen -S : -W "$fields" -- $cur) )
return 0
fi
}
# Returns true if the space separated list $1 includes $2
_list_include_item() {
[[ " $1 " == *[[:space:]]$2[[:space:]]* ]]
}
# This is where beets dynamically adds the _beet function. This
# function sets the variables $flags, $opts, $commands, and $aliases.
complete -o filenames -F _beet beet

686
lib/beets/util/__init__.py Normal file
View File

@@ -0,0 +1,686 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Miscellaneous utility functions."""
from __future__ import division
import os
import sys
import re
import shutil
import fnmatch
from collections import defaultdict
import traceback
import subprocess
import platform
MAX_FILENAME_LENGTH = 200
WINDOWS_MAGIC_PREFIX = u'\\\\?\\'
class HumanReadableException(Exception):
"""An Exception that can include a human-readable error message to
be logged without a traceback. Can preserve a traceback for
debugging purposes as well.
Has at least two fields: `reason`, the underlying exception or a
string describing the problem; and `verb`, the action being
performed during the error.
If `tb` is provided, it is a string containing a traceback for the
associated exception. (Note that this is not necessary in Python 3.x
and should be removed when we make the transition.)
"""
error_kind = 'Error' # Human-readable description of error type.
def __init__(self, reason, verb, tb=None):
self.reason = reason
self.verb = verb
self.tb = tb
super(HumanReadableException, self).__init__(self.get_message())
def _gerund(self):
"""Generate a (likely) gerund form of the English verb.
"""
if ' ' in self.verb:
return self.verb
gerund = self.verb[:-1] if self.verb.endswith('e') else self.verb
gerund += 'ing'
return gerund
def _reasonstr(self):
"""Get the reason as a string."""
if isinstance(self.reason, unicode):
return self.reason
elif isinstance(self.reason, basestring): # Byte string.
return self.reason.decode('utf8', 'ignore')
elif hasattr(self.reason, 'strerror'): # i.e., EnvironmentError
return self.reason.strerror
else:
return u'"{0}"'.format(unicode(self.reason))
def get_message(self):
"""Create the human-readable description of the error, sans
introduction.
"""
raise NotImplementedError
def log(self, logger):
"""Log to the provided `logger` a human-readable message as an
error and a verbose traceback as a debug message.
"""
if self.tb:
logger.debug(self.tb)
logger.error(u'{0}: {1}'.format(self.error_kind, self.args[0]))
class FilesystemError(HumanReadableException):
"""An error that occurred while performing a filesystem manipulation
via a function in this module. The `paths` field is a sequence of
pathnames involved in the operation.
"""
def __init__(self, reason, verb, paths, tb=None):
self.paths = paths
super(FilesystemError, self).__init__(reason, verb, tb)
def get_message(self):
# Use a nicer English phrasing for some specific verbs.
if self.verb in ('move', 'copy', 'rename'):
clause = u'while {0} {1} to {2}'.format(
self._gerund(),
displayable_path(self.paths[0]),
displayable_path(self.paths[1])
)
elif self.verb in ('delete', 'write', 'create', 'read'):
clause = u'while {0} {1}'.format(
self._gerund(),
displayable_path(self.paths[0])
)
else:
clause = u'during {0} of paths {1}'.format(
self.verb, u', '.join(displayable_path(p) for p in self.paths)
)
return u'{0} {1}'.format(self._reasonstr(), clause)
def normpath(path):
"""Provide the canonical form of the path suitable for storing in
the database.
"""
path = syspath(path, prefix=False)
path = os.path.normpath(os.path.abspath(os.path.expanduser(path)))
return bytestring_path(path)
def ancestry(path):
"""Return a list consisting of path's parent directory, its
grandparent, and so on. For instance:
>>> ancestry('/a/b/c')
['/', '/a', '/a/b']
The argument should *not* be the result of a call to `syspath`.
"""
out = []
last_path = None
while path:
path = os.path.dirname(path)
if path == last_path:
break
last_path = path
if path:
# don't yield ''
out.insert(0, path)
return out
def sorted_walk(path, ignore=(), logger=None):
"""Like `os.walk`, but yields things in case-insensitive sorted,
breadth-first order. Directory and file names matching any glob
pattern in `ignore` are skipped. If `logger` is provided, then
warning messages are logged there when a directory cannot be listed.
"""
# Make sure the path isn't a Unicode string.
path = bytestring_path(path)
# Get all the directories and files at this level.
try:
contents = os.listdir(syspath(path))
except OSError as exc:
if logger:
logger.warn(u'could not list directory {0}: {1}'.format(
displayable_path(path), exc.strerror
))
return
dirs = []
files = []
for base in contents:
base = bytestring_path(base)
# Skip ignored filenames.
skip = False
for pat in ignore:
if fnmatch.fnmatch(base, pat):
skip = True
break
if skip:
continue
# Add to output as either a file or a directory.
cur = os.path.join(path, base)
if os.path.isdir(syspath(cur)):
dirs.append(base)
else:
files.append(base)
# Sort lists (case-insensitive) and yield the current level.
dirs.sort(key=bytes.lower)
files.sort(key=bytes.lower)
yield (path, dirs, files)
# Recurse into directories.
for base in dirs:
cur = os.path.join(path, base)
# yield from sorted_walk(...)
for res in sorted_walk(cur, ignore, logger):
yield res
def mkdirall(path):
"""Make all the enclosing directories of path (like mkdir -p on the
parent).
"""
for ancestor in ancestry(path):
if not os.path.isdir(syspath(ancestor)):
try:
os.mkdir(syspath(ancestor))
except (OSError, IOError) as exc:
raise FilesystemError(exc, 'create', (ancestor,),
traceback.format_exc())
def fnmatch_all(names, patterns):
"""Determine whether all strings in `names` match at least one of
the `patterns`, which should be shell glob expressions.
"""
for name in names:
matches = False
for pattern in patterns:
matches = fnmatch.fnmatch(name, pattern)
if matches:
break
if not matches:
return False
return True
def prune_dirs(path, root=None, clutter=('.DS_Store', 'Thumbs.db')):
"""If path is an empty directory, then remove it. Recursively remove
path's ancestry up to root (which is never removed) where there are
empty directories. If path is not contained in root, then nothing is
removed. Glob patterns in clutter are ignored when determining
emptiness. If root is not provided, then only path may be removed
(i.e., no recursive removal).
"""
path = normpath(path)
if root is not None:
root = normpath(root)
ancestors = ancestry(path)
if root is None:
# Only remove the top directory.
ancestors = []
elif root in ancestors:
# Only remove directories below the root.
ancestors = ancestors[ancestors.index(root) + 1:]
else:
# Remove nothing.
return
# Traverse upward from path.
ancestors.append(path)
ancestors.reverse()
for directory in ancestors:
directory = syspath(directory)
if not os.path.exists(directory):
# Directory gone already.
continue
if fnmatch_all(os.listdir(directory), clutter):
# Directory contains only clutter (or nothing).
try:
shutil.rmtree(directory)
except OSError:
break
else:
break
def components(path):
"""Return a list of the path components in path. For instance:
>>> components('/a/b/c')
['a', 'b', 'c']
The argument should *not* be the result of a call to `syspath`.
"""
comps = []
ances = ancestry(path)
for anc in ances:
comp = os.path.basename(anc)
if comp:
comps.append(comp)
else: # root
comps.append(anc)
last = os.path.basename(path)
if last:
comps.append(last)
return comps
def _fsencoding():
"""Get the system's filesystem encoding. On Windows, this is always
UTF-8 (not MBCS).
"""
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
if encoding == 'mbcs':
# On Windows, a broken encoding known to Python as "MBCS" is
# used for the filesystem. However, we only use the Unicode API
# for Windows paths, so the encoding is actually immaterial so
# we can avoid dealing with this nastiness. We arbitrarily
# choose UTF-8.
encoding = 'utf8'
return encoding
def bytestring_path(path):
"""Given a path, which is either a str or a unicode, returns a str
path (ensuring that we never deal with Unicode pathnames).
"""
# Pass through bytestrings.
if isinstance(path, str):
return path
# On Windows, remove the magic prefix added by `syspath`. This makes
# ``bytestring_path(syspath(X)) == X``, i.e., we can safely
# round-trip through `syspath`.
if os.path.__name__ == 'ntpath' and path.startswith(WINDOWS_MAGIC_PREFIX):
path = path[len(WINDOWS_MAGIC_PREFIX):]
# Try to encode with default encodings, but fall back to UTF8.
try:
return path.encode(_fsencoding())
except (UnicodeError, LookupError):
return path.encode('utf8')
def displayable_path(path, separator=u'; '):
"""Attempts to decode a bytestring path to a unicode object for the
purpose of displaying it to the user. If the `path` argument is a
list or a tuple, the elements are joined with `separator`.
"""
if isinstance(path, (list, tuple)):
return separator.join(displayable_path(p) for p in path)
elif isinstance(path, unicode):
return path
elif not isinstance(path, str):
# A non-string object: just get its unicode representation.
return unicode(path)
try:
return path.decode(_fsencoding(), 'ignore')
except (UnicodeError, LookupError):
return path.decode('utf8', 'ignore')
def syspath(path, prefix=True):
"""Convert a path for use by the operating system. In particular,
paths on Windows must receive a magic prefix and must be converted
to Unicode before they are sent to the OS. To disable the magic
prefix on Windows, set `prefix` to False---but only do this if you
*really* know what you're doing.
"""
# Don't do anything if we're not on windows
if os.path.__name__ != 'ntpath':
return path
if not isinstance(path, unicode):
# Beets currently represents Windows paths internally with UTF-8
# arbitrarily. But earlier versions used MBCS because it is
# reported as the FS encoding by Windows. Try both.
try:
path = path.decode('utf8')
except UnicodeError:
# The encoding should always be MBCS, Windows' broken
# Unicode representation.
encoding = sys.getfilesystemencoding() or sys.getdefaultencoding()
path = path.decode(encoding, 'replace')
# Add the magic prefix if it isn't already there.
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
if prefix and not path.startswith(WINDOWS_MAGIC_PREFIX):
if path.startswith(u'\\\\'):
# UNC path. Final path should look like \\?\UNC\...
path = u'UNC' + path[1:]
path = WINDOWS_MAGIC_PREFIX + path
return path
def samefile(p1, p2):
"""Safer equality for paths."""
return shutil._samefile(syspath(p1), syspath(p2))
def remove(path, soft=True):
"""Remove the file. If `soft`, then no error will be raised if the
file does not exist.
"""
path = syspath(path)
if soft and not os.path.exists(path):
return
try:
os.remove(path)
except (OSError, IOError) as exc:
raise FilesystemError(exc, 'delete', (path,), traceback.format_exc())
def copy(path, dest, replace=False):
"""Copy a plain file. Permissions are not copied. If `dest` already
exists, raises a FilesystemError unless `replace` is True. Has no
effect if `path` is the same as `dest`. Paths are translated to
system paths before the syscall.
"""
if samefile(path, dest):
return
path = syspath(path)
dest = syspath(dest)
if not replace and os.path.exists(dest):
raise FilesystemError('file exists', 'copy', (path, dest))
try:
shutil.copyfile(path, dest)
except (OSError, IOError) as exc:
raise FilesystemError(exc, 'copy', (path, dest),
traceback.format_exc())
def move(path, dest, replace=False):
"""Rename a file. `dest` may not be a directory. If `dest` already
exists, raises an OSError unless `replace` is True. Has no effect if
`path` is the same as `dest`. If the paths are on different
filesystems (or the rename otherwise fails), a copy is attempted
instead, in which case metadata will *not* be preserved. Paths are
translated to system paths.
"""
if samefile(path, dest):
return
path = syspath(path)
dest = syspath(dest)
if os.path.exists(dest) and not replace:
raise FilesystemError('file exists', 'rename', (path, dest),
traceback.format_exc())
# First, try renaming the file.
try:
os.rename(path, dest)
except OSError:
# Otherwise, copy and delete the original.
try:
shutil.copyfile(path, dest)
os.remove(path)
except (OSError, IOError) as exc:
raise FilesystemError(exc, 'move', (path, dest),
traceback.format_exc())
def link(path, dest, replace=False):
"""Create a symbolic link from path to `dest`. Raises an OSError if
`dest` already exists, unless `replace` is True. Does nothing if
`path` == `dest`."""
if (samefile(path, dest)):
return
path = syspath(path)
dest = syspath(dest)
if os.path.exists(dest) and not replace:
raise FilesystemError('file exists', 'rename', (path, dest),
traceback.format_exc())
try:
os.symlink(path, dest)
except OSError:
raise FilesystemError('Operating system does not support symbolic '
'links.', 'link', (path, dest),
traceback.format_exc())
def unique_path(path):
"""Returns a version of ``path`` that does not exist on the
filesystem. Specifically, if ``path` itself already exists, then
something unique is appended to the path.
"""
if not os.path.exists(syspath(path)):
return path
base, ext = os.path.splitext(path)
match = re.search(r'\.(\d)+$', base)
if match:
num = int(match.group(1))
base = base[:match.start()]
else:
num = 0
while True:
num += 1
new_path = '%s.%i%s' % (base, num, ext)
if not os.path.exists(new_path):
return new_path
# Note: The Windows "reserved characters" are, of course, allowed on
# Unix. They are forbidden here because they cause problems on Samba
# shares, which are sufficiently common as to cause frequent problems.
# http://msdn.microsoft.com/en-us/library/windows/desktop/aa365247.aspx
CHAR_REPLACE = [
(re.compile(ur'[\\/]'), u'_'), # / and \ -- forbidden everywhere.
(re.compile(ur'^\.'), u'_'), # Leading dot (hidden files on Unix).
(re.compile(ur'[\x00-\x1f]'), u''), # Control characters.
(re.compile(ur'[<>:"\?\*\|]'), u'_'), # Windows "reserved characters".
(re.compile(ur'\.$'), u'_'), # Trailing dots.
(re.compile(ur'\s+$'), u''), # Trailing whitespace.
]
def sanitize_path(path, replacements=None):
"""Takes a path (as a Unicode string) and makes sure that it is
legal. Returns a new path. Only works with fragments; won't work
reliably on Windows when a path begins with a drive letter. Path
separators (including altsep!) should already be cleaned from the
path components. If replacements is specified, it is used *instead*
of the default set of replacements; it must be a list of (compiled
regex, replacement string) pairs.
"""
replacements = replacements or CHAR_REPLACE
comps = components(path)
if not comps:
return ''
for i, comp in enumerate(comps):
for regex, repl in replacements:
comp = regex.sub(repl, comp)
comps[i] = comp
return os.path.join(*comps)
def truncate_path(path, length=MAX_FILENAME_LENGTH):
"""Given a bytestring path or a Unicode path fragment, truncate the
components to a legal length. In the last component, the extension
is preserved.
"""
comps = components(path)
out = [c[:length] for c in comps]
base, ext = os.path.splitext(comps[-1])
if ext:
# Last component has an extension.
base = base[:length - len(ext)]
out[-1] = base + ext
return os.path.join(*out)
def str2bool(value):
"""Returns a boolean reflecting a human-entered string."""
if value.lower() in ('yes', '1', 'true', 't', 'y'):
return True
else:
return False
def as_string(value):
"""Convert a value to a Unicode object for matching with a query.
None becomes the empty string. Bytestrings are silently decoded.
"""
if value is None:
return u''
elif isinstance(value, buffer):
return str(value).decode('utf8', 'ignore')
elif isinstance(value, str):
return value.decode('utf8', 'ignore')
else:
return unicode(value)
def levenshtein(s1, s2):
"""A nice DP edit distance implementation from Wikibooks:
http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/
Levenshtein_distance#Python
"""
if len(s1) < len(s2):
return levenshtein(s2, s1)
if not s1:
return len(s2)
previous_row = xrange(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
def plurality(objs):
"""Given a sequence of comparable objects, returns the object that
is most common in the set and the frequency of that object. The
sequence must contain at least one object.
"""
# Calculate frequencies.
freqs = defaultdict(int)
for obj in objs:
freqs[obj] += 1
if not freqs:
raise ValueError('sequence must be non-empty')
# Find object with maximum frequency.
max_freq = 0
res = None
for obj, freq in freqs.items():
if freq > max_freq:
max_freq = freq
res = obj
return res, max_freq
def cpu_count():
"""Return the number of hardware thread contexts (cores or SMT
threads) in the system.
"""
# Adapted from the soundconverter project:
# https://github.com/kassoulet/soundconverter
if sys.platform == 'win32':
try:
num = int(os.environ['NUMBER_OF_PROCESSORS'])
except (ValueError, KeyError):
num = 0
elif sys.platform == 'darwin':
try:
num = int(command_output(['sysctl', '-n', 'hw.ncpu']))
except ValueError:
num = 0
else:
try:
num = os.sysconf('SC_NPROCESSORS_ONLN')
except (ValueError, OSError, AttributeError):
num = 0
if num >= 1:
return num
else:
return 1
def command_output(cmd, shell=False):
"""Runs the command and returns its output after it has exited.
``cmd`` is a list of arguments starting with the command names. If
``shell`` is true, ``cmd`` is assumed to be a string and passed to a
shell to execute.
If the process exits with a non-zero return code
``subprocess.CalledProcessError`` is raised. May also raise
``OSError``.
This replaces `subprocess.check_output`, which isn't available in
Python 2.6 and which can have problems if lots of output is sent to
stderr.
"""
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=platform.system() != 'Windows',
shell=shell
)
stdout, stderr = proc.communicate()
if proc.returncode:
raise subprocess.CalledProcessError(
returncode=proc.returncode,
cmd=' '.join(cmd),
)
return stdout
def max_filename_length(path, limit=MAX_FILENAME_LENGTH):
"""Attempt to determine the maximum filename length for the
filesystem containing `path`. If the value is greater than `limit`,
then `limit` is used instead (to prevent errors when a filesystem
misreports its capacity). If it cannot be determined (e.g., on
Windows), return `limit`.
"""
if hasattr(os, 'statvfs'):
try:
res = os.statvfs(path)
except OSError:
return limit
return min(res[9], limit)
else:
return limit

View File

@@ -0,0 +1,211 @@
# This file is part of beets.
# Copyright 2014, Fabrice Laporte
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Abstraction layer to resize images using PIL, ImageMagick, or a
public resizing proxy if neither is available.
"""
import urllib
import subprocess
import os
import re
from tempfile import NamedTemporaryFile
import logging
from beets import util
# Resizing methods
PIL = 1
IMAGEMAGICK = 2
WEBPROXY = 3
PROXY_URL = 'http://images.weserv.nl/'
log = logging.getLogger('beets')
def resize_url(url, maxwidth):
"""Return a proxied image URL that resizes the original image to
maxwidth (preserving aspect ratio).
"""
return '{0}?{1}'.format(PROXY_URL, urllib.urlencode({
'url': url.replace('http://', ''),
'w': str(maxwidth),
}))
def temp_file_for(path):
"""Return an unused filename with the same extension as the
specified path.
"""
ext = os.path.splitext(path)[1]
with NamedTemporaryFile(suffix=ext, delete=False) as f:
return f.name
def pil_resize(maxwidth, path_in, path_out=None):
"""Resize using Python Imaging Library (PIL). Return the output path
of resized image.
"""
path_out = path_out or temp_file_for(path_in)
from PIL import Image
log.debug(u'artresizer: PIL resizing {0} to {1}'.format(
util.displayable_path(path_in), util.displayable_path(path_out)
))
try:
im = Image.open(util.syspath(path_in))
size = maxwidth, maxwidth
im.thumbnail(size, Image.ANTIALIAS)
im.save(path_out)
return path_out
except IOError:
log.error(u"PIL cannot create thumbnail for '{0}'".format(
util.displayable_path(path_in)
))
return path_in
def im_resize(maxwidth, path_in, path_out=None):
"""Resize using ImageMagick's ``convert`` tool.
Return the output path of resized image.
"""
path_out = path_out or temp_file_for(path_in)
log.debug(u'artresizer: ImageMagick resizing {0} to {1}'.format(
util.displayable_path(path_in), util.displayable_path(path_out)
))
# "-resize widthxheight>" shrinks images with dimension(s) larger
# than the corresponding width and/or height dimension(s). The >
# "only shrink" flag is prefixed by ^ escape char for Windows
# compatibility.
try:
util.command_output([
'convert', util.syspath(path_in),
'-resize', '{0}x^>'.format(maxwidth), path_out
])
except subprocess.CalledProcessError:
log.warn(u'artresizer: IM convert failed for {0}'.format(
util.displayable_path(path_in)
))
return path_in
return path_out
BACKEND_FUNCS = {
PIL: pil_resize,
IMAGEMAGICK: im_resize,
}
class Shareable(type):
"""A pseudo-singleton metaclass that allows both shared and
non-shared instances. The ``MyClass.shared`` property holds a
lazily-created shared instance of ``MyClass`` while calling
``MyClass()`` to construct a new object works as usual.
"""
def __init__(cls, name, bases, dict):
super(Shareable, cls).__init__(name, bases, dict)
cls._instance = None
@property
def shared(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
class ArtResizer(object):
"""A singleton class that performs image resizes.
"""
__metaclass__ = Shareable
def __init__(self, method=None):
"""Create a resizer object for the given method or, if none is
specified, with an inferred method.
"""
self.method = self._check_method(method)
log.debug(u"artresizer: method is {0}".format(self.method))
self.can_compare = self._can_compare()
def resize(self, maxwidth, path_in, path_out=None):
"""Manipulate an image file according to the method, returning a
new path. For PIL or IMAGEMAGIC methods, resizes the image to a
temporary file. For WEBPROXY, returns `path_in` unmodified.
"""
if self.local:
func = BACKEND_FUNCS[self.method[0]]
return func(maxwidth, path_in, path_out)
else:
return path_in
def proxy_url(self, maxwidth, url):
"""Modifies an image URL according the method, returning a new
URL. For WEBPROXY, a URL on the proxy server is returned.
Otherwise, the URL is returned unmodified.
"""
if self.local:
return url
else:
return resize_url(url, maxwidth)
@property
def local(self):
"""A boolean indicating whether the resizing method is performed
locally (i.e., PIL or ImageMagick).
"""
return self.method[0] in BACKEND_FUNCS
def _can_compare(self):
"""A boolean indicating whether image comparison is available"""
return self.method[0] == IMAGEMAGICK and self.method[1] > (6, 8, 7)
@staticmethod
def _check_method(method=None):
"""A tuple indicating whether current method is available and its
version. If no method is given, it returns a supported one.
"""
# Guess available method
if not method:
for m in [IMAGEMAGICK, PIL]:
_, version = ArtResizer._check_method(m)
if version:
return (m, version)
return (WEBPROXY, (0))
if method == IMAGEMAGICK:
# Try invoking ImageMagick's "convert".
try:
out = util.command_output(['identify', '--version'])
if 'imagemagick' in out.lower():
pattern = r".+ (\d+)\.(\d+)\.(\d+).*"
match = re.search(pattern, out)
if match:
return (IMAGEMAGICK,
(int(match.group(1)),
int(match.group(2)),
int(match.group(3))))
return (IMAGEMAGICK, (0))
except (subprocess.CalledProcessError, OSError):
return (IMAGEMAGICK, None)
if method == PIL:
# Try importing PIL.
try:
__import__('PIL', fromlist=['Image'])
return (PIL, (0))
except ImportError:
return (PIL, None)

647
lib/beets/util/bluelet.py Normal file
View File

@@ -0,0 +1,647 @@
"""Extremely simple pure-Python implementation of coroutine-style
asynchronous socket I/O. Inspired by, but inferior to, Eventlet.
Bluelet can also be thought of as a less-terrible replacement for
asyncore.
Bluelet: easy concurrency without all the messy parallelism.
"""
import socket
import select
import sys
import types
import errno
import traceback
import time
import collections
# A little bit of "six" (Python 2/3 compatibility): cope with PEP 3109 syntax
# changes.
PY3 = sys.version_info[0] == 3
if PY3:
def _reraise(typ, exc, tb):
raise exc.with_traceback(tb)
else:
exec("""
def _reraise(typ, exc, tb):
raise typ, exc, tb
""")
# Basic events used for thread scheduling.
class Event(object):
"""Just a base class identifying Bluelet events. An event is an
object yielded from a Bluelet thread coroutine to suspend operation
and communicate with the scheduler.
"""
pass
class WaitableEvent(Event):
"""A waitable event is one encapsulating an action that can be
waited for using a select() call. That is, it's an event with an
associated file descriptor.
"""
def waitables(self):
"""Return "waitable" objects to pass to select(). Should return
three iterables for input readiness, output readiness, and
exceptional conditions (i.e., the three lists passed to
select()).
"""
return (), (), ()
def fire(self):
"""Called when an associated file descriptor becomes ready
(i.e., is returned from a select() call).
"""
pass
class ValueEvent(Event):
"""An event that does nothing but return a fixed value."""
def __init__(self, value):
self.value = value
class ExceptionEvent(Event):
"""Raise an exception at the yield point. Used internally."""
def __init__(self, exc_info):
self.exc_info = exc_info
class SpawnEvent(Event):
"""Add a new coroutine thread to the scheduler."""
def __init__(self, coro):
self.spawned = coro
class JoinEvent(Event):
"""Suspend the thread until the specified child thread has
completed.
"""
def __init__(self, child):
self.child = child
class KillEvent(Event):
"""Unschedule a child thread."""
def __init__(self, child):
self.child = child
class DelegationEvent(Event):
"""Suspend execution of the current thread, start a new thread and,
once the child thread finished, return control to the parent
thread.
"""
def __init__(self, coro):
self.spawned = coro
class ReturnEvent(Event):
"""Return a value the current thread's delegator at the point of
delegation. Ends the current (delegate) thread.
"""
def __init__(self, value):
self.value = value
class SleepEvent(WaitableEvent):
"""Suspend the thread for a given duration.
"""
def __init__(self, duration):
self.wakeup_time = time.time() + duration
def time_left(self):
return max(self.wakeup_time - time.time(), 0.0)
class ReadEvent(WaitableEvent):
"""Reads from a file-like object."""
def __init__(self, fd, bufsize):
self.fd = fd
self.bufsize = bufsize
def waitables(self):
return (self.fd,), (), ()
def fire(self):
return self.fd.read(self.bufsize)
class WriteEvent(WaitableEvent):
"""Writes to a file-like object."""
def __init__(self, fd, data):
self.fd = fd
self.data = data
def waitable(self):
return (), (self.fd,), ()
def fire(self):
self.fd.write(self.data)
# Core logic for executing and scheduling threads.
def _event_select(events):
"""Perform a select() over all the Events provided, returning the
ones ready to be fired. Only WaitableEvents (including SleepEvents)
matter here; all other events are ignored (and thus postponed).
"""
# Gather waitables and wakeup times.
waitable_to_event = {}
rlist, wlist, xlist = [], [], []
earliest_wakeup = None
for event in events:
if isinstance(event, SleepEvent):
if not earliest_wakeup:
earliest_wakeup = event.wakeup_time
else:
earliest_wakeup = min(earliest_wakeup, event.wakeup_time)
elif isinstance(event, WaitableEvent):
r, w, x = event.waitables()
rlist += r
wlist += w
xlist += x
for waitable in r:
waitable_to_event[('r', waitable)] = event
for waitable in w:
waitable_to_event[('w', waitable)] = event
for waitable in x:
waitable_to_event[('x', waitable)] = event
# If we have a any sleeping threads, determine how long to sleep.
if earliest_wakeup:
timeout = max(earliest_wakeup - time.time(), 0.0)
else:
timeout = None
# Perform select() if we have any waitables.
if rlist or wlist or xlist:
rready, wready, xready = select.select(rlist, wlist, xlist, timeout)
else:
rready, wready, xready = (), (), ()
if timeout:
time.sleep(timeout)
# Gather ready events corresponding to the ready waitables.
ready_events = set()
for ready in rready:
ready_events.add(waitable_to_event[('r', ready)])
for ready in wready:
ready_events.add(waitable_to_event[('w', ready)])
for ready in xready:
ready_events.add(waitable_to_event[('x', ready)])
# Gather any finished sleeps.
for event in events:
if isinstance(event, SleepEvent) and event.time_left() == 0.0:
ready_events.add(event)
return ready_events
class ThreadException(Exception):
def __init__(self, coro, exc_info):
self.coro = coro
self.exc_info = exc_info
def reraise(self):
_reraise(self.exc_info[0], self.exc_info[1], self.exc_info[2])
SUSPENDED = Event() # Special sentinel placeholder for suspended threads.
class Delegated(Event):
"""Placeholder indicating that a thread has delegated execution to a
different thread.
"""
def __init__(self, child):
self.child = child
def run(root_coro):
"""Schedules a coroutine, running it to completion. This
encapsulates the Bluelet scheduler, which the root coroutine can
add to by spawning new coroutines.
"""
# The "threads" dictionary keeps track of all the currently-
# executing and suspended coroutines. It maps coroutines to their
# currently "blocking" event. The event value may be SUSPENDED if
# the coroutine is waiting on some other condition: namely, a
# delegated coroutine or a joined coroutine. In this case, the
# coroutine should *also* appear as a value in one of the below
# dictionaries `delegators` or `joiners`.
threads = {root_coro: ValueEvent(None)}
# Maps child coroutines to delegating parents.
delegators = {}
# Maps child coroutines to joining (exit-waiting) parents.
joiners = collections.defaultdict(list)
def complete_thread(coro, return_value):
"""Remove a coroutine from the scheduling pool, awaking
delegators and joiners as necessary and returning the specified
value to any delegating parent.
"""
del threads[coro]
# Resume delegator.
if coro in delegators:
threads[delegators[coro]] = ValueEvent(return_value)
del delegators[coro]
# Resume joiners.
if coro in joiners:
for parent in joiners[coro]:
threads[parent] = ValueEvent(None)
del joiners[coro]
def advance_thread(coro, value, is_exc=False):
"""After an event is fired, run a given coroutine associated with
it in the threads dict until it yields again. If the coroutine
exits, then the thread is removed from the pool. If the coroutine
raises an exception, it is reraised in a ThreadException. If
is_exc is True, then the value must be an exc_info tuple and the
exception is thrown into the coroutine.
"""
try:
if is_exc:
next_event = coro.throw(*value)
else:
next_event = coro.send(value)
except StopIteration:
# Thread is done.
complete_thread(coro, None)
except:
# Thread raised some other exception.
del threads[coro]
raise ThreadException(coro, sys.exc_info())
else:
if isinstance(next_event, types.GeneratorType):
# Automatically invoke sub-coroutines. (Shorthand for
# explicit bluelet.call().)
next_event = DelegationEvent(next_event)
threads[coro] = next_event
def kill_thread(coro):
"""Unschedule this thread and its (recursive) delegates.
"""
# Collect all coroutines in the delegation stack.
coros = [coro]
while isinstance(threads[coro], Delegated):
coro = threads[coro].child
coros.append(coro)
# Complete each coroutine from the top to the bottom of the
# stack.
for coro in reversed(coros):
complete_thread(coro, None)
# Continue advancing threads until root thread exits.
exit_te = None
while threads:
try:
# Look for events that can be run immediately. Continue
# running immediate events until nothing is ready.
while True:
have_ready = False
for coro, event in list(threads.items()):
if isinstance(event, SpawnEvent):
threads[event.spawned] = ValueEvent(None) # Spawn.
advance_thread(coro, None)
have_ready = True
elif isinstance(event, ValueEvent):
advance_thread(coro, event.value)
have_ready = True
elif isinstance(event, ExceptionEvent):
advance_thread(coro, event.exc_info, True)
have_ready = True
elif isinstance(event, DelegationEvent):
threads[coro] = Delegated(event.spawned) # Suspend.
threads[event.spawned] = ValueEvent(None) # Spawn.
delegators[event.spawned] = coro
have_ready = True
elif isinstance(event, ReturnEvent):
# Thread is done.
complete_thread(coro, event.value)
have_ready = True
elif isinstance(event, JoinEvent):
threads[coro] = SUSPENDED # Suspend.
joiners[event.child].append(coro)
have_ready = True
elif isinstance(event, KillEvent):
threads[coro] = ValueEvent(None)
kill_thread(event.child)
have_ready = True
# Only start the select when nothing else is ready.
if not have_ready:
break
# Wait and fire.
event2coro = dict((v, k) for k, v in threads.items())
for event in _event_select(threads.values()):
# Run the IO operation, but catch socket errors.
try:
value = event.fire()
except socket.error as exc:
if isinstance(exc.args, tuple) and \
exc.args[0] == errno.EPIPE:
# Broken pipe. Remote host disconnected.
pass
else:
traceback.print_exc()
# Abort the coroutine.
threads[event2coro[event]] = ReturnEvent(None)
else:
advance_thread(event2coro[event], value)
except ThreadException as te:
# Exception raised from inside a thread.
event = ExceptionEvent(te.exc_info)
if te.coro in delegators:
# The thread is a delegate. Raise exception in its
# delegator.
threads[delegators[te.coro]] = event
del delegators[te.coro]
else:
# The thread is root-level. Raise in client code.
exit_te = te
break
except:
# For instance, KeyboardInterrupt during select(). Raise
# into root thread and terminate others.
threads = {root_coro: ExceptionEvent(sys.exc_info())}
# If any threads still remain, kill them.
for coro in threads:
coro.close()
# If we're exiting with an exception, raise it in the client.
if exit_te:
exit_te.reraise()
# Sockets and their associated events.
class SocketClosedError(Exception):
pass
class Listener(object):
"""A socket wrapper object for listening sockets.
"""
def __init__(self, host, port):
"""Create a listening socket on the given hostname and port.
"""
self._closed = False
self.host = host
self.port = port
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((host, port))
self.sock.listen(5)
def accept(self):
"""An event that waits for a connection on the listening socket.
When a connection is made, the event returns a Connection
object.
"""
if self._closed:
raise SocketClosedError()
return AcceptEvent(self)
def close(self):
"""Immediately close the listening socket. (Not an event.)
"""
self._closed = True
self.sock.close()
class Connection(object):
"""A socket wrapper object for connected sockets.
"""
def __init__(self, sock, addr):
self.sock = sock
self.addr = addr
self._buf = b''
self._closed = False
def close(self):
"""Close the connection."""
self._closed = True
self.sock.close()
def recv(self, size):
"""Read at most size bytes of data from the socket."""
if self._closed:
raise SocketClosedError()
if self._buf:
# We already have data read previously.
out = self._buf[:size]
self._buf = self._buf[size:]
return ValueEvent(out)
else:
return ReceiveEvent(self, size)
def send(self, data):
"""Sends data on the socket, returning the number of bytes
successfully sent.
"""
if self._closed:
raise SocketClosedError()
return SendEvent(self, data)
def sendall(self, data):
"""Send all of data on the socket."""
if self._closed:
raise SocketClosedError()
return SendEvent(self, data, True)
def readline(self, terminator=b"\n", bufsize=1024):
"""Reads a line (delimited by terminator) from the socket."""
if self._closed:
raise SocketClosedError()
while True:
if terminator in self._buf:
line, self._buf = self._buf.split(terminator, 1)
line += terminator
yield ReturnEvent(line)
break
data = yield ReceiveEvent(self, bufsize)
if data:
self._buf += data
else:
line = self._buf
self._buf = b''
yield ReturnEvent(line)
break
class AcceptEvent(WaitableEvent):
"""An event for Listener objects (listening sockets) that suspends
execution until the socket gets a connection.
"""
def __init__(self, listener):
self.listener = listener
def waitables(self):
return (self.listener.sock,), (), ()
def fire(self):
sock, addr = self.listener.sock.accept()
return Connection(sock, addr)
class ReceiveEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously reading data.
"""
def __init__(self, conn, bufsize):
self.conn = conn
self.bufsize = bufsize
def waitables(self):
return (self.conn.sock,), (), ()
def fire(self):
return self.conn.sock.recv(self.bufsize)
class SendEvent(WaitableEvent):
"""An event for Connection objects (connected sockets) for
asynchronously writing data.
"""
def __init__(self, conn, data, sendall=False):
self.conn = conn
self.data = data
self.sendall = sendall
def waitables(self):
return (), (self.conn.sock,), ()
def fire(self):
if self.sendall:
return self.conn.sock.sendall(self.data)
else:
return self.conn.sock.send(self.data)
# Public interface for threads; each returns an event object that
# can immediately be "yield"ed.
def null():
"""Event: yield to the scheduler without doing anything special.
"""
return ValueEvent(None)
def spawn(coro):
"""Event: add another coroutine to the scheduler. Both the parent
and child coroutines run concurrently.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError('%s is not a coroutine' % str(coro))
return SpawnEvent(coro)
def call(coro):
"""Event: delegate to another coroutine. The current coroutine
is resumed once the sub-coroutine finishes. If the sub-coroutine
returns a value using end(), then this event returns that value.
"""
if not isinstance(coro, types.GeneratorType):
raise ValueError('%s is not a coroutine' % str(coro))
return DelegationEvent(coro)
def end(value=None):
"""Event: ends the coroutine and returns a value to its
delegator.
"""
return ReturnEvent(value)
def read(fd, bufsize=None):
"""Event: read from a file descriptor asynchronously."""
if bufsize is None:
# Read all.
def reader():
buf = []
while True:
data = yield read(fd, 1024)
if not data:
break
buf.append(data)
yield ReturnEvent(''.join(buf))
return DelegationEvent(reader())
else:
return ReadEvent(fd, bufsize)
def write(fd, data):
"""Event: write to a file descriptor asynchronously."""
return WriteEvent(fd, data)
def connect(host, port):
"""Event: connect to a network address and return a Connection
object for communicating on the socket.
"""
addr = (host, port)
sock = socket.create_connection(addr)
return ValueEvent(Connection(sock, addr))
def sleep(duration):
"""Event: suspend the thread for ``duration`` seconds.
"""
return SleepEvent(duration)
def join(coro):
"""Suspend the thread until another, previously `spawn`ed thread
completes.
"""
return JoinEvent(coro)
def kill(coro):
"""Halt the execution of a different `spawn`ed thread.
"""
return KillEvent(coro)
# Convenience function for running socket servers.
def server(host, port, func):
"""A coroutine that runs a network server. Host and port specify the
listening address. func should be a coroutine that takes a single
parameter, a Connection object. The coroutine is invoked for every
incoming connection on the listening socket.
"""
def handler(conn):
try:
yield func(conn)
finally:
conn.close()
listener = Listener(host, port)
try:
while True:
conn = yield listener.accept()
yield spawn(handler(conn))
except KeyboardInterrupt:
pass
finally:
listener.close()

1281
lib/beets/util/confit.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
from enum import Enum
class OrderedEnum(Enum):
"""
An Enum subclass that allows comparison of members.
"""
def __ge__(self, other):
if self.__class__ is other.__class__:
return self.value >= other.value
return NotImplemented
def __gt__(self, other):
if self.__class__ is other.__class__:
return self.value > other.value
return NotImplemented
def __le__(self, other):
if self.__class__ is other.__class__:
return self.value <= other.value
return NotImplemented
def __lt__(self, other):
if self.__class__ is other.__class__:
return self.value < other.value
return NotImplemented

View File

@@ -0,0 +1,571 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""This module implements a string formatter based on the standard PEP
292 string.Template class extended with function calls. Variables, as
with string.Template, are indicated with $ and functions are delimited
with %.
This module assumes that everything is Unicode: the template and the
substitution values. Bytestrings are not supported. Also, the templates
always behave like the ``safe_substitute`` method in the standard
library: unknown symbols are left intact.
This is sort of like a tiny, horrible degeneration of a real templating
engine like Jinja2 or Mustache.
"""
from __future__ import print_function
import re
import ast
import dis
import types
SYMBOL_DELIM = u'$'
FUNC_DELIM = u'%'
GROUP_OPEN = u'{'
GROUP_CLOSE = u'}'
ARG_SEP = u','
ESCAPE_CHAR = u'$'
VARIABLE_PREFIX = '__var_'
FUNCTION_PREFIX = '__func_'
class Environment(object):
"""Contains the values and functions to be substituted into a
template.
"""
def __init__(self, values, functions):
self.values = values
self.functions = functions
# Code generation helpers.
def ex_lvalue(name):
"""A variable load expression."""
return ast.Name(name, ast.Store())
def ex_rvalue(name):
"""A variable store expression."""
return ast.Name(name, ast.Load())
def ex_literal(val):
"""An int, float, long, bool, string, or None literal with the given
value.
"""
if val is None:
return ast.Name('None', ast.Load())
elif isinstance(val, (int, float, long)):
return ast.Num(val)
elif isinstance(val, bool):
return ast.Name(str(val), ast.Load())
elif isinstance(val, basestring):
return ast.Str(val)
raise TypeError('no literal for {0}'.format(type(val)))
def ex_varassign(name, expr):
"""Assign an expression into a single variable. The expression may
either be an `ast.expr` object or a value to be used as a literal.
"""
if not isinstance(expr, ast.expr):
expr = ex_literal(expr)
return ast.Assign([ex_lvalue(name)], expr)
def ex_call(func, args):
"""A function-call expression with only positional parameters. The
function may be an expression or the name of a function. Each
argument may be an expression or a value to be used as a literal.
"""
if isinstance(func, basestring):
func = ex_rvalue(func)
args = list(args)
for i in range(len(args)):
if not isinstance(args[i], ast.expr):
args[i] = ex_literal(args[i])
return ast.Call(func, args, [], None, None)
def compile_func(arg_names, statements, name='_the_func', debug=False):
"""Compile a list of statements as the body of a function and return
the resulting Python function. If `debug`, then print out the
bytecode of the compiled function.
"""
func_def = ast.FunctionDef(
name,
ast.arguments(
[ast.Name(n, ast.Param()) for n in arg_names],
None, None,
[ex_literal(None) for _ in arg_names],
),
statements,
[],
)
mod = ast.Module([func_def])
ast.fix_missing_locations(mod)
prog = compile(mod, '<generated>', 'exec')
# Debug: show bytecode.
if debug:
dis.dis(prog)
for const in prog.co_consts:
if isinstance(const, types.CodeType):
dis.dis(const)
the_locals = {}
exec prog in {}, the_locals
return the_locals[name]
# AST nodes for the template language.
class Symbol(object):
"""A variable-substitution symbol in a template."""
def __init__(self, ident, original):
self.ident = ident
self.original = original
def __repr__(self):
return u'Symbol(%s)' % repr(self.ident)
def evaluate(self, env):
"""Evaluate the symbol in the environment, returning a Unicode
string.
"""
if self.ident in env.values:
# Substitute for a value.
return env.values[self.ident]
else:
# Keep original text.
return self.original
def translate(self):
"""Compile the variable lookup."""
expr = ex_rvalue(VARIABLE_PREFIX + self.ident.encode('utf8'))
return [expr], set([self.ident.encode('utf8')]), set()
class Call(object):
"""A function call in a template."""
def __init__(self, ident, args, original):
self.ident = ident
self.args = args
self.original = original
def __repr__(self):
return u'Call(%s, %s, %s)' % (repr(self.ident), repr(self.args),
repr(self.original))
def evaluate(self, env):
"""Evaluate the function call in the environment, returning a
Unicode string.
"""
if self.ident in env.functions:
arg_vals = [expr.evaluate(env) for expr in self.args]
try:
out = env.functions[self.ident](*arg_vals)
except Exception as exc:
# Function raised exception! Maybe inlining the name of
# the exception will help debug.
return u'<%s>' % unicode(exc)
return unicode(out)
else:
return self.original
def translate(self):
"""Compile the function call."""
varnames = set()
funcnames = set([self.ident.encode('utf8')])
arg_exprs = []
for arg in self.args:
subexprs, subvars, subfuncs = arg.translate()
varnames.update(subvars)
funcnames.update(subfuncs)
# Create a subexpression that joins the result components of
# the arguments.
arg_exprs.append(ex_call(
ast.Attribute(ex_literal(u''), 'join', ast.Load()),
[ex_call(
'map',
[
ex_rvalue('unicode'),
ast.List(subexprs, ast.Load()),
]
)],
))
subexpr_call = ex_call(
FUNCTION_PREFIX + self.ident.encode('utf8'),
arg_exprs
)
return [subexpr_call], varnames, funcnames
class Expression(object):
"""Top-level template construct: contains a list of text blobs,
Symbols, and Calls.
"""
def __init__(self, parts):
self.parts = parts
def __repr__(self):
return u'Expression(%s)' % (repr(self.parts))
def evaluate(self, env):
"""Evaluate the entire expression in the environment, returning
a Unicode string.
"""
out = []
for part in self.parts:
if isinstance(part, basestring):
out.append(part)
else:
out.append(part.evaluate(env))
return u''.join(map(unicode, out))
def translate(self):
"""Compile the expression to a list of Python AST expressions, a
set of variable names used, and a set of function names.
"""
expressions = []
varnames = set()
funcnames = set()
for part in self.parts:
if isinstance(part, basestring):
expressions.append(ex_literal(part))
else:
e, v, f = part.translate()
expressions.extend(e)
varnames.update(v)
funcnames.update(f)
return expressions, varnames, funcnames
# Parser.
class ParseError(Exception):
pass
class Parser(object):
"""Parses a template expression string. Instantiate the class with
the template source and call ``parse_expression``. The ``pos`` field
will indicate the character after the expression finished and
``parts`` will contain a list of Unicode strings, Symbols, and Calls
reflecting the concatenated portions of the expression.
This is a terrible, ad-hoc parser implementation based on a
left-to-right scan with no lexing step to speak of; it's probably
both inefficient and incorrect. Maybe this should eventually be
replaced with a real, accepted parsing technique (PEG, parser
generator, etc.).
"""
def __init__(self, string):
self.string = string
self.pos = 0
self.parts = []
# Common parsing resources.
special_chars = (SYMBOL_DELIM, FUNC_DELIM, GROUP_OPEN, GROUP_CLOSE,
ARG_SEP, ESCAPE_CHAR)
special_char_re = re.compile(ur'[%s]|$' %
u''.join(re.escape(c) for c in special_chars))
def parse_expression(self):
"""Parse a template expression starting at ``pos``. Resulting
components (Unicode strings, Symbols, and Calls) are added to
the ``parts`` field, a list. The ``pos`` field is updated to be
the next character after the expression.
"""
text_parts = []
while self.pos < len(self.string):
char = self.string[self.pos]
if char not in self.special_chars:
# A non-special character. Skip to the next special
# character, treating the interstice as literal text.
next_pos = (
self.special_char_re.search(self.string[self.pos:]).start()
+ self.pos
)
text_parts.append(self.string[self.pos:next_pos])
self.pos = next_pos
continue
if self.pos == len(self.string) - 1:
# The last character can never begin a structure, so we
# just interpret it as a literal character (unless it
# terminates the expression, as with , and }).
if char not in (GROUP_CLOSE, ARG_SEP):
text_parts.append(char)
self.pos += 1
break
next_char = self.string[self.pos + 1]
if char == ESCAPE_CHAR and next_char in \
(SYMBOL_DELIM, FUNC_DELIM, GROUP_CLOSE, ARG_SEP):
# An escaped special character ($$, $}, etc.). Note that
# ${ is not an escape sequence: this is ambiguous with
# the start of a symbol and it's not necessary (just
# using { suffices in all cases).
text_parts.append(next_char)
self.pos += 2 # Skip the next character.
continue
# Shift all characters collected so far into a single string.
if text_parts:
self.parts.append(u''.join(text_parts))
text_parts = []
if char == SYMBOL_DELIM:
# Parse a symbol.
self.parse_symbol()
elif char == FUNC_DELIM:
# Parse a function call.
self.parse_call()
elif char in (GROUP_CLOSE, ARG_SEP):
# Template terminated.
break
elif char == GROUP_OPEN:
# Start of a group has no meaning hear; just pass
# through the character.
text_parts.append(char)
self.pos += 1
else:
assert False
# If any parsed characters remain, shift them into a string.
if text_parts:
self.parts.append(u''.join(text_parts))
def parse_symbol(self):
"""Parse a variable reference (like ``$foo`` or ``${foo}``)
starting at ``pos``. Possibly appends a Symbol object (or,
failing that, text) to the ``parts`` field and updates ``pos``.
The character at ``pos`` must, as a precondition, be ``$``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == SYMBOL_DELIM
if self.pos == len(self.string) - 1:
# Last character.
self.parts.append(SYMBOL_DELIM)
self.pos += 1
return
next_char = self.string[self.pos + 1]
start_pos = self.pos
self.pos += 1
if next_char == GROUP_OPEN:
# A symbol like ${this}.
self.pos += 1 # Skip opening.
closer = self.string.find(GROUP_CLOSE, self.pos)
if closer == -1 or closer == self.pos:
# No closing brace found or identifier is empty.
self.parts.append(self.string[start_pos:self.pos])
else:
# Closer found.
ident = self.string[self.pos:closer]
self.pos = closer + 1
self.parts.append(Symbol(ident,
self.string[start_pos:self.pos]))
else:
# A bare-word symbol.
ident = self._parse_ident()
if ident:
# Found a real symbol.
self.parts.append(Symbol(ident,
self.string[start_pos:self.pos]))
else:
# A standalone $.
self.parts.append(SYMBOL_DELIM)
def parse_call(self):
"""Parse a function call (like ``%foo{bar,baz}``) starting at
``pos``. Possibly appends a Call object to ``parts`` and update
``pos``. The character at ``pos`` must be ``%``.
"""
assert self.pos < len(self.string)
assert self.string[self.pos] == FUNC_DELIM
start_pos = self.pos
self.pos += 1
ident = self._parse_ident()
if not ident:
# No function name.
self.parts.append(FUNC_DELIM)
return
if self.pos >= len(self.string):
# Identifier terminates string.
self.parts.append(self.string[start_pos:self.pos])
return
if self.string[self.pos] != GROUP_OPEN:
# Argument list not opened.
self.parts.append(self.string[start_pos:self.pos])
return
# Skip past opening brace and try to parse an argument list.
self.pos += 1
args = self.parse_argument_list()
if self.pos >= len(self.string) or \
self.string[self.pos] != GROUP_CLOSE:
# Arguments unclosed.
self.parts.append(self.string[start_pos:self.pos])
return
self.pos += 1 # Move past closing brace.
self.parts.append(Call(ident, args, self.string[start_pos:self.pos]))
def parse_argument_list(self):
"""Parse a list of arguments starting at ``pos``, returning a
list of Expression objects. Does not modify ``parts``. Should
leave ``pos`` pointing to a } character or the end of the
string.
"""
# Try to parse a subexpression in a subparser.
expressions = []
while self.pos < len(self.string):
subparser = Parser(self.string[self.pos:])
subparser.parse_expression()
# Extract and advance past the parsed expression.
expressions.append(Expression(subparser.parts))
self.pos += subparser.pos
if self.pos >= len(self.string) or \
self.string[self.pos] == GROUP_CLOSE:
# Argument list terminated by EOF or closing brace.
break
# Only other way to terminate an expression is with ,.
# Continue to the next argument.
assert self.string[self.pos] == ARG_SEP
self.pos += 1
return expressions
def _parse_ident(self):
"""Parse an identifier and return it (possibly an empty string).
Updates ``pos``.
"""
remainder = self.string[self.pos:]
ident = re.match(ur'\w*', remainder).group(0)
self.pos += len(ident)
return ident
def _parse(template):
"""Parse a top-level template string Expression. Any extraneous text
is considered literal text.
"""
parser = Parser(template)
parser.parse_expression()
parts = parser.parts
remainder = parser.string[parser.pos:]
if remainder:
parts.append(remainder)
return Expression(parts)
# External interface.
class Template(object):
"""A string template, including text, Symbols, and Calls.
"""
def __init__(self, template):
self.expr = _parse(template)
self.original = template
self.compiled = self.translate()
def __eq__(self, other):
return self.original == other.original
def interpret(self, values={}, functions={}):
"""Like `substitute`, but forces the interpreter (rather than
the compiled version) to be used. The interpreter includes
exception-handling code for missing variables and buggy template
functions but is much slower.
"""
return self.expr.evaluate(Environment(values, functions))
def substitute(self, values={}, functions={}):
"""Evaluate the template given the values and functions.
"""
try:
res = self.compiled(values, functions)
except: # Handle any exceptions thrown by compiled version.
res = self.interpret(values, functions)
return res
def translate(self):
"""Compile the template to a Python function."""
expressions, varnames, funcnames = self.expr.translate()
argnames = []
for varname in varnames:
argnames.append(VARIABLE_PREFIX.encode('utf8') + varname)
for funcname in funcnames:
argnames.append(FUNCTION_PREFIX.encode('utf8') + funcname)
func = compile_func(
argnames,
[ast.Return(ast.List(expressions, ast.Load()))],
)
def wrapper_func(values={}, functions={}):
args = {}
for varname in varnames:
args[VARIABLE_PREFIX + varname] = values[varname]
for funcname in funcnames:
args[FUNCTION_PREFIX + funcname] = functions[funcname]
parts = func(**args)
return u''.join(parts)
return wrapper_func
# Performance tests.
if __name__ == '__main__':
import timeit
_tmpl = Template(u'foo $bar %baz{foozle $bar barzle} $bar')
_vars = {'bar': 'qux'}
_funcs = {'baz': unicode.upper}
interp_time = timeit.timeit('_tmpl.interpret(_vars, _funcs)',
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
print(interp_time)
comp_time = timeit.timeit('_tmpl.substitute(_vars, _funcs)',
'from __main__ import _tmpl, _vars, _funcs',
number=10000)
print(comp_time)
print('Speedup:', interp_time / comp_time)

517
lib/beets/util/pipeline.py Normal file
View File

@@ -0,0 +1,517 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Simple but robust implementation of generator/coroutine-based
pipelines in Python. The pipelines may be run either sequentially
(single-threaded) or in parallel (one thread per pipeline stage).
This implementation supports pipeline bubbles (indications that the
processing for a certain item should abort). To use them, yield the
BUBBLE constant from any stage coroutine except the last.
In the parallel case, the implementation transparently handles thread
shutdown when the processing is complete and when a stage raises an
exception. KeyboardInterrupts (^C) are also handled.
When running a parallel pipeline, it is also possible to use
multiple coroutines for the same pipeline stage; this lets you speed
up a bottleneck stage by dividing its work among multiple threads.
To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
from __future__ import print_function
import Queue
from threading import Thread, Lock
import sys
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size. get()ing from the queue returns `val`,
which defaults to None. `sync` controls whether a lock is
required (because it's not reentrant!).
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return val
if sync:
q.mutex.acquire()
try:
q.maxsize = 0
q._qsize = _qsize
q._put = _put
q._get = _get
q.not_empty.notifyAll()
q.not_full.notifyAll()
finally:
if sync:
q.mutex.release()
class CountedQueue(Queue.Queue):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
finished with the queue.
"""
def __init__(self, maxsize=0):
Queue.Queue.__init__(self, maxsize)
self.nthreads = 0
self.poisoned = False
def acquire(self):
"""Indicate that a thread will start putting into this queue.
Should not be called after the queue is already poisoned.
"""
with self.mutex:
assert not self.poisoned
assert self.nthreads >= 0
self.nthreads += 1
def release(self):
"""Indicate that a thread that was putting into this queue has
exited. If this is the last thread using the queue, the queue
is poisoned.
"""
with self.mutex:
self.nthreads -= 1
assert self.nthreads >= 0
if self.nthreads == 0:
# All threads are done adding to this queue. Poison it
# when it becomes empty.
self.poisoned = True
# Replacement _get invalidates when no items remain.
_old_get = self._get
def _get():
out = _old_get()
if not self.queue:
_invalidate_queue(self, POISON, False)
return out
if self.queue:
# Items remain.
self._get = _get
else:
# No items. Invalidate immediately.
_invalidate_queue(self, POISON, False)
class MultiMessage(object):
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def stage(func):
"""Decorate a function to become a simple stage.
>>> @stage
... def add(n, i):
... return i + n
>>> pipe = Pipeline([
... iter([1, 2, 3]),
... add(2),
... ])
>>> list(pipe.pull())
[3, 4, 5]
"""
def coro(*args):
task = None
while True:
task = yield task
task = func(*(args + (task,)))
return coro
def mutator_stage(func):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
>>> @mutator_stage
... def setkey(key, item):
... item[key] = True
>>> pipe = Pipeline([
... iter([{'x': False}, {'a': False}]),
... setkey('x'),
... ])
>>> list(pipe.pull())
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args):
task = None
while True:
task = yield task
func(*(args + (task,)))
return coro
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
returns an empty list. Otherwise, returns a list containing obj.
"""
if isinstance(obj, MultiMessage):
return obj.messages
elif obj == BUBBLE:
return []
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self, all_threads):
super(PipelineThread, self).__init__()
self.abort_lock = Lock()
self.abort_flag = False
self.all_threads = all_threads
self.exc_info = None
def abort(self):
"""Shut down the thread at the next chance possible.
"""
with self.abort_lock:
self.abort_flag = True
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
_invalidate_queue(self.in_queue, POISON)
if hasattr(self, 'out_queue'):
_invalidate_queue(self.out_queue, POISON)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
"""
self.exc_info = exc_info
for thread in self.all_threads:
thread.abort()
class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
"""
def __init__(self, coro, out_queue, all_threads):
super(FirstPipelineThread, self).__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
self.out_queue.acquire()
self.abort_lock = Lock()
self.abort_flag = False
def run(self):
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the value from the generator.
try:
msg = self.coro.next()
except StopIteration:
break
# Send messages to the next stage.
for msg in _allmsgs(msg):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
return
# Generator finished; shut down the pipeline.
self.out_queue.release()
class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
"""
def __init__(self, coro, in_queue, out_queue, all_threads):
super(MiddlePipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
self.out_queue.acquire()
def run(self):
try:
# Prime the coroutine.
self.coro.next()
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Invoke the current stage.
out = self.coro.send(msg)
# Send messages to next stage.
for msg in _allmsgs(out):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.out_queue.release()
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
"""
def __init__(self, coro, in_queue, all_threads):
super(LastPipelineThread, self).__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
def run(self):
# Prime the coroutine.
self.coro.next()
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Send to consumer.
self.coro.send(msg)
except:
self.abort_all(sys.exc_info())
return
class Pipeline(object):
"""Represents a staged pattern of work. Each stage in the pipeline
is a coroutine that receives messages from the previous stage and
yields messages to be sent to the next stage.
"""
def __init__(self, stages):
"""Makes a new pipeline from a list of coroutines. There must
be at least two stages.
"""
if len(stages) < 2:
raise ValueError('pipeline must have at least two stages')
self.stages = []
for stage in stages:
if isinstance(stage, (list, tuple)):
self.stages.append(stage)
else:
# Default to one thread per stage.
self.stages.append((stage,))
def run_sequential(self):
"""Run the pipeline sequentially in the current thread. The
stages are run one after the other. Only the first coroutine
in each stage is used.
"""
list(self.pull())
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The
messages between the stages are stored in queues of the given
size.
"""
queue_count = len(self.stages) - 1
queues = [CountedQueue(queue_size) for i in range(queue_count)]
threads = []
# Set up first stage.
for coro in self.stages[0]:
threads.append(FirstPipelineThread(coro, queues[0], threads))
# Middle stages.
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(MiddlePipelineThread(
coro, queues[i - 1], queues[i], threads
))
# Last stage.
for coro in self.stages[-1]:
threads.append(
LastPipelineThread(coro, queues[-1], threads)
)
# Start threads.
for thread in threads:
thread.start()
# Wait for termination. The final thread lasts the longest.
try:
# Using a timeout allows us to receive KeyboardInterrupt
# exceptions during the join().
while threads[-1].isAlive():
threads[-1].join(1)
except:
# Stop all the threads immediately.
for thread in threads:
thread.abort()
raise
finally:
# Make completely sure that all the threads have finished
# before we return. They should already be either finished,
# in normal operation, or aborted, in case of an exception.
for thread in threads[:-1]:
thread.join()
for thread in threads:
exc_info = thread.exc_info
if exc_info:
# Make the exception appear as it was raised originally.
raise exc_info[0], exc_info[1], exc_info[2]
def pull(self):
"""Yield elements from the end of the pipeline. Runs the stages
sequentially until the last yields some messages. Each of the messages
is then yielded by ``pulled.next()``. If the pipeline has a consumer,
that is the last stage does not yield any messages, then pull will not
yield any messages. Only the first coroutine in each stage is used
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in coros[1:]:
coro.next()
# Begin the pipeline.
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
for msg in msgs:
yield msg
# Smoke test.
if __name__ == '__main__':
import time
# Test a normally-terminating pipeline both in sequence and
# in parallel.
def produce():
for i in range(5):
print('generating %i' % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
print('processing %i' % num)
time.sleep(2)
num = yield num * 2
def consume():
while True:
num = yield
time.sleep(1)
print('received %i' % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
ts_seq = time.time()
Pipeline([produce(), work(), consume()]).run_parallel()
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
print('Sequential time:', ts_seq - ts_start)
print('Parallel time:', ts_par - ts_seq)
print('Multiply-parallel time:', ts_end - ts_par)
print()
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print('generating %i' % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print('processing %i' % num)
time.sleep(3)
if num == 3:
raise Exception()
num = yield num * 2
def exc_consume():
while True:
num = yield
print('received %i' % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)

50
lib/beets/vfs.py Normal file
View File

@@ -0,0 +1,50 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""A simple utility for constructing filesystem-like trees from beets
libraries.
"""
from collections import namedtuple
from beets import util
Node = namedtuple('Node', ['files', 'dirs'])
def _insert(node, path, itemid):
"""Insert an item into a virtual filesystem node."""
if len(path) == 1:
# Last component. Insert file.
node.files[path[0]] = itemid
else:
# In a directory.
dirname = path[0]
rest = path[1:]
if dirname not in node.dirs:
node.dirs[dirname] = Node({}, {})
_insert(node.dirs[dirname], rest, itemid)
def libtree(lib):
"""Generates a filesystem-like directory tree for the files
contained in `lib`. Filesystem nodes are (files, dirs) named
tuples in which both components are dictionaries. The first
maps filenames to Item ids. The second maps directory names to
child node tuples.
"""
root = Node({}, {})
for item in lib.items():
dest = item.destination(fragment=True)
parts = util.components(dest)
_insert(root, parts, item.id)
return root

19
lib/beetsplug/__init__.py Normal file
View File

@@ -0,0 +1,19 @@
# This file is part of beets.
# Copyright 2013, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""A namespace package for beets plugins."""
# Make this a namespace package.
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)

280
lib/beetsplug/embedart.py Normal file
View File

@@ -0,0 +1,280 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Allows beets to embed album art into file metadata."""
import os.path
import logging
import imghdr
import subprocess
import platform
from tempfile import NamedTemporaryFile
from beets.plugins import BeetsPlugin
from beets import mediafile
from beets import ui
from beets.ui import decargs
from beets.util import syspath, normpath, displayable_path
from beets.util.artresizer import ArtResizer
from beets import config
log = logging.getLogger('beets')
class EmbedCoverArtPlugin(BeetsPlugin):
"""Allows albumart to be embedded into the actual files.
"""
def __init__(self):
super(EmbedCoverArtPlugin, self).__init__()
self.config.add({
'maxwidth': 0,
'auto': True,
'compare_threshold': 0,
'ifempty': False,
})
if self.config['maxwidth'].get(int) and not ArtResizer.shared.local:
self.config['maxwidth'] = 0
log.warn(u"embedart: ImageMagick or PIL not found; "
u"'maxwidth' option ignored")
if self.config['compare_threshold'].get(int) and not \
ArtResizer.shared.can_compare:
self.config['compare_threshold'] = 0
log.warn(u"embedart: ImageMagick 6.8.7 or higher not installed; "
u"'compare_threshold' option ignored")
def commands(self):
# Embed command.
embed_cmd = ui.Subcommand(
'embedart', help='embed image files into file metadata'
)
embed_cmd.parser.add_option(
'-f', '--file', metavar='PATH', help='the image file to embed'
)
maxwidth = config['embedart']['maxwidth'].get(int)
compare_threshold = config['embedart']['compare_threshold'].get(int)
ifempty = config['embedart']['ifempty'].get(bool)
def embed_func(lib, opts, args):
if opts.file:
imagepath = normpath(opts.file)
for item in lib.items(decargs(args)):
embed_item(item, imagepath, maxwidth, None,
compare_threshold, ifempty)
else:
for album in lib.albums(decargs(args)):
embed_album(album, maxwidth)
embed_cmd.func = embed_func
# Extract command.
extract_cmd = ui.Subcommand('extractart',
help='extract an image from file metadata')
extract_cmd.parser.add_option('-o', dest='outpath',
help='image output file')
def extract_func(lib, opts, args):
outpath = normpath(opts.outpath or 'cover')
item = lib.items(decargs(args)).get()
extract(outpath, item)
extract_cmd.func = extract_func
# Clear command.
clear_cmd = ui.Subcommand('clearart',
help='remove images from file metadata')
def clear_func(lib, opts, args):
clear(lib, decargs(args))
clear_cmd.func = clear_func
return [embed_cmd, extract_cmd, clear_cmd]
@EmbedCoverArtPlugin.listen('album_imported')
def album_imported(lib, album):
"""Automatically embed art into imported albums.
"""
if album.artpath and config['embedart']['auto']:
embed_album(album, config['embedart']['maxwidth'].get(int), True)
def embed_item(item, imagepath, maxwidth=None, itempath=None,
compare_threshold=0, ifempty=False, as_album=False):
"""Embed an image into the item's media file.
"""
if compare_threshold:
if not check_art_similarity(item, imagepath, compare_threshold):
log.warn(u'Image not similar; skipping.')
return
if ifempty:
art = get_art(item)
if not art:
pass
else:
log.debug(u'embedart: media file contained art already {0}'.format(
displayable_path(imagepath)
))
return
if maxwidth and not as_album:
imagepath = resize_image(imagepath, maxwidth)
try:
log.debug(u'embedart: embedding {0}'.format(
displayable_path(imagepath)
))
item['images'] = [_mediafile_image(imagepath, maxwidth)]
except IOError as exc:
log.error(u'embedart: could not read image file: {0}'.format(exc))
else:
# We don't want to store the image in the database.
item.try_write(itempath)
del item['images']
def embed_album(album, maxwidth=None, quiet=False):
"""Embed album art into all of the album's items.
"""
imagepath = album.artpath
if not imagepath:
log.info(u'No album art present: {0} - {1}'.
format(album.albumartist, album.album))
return
if not os.path.isfile(syspath(imagepath)):
log.error(u'Album art not found at {0}'
.format(displayable_path(imagepath)))
return
if maxwidth:
imagepath = resize_image(imagepath, maxwidth)
log.log(
logging.DEBUG if quiet else logging.INFO,
u'Embedding album art into {0.albumartist} - {0.album}.'.format(album),
)
for item in album.items():
embed_item(item, imagepath, maxwidth, None,
config['embedart']['compare_threshold'].get(int),
config['embedart']['ifempty'].get(bool), as_album=True)
def resize_image(imagepath, maxwidth):
"""Returns path to an image resized to maxwidth.
"""
log.info(u'Resizing album art to {0} pixels wide'
.format(maxwidth))
imagepath = ArtResizer.shared.resize(maxwidth, syspath(imagepath))
return imagepath
def check_art_similarity(item, imagepath, compare_threshold):
"""A boolean indicating if an image is similar to embedded item art.
"""
with NamedTemporaryFile(delete=True) as f:
art = extract(f.name, item)
if art:
# Converting images to grayscale tends to minimize the weight
# of colors in the diff score
cmd = 'convert {0} {1} -colorspace gray MIFF:- | ' \
'compare -metric PHASH - null:'.format(syspath(imagepath),
syspath(art))
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
close_fds=platform.system() != 'Windows',
shell=True)
stdout, stderr = proc.communicate()
if proc.returncode:
if proc.returncode != 1:
log.warn(u'embedart: IM phashes compare failed for {0}, \
{1}'.format(displayable_path(imagepath),
displayable_path(art)))
return
phashDiff = float(stderr)
else:
phashDiff = float(stdout)
log.info(u'embedart: compare PHASH score is {0}'.format(phashDiff))
if phashDiff > compare_threshold:
return False
return True
def _mediafile_image(image_path, maxwidth=None):
"""Return a `mediafile.Image` object for the path.
"""
with open(syspath(image_path), 'rb') as f:
data = f.read()
return mediafile.Image(data, type=mediafile.ImageType.front)
def get_art(item):
# Extract the art.
try:
mf = mediafile.MediaFile(syspath(item.path))
except mediafile.UnreadableFileError as exc:
log.error(u'Could not extract art from {0}: {1}'.format(
displayable_path(item.path), exc
))
return
return mf.art
# 'extractart' command.
def extract(outpath, item):
if not item:
log.error(u'No item matches query.')
return
art = get_art(item)
if not art:
log.error(u'No album art present in {0} - {1}.'
.format(item.artist, item.title))
return
# Add an extension to the filename.
ext = imghdr.what(None, h=art)
if not ext:
log.error(u'Unknown image type.')
return
outpath += '.' + ext
log.info(u'Extracting album art from: {0.artist} - {0.title} '
u'to: {1}'.format(item, displayable_path(outpath)))
with open(syspath(outpath), 'wb') as f:
f.write(art)
return outpath
# 'clearart' command.
def clear(lib, query):
log.info(u'Clearing album art from items:')
for item in lib.items(query):
log.info(u'{0} - {1}'.format(item.artist, item.title))
try:
mf = mediafile.MediaFile(syspath(item.path),
config['id3v23'].get(bool))
except mediafile.UnreadableFileError as exc:
log.error(u'Could not clear art from {0}: {1}'.format(
displayable_path(item.path), exc
))
continue
del mf.art
mf.save()

396
lib/beetsplug/fetchart.py Normal file
View File

@@ -0,0 +1,396 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Fetches album art.
"""
from contextlib import closing
import logging
import os
import re
from tempfile import NamedTemporaryFile
import requests
from beets import plugins
from beets import importer
from beets import ui
from beets import util
from beets import config
from beets.util.artresizer import ArtResizer
try:
import itunes
HAVE_ITUNES = True
except ImportError:
HAVE_ITUNES = False
IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg']
CONTENT_TYPES = ('image/jpeg',)
DOWNLOAD_EXTENSION = '.jpg'
log = logging.getLogger('beets')
requests_session = requests.Session()
requests_session.headers = {'User-Agent': 'beets'}
def _fetch_image(url):
"""Downloads an image from a URL and checks whether it seems to
actually be an image. If so, returns a path to the downloaded image.
Otherwise, returns None.
"""
log.debug(u'fetchart: downloading art: {0}'.format(url))
try:
with closing(requests_session.get(url, stream=True)) as resp:
if 'Content-Type' not in resp.headers \
or resp.headers['Content-Type'] not in CONTENT_TYPES:
log.debug(u'fetchart: not an image')
return
# Generate a temporary file with the correct extension.
with NamedTemporaryFile(suffix=DOWNLOAD_EXTENSION, delete=False) \
as fh:
for chunk in resp.iter_content():
fh.write(chunk)
log.debug(u'fetchart: downloaded art to: {0}'.format(
util.displayable_path(fh.name)
))
return fh.name
except (IOError, requests.RequestException):
log.debug(u'fetchart: error fetching art')
# ART SOURCES ################################################################
# Cover Art Archive.
CAA_URL = 'http://coverartarchive.org/release/{mbid}/front-500.jpg'
CAA_GROUP_URL = 'http://coverartarchive.org/release-group/{mbid}/front-500.jpg'
def caa_art(album):
"""Return the Cover Art Archive and Cover Art Archive release group URLs
using album MusicBrainz release ID and release group ID.
"""
if album.mb_albumid:
yield CAA_URL.format(mbid=album.mb_albumid)
if album.mb_releasegroupid:
yield CAA_GROUP_URL.format(mbid=album.mb_releasegroupid)
# Art from Amazon.
AMAZON_URL = 'http://images.amazon.com/images/P/%s.%02i.LZZZZZZZ.jpg'
AMAZON_INDICES = (1, 2)
def art_for_asin(album):
"""Generate URLs using Amazon ID (ASIN) string.
"""
if album.asin:
for index in AMAZON_INDICES:
yield AMAZON_URL % (album.asin, index)
# AlbumArt.org scraper.
AAO_URL = 'http://www.albumart.org/index_detail.php'
AAO_PAT = r'href\s*=\s*"([^>"]*)"[^>]*title\s*=\s*"View larger image"'
def aao_art(album):
"""Return art URL from AlbumArt.org using album ASIN.
"""
if not album.asin:
return
# Get the page from albumart.org.
try:
resp = requests_session.get(AAO_URL, params={'asin': album.asin})
log.debug(u'fetchart: scraped art URL: {0}'.format(resp.url))
except requests.RequestException:
log.debug(u'fetchart: error scraping art page')
return
# Search the page for the image URL.
m = re.search(AAO_PAT, resp.text)
if m:
image_url = m.group(1)
yield image_url
else:
log.debug(u'fetchart: no image found on page')
# Google Images scraper.
GOOGLE_URL = 'https://ajax.googleapis.com/ajax/services/search/images'
def google_art(album):
"""Return art URL from google.org given an album title and
interpreter.
"""
if not (album.albumartist and album.album):
return
search_string = (album.albumartist + ',' + album.album).encode('utf-8')
response = requests_session.get(GOOGLE_URL, params={
'v': '1.0',
'q': search_string,
'start': '0',
})
# Get results using JSON.
try:
results = response.json()
data = results['responseData']
dataInfo = data['results']
for myUrl in dataInfo:
yield myUrl['unescapedUrl']
except:
log.debug(u'fetchart: error scraping art page')
return
# Art from the iTunes Store.
def itunes_art(album):
"""Return art URL from iTunes Store given an album title.
"""
search_string = (album.albumartist + ' ' + album.album).encode('utf-8')
try:
# Isolate bugs in the iTunes library while searching.
try:
itunes_album = itunes.search_album(search_string)[0]
except Exception as exc:
log.debug('fetchart: iTunes search failed: {0}'.format(exc))
return
if itunes_album.get_artwork()['100']:
small_url = itunes_album.get_artwork()['100']
big_url = small_url.replace('100x100', '1200x1200')
yield big_url
else:
log.debug(u'fetchart: album has no artwork in iTunes Store')
except IndexError:
log.debug(u'fetchart: album not found in iTunes Store')
# Art from the filesystem.
def filename_priority(filename, cover_names):
"""Sort order for image names.
Return indexes of cover names found in the image filename. This
means that images with lower-numbered and more keywords will have higher
priority.
"""
return [idx for (idx, x) in enumerate(cover_names) if x in filename]
def art_in_path(path, cover_names, cautious):
"""Look for album art files in a specified directory.
"""
if not os.path.isdir(path):
return
# Find all files that look like images in the directory.
images = []
for fn in os.listdir(path):
for ext in IMAGE_EXTENSIONS:
if fn.lower().endswith('.' + ext):
images.append(fn)
# Look for "preferred" filenames.
images = sorted(images, key=lambda x: filename_priority(x, cover_names))
cover_pat = r"(\b|_)({0})(\b|_)".format('|'.join(cover_names))
for fn in images:
if re.search(cover_pat, os.path.splitext(fn)[0], re.I):
log.debug(u'fetchart: using well-named art file {0}'.format(
util.displayable_path(fn)
))
return os.path.join(path, fn)
# Fall back to any image in the folder.
if images and not cautious:
log.debug(u'fetchart: using fallback art file {0}'.format(
util.displayable_path(images[0])
))
return os.path.join(path, images[0])
# Try each source in turn.
SOURCES_ALL = [u'coverart', u'itunes', u'amazon', u'albumart', u'google']
ART_FUNCS = {
u'coverart': caa_art,
u'itunes': itunes_art,
u'albumart': aao_art,
u'amazon': art_for_asin,
u'google': google_art,
}
def _source_urls(album, sources=SOURCES_ALL):
"""Generate possible source URLs for an album's art. The URLs are
not guaranteed to work so they each need to be attempted in turn.
This allows the main `art_for_album` function to abort iteration
through this sequence early to avoid the cost of scraping when not
necessary.
"""
for s in sources:
urls = ART_FUNCS[s](album)
for url in urls:
yield url
def art_for_album(album, paths, maxwidth=None, local_only=False):
"""Given an Album object, returns a path to downloaded art for the
album (or None if no art is found). If `maxwidth`, then images are
resized to this maximum pixel size. If `local_only`, then only local
image files from the filesystem are returned; no network requests
are made.
"""
out = None
# Local art.
cover_names = config['fetchart']['cover_names'].as_str_seq()
cover_names = map(util.bytestring_path, cover_names)
cautious = config['fetchart']['cautious'].get(bool)
if paths:
for path in paths:
out = art_in_path(path, cover_names, cautious)
if out:
break
# Web art sources.
remote_priority = config['fetchart']['remote_priority'].get(bool)
if not local_only and (remote_priority or not out):
for url in _source_urls(album,
config['fetchart']['sources'].as_str_seq()):
if maxwidth:
url = ArtResizer.shared.proxy_url(maxwidth, url)
candidate = _fetch_image(url)
if candidate:
out = candidate
break
if maxwidth and out:
out = ArtResizer.shared.resize(maxwidth, out)
return out
# PLUGIN LOGIC ###############################################################
def batch_fetch_art(lib, albums, force, maxwidth=None):
"""Fetch album art for each of the albums. This implements the manual
fetchart CLI command.
"""
for album in albums:
if album.artpath and not force:
message = 'has album art'
else:
# In ordinary invocations, look for images on the
# filesystem. When forcing, however, always go to the Web
# sources.
local_paths = None if force else [album.path]
path = art_for_album(album, local_paths, maxwidth)
if path:
album.set_art(path, False)
album.store()
message = ui.colorize('green', 'found album art')
else:
message = ui.colorize('red', 'no art found')
log.info(u'{0} - {1}: {2}'.format(album.albumartist, album.album,
message))
class FetchArtPlugin(plugins.BeetsPlugin):
def __init__(self):
super(FetchArtPlugin, self).__init__()
self.config.add({
'auto': True,
'maxwidth': 0,
'remote_priority': False,
'cautious': False,
'google_search': False,
'cover_names': ['cover', 'front', 'art', 'album', 'folder'],
'sources': SOURCES_ALL,
})
# Holds paths to downloaded images between fetching them and
# placing them in the filesystem.
self.art_paths = {}
self.maxwidth = self.config['maxwidth'].get(int)
if self.config['auto']:
# Enable two import hooks when fetching is enabled.
self.import_stages = [self.fetch_art]
self.register_listener('import_task_files', self.assign_art)
available_sources = list(SOURCES_ALL)
if not HAVE_ITUNES and u'itunes' in available_sources:
available_sources.remove(u'itunes')
self.config['sources'] = plugins.sanitize_choices(
self.config['sources'].as_str_seq(), available_sources)
# Asynchronous; after music is added to the library.
def fetch_art(self, session, task):
"""Find art for the album being imported."""
if task.is_album: # Only fetch art for full albums.
if task.choice_flag == importer.action.ASIS:
# For as-is imports, don't search Web sources for art.
local = True
elif task.choice_flag == importer.action.APPLY:
# Search everywhere for art.
local = False
else:
# For any other choices (e.g., TRACKS), do nothing.
return
path = art_for_album(task.album, task.paths, self.maxwidth, local)
if path:
self.art_paths[task] = path
# Synchronous; after music files are put in place.
def assign_art(self, session, task):
"""Place the discovered art in the filesystem."""
if task in self.art_paths:
path = self.art_paths.pop(task)
album = task.album
src_removed = (config['import']['delete'].get(bool) or
config['import']['move'].get(bool))
album.set_art(path, not src_removed)
album.store()
if src_removed:
task.prune(path)
# Manual album art fetching.
def commands(self):
cmd = ui.Subcommand('fetchart', help='download album art')
cmd.parser.add_option('-f', '--force', dest='force',
action='store_true', default=False,
help='re-download art when already present')
def func(lib, opts, args):
batch_fetch_art(lib, lib.albums(ui.decargs(args)), opts.force,
self.maxwidth)
cmd.func = func
return [cmd]

544
lib/beetsplug/lyrics.py Normal file
View File

@@ -0,0 +1,544 @@
# This file is part of beets.
# Copyright 2014, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Fetches, embeds, and displays lyrics.
"""
from __future__ import print_function
import re
import logging
import requests
import json
import unicodedata
import urllib
import difflib
import itertools
from HTMLParser import HTMLParseError
from beets import plugins
from beets import config, ui
# Global logger.
log = logging.getLogger('beets')
DIV_RE = re.compile(r'<(/?)div>?', re.I)
COMMENT_RE = re.compile(r'<!--.*-->', re.S)
TAG_RE = re.compile(r'<[^>]*>')
BREAK_RE = re.compile(r'\n?\s*<br([\s|/][^>]*)*>\s*\n?', re.I)
URL_CHARACTERS = {
u'\u2018': u"'",
u'\u2019': u"'",
u'\u201c': u'"',
u'\u201d': u'"',
u'\u2010': u'-',
u'\u2011': u'-',
u'\u2012': u'-',
u'\u2013': u'-',
u'\u2014': u'-',
u'\u2015': u'-',
u'\u2016': u'-',
u'\u2026': u'...',
}
# Utilities.
def fetch_url(url):
"""Retrieve the content at a given URL, or return None if the source
is unreachable.
"""
try:
r = requests.get(url, verify=False)
except requests.RequestException as exc:
log.debug(u'lyrics request failed: {0}'.format(exc))
return
if r.status_code == requests.codes.ok:
return r.text
else:
log.debug(u'failed to fetch: {0} ({1})'.format(url, r.status_code))
def unescape(text):
"""Resolves &#xxx; HTML entities (and some others)."""
if isinstance(text, str):
text = text.decode('utf8', 'ignore')
out = text.replace(u'&nbsp;', u' ')
def replchar(m):
num = m.group(1)
return unichr(int(num))
out = re.sub(u"&#(\d+);", replchar, out)
return out
def extract_text_between(html, start_marker, end_marker):
try:
_, html = html.split(start_marker, 1)
html, _ = html.split(end_marker, 1)
except ValueError:
return u''
return html
def extract_text_in(html, starttag):
"""Extract the text from a <DIV> tag in the HTML starting with
``starttag``. Returns None if parsing fails.
"""
# Strip off the leading text before opening tag.
try:
_, html = html.split(starttag, 1)
except ValueError:
return
# Walk through balanced DIV tags.
level = 0
parts = []
pos = 0
for match in DIV_RE.finditer(html):
if match.group(1): # Closing tag.
level -= 1
if level == 0:
pos = match.end()
else: # Opening tag.
if level == 0:
parts.append(html[pos:match.start()])
level += 1
if level == -1:
parts.append(html[pos:match.start()])
break
else:
print('no closing tag found!')
return
return u''.join(parts)
def search_pairs(item):
"""Yield a pairs of artists and titles to search for.
The first item in the pair is the name of the artist, the second
item is a list of song names.
In addition to the artist and title obtained from the `item` the
method tries to strip extra information like paranthesized suffixes
and featured artists from the strings and add them as candidates.
The method also tries to split multiple titles separated with `/`.
"""
title, artist = item.title, item.artist
titles = [title]
artists = [artist]
# Remove any featuring artists from the artists name
pattern = r"(.*?) {0}".format(plugins.feat_tokens())
match = re.search(pattern, artist, re.IGNORECASE)
if match:
artists.append(match.group(1))
# Remove a parenthesized suffix from a title string. Common
# examples include (live), (remix), and (acoustic).
pattern = r"(.+?)\s+[(].*[)]$"
match = re.search(pattern, title, re.IGNORECASE)
if match:
titles.append(match.group(1))
# Remove any featuring artists from the title
pattern = r"(.*?) {0}".format(plugins.feat_tokens(for_artist=False))
for title in titles[:]:
match = re.search(pattern, title, re.IGNORECASE)
if match:
titles.append(match.group(1))
# Check for a dual song (e.g. Pink Floyd - Speak to Me / Breathe)
# and each of them.
multi_titles = []
for title in titles:
multi_titles.append([title])
if '/' in title:
multi_titles.append([x.strip() for x in title.split('/')])
return itertools.product(artists, multi_titles)
def _encode(s):
"""Encode the string for inclusion in a URL (common to both
LyricsWiki and Lyrics.com).
"""
if isinstance(s, unicode):
for char, repl in URL_CHARACTERS.items():
s = s.replace(char, repl)
s = s.encode('utf8', 'ignore')
return urllib.quote(s)
# Musixmatch
MUSIXMATCH_URL_PATTERN = 'https://www.musixmatch.com/lyrics/%s/%s'
def fetch_musixmatch(artist, title):
url = MUSIXMATCH_URL_PATTERN % (_lw_encode(artist.title()),
_lw_encode(title.title()))
html = fetch_url(url)
if not html:
return
lyrics = extract_text_between(html, '"lyrics_body":', '"lyrics_language":')
return lyrics.strip(',"').replace('\\n', '\n')
# LyricsWiki.
LYRICSWIKI_URL_PATTERN = 'http://lyrics.wikia.com/%s:%s'
def _lw_encode(s):
s = re.sub(r'\s+', '_', s)
s = s.replace("<", "Less_Than")
s = s.replace(">", "Greater_Than")
s = s.replace("#", "Number_")
s = re.sub(r'[\[\{]', '(', s)
s = re.sub(r'[\]\}]', ')', s)
return _encode(s)
def fetch_lyricswiki(artist, title):
"""Fetch lyrics from LyricsWiki."""
url = LYRICSWIKI_URL_PATTERN % (_lw_encode(artist), _lw_encode(title))
html = fetch_url(url)
if not html:
return
lyrics = extract_text_in(html, u"<div class='lyricbox'>")
if lyrics and 'Unfortunately, we are not licensed' not in lyrics:
return lyrics
# Lyrics.com.
LYRICSCOM_URL_PATTERN = 'http://www.lyrics.com/%s-lyrics-%s.html'
LYRICSCOM_NOT_FOUND = (
'Sorry, we do not have the lyric',
'Submit Lyrics',
)
def _lc_encode(s):
s = re.sub(r'[^\w\s-]', '', s)
s = re.sub(r'\s+', '-', s)
return _encode(s).lower()
def fetch_lyricscom(artist, title):
"""Fetch lyrics from Lyrics.com."""
url = LYRICSCOM_URL_PATTERN % (_lc_encode(title), _lc_encode(artist))
html = fetch_url(url)
if not html:
return
lyrics = extract_text_between(html, '<div id="lyrics" class="SCREENONLY" '
'itemprop="description">', '</div>')
if not lyrics:
return
for not_found_str in LYRICSCOM_NOT_FOUND:
if not_found_str in lyrics:
return
parts = lyrics.split('\n---\nLyrics powered by', 1)
if parts:
return parts[0]
# Optional Google custom search API backend.
def slugify(text):
"""Normalize a string and remove non-alphanumeric characters.
"""
text = re.sub(r"[-'_\s]", '_', text)
text = re.sub(r"_+", '_', text).strip('_')
pat = "([^,\(]*)\((.*?)\)" # Remove content within parentheses
text = re.sub(pat, '\g<1>', text).strip()
try:
text = unicodedata.normalize('NFKD', text).encode('ascii', 'ignore')
text = unicode(re.sub('[-\s]+', ' ', text))
except UnicodeDecodeError:
log.exception(u"Failing to normalize '{0}'".format(text))
return text
BY_TRANS = ['by', 'par', 'de', 'von']
LYRICS_TRANS = ['lyrics', 'paroles', 'letras', 'liedtexte']
def is_page_candidate(urlLink, urlTitle, title, artist):
"""Return True if the URL title makes it a good candidate to be a
page that contains lyrics of title by artist.
"""
title = slugify(title.lower())
artist = slugify(artist.lower())
sitename = re.search(u"//([^/]+)/.*", slugify(urlLink.lower())).group(1)
urlTitle = slugify(urlTitle.lower())
# Check if URL title contains song title (exact match)
if urlTitle.find(title) != -1:
return True
# or try extracting song title from URL title and check if
# they are close enough
tokens = [by + '_' + artist for by in BY_TRANS] + \
[artist, sitename, sitename.replace('www.', '')] + LYRICS_TRANS
songTitle = re.sub(u'(%s)' % u'|'.join(tokens), u'', urlTitle)
songTitle = songTitle.strip('_|')
typoRatio = .9
return difflib.SequenceMatcher(None, songTitle, title).ratio() >= typoRatio
def remove_credits(text):
"""Remove first/last line of text if it contains the word 'lyrics'
eg 'Lyrics by songsdatabase.com'
"""
textlines = text.split('\n')
credits = None
for i in (0, -1):
if textlines and 'lyrics' in textlines[i].lower():
credits = textlines.pop(i)
if credits:
text = '\n'.join(textlines)
return text
def is_lyrics(text, artist=None):
"""Determine whether the text seems to be valid lyrics.
"""
if not text:
return False
badTriggersOcc = []
nbLines = text.count('\n')
if nbLines <= 1:
log.debug(u"Ignoring too short lyrics '{0}'".format(text))
return False
elif nbLines < 5:
badTriggersOcc.append('too_short')
else:
# Lyrics look legit, remove credits to avoid being penalized further
# down
text = remove_credits(text)
badTriggers = ['lyrics', 'copyright', 'property', 'links']
if artist:
badTriggersOcc += [artist]
for item in badTriggers:
badTriggersOcc += [item] * len(re.findall(r'\W%s\W' % item,
text, re.I))
if badTriggersOcc:
log.debug(u'Bad triggers detected: {0}'.format(badTriggersOcc))
return len(badTriggersOcc) < 2
def _scrape_strip_cruft(html, plain_text_out=False):
"""Clean up HTML
"""
html = unescape(html)
html = html.replace('\r', '\n') # Normalize EOL.
html = re.sub(r' +', ' ', html) # Whitespaces collapse.
html = BREAK_RE.sub('\n', html) # <br> eats up surrounding '\n'.
html = re.sub(r'<(script).*?</\1>(?s)', '', html) # Strip script tags.
if plain_text_out: # Strip remaining HTML tags
html = COMMENT_RE.sub('', html)
html = TAG_RE.sub('', html)
html = '\n'.join([x.strip() for x in html.strip().split('\n')])
html = re.sub(r'\n{3,}', r'\n\n', html)
return html
def _scrape_merge_paragraphs(html):
html = re.sub(r'</p>\s*<p(\s*[^>]*)>', '\n', html)
return re.sub(r'<div .*>\s*</div>', '\n', html)
def scrape_lyrics_from_html(html):
"""Scrape lyrics from a URL. If no lyrics can be found, return None
instead.
"""
from bs4 import SoupStrainer, BeautifulSoup
if not html:
return None
def is_text_notcode(text):
length = len(text)
return (length > 20 and
text.count(' ') > length / 25 and
(text.find('{') == -1 or text.find(';') == -1))
html = _scrape_strip_cruft(html)
html = _scrape_merge_paragraphs(html)
# extract all long text blocks that are not code
try:
soup = BeautifulSoup(html, "html.parser",
parse_only=SoupStrainer(text=is_text_notcode))
except HTMLParseError:
return None
soup = sorted(soup.stripped_strings, key=len)[-1]
return soup
def fetch_google(artist, title):
"""Fetch lyrics from Google search results.
"""
query = u"%s %s" % (artist, title)
api_key = config['lyrics']['google_API_key'].get(unicode)
engine_id = config['lyrics']['google_engine_ID'].get(unicode)
url = u'https://www.googleapis.com/customsearch/v1?key=%s&cx=%s&q=%s' % \
(api_key, engine_id, urllib.quote(query.encode('utf8')))
data = urllib.urlopen(url)
data = json.load(data)
if 'error' in data:
reason = data['error']['errors'][0]['reason']
log.debug(u'google lyrics backend error: {0}'.format(reason))
return
if 'items' in data.keys():
for item in data['items']:
urlLink = item['link']
urlTitle = item.get('title', u'')
if not is_page_candidate(urlLink, urlTitle, title, artist):
continue
html = fetch_url(urlLink)
lyrics = scrape_lyrics_from_html(html)
if not lyrics:
continue
if is_lyrics(lyrics, artist):
log.debug(u'got lyrics from {0}'.format(item['displayLink']))
return lyrics
# Plugin logic.
SOURCES = ['google', 'lyricwiki', 'lyrics.com', 'musixmatch']
SOURCE_BACKENDS = {
'google': fetch_google,
'lyricwiki': fetch_lyricswiki,
'lyrics.com': fetch_lyricscom,
'musixmatch': fetch_musixmatch,
}
class LyricsPlugin(plugins.BeetsPlugin):
def __init__(self):
super(LyricsPlugin, self).__init__()
self.import_stages = [self.imported]
self.config.add({
'auto': True,
'google_API_key': None,
'google_engine_ID': u'009217259823014548361:lndtuqkycfu',
'fallback': None,
'force': False,
'sources': SOURCES,
})
available_sources = list(SOURCES)
if not self.config['google_API_key'].get() and \
'google' in SOURCES:
available_sources.remove('google')
self.config['sources'] = plugins.sanitize_choices(
self.config['sources'].as_str_seq(), available_sources)
self.backends = []
for key in self.config['sources'].as_str_seq():
self.backends.append(SOURCE_BACKENDS[key])
def commands(self):
cmd = ui.Subcommand('lyrics', help='fetch song lyrics')
cmd.parser.add_option('-p', '--print', dest='printlyr',
action='store_true', default=False,
help='print lyrics to console')
cmd.parser.add_option('-f', '--force', dest='force_refetch',
action='store_true', default=False,
help='always re-download lyrics')
def func(lib, opts, args):
# The "write to files" option corresponds to the
# import_write config value.
write = config['import']['write'].get(bool)
for item in lib.items(ui.decargs(args)):
self.fetch_item_lyrics(
lib, logging.INFO, item, write,
opts.force_refetch or self.config['force'],
)
if opts.printlyr and item.lyrics:
ui.print_(item.lyrics)
cmd.func = func
return [cmd]
def imported(self, session, task):
"""Import hook for fetching lyrics automatically.
"""
if self.config['auto']:
for item in task.imported_items():
self.fetch_item_lyrics(session.lib, logging.DEBUG, item,
False, self.config['force'])
def fetch_item_lyrics(self, lib, loglevel, item, write, force):
"""Fetch and store lyrics for a single item. If ``write``, then the
lyrics will also be written to the file itself. The ``loglevel``
parameter controls the visibility of the function's status log
messages.
"""
# Skip if the item already has lyrics.
if not force and item.lyrics:
log.log(loglevel, u'lyrics already present: {0} - {1}'
.format(item.artist, item.title))
return
lyrics = None
for artist, titles in search_pairs(item):
lyrics = [self.get_lyrics(artist, title) for title in titles]
if any(lyrics):
break
lyrics = u"\n\n---\n\n".join([l for l in lyrics if l])
if lyrics:
log.log(loglevel, u'fetched lyrics: {0} - {1}'
.format(item.artist, item.title))
else:
log.log(loglevel, u'lyrics not found: {0} - {1}'
.format(item.artist, item.title))
fallback = self.config['fallback'].get()
if fallback:
lyrics = fallback
else:
return
item.lyrics = lyrics
if write:
item.try_write()
item.store()
def get_lyrics(self, artist, title):
"""Fetch lyrics, trying each source in turn. Return a string or
None if no lyrics were found.
"""
for backend in self.backends:
lyrics = backend(artist, title)
if lyrics:
log.debug(u'got lyrics from backend: {0}'
.format(backend.__name__))
return _scrape_strip_cruft(lyrics, True)

View File

@@ -14,9 +14,6 @@
# Minor modifications made by Andrew Resch to replace the BTFailure errors with Exceptions
from types import StringType, IntType, LongType, DictType, ListType, TupleType
def decode_int(x, f):
f += 1
newf = x.index('e', f)
@@ -24,36 +21,32 @@ def decode_int(x, f):
if x[f] == '-':
if x[f + 1] == '0':
raise ValueError
elif x[f] == '0' and newf != f + 1:
elif x[f] == '0' and newf != f+1:
raise ValueError
return (n, newf + 1)
return (n, newf+1)
def decode_string(x, f):
colon = x.index(':', f)
n = int(x[f:colon])
if x[f] == '0' and colon != f + 1:
if x[f] == '0' and colon != f+1:
raise ValueError
colon += 1
return (x[colon:colon + n], colon + n)
return (x[colon:colon+n], colon+n)
def decode_list(x, f):
r, f = [], f + 1
r, f = [], f+1
while x[f] != 'e':
v, f = decode_func[x[f]](x, f)
r.append(v)
return (r, f + 1)
def decode_dict(x, f):
r, f = {}, f + 1
r, f = {}, f+1
while x[f] != 'e':
k, f = decode_string(x, f)
r[k], f = decode_func[x[f]](x, f)
return (r, f + 1)
decode_func = {}
decode_func['l'] = decode_list
decode_func['d'] = decode_dict
@@ -69,7 +62,6 @@ decode_func['7'] = decode_string
decode_func['8'] = decode_string
decode_func['9'] = decode_string
def bdecode(x):
try:
r, l = decode_func[x[0]](x, 0)
@@ -78,41 +70,38 @@ def bdecode(x):
return r
from types import StringType, IntType, LongType, DictType, ListType, TupleType
class Bencached(object):
__slots__ = ['bencoded']
def __init__(self, s):
self.bencoded = s
def encode_bencached(x, r):
def encode_bencached(x,r):
r.append(x.bencoded)
def encode_int(x, r):
r.extend(('i', str(x), 'e'))
def encode_bool(x, r):
if x:
encode_int(1, r)
else:
encode_int(0, r)
def encode_string(x, r):
r.extend((str(len(x)), ':', x))
def encode_list(x, r):
r.append('l')
for i in x:
encode_func[type(i)](i, r)
r.append('e')
def encode_dict(x, r):
def encode_dict(x,r):
r.append('d')
ilist = x.items()
ilist.sort()
@@ -121,7 +110,6 @@ def encode_dict(x, r):
encode_func[type(v)](v, r)
r.append('e')
encode_func = {}
encode_func[Bencached] = encode_bencached
encode_func[IntType] = encode_int
@@ -133,12 +121,10 @@ encode_func[DictType] = encode_dict
try:
from types import BooleanType
encode_func[BooleanType] = encode_bool
except ImportError:
pass
def bencode(x):
r = []
encode_func[type(x)](x, r)

803
lib/biplist/__init__.py Executable file
View File

@@ -0,0 +1,803 @@
"""biplist -- a library for reading and writing binary property list files.
Binary Property List (plist) files provide a faster and smaller serialization
format for property lists on OS X. This is a library for generating binary
plists which can be read by OS X, iOS, or other clients.
The API models the plistlib API, and will call through to plistlib when
XML serialization or deserialization is required.
To generate plists with UID values, wrap the values with the Uid object. The
value must be an int.
To generate plists with NSData/CFData values, wrap the values with the
Data object. The value must be a string.
Date values can only be datetime.datetime objects.
The exceptions InvalidPlistException and NotBinaryPlistException may be
thrown to indicate that the data cannot be serialized or deserialized as
a binary plist.
Plist generation example:
from biplist import *
from datetime import datetime
plist = {'aKey':'aValue',
'0':1.322,
'now':datetime.now(),
'list':[1,2,3],
'tuple':('a','b','c')
}
try:
writePlist(plist, "example.plist")
except (InvalidPlistException, NotBinaryPlistException), e:
print "Something bad happened:", e
Plist parsing example:
from biplist import *
try:
plist = readPlist("example.plist")
print plist
except (InvalidPlistException, NotBinaryPlistException), e:
print "Not a plist:", e
"""
import sys
from collections import namedtuple
import datetime
import io
import math
import plistlib
from struct import pack, unpack
from struct import error as struct_error
import sys
import time
try:
unicode
unicodeEmpty = r''
except NameError:
unicode = str
unicodeEmpty = ''
try:
long
except NameError:
long = int
try:
{}.iteritems
iteritems = lambda x: x.iteritems()
except AttributeError:
iteritems = lambda x: x.items()
__all__ = [
'Uid', 'Data', 'readPlist', 'writePlist', 'readPlistFromString',
'writePlistToString', 'InvalidPlistException', 'NotBinaryPlistException'
]
# Apple uses Jan 1, 2001 as a base for all plist date/times.
apple_reference_date = datetime.datetime.utcfromtimestamp(978307200)
class Uid(int):
"""Wrapper around integers for representing UID values. This
is used in keyed archiving."""
def __repr__(self):
return "Uid(%d)" % self
class Data(bytes):
"""Wrapper around str types for representing Data values."""
pass
class InvalidPlistException(Exception):
"""Raised when the plist is incorrectly formatted."""
pass
class NotBinaryPlistException(Exception):
"""Raised when a binary plist was expected but not encountered."""
pass
def readPlist(pathOrFile):
"""Raises NotBinaryPlistException, InvalidPlistException"""
didOpen = False
result = None
if isinstance(pathOrFile, (bytes, unicode)):
pathOrFile = open(pathOrFile, 'rb')
didOpen = True
try:
reader = PlistReader(pathOrFile)
result = reader.parse()
except NotBinaryPlistException as e:
try:
pathOrFile.seek(0)
result = None
if hasattr(plistlib, 'loads'):
contents = None
if isinstance(pathOrFile, (bytes, unicode)):
with open(pathOrFile, 'rb') as f:
contents = f.read()
else:
contents = pathOrFile.read()
result = plistlib.loads(contents)
else:
result = plistlib.readPlist(pathOrFile)
result = wrapDataObject(result, for_binary=True)
except Exception as e:
raise InvalidPlistException(e)
finally:
if didOpen:
pathOrFile.close()
return result
def wrapDataObject(o, for_binary=False):
if isinstance(o, Data) and not for_binary:
v = sys.version_info
if not (v[0] >= 3 and v[1] >= 4):
o = plistlib.Data(o)
elif isinstance(o, (bytes, plistlib.Data)) and for_binary:
if hasattr(o, 'data'):
o = Data(o.data)
elif isinstance(o, tuple):
o = wrapDataObject(list(o), for_binary)
o = tuple(o)
elif isinstance(o, list):
for i in range(len(o)):
o[i] = wrapDataObject(o[i], for_binary)
elif isinstance(o, dict):
for k in o:
o[k] = wrapDataObject(o[k], for_binary)
return o
def writePlist(rootObject, pathOrFile, binary=True):
if not binary:
rootObject = wrapDataObject(rootObject, binary)
if hasattr(plistlib, "dump"):
if isinstance(pathOrFile, (bytes, unicode)):
with open(pathOrFile, 'wb') as f:
return plistlib.dump(rootObject, f)
else:
return plistlib.dump(rootObject, pathOrFile)
else:
return plistlib.writePlist(rootObject, pathOrFile)
else:
didOpen = False
if isinstance(pathOrFile, (bytes, unicode)):
pathOrFile = open(pathOrFile, 'wb')
didOpen = True
writer = PlistWriter(pathOrFile)
result = writer.writeRoot(rootObject)
if didOpen:
pathOrFile.close()
return result
def readPlistFromString(data):
return readPlist(io.BytesIO(data))
def writePlistToString(rootObject, binary=True):
if not binary:
rootObject = wrapDataObject(rootObject, binary)
if hasattr(plistlib, "dumps"):
return plistlib.dumps(rootObject)
elif hasattr(plistlib, "writePlistToBytes"):
return plistlib.writePlistToBytes(rootObject)
else:
return plistlib.writePlistToString(rootObject)
else:
ioObject = io.BytesIO()
writer = PlistWriter(ioObject)
writer.writeRoot(rootObject)
return ioObject.getvalue()
def is_stream_binary_plist(stream):
stream.seek(0)
header = stream.read(7)
if header == b'bplist0':
return True
else:
return False
PlistTrailer = namedtuple('PlistTrailer', 'offsetSize, objectRefSize, offsetCount, topLevelObjectNumber, offsetTableOffset')
PlistByteCounts = namedtuple('PlistByteCounts', 'nullBytes, boolBytes, intBytes, realBytes, dateBytes, dataBytes, stringBytes, uidBytes, arrayBytes, setBytes, dictBytes')
class PlistReader(object):
file = None
contents = ''
offsets = None
trailer = None
currentOffset = 0
def __init__(self, fileOrStream):
"""Raises NotBinaryPlistException."""
self.reset()
self.file = fileOrStream
def parse(self):
return self.readRoot()
def reset(self):
self.trailer = None
self.contents = ''
self.offsets = []
self.currentOffset = 0
def readRoot(self):
result = None
self.reset()
# Get the header, make sure it's a valid file.
if not is_stream_binary_plist(self.file):
raise NotBinaryPlistException()
self.file.seek(0)
self.contents = self.file.read()
if len(self.contents) < 32:
raise InvalidPlistException("File is too short.")
trailerContents = self.contents[-32:]
try:
self.trailer = PlistTrailer._make(unpack("!xxxxxxBBQQQ", trailerContents))
offset_size = self.trailer.offsetSize * self.trailer.offsetCount
offset = self.trailer.offsetTableOffset
offset_contents = self.contents[offset:offset+offset_size]
offset_i = 0
while offset_i < self.trailer.offsetCount:
begin = self.trailer.offsetSize*offset_i
tmp_contents = offset_contents[begin:begin+self.trailer.offsetSize]
tmp_sized = self.getSizedInteger(tmp_contents, self.trailer.offsetSize)
self.offsets.append(tmp_sized)
offset_i += 1
self.setCurrentOffsetToObjectNumber(self.trailer.topLevelObjectNumber)
result = self.readObject()
except TypeError as e:
raise InvalidPlistException(e)
return result
def setCurrentOffsetToObjectNumber(self, objectNumber):
self.currentOffset = self.offsets[objectNumber]
def readObject(self):
result = None
tmp_byte = self.contents[self.currentOffset:self.currentOffset+1]
marker_byte = unpack("!B", tmp_byte)[0]
format = (marker_byte >> 4) & 0x0f
extra = marker_byte & 0x0f
self.currentOffset += 1
def proc_extra(extra):
if extra == 0b1111:
#self.currentOffset += 1
extra = self.readObject()
return extra
# bool, null, or fill byte
if format == 0b0000:
if extra == 0b0000:
result = None
elif extra == 0b1000:
result = False
elif extra == 0b1001:
result = True
elif extra == 0b1111:
pass # fill byte
else:
raise InvalidPlistException("Invalid object found at offset: %d" % (self.currentOffset - 1))
# int
elif format == 0b0001:
extra = proc_extra(extra)
result = self.readInteger(pow(2, extra))
# real
elif format == 0b0010:
extra = proc_extra(extra)
result = self.readReal(extra)
# date
elif format == 0b0011 and extra == 0b0011:
result = self.readDate()
# data
elif format == 0b0100:
extra = proc_extra(extra)
result = self.readData(extra)
# ascii string
elif format == 0b0101:
extra = proc_extra(extra)
result = self.readAsciiString(extra)
# Unicode string
elif format == 0b0110:
extra = proc_extra(extra)
result = self.readUnicode(extra)
# uid
elif format == 0b1000:
result = self.readUid(extra)
# array
elif format == 0b1010:
extra = proc_extra(extra)
result = self.readArray(extra)
# set
elif format == 0b1100:
extra = proc_extra(extra)
result = set(self.readArray(extra))
# dict
elif format == 0b1101:
extra = proc_extra(extra)
result = self.readDict(extra)
else:
raise InvalidPlistException("Invalid object found: {format: %s, extra: %s}" % (bin(format), bin(extra)))
return result
def readInteger(self, byteSize):
result = 0
original_offset = self.currentOffset
data = self.contents[self.currentOffset:self.currentOffset + byteSize]
result = self.getSizedInteger(data, byteSize, as_number=True)
self.currentOffset = original_offset + byteSize
return result
def readReal(self, length):
result = 0.0
to_read = pow(2, length)
data = self.contents[self.currentOffset:self.currentOffset+to_read]
if length == 2: # 4 bytes
result = unpack('>f', data)[0]
elif length == 3: # 8 bytes
result = unpack('>d', data)[0]
else:
raise InvalidPlistException("Unknown real of length %d bytes" % to_read)
return result
def readRefs(self, count):
refs = []
i = 0
while i < count:
fragment = self.contents[self.currentOffset:self.currentOffset+self.trailer.objectRefSize]
ref = self.getSizedInteger(fragment, len(fragment))
refs.append(ref)
self.currentOffset += self.trailer.objectRefSize
i += 1
return refs
def readArray(self, count):
result = []
values = self.readRefs(count)
i = 0
while i < len(values):
self.setCurrentOffsetToObjectNumber(values[i])
value = self.readObject()
result.append(value)
i += 1
return result
def readDict(self, count):
result = {}
keys = self.readRefs(count)
values = self.readRefs(count)
i = 0
while i < len(keys):
self.setCurrentOffsetToObjectNumber(keys[i])
key = self.readObject()
self.setCurrentOffsetToObjectNumber(values[i])
value = self.readObject()
result[key] = value
i += 1
return result
def readAsciiString(self, length):
result = unpack("!%ds" % length, self.contents[self.currentOffset:self.currentOffset+length])[0]
self.currentOffset += length
return result
def readUnicode(self, length):
actual_length = length*2
data = self.contents[self.currentOffset:self.currentOffset+actual_length]
# unpack not needed?!! data = unpack(">%ds" % (actual_length), data)[0]
self.currentOffset += actual_length
return data.decode('utf_16_be')
def readDate(self):
result = unpack(">d", self.contents[self.currentOffset:self.currentOffset+8])[0]
# Use timedelta to workaround time_t size limitation on 32-bit python.
result = datetime.timedelta(seconds=result) + apple_reference_date
self.currentOffset += 8
return result
def readData(self, length):
result = self.contents[self.currentOffset:self.currentOffset+length]
self.currentOffset += length
return Data(result)
def readUid(self, length):
return Uid(self.readInteger(length+1))
def getSizedInteger(self, data, byteSize, as_number=False):
"""Numbers of 8 bytes are signed integers when they refer to numbers, but unsigned otherwise."""
result = 0
# 1, 2, and 4 byte integers are unsigned
if byteSize == 1:
result = unpack('>B', data)[0]
elif byteSize == 2:
result = unpack('>H', data)[0]
elif byteSize == 4:
result = unpack('>L', data)[0]
elif byteSize == 8:
if as_number:
result = unpack('>q', data)[0]
else:
result = unpack('>Q', data)[0]
elif byteSize <= 16:
# Handle odd-sized or integers larger than 8 bytes
# Don't naively go over 16 bytes, in order to prevent infinite loops.
result = 0
if hasattr(int, 'from_bytes'):
result = int.from_bytes(data, 'big')
else:
for byte in data:
result = (result << 8) | unpack('>B', byte)[0]
else:
raise InvalidPlistException("Encountered integer longer than 16 bytes.")
return result
class HashableWrapper(object):
def __init__(self, value):
self.value = value
def __repr__(self):
return "<HashableWrapper: %s>" % [self.value]
class BoolWrapper(object):
def __init__(self, value):
self.value = value
def __repr__(self):
return "<BoolWrapper: %s>" % self.value
class FloatWrapper(object):
_instances = {}
def __new__(klass, value):
# Ensure FloatWrapper(x) for a given float x is always the same object
wrapper = klass._instances.get(value)
if wrapper is None:
wrapper = object.__new__(klass)
wrapper.value = value
klass._instances[value] = wrapper
return wrapper
def __repr__(self):
return "<FloatWrapper: %s>" % self.value
class PlistWriter(object):
header = b'bplist00bybiplist1.0'
file = None
byteCounts = None
trailer = None
computedUniques = None
writtenReferences = None
referencePositions = None
wrappedTrue = None
wrappedFalse = None
def __init__(self, file):
self.reset()
self.file = file
self.wrappedTrue = BoolWrapper(True)
self.wrappedFalse = BoolWrapper(False)
def reset(self):
self.byteCounts = PlistByteCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
self.trailer = PlistTrailer(0, 0, 0, 0, 0)
# A set of all the uniques which have been computed.
self.computedUniques = set()
# A list of all the uniques which have been written.
self.writtenReferences = {}
# A dict of the positions of the written uniques.
self.referencePositions = {}
def positionOfObjectReference(self, obj):
"""If the given object has been written already, return its
position in the offset table. Otherwise, return None."""
return self.writtenReferences.get(obj)
def writeRoot(self, root):
"""
Strategy is:
- write header
- wrap root object so everything is hashable
- compute size of objects which will be written
- need to do this in order to know how large the object refs
will be in the list/dict/set reference lists
- write objects
- keep objects in writtenReferences
- keep positions of object references in referencePositions
- write object references with the length computed previously
- computer object reference length
- write object reference positions
- write trailer
"""
output = self.header
wrapped_root = self.wrapRoot(root)
should_reference_root = True#not isinstance(wrapped_root, HashableWrapper)
self.computeOffsets(wrapped_root, asReference=should_reference_root, isRoot=True)
self.trailer = self.trailer._replace(**{'objectRefSize':self.intSize(len(self.computedUniques))})
(_, output) = self.writeObjectReference(wrapped_root, output)
output = self.writeObject(wrapped_root, output, setReferencePosition=True)
# output size at this point is an upper bound on how big the
# object reference offsets need to be.
self.trailer = self.trailer._replace(**{
'offsetSize':self.intSize(len(output)),
'offsetCount':len(self.computedUniques),
'offsetTableOffset':len(output),
'topLevelObjectNumber':0
})
output = self.writeOffsetTable(output)
output += pack('!xxxxxxBBQQQ', *self.trailer)
self.file.write(output)
def wrapRoot(self, root):
if isinstance(root, bool):
if root is True:
return self.wrappedTrue
else:
return self.wrappedFalse
elif isinstance(root, float):
return FloatWrapper(root)
elif isinstance(root, set):
n = set()
for value in root:
n.add(self.wrapRoot(value))
return HashableWrapper(n)
elif isinstance(root, dict):
n = {}
for key, value in iteritems(root):
n[self.wrapRoot(key)] = self.wrapRoot(value)
return HashableWrapper(n)
elif isinstance(root, list):
n = []
for value in root:
n.append(self.wrapRoot(value))
return HashableWrapper(n)
elif isinstance(root, tuple):
n = tuple([self.wrapRoot(value) for value in root])
return HashableWrapper(n)
else:
return root
def incrementByteCount(self, field, incr=1):
self.byteCounts = self.byteCounts._replace(**{field:self.byteCounts.__getattribute__(field) + incr})
def computeOffsets(self, obj, asReference=False, isRoot=False):
def check_key(key):
if key is None:
raise InvalidPlistException('Dictionary keys cannot be null in plists.')
elif isinstance(key, Data):
raise InvalidPlistException('Data cannot be dictionary keys in plists.')
elif not isinstance(key, (bytes, unicode)):
raise InvalidPlistException('Keys must be strings.')
def proc_size(size):
if size > 0b1110:
size += self.intSize(size)
return size
# If this should be a reference, then we keep a record of it in the
# uniques table.
if asReference:
if obj in self.computedUniques:
return
else:
self.computedUniques.add(obj)
if obj is None:
self.incrementByteCount('nullBytes')
elif isinstance(obj, BoolWrapper):
self.incrementByteCount('boolBytes')
elif isinstance(obj, Uid):
size = self.intSize(obj)
self.incrementByteCount('uidBytes', incr=1+size)
elif isinstance(obj, (int, long)):
size = self.intSize(obj)
self.incrementByteCount('intBytes', incr=1+size)
elif isinstance(obj, FloatWrapper):
size = self.realSize(obj)
self.incrementByteCount('realBytes', incr=1+size)
elif isinstance(obj, datetime.datetime):
self.incrementByteCount('dateBytes', incr=2)
elif isinstance(obj, Data):
size = proc_size(len(obj))
self.incrementByteCount('dataBytes', incr=1+size)
elif isinstance(obj, (unicode, bytes)):
size = proc_size(len(obj))
self.incrementByteCount('stringBytes', incr=1+size)
elif isinstance(obj, HashableWrapper):
obj = obj.value
if isinstance(obj, set):
size = proc_size(len(obj))
self.incrementByteCount('setBytes', incr=1+size)
for value in obj:
self.computeOffsets(value, asReference=True)
elif isinstance(obj, (list, tuple)):
size = proc_size(len(obj))
self.incrementByteCount('arrayBytes', incr=1+size)
for value in obj:
asRef = True
self.computeOffsets(value, asReference=True)
elif isinstance(obj, dict):
size = proc_size(len(obj))
self.incrementByteCount('dictBytes', incr=1+size)
for key, value in iteritems(obj):
check_key(key)
self.computeOffsets(key, asReference=True)
self.computeOffsets(value, asReference=True)
else:
raise InvalidPlistException("Unknown object type.")
def writeObjectReference(self, obj, output):
"""Tries to write an object reference, adding it to the references
table. Does not write the actual object bytes or set the reference
position. Returns a tuple of whether the object was a new reference
(True if it was, False if it already was in the reference table)
and the new output.
"""
position = self.positionOfObjectReference(obj)
if position is None:
self.writtenReferences[obj] = len(self.writtenReferences)
output += self.binaryInt(len(self.writtenReferences) - 1, byteSize=self.trailer.objectRefSize)
return (True, output)
else:
output += self.binaryInt(position, byteSize=self.trailer.objectRefSize)
return (False, output)
def writeObject(self, obj, output, setReferencePosition=False):
"""Serializes the given object to the output. Returns output.
If setReferencePosition is True, will set the position the
object was written.
"""
def proc_variable_length(format, length):
result = b''
if length > 0b1110:
result += pack('!B', (format << 4) | 0b1111)
result = self.writeObject(length, result)
else:
result += pack('!B', (format << 4) | length)
return result
if isinstance(obj, (str, unicode)) and obj == unicodeEmpty:
# The Apple Plist decoder can't decode a zero length Unicode string.
obj = b''
if setReferencePosition:
self.referencePositions[obj] = len(output)
if obj is None:
output += pack('!B', 0b00000000)
elif isinstance(obj, BoolWrapper):
if obj.value is False:
output += pack('!B', 0b00001000)
else:
output += pack('!B', 0b00001001)
elif isinstance(obj, Uid):
size = self.intSize(obj)
output += pack('!B', (0b1000 << 4) | size - 1)
output += self.binaryInt(obj)
elif isinstance(obj, (int, long)):
byteSize = self.intSize(obj)
root = math.log(byteSize, 2)
output += pack('!B', (0b0001 << 4) | int(root))
output += self.binaryInt(obj, as_number=True)
elif isinstance(obj, FloatWrapper):
# just use doubles
output += pack('!B', (0b0010 << 4) | 3)
output += self.binaryReal(obj)
elif isinstance(obj, datetime.datetime):
timestamp = (obj - apple_reference_date).total_seconds()
output += pack('!B', 0b00110011)
output += pack('!d', float(timestamp))
elif isinstance(obj, Data):
output += proc_variable_length(0b0100, len(obj))
output += obj
elif isinstance(obj, unicode):
byteData = obj.encode('utf_16_be')
output += proc_variable_length(0b0110, len(byteData)//2)
output += byteData
elif isinstance(obj, bytes):
output += proc_variable_length(0b0101, len(obj))
output += obj
elif isinstance(obj, HashableWrapper):
obj = obj.value
if isinstance(obj, (set, list, tuple)):
if isinstance(obj, set):
output += proc_variable_length(0b1100, len(obj))
else:
output += proc_variable_length(0b1010, len(obj))
objectsToWrite = []
for objRef in obj:
(isNew, output) = self.writeObjectReference(objRef, output)
if isNew:
objectsToWrite.append(objRef)
for objRef in objectsToWrite:
output = self.writeObject(objRef, output, setReferencePosition=True)
elif isinstance(obj, dict):
output += proc_variable_length(0b1101, len(obj))
keys = []
values = []
objectsToWrite = []
for key, value in iteritems(obj):
keys.append(key)
values.append(value)
for key in keys:
(isNew, output) = self.writeObjectReference(key, output)
if isNew:
objectsToWrite.append(key)
for value in values:
(isNew, output) = self.writeObjectReference(value, output)
if isNew:
objectsToWrite.append(value)
for objRef in objectsToWrite:
output = self.writeObject(objRef, output, setReferencePosition=True)
return output
def writeOffsetTable(self, output):
"""Writes all of the object reference offsets."""
all_positions = []
writtenReferences = list(self.writtenReferences.items())
writtenReferences.sort(key=lambda x: x[1])
for obj,order in writtenReferences:
# Porting note: Elsewhere we deliberately replace empty unicdoe strings
# with empty binary strings, but the empty unicode string
# goes into writtenReferences. This isn't an issue in Py2
# because u'' and b'' have the same hash; but it is in
# Py3, where they don't.
if bytes != str and obj == unicodeEmpty:
obj = b''
position = self.referencePositions.get(obj)
if position is None:
raise InvalidPlistException("Error while writing offsets table. Object not found. %s" % obj)
output += self.binaryInt(position, self.trailer.offsetSize)
all_positions.append(position)
return output
def binaryReal(self, obj):
# just use doubles
result = pack('>d', obj.value)
return result
def binaryInt(self, obj, byteSize=None, as_number=False):
result = b''
if byteSize is None:
byteSize = self.intSize(obj)
if byteSize == 1:
result += pack('>B', obj)
elif byteSize == 2:
result += pack('>H', obj)
elif byteSize == 4:
result += pack('>L', obj)
elif byteSize == 8:
if as_number:
result += pack('>q', obj)
else:
result += pack('>Q', obj)
elif byteSize <= 16:
try:
result = pack('>Q', 0) + pack('>Q', obj)
except struct_error as e:
raise InvalidPlistException("Unable to pack integer %d: %s" % (obj, e))
else:
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 16 bytes.")
return result
def intSize(self, obj):
"""Returns the number of bytes necessary to store the given integer."""
# SIGNED
if obj < 0: # Signed integer, always 8 bytes
return 8
# UNSIGNED
elif obj <= 0xFF: # 1 byte
return 1
elif obj <= 0xFFFF: # 2 bytes
return 2
elif obj <= 0xFFFFFFFF: # 4 bytes
return 4
# SIGNED
# 0x7FFFFFFFFFFFFFFF is the max.
elif obj <= 0x7FFFFFFFFFFFFFFF: # 8 bytes signed
return 8
elif obj <= 0xffffffffffffffff: # 8 bytes unsigned
return 16
else:
raise InvalidPlistException("Core Foundation can't handle integers with size greater than 8 bytes.")
def realSize(self, obj):
return 8

468
lib/bs4/__init__.py Normal file
View File

@@ -0,0 +1,468 @@
"""Beautiful Soup
Elixir and Tonic
"The Screen-Scraper's Friend"
http://www.crummy.com/software/BeautifulSoup/
Beautiful Soup uses a pluggable XML or HTML parser to parse a
(possibly invalid) document into a tree representation. Beautiful Soup
provides provides methods and Pythonic idioms that make it easy to
navigate, search, and modify the parse tree.
Beautiful Soup works with Python 2.6 and up. It works better if lxml
and/or html5lib is installed.
For more than you ever wanted to know about Beautiful Soup, see the
documentation:
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
"""
__author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.4.0"
__copyright__ = "Copyright (c) 2004-2015 Leonard Richardson"
__license__ = "MIT"
__all__ = ['BeautifulSoup']
import os
import re
import warnings
from .builder import builder_registry, ParserRejectedMarkup
from .dammit import UnicodeDammit
from .element import (
CData,
Comment,
DEFAULT_OUTPUT_ENCODING,
Declaration,
Doctype,
NavigableString,
PageElement,
ProcessingInstruction,
ResultSet,
SoupStrainer,
Tag,
)
# The very first thing we do is give a useful error if someone is
# running this code under Python 3 without converting it.
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'<>'You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
class BeautifulSoup(Tag):
"""
This class defines the basic interface called by the tree builders.
These methods will be called by the parser:
reset()
feed(markup)
The tree builder may call these methods from its feed() implementation:
handle_starttag(name, attrs) # See note about return value
handle_endtag(name)
handle_data(data) # Appends to the current data node
endData(containerClass=NavigableString) # Ends the current data node
No matter how complicated the underlying parser is, you should be
able to build a tree using 'start tag' events, 'end tag' events,
'data' events, and "done with data" events.
If you encounter an empty-element tag (aka a self-closing tag,
like HTML's <br> tag), call handle_starttag and then
handle_endtag.
"""
ROOT_TAG_NAME = u'[document]'
# If the end-user gives no indication which tree builder they
# want, look for one with these features.
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nTo get rid of this warning, change this:\n\n BeautifulSoup([your markup])\n\nto this:\n\n BeautifulSoup([your markup], \"%(parser)s\")\n"
def __init__(self, markup="", features=None, builder=None,
parse_only=None, from_encoding=None, exclude_encodings=None,
**kwargs):
"""The Soup object is initialized as the 'root tag', and the
provided markup (which can be a string or a file-like object)
is fed into the underlying parser."""
if 'convertEntities' in kwargs:
warnings.warn(
"BS4 does not respect the convertEntities argument to the "
"BeautifulSoup constructor. Entities are always converted "
"to Unicode characters.")
if 'markupMassage' in kwargs:
del kwargs['markupMassage']
warnings.warn(
"BS4 does not respect the markupMassage argument to the "
"BeautifulSoup constructor. The tree builder is responsible "
"for any necessary markup massage.")
if 'smartQuotesTo' in kwargs:
del kwargs['smartQuotesTo']
warnings.warn(
"BS4 does not respect the smartQuotesTo argument to the "
"BeautifulSoup constructor. Smart quotes are always converted "
"to Unicode characters.")
if 'selfClosingTags' in kwargs:
del kwargs['selfClosingTags']
warnings.warn(
"BS4 does not respect the selfClosingTags argument to the "
"BeautifulSoup constructor. The tree builder is responsible "
"for understanding self-closing tags.")
if 'isHTML' in kwargs:
del kwargs['isHTML']
warnings.warn(
"BS4 does not respect the isHTML argument to the "
"BeautifulSoup constructor. Suggest you use "
"features='lxml' for HTML and features='lxml-xml' for "
"XML.")
def deprecated_argument(old_name, new_name):
if old_name in kwargs:
warnings.warn(
'The "%s" argument to the BeautifulSoup constructor '
'has been renamed to "%s."' % (old_name, new_name))
value = kwargs[old_name]
del kwargs[old_name]
return value
return None
parse_only = parse_only or deprecated_argument(
"parseOnlyThese", "parse_only")
from_encoding = from_encoding or deprecated_argument(
"fromEncoding", "from_encoding")
if len(kwargs) > 0:
arg = kwargs.keys().pop()
raise TypeError(
"__init__() got an unexpected keyword argument '%s'" % arg)
if builder is None:
original_features = features
if isinstance(features, basestring):
features = [features]
if features is None or len(features) == 0:
features = self.DEFAULT_BUILDER_FEATURES
builder_class = builder_registry.lookup(*features)
if builder_class is None:
raise FeatureNotFound(
"Couldn't find a tree builder with the features you "
"requested: %s. Do you need to install a parser library?"
% ",".join(features))
builder = builder_class()
if not (original_features == builder.NAME or
original_features in builder.ALTERNATE_NAMES):
if builder.is_xml:
markup_type = "XML"
else:
markup_type = "HTML"
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % dict(
parser=builder.NAME,
markup_type=markup_type))
self.builder = builder
self.is_xml = builder.is_xml
self.builder.soup = self
self.parse_only = parse_only
if hasattr(markup, 'read'): # It's a file-type object.
markup = markup.read()
elif len(markup) <= 256:
# Print out warnings for a couple beginner problems
# involving passing non-markup to Beautiful Soup.
# Beautiful Soup will still parse the input as markup,
# just in case that's what the user really wants.
if (isinstance(markup, unicode)
and not os.path.supports_unicode_filenames):
possible_filename = markup.encode("utf8")
else:
possible_filename = markup
is_file = False
try:
is_file = os.path.exists(possible_filename)
except Exception, e:
# This is almost certainly a problem involving
# characters not valid in filenames on this
# system. Just let it go.
pass
if is_file:
if isinstance(markup, unicode):
markup = markup.encode("utf8")
warnings.warn(
'"%s" looks like a filename, not markup. You should probably open this file and pass the filehandle into Beautiful Soup.' % markup)
if markup[:5] == "http:" or markup[:6] == "https:":
# TODO: This is ugly but I couldn't get it to work in
# Python 3 otherwise.
if ((isinstance(markup, bytes) and not b' ' in markup)
or (isinstance(markup, unicode) and not u' ' in markup)):
if isinstance(markup, unicode):
markup = markup.encode("utf8")
warnings.warn(
'"%s" looks like a URL. Beautiful Soup is not an HTTP client. You should probably use an HTTP client to get the document behind the URL, and feed that document to Beautiful Soup.' % markup)
for (self.markup, self.original_encoding, self.declared_html_encoding,
self.contains_replacement_characters) in (
self.builder.prepare_markup(
markup, from_encoding, exclude_encodings=exclude_encodings)):
self.reset()
try:
self._feed()
break
except ParserRejectedMarkup:
pass
# Clear out the markup and remove the builder's circular
# reference to this object.
self.markup = None
self.builder.soup = None
def __copy__(self):
return type(self)(self.encode(), builder=self.builder)
def __getstate__(self):
# Frequently a tree builder can't be pickled.
d = dict(self.__dict__)
if 'builder' in d and not self.builder.picklable:
del d['builder']
return d
def _feed(self):
# Convert the document to Unicode.
self.builder.reset()
self.builder.feed(self.markup)
# Close out any unfinished strings and close all the open tags.
self.endData()
while self.currentTag.name != self.ROOT_TAG_NAME:
self.popTag()
def reset(self):
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
self.hidden = 1
self.builder.reset()
self.current_data = []
self.currentTag = None
self.tagStack = []
self.preserve_whitespace_tag_stack = []
self.pushTag(self)
def new_tag(self, name, namespace=None, nsprefix=None, **attrs):
"""Create a new tag associated with this soup."""
return Tag(None, self.builder, name, namespace, nsprefix, attrs)
def new_string(self, s, subclass=NavigableString):
"""Create a new NavigableString associated with this soup."""
return subclass(s)
def insert_before(self, successor):
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
def insert_after(self, successor):
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
def popTag(self):
tag = self.tagStack.pop()
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
self.preserve_whitespace_tag_stack.pop()
#print "Pop", tag.name
if self.tagStack:
self.currentTag = self.tagStack[-1]
return self.currentTag
def pushTag(self, tag):
#print "Push", tag.name
if self.currentTag:
self.currentTag.contents.append(tag)
self.tagStack.append(tag)
self.currentTag = self.tagStack[-1]
if tag.name in self.builder.preserve_whitespace_tags:
self.preserve_whitespace_tag_stack.append(tag)
def endData(self, containerClass=NavigableString):
if self.current_data:
current_data = u''.join(self.current_data)
# If whitespace is not preserved, and this string contains
# nothing but ASCII spaces, replace it with a single space
# or newline.
if not self.preserve_whitespace_tag_stack:
strippable = True
for i in current_data:
if i not in self.ASCII_SPACES:
strippable = False
break
if strippable:
if '\n' in current_data:
current_data = '\n'
else:
current_data = ' '
# Reset the data collector.
self.current_data = []
# Should we add this string to the tree at all?
if self.parse_only and len(self.tagStack) <= 1 and \
(not self.parse_only.text or \
not self.parse_only.search(current_data)):
return
o = containerClass(current_data)
self.object_was_parsed(o)
def object_was_parsed(self, o, parent=None, most_recent_element=None):
"""Add an object to the parse tree."""
parent = parent or self.currentTag
previous_element = most_recent_element or self._most_recent_element
next_element = previous_sibling = next_sibling = None
if isinstance(o, Tag):
next_element = o.next_element
next_sibling = o.next_sibling
previous_sibling = o.previous_sibling
if not previous_element:
previous_element = o.previous_element
o.setup(parent, previous_element, next_element, previous_sibling, next_sibling)
self._most_recent_element = o
parent.contents.append(o)
if parent.next_sibling:
# This node is being inserted into an element that has
# already been parsed. Deal with any dangling references.
index = parent.contents.index(o)
if index == 0:
previous_element = parent
previous_sibling = None
else:
previous_element = previous_sibling = parent.contents[index-1]
if index == len(parent.contents)-1:
next_element = parent.next_sibling
next_sibling = None
else:
next_element = next_sibling = parent.contents[index+1]
o.previous_element = previous_element
if previous_element:
previous_element.next_element = o
o.next_element = next_element
if next_element:
next_element.previous_element = o
o.next_sibling = next_sibling
if next_sibling:
next_sibling.previous_sibling = o
o.previous_sibling = previous_sibling
if previous_sibling:
previous_sibling.next_sibling = o
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
"""Pops the tag stack up to and including the most recent
instance of the given tag. If inclusivePop is false, pops the tag
stack up to but *not* including the most recent instqance of
the given tag."""
#print "Popping to %s" % name
if name == self.ROOT_TAG_NAME:
# The BeautifulSoup object itself can never be popped.
return
most_recently_popped = None
stack_size = len(self.tagStack)
for i in range(stack_size - 1, 0, -1):
t = self.tagStack[i]
if (name == t.name and nsprefix == t.prefix):
if inclusivePop:
most_recently_popped = self.popTag()
break
most_recently_popped = self.popTag()
return most_recently_popped
def handle_starttag(self, name, namespace, nsprefix, attrs):
"""Push a start tag on to the stack.
If this method returns None, the tag was rejected by the
SoupStrainer. You should proceed as if the tag had not occured
in the document. For instance, if this was a self-closing tag,
don't call handle_endtag.
"""
# print "Start tag %s: %s" % (name, attrs)
self.endData()
if (self.parse_only and len(self.tagStack) <= 1
and (self.parse_only.text
or not self.parse_only.search_tag(name, attrs))):
return None
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
self.currentTag, self._most_recent_element)
if tag is None:
return tag
if self._most_recent_element:
self._most_recent_element.next_element = tag
self._most_recent_element = tag
self.pushTag(tag)
return tag
def handle_endtag(self, name, nsprefix=None):
#print "End tag: " + name
self.endData()
self._popToTag(name, nsprefix)
def handle_data(self, data):
self.current_data.append(data)
def decode(self, pretty_print=False,
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
formatter="minimal"):
"""Returns a string or Unicode representation of this document.
To get Unicode, pass None for encoding."""
if self.is_xml:
# Print the XML declaration
encoding_part = ''
if eventual_encoding != None:
encoding_part = ' encoding="%s"' % eventual_encoding
prefix = u'<?xml version="1.0"%s?>\n' % encoding_part
else:
prefix = u''
if not pretty_print:
indent_level = None
else:
indent_level = 0
return prefix + super(BeautifulSoup, self).decode(
indent_level, eventual_encoding, formatter)
# Alias to make it easier to type import: 'from bs4 import _soup'
_s = BeautifulSoup
_soup = BeautifulSoup
class BeautifulStoneSoup(BeautifulSoup):
"""Deprecated interface to an XML parser."""
def __init__(self, *args, **kwargs):
kwargs['features'] = 'xml'
warnings.warn(
'The BeautifulStoneSoup class is deprecated. Instead of using '
'it, pass features="xml" into the BeautifulSoup constructor.')
super(BeautifulStoneSoup, self).__init__(*args, **kwargs)
class StopParsing(Exception):
pass
class FeatureNotFound(ValueError):
pass
#By default, act as an HTML pretty-printer.
if __name__ == '__main__':
import sys
soup = BeautifulSoup(sys.stdin)
print soup.prettify()

324
lib/bs4/builder/__init__.py Normal file
View File

@@ -0,0 +1,324 @@
from collections import defaultdict
import itertools
import sys
from bs4.element import (
CharsetMetaAttributeValue,
ContentMetaAttributeValue,
whitespace_re
)
__all__ = [
'HTMLTreeBuilder',
'SAXTreeBuilder',
'TreeBuilder',
'TreeBuilderRegistry',
]
# Some useful features for a TreeBuilder to have.
FAST = 'fast'
PERMISSIVE = 'permissive'
STRICT = 'strict'
XML = 'xml'
HTML = 'html'
HTML_5 = 'html5'
class TreeBuilderRegistry(object):
def __init__(self):
self.builders_for_feature = defaultdict(list)
self.builders = []
def register(self, treebuilder_class):
"""Register a treebuilder based on its advertised features."""
for feature in treebuilder_class.features:
self.builders_for_feature[feature].insert(0, treebuilder_class)
self.builders.insert(0, treebuilder_class)
def lookup(self, *features):
if len(self.builders) == 0:
# There are no builders at all.
return None
if len(features) == 0:
# They didn't ask for any features. Give them the most
# recently registered builder.
return self.builders[0]
# Go down the list of features in order, and eliminate any builders
# that don't match every feature.
features = list(features)
features.reverse()
candidates = None
candidate_set = None
while len(features) > 0:
feature = features.pop()
we_have_the_feature = self.builders_for_feature.get(feature, [])
if len(we_have_the_feature) > 0:
if candidates is None:
candidates = we_have_the_feature
candidate_set = set(candidates)
else:
# Eliminate any candidates that don't have this feature.
candidate_set = candidate_set.intersection(
set(we_have_the_feature))
# The only valid candidates are the ones in candidate_set.
# Go through the original list of candidates and pick the first one
# that's in candidate_set.
if candidate_set is None:
return None
for candidate in candidates:
if candidate in candidate_set:
return candidate
return None
# The BeautifulSoup class will take feature lists from developers and use them
# to look up builders in this registry.
builder_registry = TreeBuilderRegistry()
class TreeBuilder(object):
"""Turn a document into a Beautiful Soup object tree."""
NAME = "[Unknown tree builder]"
ALTERNATE_NAMES = []
features = []
is_xml = False
picklable = False
preserve_whitespace_tags = set()
empty_element_tags = None # A tag will be considered an empty-element
# tag when and only when it has no contents.
# A value for these tag/attribute combinations is a space- or
# comma-separated list of CDATA, rather than a single CDATA.
cdata_list_attributes = {}
def __init__(self):
self.soup = None
def reset(self):
pass
def can_be_empty_element(self, tag_name):
"""Might a tag with this name be an empty-element tag?
The final markup may or may not actually present this tag as
self-closing.
For instance: an HTMLBuilder does not consider a <p> tag to be
an empty-element tag (it's not in
HTMLBuilder.empty_element_tags). This means an empty <p> tag
will be presented as "<p></p>", not "<p />".
The default implementation has no opinion about which tags are
empty-element tags, so a tag will be presented as an
empty-element tag if and only if it has no contents.
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
be left alone.
"""
if self.empty_element_tags is None:
return True
return tag_name in self.empty_element_tags
def feed(self, markup):
raise NotImplementedError()
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None):
return markup, None, None, False
def test_fragment_to_document(self, fragment):
"""Wrap an HTML fragment to make it look like a document.
Different parsers do this differently. For instance, lxml
introduces an empty <head> tag, and html5lib
doesn't. Abstracting this away lets us write simple tests
which run HTML fragments through the parser and compare the
results against other HTML fragments.
This method should not be used outside of tests.
"""
return fragment
def set_up_substitutions(self, tag):
return False
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
"""Replaces class="foo bar" with class=["foo", "bar"]
Modifies its input in place.
"""
if not attrs:
return attrs
if self.cdata_list_attributes:
universal = self.cdata_list_attributes.get('*', [])
tag_specific = self.cdata_list_attributes.get(
tag_name.lower(), None)
for attr in attrs.keys():
if attr in universal or (tag_specific and attr in tag_specific):
# We have a "class"-type attribute whose string
# value is a whitespace-separated list of
# values. Split it into a list.
value = attrs[attr]
if isinstance(value, basestring):
values = whitespace_re.split(value)
else:
# html5lib sometimes calls setAttributes twice
# for the same tag when rearranging the parse
# tree. On the second call the attribute value
# here is already a list. If this happens,
# leave the value alone rather than trying to
# split it again.
values = value
attrs[attr] = values
return attrs
class SAXTreeBuilder(TreeBuilder):
"""A Beautiful Soup treebuilder that listens for SAX events."""
def feed(self, markup):
raise NotImplementedError()
def close(self):
pass
def startElement(self, name, attrs):
attrs = dict((key[1], value) for key, value in list(attrs.items()))
#print "Start %s, %r" % (name, attrs)
self.soup.handle_starttag(name, attrs)
def endElement(self, name):
#print "End %s" % name
self.soup.handle_endtag(name)
def startElementNS(self, nsTuple, nodeName, attrs):
# Throw away (ns, nodeName) for now.
self.startElement(nodeName, attrs)
def endElementNS(self, nsTuple, nodeName):
# Throw away (ns, nodeName) for now.
self.endElement(nodeName)
#handler.endElementNS((ns, node.nodeName), node.nodeName)
def startPrefixMapping(self, prefix, nodeValue):
# Ignore the prefix for now.
pass
def endPrefixMapping(self, prefix):
# Ignore the prefix for now.
# handler.endPrefixMapping(prefix)
pass
def characters(self, content):
self.soup.handle_data(content)
def startDocument(self):
pass
def endDocument(self):
pass
class HTMLTreeBuilder(TreeBuilder):
"""This TreeBuilder knows facts about HTML.
Such as which tags are empty-element tags.
"""
preserve_whitespace_tags = set(['pre', 'textarea'])
empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta',
'spacer', 'link', 'frame', 'base'])
# The HTML standard defines these attributes as containing a
# space-separated list of values, not a single value. That is,
# class="foo bar" means that the 'class' attribute has two values,
# 'foo' and 'bar', not the single value 'foo bar'. When we
# encounter one of these attributes, we will parse its value into
# a list of values if possible. Upon output, the list will be
# converted back into a string.
cdata_list_attributes = {
"*" : ['class', 'accesskey', 'dropzone'],
"a" : ['rel', 'rev'],
"link" : ['rel', 'rev'],
"td" : ["headers"],
"th" : ["headers"],
"td" : ["headers"],
"form" : ["accept-charset"],
"object" : ["archive"],
# These are HTML5 specific, as are *.accesskey and *.dropzone above.
"area" : ["rel"],
"icon" : ["sizes"],
"iframe" : ["sandbox"],
"output" : ["for"],
}
def set_up_substitutions(self, tag):
# We are only interested in <meta> tags
if tag.name != 'meta':
return False
http_equiv = tag.get('http-equiv')
content = tag.get('content')
charset = tag.get('charset')
# We are interested in <meta> tags that say what encoding the
# document was originally in. This means HTML 5-style <meta>
# tags that provide the "charset" attribute. It also means
# HTML 4-style <meta> tags that provide the "content"
# attribute and have "http-equiv" set to "content-type".
#
# In both cases we will replace the value of the appropriate
# attribute with a standin object that can take on any
# encoding.
meta_encoding = None
if charset is not None:
# HTML 5 style:
# <meta charset="utf8">
meta_encoding = charset
tag['charset'] = CharsetMetaAttributeValue(charset)
elif (content is not None and http_equiv is not None
and http_equiv.lower() == 'content-type'):
# HTML 4 style:
# <meta http-equiv="content-type" content="text/html; charset=utf8">
tag['content'] = ContentMetaAttributeValue(content)
return (meta_encoding is not None)
def register_treebuilders_from(module):
"""Copy TreeBuilders from the given module into this module."""
# I'm fairly sure this is not the best way to do this.
this_module = sys.modules['bs4.builder']
for name in module.__all__:
obj = getattr(module, name)
if issubclass(obj, TreeBuilder):
setattr(this_module, name, obj)
this_module.__all__.append(name)
# Register the builder while we're at it.
this_module.builder_registry.register(obj)
class ParserRejectedMarkup(Exception):
pass
# Builders are registered in reverse order of priority, so that custom
# builder registrations will take precedence. In general, we want lxml
# to take precedence over html5lib, because it's faster. And we only
# want to use HTMLParser as a last result.
from . import _htmlparser
register_treebuilders_from(_htmlparser)
try:
from . import _html5lib
register_treebuilders_from(_html5lib)
except ImportError:
# They don't have html5lib installed.
pass
try:
from . import _lxml
register_treebuilders_from(_lxml)
except ImportError:
# They don't have lxml installed.
pass

View File

@@ -0,0 +1,329 @@
__all__ = [
'HTML5TreeBuilder',
]
from pdb import set_trace
import warnings
from bs4.builder import (
PERMISSIVE,
HTML,
HTML_5,
HTMLTreeBuilder,
)
from bs4.element import (
NamespacedAttribute,
whitespace_re,
)
import html5lib
from html5lib.constants import namespaces
from bs4.element import (
Comment,
Doctype,
NavigableString,
Tag,
)
class HTML5TreeBuilder(HTMLTreeBuilder):
"""Use html5lib to build a tree."""
NAME = "html5lib"
features = [NAME, PERMISSIVE, HTML_5, HTML]
def prepare_markup(self, markup, user_specified_encoding,
document_declared_encoding=None, exclude_encodings=None):
# Store the user-specified encoding for use later on.
self.user_specified_encoding = user_specified_encoding
# document_declared_encoding and exclude_encodings aren't used
# ATM because the html5lib TreeBuilder doesn't use
# UnicodeDammit.
if exclude_encodings:
warnings.warn("You provided a value for exclude_encoding, but the html5lib tree builder doesn't support exclude_encoding.")
yield (markup, None, None, False)
# These methods are defined by Beautiful Soup.
def feed(self, markup):
if self.soup.parse_only is not None:
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
doc = parser.parse(markup, encoding=self.user_specified_encoding)
# Set the character encoding detected by the tokenizer.
if isinstance(markup, unicode):
# We need to special-case this because html5lib sets
# charEncoding to UTF-8 if it gets Unicode input.
doc.original_encoding = None
else:
doc.original_encoding = parser.tokenizer.stream.charEncoding[0]
def create_treebuilder(self, namespaceHTMLElements):
self.underlying_builder = TreeBuilderForHtml5lib(
self.soup, namespaceHTMLElements)
return self.underlying_builder
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<html><head></head><body>%s</body></html>' % fragment
class TreeBuilderForHtml5lib(html5lib.treebuilders._base.TreeBuilder):
def __init__(self, soup, namespaceHTMLElements):
self.soup = soup
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
def documentClass(self):
self.soup.reset()
return Element(self.soup, self.soup, None)
def insertDoctype(self, token):
name = token["name"]
publicId = token["publicId"]
systemId = token["systemId"]
doctype = Doctype.for_name_and_ids(name, publicId, systemId)
self.soup.object_was_parsed(doctype)
def elementClass(self, name, namespace):
tag = self.soup.new_tag(name, namespace)
return Element(tag, self.soup, namespace)
def commentClass(self, data):
return TextNode(Comment(data), self.soup)
def fragmentClass(self):
self.soup = BeautifulSoup("")
self.soup.name = "[document_fragment]"
return Element(self.soup, self.soup, None)
def appendChild(self, node):
# XXX This code is not covered by the BS4 tests.
self.soup.append(node.element)
def getDocument(self):
return self.soup
def getFragment(self):
return html5lib.treebuilders._base.TreeBuilder.getFragment(self).element
class AttrList(object):
def __init__(self, element):
self.element = element
self.attrs = dict(self.element.attrs)
def __iter__(self):
return list(self.attrs.items()).__iter__()
def __setitem__(self, name, value):
# If this attribute is a multi-valued attribute for this element,
# turn its value into a list.
list_attr = HTML5TreeBuilder.cdata_list_attributes
if (name in list_attr['*']
or (self.element.name in list_attr
and name in list_attr[self.element.name])):
value = whitespace_re.split(value)
self.element[name] = value
def items(self):
return list(self.attrs.items())
def keys(self):
return list(self.attrs.keys())
def __len__(self):
return len(self.attrs)
def __getitem__(self, name):
return self.attrs[name]
def __contains__(self, name):
return name in list(self.attrs.keys())
class Element(html5lib.treebuilders._base.Node):
def __init__(self, element, soup, namespace):
html5lib.treebuilders._base.Node.__init__(self, element.name)
self.element = element
self.soup = soup
self.namespace = namespace
def appendChild(self, node):
string_child = child = None
if isinstance(node, basestring):
# Some other piece of code decided to pass in a string
# instead of creating a TextElement object to contain the
# string.
string_child = child = node
elif isinstance(node, Tag):
# Some other piece of code decided to pass in a Tag
# instead of creating an Element object to contain the
# Tag.
child = node
elif node.element.__class__ == NavigableString:
string_child = child = node.element
else:
child = node.element
if not isinstance(child, basestring) and child.parent is not None:
node.element.extract()
if (string_child and self.element.contents
and self.element.contents[-1].__class__ == NavigableString):
# We are appending a string onto another string.
# TODO This has O(n^2) performance, for input like
# "a</a>a</a>a</a>..."
old_element = self.element.contents[-1]
new_element = self.soup.new_string(old_element + string_child)
old_element.replace_with(new_element)
self.soup._most_recent_element = new_element
else:
if isinstance(node, basestring):
# Create a brand new NavigableString from this string.
child = self.soup.new_string(node)
# Tell Beautiful Soup to act as if it parsed this element
# immediately after the parent's last descendant. (Or
# immediately after the parent, if it has no children.)
if self.element.contents:
most_recent_element = self.element._last_descendant(False)
elif self.element.next_element is not None:
# Something from further ahead in the parse tree is
# being inserted into this earlier element. This is
# very annoying because it means an expensive search
# for the last element in the tree.
most_recent_element = self.soup._last_descendant()
else:
most_recent_element = self.element
self.soup.object_was_parsed(
child, parent=self.element,
most_recent_element=most_recent_element)
def getAttributes(self):
return AttrList(self.element)
def setAttributes(self, attributes):
if attributes is not None and len(attributes) > 0:
converted_attributes = []
for name, value in list(attributes.items()):
if isinstance(name, tuple):
new_name = NamespacedAttribute(*name)
del attributes[name]
attributes[new_name] = value
self.soup.builder._replace_cdata_list_attribute_values(
self.name, attributes)
for name, value in attributes.items():
self.element[name] = value
# The attributes may contain variables that need substitution.
# Call set_up_substitutions manually.
#
# The Tag constructor called this method when the Tag was created,
# but we just set/changed the attributes, so call it again.
self.soup.builder.set_up_substitutions(self.element)
attributes = property(getAttributes, setAttributes)
def insertText(self, data, insertBefore=None):
if insertBefore:
text = TextNode(self.soup.new_string(data), self.soup)
self.insertBefore(data, insertBefore)
else:
self.appendChild(data)
def insertBefore(self, node, refNode):
index = self.element.index(refNode.element)
if (node.element.__class__ == NavigableString and self.element.contents
and self.element.contents[index-1].__class__ == NavigableString):
# (See comments in appendChild)
old_node = self.element.contents[index-1]
new_str = self.soup.new_string(old_node + node.element)
old_node.replace_with(new_str)
else:
self.element.insert(index, node.element)
node.parent = self
def removeChild(self, node):
node.element.extract()
def reparentChildren(self, new_parent):
"""Move all of this tag's children into another tag."""
# print "MOVE", self.element.contents
# print "FROM", self.element
# print "TO", new_parent.element
element = self.element
new_parent_element = new_parent.element
# Determine what this tag's next_element will be once all the children
# are removed.
final_next_element = element.next_sibling
new_parents_last_descendant = new_parent_element._last_descendant(False, False)
if len(new_parent_element.contents) > 0:
# The new parent already contains children. We will be
# appending this tag's children to the end.
new_parents_last_child = new_parent_element.contents[-1]
new_parents_last_descendant_next_element = new_parents_last_descendant.next_element
else:
# The new parent contains no children.
new_parents_last_child = None
new_parents_last_descendant_next_element = new_parent_element.next_element
to_append = element.contents
append_after = new_parent_element.contents
if len(to_append) > 0:
# Set the first child's previous_element and previous_sibling
# to elements within the new parent
first_child = to_append[0]
if new_parents_last_descendant:
first_child.previous_element = new_parents_last_descendant
else:
first_child.previous_element = new_parent_element
first_child.previous_sibling = new_parents_last_child
if new_parents_last_descendant:
new_parents_last_descendant.next_element = first_child
else:
new_parent_element.next_element = first_child
if new_parents_last_child:
new_parents_last_child.next_sibling = first_child
# Fix the last child's next_element and next_sibling
last_child = to_append[-1]
last_child.next_element = new_parents_last_descendant_next_element
if new_parents_last_descendant_next_element:
new_parents_last_descendant_next_element.previous_element = last_child
last_child.next_sibling = None
for child in to_append:
child.parent = new_parent_element
new_parent_element.contents.append(child)
# Now that this element has no children, change its .next_element.
element.contents = []
element.next_element = final_next_element
# print "DONE WITH MOVE"
# print "FROM", self.element
# print "TO", new_parent_element
def cloneNode(self):
tag = self.soup.new_tag(self.element.name, self.namespace)
node = Element(tag, self.soup, self.namespace)
for key,value in self.attributes:
node.attributes[key] = value
return node
def hasContent(self):
return self.element.contents
def getNameTuple(self):
if self.namespace == None:
return namespaces["html"], self.name
else:
return self.namespace, self.name
nameTuple = property(getNameTuple)
class TextNode(Element):
def __init__(self, element, soup):
html5lib.treebuilders._base.Node.__init__(self, None)
self.element = element
self.soup = soup
def cloneNode(self):
raise NotImplementedError

View File

@@ -0,0 +1,262 @@
"""Use the HTMLParser library to parse HTML files that aren't too bad."""
__all__ = [
'HTMLParserTreeBuilder',
]
from HTMLParser import HTMLParser
try:
from HTMLParser import HTMLParseError
except ImportError, e:
# HTMLParseError is removed in Python 3.5. Since it can never be
# thrown in 3.5, we can just define our own class as a placeholder.
class HTMLParseError(Exception):
pass
import sys
import warnings
# Starting in Python 3.2, the HTMLParser constructor takes a 'strict'
# argument, which we'd like to set to False. Unfortunately,
# http://bugs.python.org/issue13273 makes strict=True a better bet
# before Python 3.2.3.
#
# At the end of this file, we monkeypatch HTMLParser so that
# strict=True works well on Python 3.2.2.
major, minor, release = sys.version_info[:3]
CONSTRUCTOR_TAKES_STRICT = major == 3 and minor == 2 and release >= 3
CONSTRUCTOR_STRICT_IS_DEPRECATED = major == 3 and minor == 3
CONSTRUCTOR_TAKES_CONVERT_CHARREFS = major == 3 and minor >= 4
from bs4.element import (
CData,
Comment,
Declaration,
Doctype,
ProcessingInstruction,
)
from bs4.dammit import EntitySubstitution, UnicodeDammit
from bs4.builder import (
HTML,
HTMLTreeBuilder,
STRICT,
)
HTMLPARSER = 'html.parser'
class BeautifulSoupHTMLParser(HTMLParser):
def handle_starttag(self, name, attrs):
# XXX namespace
attr_dict = {}
for key, value in attrs:
# Change None attribute values to the empty string
# for consistency with the other tree builders.
if value is None:
value = ''
attr_dict[key] = value
attrvalue = '""'
self.soup.handle_starttag(name, None, None, attr_dict)
def handle_endtag(self, name):
self.soup.handle_endtag(name)
def handle_data(self, data):
self.soup.handle_data(data)
def handle_charref(self, name):
# XXX workaround for a bug in HTMLParser. Remove this once
# it's fixed in all supported versions.
# http://bugs.python.org/issue13633
if name.startswith('x'):
real_name = int(name.lstrip('x'), 16)
elif name.startswith('X'):
real_name = int(name.lstrip('X'), 16)
else:
real_name = int(name)
try:
data = unichr(real_name)
except (ValueError, OverflowError), e:
data = u"\N{REPLACEMENT CHARACTER}"
self.handle_data(data)
def handle_entityref(self, name):
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
if character is not None:
data = character
else:
data = "&%s;" % name
self.handle_data(data)
def handle_comment(self, data):
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(Comment)
def handle_decl(self, data):
self.soup.endData()
if data.startswith("DOCTYPE "):
data = data[len("DOCTYPE "):]
elif data == 'DOCTYPE':
# i.e. "<!DOCTYPE>"
data = ''
self.soup.handle_data(data)
self.soup.endData(Doctype)
def unknown_decl(self, data):
if data.upper().startswith('CDATA['):
cls = CData
data = data[len('CDATA['):]
else:
cls = Declaration
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(cls)
def handle_pi(self, data):
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(ProcessingInstruction)
class HTMLParserTreeBuilder(HTMLTreeBuilder):
is_xml = False
picklable = True
NAME = HTMLPARSER
features = [NAME, HTML, STRICT]
def __init__(self, *args, **kwargs):
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
kwargs['strict'] = False
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
kwargs['convert_charrefs'] = False
self.parser_args = (args, kwargs)
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None, exclude_encodings=None):
"""
:return: A 4-tuple (markup, original encoding, encoding
declared within markup, whether any characters had to be
replaced with REPLACEMENT CHARACTER).
"""
if isinstance(markup, unicode):
yield (markup, None, None, False)
return
try_encodings = [user_specified_encoding, document_declared_encoding]
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
exclude_encodings=exclude_encodings)
yield (dammit.markup, dammit.original_encoding,
dammit.declared_html_encoding,
dammit.contains_replacement_characters)
def feed(self, markup):
args, kwargs = self.parser_args
parser = BeautifulSoupHTMLParser(*args, **kwargs)
parser.soup = self.soup
try:
parser.feed(markup)
except HTMLParseError, e:
warnings.warn(RuntimeWarning(
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
raise e
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a
# string.
#
# XXX This code can be removed once most Python 3 users are on 3.2.3.
if major == 3 and minor == 2 and not CONSTRUCTOR_TAKES_STRICT:
import re
attrfind_tolerant = re.compile(
r'\s*((?<=[\'"\s])[^\s/>][^\s/=>]*)(\s*=+\s*'
r'(\'[^\']*\'|"[^"]*"|(?![\'"])[^>\s]*))?')
HTMLParserTreeBuilder.attrfind_tolerant = attrfind_tolerant
locatestarttagend = re.compile(r"""
<[a-zA-Z][-.a-zA-Z0-9:_]* # tag name
(?:\s+ # whitespace before attribute name
(?:[a-zA-Z_][-.:a-zA-Z0-9_]* # attribute name
(?:\s*=\s* # value indicator
(?:'[^']*' # LITA-enclosed value
|\"[^\"]*\" # LIT-enclosed value
|[^'\">\s]+ # bare value
)
)?
)
)*
\s* # trailing whitespace
""", re.VERBOSE)
BeautifulSoupHTMLParser.locatestarttagend = locatestarttagend
from html.parser import tagfind, attrfind
def parse_starttag(self, i):
self.__starttag_text = None
endpos = self.check_for_whole_start_tag(i)
if endpos < 0:
return endpos
rawdata = self.rawdata
self.__starttag_text = rawdata[i:endpos]
# Now parse the data between i+1 and j into a tag and attrs
attrs = []
match = tagfind.match(rawdata, i+1)
assert match, 'unexpected call to parse_starttag()'
k = match.end()
self.lasttag = tag = rawdata[i+1:k].lower()
while k < endpos:
if self.strict:
m = attrfind.match(rawdata, k)
else:
m = attrfind_tolerant.match(rawdata, k)
if not m:
break
attrname, rest, attrvalue = m.group(1, 2, 3)
if not rest:
attrvalue = None
elif attrvalue[:1] == '\'' == attrvalue[-1:] or \
attrvalue[:1] == '"' == attrvalue[-1:]:
attrvalue = attrvalue[1:-1]
if attrvalue:
attrvalue = self.unescape(attrvalue)
attrs.append((attrname.lower(), attrvalue))
k = m.end()
end = rawdata[k:endpos].strip()
if end not in (">", "/>"):
lineno, offset = self.getpos()
if "\n" in self.__starttag_text:
lineno = lineno + self.__starttag_text.count("\n")
offset = len(self.__starttag_text) \
- self.__starttag_text.rfind("\n")
else:
offset = offset + len(self.__starttag_text)
if self.strict:
self.error("junk characters in start tag: %r"
% (rawdata[k:endpos][:20],))
self.handle_data(rawdata[i:endpos])
return endpos
if end.endswith('/>'):
# XHTML-style empty tag: <span attr="value" />
self.handle_startendtag(tag, attrs)
else:
self.handle_starttag(tag, attrs)
if tag in self.CDATA_CONTENT_ELEMENTS:
self.set_cdata_mode(tag)
return endpos
def set_cdata_mode(self, elem):
self.cdata_elem = elem.lower()
self.interesting = re.compile(r'</\s*%s\s*>' % self.cdata_elem, re.I)
BeautifulSoupHTMLParser.parse_starttag = parse_starttag
BeautifulSoupHTMLParser.set_cdata_mode = set_cdata_mode
CONSTRUCTOR_TAKES_STRICT = True

248
lib/bs4/builder/_lxml.py Normal file
View File

@@ -0,0 +1,248 @@
__all__ = [
'LXMLTreeBuilderForXML',
'LXMLTreeBuilder',
]
from io import BytesIO
from StringIO import StringIO
import collections
from lxml import etree
from bs4.element import (
Comment,
Doctype,
NamespacedAttribute,
ProcessingInstruction,
)
from bs4.builder import (
FAST,
HTML,
HTMLTreeBuilder,
PERMISSIVE,
ParserRejectedMarkup,
TreeBuilder,
XML)
from bs4.dammit import EncodingDetector
LXML = 'lxml'
class LXMLTreeBuilderForXML(TreeBuilder):
DEFAULT_PARSER_CLASS = etree.XMLParser
is_xml = True
NAME = "lxml-xml"
ALTERNATE_NAMES = ["xml"]
# Well, it's permissive by XML parser standards.
features = [NAME, LXML, XML, FAST, PERMISSIVE]
CHUNK_SIZE = 512
# This namespace mapping is specified in the XML Namespace
# standard.
DEFAULT_NSMAPS = {'http://www.w3.org/XML/1998/namespace' : "xml"}
def default_parser(self, encoding):
# This can either return a parser object or a class, which
# will be instantiated with default arguments.
if self._default_parser is not None:
return self._default_parser
return etree.XMLParser(
target=self, strip_cdata=False, recover=True, encoding=encoding)
def parser_for(self, encoding):
# Use the default parser.
parser = self.default_parser(encoding)
if isinstance(parser, collections.Callable):
# Instantiate the parser with default arguments
parser = parser(target=self, strip_cdata=False, encoding=encoding)
return parser
def __init__(self, parser=None, empty_element_tags=None):
# TODO: Issue a warning if parser is present but not a
# callable, since that means there's no way to create new
# parsers for different encodings.
self._default_parser = parser
if empty_element_tags is not None:
self.empty_element_tags = set(empty_element_tags)
self.soup = None
self.nsmaps = [self.DEFAULT_NSMAPS]
def _getNsTag(self, tag):
# Split the namespace URL out of a fully-qualified lxml tag
# name. Copied from lxml's src/lxml/sax.py.
if tag[0] == '{':
return tuple(tag[1:].split('}', 1))
else:
return (None, tag)
def prepare_markup(self, markup, user_specified_encoding=None,
exclude_encodings=None,
document_declared_encoding=None):
"""
:yield: A series of 4-tuples.
(markup, encoding, declared encoding,
has undergone character replacement)
Each 4-tuple represents a strategy for parsing the document.
"""
if isinstance(markup, unicode):
# We were given Unicode. Maybe lxml can parse Unicode on
# this system?
yield markup, None, document_declared_encoding, False
if isinstance(markup, unicode):
# No, apparently not. Convert the Unicode to UTF-8 and
# tell lxml to parse it as UTF-8.
yield (markup.encode("utf8"), "utf8",
document_declared_encoding, False)
# Instead of using UnicodeDammit to convert the bytestring to
# Unicode using different encodings, use EncodingDetector to
# iterate over the encodings, and tell lxml to try to parse
# the document as each one in turn.
is_html = not self.is_xml
try_encodings = [user_specified_encoding, document_declared_encoding]
detector = EncodingDetector(
markup, try_encodings, is_html, exclude_encodings)
for encoding in detector.encodings:
yield (detector.markup, encoding, document_declared_encoding, False)
def feed(self, markup):
if isinstance(markup, bytes):
markup = BytesIO(markup)
elif isinstance(markup, unicode):
markup = StringIO(markup)
# Call feed() at least once, even if the markup is empty,
# or the parser won't be initialized.
data = markup.read(self.CHUNK_SIZE)
try:
self.parser = self.parser_for(self.soup.original_encoding)
self.parser.feed(data)
while len(data) != 0:
# Now call feed() on the rest of the data, chunk by chunk.
data = markup.read(self.CHUNK_SIZE)
if len(data) != 0:
self.parser.feed(data)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
raise ParserRejectedMarkup(str(e))
def close(self):
self.nsmaps = [self.DEFAULT_NSMAPS]
def start(self, name, attrs, nsmap={}):
# Make sure attrs is a mutable dict--lxml may send an immutable dictproxy.
attrs = dict(attrs)
nsprefix = None
# Invert each namespace map as it comes in.
if len(self.nsmaps) > 1:
# There are no new namespaces for this tag, but
# non-default namespaces are in play, so we need a
# separate tag stack to know when they end.
self.nsmaps.append(None)
elif len(nsmap) > 0:
# A new namespace mapping has come into play.
inverted_nsmap = dict((value, key) for key, value in nsmap.items())
self.nsmaps.append(inverted_nsmap)
# Also treat the namespace mapping as a set of attributes on the
# tag, so we can recreate it later.
attrs = attrs.copy()
for prefix, namespace in nsmap.items():
attribute = NamespacedAttribute(
"xmlns", prefix, "http://www.w3.org/2000/xmlns/")
attrs[attribute] = namespace
# Namespaces are in play. Find any attributes that came in
# from lxml with namespaces attached to their names, and
# turn then into NamespacedAttribute objects.
new_attrs = {}
for attr, value in attrs.items():
namespace, attr = self._getNsTag(attr)
if namespace is None:
new_attrs[attr] = value
else:
nsprefix = self._prefix_for_namespace(namespace)
attr = NamespacedAttribute(nsprefix, attr, namespace)
new_attrs[attr] = value
attrs = new_attrs
namespace, name = self._getNsTag(name)
nsprefix = self._prefix_for_namespace(namespace)
self.soup.handle_starttag(name, namespace, nsprefix, attrs)
def _prefix_for_namespace(self, namespace):
"""Find the currently active prefix for the given namespace."""
if namespace is None:
return None
for inverted_nsmap in reversed(self.nsmaps):
if inverted_nsmap is not None and namespace in inverted_nsmap:
return inverted_nsmap[namespace]
return None
def end(self, name):
self.soup.endData()
completed_tag = self.soup.tagStack[-1]
namespace, name = self._getNsTag(name)
nsprefix = None
if namespace is not None:
for inverted_nsmap in reversed(self.nsmaps):
if inverted_nsmap is not None and namespace in inverted_nsmap:
nsprefix = inverted_nsmap[namespace]
break
self.soup.handle_endtag(name, nsprefix)
if len(self.nsmaps) > 1:
# This tag, or one of its parents, introduced a namespace
# mapping, so pop it off the stack.
self.nsmaps.pop()
def pi(self, target, data):
self.soup.endData()
self.soup.handle_data(target + ' ' + data)
self.soup.endData(ProcessingInstruction)
def data(self, content):
self.soup.handle_data(content)
def doctype(self, name, pubid, system):
self.soup.endData()
doctype = Doctype.for_name_and_ids(name, pubid, system)
self.soup.object_was_parsed(doctype)
def comment(self, content):
"Handle comments as Comment objects."
self.soup.endData()
self.soup.handle_data(content)
self.soup.endData(Comment)
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
NAME = LXML
ALTERNATE_NAMES = ["lxml-html"]
features = ALTERNATE_NAMES + [NAME, HTML, FAST, PERMISSIVE]
is_xml = False
def default_parser(self, encoding):
return etree.HTMLParser
def feed(self, markup):
encoding = self.soup.original_encoding
try:
self.parser = self.parser_for(encoding)
self.parser.feed(markup)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
raise ParserRejectedMarkup(str(e))
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<html><body>%s</body></html>' % fragment

839
lib/bs4/dammit.py Normal file
View File

@@ -0,0 +1,839 @@
# -*- coding: utf-8 -*-
"""Beautiful Soup bonus library: Unicode, Dammit
This library converts a bytestream to Unicode through any means
necessary. It is heavily based on code from Mark Pilgrim's Universal
Feed Parser. It works best on XML and HTML, but it does not rewrite the
XML or HTML to reflect a new encoding; that's the tree builder's job.
"""
from pdb import set_trace
import codecs
from htmlentitydefs import codepoint2name
import re
import logging
import string
# Import a library to autodetect character encodings.
chardet_type = None
try:
# First try the fast C implementation.
# PyPI package: cchardet
import cchardet
def chardet_dammit(s):
return cchardet.detect(s)['encoding']
except ImportError:
try:
# Fall back to the pure Python implementation
# Debian package: python-chardet
# PyPI package: chardet
import chardet
def chardet_dammit(s):
return chardet.detect(s)['encoding']
#import chardet.constants
#chardet.constants._debug = 1
except ImportError:
# No chardet available.
def chardet_dammit(s):
return None
# Available from http://cjkpython.i18n.org/.
try:
import iconv_codec
except ImportError:
pass
xml_encoding_re = re.compile(
'^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I)
html_meta_re = re.compile(
'<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
class EntitySubstitution(object):
"""Substitute XML or HTML entities for the corresponding characters."""
def _populate_class_variables():
lookup = {}
reverse_lookup = {}
characters_for_re = []
for codepoint, name in list(codepoint2name.items()):
character = unichr(codepoint)
if codepoint != 34:
# There's no point in turning the quotation mark into
# &quot;, unless it happens within an attribute value, which
# is handled elsewhere.
characters_for_re.append(character)
lookup[character] = name
# But we do want to turn &quot; into the quotation mark.
reverse_lookup[name] = character
re_definition = "[%s]" % "".join(characters_for_re)
return lookup, reverse_lookup, re.compile(re_definition)
(CHARACTER_TO_HTML_ENTITY, HTML_ENTITY_TO_CHARACTER,
CHARACTER_TO_HTML_ENTITY_RE) = _populate_class_variables()
CHARACTER_TO_XML_ENTITY = {
"'": "apos",
'"': "quot",
"&": "amp",
"<": "lt",
">": "gt",
}
BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
"&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)"
")")
AMPERSAND_OR_BRACKET = re.compile("([<>&])")
@classmethod
def _substitute_html_entity(cls, matchobj):
entity = cls.CHARACTER_TO_HTML_ENTITY.get(matchobj.group(0))
return "&%s;" % entity
@classmethod
def _substitute_xml_entity(cls, matchobj):
"""Used with a regular expression to substitute the
appropriate XML entity for an XML special character."""
entity = cls.CHARACTER_TO_XML_ENTITY[matchobj.group(0)]
return "&%s;" % entity
@classmethod
def quoted_attribute_value(self, value):
"""Make a value into a quoted XML attribute, possibly escaping it.
Most strings will be quoted using double quotes.
Bob's Bar -> "Bob's Bar"
If a string contains double quotes, it will be quoted using
single quotes.
Welcome to "my bar" -> 'Welcome to "my bar"'
If a string contains both single and double quotes, the
double quotes will be escaped, and the string will be quoted
using double quotes.
Welcome to "Bob's Bar" -> "Welcome to &quot;Bob's bar&quot;
"""
quote_with = '"'
if '"' in value:
if "'" in value:
# The string contains both single and double
# quotes. Turn the double quotes into
# entities. We quote the double quotes rather than
# the single quotes because the entity name is
# "&quot;" whether this is HTML or XML. If we
# quoted the single quotes, we'd have to decide
# between &apos; and &squot;.
replace_with = "&quot;"
value = value.replace('"', replace_with)
else:
# There are double quotes but no single quotes.
# We can use single quotes to quote the attribute.
quote_with = "'"
return quote_with + value + quote_with
@classmethod
def substitute_xml(cls, value, make_quoted_attribute=False):
"""Substitute XML entities for special XML characters.
:param value: A string to be substituted. The less-than sign
will become &lt;, the greater-than sign will become &gt;,
and any ampersands will become &amp;. If you want ampersands
that appear to be part of an entity definition to be left
alone, use substitute_xml_containing_entities() instead.
:param make_quoted_attribute: If True, then the string will be
quoted, as befits an attribute value.
"""
# Escape angle brackets and ampersands.
value = cls.AMPERSAND_OR_BRACKET.sub(
cls._substitute_xml_entity, value)
if make_quoted_attribute:
value = cls.quoted_attribute_value(value)
return value
@classmethod
def substitute_xml_containing_entities(
cls, value, make_quoted_attribute=False):
"""Substitute XML entities for special XML characters.
:param value: A string to be substituted. The less-than sign will
become &lt;, the greater-than sign will become &gt;, and any
ampersands that are not part of an entity defition will
become &amp;.
:param make_quoted_attribute: If True, then the string will be
quoted, as befits an attribute value.
"""
# Escape angle brackets, and ampersands that aren't part of
# entities.
value = cls.BARE_AMPERSAND_OR_BRACKET.sub(
cls._substitute_xml_entity, value)
if make_quoted_attribute:
value = cls.quoted_attribute_value(value)
return value
@classmethod
def substitute_html(cls, s):
"""Replace certain Unicode characters with named HTML entities.
This differs from data.encode(encoding, 'xmlcharrefreplace')
in that the goal is to make the result more readable (to those
with ASCII displays) rather than to recover from
errors. There's absolutely nothing wrong with a UTF-8 string
containg a LATIN SMALL LETTER E WITH ACUTE, but replacing that
character with "&eacute;" will make it more readable to some
people.
"""
return cls.CHARACTER_TO_HTML_ENTITY_RE.sub(
cls._substitute_html_entity, s)
class EncodingDetector:
"""Suggests a number of possible encodings for a bytestring.
Order of precedence:
1. Encodings you specifically tell EncodingDetector to try first
(the override_encodings argument to the constructor).
2. An encoding declared within the bytestring itself, either in an
XML declaration (if the bytestring is to be interpreted as an XML
document), or in a <meta> tag (if the bytestring is to be
interpreted as an HTML document.)
3. An encoding detected through textual analysis by chardet,
cchardet, or a similar external library.
4. UTF-8.
5. Windows-1252.
"""
def __init__(self, markup, override_encodings=None, is_html=False,
exclude_encodings=None):
self.override_encodings = override_encodings or []
exclude_encodings = exclude_encodings or []
self.exclude_encodings = set([x.lower() for x in exclude_encodings])
self.chardet_encoding = None
self.is_html = is_html
self.declared_encoding = None
# First order of business: strip a byte-order mark.
self.markup, self.sniffed_encoding = self.strip_byte_order_mark(markup)
def _usable(self, encoding, tried):
if encoding is not None:
encoding = encoding.lower()
if encoding in self.exclude_encodings:
return False
if encoding not in tried:
tried.add(encoding)
return True
return False
@property
def encodings(self):
"""Yield a number of encodings that might work for this markup."""
tried = set()
for e in self.override_encodings:
if self._usable(e, tried):
yield e
# Did the document originally start with a byte-order mark
# that indicated its encoding?
if self._usable(self.sniffed_encoding, tried):
yield self.sniffed_encoding
# Look within the document for an XML or HTML encoding
# declaration.
if self.declared_encoding is None:
self.declared_encoding = self.find_declared_encoding(
self.markup, self.is_html)
if self._usable(self.declared_encoding, tried):
yield self.declared_encoding
# Use third-party character set detection to guess at the
# encoding.
if self.chardet_encoding is None:
self.chardet_encoding = chardet_dammit(self.markup)
if self._usable(self.chardet_encoding, tried):
yield self.chardet_encoding
# As a last-ditch effort, try utf-8 and windows-1252.
for e in ('utf-8', 'windows-1252'):
if self._usable(e, tried):
yield e
@classmethod
def strip_byte_order_mark(cls, data):
"""If a byte-order mark is present, strip it and return the encoding it implies."""
encoding = None
if isinstance(data, unicode):
# Unicode data cannot have a byte-order mark.
return data, encoding
if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
and (data[2:4] != '\x00\x00'):
encoding = 'utf-16be'
data = data[2:]
elif (len(data) >= 4) and (data[:2] == b'\xff\xfe') \
and (data[2:4] != '\x00\x00'):
encoding = 'utf-16le'
data = data[2:]
elif data[:3] == b'\xef\xbb\xbf':
encoding = 'utf-8'
data = data[3:]
elif data[:4] == b'\x00\x00\xfe\xff':
encoding = 'utf-32be'
data = data[4:]
elif data[:4] == b'\xff\xfe\x00\x00':
encoding = 'utf-32le'
data = data[4:]
return data, encoding
@classmethod
def find_declared_encoding(cls, markup, is_html=False, search_entire_document=False):
"""Given a document, tries to find its declared encoding.
An XML encoding is declared at the beginning of the document.
An HTML encoding is declared in a <meta> tag, hopefully near the
beginning of the document.
"""
if search_entire_document:
xml_endpos = html_endpos = len(markup)
else:
xml_endpos = 1024
html_endpos = max(2048, int(len(markup) * 0.05))
declared_encoding = None
declared_encoding_match = xml_encoding_re.search(markup, endpos=xml_endpos)
if not declared_encoding_match and is_html:
declared_encoding_match = html_meta_re.search(markup, endpos=html_endpos)
if declared_encoding_match is not None:
declared_encoding = declared_encoding_match.groups()[0].decode(
'ascii', 'replace')
if declared_encoding:
return declared_encoding.lower()
return None
class UnicodeDammit:
"""A class for detecting the encoding of a *ML document and
converting it to a Unicode string. If the source encoding is
windows-1252, can replace MS smart quotes with their HTML or XML
equivalents."""
# This dictionary maps commonly seen values for "charset" in HTML
# meta tags to the corresponding Python codec names. It only covers
# values that aren't in Python's aliases and can't be determined
# by the heuristics in find_codec.
CHARSET_ALIASES = {"macintosh": "mac-roman",
"x-sjis": "shift-jis"}
ENCODINGS_WITH_SMART_QUOTES = [
"windows-1252",
"iso-8859-1",
"iso-8859-2",
]
def __init__(self, markup, override_encodings=[],
smart_quotes_to=None, is_html=False, exclude_encodings=[]):
self.smart_quotes_to = smart_quotes_to
self.tried_encodings = []
self.contains_replacement_characters = False
self.is_html = is_html
self.detector = EncodingDetector(
markup, override_encodings, is_html, exclude_encodings)
# Short-circuit if the data is in Unicode to begin with.
if isinstance(markup, unicode) or markup == '':
self.markup = markup
self.unicode_markup = unicode(markup)
self.original_encoding = None
return
# The encoding detector may have stripped a byte-order mark.
# Use the stripped markup from this point on.
self.markup = self.detector.markup
u = None
for encoding in self.detector.encodings:
markup = self.detector.markup
u = self._convert_from(encoding)
if u is not None:
break
if not u:
# None of the encodings worked. As an absolute last resort,
# try them again with character replacement.
for encoding in self.detector.encodings:
if encoding != "ascii":
u = self._convert_from(encoding, "replace")
if u is not None:
logging.warning(
"Some characters could not be decoded, and were "
"replaced with REPLACEMENT CHARACTER.")
self.contains_replacement_characters = True
break
# If none of that worked, we could at this point force it to
# ASCII, but that would destroy so much data that I think
# giving up is better.
self.unicode_markup = u
if not u:
self.original_encoding = None
def _sub_ms_char(self, match):
"""Changes a MS smart quote character to an XML or HTML
entity, or an ASCII character."""
orig = match.group(1)
if self.smart_quotes_to == 'ascii':
sub = self.MS_CHARS_TO_ASCII.get(orig).encode()
else:
sub = self.MS_CHARS.get(orig)
if type(sub) == tuple:
if self.smart_quotes_to == 'xml':
sub = '&#x'.encode() + sub[1].encode() + ';'.encode()
else:
sub = '&'.encode() + sub[0].encode() + ';'.encode()
else:
sub = sub.encode()
return sub
def _convert_from(self, proposed, errors="strict"):
proposed = self.find_codec(proposed)
if not proposed or (proposed, errors) in self.tried_encodings:
return None
self.tried_encodings.append((proposed, errors))
markup = self.markup
# Convert smart quotes to HTML if coming from an encoding
# that might have them.
if (self.smart_quotes_to is not None
and proposed in self.ENCODINGS_WITH_SMART_QUOTES):
smart_quotes_re = b"([\x80-\x9f])"
smart_quotes_compiled = re.compile(smart_quotes_re)
markup = smart_quotes_compiled.sub(self._sub_ms_char, markup)
try:
#print "Trying to convert document to %s (errors=%s)" % (
# proposed, errors)
u = self._to_unicode(markup, proposed, errors)
self.markup = u
self.original_encoding = proposed
except Exception as e:
#print "That didn't work!"
#print e
return None
#print "Correct encoding: %s" % proposed
return self.markup
def _to_unicode(self, data, encoding, errors="strict"):
'''Given a string and its encoding, decodes the string into Unicode.
%encoding is a string recognized by encodings.aliases'''
return unicode(data, encoding, errors)
@property
def declared_html_encoding(self):
if not self.is_html:
return None
return self.detector.declared_encoding
def find_codec(self, charset):
value = (self._codec(self.CHARSET_ALIASES.get(charset, charset))
or (charset and self._codec(charset.replace("-", "")))
or (charset and self._codec(charset.replace("-", "_")))
or (charset and charset.lower())
or charset
)
if value:
return value.lower()
return None
def _codec(self, charset):
if not charset:
return charset
codec = None
try:
codecs.lookup(charset)
codec = charset
except (LookupError, ValueError):
pass
return codec
# A partial mapping of ISO-Latin-1 to HTML entities/XML numeric entities.
MS_CHARS = {b'\x80': ('euro', '20AC'),
b'\x81': ' ',
b'\x82': ('sbquo', '201A'),
b'\x83': ('fnof', '192'),
b'\x84': ('bdquo', '201E'),
b'\x85': ('hellip', '2026'),
b'\x86': ('dagger', '2020'),
b'\x87': ('Dagger', '2021'),
b'\x88': ('circ', '2C6'),
b'\x89': ('permil', '2030'),
b'\x8A': ('Scaron', '160'),
b'\x8B': ('lsaquo', '2039'),
b'\x8C': ('OElig', '152'),
b'\x8D': '?',
b'\x8E': ('#x17D', '17D'),
b'\x8F': '?',
b'\x90': '?',
b'\x91': ('lsquo', '2018'),
b'\x92': ('rsquo', '2019'),
b'\x93': ('ldquo', '201C'),
b'\x94': ('rdquo', '201D'),
b'\x95': ('bull', '2022'),
b'\x96': ('ndash', '2013'),
b'\x97': ('mdash', '2014'),
b'\x98': ('tilde', '2DC'),
b'\x99': ('trade', '2122'),
b'\x9a': ('scaron', '161'),
b'\x9b': ('rsaquo', '203A'),
b'\x9c': ('oelig', '153'),
b'\x9d': '?',
b'\x9e': ('#x17E', '17E'),
b'\x9f': ('Yuml', ''),}
# A parochial partial mapping of ISO-Latin-1 to ASCII. Contains
# horrors like stripping diacritical marks to turn á into a, but also
# contains non-horrors like turning “ into ".
MS_CHARS_TO_ASCII = {
b'\x80' : 'EUR',
b'\x81' : ' ',
b'\x82' : ',',
b'\x83' : 'f',
b'\x84' : ',,',
b'\x85' : '...',
b'\x86' : '+',
b'\x87' : '++',
b'\x88' : '^',
b'\x89' : '%',
b'\x8a' : 'S',
b'\x8b' : '<',
b'\x8c' : 'OE',
b'\x8d' : '?',
b'\x8e' : 'Z',
b'\x8f' : '?',
b'\x90' : '?',
b'\x91' : "'",
b'\x92' : "'",
b'\x93' : '"',
b'\x94' : '"',
b'\x95' : '*',
b'\x96' : '-',
b'\x97' : '--',
b'\x98' : '~',
b'\x99' : '(TM)',
b'\x9a' : 's',
b'\x9b' : '>',
b'\x9c' : 'oe',
b'\x9d' : '?',
b'\x9e' : 'z',
b'\x9f' : 'Y',
b'\xa0' : ' ',
b'\xa1' : '!',
b'\xa2' : 'c',
b'\xa3' : 'GBP',
b'\xa4' : '$', #This approximation is especially parochial--this is the
#generic currency symbol.
b'\xa5' : 'YEN',
b'\xa6' : '|',
b'\xa7' : 'S',
b'\xa8' : '..',
b'\xa9' : '',
b'\xaa' : '(th)',
b'\xab' : '<<',
b'\xac' : '!',
b'\xad' : ' ',
b'\xae' : '(R)',
b'\xaf' : '-',
b'\xb0' : 'o',
b'\xb1' : '+-',
b'\xb2' : '2',
b'\xb3' : '3',
b'\xb4' : ("'", 'acute'),
b'\xb5' : 'u',
b'\xb6' : 'P',
b'\xb7' : '*',
b'\xb8' : ',',
b'\xb9' : '1',
b'\xba' : '(th)',
b'\xbb' : '>>',
b'\xbc' : '1/4',
b'\xbd' : '1/2',
b'\xbe' : '3/4',
b'\xbf' : '?',
b'\xc0' : 'A',
b'\xc1' : 'A',
b'\xc2' : 'A',
b'\xc3' : 'A',
b'\xc4' : 'A',
b'\xc5' : 'A',
b'\xc6' : 'AE',
b'\xc7' : 'C',
b'\xc8' : 'E',
b'\xc9' : 'E',
b'\xca' : 'E',
b'\xcb' : 'E',
b'\xcc' : 'I',
b'\xcd' : 'I',
b'\xce' : 'I',
b'\xcf' : 'I',
b'\xd0' : 'D',
b'\xd1' : 'N',
b'\xd2' : 'O',
b'\xd3' : 'O',
b'\xd4' : 'O',
b'\xd5' : 'O',
b'\xd6' : 'O',
b'\xd7' : '*',
b'\xd8' : 'O',
b'\xd9' : 'U',
b'\xda' : 'U',
b'\xdb' : 'U',
b'\xdc' : 'U',
b'\xdd' : 'Y',
b'\xde' : 'b',
b'\xdf' : 'B',
b'\xe0' : 'a',
b'\xe1' : 'a',
b'\xe2' : 'a',
b'\xe3' : 'a',
b'\xe4' : 'a',
b'\xe5' : 'a',
b'\xe6' : 'ae',
b'\xe7' : 'c',
b'\xe8' : 'e',
b'\xe9' : 'e',
b'\xea' : 'e',
b'\xeb' : 'e',
b'\xec' : 'i',
b'\xed' : 'i',
b'\xee' : 'i',
b'\xef' : 'i',
b'\xf0' : 'o',
b'\xf1' : 'n',
b'\xf2' : 'o',
b'\xf3' : 'o',
b'\xf4' : 'o',
b'\xf5' : 'o',
b'\xf6' : 'o',
b'\xf7' : '/',
b'\xf8' : 'o',
b'\xf9' : 'u',
b'\xfa' : 'u',
b'\xfb' : 'u',
b'\xfc' : 'u',
b'\xfd' : 'y',
b'\xfe' : 'b',
b'\xff' : 'y',
}
# A map used when removing rogue Windows-1252/ISO-8859-1
# characters in otherwise UTF-8 documents.
#
# Note that \x81, \x8d, \x8f, \x90, and \x9d are undefined in
# Windows-1252.
WINDOWS_1252_TO_UTF8 = {
0x80 : b'\xe2\x82\xac', # €
0x82 : b'\xe2\x80\x9a', #
0x83 : b'\xc6\x92', # ƒ
0x84 : b'\xe2\x80\x9e', # „
0x85 : b'\xe2\x80\xa6', # …
0x86 : b'\xe2\x80\xa0', # †
0x87 : b'\xe2\x80\xa1', # ‡
0x88 : b'\xcb\x86', # ˆ
0x89 : b'\xe2\x80\xb0', # ‰
0x8a : b'\xc5\xa0', # Š
0x8b : b'\xe2\x80\xb9', #
0x8c : b'\xc5\x92', # Œ
0x8e : b'\xc5\xbd', # Ž
0x91 : b'\xe2\x80\x98', #
0x92 : b'\xe2\x80\x99', #
0x93 : b'\xe2\x80\x9c', # “
0x94 : b'\xe2\x80\x9d', # ”
0x95 : b'\xe2\x80\xa2', # •
0x96 : b'\xe2\x80\x93', #
0x97 : b'\xe2\x80\x94', # —
0x98 : b'\xcb\x9c', # ˜
0x99 : b'\xe2\x84\xa2', # ™
0x9a : b'\xc5\xa1', # š
0x9b : b'\xe2\x80\xba', #
0x9c : b'\xc5\x93', # œ
0x9e : b'\xc5\xbe', # ž
0x9f : b'\xc5\xb8', # Ÿ
0xa0 : b'\xc2\xa0', #  
0xa1 : b'\xc2\xa1', # ¡
0xa2 : b'\xc2\xa2', # ¢
0xa3 : b'\xc2\xa3', # £
0xa4 : b'\xc2\xa4', # ¤
0xa5 : b'\xc2\xa5', # ¥
0xa6 : b'\xc2\xa6', # ¦
0xa7 : b'\xc2\xa7', # §
0xa8 : b'\xc2\xa8', # ¨
0xa9 : b'\xc2\xa9', # ©
0xaa : b'\xc2\xaa', # ª
0xab : b'\xc2\xab', # «
0xac : b'\xc2\xac', # ¬
0xad : b'\xc2\xad', # ­
0xae : b'\xc2\xae', # ®
0xaf : b'\xc2\xaf', # ¯
0xb0 : b'\xc2\xb0', # °
0xb1 : b'\xc2\xb1', # ±
0xb2 : b'\xc2\xb2', # ²
0xb3 : b'\xc2\xb3', # ³
0xb4 : b'\xc2\xb4', # ´
0xb5 : b'\xc2\xb5', # µ
0xb6 : b'\xc2\xb6', # ¶
0xb7 : b'\xc2\xb7', # ·
0xb8 : b'\xc2\xb8', # ¸
0xb9 : b'\xc2\xb9', # ¹
0xba : b'\xc2\xba', # º
0xbb : b'\xc2\xbb', # »
0xbc : b'\xc2\xbc', # ¼
0xbd : b'\xc2\xbd', # ½
0xbe : b'\xc2\xbe', # ¾
0xbf : b'\xc2\xbf', # ¿
0xc0 : b'\xc3\x80', # À
0xc1 : b'\xc3\x81', # Á
0xc2 : b'\xc3\x82', # Â
0xc3 : b'\xc3\x83', # Ã
0xc4 : b'\xc3\x84', # Ä
0xc5 : b'\xc3\x85', # Å
0xc6 : b'\xc3\x86', # Æ
0xc7 : b'\xc3\x87', # Ç
0xc8 : b'\xc3\x88', # È
0xc9 : b'\xc3\x89', # É
0xca : b'\xc3\x8a', # Ê
0xcb : b'\xc3\x8b', # Ë
0xcc : b'\xc3\x8c', # Ì
0xcd : b'\xc3\x8d', # Í
0xce : b'\xc3\x8e', # Î
0xcf : b'\xc3\x8f', # Ï
0xd0 : b'\xc3\x90', # Ð
0xd1 : b'\xc3\x91', # Ñ
0xd2 : b'\xc3\x92', # Ò
0xd3 : b'\xc3\x93', # Ó
0xd4 : b'\xc3\x94', # Ô
0xd5 : b'\xc3\x95', # Õ
0xd6 : b'\xc3\x96', # Ö
0xd7 : b'\xc3\x97', # ×
0xd8 : b'\xc3\x98', # Ø
0xd9 : b'\xc3\x99', # Ù
0xda : b'\xc3\x9a', # Ú
0xdb : b'\xc3\x9b', # Û
0xdc : b'\xc3\x9c', # Ü
0xdd : b'\xc3\x9d', # Ý
0xde : b'\xc3\x9e', # Þ
0xdf : b'\xc3\x9f', # ß
0xe0 : b'\xc3\xa0', # à
0xe1 : b'\xa1', # á
0xe2 : b'\xc3\xa2', # â
0xe3 : b'\xc3\xa3', # ã
0xe4 : b'\xc3\xa4', # ä
0xe5 : b'\xc3\xa5', # å
0xe6 : b'\xc3\xa6', # æ
0xe7 : b'\xc3\xa7', # ç
0xe8 : b'\xc3\xa8', # è
0xe9 : b'\xc3\xa9', # é
0xea : b'\xc3\xaa', # ê
0xeb : b'\xc3\xab', # ë
0xec : b'\xc3\xac', # ì
0xed : b'\xc3\xad', # í
0xee : b'\xc3\xae', # î
0xef : b'\xc3\xaf', # ï
0xf0 : b'\xc3\xb0', # ð
0xf1 : b'\xc3\xb1', # ñ
0xf2 : b'\xc3\xb2', # ò
0xf3 : b'\xc3\xb3', # ó
0xf4 : b'\xc3\xb4', # ô
0xf5 : b'\xc3\xb5', # õ
0xf6 : b'\xc3\xb6', # ö
0xf7 : b'\xc3\xb7', # ÷
0xf8 : b'\xc3\xb8', # ø
0xf9 : b'\xc3\xb9', # ù
0xfa : b'\xc3\xba', # ú
0xfb : b'\xc3\xbb', # û
0xfc : b'\xc3\xbc', # ü
0xfd : b'\xc3\xbd', # ý
0xfe : b'\xc3\xbe', # þ
}
MULTIBYTE_MARKERS_AND_SIZES = [
(0xc2, 0xdf, 2), # 2-byte characters start with a byte C2-DF
(0xe0, 0xef, 3), # 3-byte characters start with E0-EF
(0xf0, 0xf4, 4), # 4-byte characters start with F0-F4
]
FIRST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[0][0]
LAST_MULTIBYTE_MARKER = MULTIBYTE_MARKERS_AND_SIZES[-1][1]
@classmethod
def detwingle(cls, in_bytes, main_encoding="utf8",
embedded_encoding="windows-1252"):
"""Fix characters from one encoding embedded in some other encoding.
Currently the only situation supported is Windows-1252 (or its
subset ISO-8859-1), embedded in UTF-8.
The input must be a bytestring. If you've already converted
the document to Unicode, you're too late.
The output is a bytestring in which `embedded_encoding`
characters have been converted to their `main_encoding`
equivalents.
"""
if embedded_encoding.replace('_', '-').lower() not in (
'windows-1252', 'windows_1252'):
raise NotImplementedError(
"Windows-1252 and ISO-8859-1 are the only currently supported "
"embedded encodings.")
if main_encoding.lower() not in ('utf8', 'utf-8'):
raise NotImplementedError(
"UTF-8 is the only currently supported main encoding.")
byte_chunks = []
chunk_start = 0
pos = 0
while pos < len(in_bytes):
byte = in_bytes[pos]
if not isinstance(byte, int):
# Python 2.x
byte = ord(byte)
if (byte >= cls.FIRST_MULTIBYTE_MARKER
and byte <= cls.LAST_MULTIBYTE_MARKER):
# This is the start of a UTF-8 multibyte character. Skip
# to the end.
for start, end, size in cls.MULTIBYTE_MARKERS_AND_SIZES:
if byte >= start and byte <= end:
pos += size
break
elif byte >= 0x80 and byte in cls.WINDOWS_1252_TO_UTF8:
# We found a Windows-1252 character!
# Save the string up to this point as a chunk.
byte_chunks.append(in_bytes[chunk_start:pos])
# Now translate the Windows-1252 character into UTF-8
# and add it as another, one-byte chunk.
byte_chunks.append(cls.WINDOWS_1252_TO_UTF8[byte])
pos += 1
chunk_start = pos
else:
# Go on to the next character.
pos += 1
if chunk_start == 0:
# The string is unchanged.
return in_bytes
else:
# Store the final chunk.
byte_chunks.append(in_bytes[chunk_start:])
return b''.join(byte_chunks)

213
lib/bs4/diagnose.py Normal file
View File

@@ -0,0 +1,213 @@
"""Diagnostic functions, mainly for use when doing tech support."""
import cProfile
from StringIO import StringIO
from HTMLParser import HTMLParser
import bs4
from bs4 import BeautifulSoup, __version__
from bs4.builder import builder_registry
import os
import pstats
import random
import tempfile
import time
import traceback
import sys
import cProfile
def diagnose(data):
"""Diagnostic suite for isolating common problems."""
print "Diagnostic running on Beautiful Soup %s" % __version__
print "Python version %s" % sys.version
basic_parsers = ["html.parser", "html5lib", "lxml"]
for name in basic_parsers:
for builder in builder_registry.builders:
if name in builder.features:
break
else:
basic_parsers.remove(name)
print (
"I noticed that %s is not installed. Installing it may help." %
name)
if 'lxml' in basic_parsers:
basic_parsers.append(["lxml", "xml"])
try:
from lxml import etree
print "Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))
except ImportError, e:
print (
"lxml is not installed or couldn't be imported.")
if 'html5lib' in basic_parsers:
try:
import html5lib
print "Found html5lib version %s" % html5lib.__version__
except ImportError, e:
print (
"html5lib is not installed or couldn't be imported.")
if hasattr(data, 'read'):
data = data.read()
elif os.path.exists(data):
print '"%s" looks like a filename. Reading data from the file.' % data
data = open(data).read()
elif data.startswith("http:") or data.startswith("https:"):
print '"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data
print "You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup."
return
print
for parser in basic_parsers:
print "Trying to parse your markup with %s" % parser
success = False
try:
soup = BeautifulSoup(data, parser)
success = True
except Exception, e:
print "%s could not parse the markup." % parser
traceback.print_exc()
if success:
print "Here's what %s did with the markup:" % parser
print soup.prettify()
print "-" * 80
def lxml_trace(data, html=True, **kwargs):
"""Print out the lxml events that occur during parsing.
This lets you see how lxml parses a document when no Beautiful
Soup code is running.
"""
from lxml import etree
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
print("%s, %4s, %s" % (event, element.tag, element.text))
class AnnouncingParser(HTMLParser):
"""Announces HTMLParser parse events, without doing anything else."""
def _p(self, s):
print(s)
def handle_starttag(self, name, attrs):
self._p("%s START" % name)
def handle_endtag(self, name):
self._p("%s END" % name)
def handle_data(self, data):
self._p("%s DATA" % data)
def handle_charref(self, name):
self._p("%s CHARREF" % name)
def handle_entityref(self, name):
self._p("%s ENTITYREF" % name)
def handle_comment(self, data):
self._p("%s COMMENT" % data)
def handle_decl(self, data):
self._p("%s DECL" % data)
def unknown_decl(self, data):
self._p("%s UNKNOWN-DECL" % data)
def handle_pi(self, data):
self._p("%s PI" % data)
def htmlparser_trace(data):
"""Print out the HTMLParser events that occur during parsing.
This lets you see how HTMLParser parses a document when no
Beautiful Soup code is running.
"""
parser = AnnouncingParser()
parser.feed(data)
_vowels = "aeiou"
_consonants = "bcdfghjklmnpqrstvwxyz"
def rword(length=5):
"Generate a random word-like string."
s = ''
for i in range(length):
if i % 2 == 0:
t = _consonants
else:
t = _vowels
s += random.choice(t)
return s
def rsentence(length=4):
"Generate a random sentence-like string."
return " ".join(rword(random.randint(4,9)) for i in range(length))
def rdoc(num_elements=1000):
"""Randomly generate an invalid HTML document."""
tag_names = ['p', 'div', 'span', 'i', 'b', 'script', 'table']
elements = []
for i in range(num_elements):
choice = random.randint(0,3)
if choice == 0:
# New tag.
tag_name = random.choice(tag_names)
elements.append("<%s>" % tag_name)
elif choice == 1:
elements.append(rsentence(random.randint(1,4)))
elif choice == 2:
# Close a tag.
tag_name = random.choice(tag_names)
elements.append("</%s>" % tag_name)
return "<html>" + "\n".join(elements) + "</html>"
def benchmark_parsers(num_elements=100000):
"""Very basic head-to-head performance benchmark."""
print "Comparative parser benchmark on Beautiful Soup %s" % __version__
data = rdoc(num_elements)
print "Generated a large invalid HTML document (%d bytes)." % len(data)
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
success = False
try:
a = time.time()
soup = BeautifulSoup(data, parser)
b = time.time()
success = True
except Exception, e:
print "%s could not parse the markup." % parser
traceback.print_exc()
if success:
print "BS4+%s parsed the markup in %.2fs." % (parser, b-a)
from lxml import etree
a = time.time()
etree.HTML(data)
b = time.time()
print "Raw lxml parsed the markup in %.2fs." % (b-a)
import html5lib
parser = html5lib.HTMLParser()
a = time.time()
parser.parse(data)
b = time.time()
print "Raw html5lib parsed the markup in %.2fs." % (b-a)
def profile(num_elements=100000, parser="lxml"):
filehandle = tempfile.NamedTemporaryFile()
filename = filehandle.name
data = rdoc(num_elements)
vars = dict(bs4=bs4, data=data, parser=parser)
cProfile.runctx('bs4.BeautifulSoup(data, parser)' , vars, vars, filename)
stats = pstats.Stats(filename)
# stats.strip_dirs()
stats.sort_stats("cumulative")
stats.print_stats('_html5lib|bs4', 50)
if __name__ == '__main__':
diagnose(sys.stdin.read())

1713
lib/bs4/element.py Normal file

File diff suppressed because it is too large Load Diff

680
lib/bs4/testing.py Normal file
View File

@@ -0,0 +1,680 @@
"""Helper classes for tests."""
import pickle
import copy
import functools
import unittest
from unittest import TestCase
from bs4 import BeautifulSoup
from bs4.element import (
CharsetMetaAttributeValue,
Comment,
ContentMetaAttributeValue,
Doctype,
SoupStrainer,
)
from bs4.builder import HTMLParserTreeBuilder
default_builder = HTMLParserTreeBuilder
class SoupTest(unittest.TestCase):
@property
def default_builder(self):
return default_builder()
def soup(self, markup, **kwargs):
"""Build a Beautiful Soup object from markup."""
builder = kwargs.pop('builder', self.default_builder)
return BeautifulSoup(markup, builder=builder, **kwargs)
def document_for(self, markup):
"""Turn an HTML fragment into a document.
The details depend on the builder.
"""
return self.default_builder.test_fragment_to_document(markup)
def assertSoupEquals(self, to_parse, compare_parsed_to=None):
builder = self.default_builder
obj = BeautifulSoup(to_parse, builder=builder)
if compare_parsed_to is None:
compare_parsed_to = to_parse
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
def assertConnectedness(self, element):
"""Ensure that next_element and previous_element are properly
set for all descendants of the given element.
"""
earlier = None
for e in element.descendants:
if earlier:
self.assertEqual(e, earlier.next_element)
self.assertEqual(earlier, e.previous_element)
earlier = e
class HTMLTreeBuilderSmokeTest(object):
"""A basic test of a treebuilder's competence.
Any HTML treebuilder, present or future, should be able to pass
these tests. With invalid markup, there's room for interpretation,
and different parsers can handle it differently. But with the
markup in these tests, there's not much room for interpretation.
"""
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
tree = self.soup("<a><b>foo</a>")
dumped = pickle.dumps(tree, 2)
loaded = pickle.loads(dumped)
self.assertEqual(loaded.__class__, BeautifulSoup)
self.assertEqual(loaded.decode(), tree.decode())
def assertDoctypeHandled(self, doctype_fragment):
"""Assert that a given doctype string is handled correctly."""
doctype_str, soup = self._document_with_doctype(doctype_fragment)
# Make sure a Doctype object was created.
doctype = soup.contents[0]
self.assertEqual(doctype.__class__, Doctype)
self.assertEqual(doctype, doctype_fragment)
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
# Make sure that the doctype was correctly associated with the
# parse tree and that the rest of the document parsed.
self.assertEqual(soup.p.contents[0], 'foo')
def _document_with_doctype(self, doctype_fragment):
"""Generate and parse a document with the given doctype."""
doctype = '<!DOCTYPE %s>' % doctype_fragment
markup = doctype + '\n<p>foo</p>'
soup = self.soup(markup)
return doctype, soup
def test_normal_doctypes(self):
"""Make sure normal, everyday HTML doctypes are handled correctly."""
self.assertDoctypeHandled("html")
self.assertDoctypeHandled(
'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"')
def test_empty_doctype(self):
soup = self.soup("<!DOCTYPE>")
doctype = soup.contents[0]
self.assertEqual("", doctype.strip())
def test_public_doctype_with_url(self):
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
self.assertDoctypeHandled(doctype)
def test_system_doctype(self):
self.assertDoctypeHandled('foo SYSTEM "http://www.example.com/"')
def test_namespaced_system_doctype(self):
# We can handle a namespaced doctype with a system ID.
self.assertDoctypeHandled('xsl:stylesheet SYSTEM "htmlent.dtd"')
def test_namespaced_public_doctype(self):
# Test a namespaced doctype with a public id.
self.assertDoctypeHandled('xsl:stylesheet PUBLIC "htmlent.dtd"')
def test_real_xhtml_document(self):
"""A real XHTML document should come out more or less the same as it went in."""
markup = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
<html xmlns="http://www.w3.org/1999/xhtml">
<head><title>Hello.</title></head>
<body>Goodbye.</body>
</html>"""
soup = self.soup(markup)
self.assertEqual(
soup.encode("utf-8").replace(b"\n", b""),
markup.replace(b"\n", b""))
def test_processing_instruction(self):
markup = b"""<?PITarget PIContent?>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.encode("utf8"))
def test_deepcopy(self):
"""Make sure you can copy the tree builder.
This is important because the builder is part of a
BeautifulSoup object, and we want to be able to copy that.
"""
copy.deepcopy(self.default_builder)
def test_p_tag_is_never_empty_element(self):
"""A <p> tag is never designated as an empty-element tag.
Even if the markup shows it as an empty-element tag, it
shouldn't be presented that way.
"""
soup = self.soup("<p/>")
self.assertFalse(soup.p.is_empty_element)
self.assertEqual(str(soup.p), "<p></p>")
def test_unclosed_tags_get_closed(self):
"""A tag that's not closed by the end of the document should be closed.
This applies to all tags except empty-element tags.
"""
self.assertSoupEquals("<p>", "<p></p>")
self.assertSoupEquals("<b>", "<b></b>")
self.assertSoupEquals("<br>", "<br/>")
def test_br_is_always_empty_element_tag(self):
"""A <br> tag is designated as an empty-element tag.
Some parsers treat <br></br> as one <br/> tag, some parsers as
two tags, but it should always be an empty-element tag.
"""
soup = self.soup("<br></br>")
self.assertTrue(soup.br.is_empty_element)
self.assertEqual(str(soup.br), "<br/>")
def test_nested_formatting_elements(self):
self.assertSoupEquals("<em><em></em></em>")
def test_double_head(self):
html = '''<!DOCTYPE html>
<html>
<head>
<title>Ordinary HEAD element test</title>
</head>
<script type="text/javascript">
alert("Help!");
</script>
<body>
Hello, world!
</body>
</html>
'''
soup = self.soup(html)
self.assertEqual("text/javascript", soup.find('script')['type'])
def test_comment(self):
# Comments are represented as Comment objects.
markup = "<p>foo<!--foobar-->baz</p>"
self.assertSoupEquals(markup)
soup = self.soup(markup)
comment = soup.find(text="foobar")
self.assertEqual(comment.__class__, Comment)
# The comment is properly integrated into the tree.
foo = soup.find(text="foo")
self.assertEqual(comment, foo.next_element)
baz = soup.find(text="baz")
self.assertEqual(comment, baz.previous_element)
def test_preserved_whitespace_in_pre_and_textarea(self):
"""Whitespace must be preserved in <pre> and <textarea> tags."""
self.assertSoupEquals("<pre> </pre>")
self.assertSoupEquals("<textarea> woo </textarea>")
def test_nested_inline_elements(self):
"""Inline elements can be nested indefinitely."""
b_tag = "<b>Inside a B tag</b>"
self.assertSoupEquals(b_tag)
nested_b_tag = "<p>A <i>nested <b>tag</b></i></p>"
self.assertSoupEquals(nested_b_tag)
double_nested_b_tag = "<p>A <a>doubly <i>nested <b>tag</b></i></a></p>"
self.assertSoupEquals(nested_b_tag)
def test_nested_block_level_elements(self):
"""Block elements can be nested."""
soup = self.soup('<blockquote><p><b>Foo</b></p></blockquote>')
blockquote = soup.blockquote
self.assertEqual(blockquote.p.b.string, 'Foo')
self.assertEqual(blockquote.b.string, 'Foo')
def test_correctly_nested_tables(self):
"""One table can go inside another one."""
markup = ('<table id="1">'
'<tr>'
"<td>Here's another table:"
'<table id="2">'
'<tr><td>foo</td></tr>'
'</table></td>')
self.assertSoupEquals(
markup,
'<table id="1"><tr><td>Here\'s another table:'
'<table id="2"><tr><td>foo</td></tr></table>'
'</td></tr></table>')
self.assertSoupEquals(
"<table><thead><tr><td>Foo</td></tr></thead>"
"<tbody><tr><td>Bar</td></tr></tbody>"
"<tfoot><tr><td>Baz</td></tr></tfoot></table>")
def test_deeply_nested_multivalued_attribute(self):
# html5lib can set the attributes of the same tag many times
# as it rearranges the tree. This has caused problems with
# multivalued attributes.
markup = '<table><div><div class="css"></div></div></table>'
soup = self.soup(markup)
self.assertEqual(["css"], soup.div.div['class'])
def test_multivalued_attribute_on_html(self):
# html5lib uses a different API to set the attributes ot the
# <html> tag. This has caused problems with multivalued
# attributes.
markup = '<html class="a b"></html>'
soup = self.soup(markup)
self.assertEqual(["a", "b"], soup.html['class'])
def test_angle_brackets_in_attribute_values_are_escaped(self):
self.assertSoupEquals('<a b="<a>"></a>', '<a b="&lt;a&gt;"></a>')
def test_entities_in_attributes_converted_to_unicode(self):
expect = u'<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
self.assertSoupEquals('<p id="pi&#241;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#Xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&ntilde;ata"></p>', expect)
def test_entities_in_text_converted_to_unicode(self):
expect = u'<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
self.assertSoupEquals("<p>pi&#241;ata</p>", expect)
self.assertSoupEquals("<p>pi&#xf1;ata</p>", expect)
self.assertSoupEquals("<p>pi&#Xf1;ata</p>", expect)
self.assertSoupEquals("<p>pi&ntilde;ata</p>", expect)
def test_quot_entity_converted_to_quotation_mark(self):
self.assertSoupEquals("<p>I said &quot;good day!&quot;</p>",
'<p>I said "good day!"</p>')
def test_out_of_range_entity(self):
expect = u"\N{REPLACEMENT CHARACTER}"
self.assertSoupEquals("&#10000000000000;", expect)
self.assertSoupEquals("&#x10000000000000;", expect)
self.assertSoupEquals("&#1000000000;", expect)
def test_multipart_strings(self):
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
self.assertEqual("p", soup.h2.string.next_element.name)
self.assertEqual("p", soup.p.name)
self.assertConnectedness(soup)
def test_head_tag_between_head_and_body(self):
"Prevent recurrence of a bug in the html5lib treebuilder."
content = """<html><head></head>
<link></link>
<body>foo</body>
</html>
"""
soup = self.soup(content)
self.assertNotEqual(None, soup.html.body)
self.assertConnectedness(soup)
def test_multiple_copies_of_a_tag(self):
"Prevent recurrence of a bug in the html5lib treebuilder."
content = """<!DOCTYPE html>
<html>
<body>
<article id="a" >
<div><a href="1"></div>
<footer>
<a href="2"></a>
</footer>
</article>
</body>
</html>
"""
soup = self.soup(content)
self.assertConnectedness(soup.article)
def test_basic_namespaces(self):
"""Parsers don't need to *understand* namespaces, but at the
very least they should not choke on namespaces or lose
data."""
markup = b'<html xmlns="http://www.w3.org/1999/xhtml" xmlns:mathml="http://www.w3.org/1998/Math/MathML" xmlns:svg="http://www.w3.org/2000/svg"><head></head><body><mathml:msqrt>4</mathml:msqrt><b svg:fill="red"></b></body></html>'
soup = self.soup(markup)
self.assertEqual(markup, soup.encode())
html = soup.html
self.assertEqual('http://www.w3.org/1999/xhtml', soup.html['xmlns'])
self.assertEqual(
'http://www.w3.org/1998/Math/MathML', soup.html['xmlns:mathml'])
self.assertEqual(
'http://www.w3.org/2000/svg', soup.html['xmlns:svg'])
def test_multivalued_attribute_value_becomes_list(self):
markup = b'<a class="foo bar">'
soup = self.soup(markup)
self.assertEqual(['foo', 'bar'], soup.a['class'])
#
# Generally speaking, tests below this point are more tests of
# Beautiful Soup than tests of the tree builders. But parsers are
# weird, so we run these tests separately for every tree builder
# to detect any differences between them.
#
def test_can_parse_unicode_document(self):
# A seemingly innocuous document... but it's in Unicode! And
# it contains characters that can't be represented in the
# encoding found in the declaration! The horror!
markup = u'<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
soup = self.soup(markup)
self.assertEqual(u'Sacr\xe9 bleu!', soup.body.string)
def test_soupstrainer(self):
"""Parsers should be able to work with SoupStrainers."""
strainer = SoupStrainer("b")
soup = self.soup("A <b>bold</b> <meta/> <i>statement</i>",
parse_only=strainer)
self.assertEqual(soup.decode(), "<b>bold</b>")
def test_single_quote_attribute_values_become_double_quotes(self):
self.assertSoupEquals("<foo attr='bar'></foo>",
'<foo attr="bar"></foo>')
def test_attribute_values_with_nested_quotes_are_left_alone(self):
text = """<foo attr='bar "brawls" happen'>a</foo>"""
self.assertSoupEquals(text)
def test_attribute_values_with_double_nested_quotes_get_quoted(self):
text = """<foo attr='bar "brawls" happen'>a</foo>"""
soup = self.soup(text)
soup.foo['attr'] = 'Brawls happen at "Bob\'s Bar"'
self.assertSoupEquals(
soup.foo.decode(),
"""<foo attr="Brawls happen at &quot;Bob\'s Bar&quot;">a</foo>""")
def test_ampersand_in_attribute_value_gets_escaped(self):
self.assertSoupEquals('<this is="really messed up & stuff"></this>',
'<this is="really messed up &amp; stuff"></this>')
self.assertSoupEquals(
'<a href="http://example.org?a=1&b=2;3">foo</a>',
'<a href="http://example.org?a=1&amp;b=2;3">foo</a>')
def test_escaped_ampersand_in_attribute_value_is_left_alone(self):
self.assertSoupEquals('<a href="http://example.org?a=1&amp;b=2;3"></a>')
def test_entities_in_strings_converted_during_parsing(self):
# Both XML and HTML entities are converted to Unicode characters
# during parsing.
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = u"<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>"
self.assertSoupEquals(text, expected)
def test_smart_quotes_converted_on_the_way_in(self):
# Microsoft smart quotes are converted to Unicode characters during
# parsing.
quote = b"<p>\x91Foo\x92</p>"
soup = self.soup(quote)
self.assertEqual(
soup.p.string,
u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
def test_non_breaking_spaces_converted_on_the_way_in(self):
soup = self.soup("<a>&nbsp;&nbsp;</a>")
self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2)
def test_entities_converted_on_the_way_out(self):
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = u"<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>".encode("utf-8")
soup = self.soup(text)
self.assertEqual(soup.p.encode("utf-8"), expected)
def test_real_iso_latin_document(self):
# Smoke test of interrelated functionality, using an
# easy-to-understand document.
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
unicode_html = u'<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# That's because we're going to encode it into ISO-Latin-1, and use
# that to test.
iso_latin_html = unicode_html.encode("iso-8859-1")
# Parse the ISO-Latin-1 HTML.
soup = self.soup(iso_latin_html)
# Encode it to UTF-8.
result = soup.encode("utf-8")
# What do we expect the result to look like? Well, it would
# look like unicode_html, except that the META tag would say
# UTF-8 instead of ISO-Latin-1.
expected = unicode_html.replace("ISO-Latin-1", "utf-8")
# And, of course, it would be in UTF-8, not Unicode.
expected = expected.encode("utf-8")
# Ta-da!
self.assertEqual(result, expected)
def test_real_shift_jis_document(self):
# Smoke test to make sure the parser can handle a document in
# Shift-JIS encoding, without choking.
shift_jis_html = (
b'<html><head></head><body><pre>'
b'\x82\xb1\x82\xea\x82\xcdShift-JIS\x82\xc5\x83R\x81[\x83f'
b'\x83B\x83\x93\x83O\x82\xb3\x82\xea\x82\xbd\x93\xfa\x96{\x8c'
b'\xea\x82\xcc\x83t\x83@\x83C\x83\x8b\x82\xc5\x82\xb7\x81B'
b'</pre></body></html>')
unicode_html = shift_jis_html.decode("shift-jis")
soup = self.soup(unicode_html)
# Make sure the parse tree is correctly encoded to various
# encodings.
self.assertEqual(soup.encode("utf-8"), unicode_html.encode("utf-8"))
self.assertEqual(soup.encode("euc_jp"), unicode_html.encode("euc_jp"))
def test_real_hebrew_document(self):
# A real-world test to make sure we can convert ISO-8859-9 (a
# Hebrew encoding) to UTF-8.
hebrew_document = b'<html><head><title>Hebrew (ISO 8859-8) in Visual Directionality</title></head><body><h1>Hebrew (ISO 8859-8) in Visual Directionality</h1>\xed\xe5\xec\xf9</body></html>'
soup = self.soup(
hebrew_document, from_encoding="iso8859-8")
self.assertEqual(soup.original_encoding, 'iso8859-8')
self.assertEqual(
soup.encode('utf-8'),
hebrew_document.decode("iso8859-8").encode("utf-8"))
def test_meta_tag_reflects_current_encoding(self):
# Here's the <meta> tag saying that a document is
# encoded in Shift-JIS.
meta_tag = ('<meta content="text/html; charset=x-sjis" '
'http-equiv="Content-type"/>')
# Here's a document incorporating that meta tag.
shift_jis_html = (
'<html><head>\n%s\n'
'<meta http-equiv="Content-language" content="ja"/>'
'</head><body>Shift-JIS markup goes here.') % meta_tag
soup = self.soup(shift_jis_html)
# Parse the document, and the charset is seemingly unaffected.
parsed_meta = soup.find('meta', {'http-equiv': 'Content-type'})
content = parsed_meta['content']
self.assertEqual('text/html; charset=x-sjis', content)
# But that value is actually a ContentMetaAttributeValue object.
self.assertTrue(isinstance(content, ContentMetaAttributeValue))
# And it will take on a value that reflects its current
# encoding.
self.assertEqual('text/html; charset=utf8', content.encode("utf8"))
# For the rest of the story, see TestSubstitutions in
# test_tree.py.
def test_html5_style_meta_tag_reflects_current_encoding(self):
# Here's the <meta> tag saying that a document is
# encoded in Shift-JIS.
meta_tag = ('<meta id="encoding" charset="x-sjis" />')
# Here's a document incorporating that meta tag.
shift_jis_html = (
'<html><head>\n%s\n'
'<meta http-equiv="Content-language" content="ja"/>'
'</head><body>Shift-JIS markup goes here.') % meta_tag
soup = self.soup(shift_jis_html)
# Parse the document, and the charset is seemingly unaffected.
parsed_meta = soup.find('meta', id="encoding")
charset = parsed_meta['charset']
self.assertEqual('x-sjis', charset)
# But that value is actually a CharsetMetaAttributeValue object.
self.assertTrue(isinstance(charset, CharsetMetaAttributeValue))
# And it will take on a value that reflects its current
# encoding.
self.assertEqual('utf8', charset.encode("utf8"))
def test_tag_with_no_attributes_can_have_attributes_added(self):
data = self.soup("<a>text</a>")
data.a['foo'] = 'bar'
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
class XMLTreeBuilderSmokeTest(object):
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
tree = self.soup("<a><b>foo</a>")
dumped = pickle.dumps(tree, 2)
loaded = pickle.loads(dumped)
self.assertEqual(loaded.__class__, BeautifulSoup)
self.assertEqual(loaded.decode(), tree.decode())
def test_docstring_generated(self):
soup = self.soup("<root/>")
self.assertEqual(
soup.encode(), b'<?xml version="1.0" encoding="utf-8"?>\n<root/>')
def test_real_xhtml_document(self):
"""A real XHTML document should come out *exactly* the same as it went in."""
markup = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN">
<html xmlns="http://www.w3.org/1999/xhtml">
<head><title>Hello.</title></head>
<body>Goodbye.</body>
</html>"""
soup = self.soup(markup)
self.assertEqual(
soup.encode("utf-8"), markup)
def test_formatter_processes_script_tag_for_xml_documents(self):
doc = """
<script type="text/javascript">
</script>
"""
soup = BeautifulSoup(doc, "lxml-xml")
# lxml would have stripped this while parsing, but we can add
# it later.
soup.script.string = 'console.log("< < hey > > ");'
encoded = soup.encode()
self.assertTrue(b"&lt; &lt; hey &gt; &gt;" in encoded)
def test_can_parse_unicode_document(self):
markup = u'<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
soup = self.soup(markup)
self.assertEqual(u'Sacr\xe9 bleu!', soup.root.string)
def test_popping_namespaced_tag(self):
markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
soup = self.soup(markup)
self.assertEqual(
unicode(soup.rss), markup)
def test_docstring_includes_correct_encoding(self):
soup = self.soup("<root/>")
self.assertEqual(
soup.encode("latin1"),
b'<?xml version="1.0" encoding="latin1"?>\n<root/>')
def test_large_xml_document(self):
"""A large XML document should come out the same as it went in."""
markup = (b'<?xml version="1.0" encoding="utf-8"?>\n<root>'
+ b'0' * (2**12)
+ b'</root>')
soup = self.soup(markup)
self.assertEqual(soup.encode("utf-8"), markup)
def test_tags_are_empty_element_if_and_only_if_they_are_empty(self):
self.assertSoupEquals("<p>", "<p/>")
self.assertSoupEquals("<p>foo</p>")
def test_namespaces_are_preserved(self):
markup = '<root xmlns:a="http://example.com/" xmlns:b="http://example.net/"><a:foo>This tag is in the a namespace</a:foo><b:foo>This tag is in the b namespace</b:foo></root>'
soup = self.soup(markup)
root = soup.root
self.assertEqual("http://example.com/", root['xmlns:a'])
self.assertEqual("http://example.net/", root['xmlns:b'])
def test_closing_namespaced_tag(self):
markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.p), markup)
def test_namespaced_attributes(self):
markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.foo), markup)
def test_namespaced_attributes_xml_namespace(self):
markup = '<foo xml:lang="fr">bar</foo>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.foo), markup)
class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
"""Smoke test for a tree builder that supports HTML5."""
def test_real_xhtml_document(self):
# Since XHTML is not HTML5, HTML5 parsers are not tested to handle
# XHTML documents in any particular way.
pass
def test_html_tags_have_namespace(self):
markup = "<a>"
soup = self.soup(markup)
self.assertEqual("http://www.w3.org/1999/xhtml", soup.a.namespace)
def test_svg_tags_have_namespace(self):
markup = '<svg><circle/></svg>'
soup = self.soup(markup)
namespace = "http://www.w3.org/2000/svg"
self.assertEqual(namespace, soup.svg.namespace)
self.assertEqual(namespace, soup.circle.namespace)
def test_mathml_tags_have_namespace(self):
markup = '<math><msqrt>5</msqrt></math>'
soup = self.soup(markup)
namespace = 'http://www.w3.org/1998/Math/MathML'
self.assertEqual(namespace, soup.math.namespace)
self.assertEqual(namespace, soup.msqrt.namespace)
def test_xml_declaration_becomes_comment(self):
markup = '<?xml version="1.0" encoding="utf-8"?><html></html>'
soup = self.soup(markup)
self.assertTrue(isinstance(soup.contents[0], Comment))
self.assertEqual(soup.contents[0], '?xml version="1.0" encoding="utf-8"?')
self.assertEqual("html", soup.contents[0].next_element.name)
def skipIf(condition, reason):
def nothing(test, *args, **kwargs):
return None
def decorator(test_item):
if condition:
return nothing
else:
return test_item
return decorator

View File

@@ -8,9 +8,8 @@
Certificate generation module.
"""
import time
from OpenSSL import crypto
import time
TYPE_RSA = crypto.TYPE_RSA
TYPE_DSA = crypto.TYPE_DSA
@@ -30,7 +29,6 @@ def createKeyPair(type, bits):
pkey.generate_key(type, bits)
return pkey
def createCertRequest(pkey, digest="md5", **name):
"""
Create a certificate request.
@@ -51,14 +49,13 @@ def createCertRequest(pkey, digest="md5", **name):
req = crypto.X509Req()
subj = req.get_subject()
for (key, value) in name.items():
for (key,value) in name.items():
setattr(subj, key, value)
req.set_pubkey(pkey)
req.sign(pkey, digest)
return req
def createCertificate(req, (issuerCert, issuerKey), serial, (notBefore, notAfter), digest="md5"):
"""
Generate a certificate given a certificate request.

25
lib/cherrypy/LICENSE.txt Normal file
View File

@@ -0,0 +1,25 @@
Copyright (c) 2004-2011, CherryPy Team (team@cherrypy.org)
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the CherryPy Team nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

652
lib/cherrypy/__init__.py Normal file
View File

@@ -0,0 +1,652 @@
"""CherryPy is a pythonic, object-oriented HTTP framework.
CherryPy consists of not one, but four separate API layers.
The APPLICATION LAYER is the simplest. CherryPy applications are written as
a tree of classes and methods, where each branch in the tree corresponds to
a branch in the URL path. Each method is a 'page handler', which receives
GET and POST params as keyword arguments, and returns or yields the (HTML)
body of the response. The special method name 'index' is used for paths
that end in a slash, and the special method name 'default' is used to
handle multiple paths via a single handler. This layer also includes:
* the 'exposed' attribute (and cherrypy.expose)
* cherrypy.quickstart()
* _cp_config attributes
* cherrypy.tools (including cherrypy.session)
* cherrypy.url()
The ENVIRONMENT LAYER is used by developers at all levels. It provides
information about the current request and response, plus the application
and server environment, via a (default) set of top-level objects:
* cherrypy.request
* cherrypy.response
* cherrypy.engine
* cherrypy.server
* cherrypy.tree
* cherrypy.config
* cherrypy.thread_data
* cherrypy.log
* cherrypy.HTTPError, NotFound, and HTTPRedirect
* cherrypy.lib
The EXTENSION LAYER allows advanced users to construct and share their own
plugins. It consists of:
* Hook API
* Tool API
* Toolbox API
* Dispatch API
* Config Namespace API
Finally, there is the CORE LAYER, which uses the core API's to construct
the default components which are available at higher layers. You can think
of the default components as the 'reference implementation' for CherryPy.
Megaframeworks (and advanced users) may replace the default components
with customized or extended components. The core API's are:
* Application API
* Engine API
* Request API
* Server API
* WSGI API
These API's are described in the `CherryPy specification <https://bitbucket.org/cherrypy/cherrypy/wiki/CherryPySpec>`_.
"""
__version__ = "3.6.0"
from cherrypy._cpcompat import urljoin as _urljoin, urlencode as _urlencode
from cherrypy._cpcompat import basestring, unicodestr, set
from cherrypy._cperror import HTTPError, HTTPRedirect, InternalRedirect
from cherrypy._cperror import NotFound, CherryPyException, TimeoutError
from cherrypy import _cpdispatch as dispatch
from cherrypy import _cptools
tools = _cptools.default_toolbox
Tool = _cptools.Tool
from cherrypy import _cprequest
from cherrypy.lib import httputil as _httputil
from cherrypy import _cptree
tree = _cptree.Tree()
from cherrypy._cptree import Application
from cherrypy import _cpwsgi as wsgi
from cherrypy import process
try:
from cherrypy.process import win32
engine = win32.Win32Bus()
engine.console_control_handler = win32.ConsoleCtrlHandler(engine)
del win32
except ImportError:
engine = process.bus
# Timeout monitor. We add two channels to the engine
# to which cherrypy.Application will publish.
engine.listeners['before_request'] = set()
engine.listeners['after_request'] = set()
class _TimeoutMonitor(process.plugins.Monitor):
def __init__(self, bus):
self.servings = []
process.plugins.Monitor.__init__(self, bus, self.run)
def before_request(self):
self.servings.append((serving.request, serving.response))
def after_request(self):
try:
self.servings.remove((serving.request, serving.response))
except ValueError:
pass
def run(self):
"""Check timeout on all responses. (Internal)"""
for req, resp in self.servings:
resp.check_timeout()
engine.timeout_monitor = _TimeoutMonitor(engine)
engine.timeout_monitor.subscribe()
engine.autoreload = process.plugins.Autoreloader(engine)
engine.autoreload.subscribe()
engine.thread_manager = process.plugins.ThreadManager(engine)
engine.thread_manager.subscribe()
engine.signal_handler = process.plugins.SignalHandler(engine)
class _HandleSignalsPlugin(object):
"""Handle signals from other processes based on the configured
platform handlers above."""
def __init__(self, bus):
self.bus = bus
def subscribe(self):
"""Add the handlers based on the platform"""
if hasattr(self.bus, "signal_handler"):
self.bus.signal_handler.subscribe()
if hasattr(self.bus, "console_control_handler"):
self.bus.console_control_handler.subscribe()
engine.signals = _HandleSignalsPlugin(engine)
from cherrypy import _cpserver
server = _cpserver.Server()
server.subscribe()
def quickstart(root=None, script_name="", config=None):
"""Mount the given root, start the builtin server (and engine), then block.
root: an instance of a "controller class" (a collection of page handler
methods) which represents the root of the application.
script_name: a string containing the "mount point" of the application.
This should start with a slash, and be the path portion of the URL
at which to mount the given root. For example, if root.index() will
handle requests to "http://www.example.com:8080/dept/app1/", then
the script_name argument would be "/dept/app1".
It MUST NOT end in a slash. If the script_name refers to the root
of the URI, it MUST be an empty string (not "/").
config: a file or dict containing application config. If this contains
a [global] section, those entries will be used in the global
(site-wide) config.
"""
if config:
_global_conf_alias.update(config)
tree.mount(root, script_name, config)
engine.signals.subscribe()
engine.start()
engine.block()
from cherrypy._cpcompat import threadlocal as _local
class _Serving(_local):
"""An interface for registering request and response objects.
Rather than have a separate "thread local" object for the request and
the response, this class works as a single threadlocal container for
both objects (and any others which developers wish to define). In this
way, we can easily dump those objects when we stop/start a new HTTP
conversation, yet still refer to them as module-level globals in a
thread-safe way.
"""
request = _cprequest.Request(_httputil.Host("127.0.0.1", 80),
_httputil.Host("127.0.0.1", 1111))
"""
The request object for the current thread. In the main thread,
and any threads which are not receiving HTTP requests, this is None."""
response = _cprequest.Response()
"""
The response object for the current thread. In the main thread,
and any threads which are not receiving HTTP requests, this is None."""
def load(self, request, response):
self.request = request
self.response = response
def clear(self):
"""Remove all attributes of self."""
self.__dict__.clear()
serving = _Serving()
class _ThreadLocalProxy(object):
__slots__ = ['__attrname__', '__dict__']
def __init__(self, attrname):
self.__attrname__ = attrname
def __getattr__(self, name):
child = getattr(serving, self.__attrname__)
return getattr(child, name)
def __setattr__(self, name, value):
if name in ("__attrname__", ):
object.__setattr__(self, name, value)
else:
child = getattr(serving, self.__attrname__)
setattr(child, name, value)
def __delattr__(self, name):
child = getattr(serving, self.__attrname__)
delattr(child, name)
def _get_dict(self):
child = getattr(serving, self.__attrname__)
d = child.__class__.__dict__.copy()
d.update(child.__dict__)
return d
__dict__ = property(_get_dict)
def __getitem__(self, key):
child = getattr(serving, self.__attrname__)
return child[key]
def __setitem__(self, key, value):
child = getattr(serving, self.__attrname__)
child[key] = value
def __delitem__(self, key):
child = getattr(serving, self.__attrname__)
del child[key]
def __contains__(self, key):
child = getattr(serving, self.__attrname__)
return key in child
def __len__(self):
child = getattr(serving, self.__attrname__)
return len(child)
def __nonzero__(self):
child = getattr(serving, self.__attrname__)
return bool(child)
# Python 3
__bool__ = __nonzero__
# Create request and response object (the same objects will be used
# throughout the entire life of the webserver, but will redirect
# to the "serving" object)
request = _ThreadLocalProxy('request')
response = _ThreadLocalProxy('response')
# Create thread_data object as a thread-specific all-purpose storage
class _ThreadData(_local):
"""A container for thread-specific data."""
thread_data = _ThreadData()
# Monkeypatch pydoc to allow help() to go through the threadlocal proxy.
# Jan 2007: no Googleable examples of anyone else replacing pydoc.resolve.
# The only other way would be to change what is returned from type(request)
# and that's not possible in pure Python (you'd have to fake ob_type).
def _cherrypy_pydoc_resolve(thing, forceload=0):
"""Given an object or a path to an object, get the object and its name."""
if isinstance(thing, _ThreadLocalProxy):
thing = getattr(serving, thing.__attrname__)
return _pydoc._builtin_resolve(thing, forceload)
try:
import pydoc as _pydoc
_pydoc._builtin_resolve = _pydoc.resolve
_pydoc.resolve = _cherrypy_pydoc_resolve
except ImportError:
pass
from cherrypy import _cplogging
class _GlobalLogManager(_cplogging.LogManager):
"""A site-wide LogManager; routes to app.log or global log as appropriate.
This :class:`LogManager<cherrypy._cplogging.LogManager>` implements
cherrypy.log() and cherrypy.log.access(). If either
function is called during a request, the message will be sent to the
logger for the current Application. If they are called outside of a
request, the message will be sent to the site-wide logger.
"""
def __call__(self, *args, **kwargs):
"""Log the given message to the app.log or global log as appropriate.
"""
# Do NOT use try/except here. See
# https://bitbucket.org/cherrypy/cherrypy/issue/945
if hasattr(request, 'app') and hasattr(request.app, 'log'):
log = request.app.log
else:
log = self
return log.error(*args, **kwargs)
def access(self):
"""Log an access message to the app.log or global log as appropriate.
"""
try:
return request.app.log.access()
except AttributeError:
return _cplogging.LogManager.access(self)
log = _GlobalLogManager()
# Set a default screen handler on the global log.
log.screen = True
log.error_file = ''
# Using an access file makes CP about 10% slower. Leave off by default.
log.access_file = ''
def _buslog(msg, level):
log.error(msg, 'ENGINE', severity=level)
engine.subscribe('log', _buslog)
# Helper functions for CP apps #
def expose(func=None, alias=None):
"""Expose the function, optionally providing an alias or set of aliases."""
def expose_(func):
func.exposed = True
if alias is not None:
if isinstance(alias, basestring):
parents[alias.replace(".", "_")] = func
else:
for a in alias:
parents[a.replace(".", "_")] = func
return func
import sys
import types
if isinstance(func, (types.FunctionType, types.MethodType)):
if alias is None:
# @expose
func.exposed = True
return func
else:
# func = expose(func, alias)
parents = sys._getframe(1).f_locals
return expose_(func)
elif func is None:
if alias is None:
# @expose()
parents = sys._getframe(1).f_locals
return expose_
else:
# @expose(alias="alias") or
# @expose(alias=["alias1", "alias2"])
parents = sys._getframe(1).f_locals
return expose_
else:
# @expose("alias") or
# @expose(["alias1", "alias2"])
parents = sys._getframe(1).f_locals
alias = func
return expose_
def popargs(*args, **kwargs):
"""A decorator for _cp_dispatch
(cherrypy.dispatch.Dispatcher.dispatch_method_name).
Optional keyword argument: handler=(Object or Function)
Provides a _cp_dispatch function that pops off path segments into
cherrypy.request.params under the names specified. The dispatch
is then forwarded on to the next vpath element.
Note that any existing (and exposed) member function of the class that
popargs is applied to will override that value of the argument. For
instance, if you have a method named "list" on the class decorated with
popargs, then accessing "/list" will call that function instead of popping
it off as the requested parameter. This restriction applies to all
_cp_dispatch functions. The only way around this restriction is to create
a "blank class" whose only function is to provide _cp_dispatch.
If there are path elements after the arguments, or more arguments
are requested than are available in the vpath, then the 'handler'
keyword argument specifies the next object to handle the parameterized
request. If handler is not specified or is None, then self is used.
If handler is a function rather than an instance, then that function
will be called with the args specified and the return value from that
function used as the next object INSTEAD of adding the parameters to
cherrypy.request.args.
This decorator may be used in one of two ways:
As a class decorator:
@cherrypy.popargs('year', 'month', 'day')
class Blog:
def index(self, year=None, month=None, day=None):
#Process the parameters here; any url like
#/, /2009, /2009/12, or /2009/12/31
#will fill in the appropriate parameters.
def create(self):
#This link will still be available at /create. Defined functions
#take precedence over arguments.
Or as a member of a class:
class Blog:
_cp_dispatch = cherrypy.popargs('year', 'month', 'day')
#...
The handler argument may be used to mix arguments with built in functions.
For instance, the following setup allows different activities at the
day, month, and year level:
class DayHandler:
def index(self, year, month, day):
#Do something with this day; probably list entries
def delete(self, year, month, day):
#Delete all entries for this day
@cherrypy.popargs('day', handler=DayHandler())
class MonthHandler:
def index(self, year, month):
#Do something with this month; probably list entries
def delete(self, year, month):
#Delete all entries for this month
@cherrypy.popargs('month', handler=MonthHandler())
class YearHandler:
def index(self, year):
#Do something with this year
#...
@cherrypy.popargs('year', handler=YearHandler())
class Root:
def index(self):
#...
"""
# Since keyword arg comes after *args, we have to process it ourselves
# for lower versions of python.
handler = None
handler_call = False
for k, v in kwargs.items():
if k == 'handler':
handler = v
else:
raise TypeError(
"cherrypy.popargs() got an unexpected keyword argument '{0}'"
.format(k)
)
import inspect
if handler is not None \
and (hasattr(handler, '__call__') or inspect.isclass(handler)):
handler_call = True
def decorated(cls_or_self=None, vpath=None):
if inspect.isclass(cls_or_self):
# cherrypy.popargs is a class decorator
cls = cls_or_self
setattr(cls, dispatch.Dispatcher.dispatch_method_name, decorated)
return cls
# We're in the actual function
self = cls_or_self
parms = {}
for arg in args:
if not vpath:
break
parms[arg] = vpath.pop(0)
if handler is not None:
if handler_call:
return handler(**parms)
else:
request.params.update(parms)
return handler
request.params.update(parms)
# If we are the ultimate handler, then to prevent our _cp_dispatch
# from being called again, we will resolve remaining elements through
# getattr() directly.
if vpath:
return getattr(self, vpath.pop(0), None)
else:
return self
return decorated
def url(path="", qs="", script_name=None, base=None, relative=None):
"""Create an absolute URL for the given path.
If 'path' starts with a slash ('/'), this will return
(base + script_name + path + qs).
If it does not start with a slash, this returns
(base + script_name [+ request.path_info] + path + qs).
If script_name is None, cherrypy.request will be used
to find a script_name, if available.
If base is None, cherrypy.request.base will be used (if available).
Note that you can use cherrypy.tools.proxy to change this.
Finally, note that this function can be used to obtain an absolute URL
for the current request path (minus the querystring) by passing no args.
If you call url(qs=cherrypy.request.query_string), you should get the
original browser URL (assuming no internal redirections).
If relative is None or not provided, request.app.relative_urls will
be used (if available, else False). If False, the output will be an
absolute URL (including the scheme, host, vhost, and script_name).
If True, the output will instead be a URL that is relative to the
current request path, perhaps including '..' atoms. If relative is
the string 'server', the output will instead be a URL that is
relative to the server root; i.e., it will start with a slash.
"""
if isinstance(qs, (tuple, list, dict)):
qs = _urlencode(qs)
if qs:
qs = '?' + qs
if request.app:
if not path.startswith("/"):
# Append/remove trailing slash from path_info as needed
# (this is to support mistyped URL's without redirecting;
# if you want to redirect, use tools.trailing_slash).
pi = request.path_info
if request.is_index is True:
if not pi.endswith('/'):
pi = pi + '/'
elif request.is_index is False:
if pi.endswith('/') and pi != '/':
pi = pi[:-1]
if path == "":
path = pi
else:
path = _urljoin(pi, path)
if script_name is None:
script_name = request.script_name
if base is None:
base = request.base
newurl = base + script_name + path + qs
else:
# No request.app (we're being called outside a request).
# We'll have to guess the base from server.* attributes.
# This will produce very different results from the above
# if you're using vhosts or tools.proxy.
if base is None:
base = server.base()
path = (script_name or "") + path
newurl = base + path + qs
if './' in newurl:
# Normalize the URL by removing ./ and ../
atoms = []
for atom in newurl.split('/'):
if atom == '.':
pass
elif atom == '..':
atoms.pop()
else:
atoms.append(atom)
newurl = '/'.join(atoms)
# At this point, we should have a fully-qualified absolute URL.
if relative is None:
relative = getattr(request.app, "relative_urls", False)
# See http://www.ietf.org/rfc/rfc2396.txt
if relative == 'server':
# "A relative reference beginning with a single slash character is
# termed an absolute-path reference, as defined by <abs_path>..."
# This is also sometimes called "server-relative".
newurl = '/' + '/'.join(newurl.split('/', 3)[3:])
elif relative:
# "A relative reference that does not begin with a scheme name
# or a slash character is termed a relative-path reference."
old = url(relative=False).split('/')[:-1]
new = newurl.split('/')
while old and new:
a, b = old[0], new[0]
if a != b:
break
old.pop(0)
new.pop(0)
new = (['..'] * len(old)) + new
newurl = '/'.join(new)
return newurl
# import _cpconfig last so it can reference other top-level objects
from cherrypy import _cpconfig
# Use _global_conf_alias so quickstart can use 'config' as an arg
# without shadowing cherrypy.config.
config = _global_conf_alias = _cpconfig.Config()
config.defaults = {
'tools.log_tracebacks.on': True,
'tools.log_headers.on': True,
'tools.trailing_slash.on': True,
'tools.encode.on': True
}
config.namespaces["log"] = lambda k, v: setattr(log, k, v)
config.namespaces["checker"] = lambda k, v: setattr(checker, k, v)
# Must reset to get our defaults applied.
config.reset()
from cherrypy import _cpchecker
checker = _cpchecker.Checker()
engine.subscribe('start', checker)

332
lib/cherrypy/_cpchecker.py Normal file
View File

@@ -0,0 +1,332 @@
import os
import warnings
import cherrypy
from cherrypy._cpcompat import iteritems, copykeys, builtins
class Checker(object):
"""A checker for CherryPy sites and their mounted applications.
When this object is called at engine startup, it executes each
of its own methods whose names start with ``check_``. If you wish
to disable selected checks, simply add a line in your global
config which sets the appropriate method to False::
[global]
checker.check_skipped_app_config = False
You may also dynamically add or replace ``check_*`` methods in this way.
"""
on = True
"""If True (the default), run all checks; if False, turn off all checks."""
def __init__(self):
self._populate_known_types()
def __call__(self):
"""Run all check_* methods."""
if self.on:
oldformatwarning = warnings.formatwarning
warnings.formatwarning = self.formatwarning
try:
for name in dir(self):
if name.startswith("check_"):
method = getattr(self, name)
if method and hasattr(method, '__call__'):
method()
finally:
warnings.formatwarning = oldformatwarning
def formatwarning(self, message, category, filename, lineno, line=None):
"""Function to format a warning."""
return "CherryPy Checker:\n%s\n\n" % message
# This value should be set inside _cpconfig.
global_config_contained_paths = False
def check_app_config_entries_dont_start_with_script_name(self):
"""Check for Application config with sections that repeat script_name.
"""
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
if not app.config:
continue
if sn == '':
continue
sn_atoms = sn.strip("/").split("/")
for key in app.config.keys():
key_atoms = key.strip("/").split("/")
if key_atoms[:len(sn_atoms)] == sn_atoms:
warnings.warn(
"The application mounted at %r has config "
"entries that start with its script name: %r" % (sn,
key))
def check_site_config_entries_in_app_config(self):
"""Check for mounted Applications that have site-scoped config."""
for sn, app in iteritems(cherrypy.tree.apps):
if not isinstance(app, cherrypy.Application):
continue
msg = []
for section, entries in iteritems(app.config):
if section.startswith('/'):
for key, value in iteritems(entries):
for n in ("engine.", "server.", "tree.", "checker."):
if key.startswith(n):
msg.append("[%s] %s = %s" %
(section, key, value))
if msg:
msg.insert(0,
"The application mounted at %r contains the "
"following config entries, which are only allowed "
"in site-wide config. Move them to a [global] "
"section and pass them to cherrypy.config.update() "
"instead of tree.mount()." % sn)
warnings.warn(os.linesep.join(msg))
def check_skipped_app_config(self):
"""Check for mounted Applications that have no config."""
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
if not app.config:
msg = "The Application mounted at %r has an empty config." % sn
if self.global_config_contained_paths:
msg += (" It looks like the config you passed to "
"cherrypy.config.update() contains application-"
"specific sections. You must explicitly pass "
"application config via "
"cherrypy.tree.mount(..., config=app_config)")
warnings.warn(msg)
return
def check_app_config_brackets(self):
"""Check for Application config with extraneous brackets in section
names.
"""
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
if not app.config:
continue
for key in app.config.keys():
if key.startswith("[") or key.endswith("]"):
warnings.warn(
"The application mounted at %r has config "
"section names with extraneous brackets: %r. "
"Config *files* need brackets; config *dicts* "
"(e.g. passed to tree.mount) do not." % (sn, key))
def check_static_paths(self):
"""Check Application config for incorrect static paths."""
# Use the dummy Request object in the main thread.
request = cherrypy.request
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
request.app = app
for section in app.config:
# get_resource will populate request.config
request.get_resource(section + "/dummy.html")
conf = request.config.get
if conf("tools.staticdir.on", False):
msg = ""
root = conf("tools.staticdir.root")
dir = conf("tools.staticdir.dir")
if dir is None:
msg = "tools.staticdir.dir is not set."
else:
fulldir = ""
if os.path.isabs(dir):
fulldir = dir
if root:
msg = ("dir is an absolute path, even "
"though a root is provided.")
testdir = os.path.join(root, dir[1:])
if os.path.exists(testdir):
msg += (
"\nIf you meant to serve the "
"filesystem folder at %r, remove the "
"leading slash from dir." % (testdir,))
else:
if not root:
msg = (
"dir is a relative path and "
"no root provided.")
else:
fulldir = os.path.join(root, dir)
if not os.path.isabs(fulldir):
msg = ("%r is not an absolute path." % (
fulldir,))
if fulldir and not os.path.exists(fulldir):
if msg:
msg += "\n"
msg += ("%r (root + dir) is not an existing "
"filesystem path." % fulldir)
if msg:
warnings.warn("%s\nsection: [%s]\nroot: %r\ndir: %r"
% (msg, section, root, dir))
# -------------------------- Compatibility -------------------------- #
obsolete = {
'server.default_content_type': 'tools.response_headers.headers',
'log_access_file': 'log.access_file',
'log_config_options': None,
'log_file': 'log.error_file',
'log_file_not_found': None,
'log_request_headers': 'tools.log_headers.on',
'log_to_screen': 'log.screen',
'show_tracebacks': 'request.show_tracebacks',
'throw_errors': 'request.throw_errors',
'profiler.on': ('cherrypy.tree.mount(profiler.make_app('
'cherrypy.Application(Root())))'),
}
deprecated = {}
def _compat(self, config):
"""Process config and warn on each obsolete or deprecated entry."""
for section, conf in config.items():
if isinstance(conf, dict):
for k, v in conf.items():
if k in self.obsolete:
warnings.warn("%r is obsolete. Use %r instead.\n"
"section: [%s]" %
(k, self.obsolete[k], section))
elif k in self.deprecated:
warnings.warn("%r is deprecated. Use %r instead.\n"
"section: [%s]" %
(k, self.deprecated[k], section))
else:
if section in self.obsolete:
warnings.warn("%r is obsolete. Use %r instead."
% (section, self.obsolete[section]))
elif section in self.deprecated:
warnings.warn("%r is deprecated. Use %r instead."
% (section, self.deprecated[section]))
def check_compatibility(self):
"""Process config and warn on each obsolete or deprecated entry."""
self._compat(cherrypy.config)
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
self._compat(app.config)
# ------------------------ Known Namespaces ------------------------ #
extra_config_namespaces = []
def _known_ns(self, app):
ns = ["wsgi"]
ns.extend(copykeys(app.toolboxes))
ns.extend(copykeys(app.namespaces))
ns.extend(copykeys(app.request_class.namespaces))
ns.extend(copykeys(cherrypy.config.namespaces))
ns += self.extra_config_namespaces
for section, conf in app.config.items():
is_path_section = section.startswith("/")
if is_path_section and isinstance(conf, dict):
for k, v in conf.items():
atoms = k.split(".")
if len(atoms) > 1:
if atoms[0] not in ns:
# Spit out a special warning if a known
# namespace is preceded by "cherrypy."
if atoms[0] == "cherrypy" and atoms[1] in ns:
msg = (
"The config entry %r is invalid; "
"try %r instead.\nsection: [%s]"
% (k, ".".join(atoms[1:]), section))
else:
msg = (
"The config entry %r is invalid, "
"because the %r config namespace "
"is unknown.\n"
"section: [%s]" % (k, atoms[0], section))
warnings.warn(msg)
elif atoms[0] == "tools":
if atoms[1] not in dir(cherrypy.tools):
msg = (
"The config entry %r may be invalid, "
"because the %r tool was not found.\n"
"section: [%s]" % (k, atoms[1], section))
warnings.warn(msg)
def check_config_namespaces(self):
"""Process config and warn on each unknown config namespace."""
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
self._known_ns(app)
# -------------------------- Config Types -------------------------- #
known_config_types = {}
def _populate_known_types(self):
b = [x for x in vars(builtins).values()
if type(x) is type(str)]
def traverse(obj, namespace):
for name in dir(obj):
# Hack for 3.2's warning about body_params
if name == 'body_params':
continue
vtype = type(getattr(obj, name, None))
if vtype in b:
self.known_config_types[namespace + "." + name] = vtype
traverse(cherrypy.request, "request")
traverse(cherrypy.response, "response")
traverse(cherrypy.server, "server")
traverse(cherrypy.engine, "engine")
traverse(cherrypy.log, "log")
def _known_types(self, config):
msg = ("The config entry %r in section %r is of type %r, "
"which does not match the expected type %r.")
for section, conf in config.items():
if isinstance(conf, dict):
for k, v in conf.items():
if v is not None:
expected_type = self.known_config_types.get(k, None)
vtype = type(v)
if expected_type and vtype != expected_type:
warnings.warn(msg % (k, section, vtype.__name__,
expected_type.__name__))
else:
k, v = section, conf
if v is not None:
expected_type = self.known_config_types.get(k, None)
vtype = type(v)
if expected_type and vtype != expected_type:
warnings.warn(msg % (k, section, vtype.__name__,
expected_type.__name__))
def check_config_types(self):
"""Assert that config values are of the same type as default values."""
self._known_types(cherrypy.config)
for sn, app in cherrypy.tree.apps.items():
if not isinstance(app, cherrypy.Application):
continue
self._known_types(app.config)
# -------------------- Specific config warnings -------------------- #
def check_localhost(self):
"""Warn if any socket_host is 'localhost'. See #711."""
for k, v in cherrypy.config.items():
if k == 'server.socket_host' and v == 'localhost':
warnings.warn("The use of 'localhost' as a socket host can "
"cause problems on newer systems, since "
"'localhost' can map to either an IPv4 or an "
"IPv6 address. You should use '127.0.0.1' "
"or '[::1]' instead.")

383
lib/cherrypy/_cpcompat.py Normal file
View File

@@ -0,0 +1,383 @@
"""Compatibility code for using CherryPy with various versions of Python.
CherryPy 3.2 is compatible with Python versions 2.3+. This module provides a
useful abstraction over the differences between Python versions, sometimes by
preferring a newer idiom, sometimes an older one, and sometimes a custom one.
In particular, Python 2 uses str and '' for byte strings, while Python 3
uses str and '' for unicode strings. We will call each of these the 'native
string' type for each version. Because of this major difference, this module
provides new 'bytestr', 'unicodestr', and 'nativestr' attributes, as well as
two functions: 'ntob', which translates native strings (of type 'str') into
byte strings regardless of Python version, and 'ntou', which translates native
strings to unicode strings. This also provides a 'BytesIO' name for dealing
specifically with bytes, and a 'StringIO' name for dealing with native strings.
It also provides a 'base64_decode' function with native strings as input and
output.
"""
import os
import re
import sys
import threading
if sys.version_info >= (3, 0):
py3k = True
bytestr = bytes
unicodestr = str
nativestr = unicodestr
basestring = (bytes, str)
def ntob(n, encoding='ISO-8859-1'):
"""Return the given native string as a byte string in the given
encoding.
"""
assert_native(n)
# In Python 3, the native string type is unicode
return n.encode(encoding)
def ntou(n, encoding='ISO-8859-1'):
"""Return the given native string as a unicode string with the given
encoding.
"""
assert_native(n)
# In Python 3, the native string type is unicode
return n
def tonative(n, encoding='ISO-8859-1'):
"""Return the given string as a native string in the given encoding."""
# In Python 3, the native string type is unicode
if isinstance(n, bytes):
return n.decode(encoding)
return n
# type("")
from io import StringIO
# bytes:
from io import BytesIO as BytesIO
else:
# Python 2
py3k = False
bytestr = str
unicodestr = unicode
nativestr = bytestr
basestring = basestring
def ntob(n, encoding='ISO-8859-1'):
"""Return the given native string as a byte string in the given
encoding.
"""
assert_native(n)
# In Python 2, the native string type is bytes. Assume it's already
# in the given encoding, which for ISO-8859-1 is almost always what
# was intended.
return n
def ntou(n, encoding='ISO-8859-1'):
"""Return the given native string as a unicode string with the given
encoding.
"""
assert_native(n)
# In Python 2, the native string type is bytes.
# First, check for the special encoding 'escape'. The test suite uses
# this to signal that it wants to pass a string with embedded \uXXXX
# escapes, but without having to prefix it with u'' for Python 2,
# but no prefix for Python 3.
if encoding == 'escape':
return unicode(
re.sub(r'\\u([0-9a-zA-Z]{4})',
lambda m: unichr(int(m.group(1), 16)),
n.decode('ISO-8859-1')))
# Assume it's already in the given encoding, which for ISO-8859-1
# is almost always what was intended.
return n.decode(encoding)
def tonative(n, encoding='ISO-8859-1'):
"""Return the given string as a native string in the given encoding."""
# In Python 2, the native string type is bytes.
if isinstance(n, unicode):
return n.encode(encoding)
return n
try:
# type("")
from cStringIO import StringIO
except ImportError:
# type("")
from StringIO import StringIO
# bytes:
BytesIO = StringIO
def assert_native(n):
if not isinstance(n, nativestr):
raise TypeError("n must be a native str (got %s)" % type(n).__name__)
try:
set = set
except NameError:
from sets import Set as set
try:
# Python 3.1+
from base64 import decodebytes as _base64_decodebytes
except ImportError:
# Python 3.0-
# since CherryPy claims compability with Python 2.3, we must use
# the legacy API of base64
from base64 import decodestring as _base64_decodebytes
def base64_decode(n, encoding='ISO-8859-1'):
"""Return the native string base64-decoded (as a native string)."""
if isinstance(n, unicodestr):
b = n.encode(encoding)
else:
b = n
b = _base64_decodebytes(b)
if nativestr is unicodestr:
return b.decode(encoding)
else:
return b
try:
# Python 2.5+
from hashlib import md5
except ImportError:
from md5 import new as md5
try:
# Python 2.5+
from hashlib import sha1 as sha
except ImportError:
from sha import new as sha
try:
sorted = sorted
except NameError:
def sorted(i):
i = i[:]
i.sort()
return i
try:
reversed = reversed
except NameError:
def reversed(x):
i = len(x)
while i > 0:
i -= 1
yield x[i]
try:
# Python 3
from urllib.parse import urljoin, urlencode
from urllib.parse import quote, quote_plus
from urllib.request import unquote, urlopen
from urllib.request import parse_http_list, parse_keqv_list
except ImportError:
# Python 2
from urlparse import urljoin
from urllib import urlencode, urlopen
from urllib import quote, quote_plus
from urllib import unquote
from urllib2 import parse_http_list, parse_keqv_list
try:
from threading import local as threadlocal
except ImportError:
from cherrypy._cpthreadinglocal import local as threadlocal
try:
dict.iteritems
# Python 2
iteritems = lambda d: d.iteritems()
copyitems = lambda d: d.items()
except AttributeError:
# Python 3
iteritems = lambda d: d.items()
copyitems = lambda d: list(d.items())
try:
dict.iterkeys
# Python 2
iterkeys = lambda d: d.iterkeys()
copykeys = lambda d: d.keys()
except AttributeError:
# Python 3
iterkeys = lambda d: d.keys()
copykeys = lambda d: list(d.keys())
try:
dict.itervalues
# Python 2
itervalues = lambda d: d.itervalues()
copyvalues = lambda d: d.values()
except AttributeError:
# Python 3
itervalues = lambda d: d.values()
copyvalues = lambda d: list(d.values())
try:
# Python 3
import builtins
except ImportError:
# Python 2
import __builtin__ as builtins
try:
# Python 2. We try Python 2 first clients on Python 2
# don't try to import the 'http' module from cherrypy.lib
from Cookie import SimpleCookie, CookieError
from httplib import BadStatusLine, HTTPConnection, IncompleteRead
from httplib import NotConnected
from BaseHTTPServer import BaseHTTPRequestHandler
except ImportError:
# Python 3
from http.cookies import SimpleCookie, CookieError
from http.client import BadStatusLine, HTTPConnection, IncompleteRead
from http.client import NotConnected
from http.server import BaseHTTPRequestHandler
# Some platforms don't expose HTTPSConnection, so handle it separately
if py3k:
try:
from http.client import HTTPSConnection
except ImportError:
# Some platforms which don't have SSL don't expose HTTPSConnection
HTTPSConnection = None
else:
try:
from httplib import HTTPSConnection
except ImportError:
HTTPSConnection = None
try:
# Python 2
xrange = xrange
except NameError:
# Python 3
xrange = range
import threading
if hasattr(threading.Thread, "daemon"):
# Python 2.6+
def get_daemon(t):
return t.daemon
def set_daemon(t, val):
t.daemon = val
else:
def get_daemon(t):
return t.isDaemon()
def set_daemon(t, val):
t.setDaemon(val)
try:
from email.utils import formatdate
def HTTPDate(timeval=None):
return formatdate(timeval, usegmt=True)
except ImportError:
from rfc822 import formatdate as HTTPDate
try:
# Python 3
from urllib.parse import unquote as parse_unquote
def unquote_qs(atom, encoding, errors='strict'):
return parse_unquote(
atom.replace('+', ' '),
encoding=encoding,
errors=errors)
except ImportError:
# Python 2
from urllib import unquote as parse_unquote
def unquote_qs(atom, encoding, errors='strict'):
return parse_unquote(atom.replace('+', ' ')).decode(encoding, errors)
try:
# Prefer simplejson, which is usually more advanced than the builtin
# module.
import simplejson as json
json_decode = json.JSONDecoder().decode
_json_encode = json.JSONEncoder().iterencode
except ImportError:
if sys.version_info >= (2, 6):
# Python >=2.6 : json is part of the standard library
import json
json_decode = json.JSONDecoder().decode
_json_encode = json.JSONEncoder().iterencode
else:
json = None
def json_decode(s):
raise ValueError('No JSON library is available')
def _json_encode(s):
raise ValueError('No JSON library is available')
finally:
if json and py3k:
# The two Python 3 implementations (simplejson/json)
# outputs str. We need bytes.
def json_encode(value):
for chunk in _json_encode(value):
yield chunk.encode('utf8')
else:
json_encode = _json_encode
try:
import cPickle as pickle
except ImportError:
# In Python 2, pickle is a Python version.
# In Python 3, pickle is the sped-up C version.
import pickle
try:
os.urandom(20)
import binascii
def random20():
return binascii.hexlify(os.urandom(20)).decode('ascii')
except (AttributeError, NotImplementedError):
import random
# os.urandom not available until Python 2.4. Fall back to random.random.
def random20():
return sha('%s' % random.random()).hexdigest()
try:
from _thread import get_ident as get_thread_ident
except ImportError:
from thread import get_ident as get_thread_ident
try:
# Python 3
next = next
except NameError:
# Python 2
def next(i):
return i.next()
if sys.version_info >= (3, 3):
Timer = threading.Timer
Event = threading.Event
else:
# Python 3.2 and earlier
Timer = threading._Timer
Event = threading._Event
# Prior to Python 2.6, the Thread class did not have a .daemon property.
# This mix-in adds that property.
class SetDaemonProperty:
def __get_daemon(self):
return self.isDaemon()
def __set_daemon(self, daemon):
self.setDaemon(daemon)
if sys.version_info < (2, 6):
daemon = property(__get_daemon, __set_daemon)

File diff suppressed because it is too large Load Diff

317
lib/cherrypy/_cpconfig.py Normal file
View File

@@ -0,0 +1,317 @@
"""
Configuration system for CherryPy.
Configuration in CherryPy is implemented via dictionaries. Keys are strings
which name the mapped value, which may be of any type.
Architecture
------------
CherryPy Requests are part of an Application, which runs in a global context,
and configuration data may apply to any of those three scopes:
Global
Configuration entries which apply everywhere are stored in
cherrypy.config.
Application
Entries which apply to each mounted application are stored
on the Application object itself, as 'app.config'. This is a two-level
dict where each key is a path, or "relative URL" (for example, "/" or
"/path/to/my/page"), and each value is a config dict. Usually, this
data is provided in the call to tree.mount(root(), config=conf),
although you may also use app.merge(conf).
Request
Each Request object possesses a single 'Request.config' dict.
Early in the request process, this dict is populated by merging global
config entries, Application entries (whose path equals or is a parent
of Request.path_info), and any config acquired while looking up the
page handler (see next).
Declaration
-----------
Configuration data may be supplied as a Python dictionary, as a filename,
or as an open file object. When you supply a filename or file, CherryPy
uses Python's builtin ConfigParser; you declare Application config by
writing each path as a section header::
[/path/to/my/page]
request.stream = True
To declare global configuration entries, place them in a [global] section.
You may also declare config entries directly on the classes and methods
(page handlers) that make up your CherryPy application via the ``_cp_config``
attribute. For example::
class Demo:
_cp_config = {'tools.gzip.on': True}
def index(self):
return "Hello world"
index.exposed = True
index._cp_config = {'request.show_tracebacks': False}
.. note::
This behavior is only guaranteed for the default dispatcher.
Other dispatchers may have different restrictions on where
you can attach _cp_config attributes.
Namespaces
----------
Configuration keys are separated into namespaces by the first "." in the key.
Current namespaces:
engine
Controls the 'application engine', including autoreload.
These can only be declared in the global config.
tree
Grafts cherrypy.Application objects onto cherrypy.tree.
These can only be declared in the global config.
hooks
Declares additional request-processing functions.
log
Configures the logging for each application.
These can only be declared in the global or / config.
request
Adds attributes to each Request.
response
Adds attributes to each Response.
server
Controls the default HTTP server via cherrypy.server.
These can only be declared in the global config.
tools
Runs and configures additional request-processing packages.
wsgi
Adds WSGI middleware to an Application's "pipeline".
These can only be declared in the app's root config ("/").
checker
Controls the 'checker', which looks for common errors in
app state (including config) when the engine starts.
Global config only.
The only key that does not exist in a namespace is the "environment" entry.
This special entry 'imports' other config entries from a template stored in
cherrypy._cpconfig.environments[environment]. It only applies to the global
config, and only when you use cherrypy.config.update.
You can define your own namespaces to be called at the Global, Application,
or Request level, by adding a named handler to cherrypy.config.namespaces,
app.namespaces, or app.request_class.namespaces. The name can
be any string, and the handler must be either a callable or a (Python 2.5
style) context manager.
"""
import cherrypy
from cherrypy._cpcompat import set, basestring
from cherrypy.lib import reprconf
# Deprecated in CherryPy 3.2--remove in 3.3
NamespaceSet = reprconf.NamespaceSet
def merge(base, other):
"""Merge one app config (from a dict, file, or filename) into another.
If the given config is a filename, it will be appended to
the list of files to monitor for "autoreload" changes.
"""
if isinstance(other, basestring):
cherrypy.engine.autoreload.files.add(other)
# Load other into base
for section, value_map in reprconf.as_dict(other).items():
if not isinstance(value_map, dict):
raise ValueError(
"Application config must include section headers, but the "
"config you tried to merge doesn't have any sections. "
"Wrap your config in another dict with paths as section "
"headers, for example: {'/': config}.")
base.setdefault(section, {}).update(value_map)
class Config(reprconf.Config):
"""The 'global' configuration data for the entire CherryPy process."""
def update(self, config):
"""Update self from a dict, file or filename."""
if isinstance(config, basestring):
# Filename
cherrypy.engine.autoreload.files.add(config)
reprconf.Config.update(self, config)
def _apply(self, config):
"""Update self from a dict."""
if isinstance(config.get("global"), dict):
if len(config) > 1:
cherrypy.checker.global_config_contained_paths = True
config = config["global"]
if 'tools.staticdir.dir' in config:
config['tools.staticdir.section'] = "global"
reprconf.Config._apply(self, config)
def __call__(self, *args, **kwargs):
"""Decorator for page handlers to set _cp_config."""
if args:
raise TypeError(
"The cherrypy.config decorator does not accept positional "
"arguments; you must use keyword arguments.")
def tool_decorator(f):
if not hasattr(f, "_cp_config"):
f._cp_config = {}
for k, v in kwargs.items():
f._cp_config[k] = v
return f
return tool_decorator
# Sphinx begin config.environments
Config.environments = environments = {
"staging": {
'engine.autoreload.on': False,
'checker.on': False,
'tools.log_headers.on': False,
'request.show_tracebacks': False,
'request.show_mismatched_params': False,
},
"production": {
'engine.autoreload.on': False,
'checker.on': False,
'tools.log_headers.on': False,
'request.show_tracebacks': False,
'request.show_mismatched_params': False,
'log.screen': False,
},
"embedded": {
# For use with CherryPy embedded in another deployment stack.
'engine.autoreload.on': False,
'checker.on': False,
'tools.log_headers.on': False,
'request.show_tracebacks': False,
'request.show_mismatched_params': False,
'log.screen': False,
'engine.SIGHUP': None,
'engine.SIGTERM': None,
},
"test_suite": {
'engine.autoreload.on': False,
'checker.on': False,
'tools.log_headers.on': False,
'request.show_tracebacks': True,
'request.show_mismatched_params': True,
'log.screen': False,
},
}
# Sphinx end config.environments
def _server_namespace_handler(k, v):
"""Config handler for the "server" namespace."""
atoms = k.split(".", 1)
if len(atoms) > 1:
# Special-case config keys of the form 'server.servername.socket_port'
# to configure additional HTTP servers.
if not hasattr(cherrypy, "servers"):
cherrypy.servers = {}
servername, k = atoms
if servername not in cherrypy.servers:
from cherrypy import _cpserver
cherrypy.servers[servername] = _cpserver.Server()
# On by default, but 'on = False' can unsubscribe it (see below).
cherrypy.servers[servername].subscribe()
if k == 'on':
if v:
cherrypy.servers[servername].subscribe()
else:
cherrypy.servers[servername].unsubscribe()
else:
setattr(cherrypy.servers[servername], k, v)
else:
setattr(cherrypy.server, k, v)
Config.namespaces["server"] = _server_namespace_handler
def _engine_namespace_handler(k, v):
"""Backward compatibility handler for the "engine" namespace."""
engine = cherrypy.engine
deprecated = {
'autoreload_on': 'autoreload.on',
'autoreload_frequency': 'autoreload.frequency',
'autoreload_match': 'autoreload.match',
'reload_files': 'autoreload.files',
'deadlock_poll_freq': 'timeout_monitor.frequency'
}
if k in deprecated:
engine.log(
'WARNING: Use of engine.%s is deprecated and will be removed in a '
'future version. Use engine.%s instead.' % (k, deprecated[k]))
if k == 'autoreload_on':
if v:
engine.autoreload.subscribe()
else:
engine.autoreload.unsubscribe()
elif k == 'autoreload_frequency':
engine.autoreload.frequency = v
elif k == 'autoreload_match':
engine.autoreload.match = v
elif k == 'reload_files':
engine.autoreload.files = set(v)
elif k == 'deadlock_poll_freq':
engine.timeout_monitor.frequency = v
elif k == 'SIGHUP':
engine.listeners['SIGHUP'] = set([v])
elif k == 'SIGTERM':
engine.listeners['SIGTERM'] = set([v])
elif "." in k:
plugin, attrname = k.split(".", 1)
plugin = getattr(engine, plugin)
if attrname == 'on':
if v and hasattr(getattr(plugin, 'subscribe', None), '__call__'):
plugin.subscribe()
return
elif (
(not v) and
hasattr(getattr(plugin, 'unsubscribe', None), '__call__')
):
plugin.unsubscribe()
return
setattr(plugin, attrname, v)
else:
setattr(engine, k, v)
Config.namespaces["engine"] = _engine_namespace_handler
def _tree_namespace_handler(k, v):
"""Namespace handler for the 'tree' config namespace."""
if isinstance(v, dict):
for script_name, app in v.items():
cherrypy.tree.graft(app, script_name)
cherrypy.engine.log("Mounted: %s on %s" %
(app, script_name or "/"))
else:
cherrypy.tree.graft(v, v.script_name)
cherrypy.engine.log("Mounted: %s on %s" % (v, v.script_name or "/"))
Config.namespaces["tree"] = _tree_namespace_handler

686
lib/cherrypy/_cpdispatch.py Normal file
View File

@@ -0,0 +1,686 @@
"""CherryPy dispatchers.
A 'dispatcher' is the object which looks up the 'page handler' callable
and collects config for the current request based on the path_info, other
request attributes, and the application architecture. The core calls the
dispatcher as early as possible, passing it a 'path_info' argument.
The default dispatcher discovers the page handler by matching path_info
to a hierarchical arrangement of objects, starting at request.app.root.
"""
import string
import sys
import types
try:
classtype = (type, types.ClassType)
except AttributeError:
classtype = type
import cherrypy
from cherrypy._cpcompat import set
class PageHandler(object):
"""Callable which sets response.body."""
def __init__(self, callable, *args, **kwargs):
self.callable = callable
self.args = args
self.kwargs = kwargs
def get_args(self):
return cherrypy.serving.request.args
def set_args(self, args):
cherrypy.serving.request.args = args
return cherrypy.serving.request.args
args = property(
get_args,
set_args,
doc="The ordered args should be accessible from post dispatch hooks"
)
def get_kwargs(self):
return cherrypy.serving.request.kwargs
def set_kwargs(self, kwargs):
cherrypy.serving.request.kwargs = kwargs
return cherrypy.serving.request.kwargs
kwargs = property(
get_kwargs,
set_kwargs,
doc="The named kwargs should be accessible from post dispatch hooks"
)
def __call__(self):
try:
return self.callable(*self.args, **self.kwargs)
except TypeError:
x = sys.exc_info()[1]
try:
test_callable_spec(self.callable, self.args, self.kwargs)
except cherrypy.HTTPError:
raise sys.exc_info()[1]
except:
raise x
raise
def test_callable_spec(callable, callable_args, callable_kwargs):
"""
Inspect callable and test to see if the given args are suitable for it.
When an error occurs during the handler's invoking stage there are 2
erroneous cases:
1. Too many parameters passed to a function which doesn't define
one of *args or **kwargs.
2. Too little parameters are passed to the function.
There are 3 sources of parameters to a cherrypy handler.
1. query string parameters are passed as keyword parameters to the
handler.
2. body parameters are also passed as keyword parameters.
3. when partial matching occurs, the final path atoms are passed as
positional args.
Both the query string and path atoms are part of the URI. If they are
incorrect, then a 404 Not Found should be raised. Conversely the body
parameters are part of the request; if they are invalid a 400 Bad Request.
"""
show_mismatched_params = getattr(
cherrypy.serving.request, 'show_mismatched_params', False)
try:
(args, varargs, varkw, defaults) = getargspec(callable)
except TypeError:
if isinstance(callable, object) and hasattr(callable, '__call__'):
(args, varargs, varkw,
defaults) = getargspec(callable.__call__)
else:
# If it wasn't one of our own types, re-raise
# the original error
raise
if args and args[0] == 'self':
args = args[1:]
arg_usage = dict([(arg, 0,) for arg in args])
vararg_usage = 0
varkw_usage = 0
extra_kwargs = set()
for i, value in enumerate(callable_args):
try:
arg_usage[args[i]] += 1
except IndexError:
vararg_usage += 1
for key in callable_kwargs.keys():
try:
arg_usage[key] += 1
except KeyError:
varkw_usage += 1
extra_kwargs.add(key)
# figure out which args have defaults.
args_with_defaults = args[-len(defaults or []):]
for i, val in enumerate(defaults or []):
# Defaults take effect only when the arg hasn't been used yet.
if arg_usage[args_with_defaults[i]] == 0:
arg_usage[args_with_defaults[i]] += 1
missing_args = []
multiple_args = []
for key, usage in arg_usage.items():
if usage == 0:
missing_args.append(key)
elif usage > 1:
multiple_args.append(key)
if missing_args:
# In the case where the method allows body arguments
# there are 3 potential errors:
# 1. not enough query string parameters -> 404
# 2. not enough body parameters -> 400
# 3. not enough path parts (partial matches) -> 404
#
# We can't actually tell which case it is,
# so I'm raising a 404 because that covers 2/3 of the
# possibilities
#
# In the case where the method does not allow body
# arguments it's definitely a 404.
message = None
if show_mismatched_params:
message = "Missing parameters: %s" % ",".join(missing_args)
raise cherrypy.HTTPError(404, message=message)
# the extra positional arguments come from the path - 404 Not Found
if not varargs and vararg_usage > 0:
raise cherrypy.HTTPError(404)
body_params = cherrypy.serving.request.body.params or {}
body_params = set(body_params.keys())
qs_params = set(callable_kwargs.keys()) - body_params
if multiple_args:
if qs_params.intersection(set(multiple_args)):
# If any of the multiple parameters came from the query string then
# it's a 404 Not Found
error = 404
else:
# Otherwise it's a 400 Bad Request
error = 400
message = None
if show_mismatched_params:
message = "Multiple values for parameters: "\
"%s" % ",".join(multiple_args)
raise cherrypy.HTTPError(error, message=message)
if not varkw and varkw_usage > 0:
# If there were extra query string parameters, it's a 404 Not Found
extra_qs_params = set(qs_params).intersection(extra_kwargs)
if extra_qs_params:
message = None
if show_mismatched_params:
message = "Unexpected query string "\
"parameters: %s" % ", ".join(extra_qs_params)
raise cherrypy.HTTPError(404, message=message)
# If there were any extra body parameters, it's a 400 Not Found
extra_body_params = set(body_params).intersection(extra_kwargs)
if extra_body_params:
message = None
if show_mismatched_params:
message = "Unexpected body parameters: "\
"%s" % ", ".join(extra_body_params)
raise cherrypy.HTTPError(400, message=message)
try:
import inspect
except ImportError:
test_callable_spec = lambda callable, args, kwargs: None
else:
getargspec = inspect.getargspec
# Python 3 requires using getfullargspec if keyword-only arguments are present
if hasattr(inspect, 'getfullargspec'):
def getargspec(callable):
return inspect.getfullargspec(callable)[:4]
class LateParamPageHandler(PageHandler):
"""When passing cherrypy.request.params to the page handler, we do not
want to capture that dict too early; we want to give tools like the
decoding tool a chance to modify the params dict in-between the lookup
of the handler and the actual calling of the handler. This subclass
takes that into account, and allows request.params to be 'bound late'
(it's more complicated than that, but that's the effect).
"""
def _get_kwargs(self):
kwargs = cherrypy.serving.request.params.copy()
if self._kwargs:
kwargs.update(self._kwargs)
return kwargs
def _set_kwargs(self, kwargs):
cherrypy.serving.request.kwargs = kwargs
self._kwargs = kwargs
kwargs = property(_get_kwargs, _set_kwargs,
doc='page handler kwargs (with '
'cherrypy.request.params copied in)')
if sys.version_info < (3, 0):
punctuation_to_underscores = string.maketrans(
string.punctuation, '_' * len(string.punctuation))
def validate_translator(t):
if not isinstance(t, str) or len(t) != 256:
raise ValueError(
"The translate argument must be a str of len 256.")
else:
punctuation_to_underscores = str.maketrans(
string.punctuation, '_' * len(string.punctuation))
def validate_translator(t):
if not isinstance(t, dict):
raise ValueError("The translate argument must be a dict.")
class Dispatcher(object):
"""CherryPy Dispatcher which walks a tree of objects to find a handler.
The tree is rooted at cherrypy.request.app.root, and each hierarchical
component in the path_info argument is matched to a corresponding nested
attribute of the root object. Matching handlers must have an 'exposed'
attribute which evaluates to True. The special method name "index"
matches a URI which ends in a slash ("/"). The special method name
"default" may match a portion of the path_info (but only when no longer
substring of the path_info matches some other object).
This is the default, built-in dispatcher for CherryPy.
"""
dispatch_method_name = '_cp_dispatch'
"""
The name of the dispatch method that nodes may optionally implement
to provide their own dynamic dispatch algorithm.
"""
def __init__(self, dispatch_method_name=None,
translate=punctuation_to_underscores):
validate_translator(translate)
self.translate = translate
if dispatch_method_name:
self.dispatch_method_name = dispatch_method_name
def __call__(self, path_info):
"""Set handler and config for the current request."""
request = cherrypy.serving.request
func, vpath = self.find_handler(path_info)
if func:
# Decode any leftover %2F in the virtual_path atoms.
vpath = [x.replace("%2F", "/") for x in vpath]
request.handler = LateParamPageHandler(func, *vpath)
else:
request.handler = cherrypy.NotFound()
def find_handler(self, path):
"""Return the appropriate page handler, plus any virtual path.
This will return two objects. The first will be a callable,
which can be used to generate page output. Any parameters from
the query string or request body will be sent to that callable
as keyword arguments.
The callable is found by traversing the application's tree,
starting from cherrypy.request.app.root, and matching path
components to successive objects in the tree. For example, the
URL "/path/to/handler" might return root.path.to.handler.
The second object returned will be a list of names which are
'virtual path' components: parts of the URL which are dynamic,
and were not used when looking up the handler.
These virtual path components are passed to the handler as
positional arguments.
"""
request = cherrypy.serving.request
app = request.app
root = app.root
dispatch_name = self.dispatch_method_name
# Get config for the root object/path.
fullpath = [x for x in path.strip('/').split('/') if x] + ['index']
fullpath_len = len(fullpath)
segleft = fullpath_len
nodeconf = {}
if hasattr(root, "_cp_config"):
nodeconf.update(root._cp_config)
if "/" in app.config:
nodeconf.update(app.config["/"])
object_trail = [['root', root, nodeconf, segleft]]
node = root
iternames = fullpath[:]
while iternames:
name = iternames[0]
# map to legal Python identifiers (e.g. replace '.' with '_')
objname = name.translate(self.translate)
nodeconf = {}
subnode = getattr(node, objname, None)
pre_len = len(iternames)
if subnode is None:
dispatch = getattr(node, dispatch_name, None)
if dispatch and hasattr(dispatch, '__call__') and not \
getattr(dispatch, 'exposed', False) and \
pre_len > 1:
# Don't expose the hidden 'index' token to _cp_dispatch
# We skip this if pre_len == 1 since it makes no sense
# to call a dispatcher when we have no tokens left.
index_name = iternames.pop()
subnode = dispatch(vpath=iternames)
iternames.append(index_name)
else:
# We didn't find a path, but keep processing in case there
# is a default() handler.
iternames.pop(0)
else:
# We found the path, remove the vpath entry
iternames.pop(0)
segleft = len(iternames)
if segleft > pre_len:
# No path segment was removed. Raise an error.
raise cherrypy.CherryPyException(
"A vpath segment was added. Custom dispatchers may only "
+ "remove elements. While trying to process "
+ "{0} in {1}".format(name, fullpath)
)
elif segleft == pre_len:
# Assume that the handler used the current path segment, but
# did not pop it. This allows things like
# return getattr(self, vpath[0], None)
iternames.pop(0)
segleft -= 1
node = subnode
if node is not None:
# Get _cp_config attached to this node.
if hasattr(node, "_cp_config"):
nodeconf.update(node._cp_config)
# Mix in values from app.config for this path.
existing_len = fullpath_len - pre_len
if existing_len != 0:
curpath = '/' + '/'.join(fullpath[0:existing_len])
else:
curpath = ''
new_segs = fullpath[fullpath_len - pre_len:fullpath_len - segleft]
for seg in new_segs:
curpath += '/' + seg
if curpath in app.config:
nodeconf.update(app.config[curpath])
object_trail.append([name, node, nodeconf, segleft])
def set_conf():
"""Collapse all object_trail config into cherrypy.request.config.
"""
base = cherrypy.config.copy()
# Note that we merge the config from each node
# even if that node was None.
for name, obj, conf, segleft in object_trail:
base.update(conf)
if 'tools.staticdir.dir' in conf:
base['tools.staticdir.section'] = '/' + \
'/'.join(fullpath[0:fullpath_len - segleft])
return base
# Try successive objects (reverse order)
num_candidates = len(object_trail) - 1
for i in range(num_candidates, -1, -1):
name, candidate, nodeconf, segleft = object_trail[i]
if candidate is None:
continue
# Try a "default" method on the current leaf.
if hasattr(candidate, "default"):
defhandler = candidate.default
if getattr(defhandler, 'exposed', False):
# Insert any extra _cp_config from the default handler.
conf = getattr(defhandler, "_cp_config", {})
object_trail.insert(
i + 1, ["default", defhandler, conf, segleft])
request.config = set_conf()
# See https://bitbucket.org/cherrypy/cherrypy/issue/613
request.is_index = path.endswith("/")
return defhandler, fullpath[fullpath_len - segleft:-1]
# Uncomment the next line to restrict positional params to
# "default".
# if i < num_candidates - 2: continue
# Try the current leaf.
if getattr(candidate, 'exposed', False):
request.config = set_conf()
if i == num_candidates:
# We found the extra ".index". Mark request so tools
# can redirect if path_info has no trailing slash.
request.is_index = True
else:
# We're not at an 'index' handler. Mark request so tools
# can redirect if path_info has NO trailing slash.
# Note that this also includes handlers which take
# positional parameters (virtual paths).
request.is_index = False
return candidate, fullpath[fullpath_len - segleft:-1]
# We didn't find anything
request.config = set_conf()
return None, []
class MethodDispatcher(Dispatcher):
"""Additional dispatch based on cherrypy.request.method.upper().
Methods named GET, POST, etc will be called on an exposed class.
The method names must be all caps; the appropriate Allow header
will be output showing all capitalized method names as allowable
HTTP verbs.
Note that the containing class must be exposed, not the methods.
"""
def __call__(self, path_info):
"""Set handler and config for the current request."""
request = cherrypy.serving.request
resource, vpath = self.find_handler(path_info)
if resource:
# Set Allow header
avail = [m for m in dir(resource) if m.isupper()]
if "GET" in avail and "HEAD" not in avail:
avail.append("HEAD")
avail.sort()
cherrypy.serving.response.headers['Allow'] = ", ".join(avail)
# Find the subhandler
meth = request.method.upper()
func = getattr(resource, meth, None)
if func is None and meth == "HEAD":
func = getattr(resource, "GET", None)
if func:
# Grab any _cp_config on the subhandler.
if hasattr(func, "_cp_config"):
request.config.update(func._cp_config)
# Decode any leftover %2F in the virtual_path atoms.
vpath = [x.replace("%2F", "/") for x in vpath]
request.handler = LateParamPageHandler(func, *vpath)
else:
request.handler = cherrypy.HTTPError(405)
else:
request.handler = cherrypy.NotFound()
class RoutesDispatcher(object):
"""A Routes based dispatcher for CherryPy."""
def __init__(self, full_result=False, **mapper_options):
"""
Routes dispatcher
Set full_result to True if you wish the controller
and the action to be passed on to the page handler
parameters. By default they won't be.
"""
import routes
self.full_result = full_result
self.controllers = {}
self.mapper = routes.Mapper(**mapper_options)
self.mapper.controller_scan = self.controllers.keys
def connect(self, name, route, controller, **kwargs):
self.controllers[name] = controller
self.mapper.connect(name, route, controller=name, **kwargs)
def redirect(self, url):
raise cherrypy.HTTPRedirect(url)
def __call__(self, path_info):
"""Set handler and config for the current request."""
func = self.find_handler(path_info)
if func:
cherrypy.serving.request.handler = LateParamPageHandler(func)
else:
cherrypy.serving.request.handler = cherrypy.NotFound()
def find_handler(self, path_info):
"""Find the right page handler, and set request.config."""
import routes
request = cherrypy.serving.request
config = routes.request_config()
config.mapper = self.mapper
if hasattr(request, 'wsgi_environ'):
config.environ = request.wsgi_environ
config.host = request.headers.get('Host', None)
config.protocol = request.scheme
config.redirect = self.redirect
result = self.mapper.match(path_info)
config.mapper_dict = result
params = {}
if result:
params = result.copy()
if not self.full_result:
params.pop('controller', None)
params.pop('action', None)
request.params.update(params)
# Get config for the root object/path.
request.config = base = cherrypy.config.copy()
curpath = ""
def merge(nodeconf):
if 'tools.staticdir.dir' in nodeconf:
nodeconf['tools.staticdir.section'] = curpath or "/"
base.update(nodeconf)
app = request.app
root = app.root
if hasattr(root, "_cp_config"):
merge(root._cp_config)
if "/" in app.config:
merge(app.config["/"])
# Mix in values from app.config.
atoms = [x for x in path_info.split("/") if x]
if atoms:
last = atoms.pop()
else:
last = None
for atom in atoms:
curpath = "/".join((curpath, atom))
if curpath in app.config:
merge(app.config[curpath])
handler = None
if result:
controller = result.get('controller')
controller = self.controllers.get(controller, controller)
if controller:
if isinstance(controller, classtype):
controller = controller()
# Get config from the controller.
if hasattr(controller, "_cp_config"):
merge(controller._cp_config)
action = result.get('action')
if action is not None:
handler = getattr(controller, action, None)
# Get config from the handler
if hasattr(handler, "_cp_config"):
merge(handler._cp_config)
else:
handler = controller
# Do the last path atom here so it can
# override the controller's _cp_config.
if last:
curpath = "/".join((curpath, last))
if curpath in app.config:
merge(app.config[curpath])
return handler
def XMLRPCDispatcher(next_dispatcher=Dispatcher()):
from cherrypy.lib import xmlrpcutil
def xmlrpc_dispatch(path_info):
path_info = xmlrpcutil.patched_path(path_info)
return next_dispatcher(path_info)
return xmlrpc_dispatch
def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True,
**domains):
"""
Select a different handler based on the Host header.
This can be useful when running multiple sites within one CP server.
It allows several domains to point to different parts of a single
website structure. For example::
http://www.domain.example -> root
http://www.domain2.example -> root/domain2/
http://www.domain2.example:443 -> root/secure
can be accomplished via the following config::
[/]
request.dispatch = cherrypy.dispatch.VirtualHost(
**{'www.domain2.example': '/domain2',
'www.domain2.example:443': '/secure',
})
next_dispatcher
The next dispatcher object in the dispatch chain.
The VirtualHost dispatcher adds a prefix to the URL and calls
another dispatcher. Defaults to cherrypy.dispatch.Dispatcher().
use_x_forwarded_host
If True (the default), any "X-Forwarded-Host"
request header will be used instead of the "Host" header. This
is commonly added by HTTP servers (such as Apache) when proxying.
``**domains``
A dict of {host header value: virtual prefix} pairs.
The incoming "Host" request header is looked up in this dict,
and, if a match is found, the corresponding "virtual prefix"
value will be prepended to the URL path before calling the
next dispatcher. Note that you often need separate entries
for "example.com" and "www.example.com". In addition, "Host"
headers may contain the port number.
"""
from cherrypy.lib import httputil
def vhost_dispatch(path_info):
request = cherrypy.serving.request
header = request.headers.get
domain = header('Host', '')
if use_x_forwarded_host:
domain = header("X-Forwarded-Host", domain)
prefix = domains.get(domain, "")
if prefix:
path_info = httputil.urljoin(prefix, path_info)
result = next_dispatcher(path_info)
# Touch up staticdir config. See
# https://bitbucket.org/cherrypy/cherrypy/issue/614.
section = request.config.get('tools.staticdir.section')
if section:
section = section[len(prefix):]
request.config['tools.staticdir.section'] = section
return result
return vhost_dispatch

609
lib/cherrypy/_cperror.py Normal file
View File

@@ -0,0 +1,609 @@
"""Exception classes for CherryPy.
CherryPy provides (and uses) exceptions for declaring that the HTTP response
should be a status other than the default "200 OK". You can ``raise`` them like
normal Python exceptions. You can also call them and they will raise
themselves; this means you can set an
:class:`HTTPError<cherrypy._cperror.HTTPError>`
or :class:`HTTPRedirect<cherrypy._cperror.HTTPRedirect>` as the
:attr:`request.handler<cherrypy._cprequest.Request.handler>`.
.. _redirectingpost:
Redirecting POST
================
When you GET a resource and are redirected by the server to another Location,
there's generally no problem since GET is both a "safe method" (there should
be no side-effects) and an "idempotent method" (multiple calls are no different
than a single call).
POST, however, is neither safe nor idempotent--if you
charge a credit card, you don't want to be charged twice by a redirect!
For this reason, *none* of the 3xx responses permit a user-agent (browser) to
resubmit a POST on redirection without first confirming the action with the
user:
===== ================================= ===========
300 Multiple Choices Confirm with the user
301 Moved Permanently Confirm with the user
302 Found (Object moved temporarily) Confirm with the user
303 See Other GET the new URI--no confirmation
304 Not modified (for conditional GET only--POST should not raise this error)
305 Use Proxy Confirm with the user
307 Temporary Redirect Confirm with the user
===== ================================= ===========
However, browsers have historically implemented these restrictions poorly;
in particular, many browsers do not force the user to confirm 301, 302
or 307 when redirecting POST. For this reason, CherryPy defaults to 303,
which most user-agents appear to have implemented correctly. Therefore, if
you raise HTTPRedirect for a POST request, the user-agent will most likely
attempt to GET the new URI (without asking for confirmation from the user).
We realize this is confusing for developers, but it's the safest thing we
could do. You are of course free to raise ``HTTPRedirect(uri, status=302)``
or any other 3xx status if you know what you're doing, but given the
environment, we couldn't let any of those be the default.
Custom Error Handling
=====================
.. image:: /refman/cperrors.gif
Anticipated HTTP responses
--------------------------
The 'error_page' config namespace can be used to provide custom HTML output for
expected responses (like 404 Not Found). Supply a filename from which the
output will be read. The contents will be interpolated with the values
%(status)s, %(message)s, %(traceback)s, and %(version)s using plain old Python
`string formatting <http://docs.python.org/2/library/stdtypes.html#string-formatting-operations>`_.
::
_cp_config = {
'error_page.404': os.path.join(localDir, "static/index.html")
}
Beginning in version 3.1, you may also provide a function or other callable as
an error_page entry. It will be passed the same status, message, traceback and
version arguments that are interpolated into templates::
def error_page_402(status, message, traceback, version):
return "Error %s - Well, I'm very sorry but you haven't paid!" % status
cherrypy.config.update({'error_page.402': error_page_402})
Also in 3.1, in addition to the numbered error codes, you may also supply
"error_page.default" to handle all codes which do not have their own error_page
entry.
Unanticipated errors
--------------------
CherryPy also has a generic error handling mechanism: whenever an unanticipated
error occurs in your code, it will call
:func:`Request.error_response<cherrypy._cprequest.Request.error_response>` to
set the response status, headers, and body. By default, this is the same
output as
:class:`HTTPError(500) <cherrypy._cperror.HTTPError>`. If you want to provide
some other behavior, you generally replace "request.error_response".
Here is some sample code that shows how to display a custom error message and
send an e-mail containing the error::
from cherrypy import _cperror
def handle_error():
cherrypy.response.status = 500
cherrypy.response.body = [
"<html><body>Sorry, an error occured</body></html>"
]
sendMail('error@domain.com',
'Error in your web app',
_cperror.format_exc())
class Root:
_cp_config = {'request.error_response': handle_error}
Note that you have to explicitly set
:attr:`response.body <cherrypy._cprequest.Response.body>`
and not simply return an error message as a result.
"""
from cgi import escape as _escape
from sys import exc_info as _exc_info
from traceback import format_exception as _format_exception
from cherrypy._cpcompat import basestring, bytestr, iteritems, ntob
from cherrypy._cpcompat import tonative, urljoin as _urljoin
from cherrypy.lib import httputil as _httputil
class CherryPyException(Exception):
"""A base class for CherryPy exceptions."""
pass
class TimeoutError(CherryPyException):
"""Exception raised when Response.timed_out is detected."""
pass
class InternalRedirect(CherryPyException):
"""Exception raised to switch to the handler for a different URL.
This exception will redirect processing to another path within the site
(without informing the client). Provide the new path as an argument when
raising the exception. Provide any params in the querystring for the new
URL.
"""
def __init__(self, path, query_string=""):
import cherrypy
self.request = cherrypy.serving.request
self.query_string = query_string
if "?" in path:
# Separate any params included in the path
path, self.query_string = path.split("?", 1)
# Note that urljoin will "do the right thing" whether url is:
# 1. a URL relative to root (e.g. "/dummy")
# 2. a URL relative to the current path
# Note that any query string will be discarded.
path = _urljoin(self.request.path_info, path)
# Set a 'path' member attribute so that code which traps this
# error can have access to it.
self.path = path
CherryPyException.__init__(self, path, self.query_string)
class HTTPRedirect(CherryPyException):
"""Exception raised when the request should be redirected.
This exception will force a HTTP redirect to the URL or URL's you give it.
The new URL must be passed as the first argument to the Exception,
e.g., HTTPRedirect(newUrl). Multiple URLs are allowed in a list.
If a URL is absolute, it will be used as-is. If it is relative, it is
assumed to be relative to the current cherrypy.request.path_info.
If one of the provided URL is a unicode object, it will be encoded
using the default encoding or the one passed in parameter.
There are multiple types of redirect, from which you can select via the
``status`` argument. If you do not provide a ``status`` arg, it defaults to
303 (or 302 if responding with HTTP/1.0).
Examples::
raise cherrypy.HTTPRedirect("")
raise cherrypy.HTTPRedirect("/abs/path", 307)
raise cherrypy.HTTPRedirect(["path1", "path2?a=1&b=2"], 301)
See :ref:`redirectingpost` for additional caveats.
"""
status = None
"""The integer HTTP status code to emit."""
urls = None
"""The list of URL's to emit."""
encoding = 'utf-8'
"""The encoding when passed urls are not native strings"""
def __init__(self, urls, status=None, encoding=None):
import cherrypy
request = cherrypy.serving.request
if isinstance(urls, basestring):
urls = [urls]
abs_urls = []
for url in urls:
url = tonative(url, encoding or self.encoding)
# Note that urljoin will "do the right thing" whether url is:
# 1. a complete URL with host (e.g. "http://www.example.com/test")
# 2. a URL relative to root (e.g. "/dummy")
# 3. a URL relative to the current path
# Note that any query string in cherrypy.request is discarded.
url = _urljoin(cherrypy.url(), url)
abs_urls.append(url)
self.urls = abs_urls
# RFC 2616 indicates a 301 response code fits our goal; however,
# browser support for 301 is quite messy. Do 302/303 instead. See
# http://www.alanflavell.org.uk/www/post-redirect.html
if status is None:
if request.protocol >= (1, 1):
status = 303
else:
status = 302
else:
status = int(status)
if status < 300 or status > 399:
raise ValueError("status must be between 300 and 399.")
self.status = status
CherryPyException.__init__(self, abs_urls, status)
def set_response(self):
"""Modify cherrypy.response status, headers, and body to represent
self.
CherryPy uses this internally, but you can also use it to create an
HTTPRedirect object and set its output without *raising* the exception.
"""
import cherrypy
response = cherrypy.serving.response
response.status = status = self.status
if status in (300, 301, 302, 303, 307):
response.headers['Content-Type'] = "text/html;charset=utf-8"
# "The ... URI SHOULD be given by the Location field
# in the response."
response.headers['Location'] = self.urls[0]
# "Unless the request method was HEAD, the entity of the response
# SHOULD contain a short hypertext note with a hyperlink to the
# new URI(s)."
msg = {
300: "This resource can be found at ",
301: "This resource has permanently moved to ",
302: "This resource resides temporarily at ",
303: "This resource can be found at ",
307: "This resource has moved temporarily to ",
}[status]
msg += '<a href=%s>%s</a>.'
from xml.sax import saxutils
msgs = [msg % (saxutils.quoteattr(u), u) for u in self.urls]
response.body = ntob("<br />\n".join(msgs), 'utf-8')
# Previous code may have set C-L, so we have to reset it
# (allow finalize to set it).
response.headers.pop('Content-Length', None)
elif status == 304:
# Not Modified.
# "The response MUST include the following header fields:
# Date, unless its omission is required by section 14.18.1"
# The "Date" header should have been set in Response.__init__
# "...the response SHOULD NOT include other entity-headers."
for key in ('Allow', 'Content-Encoding', 'Content-Language',
'Content-Length', 'Content-Location', 'Content-MD5',
'Content-Range', 'Content-Type', 'Expires',
'Last-Modified'):
if key in response.headers:
del response.headers[key]
# "The 304 response MUST NOT contain a message-body."
response.body = None
# Previous code may have set C-L, so we have to reset it.
response.headers.pop('Content-Length', None)
elif status == 305:
# Use Proxy.
# self.urls[0] should be the URI of the proxy.
response.headers['Location'] = self.urls[0]
response.body = None
# Previous code may have set C-L, so we have to reset it.
response.headers.pop('Content-Length', None)
else:
raise ValueError("The %s status code is unknown." % status)
def __call__(self):
"""Use this exception as a request.handler (raise self)."""
raise self
def clean_headers(status):
"""Remove any headers which should not apply to an error response."""
import cherrypy
response = cherrypy.serving.response
# Remove headers which applied to the original content,
# but do not apply to the error page.
respheaders = response.headers
for key in ["Accept-Ranges", "Age", "ETag", "Location", "Retry-After",
"Vary", "Content-Encoding", "Content-Length", "Expires",
"Content-Location", "Content-MD5", "Last-Modified"]:
if key in respheaders:
del respheaders[key]
if status != 416:
# A server sending a response with status code 416 (Requested
# range not satisfiable) SHOULD include a Content-Range field
# with a byte-range-resp-spec of "*". The instance-length
# specifies the current length of the selected resource.
# A response with status code 206 (Partial Content) MUST NOT
# include a Content-Range field with a byte-range- resp-spec of "*".
if "Content-Range" in respheaders:
del respheaders["Content-Range"]
class HTTPError(CherryPyException):
"""Exception used to return an HTTP error code (4xx-5xx) to the client.
This exception can be used to automatically send a response using a
http status code, with an appropriate error page. It takes an optional
``status`` argument (which must be between 400 and 599); it defaults to 500
("Internal Server Error"). It also takes an optional ``message`` argument,
which will be returned in the response body. See
`RFC2616 <http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.4>`_
for a complete list of available error codes and when to use them.
Examples::
raise cherrypy.HTTPError(403)
raise cherrypy.HTTPError(
"403 Forbidden", "You are not allowed to access this resource.")
"""
status = None
"""The HTTP status code. May be of type int or str (with a Reason-Phrase).
"""
code = None
"""The integer HTTP status code."""
reason = None
"""The HTTP Reason-Phrase string."""
def __init__(self, status=500, message=None):
self.status = status
try:
self.code, self.reason, defaultmsg = _httputil.valid_status(status)
except ValueError:
raise self.__class__(500, _exc_info()[1].args[0])
if self.code < 400 or self.code > 599:
raise ValueError("status must be between 400 and 599.")
# See http://www.python.org/dev/peps/pep-0352/
# self.message = message
self._message = message or defaultmsg
CherryPyException.__init__(self, status, message)
def set_response(self):
"""Modify cherrypy.response status, headers, and body to represent
self.
CherryPy uses this internally, but you can also use it to create an
HTTPError object and set its output without *raising* the exception.
"""
import cherrypy
response = cherrypy.serving.response
clean_headers(self.code)
# In all cases, finalize will be called after this method,
# so don't bother cleaning up response values here.
response.status = self.status
tb = None
if cherrypy.serving.request.show_tracebacks:
tb = format_exc()
response.headers.pop('Content-Length', None)
content = self.get_error_page(self.status, traceback=tb,
message=self._message)
response.body = content
_be_ie_unfriendly(self.code)
def get_error_page(self, *args, **kwargs):
return get_error_page(*args, **kwargs)
def __call__(self):
"""Use this exception as a request.handler (raise self)."""
raise self
class NotFound(HTTPError):
"""Exception raised when a URL could not be mapped to any handler (404).
This is equivalent to raising
:class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`.
"""
def __init__(self, path=None):
if path is None:
import cherrypy
request = cherrypy.serving.request
path = request.script_name + request.path_info
self.args = (path,)
HTTPError.__init__(self, 404, "The path '%s' was not found." % path)
_HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
"-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"></meta>
<title>%(status)s</title>
<style type="text/css">
#powered_by {
margin-top: 20px;
border-top: 2px solid black;
font-style: italic;
}
#traceback {
color: red;
}
</style>
</head>
<body>
<h2>%(status)s</h2>
<p>%(message)s</p>
<pre id="traceback">%(traceback)s</pre>
<div id="powered_by">
<span>
Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a>
</span>
</div>
</body>
</html>
'''
def get_error_page(status, **kwargs):
"""Return an HTML page, containing a pretty error response.
status should be an int or a str.
kwargs will be interpolated into the page template.
"""
import cherrypy
try:
code, reason, message = _httputil.valid_status(status)
except ValueError:
raise cherrypy.HTTPError(500, _exc_info()[1].args[0])
# We can't use setdefault here, because some
# callers send None for kwarg values.
if kwargs.get('status') is None:
kwargs['status'] = "%s %s" % (code, reason)
if kwargs.get('message') is None:
kwargs['message'] = message
if kwargs.get('traceback') is None:
kwargs['traceback'] = ''
if kwargs.get('version') is None:
kwargs['version'] = cherrypy.__version__
for k, v in iteritems(kwargs):
if v is None:
kwargs[k] = ""
else:
kwargs[k] = _escape(kwargs[k])
# Use a custom template or callable for the error page?
pages = cherrypy.serving.request.error_page
error_page = pages.get(code) or pages.get('default')
# Default template, can be overridden below.
template = _HTTPErrorTemplate
if error_page:
try:
if hasattr(error_page, '__call__'):
# The caller function may be setting headers manually,
# so we delegate to it completely. We may be returning
# an iterator as well as a string here.
#
# We *must* make sure any content is not unicode.
result = error_page(**kwargs)
if cherrypy.lib.is_iterator(result):
from cherrypy.lib.encoding import UTF8StreamEncoder
return UTF8StreamEncoder(result)
elif isinstance(result, cherrypy._cpcompat.unicodestr):
return result.encode('utf-8')
else:
if not isinstance(result, cherrypy._cpcompat.bytestr):
raise ValueError('error page function did not '
'return a bytestring, unicodestring or an '
'iterator - returned object of type %s.'
% (type(result).__name__))
return result
else:
# Load the template from this path.
template = tonative(open(error_page, 'rb').read())
except:
e = _format_exception(*_exc_info())[-1]
m = kwargs['message']
if m:
m += "<br />"
m += "In addition, the custom error page failed:\n<br />%s" % e
kwargs['message'] = m
response = cherrypy.serving.response
response.headers['Content-Type'] = "text/html;charset=utf-8"
result = template % kwargs
return result.encode('utf-8')
_ie_friendly_error_sizes = {
400: 512, 403: 256, 404: 512, 405: 256,
406: 512, 408: 512, 409: 512, 410: 256,
500: 512, 501: 512, 505: 512,
}
def _be_ie_unfriendly(status):
import cherrypy
response = cherrypy.serving.response
# For some statuses, Internet Explorer 5+ shows "friendly error
# messages" instead of our response.body if the body is smaller
# than a given size. Fix this by returning a body over that size
# (by adding whitespace).
# See http://support.microsoft.com/kb/q218155/
s = _ie_friendly_error_sizes.get(status, 0)
if s:
s += 1
# Since we are issuing an HTTP error status, we assume that
# the entity is short, and we should just collapse it.
content = response.collapse_body()
l = len(content)
if l and l < s:
# IN ADDITION: the response must be written to IE
# in one chunk or it will still get replaced! Bah.
content = content + (ntob(" ") * (s - l))
response.body = content
response.headers['Content-Length'] = str(len(content))
def format_exc(exc=None):
"""Return exc (or sys.exc_info if None), formatted."""
try:
if exc is None:
exc = _exc_info()
if exc == (None, None, None):
return ""
import traceback
return "".join(traceback.format_exception(*exc))
finally:
del exc
def bare_error(extrabody=None):
"""Produce status, headers, body for a critical error.
Returns a triple without calling any other questionable functions,
so it should be as error-free as possible. Call it from an HTTP server
if you get errors outside of the request.
If extrabody is None, a friendly but rather unhelpful error message
is set in the body. If extrabody is a string, it will be appended
as-is to the body.
"""
# The whole point of this function is to be a last line-of-defense
# in handling errors. That is, it must not raise any errors itself;
# it cannot be allowed to fail. Therefore, don't add to it!
# In particular, don't call any other CP functions.
body = ntob("Unrecoverable error in the server.")
if extrabody is not None:
if not isinstance(extrabody, bytestr):
extrabody = extrabody.encode('utf-8')
body += ntob("\n") + extrabody
return (ntob("500 Internal Server Error"),
[(ntob('Content-Type'), ntob('text/plain')),
(ntob('Content-Length'), ntob(str(len(body)), 'ISO-8859-1'))],
[body])

459
lib/cherrypy/_cplogging.py Normal file
View File

@@ -0,0 +1,459 @@
"""
Simple config
=============
Although CherryPy uses the :mod:`Python logging module <logging>`, it does so
behind the scenes so that simple logging is simple, but complicated logging
is still possible. "Simple" logging means that you can log to the screen
(i.e. console/stdout) or to a file, and that you can easily have separate
error and access log files.
Here are the simplified logging settings. You use these by adding lines to
your config file or dict. You should set these at either the global level or
per application (see next), but generally not both.
* ``log.screen``: Set this to True to have both "error" and "access" messages
printed to stdout.
* ``log.access_file``: Set this to an absolute filename where you want
"access" messages written.
* ``log.error_file``: Set this to an absolute filename where you want "error"
messages written.
Many events are automatically logged; to log your own application events, call
:func:`cherrypy.log`.
Architecture
============
Separate scopes
---------------
CherryPy provides log managers at both the global and application layers.
This means you can have one set of logging rules for your entire site,
and another set of rules specific to each application. The global log
manager is found at :func:`cherrypy.log`, and the log manager for each
application is found at :attr:`app.log<cherrypy._cptree.Application.log>`.
If you're inside a request, the latter is reachable from
``cherrypy.request.app.log``; if you're outside a request, you'll have to
obtain a reference to the ``app``: either the return value of
:func:`tree.mount()<cherrypy._cptree.Tree.mount>` or, if you used
:func:`quickstart()<cherrypy.quickstart>` instead, via
``cherrypy.tree.apps['/']``.
By default, the global logs are named "cherrypy.error" and "cherrypy.access",
and the application logs are named "cherrypy.error.2378745" and
"cherrypy.access.2378745" (the number is the id of the Application object).
This means that the application logs "bubble up" to the site logs, so if your
application has no log handlers, the site-level handlers will still log the
messages.
Errors vs. Access
-----------------
Each log manager handles both "access" messages (one per HTTP request) and
"error" messages (everything else). Note that the "error" log is not just for
errors! The format of access messages is highly formalized, but the error log
isn't--it receives messages from a variety of sources (including full error
tracebacks, if enabled).
If you are logging the access log and error log to the same source, then there
is a possibility that a specially crafted error message may replicate an access
log message as described in CWE-117. In this case it is the application
developer's responsibility to manually escape data before using CherryPy's log()
functionality, or they may create an application that is vulnerable to CWE-117.
This would be achieved by using a custom handler escape any special characters,
and attached as described below.
Custom Handlers
===============
The simple settings above work by manipulating Python's standard :mod:`logging`
module. So when you need something more complex, the full power of the standard
module is yours to exploit. You can borrow or create custom handlers, formats,
filters, and much more. Here's an example that skips the standard FileHandler
and uses a RotatingFileHandler instead:
::
#python
log = app.log
# Remove the default FileHandlers if present.
log.error_file = ""
log.access_file = ""
maxBytes = getattr(log, "rot_maxBytes", 10000000)
backupCount = getattr(log, "rot_backupCount", 1000)
# Make a new RotatingFileHandler for the error log.
fname = getattr(log, "rot_error_file", "error.log")
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
h.setLevel(DEBUG)
h.setFormatter(_cplogging.logfmt)
log.error_log.addHandler(h)
# Make a new RotatingFileHandler for the access log.
fname = getattr(log, "rot_access_file", "access.log")
h = handlers.RotatingFileHandler(fname, 'a', maxBytes, backupCount)
h.setLevel(DEBUG)
h.setFormatter(_cplogging.logfmt)
log.access_log.addHandler(h)
The ``rot_*`` attributes are pulled straight from the application log object.
Since "log.*" config entries simply set attributes on the log object, you can
add custom attributes to your heart's content. Note that these handlers are
used ''instead'' of the default, simple handlers outlined above (so don't set
the "log.error_file" config entry, for example).
"""
import datetime
import logging
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
logging.Logger.manager.emittedNoHandlerWarning = 1
logfmt = logging.Formatter("%(message)s")
import os
import sys
import cherrypy
from cherrypy import _cperror
from cherrypy._cpcompat import ntob, py3k
class NullHandler(logging.Handler):
"""A no-op logging handler to silence the logging.lastResort handler."""
def handle(self, record):
pass
def emit(self, record):
pass
def createLock(self):
self.lock = None
class LogManager(object):
"""An object to assist both simple and advanced logging.
``cherrypy.log`` is an instance of this class.
"""
appid = None
"""The id() of the Application object which owns this log manager. If this
is a global log manager, appid is None."""
error_log = None
"""The actual :class:`logging.Logger` instance for error messages."""
access_log = None
"""The actual :class:`logging.Logger` instance for access messages."""
if py3k:
access_log_format = \
'{h} {l} {u} {t} "{r}" {s} {b} "{f}" "{a}"'
else:
access_log_format = \
'%(h)s %(l)s %(u)s %(t)s "%(r)s" %(s)s %(b)s "%(f)s" "%(a)s"'
logger_root = None
"""The "top-level" logger name.
This string will be used as the first segment in the Logger names.
The default is "cherrypy", for example, in which case the Logger names
will be of the form::
cherrypy.error.<appid>
cherrypy.access.<appid>
"""
def __init__(self, appid=None, logger_root="cherrypy"):
self.logger_root = logger_root
self.appid = appid
if appid is None:
self.error_log = logging.getLogger("%s.error" % logger_root)
self.access_log = logging.getLogger("%s.access" % logger_root)
else:
self.error_log = logging.getLogger(
"%s.error.%s" % (logger_root, appid))
self.access_log = logging.getLogger(
"%s.access.%s" % (logger_root, appid))
self.error_log.setLevel(logging.INFO)
self.access_log.setLevel(logging.INFO)
# Silence the no-handlers "warning" (stderr write!) in stdlib logging
self.error_log.addHandler(NullHandler())
self.access_log.addHandler(NullHandler())
cherrypy.engine.subscribe('graceful', self.reopen_files)
def reopen_files(self):
"""Close and reopen all file handlers."""
for log in (self.error_log, self.access_log):
for h in log.handlers:
if isinstance(h, logging.FileHandler):
h.acquire()
h.stream.close()
h.stream = open(h.baseFilename, h.mode)
h.release()
def error(self, msg='', context='', severity=logging.INFO,
traceback=False):
"""Write the given ``msg`` to the error log.
This is not just for errors! Applications may call this at any time
to log application-specific information.
If ``traceback`` is True, the traceback of the current exception
(if any) will be appended to ``msg``.
"""
if traceback:
msg += _cperror.format_exc()
self.error_log.log(severity, ' '.join((self.time(), context, msg)))
def __call__(self, *args, **kwargs):
"""An alias for ``error``."""
return self.error(*args, **kwargs)
def access(self):
"""Write to the access log (in Apache/NCSA Combined Log format).
See the
`apache documentation <http://httpd.apache.org/docs/current/logs.html#combined>`_
for format details.
CherryPy calls this automatically for you. Note there are no arguments;
it collects the data itself from
:class:`cherrypy.request<cherrypy._cprequest.Request>`.
Like Apache started doing in 2.0.46, non-printable and other special
characters in %r (and we expand that to all parts) are escaped using
\\xhh sequences, where hh stands for the hexadecimal representation
of the raw byte. Exceptions from this rule are " and \\, which are
escaped by prepending a backslash, and all whitespace characters,
which are written in their C-style notation (\\n, \\t, etc).
"""
request = cherrypy.serving.request
remote = request.remote
response = cherrypy.serving.response
outheaders = response.headers
inheaders = request.headers
if response.output_status is None:
status = "-"
else:
status = response.output_status.split(ntob(" "), 1)[0]
if py3k:
status = status.decode('ISO-8859-1')
atoms = {'h': remote.name or remote.ip,
'l': '-',
'u': getattr(request, "login", None) or "-",
't': self.time(),
'r': request.request_line,
's': status,
'b': dict.get(outheaders, 'Content-Length', '') or "-",
'f': dict.get(inheaders, 'Referer', ''),
'a': dict.get(inheaders, 'User-Agent', ''),
'o': dict.get(inheaders, 'Host', '-'),
}
if py3k:
for k, v in atoms.items():
if not isinstance(v, str):
v = str(v)
v = v.replace('"', '\\"').encode('utf8')
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
# and backslash for us. All we have to do is strip the quotes.
v = repr(v)[2:-1]
# in python 3.0 the repr of bytes (as returned by encode)
# uses double \'s. But then the logger escapes them yet, again
# resulting in quadruple slashes. Remove the extra one here.
v = v.replace('\\\\', '\\')
# Escape double-quote.
atoms[k] = v
try:
self.access_log.log(
logging.INFO, self.access_log_format.format(**atoms))
except:
self(traceback=True)
else:
for k, v in atoms.items():
if isinstance(v, unicode):
v = v.encode('utf8')
elif not isinstance(v, str):
v = str(v)
# Fortunately, repr(str) escapes unprintable chars, \n, \t, etc
# and backslash for us. All we have to do is strip the quotes.
v = repr(v)[1:-1]
# Escape double-quote.
atoms[k] = v.replace('"', '\\"')
try:
self.access_log.log(
logging.INFO, self.access_log_format % atoms)
except:
self(traceback=True)
def time(self):
"""Return now() in Apache Common Log Format (no timezone)."""
now = datetime.datetime.now()
monthnames = ['jan', 'feb', 'mar', 'apr', 'may', 'jun',
'jul', 'aug', 'sep', 'oct', 'nov', 'dec']
month = monthnames[now.month - 1].capitalize()
return ('[%02d/%s/%04d:%02d:%02d:%02d]' %
(now.day, month, now.year, now.hour, now.minute, now.second))
def _get_builtin_handler(self, log, key):
for h in log.handlers:
if getattr(h, "_cpbuiltin", None) == key:
return h
# ------------------------- Screen handlers ------------------------- #
def _set_screen_handler(self, log, enable, stream=None):
h = self._get_builtin_handler(log, "screen")
if enable:
if not h:
if stream is None:
stream = sys.stderr
h = logging.StreamHandler(stream)
h.setFormatter(logfmt)
h._cpbuiltin = "screen"
log.addHandler(h)
elif h:
log.handlers.remove(h)
def _get_screen(self):
h = self._get_builtin_handler
has_h = h(self.error_log, "screen") or h(self.access_log, "screen")
return bool(has_h)
def _set_screen(self, newvalue):
self._set_screen_handler(self.error_log, newvalue, stream=sys.stderr)
self._set_screen_handler(self.access_log, newvalue, stream=sys.stdout)
screen = property(_get_screen, _set_screen,
doc="""Turn stderr/stdout logging on or off.
If you set this to True, it'll add the appropriate StreamHandler for
you. If you set it to False, it will remove the handler.
""")
# -------------------------- File handlers -------------------------- #
def _add_builtin_file_handler(self, log, fname):
h = logging.FileHandler(fname)
h.setFormatter(logfmt)
h._cpbuiltin = "file"
log.addHandler(h)
def _set_file_handler(self, log, filename):
h = self._get_builtin_handler(log, "file")
if filename:
if h:
if h.baseFilename != os.path.abspath(filename):
h.close()
log.handlers.remove(h)
self._add_builtin_file_handler(log, filename)
else:
self._add_builtin_file_handler(log, filename)
else:
if h:
h.close()
log.handlers.remove(h)
def _get_error_file(self):
h = self._get_builtin_handler(self.error_log, "file")
if h:
return h.baseFilename
return ''
def _set_error_file(self, newvalue):
self._set_file_handler(self.error_log, newvalue)
error_file = property(_get_error_file, _set_error_file,
doc="""The filename for self.error_log.
If you set this to a string, it'll add the appropriate FileHandler for
you. If you set it to ``None`` or ``''``, it will remove the handler.
""")
def _get_access_file(self):
h = self._get_builtin_handler(self.access_log, "file")
if h:
return h.baseFilename
return ''
def _set_access_file(self, newvalue):
self._set_file_handler(self.access_log, newvalue)
access_file = property(_get_access_file, _set_access_file,
doc="""The filename for self.access_log.
If you set this to a string, it'll add the appropriate FileHandler for
you. If you set it to ``None`` or ``''``, it will remove the handler.
""")
# ------------------------- WSGI handlers ------------------------- #
def _set_wsgi_handler(self, log, enable):
h = self._get_builtin_handler(log, "wsgi")
if enable:
if not h:
h = WSGIErrorHandler()
h.setFormatter(logfmt)
h._cpbuiltin = "wsgi"
log.addHandler(h)
elif h:
log.handlers.remove(h)
def _get_wsgi(self):
return bool(self._get_builtin_handler(self.error_log, "wsgi"))
def _set_wsgi(self, newvalue):
self._set_wsgi_handler(self.error_log, newvalue)
wsgi = property(_get_wsgi, _set_wsgi,
doc="""Write errors to wsgi.errors.
If you set this to True, it'll add the appropriate
:class:`WSGIErrorHandler<cherrypy._cplogging.WSGIErrorHandler>` for you
(which writes errors to ``wsgi.errors``).
If you set it to False, it will remove the handler.
""")
class WSGIErrorHandler(logging.Handler):
"A handler class which writes logging records to environ['wsgi.errors']."
def flush(self):
"""Flushes the stream."""
try:
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
except (AttributeError, KeyError):
pass
else:
stream.flush()
def emit(self, record):
"""Emit a record."""
try:
stream = cherrypy.serving.request.wsgi_environ.get('wsgi.errors')
except (AttributeError, KeyError):
pass
else:
try:
msg = self.format(record)
fs = "%s\n"
import types
# if no unicode support...
if not hasattr(types, "UnicodeType"):
stream.write(fs % msg)
else:
try:
stream.write(fs % msg)
except UnicodeError:
stream.write(fs % msg.encode("UTF-8"))
self.flush()
except:
self.handleError(record)

353
lib/cherrypy/_cpmodpy.py Normal file
View File

@@ -0,0 +1,353 @@
"""Native adapter for serving CherryPy via mod_python
Basic usage:
##########################################
# Application in a module called myapp.py
##########################################
import cherrypy
class Root:
@cherrypy.expose
def index(self):
return 'Hi there, Ho there, Hey there'
# We will use this method from the mod_python configuration
# as the entry point to our application
def setup_server():
cherrypy.tree.mount(Root())
cherrypy.config.update({'environment': 'production',
'log.screen': False,
'show_tracebacks': False})
##########################################
# mod_python settings for apache2
# This should reside in your httpd.conf
# or a file that will be loaded at
# apache startup
##########################################
# Start
DocumentRoot "/"
Listen 8080
LoadModule python_module /usr/lib/apache2/modules/mod_python.so
<Location "/">
PythonPath "sys.path+['/path/to/my/application']"
SetHandler python-program
PythonHandler cherrypy._cpmodpy::handler
PythonOption cherrypy.setup myapp::setup_server
PythonDebug On
</Location>
# End
The actual path to your mod_python.so is dependent on your
environment. In this case we suppose a global mod_python
installation on a Linux distribution such as Ubuntu.
We do set the PythonPath configuration setting so that
your application can be found by from the user running
the apache2 instance. Of course if your application
resides in the global site-package this won't be needed.
Then restart apache2 and access http://127.0.0.1:8080
"""
import logging
import sys
import cherrypy
from cherrypy._cpcompat import BytesIO, copyitems, ntob
from cherrypy._cperror import format_exc, bare_error
from cherrypy.lib import httputil
# ------------------------------ Request-handling
def setup(req):
from mod_python import apache
# Run any setup functions defined by a "PythonOption cherrypy.setup"
# directive.
options = req.get_options()
if 'cherrypy.setup' in options:
for function in options['cherrypy.setup'].split():
atoms = function.split('::', 1)
if len(atoms) == 1:
mod = __import__(atoms[0], globals(), locals())
else:
modname, fname = atoms
mod = __import__(modname, globals(), locals(), [fname])
func = getattr(mod, fname)
func()
cherrypy.config.update({'log.screen': False,
"tools.ignore_headers.on": True,
"tools.ignore_headers.headers": ['Range'],
})
engine = cherrypy.engine
if hasattr(engine, "signal_handler"):
engine.signal_handler.unsubscribe()
if hasattr(engine, "console_control_handler"):
engine.console_control_handler.unsubscribe()
engine.autoreload.unsubscribe()
cherrypy.server.unsubscribe()
def _log(msg, level):
newlevel = apache.APLOG_ERR
if logging.DEBUG >= level:
newlevel = apache.APLOG_DEBUG
elif logging.INFO >= level:
newlevel = apache.APLOG_INFO
elif logging.WARNING >= level:
newlevel = apache.APLOG_WARNING
# On Windows, req.server is required or the msg will vanish. See
# http://www.modpython.org/pipermail/mod_python/2003-October/014291.html
# Also, "When server is not specified...LogLevel does not apply..."
apache.log_error(msg, newlevel, req.server)
engine.subscribe('log', _log)
engine.start()
def cherrypy_cleanup(data):
engine.exit()
try:
# apache.register_cleanup wasn't available until 3.1.4.
apache.register_cleanup(cherrypy_cleanup)
except AttributeError:
req.server.register_cleanup(req, cherrypy_cleanup)
class _ReadOnlyRequest:
expose = ('read', 'readline', 'readlines')
def __init__(self, req):
for method in self.expose:
self.__dict__[method] = getattr(req, method)
recursive = False
_isSetUp = False
def handler(req):
from mod_python import apache
try:
global _isSetUp
if not _isSetUp:
setup(req)
_isSetUp = True
# Obtain a Request object from CherryPy
local = req.connection.local_addr
local = httputil.Host(
local[0], local[1], req.connection.local_host or "")
remote = req.connection.remote_addr
remote = httputil.Host(
remote[0], remote[1], req.connection.remote_host or "")
scheme = req.parsed_uri[0] or 'http'
req.get_basic_auth_pw()
try:
# apache.mpm_query only became available in mod_python 3.1
q = apache.mpm_query
threaded = q(apache.AP_MPMQ_IS_THREADED)
forked = q(apache.AP_MPMQ_IS_FORKED)
except AttributeError:
bad_value = ("You must provide a PythonOption '%s', "
"either 'on' or 'off', when running a version "
"of mod_python < 3.1")
threaded = options.get('multithread', '').lower()
if threaded == 'on':
threaded = True
elif threaded == 'off':
threaded = False
else:
raise ValueError(bad_value % "multithread")
forked = options.get('multiprocess', '').lower()
if forked == 'on':
forked = True
elif forked == 'off':
forked = False
else:
raise ValueError(bad_value % "multiprocess")
sn = cherrypy.tree.script_name(req.uri or "/")
if sn is None:
send_response(req, '404 Not Found', [], '')
else:
app = cherrypy.tree.apps[sn]
method = req.method
path = req.uri
qs = req.args or ""
reqproto = req.protocol
headers = copyitems(req.headers_in)
rfile = _ReadOnlyRequest(req)
prev = None
try:
redirections = []
while True:
request, response = app.get_serving(local, remote, scheme,
"HTTP/1.1")
request.login = req.user
request.multithread = bool(threaded)
request.multiprocess = bool(forked)
request.app = app
request.prev = prev
# Run the CherryPy Request object and obtain the response
try:
request.run(method, path, qs, reqproto, headers, rfile)
break
except cherrypy.InternalRedirect:
ir = sys.exc_info()[1]
app.release_serving()
prev = request
if not recursive:
if ir.path in redirections:
raise RuntimeError(
"InternalRedirector visited the same URL "
"twice: %r" % ir.path)
else:
# Add the *previous* path_info + qs to
# redirections.
if qs:
qs = "?" + qs
redirections.append(sn + path + qs)
# Munge environment and try again.
method = "GET"
path = ir.path
qs = ir.query_string
rfile = BytesIO()
send_response(
req, response.output_status, response.header_list,
response.body, response.stream)
finally:
app.release_serving()
except:
tb = format_exc()
cherrypy.log(tb, 'MOD_PYTHON', severity=logging.ERROR)
s, h, b = bare_error()
send_response(req, s, h, b)
return apache.OK
def send_response(req, status, headers, body, stream=False):
# Set response status
req.status = int(status[:3])
# Set response headers
req.content_type = "text/plain"
for header, value in headers:
if header.lower() == 'content-type':
req.content_type = value
continue
req.headers_out.add(header, value)
if stream:
# Flush now so the status and headers are sent immediately.
req.flush()
# Set response body
if isinstance(body, basestring):
req.write(body)
else:
for seg in body:
req.write(seg)
# --------------- Startup tools for CherryPy + mod_python --------------- #
import os
import re
try:
import subprocess
def popen(fullcmd):
p = subprocess.Popen(fullcmd, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
close_fds=True)
return p.stdout
except ImportError:
def popen(fullcmd):
pipein, pipeout = os.popen4(fullcmd)
return pipeout
def read_process(cmd, args=""):
fullcmd = "%s %s" % (cmd, args)
pipeout = popen(fullcmd)
try:
firstline = pipeout.readline()
cmd_not_found = re.search(
ntob("(not recognized|No such file|not found)"),
firstline,
re.IGNORECASE
)
if cmd_not_found:
raise IOError('%s must be on your system path.' % cmd)
output = firstline + pipeout.read()
finally:
pipeout.close()
return output
class ModPythonServer(object):
template = """
# Apache2 server configuration file for running CherryPy with mod_python.
DocumentRoot "/"
Listen %(port)s
LoadModule python_module modules/mod_python.so
<Location %(loc)s>
SetHandler python-program
PythonHandler %(handler)s
PythonDebug On
%(opts)s
</Location>
"""
def __init__(self, loc="/", port=80, opts=None, apache_path="apache",
handler="cherrypy._cpmodpy::handler"):
self.loc = loc
self.port = port
self.opts = opts
self.apache_path = apache_path
self.handler = handler
def start(self):
opts = "".join([" PythonOption %s %s\n" % (k, v)
for k, v in self.opts])
conf_data = self.template % {"port": self.port,
"loc": self.loc,
"opts": opts,
"handler": self.handler,
}
mpconf = os.path.join(os.path.dirname(__file__), "cpmodpy.conf")
f = open(mpconf, 'wb')
try:
f.write(conf_data)
finally:
f.close()
response = read_process(self.apache_path, "-k start -f %s" % mpconf)
self.ready = True
return response
def stop(self):
os.popen("apache -k stop")
self.ready = False

View File

@@ -0,0 +1,154 @@
"""Native adapter for serving CherryPy via its builtin server."""
import logging
import sys
import cherrypy
from cherrypy._cpcompat import BytesIO
from cherrypy._cperror import format_exc, bare_error
from cherrypy.lib import httputil
from cherrypy import wsgiserver
class NativeGateway(wsgiserver.Gateway):
recursive = False
def respond(self):
req = self.req
try:
# Obtain a Request object from CherryPy
local = req.server.bind_addr
local = httputil.Host(local[0], local[1], "")
remote = req.conn.remote_addr, req.conn.remote_port
remote = httputil.Host(remote[0], remote[1], "")
scheme = req.scheme
sn = cherrypy.tree.script_name(req.uri or "/")
if sn is None:
self.send_response('404 Not Found', [], [''])
else:
app = cherrypy.tree.apps[sn]
method = req.method
path = req.path
qs = req.qs or ""
headers = req.inheaders.items()
rfile = req.rfile
prev = None
try:
redirections = []
while True:
request, response = app.get_serving(
local, remote, scheme, "HTTP/1.1")
request.multithread = True
request.multiprocess = False
request.app = app
request.prev = prev
# Run the CherryPy Request object and obtain the
# response
try:
request.run(method, path, qs,
req.request_protocol, headers, rfile)
break
except cherrypy.InternalRedirect:
ir = sys.exc_info()[1]
app.release_serving()
prev = request
if not self.recursive:
if ir.path in redirections:
raise RuntimeError(
"InternalRedirector visited the same "
"URL twice: %r" % ir.path)
else:
# Add the *previous* path_info + qs to
# redirections.
if qs:
qs = "?" + qs
redirections.append(sn + path + qs)
# Munge environment and try again.
method = "GET"
path = ir.path
qs = ir.query_string
rfile = BytesIO()
self.send_response(
response.output_status, response.header_list,
response.body)
finally:
app.release_serving()
except:
tb = format_exc()
# print tb
cherrypy.log(tb, 'NATIVE_ADAPTER', severity=logging.ERROR)
s, h, b = bare_error()
self.send_response(s, h, b)
def send_response(self, status, headers, body):
req = self.req
# Set response status
req.status = str(status or "500 Server Error")
# Set response headers
for header, value in headers:
req.outheaders.append((header, value))
if (req.ready and not req.sent_headers):
req.sent_headers = True
req.send_headers()
# Set response body
for seg in body:
req.write(seg)
class CPHTTPServer(wsgiserver.HTTPServer):
"""Wrapper for wsgiserver.HTTPServer.
wsgiserver has been designed to not reference CherryPy in any way,
so that it can be used in other frameworks and applications.
Therefore, we wrap it here, so we can apply some attributes
from config -> cherrypy.server -> HTTPServer.
"""
def __init__(self, server_adapter=cherrypy.server):
self.server_adapter = server_adapter
server_name = (self.server_adapter.socket_host or
self.server_adapter.socket_file or
None)
wsgiserver.HTTPServer.__init__(
self, server_adapter.bind_addr, NativeGateway,
minthreads=server_adapter.thread_pool,
maxthreads=server_adapter.thread_pool_max,
server_name=server_name)
self.max_request_header_size = (
self.server_adapter.max_request_header_size or 0)
self.max_request_body_size = (
self.server_adapter.max_request_body_size or 0)
self.request_queue_size = self.server_adapter.socket_queue_size
self.timeout = self.server_adapter.socket_timeout
self.shutdown_timeout = self.server_adapter.shutdown_timeout
self.protocol = self.server_adapter.protocol_version
self.nodelay = self.server_adapter.nodelay
ssl_module = self.server_adapter.ssl_module or 'pyopenssl'
if self.server_adapter.ssl_context:
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
self.ssl_adapter = adapter_class(
self.server_adapter.ssl_certificate,
self.server_adapter.ssl_private_key,
self.server_adapter.ssl_certificate_chain)
self.ssl_adapter.context = self.server_adapter.ssl_context
elif self.server_adapter.ssl_certificate:
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
self.ssl_adapter = adapter_class(
self.server_adapter.ssl_certificate,
self.server_adapter.ssl_private_key,
self.server_adapter.ssl_certificate_chain)

1013
lib/cherrypy/_cpreqbody.py Normal file

File diff suppressed because it is too large Load Diff

973
lib/cherrypy/_cprequest.py Normal file
View File

@@ -0,0 +1,973 @@
import os
import sys
import time
import warnings
import cherrypy
from cherrypy._cpcompat import basestring, copykeys, ntob, unicodestr
from cherrypy._cpcompat import SimpleCookie, CookieError, py3k
from cherrypy import _cpreqbody, _cpconfig
from cherrypy._cperror import format_exc, bare_error
from cherrypy.lib import httputil, file_generator
class Hook(object):
"""A callback and its metadata: failsafe, priority, and kwargs."""
callback = None
"""
The bare callable that this Hook object is wrapping, which will
be called when the Hook is called."""
failsafe = False
"""
If True, the callback is guaranteed to run even if other callbacks
from the same call point raise exceptions."""
priority = 50
"""
Defines the order of execution for a list of Hooks. Priority numbers
should be limited to the closed interval [0, 100], but values outside
this range are acceptable, as are fractional values."""
kwargs = {}
"""
A set of keyword arguments that will be passed to the
callable on each call."""
def __init__(self, callback, failsafe=None, priority=None, **kwargs):
self.callback = callback
if failsafe is None:
failsafe = getattr(callback, "failsafe", False)
self.failsafe = failsafe
if priority is None:
priority = getattr(callback, "priority", 50)
self.priority = priority
self.kwargs = kwargs
def __lt__(self, other):
# Python 3
return self.priority < other.priority
def __cmp__(self, other):
# Python 2
return cmp(self.priority, other.priority)
def __call__(self):
"""Run self.callback(**self.kwargs)."""
return self.callback(**self.kwargs)
def __repr__(self):
cls = self.__class__
return ("%s.%s(callback=%r, failsafe=%r, priority=%r, %s)"
% (cls.__module__, cls.__name__, self.callback,
self.failsafe, self.priority,
", ".join(['%s=%r' % (k, v)
for k, v in self.kwargs.items()])))
class HookMap(dict):
"""A map of call points to lists of callbacks (Hook objects)."""
def __new__(cls, points=None):
d = dict.__new__(cls)
for p in points or []:
d[p] = []
return d
def __init__(self, *a, **kw):
pass
def attach(self, point, callback, failsafe=None, priority=None, **kwargs):
"""Append a new Hook made from the supplied arguments."""
self[point].append(Hook(callback, failsafe, priority, **kwargs))
def run(self, point):
"""Execute all registered Hooks (callbacks) for the given point."""
exc = None
hooks = self[point]
hooks.sort()
for hook in hooks:
# Some hooks are guaranteed to run even if others at
# the same hookpoint fail. We will still log the failure,
# but proceed on to the next hook. The only way
# to stop all processing from one of these hooks is
# to raise SystemExit and stop the whole server.
if exc is None or hook.failsafe:
try:
hook()
except (KeyboardInterrupt, SystemExit):
raise
except (cherrypy.HTTPError, cherrypy.HTTPRedirect,
cherrypy.InternalRedirect):
exc = sys.exc_info()[1]
except:
exc = sys.exc_info()[1]
cherrypy.log(traceback=True, severity=40)
if exc:
raise exc
def __copy__(self):
newmap = self.__class__()
# We can't just use 'update' because we want copies of the
# mutable values (each is a list) as well.
for k, v in self.items():
newmap[k] = v[:]
return newmap
copy = __copy__
def __repr__(self):
cls = self.__class__
return "%s.%s(points=%r)" % (
cls.__module__,
cls.__name__,
copykeys(self)
)
# Config namespace handlers
def hooks_namespace(k, v):
"""Attach bare hooks declared in config."""
# Use split again to allow multiple hooks for a single
# hookpoint per path (e.g. "hooks.before_handler.1").
# Little-known fact you only get from reading source ;)
hookpoint = k.split(".", 1)[0]
if isinstance(v, basestring):
v = cherrypy.lib.attributes(v)
if not isinstance(v, Hook):
v = Hook(v)
cherrypy.serving.request.hooks[hookpoint].append(v)
def request_namespace(k, v):
"""Attach request attributes declared in config."""
# Provides config entries to set request.body attrs (like
# attempt_charsets).
if k[:5] == 'body.':
setattr(cherrypy.serving.request.body, k[5:], v)
else:
setattr(cherrypy.serving.request, k, v)
def response_namespace(k, v):
"""Attach response attributes declared in config."""
# Provides config entries to set default response headers
# http://cherrypy.org/ticket/889
if k[:8] == 'headers.':
cherrypy.serving.response.headers[k.split('.', 1)[1]] = v
else:
setattr(cherrypy.serving.response, k, v)
def error_page_namespace(k, v):
"""Attach error pages declared in config."""
if k != 'default':
k = int(k)
cherrypy.serving.request.error_page[k] = v
hookpoints = ['on_start_resource', 'before_request_body',
'before_handler', 'before_finalize',
'on_end_resource', 'on_end_request',
'before_error_response', 'after_error_response']
class Request(object):
"""An HTTP request.
This object represents the metadata of an HTTP request message;
that is, it contains attributes which describe the environment
in which the request URL, headers, and body were sent (if you
want tools to interpret the headers and body, those are elsewhere,
mostly in Tools). This 'metadata' consists of socket data,
transport characteristics, and the Request-Line. This object
also contains data regarding the configuration in effect for
the given URL, and the execution plan for generating a response.
"""
prev = None
"""
The previous Request object (if any). This should be None
unless we are processing an InternalRedirect."""
# Conversation/connection attributes
local = httputil.Host("127.0.0.1", 80)
"An httputil.Host(ip, port, hostname) object for the server socket."
remote = httputil.Host("127.0.0.1", 1111)
"An httputil.Host(ip, port, hostname) object for the client socket."
scheme = "http"
"""
The protocol used between client and server. In most cases,
this will be either 'http' or 'https'."""
server_protocol = "HTTP/1.1"
"""
The HTTP version for which the HTTP server is at least
conditionally compliant."""
base = ""
"""The (scheme://host) portion of the requested URL.
In some cases (e.g. when proxying via mod_rewrite), this may contain
path segments which cherrypy.url uses when constructing url's, but
which otherwise are ignored by CherryPy. Regardless, this value
MUST NOT end in a slash."""
# Request-Line attributes
request_line = ""
"""
The complete Request-Line received from the client. This is a
single string consisting of the request method, URI, and protocol
version (joined by spaces). Any final CRLF is removed."""
method = "GET"
"""
Indicates the HTTP method to be performed on the resource identified
by the Request-URI. Common methods include GET, HEAD, POST, PUT, and
DELETE. CherryPy allows any extension method; however, various HTTP
servers and gateways may restrict the set of allowable methods.
CherryPy applications SHOULD restrict the set (on a per-URI basis)."""
query_string = ""
"""
The query component of the Request-URI, a string of information to be
interpreted by the resource. The query portion of a URI follows the
path component, and is separated by a '?'. For example, the URI
'http://www.cherrypy.org/wiki?a=3&b=4' has the query component,
'a=3&b=4'."""
query_string_encoding = 'utf8'
"""
The encoding expected for query string arguments after % HEX HEX decoding).
If a query string is provided that cannot be decoded with this encoding,
404 is raised (since technically it's a different URI). If you want
arbitrary encodings to not error, set this to 'Latin-1'; you can then
encode back to bytes and re-decode to whatever encoding you like later.
"""
protocol = (1, 1)
"""The HTTP protocol version corresponding to the set
of features which should be allowed in the response. If BOTH
the client's request message AND the server's level of HTTP
compliance is HTTP/1.1, this attribute will be the tuple (1, 1).
If either is 1.0, this attribute will be the tuple (1, 0).
Lower HTTP protocol versions are not explicitly supported."""
params = {}
"""
A dict which combines query string (GET) and request entity (POST)
variables. This is populated in two stages: GET params are added
before the 'on_start_resource' hook, and POST params are added
between the 'before_request_body' and 'before_handler' hooks."""
# Message attributes
header_list = []
"""
A list of the HTTP request headers as (name, value) tuples.
In general, you should use request.headers (a dict) instead."""
headers = httputil.HeaderMap()
"""
A dict-like object containing the request headers. Keys are header
names (in Title-Case format); however, you may get and set them in
a case-insensitive manner. That is, headers['Content-Type'] and
headers['content-type'] refer to the same value. Values are header
values (decoded according to :rfc:`2047` if necessary). See also:
httputil.HeaderMap, httputil.HeaderElement."""
cookie = SimpleCookie()
"""See help(Cookie)."""
rfile = None
"""
If the request included an entity (body), it will be available
as a stream in this attribute. However, the rfile will normally
be read for you between the 'before_request_body' hook and the
'before_handler' hook, and the resulting string is placed into
either request.params or the request.body attribute.
You may disable the automatic consumption of the rfile by setting
request.process_request_body to False, either in config for the desired
path, or in an 'on_start_resource' or 'before_request_body' hook.
WARNING: In almost every case, you should not attempt to read from the
rfile stream after CherryPy's automatic mechanism has read it. If you
turn off the automatic parsing of rfile, you should read exactly the
number of bytes specified in request.headers['Content-Length'].
Ignoring either of these warnings may result in a hung request thread
or in corruption of the next (pipelined) request.
"""
process_request_body = True
"""
If True, the rfile (if any) is automatically read and parsed,
and the result placed into request.params or request.body."""
methods_with_bodies = ("POST", "PUT")
"""
A sequence of HTTP methods for which CherryPy will automatically
attempt to read a body from the rfile. If you are going to change
this property, modify it on the configuration (recommended)
or on the "hook point" `on_start_resource`.
"""
body = None
"""
If the request Content-Type is 'application/x-www-form-urlencoded'
or multipart, this will be None. Otherwise, this will be an instance
of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you
can .read()); this value is set between the 'before_request_body' and
'before_handler' hooks (assuming that process_request_body is True)."""
# Dispatch attributes
dispatch = cherrypy.dispatch.Dispatcher()
"""
The object which looks up the 'page handler' callable and collects
config for the current request based on the path_info, other
request attributes, and the application architecture. The core
calls the dispatcher as early as possible, passing it a 'path_info'
argument.
The default dispatcher discovers the page handler by matching path_info
to a hierarchical arrangement of objects, starting at request.app.root.
See help(cherrypy.dispatch) for more information."""
script_name = ""
"""
The 'mount point' of the application which is handling this request.
This attribute MUST NOT end in a slash. If the script_name refers to
the root of the URI, it MUST be an empty string (not "/").
"""
path_info = "/"
"""
The 'relative path' portion of the Request-URI. This is relative
to the script_name ('mount point') of the application which is
handling this request."""
login = None
"""
When authentication is used during the request processing this is
set to 'False' if it failed and to the 'username' value if it succeeded.
The default 'None' implies that no authentication happened."""
# Note that cherrypy.url uses "if request.app:" to determine whether
# the call is during a real HTTP request or not. So leave this None.
app = None
"""The cherrypy.Application object which is handling this request."""
handler = None
"""
The function, method, or other callable which CherryPy will call to
produce the response. The discovery of the handler and the arguments
it will receive are determined by the request.dispatch object.
By default, the handler is discovered by walking a tree of objects
starting at request.app.root, and is then passed all HTTP params
(from the query string and POST body) as keyword arguments."""
toolmaps = {}
"""
A nested dict of all Toolboxes and Tools in effect for this request,
of the form: {Toolbox.namespace: {Tool.name: config dict}}."""
config = None
"""
A flat dict of all configuration entries which apply to the
current request. These entries are collected from global config,
application config (based on request.path_info), and from handler
config (exactly how is governed by the request.dispatch object in
effect for this request; by default, handler config can be attached
anywhere in the tree between request.app.root and the final handler,
and inherits downward)."""
is_index = None
"""
This will be True if the current request is mapped to an 'index'
resource handler (also, a 'default' handler if path_info ends with
a slash). The value may be used to automatically redirect the
user-agent to a 'more canonical' URL which either adds or removes
the trailing slash. See cherrypy.tools.trailing_slash."""
hooks = HookMap(hookpoints)
"""
A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
Each key is a str naming the hook point, and each value is a list
of hooks which will be called at that hook point during this request.
The list of hooks is generally populated as early as possible (mostly
from Tools specified in config), but may be extended at any time.
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools."""
error_response = cherrypy.HTTPError(500).set_response
"""
The no-arg callable which will handle unexpected, untrapped errors
during request processing. This is not used for expected exceptions
(like NotFound, HTTPError, or HTTPRedirect) which are raised in
response to expected conditions (those should be customized either
via request.error_page or by overriding HTTPError.set_response).
By default, error_response uses HTTPError(500) to return a generic
error response to the user-agent."""
error_page = {}
"""
A dict of {error code: response filename or callable} pairs.
The error code must be an int representing a given HTTP error code,
or the string 'default', which will be used if no matching entry
is found for a given numeric code.
If a filename is provided, the file should contain a Python string-
formatting template, and can expect by default to receive format
values with the mapping keys %(status)s, %(message)s, %(traceback)s,
and %(version)s. The set of format mappings can be extended by
overriding HTTPError.set_response.
If a callable is provided, it will be called by default with keyword
arguments 'status', 'message', 'traceback', and 'version', as for a
string-formatting template. The callable must return a string or
iterable of strings which will be set to response.body. It may also
override headers or perform any other processing.
If no entry is given for an error code, and no 'default' entry exists,
a default template will be used.
"""
show_tracebacks = True
"""
If True, unexpected errors encountered during request processing will
include a traceback in the response body."""
show_mismatched_params = True
"""
If True, mismatched parameters encountered during PageHandler invocation
processing will be included in the response body."""
throws = (KeyboardInterrupt, SystemExit, cherrypy.InternalRedirect)
"""The sequence of exceptions which Request.run does not trap."""
throw_errors = False
"""
If True, Request.run will not trap any errors (except HTTPRedirect and
HTTPError, which are more properly called 'exceptions', not errors)."""
closed = False
"""True once the close method has been called, False otherwise."""
stage = None
"""
A string containing the stage reached in the request-handling process.
This is useful when debugging a live server with hung requests."""
namespaces = _cpconfig.NamespaceSet(
**{"hooks": hooks_namespace,
"request": request_namespace,
"response": response_namespace,
"error_page": error_page_namespace,
"tools": cherrypy.tools,
})
def __init__(self, local_host, remote_host, scheme="http",
server_protocol="HTTP/1.1"):
"""Populate a new Request object.
local_host should be an httputil.Host object with the server info.
remote_host should be an httputil.Host object with the client info.
scheme should be a string, either "http" or "https".
"""
self.local = local_host
self.remote = remote_host
self.scheme = scheme
self.server_protocol = server_protocol
self.closed = False
# Put a *copy* of the class error_page into self.
self.error_page = self.error_page.copy()
# Put a *copy* of the class namespaces into self.
self.namespaces = self.namespaces.copy()
self.stage = None
def close(self):
"""Run cleanup code. (Core)"""
if not self.closed:
self.closed = True
self.stage = 'on_end_request'
self.hooks.run('on_end_request')
self.stage = 'close'
def run(self, method, path, query_string, req_protocol, headers, rfile):
r"""Process the Request. (Core)
method, path, query_string, and req_protocol should be pulled directly
from the Request-Line (e.g. "GET /path?key=val HTTP/1.0").
path
This should be %XX-unquoted, but query_string should not be.
When using Python 2, they both MUST be byte strings,
not unicode strings.
When using Python 3, they both MUST be unicode strings,
not byte strings, and preferably not bytes \x00-\xFF
disguised as unicode.
headers
A list of (name, value) tuples.
rfile
A file-like object containing the HTTP request entity.
When run() is done, the returned object should have 3 attributes:
* status, e.g. "200 OK"
* header_list, a list of (name, value) tuples
* body, an iterable yielding strings
Consumer code (HTTP servers) should then access these response
attributes to build the outbound stream.
"""
response = cherrypy.serving.response
self.stage = 'run'
try:
self.error_response = cherrypy.HTTPError(500).set_response
self.method = method
path = path or "/"
self.query_string = query_string or ''
self.params = {}
# Compare request and server HTTP protocol versions, in case our
# server does not support the requested protocol. Limit our output
# to min(req, server). We want the following output:
# request server actual written supported response
# protocol protocol response protocol feature set
# a 1.0 1.0 1.0 1.0
# b 1.0 1.1 1.1 1.0
# c 1.1 1.0 1.0 1.0
# d 1.1 1.1 1.1 1.1
# Notice that, in (b), the response will be "HTTP/1.1" even though
# the client only understands 1.0. RFC 2616 10.5.6 says we should
# only return 505 if the _major_ version is different.
rp = int(req_protocol[5]), int(req_protocol[7])
sp = int(self.server_protocol[5]), int(self.server_protocol[7])
self.protocol = min(rp, sp)
response.headers.protocol = self.protocol
# Rebuild first line of the request (e.g. "GET /path HTTP/1.0").
url = path
if query_string:
url += '?' + query_string
self.request_line = '%s %s %s' % (method, url, req_protocol)
self.header_list = list(headers)
self.headers = httputil.HeaderMap()
self.rfile = rfile
self.body = None
self.cookie = SimpleCookie()
self.handler = None
# path_info should be the path from the
# app root (script_name) to the handler.
self.script_name = self.app.script_name
self.path_info = pi = path[len(self.script_name):]
self.stage = 'respond'
self.respond(pi)
except self.throws:
raise
except:
if self.throw_errors:
raise
else:
# Failure in setup, error handler or finalize. Bypass them.
# Can't use handle_error because we may not have hooks yet.
cherrypy.log(traceback=True, severity=40)
if self.show_tracebacks:
body = format_exc()
else:
body = ""
r = bare_error(body)
response.output_status, response.header_list, response.body = r
if self.method == "HEAD":
# HEAD requests MUST NOT return a message-body in the response.
response.body = []
try:
cherrypy.log.access()
except:
cherrypy.log.error(traceback=True)
if response.timed_out:
raise cherrypy.TimeoutError()
return response
# Uncomment for stage debugging
# stage = property(lambda self: self._stage, lambda self, v: print(v))
def respond(self, path_info):
"""Generate a response for the resource at self.path_info. (Core)"""
response = cherrypy.serving.response
try:
try:
try:
if self.app is None:
raise cherrypy.NotFound()
# Get the 'Host' header, so we can HTTPRedirect properly.
self.stage = 'process_headers'
self.process_headers()
# Make a copy of the class hooks
self.hooks = self.__class__.hooks.copy()
self.toolmaps = {}
self.stage = 'get_resource'
self.get_resource(path_info)
self.body = _cpreqbody.RequestBody(
self.rfile, self.headers, request_params=self.params)
self.namespaces(self.config)
self.stage = 'on_start_resource'
self.hooks.run('on_start_resource')
# Parse the querystring
self.stage = 'process_query_string'
self.process_query_string()
# Process the body
if self.process_request_body:
if self.method not in self.methods_with_bodies:
self.process_request_body = False
self.stage = 'before_request_body'
self.hooks.run('before_request_body')
if self.process_request_body:
self.body.process()
# Run the handler
self.stage = 'before_handler'
self.hooks.run('before_handler')
if self.handler:
self.stage = 'handler'
response.body = self.handler()
# Finalize
self.stage = 'before_finalize'
self.hooks.run('before_finalize')
response.finalize()
except (cherrypy.HTTPRedirect, cherrypy.HTTPError):
inst = sys.exc_info()[1]
inst.set_response()
self.stage = 'before_finalize (HTTPError)'
self.hooks.run('before_finalize')
response.finalize()
finally:
self.stage = 'on_end_resource'
self.hooks.run('on_end_resource')
except self.throws:
raise
except:
if self.throw_errors:
raise
self.handle_error()
def process_query_string(self):
"""Parse the query string into Python structures. (Core)"""
try:
p = httputil.parse_query_string(
self.query_string, encoding=self.query_string_encoding)
except UnicodeDecodeError:
raise cherrypy.HTTPError(
404, "The given query string could not be processed. Query "
"strings for this resource must be encoded with %r." %
self.query_string_encoding)
# Python 2 only: keyword arguments must be byte strings (type 'str').
if not py3k:
for key, value in p.items():
if isinstance(key, unicode):
del p[key]
p[key.encode(self.query_string_encoding)] = value
self.params.update(p)
def process_headers(self):
"""Parse HTTP header data into Python structures. (Core)"""
# Process the headers into self.headers
headers = self.headers
for name, value in self.header_list:
# Call title() now (and use dict.__method__(headers))
# so title doesn't have to be called twice.
name = name.title()
value = value.strip()
# Warning: if there is more than one header entry for cookies
# (AFAIK, only Konqueror does that), only the last one will
# remain in headers (but they will be correctly stored in
# request.cookie).
if "=?" in value:
dict.__setitem__(headers, name, httputil.decode_TEXT(value))
else:
dict.__setitem__(headers, name, value)
# Handle cookies differently because on Konqueror, multiple
# cookies come on different lines with the same key
if name == 'Cookie':
try:
self.cookie.load(value)
except CookieError:
msg = "Illegal cookie name %s" % value.split('=')[0]
raise cherrypy.HTTPError(400, msg)
if not dict.__contains__(headers, 'Host'):
# All Internet-based HTTP/1.1 servers MUST respond with a 400
# (Bad Request) status code to any HTTP/1.1 request message
# which lacks a Host header field.
if self.protocol >= (1, 1):
msg = "HTTP/1.1 requires a 'Host' request header."
raise cherrypy.HTTPError(400, msg)
host = dict.get(headers, 'Host')
if not host:
host = self.local.name or self.local.ip
self.base = "%s://%s" % (self.scheme, host)
def get_resource(self, path):
"""Call a dispatcher (which sets self.handler and .config). (Core)"""
# First, see if there is a custom dispatch at this URI. Custom
# dispatchers can only be specified in app.config, not in _cp_config
# (since custom dispatchers may not even have an app.root).
dispatch = self.app.find_config(
path, "request.dispatch", self.dispatch)
# dispatch() should set self.handler and self.config
dispatch(path)
def handle_error(self):
"""Handle the last unanticipated exception. (Core)"""
try:
self.hooks.run("before_error_response")
if self.error_response:
self.error_response()
self.hooks.run("after_error_response")
cherrypy.serving.response.finalize()
except cherrypy.HTTPRedirect:
inst = sys.exc_info()[1]
inst.set_response()
cherrypy.serving.response.finalize()
# ------------------------- Properties ------------------------- #
def _get_body_params(self):
warnings.warn(
"body_params is deprecated in CherryPy 3.2, will be removed in "
"CherryPy 3.3.",
DeprecationWarning
)
return self.body.params
body_params = property(_get_body_params,
doc="""
If the request Content-Type is 'application/x-www-form-urlencoded' or
multipart, this will be a dict of the params pulled from the entity
body; that is, it will be the portion of request.params that come
from the message body (sometimes called "POST params", although they
can be sent with various HTTP method verbs). This value is set between
the 'before_request_body' and 'before_handler' hooks (assuming that
process_request_body is True).
Deprecated in 3.2, will be removed for 3.3 in favor of
:attr:`request.body.params<cherrypy._cprequest.RequestBody.params>`.""")
class ResponseBody(object):
"""The body of the HTTP response (the response entity)."""
if py3k:
unicode_err = ("Page handlers MUST return bytes. Use tools.encode "
"if you wish to return unicode.")
def __get__(self, obj, objclass=None):
if obj is None:
# When calling on the class instead of an instance...
return self
else:
return obj._body
def __set__(self, obj, value):
# Convert the given value to an iterable object.
if py3k and isinstance(value, str):
raise ValueError(self.unicode_err)
if isinstance(value, basestring):
# strings get wrapped in a list because iterating over a single
# item list is much faster than iterating over every character
# in a long string.
if value:
value = [value]
else:
# [''] doesn't evaluate to False, so replace it with [].
value = []
elif py3k and isinstance(value, list):
# every item in a list must be bytes...
for i, item in enumerate(value):
if isinstance(item, str):
raise ValueError(self.unicode_err)
# Don't use isinstance here; io.IOBase which has an ABC takes
# 1000 times as long as, say, isinstance(value, str)
elif hasattr(value, 'read'):
value = file_generator(value)
elif value is None:
value = []
obj._body = value
class Response(object):
"""An HTTP Response, including status, headers, and body."""
status = ""
"""The HTTP Status-Code and Reason-Phrase."""
header_list = []
"""
A list of the HTTP response headers as (name, value) tuples.
In general, you should use response.headers (a dict) instead. This
attribute is generated from response.headers and is not valid until
after the finalize phase."""
headers = httputil.HeaderMap()
"""
A dict-like object containing the response headers. Keys are header
names (in Title-Case format); however, you may get and set them in
a case-insensitive manner. That is, headers['Content-Type'] and
headers['content-type'] refer to the same value. Values are header
values (decoded according to :rfc:`2047` if necessary).
.. seealso:: classes :class:`HeaderMap`, :class:`HeaderElement`
"""
cookie = SimpleCookie()
"""See help(Cookie)."""
body = ResponseBody()
"""The body (entity) of the HTTP response."""
time = None
"""The value of time.time() when created. Use in HTTP dates."""
timeout = 300
"""Seconds after which the response will be aborted."""
timed_out = False
"""
Flag to indicate the response should be aborted, because it has
exceeded its timeout."""
stream = False
"""If False, buffer the response body."""
def __init__(self):
self.status = None
self.header_list = None
self._body = []
self.time = time.time()
self.headers = httputil.HeaderMap()
# Since we know all our keys are titled strings, we can
# bypass HeaderMap.update and get a big speed boost.
dict.update(self.headers, {
"Content-Type": 'text/html',
"Server": "CherryPy/" + cherrypy.__version__,
"Date": httputil.HTTPDate(self.time),
})
self.cookie = SimpleCookie()
def collapse_body(self):
"""Collapse self.body to a single string; replace it and return it."""
if isinstance(self.body, basestring):
return self.body
newbody = []
for chunk in self.body:
if py3k and not isinstance(chunk, bytes):
raise TypeError("Chunk %s is not of type 'bytes'." %
repr(chunk))
newbody.append(chunk)
newbody = ntob('').join(newbody)
self.body = newbody
return newbody
def finalize(self):
"""Transform headers (and cookies) into self.header_list. (Core)"""
try:
code, reason, _ = httputil.valid_status(self.status)
except ValueError:
raise cherrypy.HTTPError(500, sys.exc_info()[1].args[0])
headers = self.headers
self.status = "%s %s" % (code, reason)
self.output_status = ntob(str(code), 'ascii') + \
ntob(" ") + headers.encode(reason)
if self.stream:
# The upshot: wsgiserver will chunk the response if
# you pop Content-Length (or set it explicitly to None).
# Note that lib.static sets C-L to the file's st_size.
if dict.get(headers, 'Content-Length') is None:
dict.pop(headers, 'Content-Length', None)
elif code < 200 or code in (204, 205, 304):
# "All 1xx (informational), 204 (no content),
# and 304 (not modified) responses MUST NOT
# include a message-body."
dict.pop(headers, 'Content-Length', None)
self.body = ntob("")
else:
# Responses which are not streamed should have a Content-Length,
# but allow user code to set Content-Length if desired.
if dict.get(headers, 'Content-Length') is None:
content = self.collapse_body()
dict.__setitem__(headers, 'Content-Length', len(content))
# Transform our header dict into a list of tuples.
self.header_list = h = headers.output()
cookie = self.cookie.output()
if cookie:
for line in cookie.split("\n"):
if line.endswith("\r"):
# Python 2.4 emits cookies joined by LF but 2.5+ by CRLF.
line = line[:-1]
name, value = line.split(": ", 1)
if isinstance(name, unicodestr):
name = name.encode("ISO-8859-1")
if isinstance(value, unicodestr):
value = headers.encode(value)
h.append((name, value))
def check_timeout(self):
"""If now > self.time + self.timeout, set self.timed_out.
This purposefully sets a flag, rather than raising an error,
so that a monitor thread can interrupt the Response thread.
"""
if time.time() > self.time + self.timeout:
self.timed_out = True

226
lib/cherrypy/_cpserver.py Normal file
View File

@@ -0,0 +1,226 @@
"""Manage HTTP servers with CherryPy."""
import warnings
import cherrypy
from cherrypy.lib import attributes
from cherrypy._cpcompat import basestring, py3k
# We import * because we want to export check_port
# et al as attributes of this module.
from cherrypy.process.servers import *
class Server(ServerAdapter):
"""An adapter for an HTTP server.
You can set attributes (like socket_host and socket_port)
on *this* object (which is probably cherrypy.server), and call
quickstart. For example::
cherrypy.server.socket_port = 80
cherrypy.quickstart()
"""
socket_port = 8080
"""The TCP port on which to listen for connections."""
_socket_host = '127.0.0.1'
def _get_socket_host(self):
return self._socket_host
def _set_socket_host(self, value):
if value == '':
raise ValueError("The empty string ('') is not an allowed value. "
"Use '0.0.0.0' instead to listen on all active "
"interfaces (INADDR_ANY).")
self._socket_host = value
socket_host = property(
_get_socket_host,
_set_socket_host,
doc="""The hostname or IP address on which to listen for connections.
Host values may be any IPv4 or IPv6 address, or any valid hostname.
The string 'localhost' is a synonym for '127.0.0.1' (or '::1', if
your hosts file prefers IPv6). The string '0.0.0.0' is a special
IPv4 entry meaning "any active interface" (INADDR_ANY), and '::'
is the similar IN6ADDR_ANY for IPv6. The empty string or None are
not allowed.""")
socket_file = None
"""If given, the name of the UNIX socket to use instead of TCP/IP.
When this option is not None, the `socket_host` and `socket_port` options
are ignored."""
socket_queue_size = 5
"""The 'backlog' argument to socket.listen(); specifies the maximum number
of queued connections (default 5)."""
socket_timeout = 10
"""The timeout in seconds for accepted connections (default 10)."""
accepted_queue_size = -1
"""The maximum number of requests which will be queued up before
the server refuses to accept it (default -1, meaning no limit)."""
accepted_queue_timeout = 10
"""The timeout in seconds for attempting to add a request to the
queue when the queue is full (default 10)."""
shutdown_timeout = 5
"""The time to wait for HTTP worker threads to clean up."""
protocol_version = 'HTTP/1.1'
"""The version string to write in the Status-Line of all HTTP responses,
for example, "HTTP/1.1" (the default). Depending on the HTTP server used,
this should also limit the supported features used in the response."""
thread_pool = 10
"""The number of worker threads to start up in the pool."""
thread_pool_max = -1
"""The maximum size of the worker-thread pool. Use -1 to indicate no limit.
"""
max_request_header_size = 500 * 1024
"""The maximum number of bytes allowable in the request headers.
If exceeded, the HTTP server should return "413 Request Entity Too Large".
"""
max_request_body_size = 100 * 1024 * 1024
"""The maximum number of bytes allowable in the request body. If exceeded,
the HTTP server should return "413 Request Entity Too Large"."""
instance = None
"""If not None, this should be an HTTP server instance (such as
CPWSGIServer) which cherrypy.server will control. Use this when you need
more control over object instantiation than is available in the various
configuration options."""
ssl_context = None
"""When using PyOpenSSL, an instance of SSL.Context."""
ssl_certificate = None
"""The filename of the SSL certificate to use."""
ssl_certificate_chain = None
"""When using PyOpenSSL, the certificate chain to pass to
Context.load_verify_locations."""
ssl_private_key = None
"""The filename of the private key to use with SSL."""
if py3k:
ssl_module = 'builtin'
"""The name of a registered SSL adaptation module to use with
the builtin WSGI server. Builtin options are: 'builtin' (to
use the SSL library built into recent versions of Python).
You may also register your own classes in the
wsgiserver.ssl_adapters dict."""
else:
ssl_module = 'pyopenssl'
"""The name of a registered SSL adaptation module to use with the
builtin WSGI server. Builtin options are 'builtin' (to use the SSL
library built into recent versions of Python) and 'pyopenssl' (to
use the PyOpenSSL project, which you must install separately). You
may also register your own classes in the wsgiserver.ssl_adapters
dict."""
statistics = False
"""Turns statistics-gathering on or off for aware HTTP servers."""
nodelay = True
"""If True (the default since 3.1), sets the TCP_NODELAY socket option."""
wsgi_version = (1, 0)
"""The WSGI version tuple to use with the builtin WSGI server.
The provided options are (1, 0) [which includes support for PEP 3333,
which declares it covers WSGI version 1.0.1 but still mandates the
wsgi.version (1, 0)] and ('u', 0), an experimental unicode version.
You may create and register your own experimental versions of the WSGI
protocol by adding custom classes to the wsgiserver.wsgi_gateways dict."""
def __init__(self):
self.bus = cherrypy.engine
self.httpserver = None
self.interrupt = None
self.running = False
def httpserver_from_self(self, httpserver=None):
"""Return a (httpserver, bind_addr) pair based on self attributes."""
if httpserver is None:
httpserver = self.instance
if httpserver is None:
from cherrypy import _cpwsgi_server
httpserver = _cpwsgi_server.CPWSGIServer(self)
if isinstance(httpserver, basestring):
# Is anyone using this? Can I add an arg?
httpserver = attributes(httpserver)(self)
return httpserver, self.bind_addr
def start(self):
"""Start the HTTP server."""
if not self.httpserver:
self.httpserver, self.bind_addr = self.httpserver_from_self()
ServerAdapter.start(self)
start.priority = 75
def _get_bind_addr(self):
if self.socket_file:
return self.socket_file
if self.socket_host is None and self.socket_port is None:
return None
return (self.socket_host, self.socket_port)
def _set_bind_addr(self, value):
if value is None:
self.socket_file = None
self.socket_host = None
self.socket_port = None
elif isinstance(value, basestring):
self.socket_file = value
self.socket_host = None
self.socket_port = None
else:
try:
self.socket_host, self.socket_port = value
self.socket_file = None
except ValueError:
raise ValueError("bind_addr must be a (host, port) tuple "
"(for TCP sockets) or a string (for Unix "
"domain sockets), not %r" % value)
bind_addr = property(
_get_bind_addr,
_set_bind_addr,
doc='A (host, port) tuple for TCP sockets or '
'a str for Unix domain sockets.')
def base(self):
"""Return the base (scheme://host[:port] or sock file) for this server.
"""
if self.socket_file:
return self.socket_file
host = self.socket_host
if host in ('0.0.0.0', '::'):
# 0.0.0.0 is INADDR_ANY and :: is IN6ADDR_ANY.
# Look up the host name, which should be the
# safest thing to spit out in a URL.
import socket
host = socket.gethostname()
port = self.socket_port
if self.ssl_certificate:
scheme = "https"
if port != 443:
host += ":%s" % port
else:
scheme = "http"
if port != 80:
host += ":%s" % port
return "%s://%s" % (scheme, host)

View File

@@ -0,0 +1,241 @@
# This is a backport of Python-2.4's threading.local() implementation
"""Thread-local objects
(Note that this module provides a Python version of thread
threading.local class. Depending on the version of Python you're
using, there may be a faster one available. You should always import
the local class from threading.)
Thread-local objects support the management of thread-local data.
If you have data that you want to be local to a thread, simply create
a thread-local object and use its attributes:
>>> mydata = local()
>>> mydata.number = 42
>>> mydata.number
42
You can also access the local-object's dictionary:
>>> mydata.__dict__
{'number': 42}
>>> mydata.__dict__.setdefault('widgets', [])
[]
>>> mydata.widgets
[]
What's important about thread-local objects is that their data are
local to a thread. If we access the data in a different thread:
>>> log = []
>>> def f():
... items = mydata.__dict__.items()
... items.sort()
... log.append(items)
... mydata.number = 11
... log.append(mydata.number)
>>> import threading
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
>>> log
[[], 11]
we get different data. Furthermore, changes made in the other thread
don't affect data seen in this thread:
>>> mydata.number
42
Of course, values you get from a local object, including a __dict__
attribute, are for whatever thread was current at the time the
attribute was read. For that reason, you generally don't want to save
these values across threads, as they apply only to the thread they
came from.
You can create custom local objects by subclassing the local class:
>>> class MyLocal(local):
... number = 2
... initialized = False
... def __init__(self, **kw):
... if self.initialized:
... raise SystemError('__init__ called too many times')
... self.initialized = True
... self.__dict__.update(kw)
... def squared(self):
... return self.number ** 2
This can be useful to support default values, methods and
initialization. Note that if you define an __init__ method, it will be
called each time the local object is used in a separate thread. This
is necessary to initialize each thread's dictionary.
Now if we create a local object:
>>> mydata = MyLocal(color='red')
Now we have a default number:
>>> mydata.number
2
an initial color:
>>> mydata.color
'red'
>>> del mydata.color
And a method that operates on the data:
>>> mydata.squared()
4
As before, we can access the data in a separate thread:
>>> log = []
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
>>> log
[[('color', 'red'), ('initialized', True)], 11]
without affecting this thread's data:
>>> mydata.number
2
>>> mydata.color
Traceback (most recent call last):
...
AttributeError: 'MyLocal' object has no attribute 'color'
Note that subclasses can define slots, but they are not thread
local. They are shared across threads:
>>> class MyLocal(local):
... __slots__ = 'number'
>>> mydata = MyLocal()
>>> mydata.number = 42
>>> mydata.color = 'red'
So, the separate thread:
>>> thread = threading.Thread(target=f)
>>> thread.start()
>>> thread.join()
affects what we see:
>>> mydata.number
11
>>> del mydata
"""
# Threading import is at end
class _localbase(object):
__slots__ = '_local__key', '_local__args', '_local__lock'
def __new__(cls, *args, **kw):
self = object.__new__(cls)
key = 'thread.local.' + str(id(self))
object.__setattr__(self, '_local__key', key)
object.__setattr__(self, '_local__args', (args, kw))
object.__setattr__(self, '_local__lock', RLock())
if args or kw and (cls.__init__ is object.__init__):
raise TypeError("Initialization arguments are not supported")
# We need to create the thread dict in anticipation of
# __init__ being called, to make sure we don't call it
# again ourselves.
dict = object.__getattribute__(self, '__dict__')
currentThread().__dict__[key] = dict
return self
def _patch(self):
key = object.__getattribute__(self, '_local__key')
d = currentThread().__dict__.get(key)
if d is None:
d = {}
currentThread().__dict__[key] = d
object.__setattr__(self, '__dict__', d)
# we have a new instance dict, so call out __init__ if we have
# one
cls = type(self)
if cls.__init__ is not object.__init__:
args, kw = object.__getattribute__(self, '_local__args')
cls.__init__(self, *args, **kw)
else:
object.__setattr__(self, '__dict__', d)
class local(_localbase):
def __getattribute__(self, name):
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__getattribute__(self, name)
finally:
lock.release()
def __setattr__(self, name, value):
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__setattr__(self, name, value)
finally:
lock.release()
def __delattr__(self, name):
lock = object.__getattribute__(self, '_local__lock')
lock.acquire()
try:
_patch(self)
return object.__delattr__(self, name)
finally:
lock.release()
def __del__():
threading_enumerate = enumerate
__getattribute__ = object.__getattribute__
def __del__(self):
key = __getattribute__(self, '_local__key')
try:
threads = list(threading_enumerate())
except:
# if enumerate fails, as it seems to do during
# shutdown, we'll skip cleanup under the assumption
# that there is nothing to clean up
return
for thread in threads:
try:
__dict__ = thread.__dict__
except AttributeError:
# Thread is dying, rest in peace
continue
if key in __dict__:
try:
del __dict__[key]
except KeyError:
pass # didn't have anything in this thread
return __del__
__del__ = __del__()
from threading import currentThread, enumerate, RLock

529
lib/cherrypy/_cptools.py Normal file
View File

@@ -0,0 +1,529 @@
"""CherryPy tools. A "tool" is any helper, adapted to CP.
Tools are usually designed to be used in a variety of ways (although some
may only offer one if they choose):
Library calls
All tools are callables that can be used wherever needed.
The arguments are straightforward and should be detailed within the
docstring.
Function decorators
All tools, when called, may be used as decorators which configure
individual CherryPy page handlers (methods on the CherryPy tree).
That is, "@tools.anytool()" should "turn on" the tool via the
decorated function's _cp_config attribute.
CherryPy config
If a tool exposes a "_setup" callable, it will be called
once per Request (if the feature is "turned on" via config).
Tools may be implemented as any object with a namespace. The builtins
are generally either modules or instances of the tools.Tool class.
"""
import sys
import warnings
import cherrypy
def _getargs(func):
"""Return the names of all static arguments to the given function."""
# Use this instead of importing inspect for less mem overhead.
import types
if sys.version_info >= (3, 0):
if isinstance(func, types.MethodType):
func = func.__func__
co = func.__code__
else:
if isinstance(func, types.MethodType):
func = func.im_func
co = func.func_code
return co.co_varnames[:co.co_argcount]
_attr_error = (
"CherryPy Tools cannot be turned on directly. Instead, turn them "
"on via config, or use them as decorators on your page handlers."
)
class Tool(object):
"""A registered function for use with CherryPy request-processing hooks.
help(tool.callable) should give you more information about this Tool.
"""
namespace = "tools"
def __init__(self, point, callable, name=None, priority=50):
self._point = point
self.callable = callable
self._name = name
self._priority = priority
self.__doc__ = self.callable.__doc__
self._setargs()
def _get_on(self):
raise AttributeError(_attr_error)
def _set_on(self, value):
raise AttributeError(_attr_error)
on = property(_get_on, _set_on)
def _setargs(self):
"""Copy func parameter names to obj attributes."""
try:
for arg in _getargs(self.callable):
setattr(self, arg, None)
except (TypeError, AttributeError):
if hasattr(self.callable, "__call__"):
for arg in _getargs(self.callable.__call__):
setattr(self, arg, None)
# IronPython 1.0 raises NotImplementedError because
# inspect.getargspec tries to access Python bytecode
# in co_code attribute.
except NotImplementedError:
pass
# IronPython 1B1 may raise IndexError in some cases,
# but if we trap it here it doesn't prevent CP from working.
except IndexError:
pass
def _merged_args(self, d=None):
"""Return a dict of configuration entries for this Tool."""
if d:
conf = d.copy()
else:
conf = {}
tm = cherrypy.serving.request.toolmaps[self.namespace]
if self._name in tm:
conf.update(tm[self._name])
if "on" in conf:
del conf["on"]
return conf
def __call__(self, *args, **kwargs):
"""Compile-time decorator (turn on the tool in config).
For example::
@tools.proxy()
def whats_my_base(self):
return cherrypy.request.base
whats_my_base.exposed = True
"""
if args:
raise TypeError("The %r Tool does not accept positional "
"arguments; you must use keyword arguments."
% self._name)
def tool_decorator(f):
if not hasattr(f, "_cp_config"):
f._cp_config = {}
subspace = self.namespace + "." + self._name + "."
f._cp_config[subspace + "on"] = True
for k, v in kwargs.items():
f._cp_config[subspace + k] = v
return f
return tool_decorator
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
"""
conf = self._merged_args()
p = conf.pop("priority", None)
if p is None:
p = getattr(self.callable, "priority", self._priority)
cherrypy.serving.request.hooks.attach(self._point, self.callable,
priority=p, **conf)
class HandlerTool(Tool):
"""Tool which is called 'before main', that may skip normal handlers.
If the tool successfully handles the request (by setting response.body),
if should return True. This will cause CherryPy to skip any 'normal' page
handler. If the tool did not handle the request, it should return False
to tell CherryPy to continue on and call the normal page handler. If the
tool is declared AS a page handler (see the 'handler' method), returning
False will raise NotFound.
"""
def __init__(self, callable, name=None):
Tool.__init__(self, 'before_handler', callable, name)
def handler(self, *args, **kwargs):
"""Use this tool as a CherryPy page handler.
For example::
class Root:
nav = tools.staticdir.handler(section="/nav", dir="nav",
root=absDir)
"""
def handle_func(*a, **kw):
handled = self.callable(*args, **self._merged_args(kwargs))
if not handled:
raise cherrypy.NotFound()
return cherrypy.serving.response.body
handle_func.exposed = True
return handle_func
def _wrapper(self, **kwargs):
if self.callable(**kwargs):
cherrypy.serving.request.handler = None
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
"""
conf = self._merged_args()
p = conf.pop("priority", None)
if p is None:
p = getattr(self.callable, "priority", self._priority)
cherrypy.serving.request.hooks.attach(self._point, self._wrapper,
priority=p, **conf)
class HandlerWrapperTool(Tool):
"""Tool which wraps request.handler in a provided wrapper function.
The 'newhandler' arg must be a handler wrapper function that takes a
'next_handler' argument, plus ``*args`` and ``**kwargs``. Like all
page handler
functions, it must return an iterable for use as cherrypy.response.body.
For example, to allow your 'inner' page handlers to return dicts
which then get interpolated into a template::
def interpolator(next_handler, *args, **kwargs):
filename = cherrypy.request.config.get('template')
cherrypy.response.template = env.get_template(filename)
response_dict = next_handler(*args, **kwargs)
return cherrypy.response.template.render(**response_dict)
cherrypy.tools.jinja = HandlerWrapperTool(interpolator)
"""
def __init__(self, newhandler, point='before_handler', name=None,
priority=50):
self.newhandler = newhandler
self._point = point
self._name = name
self._priority = priority
def callable(self, *args, **kwargs):
innerfunc = cherrypy.serving.request.handler
def wrap(*args, **kwargs):
return self.newhandler(innerfunc, *args, **kwargs)
cherrypy.serving.request.handler = wrap
class ErrorTool(Tool):
"""Tool which is used to replace the default request.error_response."""
def __init__(self, callable, name=None):
Tool.__init__(self, None, callable, name)
def _wrapper(self):
self.callable(**self._merged_args())
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
"""
cherrypy.serving.request.error_response = self._wrapper
# Builtin tools #
from cherrypy.lib import cptools, encoding, auth, static, jsontools
from cherrypy.lib import sessions as _sessions, xmlrpcutil as _xmlrpc
from cherrypy.lib import caching as _caching
from cherrypy.lib import auth_basic, auth_digest
class SessionTool(Tool):
"""Session Tool for CherryPy.
sessions.locking
When 'implicit' (the default), the session will be locked for you,
just before running the page handler.
When 'early', the session will be locked before reading the request
body. This is off by default for safety reasons; for example,
a large upload would block the session, denying an AJAX
progress meter
(`issue <https://bitbucket.org/cherrypy/cherrypy/issue/630>`_).
When 'explicit' (or any other value), you need to call
cherrypy.session.acquire_lock() yourself before using
session data.
"""
def __init__(self):
# _sessions.init must be bound after headers are read
Tool.__init__(self, 'before_request_body', _sessions.init)
def _lock_session(self):
cherrypy.serving.session.acquire_lock()
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
"""
hooks = cherrypy.serving.request.hooks
conf = self._merged_args()
p = conf.pop("priority", None)
if p is None:
p = getattr(self.callable, "priority", self._priority)
hooks.attach(self._point, self.callable, priority=p, **conf)
locking = conf.pop('locking', 'implicit')
if locking == 'implicit':
hooks.attach('before_handler', self._lock_session)
elif locking == 'early':
# Lock before the request body (but after _sessions.init runs!)
hooks.attach('before_request_body', self._lock_session,
priority=60)
else:
# Don't lock
pass
hooks.attach('before_finalize', _sessions.save)
hooks.attach('on_end_request', _sessions.close)
def regenerate(self):
"""Drop the current session and make a new one (with a new id)."""
sess = cherrypy.serving.session
sess.regenerate()
# Grab cookie-relevant tool args
conf = dict([(k, v) for k, v in self._merged_args().items()
if k in ('path', 'path_header', 'name', 'timeout',
'domain', 'secure')])
_sessions.set_response_cookie(**conf)
class XMLRPCController(object):
"""A Controller (page handler collection) for XML-RPC.
To use it, have your controllers subclass this base class (it will
turn on the tool for you).
You can also supply the following optional config entries::
tools.xmlrpc.encoding: 'utf-8'
tools.xmlrpc.allow_none: 0
XML-RPC is a rather discontinuous layer over HTTP; dispatching to the
appropriate handler must first be performed according to the URL, and
then a second dispatch step must take place according to the RPC method
specified in the request body. It also allows a superfluous "/RPC2"
prefix in the URL, supplies its own handler args in the body, and
requires a 200 OK "Fault" response instead of 404 when the desired
method is not found.
Therefore, XML-RPC cannot be implemented for CherryPy via a Tool alone.
This Controller acts as the dispatch target for the first half (based
on the URL); it then reads the RPC method from the request body and
does its own second dispatch step based on that method. It also reads
body params, and returns a Fault on error.
The XMLRPCDispatcher strips any /RPC2 prefix; if you aren't using /RPC2
in your URL's, you can safely skip turning on the XMLRPCDispatcher.
Otherwise, you need to use declare it in config::
request.dispatch: cherrypy.dispatch.XMLRPCDispatcher()
"""
# Note we're hard-coding this into the 'tools' namespace. We could do
# a huge amount of work to make it relocatable, but the only reason why
# would be if someone actually disabled the default_toolbox. Meh.
_cp_config = {'tools.xmlrpc.on': True}
def default(self, *vpath, **params):
rpcparams, rpcmethod = _xmlrpc.process_body()
subhandler = self
for attr in str(rpcmethod).split('.'):
subhandler = getattr(subhandler, attr, None)
if subhandler and getattr(subhandler, "exposed", False):
body = subhandler(*(vpath + rpcparams), **params)
else:
# https://bitbucket.org/cherrypy/cherrypy/issue/533
# if a method is not found, an xmlrpclib.Fault should be returned
# raising an exception here will do that; see
# cherrypy.lib.xmlrpcutil.on_error
raise Exception('method "%s" is not supported' % attr)
conf = cherrypy.serving.request.toolmaps['tools'].get("xmlrpc", {})
_xmlrpc.respond(body,
conf.get('encoding', 'utf-8'),
conf.get('allow_none', 0))
return cherrypy.serving.response.body
default.exposed = True
class SessionAuthTool(HandlerTool):
def _setargs(self):
for name in dir(cptools.SessionAuth):
if not name.startswith("__"):
setattr(self, name, None)
class CachingTool(Tool):
"""Caching Tool for CherryPy."""
def _wrapper(self, **kwargs):
request = cherrypy.serving.request
if _caching.get(**kwargs):
request.handler = None
else:
if request.cacheable:
# Note the devious technique here of adding hooks on the fly
request.hooks.attach('before_finalize', _caching.tee_output,
priority=90)
_wrapper.priority = 20
def _setup(self):
"""Hook caching into cherrypy.request."""
conf = self._merged_args()
p = conf.pop("priority", None)
cherrypy.serving.request.hooks.attach('before_handler', self._wrapper,
priority=p, **conf)
class Toolbox(object):
"""A collection of Tools.
This object also functions as a config namespace handler for itself.
Custom toolboxes should be added to each Application's toolboxes dict.
"""
def __init__(self, namespace):
self.namespace = namespace
def __setattr__(self, name, value):
# If the Tool._name is None, supply it from the attribute name.
if isinstance(value, Tool):
if value._name is None:
value._name = name
value.namespace = self.namespace
object.__setattr__(self, name, value)
def __enter__(self):
"""Populate request.toolmaps from tools specified in config."""
cherrypy.serving.request.toolmaps[self.namespace] = map = {}
def populate(k, v):
toolname, arg = k.split(".", 1)
bucket = map.setdefault(toolname, {})
bucket[arg] = v
return populate
def __exit__(self, exc_type, exc_val, exc_tb):
"""Run tool._setup() for each tool in our toolmap."""
map = cherrypy.serving.request.toolmaps.get(self.namespace)
if map:
for name, settings in map.items():
if settings.get("on", False):
tool = getattr(self, name)
tool._setup()
class DeprecatedTool(Tool):
_name = None
warnmsg = "This Tool is deprecated."
def __init__(self, point, warnmsg=None):
self.point = point
if warnmsg is not None:
self.warnmsg = warnmsg
def __call__(self, *args, **kwargs):
warnings.warn(self.warnmsg)
def tool_decorator(f):
return f
return tool_decorator
def _setup(self):
warnings.warn(self.warnmsg)
default_toolbox = _d = Toolbox("tools")
_d.session_auth = SessionAuthTool(cptools.session_auth)
_d.allow = Tool('on_start_resource', cptools.allow)
_d.proxy = Tool('before_request_body', cptools.proxy, priority=30)
_d.response_headers = Tool('on_start_resource', cptools.response_headers)
_d.log_tracebacks = Tool('before_error_response', cptools.log_traceback)
_d.log_headers = Tool('before_error_response', cptools.log_request_headers)
_d.log_hooks = Tool('on_end_request', cptools.log_hooks, priority=100)
_d.err_redirect = ErrorTool(cptools.redirect)
_d.etags = Tool('before_finalize', cptools.validate_etags, priority=75)
_d.decode = Tool('before_request_body', encoding.decode)
# the order of encoding, gzip, caching is important
_d.encode = Tool('before_handler', encoding.ResponseEncoder, priority=70)
_d.gzip = Tool('before_finalize', encoding.gzip, priority=80)
_d.staticdir = HandlerTool(static.staticdir)
_d.staticfile = HandlerTool(static.staticfile)
_d.sessions = SessionTool()
_d.xmlrpc = ErrorTool(_xmlrpc.on_error)
_d.caching = CachingTool('before_handler', _caching.get, 'caching')
_d.expires = Tool('before_finalize', _caching.expires)
_d.tidy = DeprecatedTool(
'before_finalize',
"The tidy tool has been removed from the standard distribution of "
"CherryPy. The most recent version can be found at "
"http://tools.cherrypy.org/browser.")
_d.nsgmls = DeprecatedTool(
'before_finalize',
"The nsgmls tool has been removed from the standard distribution of "
"CherryPy. The most recent version can be found at "
"http://tools.cherrypy.org/browser.")
_d.ignore_headers = Tool('before_request_body', cptools.ignore_headers)
_d.referer = Tool('before_request_body', cptools.referer)
_d.basic_auth = Tool('on_start_resource', auth.basic_auth)
_d.digest_auth = Tool('on_start_resource', auth.digest_auth)
_d.trailing_slash = Tool('before_handler', cptools.trailing_slash, priority=60)
_d.flatten = Tool('before_finalize', cptools.flatten)
_d.accept = Tool('on_start_resource', cptools.accept)
_d.redirect = Tool('on_start_resource', cptools.redirect)
_d.autovary = Tool('on_start_resource', cptools.autovary, priority=0)
_d.json_in = Tool('before_request_body', jsontools.json_in, priority=30)
_d.json_out = Tool('before_handler', jsontools.json_out, priority=30)
_d.auth_basic = Tool('before_handler', auth_basic.basic_auth, priority=1)
_d.auth_digest = Tool('before_handler', auth_digest.digest_auth, priority=1)
del _d, cptools, encoding, auth, static

299
lib/cherrypy/_cptree.py Normal file
View File

@@ -0,0 +1,299 @@
"""CherryPy Application and Tree objects."""
import os
import cherrypy
from cherrypy._cpcompat import ntou, py3k
from cherrypy import _cpconfig, _cplogging, _cprequest, _cpwsgi, tools
from cherrypy.lib import httputil
class Application(object):
"""A CherryPy Application.
Servers and gateways should not instantiate Request objects directly.
Instead, they should ask an Application object for a request object.
An instance of this class may also be used as a WSGI callable
(WSGI application object) for itself.
"""
root = None
"""The top-most container of page handlers for this app. Handlers should
be arranged in a hierarchy of attributes, matching the expected URI
hierarchy; the default dispatcher then searches this hierarchy for a
matching handler. When using a dispatcher other than the default,
this value may be None."""
config = {}
"""A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict
of {key: value} pairs."""
namespaces = _cpconfig.NamespaceSet()
toolboxes = {'tools': cherrypy.tools}
log = None
"""A LogManager instance. See _cplogging."""
wsgiapp = None
"""A CPWSGIApp instance. See _cpwsgi."""
request_class = _cprequest.Request
response_class = _cprequest.Response
relative_urls = False
def __init__(self, root, script_name="", config=None):
self.log = _cplogging.LogManager(id(self), cherrypy.log.logger_root)
self.root = root
self.script_name = script_name
self.wsgiapp = _cpwsgi.CPWSGIApp(self)
self.namespaces = self.namespaces.copy()
self.namespaces["log"] = lambda k, v: setattr(self.log, k, v)
self.namespaces["wsgi"] = self.wsgiapp.namespace_handler
self.config = self.__class__.config.copy()
if config:
self.merge(config)
def __repr__(self):
return "%s.%s(%r, %r)" % (self.__module__, self.__class__.__name__,
self.root, self.script_name)
script_name_doc = """The URI "mount point" for this app. A mount point
is that portion of the URI which is constant for all URIs that are
serviced by this application; it does not include scheme, host, or proxy
("virtual host") portions of the URI.
For example, if script_name is "/my/cool/app", then the URL
"http://www.example.com/my/cool/app/page1" might be handled by a
"page1" method on the root object.
The value of script_name MUST NOT end in a slash. If the script_name
refers to the root of the URI, it MUST be an empty string (not "/").
If script_name is explicitly set to None, then the script_name will be
provided for each call from request.wsgi_environ['SCRIPT_NAME'].
"""
def _get_script_name(self):
if self._script_name is not None:
return self._script_name
# A `_script_name` with a value of None signals that the script name
# should be pulled from WSGI environ.
return cherrypy.serving.request.wsgi_environ['SCRIPT_NAME'].rstrip("/")
def _set_script_name(self, value):
if value:
value = value.rstrip("/")
self._script_name = value
script_name = property(fget=_get_script_name, fset=_set_script_name,
doc=script_name_doc)
def merge(self, config):
"""Merge the given config into self.config."""
_cpconfig.merge(self.config, config)
# Handle namespaces specified in config.
self.namespaces(self.config.get("/", {}))
def find_config(self, path, key, default=None):
"""Return the most-specific value for key along path, or default."""
trail = path or "/"
while trail:
nodeconf = self.config.get(trail, {})
if key in nodeconf:
return nodeconf[key]
lastslash = trail.rfind("/")
if lastslash == -1:
break
elif lastslash == 0 and trail != "/":
trail = "/"
else:
trail = trail[:lastslash]
return default
def get_serving(self, local, remote, scheme, sproto):
"""Create and return a Request and Response object."""
req = self.request_class(local, remote, scheme, sproto)
req.app = self
for name, toolbox in self.toolboxes.items():
req.namespaces[name] = toolbox
resp = self.response_class()
cherrypy.serving.load(req, resp)
cherrypy.engine.publish('acquire_thread')
cherrypy.engine.publish('before_request')
return req, resp
def release_serving(self):
"""Release the current serving (request and response)."""
req = cherrypy.serving.request
cherrypy.engine.publish('after_request')
try:
req.close()
except:
cherrypy.log(traceback=True, severity=40)
cherrypy.serving.clear()
def __call__(self, environ, start_response):
return self.wsgiapp(environ, start_response)
class Tree(object):
"""A registry of CherryPy applications, mounted at diverse points.
An instance of this class may also be used as a WSGI callable
(WSGI application object), in which case it dispatches to all
mounted apps.
"""
apps = {}
"""
A dict of the form {script name: application}, where "script name"
is a string declaring the URI mount point (no trailing slash), and
"application" is an instance of cherrypy.Application (or an arbitrary
WSGI callable if you happen to be using a WSGI server)."""
def __init__(self):
self.apps = {}
def mount(self, root, script_name="", config=None):
"""Mount a new app from a root object, script_name, and config.
root
An instance of a "controller class" (a collection of page
handler methods) which represents the root of the application.
This may also be an Application instance, or None if using
a dispatcher other than the default.
script_name
A string containing the "mount point" of the application.
This should start with a slash, and be the path portion of the
URL at which to mount the given root. For example, if root.index()
will handle requests to "http://www.example.com:8080/dept/app1/",
then the script_name argument would be "/dept/app1".
It MUST NOT end in a slash. If the script_name refers to the
root of the URI, it MUST be an empty string (not "/").
config
A file or dict containing application config.
"""
if script_name is None:
raise TypeError(
"The 'script_name' argument may not be None. Application "
"objects may, however, possess a script_name of None (in "
"order to inpect the WSGI environ for SCRIPT_NAME upon each "
"request). You cannot mount such Applications on this Tree; "
"you must pass them to a WSGI server interface directly.")
# Next line both 1) strips trailing slash and 2) maps "/" -> "".
script_name = script_name.rstrip("/")
if isinstance(root, Application):
app = root
if script_name != "" and script_name != app.script_name:
raise ValueError(
"Cannot specify a different script name and pass an "
"Application instance to cherrypy.mount")
script_name = app.script_name
else:
app = Application(root, script_name)
# If mounted at "", add favicon.ico
if (script_name == "" and root is not None
and not hasattr(root, "favicon_ico")):
favicon = os.path.join(os.getcwd(), os.path.dirname(__file__),
"favicon.ico")
root.favicon_ico = tools.staticfile.handler(favicon)
if config:
app.merge(config)
self.apps[script_name] = app
return app
def graft(self, wsgi_callable, script_name=""):
"""Mount a wsgi callable at the given script_name."""
# Next line both 1) strips trailing slash and 2) maps "/" -> "".
script_name = script_name.rstrip("/")
self.apps[script_name] = wsgi_callable
def script_name(self, path=None):
"""The script_name of the app at the given path, or None.
If path is None, cherrypy.request is used.
"""
if path is None:
try:
request = cherrypy.serving.request
path = httputil.urljoin(request.script_name,
request.path_info)
except AttributeError:
return None
while True:
if path in self.apps:
return path
if path == "":
return None
# Move one node up the tree and try again.
path = path[:path.rfind("/")]
def __call__(self, environ, start_response):
# If you're calling this, then you're probably setting SCRIPT_NAME
# to '' (some WSGI servers always set SCRIPT_NAME to '').
# Try to look up the app using the full path.
env1x = environ
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
env1x = _cpwsgi.downgrade_wsgi_ux_to_1x(environ)
path = httputil.urljoin(env1x.get('SCRIPT_NAME', ''),
env1x.get('PATH_INFO', ''))
sn = self.script_name(path or "/")
if sn is None:
start_response('404 Not Found', [])
return []
app = self.apps[sn]
# Correct the SCRIPT_NAME and PATH_INFO environ entries.
environ = environ.copy()
if not py3k:
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
# Python 2/WSGI u.0: all strings MUST be of type unicode
enc = environ[ntou('wsgi.url_encoding')]
environ[ntou('SCRIPT_NAME')] = sn.decode(enc)
environ[ntou('PATH_INFO')] = path[
len(sn.rstrip("/")):].decode(enc)
else:
# Python 2/WSGI 1.x: all strings MUST be of type str
environ['SCRIPT_NAME'] = sn
environ['PATH_INFO'] = path[len(sn.rstrip("/")):]
else:
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
# Python 3/WSGI u.0: all strings MUST be full unicode
environ['SCRIPT_NAME'] = sn
environ['PATH_INFO'] = path[len(sn.rstrip("/")):]
else:
# Python 3/WSGI 1.x: all strings MUST be ISO-8859-1 str
environ['SCRIPT_NAME'] = sn.encode(
'utf-8').decode('ISO-8859-1')
environ['PATH_INFO'] = path[
len(sn.rstrip("/")):].encode('utf-8').decode('ISO-8859-1')
return app(environ, start_response)

438
lib/cherrypy/_cpwsgi.py Normal file
View File

@@ -0,0 +1,438 @@
"""WSGI interface (see PEP 333 and 3333).
Note that WSGI environ keys and values are 'native strings'; that is,
whatever the type of "" is. For Python 2, that's a byte string; for Python 3,
it's a unicode string. But PEP 3333 says: "even if Python's str type is
actually Unicode "under the hood", the content of native strings must
still be translatable to bytes via the Latin-1 encoding!"
"""
import sys as _sys
import cherrypy as _cherrypy
from cherrypy._cpcompat import BytesIO, bytestr, ntob, ntou, py3k, unicodestr
from cherrypy import _cperror
from cherrypy.lib import httputil
from cherrypy.lib import is_closable_iterator
def downgrade_wsgi_ux_to_1x(environ):
"""Return a new environ dict for WSGI 1.x from the given WSGI u.x environ.
"""
env1x = {}
url_encoding = environ[ntou('wsgi.url_encoding')]
for k, v in list(environ.items()):
if k in [ntou('PATH_INFO'), ntou('SCRIPT_NAME'), ntou('QUERY_STRING')]:
v = v.encode(url_encoding)
elif isinstance(v, unicodestr):
v = v.encode('ISO-8859-1')
env1x[k.encode('ISO-8859-1')] = v
return env1x
class VirtualHost(object):
"""Select a different WSGI application based on the Host header.
This can be useful when running multiple sites within one CP server.
It allows several domains to point to different applications. For example::
root = Root()
RootApp = cherrypy.Application(root)
Domain2App = cherrypy.Application(root)
SecureApp = cherrypy.Application(Secure())
vhost = cherrypy._cpwsgi.VirtualHost(RootApp,
domains={'www.domain2.example': Domain2App,
'www.domain2.example:443': SecureApp,
})
cherrypy.tree.graft(vhost)
"""
default = None
"""Required. The default WSGI application."""
use_x_forwarded_host = True
"""If True (the default), any "X-Forwarded-Host"
request header will be used instead of the "Host" header. This
is commonly added by HTTP servers (such as Apache) when proxying."""
domains = {}
"""A dict of {host header value: application} pairs.
The incoming "Host" request header is looked up in this dict,
and, if a match is found, the corresponding WSGI application
will be called instead of the default. Note that you often need
separate entries for "example.com" and "www.example.com".
In addition, "Host" headers may contain the port number.
"""
def __init__(self, default, domains=None, use_x_forwarded_host=True):
self.default = default
self.domains = domains or {}
self.use_x_forwarded_host = use_x_forwarded_host
def __call__(self, environ, start_response):
domain = environ.get('HTTP_HOST', '')
if self.use_x_forwarded_host:
domain = environ.get("HTTP_X_FORWARDED_HOST", domain)
nextapp = self.domains.get(domain)
if nextapp is None:
nextapp = self.default
return nextapp(environ, start_response)
class InternalRedirector(object):
"""WSGI middleware that handles raised cherrypy.InternalRedirect."""
def __init__(self, nextapp, recursive=False):
self.nextapp = nextapp
self.recursive = recursive
def __call__(self, environ, start_response):
redirections = []
while True:
environ = environ.copy()
try:
return self.nextapp(environ, start_response)
except _cherrypy.InternalRedirect:
ir = _sys.exc_info()[1]
sn = environ.get('SCRIPT_NAME', '')
path = environ.get('PATH_INFO', '')
qs = environ.get('QUERY_STRING', '')
# Add the *previous* path_info + qs to redirections.
old_uri = sn + path
if qs:
old_uri += "?" + qs
redirections.append(old_uri)
if not self.recursive:
# Check to see if the new URI has been redirected to
# already
new_uri = sn + ir.path
if ir.query_string:
new_uri += "?" + ir.query_string
if new_uri in redirections:
ir.request.close()
raise RuntimeError("InternalRedirector visited the "
"same URL twice: %r" % new_uri)
# Munge the environment and try again.
environ['REQUEST_METHOD'] = "GET"
environ['PATH_INFO'] = ir.path
environ['QUERY_STRING'] = ir.query_string
environ['wsgi.input'] = BytesIO()
environ['CONTENT_LENGTH'] = "0"
environ['cherrypy.previous_request'] = ir.request
class ExceptionTrapper(object):
"""WSGI middleware that traps exceptions."""
def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)):
self.nextapp = nextapp
self.throws = throws
def __call__(self, environ, start_response):
return _TrappedResponse(
self.nextapp,
environ,
start_response,
self.throws
)
class _TrappedResponse(object):
response = iter([])
def __init__(self, nextapp, environ, start_response, throws):
self.nextapp = nextapp
self.environ = environ
self.start_response = start_response
self.throws = throws
self.started_response = False
self.response = self.trap(
self.nextapp, self.environ, self.start_response)
self.iter_response = iter(self.response)
def __iter__(self):
self.started_response = True
return self
if py3k:
def __next__(self):
return self.trap(next, self.iter_response)
else:
def next(self):
return self.trap(self.iter_response.next)
def close(self):
if hasattr(self.response, 'close'):
self.response.close()
def trap(self, func, *args, **kwargs):
try:
return func(*args, **kwargs)
except self.throws:
raise
except StopIteration:
raise
except:
tb = _cperror.format_exc()
#print('trapped (started %s):' % self.started_response, tb)
_cherrypy.log(tb, severity=40)
if not _cherrypy.request.show_tracebacks:
tb = ""
s, h, b = _cperror.bare_error(tb)
if py3k:
# What fun.
s = s.decode('ISO-8859-1')
h = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1'))
for k, v in h]
if self.started_response:
# Empty our iterable (so future calls raise StopIteration)
self.iter_response = iter([])
else:
self.iter_response = iter(b)
try:
self.start_response(s, h, _sys.exc_info())
except:
# "The application must not trap any exceptions raised by
# start_response, if it called start_response with exc_info.
# Instead, it should allow such exceptions to propagate
# back to the server or gateway."
# But we still log and call close() to clean up ourselves.
_cherrypy.log(traceback=True, severity=40)
raise
if self.started_response:
return ntob("").join(b)
else:
return b
# WSGI-to-CP Adapter #
class AppResponse(object):
"""WSGI response iterable for CherryPy applications."""
def __init__(self, environ, start_response, cpapp):
self.cpapp = cpapp
try:
if not py3k:
if environ.get(ntou('wsgi.version')) == (ntou('u'), 0):
environ = downgrade_wsgi_ux_to_1x(environ)
self.environ = environ
self.run()
r = _cherrypy.serving.response
outstatus = r.output_status
if not isinstance(outstatus, bytestr):
raise TypeError("response.output_status is not a byte string.")
outheaders = []
for k, v in r.header_list:
if not isinstance(k, bytestr):
raise TypeError(
"response.header_list key %r is not a byte string." %
k)
if not isinstance(v, bytestr):
raise TypeError(
"response.header_list value %r is not a byte string." %
v)
outheaders.append((k, v))
if py3k:
# According to PEP 3333, when using Python 3, the response
# status and headers must be bytes masquerading as unicode;
# that is, they must be of type "str" but are restricted to
# code points in the "latin-1" set.
outstatus = outstatus.decode('ISO-8859-1')
outheaders = [(k.decode('ISO-8859-1'), v.decode('ISO-8859-1'))
for k, v in outheaders]
self.iter_response = iter(r.body)
self.write = start_response(outstatus, outheaders)
except:
self.close()
raise
def __iter__(self):
return self
if py3k:
def __next__(self):
return next(self.iter_response)
else:
def next(self):
return self.iter_response.next()
def close(self):
"""Close and de-reference the current request and response. (Core)"""
streaming = _cherrypy.serving.response.stream
self.cpapp.release_serving()
# We avoid the expense of examining the iterator to see if it's
# closable unless we are streaming the response, as that's the
# only situation where we are going to have an iterator which
# may not have been exhausted yet.
if streaming and is_closable_iterator(self.iter_response):
iter_close = self.iter_response.close
try:
iter_close()
except Exception:
_cherrypy.log(traceback=True, severity=40)
def run(self):
"""Create a Request object using environ."""
env = self.environ.get
local = httputil.Host('', int(env('SERVER_PORT', 80)),
env('SERVER_NAME', ''))
remote = httputil.Host(env('REMOTE_ADDR', ''),
int(env('REMOTE_PORT', -1) or -1),
env('REMOTE_HOST', ''))
scheme = env('wsgi.url_scheme')
sproto = env('ACTUAL_SERVER_PROTOCOL', "HTTP/1.1")
request, resp = self.cpapp.get_serving(local, remote, scheme, sproto)
# LOGON_USER is served by IIS, and is the name of the
# user after having been mapped to a local account.
# Both IIS and Apache set REMOTE_USER, when possible.
request.login = env('LOGON_USER') or env('REMOTE_USER') or None
request.multithread = self.environ['wsgi.multithread']
request.multiprocess = self.environ['wsgi.multiprocess']
request.wsgi_environ = self.environ
request.prev = env('cherrypy.previous_request', None)
meth = self.environ['REQUEST_METHOD']
path = httputil.urljoin(self.environ.get('SCRIPT_NAME', ''),
self.environ.get('PATH_INFO', ''))
qs = self.environ.get('QUERY_STRING', '')
if py3k:
# This isn't perfect; if the given PATH_INFO is in the
# wrong encoding, it may fail to match the appropriate config
# section URI. But meh.
old_enc = self.environ.get('wsgi.url_encoding', 'ISO-8859-1')
new_enc = self.cpapp.find_config(self.environ.get('PATH_INFO', ''),
"request.uri_encoding", 'utf-8')
if new_enc.lower() != old_enc.lower():
# Even though the path and qs are unicode, the WSGI server
# is required by PEP 3333 to coerce them to ISO-8859-1
# masquerading as unicode. So we have to encode back to
# bytes and then decode again using the "correct" encoding.
try:
u_path = path.encode(old_enc).decode(new_enc)
u_qs = qs.encode(old_enc).decode(new_enc)
except (UnicodeEncodeError, UnicodeDecodeError):
# Just pass them through without transcoding and hope.
pass
else:
# Only set transcoded values if they both succeed.
path = u_path
qs = u_qs
rproto = self.environ.get('SERVER_PROTOCOL')
headers = self.translate_headers(self.environ)
rfile = self.environ['wsgi.input']
request.run(meth, path, qs, rproto, headers, rfile)
headerNames = {'HTTP_CGI_AUTHORIZATION': 'Authorization',
'CONTENT_LENGTH': 'Content-Length',
'CONTENT_TYPE': 'Content-Type',
'REMOTE_HOST': 'Remote-Host',
'REMOTE_ADDR': 'Remote-Addr',
}
def translate_headers(self, environ):
"""Translate CGI-environ header names to HTTP header names."""
for cgiName in environ:
# We assume all incoming header keys are uppercase already.
if cgiName in self.headerNames:
yield self.headerNames[cgiName], environ[cgiName]
elif cgiName[:5] == "HTTP_":
# Hackish attempt at recovering original header names.
translatedHeader = cgiName[5:].replace("_", "-")
yield translatedHeader, environ[cgiName]
class CPWSGIApp(object):
"""A WSGI application object for a CherryPy Application."""
pipeline = [('ExceptionTrapper', ExceptionTrapper),
('InternalRedirector', InternalRedirector),
]
"""A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a
constructor that takes an initial, positional 'nextapp' argument,
plus optional keyword arguments, and returns a WSGI application
(that takes environ and start_response arguments). The 'name' can
be any you choose, and will correspond to keys in self.config."""
head = None
"""Rather than nest all apps in the pipeline on each call, it's only
done the first time, and the result is memoized into self.head. Set
this to None again if you change self.pipeline after calling self."""
config = {}
"""A dict whose keys match names listed in the pipeline. Each
value is a further dict which will be passed to the corresponding
named WSGI callable (from the pipeline) as keyword arguments."""
response_class = AppResponse
"""The class to instantiate and return as the next app in the WSGI chain.
"""
def __init__(self, cpapp, pipeline=None):
self.cpapp = cpapp
self.pipeline = self.pipeline[:]
if pipeline:
self.pipeline.extend(pipeline)
self.config = self.config.copy()
def tail(self, environ, start_response):
"""WSGI application callable for the actual CherryPy application.
You probably shouldn't call this; call self.__call__ instead,
so that any WSGI middleware in self.pipeline can run first.
"""
return self.response_class(environ, start_response, self.cpapp)
def __call__(self, environ, start_response):
head = self.head
if head is None:
# Create and nest the WSGI apps in our pipeline (in reverse order).
# Then memoize the result in self.head.
head = self.tail
for name, callable in self.pipeline[::-1]:
conf = self.config.get(name, {})
head = callable(head, **conf)
self.head = head
return head(environ, start_response)
def namespace_handler(self, k, v):
"""Config handler for the 'wsgi' namespace."""
if k == "pipeline":
# Note this allows multiple 'wsgi.pipeline' config entries
# (but each entry will be processed in a 'random' order).
# It should also allow developers to set default middleware
# in code (passed to self.__init__) that deployers can add to
# (but not remove) via config.
self.pipeline.extend(v)
elif k == "response_class":
self.response_class = v
else:
name, arg = k.split(".", 1)
bucket = self.config.setdefault(name, {})
bucket[arg] = v

View File

@@ -0,0 +1,70 @@
"""WSGI server interface (see PEP 333). This adds some CP-specific bits to
the framework-agnostic wsgiserver package.
"""
import sys
import cherrypy
from cherrypy import wsgiserver
class CPWSGIServer(wsgiserver.CherryPyWSGIServer):
"""Wrapper for wsgiserver.CherryPyWSGIServer.
wsgiserver has been designed to not reference CherryPy in any way,
so that it can be used in other frameworks and applications. Therefore,
we wrap it here, so we can set our own mount points from cherrypy.tree
and apply some attributes from config -> cherrypy.server -> wsgiserver.
"""
def __init__(self, server_adapter=cherrypy.server):
self.server_adapter = server_adapter
self.max_request_header_size = (
self.server_adapter.max_request_header_size or 0
)
self.max_request_body_size = (
self.server_adapter.max_request_body_size or 0
)
server_name = (self.server_adapter.socket_host or
self.server_adapter.socket_file or
None)
self.wsgi_version = self.server_adapter.wsgi_version
s = wsgiserver.CherryPyWSGIServer
s.__init__(self, server_adapter.bind_addr, cherrypy.tree,
self.server_adapter.thread_pool,
server_name,
max=self.server_adapter.thread_pool_max,
request_queue_size=self.server_adapter.socket_queue_size,
timeout=self.server_adapter.socket_timeout,
shutdown_timeout=self.server_adapter.shutdown_timeout,
accepted_queue_size=self.server_adapter.accepted_queue_size,
accepted_queue_timeout=self.server_adapter.accepted_queue_timeout,
)
self.protocol = self.server_adapter.protocol_version
self.nodelay = self.server_adapter.nodelay
if sys.version_info >= (3, 0):
ssl_module = self.server_adapter.ssl_module or 'builtin'
else:
ssl_module = self.server_adapter.ssl_module or 'pyopenssl'
if self.server_adapter.ssl_context:
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
self.ssl_adapter = adapter_class(
self.server_adapter.ssl_certificate,
self.server_adapter.ssl_private_key,
self.server_adapter.ssl_certificate_chain)
self.ssl_adapter.context = self.server_adapter.ssl_context
elif self.server_adapter.ssl_certificate:
adapter_class = wsgiserver.get_ssl_adapter_class(ssl_module)
self.ssl_adapter = adapter_class(
self.server_adapter.ssl_certificate,
self.server_adapter.ssl_private_key,
self.server_adapter.ssl_certificate_chain)
self.stats['Enabled'] = getattr(
self.server_adapter, 'statistics', False)
def error_log(self, msg="", level=20, traceback=False):
cherrypy.engine.log(msg, level, traceback)

Some files were not shown because too many files have changed in this diff Show More