mirror of
https://github.com/reddit-archive/reddit.git
synced 2026-04-27 03:00:12 -04:00
1445 lines
44 KiB
Python
1445 lines
44 KiB
Python
# The contents of this file are subject to the Common Public Attribution
|
|
# License Version 1.0. (the "License"); you may not use this file except in
|
|
# compliance with the License. You may obtain a copy of the License at
|
|
# http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
|
|
# License Version 1.1, but Sections 14 and 15 have been added to cover use of
|
|
# software over a computer network and provide for limited attribution for the
|
|
# Original Developer. In addition, Exhibit A has been modified to be consistent
|
|
# with Exhibit B.
|
|
#
|
|
# Software distributed under the License is distributed on an "AS IS" basis,
|
|
# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
|
|
# the specific language governing rights and limitations under the License.
|
|
#
|
|
# The Original Code is reddit.
|
|
#
|
|
# The Original Developer is the Initial Developer. The Initial Developer of
|
|
# the Original Code is reddit Inc.
|
|
#
|
|
# All portions of the code written by reddit are Copyright (c) 2006-2012 reddit
|
|
# Inc. All Rights Reserved.
|
|
###############################################################################
|
|
|
|
import os
|
|
import base64
|
|
import traceback
|
|
|
|
from urllib import unquote_plus
|
|
from urllib2 import urlopen
|
|
from urlparse import urlparse, urlunparse
|
|
import signal
|
|
from copy import deepcopy
|
|
import cPickle as pickle
|
|
import re, math, random
|
|
from decimal import Decimal
|
|
|
|
from BeautifulSoup import BeautifulSoup, SoupStrainer
|
|
|
|
from time import sleep
|
|
from datetime import datetime, timedelta
|
|
from functools import wraps, partial, WRAPPER_ASSIGNMENTS
|
|
from pylons import g
|
|
from pylons.i18n import ungettext, _
|
|
from r2.lib.filters import _force_unicode, _force_utf8
|
|
from mako.filters import url_escape
|
|
from r2.lib.contrib import ipaddress
|
|
from r2.lib.require import require, require_split
|
|
import snudown
|
|
|
|
from r2.lib.utils._utils import *
|
|
|
|
iters = (list, tuple, set)
|
|
|
|
def randstr(len, reallyrandom = False):
|
|
"""If reallyrandom = False, generates a random alphanumeric string
|
|
(base-36 compatible) of length len. If reallyrandom, add
|
|
uppercase and punctuation (which we'll call 'base-93' for the sake
|
|
of argument) and suitable for use as salt."""
|
|
alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789'
|
|
if reallyrandom:
|
|
alphabet += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&\()*+,-./:;<=>?@[\\]^_{|}~'
|
|
return ''.join(random.choice(alphabet)
|
|
for i in range(len))
|
|
|
|
class Storage(dict):
|
|
"""
|
|
A Storage object is like a dictionary except `obj.foo` can be used
|
|
in addition to `obj['foo']`.
|
|
|
|
>>> o = storage(a=1)
|
|
>>> o.a
|
|
1
|
|
>>> o['a']
|
|
1
|
|
>>> o.a = 2
|
|
>>> o['a']
|
|
2
|
|
>>> del o.a
|
|
>>> o.a
|
|
Traceback (most recent call last):
|
|
...
|
|
AttributeError: 'a'
|
|
|
|
"""
|
|
def __getattr__(self, key):
|
|
try:
|
|
return self[key]
|
|
except KeyError, k:
|
|
raise AttributeError, k
|
|
|
|
def __setattr__(self, key, value):
|
|
self[key] = value
|
|
|
|
def __delattr__(self, key):
|
|
try:
|
|
del self[key]
|
|
except KeyError, k:
|
|
raise AttributeError, k
|
|
|
|
def __repr__(self):
|
|
return '<Storage ' + dict.__repr__(self) + '>'
|
|
|
|
storage = Storage
|
|
|
|
def storify(mapping, *requireds, **defaults):
|
|
"""
|
|
Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
|
|
d doesn't have all of the keys in `requireds` and using the default
|
|
values for keys found in `defaults`.
|
|
|
|
For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
|
|
`storage({'a':1, 'b':2, 'c':3})`.
|
|
|
|
If a `storify` value is a list (e.g. multiple values in a form submission),
|
|
`storify` returns the last element of the list, unless the key appears in
|
|
`defaults` as a list. Thus:
|
|
|
|
>>> storify({'a':[1, 2]}).a
|
|
2
|
|
>>> storify({'a':[1, 2]}, a=[]).a
|
|
[1, 2]
|
|
>>> storify({'a':1}, a=[]).a
|
|
[1]
|
|
>>> storify({}, a=[]).a
|
|
[]
|
|
|
|
Similarly, if the value has a `value` attribute, `storify will return _its_
|
|
value, unless the key appears in `defaults` as a dictionary.
|
|
|
|
>>> storify({'a':storage(value=1)}).a
|
|
1
|
|
>>> storify({'a':storage(value=1)}, a={}).a
|
|
<Storage {'value': 1}>
|
|
>>> storify({}, a={}).a
|
|
{}
|
|
|
|
"""
|
|
def getvalue(x):
|
|
if hasattr(x, 'value'):
|
|
return x.value
|
|
else:
|
|
return x
|
|
|
|
stor = Storage()
|
|
for key in requireds + tuple(mapping.keys()):
|
|
value = mapping[key]
|
|
if isinstance(value, list):
|
|
if isinstance(defaults.get(key), list):
|
|
value = [getvalue(x) for x in value]
|
|
else:
|
|
value = value[-1]
|
|
if not isinstance(defaults.get(key), dict):
|
|
value = getvalue(value)
|
|
if isinstance(defaults.get(key), list) and not isinstance(value, list):
|
|
value = [value]
|
|
setattr(stor, key, value)
|
|
|
|
for (key, value) in defaults.iteritems():
|
|
result = value
|
|
if hasattr(stor, key):
|
|
result = stor[key]
|
|
if value == () and not isinstance(result, tuple):
|
|
result = (result,)
|
|
setattr(stor, key, result)
|
|
|
|
return stor
|
|
|
|
class Enum(Storage):
|
|
def __init__(self, *a):
|
|
self.name = tuple(a)
|
|
Storage.__init__(self, ((e, i) for i, e in enumerate(a)))
|
|
def __contains__(self, item):
|
|
if isinstance(item, int):
|
|
return item in self.values()
|
|
else:
|
|
return Storage.__contains__(self, item)
|
|
|
|
|
|
class Results():
|
|
def __init__(self, sa_ResultProxy, build_fn, do_batch=False):
|
|
self.rp = sa_ResultProxy
|
|
self.fn = build_fn
|
|
self.do_batch = do_batch
|
|
|
|
@property
|
|
def rowcount(self):
|
|
return self.rp.rowcount
|
|
|
|
def _fetch(self, res):
|
|
if self.do_batch:
|
|
return self.fn(res)
|
|
else:
|
|
return [self.fn(row) for row in res]
|
|
|
|
def fetchall(self):
|
|
return self._fetch(self.rp.fetchall())
|
|
|
|
def fetchmany(self, n):
|
|
rows = self._fetch(self.rp.fetchmany(n))
|
|
if rows:
|
|
return rows
|
|
else:
|
|
raise StopIteration
|
|
|
|
def fetchone(self):
|
|
row = self.rp.fetchone()
|
|
if row:
|
|
if self.do_batch:
|
|
row = tup(row)
|
|
return self.fn(row)[0]
|
|
else:
|
|
return self.fn(row)
|
|
else:
|
|
raise StopIteration
|
|
|
|
def strip_www(domain):
|
|
if domain.count('.') >= 2 and domain.startswith("www."):
|
|
return domain[4:]
|
|
else:
|
|
return domain
|
|
|
|
def is_subdomain(subdomain, base):
|
|
"""Check if a domain is equal to or a subdomain of a base domain."""
|
|
return subdomain == base or (subdomain is not None and subdomain.endswith('.' + base))
|
|
|
|
r_base_url = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^#]*[^#/])/?")
|
|
def base_url(url):
|
|
res = r_base_url.findall(url)
|
|
return (res and res[0]) or url
|
|
|
|
r_domain = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^/:#?]*)")
|
|
def domain(s):
|
|
"""
|
|
Takes a URL and returns the domain part, minus www., if
|
|
present
|
|
"""
|
|
res = r_domain.findall(s)
|
|
domain = (res and res[0]) or s
|
|
return domain.lower()
|
|
|
|
r_path_component = re.compile(".*?/(.*)")
|
|
def path_component(s):
|
|
"""
|
|
takes a url http://www.foo.com/i/like/cheese and returns
|
|
i/like/cheese
|
|
"""
|
|
res = r_path_component.findall(base_url(s))
|
|
return (res and res[0]) or s
|
|
|
|
def get_title(url):
|
|
"""Fetches the contents of url and extracts (and utf-8 encodes)
|
|
the contents of <title>"""
|
|
if not url or not (url.startswith('http://') or url.startswith('https://')):
|
|
return None
|
|
|
|
try:
|
|
opener = urlopen(url, timeout=15)
|
|
|
|
# Attempt to find the title in the first 1kb
|
|
data = opener.read(1024)
|
|
title = extract_title(data)
|
|
|
|
# Title not found in the first kb, try searching an additional 2kb
|
|
if not title:
|
|
data += opener.read(2048)
|
|
title = extract_title(data)
|
|
|
|
opener.close()
|
|
|
|
return title
|
|
|
|
except:
|
|
return None
|
|
|
|
def extract_title(data):
|
|
"""Tries to extract the value of the title element from a string of HTML"""
|
|
bs = BeautifulSoup(data, convertEntities=BeautifulSoup.HTML_ENTITIES)
|
|
if not bs:
|
|
return
|
|
|
|
title_bs = bs.html.head.title
|
|
|
|
if not title_bs or not title_bs.string:
|
|
return
|
|
|
|
return title_bs.string.encode('utf-8').strip()
|
|
|
|
valid_schemes = ('http', 'https', 'ftp', 'mailto')
|
|
valid_dns = re.compile('\A[-a-zA-Z0-9]+\Z')
|
|
def sanitize_url(url, require_scheme = False):
|
|
"""Validates that the url is of the form
|
|
|
|
scheme://domain/path/to/content#anchor?cruft
|
|
|
|
using the python built-in urlparse. If the url fails to validate,
|
|
returns None. If no scheme is provided and 'require_scheme =
|
|
False' is set, the url is returned with scheme 'http', provided it
|
|
otherwise validates"""
|
|
|
|
if not url:
|
|
return
|
|
|
|
url = url.strip()
|
|
if url.lower() == 'self':
|
|
return url
|
|
|
|
try:
|
|
u = urlparse(url)
|
|
# first pass: make sure a scheme has been specified
|
|
if not require_scheme and not u.scheme:
|
|
url = 'http://' + url
|
|
u = urlparse(url)
|
|
except ValueError:
|
|
return
|
|
|
|
if u.scheme and u.scheme in valid_schemes:
|
|
# if there is a scheme and no hostname, it is a bad url.
|
|
if not u.hostname:
|
|
return
|
|
if u.username is not None or u.password is not None:
|
|
return
|
|
labels = u.hostname.split('.')
|
|
for label in labels:
|
|
try:
|
|
#if this succeeds, this portion of the dns is almost
|
|
#valid and converted to ascii
|
|
label = label.encode('idna')
|
|
except TypeError:
|
|
print "label sucks: [%r]" % label
|
|
raise
|
|
except UnicodeError:
|
|
return
|
|
else:
|
|
#then if this success, this portion of the dns is really valid
|
|
if not re.match(valid_dns, label):
|
|
return
|
|
return url
|
|
|
|
def trunc_string(text, length):
|
|
return text[0:length]+'...' if len(text)>length else text
|
|
|
|
# Truncate a time to a certain number of minutes
|
|
# e.g, trunc_time(5:52, 30) == 5:30
|
|
def trunc_time(time, mins, hours=None):
|
|
if hours is not None:
|
|
if hours < 1 or hours > 60:
|
|
raise ValueError("Hours %d is weird" % mins)
|
|
time = time.replace(hour = hours * (time.hour / hours))
|
|
|
|
if mins < 1 or mins > 60:
|
|
raise ValueError("Mins %d is weird" % mins)
|
|
|
|
return time.replace(minute = mins * (time.minute / mins),
|
|
second = 0,
|
|
microsecond = 0)
|
|
|
|
def long_datetime(datetime):
|
|
return datetime.astimezone(g.tz).ctime() + " " + str(g.tz)
|
|
|
|
def median(l):
|
|
if l:
|
|
s = sorted(l)
|
|
i = len(s) / 2
|
|
return s[i]
|
|
|
|
def query_string(dict):
|
|
pairs = []
|
|
for k,v in dict.iteritems():
|
|
if v is not None:
|
|
try:
|
|
k = url_escape(_force_unicode(k))
|
|
v = url_escape(_force_unicode(v))
|
|
pairs.append(k + '=' + v)
|
|
except UnicodeDecodeError:
|
|
continue
|
|
if pairs:
|
|
return '?' + '&'.join(pairs)
|
|
else:
|
|
return ''
|
|
|
|
class UrlParser(object):
|
|
"""
|
|
Wrapper for urlparse and urlunparse for making changes to urls.
|
|
All attributes present on the tuple-like object returned by
|
|
urlparse are present on this class, and are setable, with the
|
|
exception of netloc, which is instead treated via a getter method
|
|
as a concatenation of hostname and port.
|
|
|
|
Unlike urlparse, this class allows the query parameters to be
|
|
converted to a dictionary via the query_dict method (and
|
|
correspondingly updated vi update_query). The extension of the
|
|
path can also be set and queried.
|
|
|
|
The class also contains reddit-specific functions for setting,
|
|
checking, and getting a path's subreddit. It also can convert
|
|
paths between in-frame and out of frame cname'd forms.
|
|
|
|
"""
|
|
|
|
__slots__ = ['scheme', 'path', 'params', 'query',
|
|
'fragment', 'username', 'password', 'hostname',
|
|
'port', '_url_updates', '_orig_url', '_query_dict']
|
|
|
|
valid_schemes = ('http', 'https', 'ftp', 'mailto')
|
|
cname_get = "cnameframe"
|
|
|
|
def __init__(self, url):
|
|
u = urlparse(url)
|
|
for s in self.__slots__:
|
|
if hasattr(u, s):
|
|
setattr(self, s, getattr(u, s))
|
|
self._url_updates = {}
|
|
self._orig_url = url
|
|
self._query_dict = None
|
|
|
|
def update_query(self, **updates):
|
|
"""
|
|
Can be used instead of self.query_dict.update() to add/change
|
|
query params in situations where the original contents are not
|
|
required.
|
|
"""
|
|
self._url_updates.update(updates)
|
|
|
|
@property
|
|
def query_dict(self):
|
|
"""
|
|
Parses the `params' attribute of the original urlparse and
|
|
generates a dictionary where both the keys and values have
|
|
been url_unescape'd. Any updates or changes to the resulting
|
|
dict will be reflected in the updated query params
|
|
"""
|
|
if self._query_dict is None:
|
|
def _split(param):
|
|
p = param.split('=')
|
|
return (unquote_plus(p[0]),
|
|
unquote_plus('='.join(p[1:])))
|
|
self._query_dict = dict(_split(p) for p in self.query.split('&')
|
|
if p)
|
|
return self._query_dict
|
|
|
|
def path_extension(self):
|
|
"""
|
|
Fetches the current extension of the path.
|
|
"""
|
|
return self.path.split('/')[-1].split('.')[-1]
|
|
|
|
def set_extension(self, extension):
|
|
"""
|
|
Changes the extension of the path to the provided value (the
|
|
"." should not be included in the extension as a "." is
|
|
provided)
|
|
"""
|
|
pieces = self.path.split('/')
|
|
dirs = pieces[:-1]
|
|
base = pieces[-1].split('.')
|
|
base = '.'.join(base[:-1] if len(base) > 1 else base)
|
|
if extension:
|
|
base += '.' + extension
|
|
dirs.append(base)
|
|
self.path = '/'.join(dirs)
|
|
return self
|
|
|
|
|
|
def unparse(self):
|
|
"""
|
|
Converts the url back to a string, applying all updates made
|
|
to the feilds thereof.
|
|
|
|
Note: if a host name has been added and none was present
|
|
before, will enforce scheme -> "http" unless otherwise
|
|
specified. Double-slashes are removed from the resultant
|
|
path, and the query string is reconstructed only if the
|
|
query_dict has been modified/updated.
|
|
"""
|
|
# only parse the query params if there is an update dict
|
|
q = self.query
|
|
if self._url_updates or self._query_dict is not None:
|
|
q = self._query_dict or self.query_dict
|
|
q.update(self._url_updates)
|
|
q = query_string(q).lstrip('?')
|
|
|
|
# make sure the port is not doubly specified
|
|
if self.port and ":" in self.hostname:
|
|
self.hostname = self.hostname.split(':')[0]
|
|
|
|
# if there is a netloc, there had better be a scheme
|
|
if self.netloc and not self.scheme:
|
|
self.scheme = "http"
|
|
|
|
return urlunparse((self.scheme, self.netloc,
|
|
self.path.replace('//', '/'),
|
|
self.params, q, self.fragment))
|
|
|
|
def path_has_subreddit(self):
|
|
"""
|
|
utility method for checking if the path starts with a
|
|
subreddit specifier (namely /r/ or /reddits/).
|
|
"""
|
|
return (self.path.startswith('/r/') or
|
|
self.path.startswith('/reddits/'))
|
|
|
|
def get_subreddit(self):
|
|
"""checks if the current url refers to a subreddit and returns
|
|
that subreddit object. The cases here are:
|
|
|
|
* the hostname is unset or is g.domain, in which case it
|
|
looks for /r/XXXX or /reddits. The default in this case
|
|
is Default.
|
|
* the hostname is a cname to a known subreddit.
|
|
|
|
On failure to find a subreddit, returns None.
|
|
"""
|
|
from pylons import g
|
|
from r2.models import Subreddit, Sub, NotFound, DefaultSR
|
|
try:
|
|
if not self.hostname or self.hostname.startswith(g.domain):
|
|
if self.path.startswith('/r/'):
|
|
return Subreddit._by_name(self.path.split('/')[2])
|
|
elif self.path.startswith('/reddits/'):
|
|
return Sub
|
|
else:
|
|
return DefaultSR()
|
|
elif self.hostname:
|
|
return Subreddit._by_domain(self.hostname)
|
|
except NotFound:
|
|
pass
|
|
return None
|
|
|
|
def is_reddit_url(self, subreddit = None):
|
|
"""utility method for seeing if the url is associated with
|
|
reddit as we don't necessarily want to mangle non-reddit
|
|
domains
|
|
|
|
returns true only if hostname is nonexistant, a subdomain of
|
|
g.domain, or a subdomain of the provided subreddit's cname.
|
|
"""
|
|
from pylons import g
|
|
return (not self.hostname or
|
|
is_subdomain(self.hostname, g.domain) or
|
|
(subreddit and subreddit.domain and
|
|
is_subdomain(self.hostname, subreddit.domain)))
|
|
|
|
def path_add_subreddit(self, subreddit):
|
|
"""
|
|
Adds the subreddit's path to the path if another subreddit's
|
|
prefix is not already present.
|
|
"""
|
|
if not self.path_has_subreddit():
|
|
self.path = (subreddit.path + self.path)
|
|
return self
|
|
|
|
@property
|
|
def netloc(self):
|
|
"""
|
|
Getter method which returns the hostname:port, or empty string
|
|
if no hostname is present.
|
|
"""
|
|
if not self.hostname:
|
|
return ""
|
|
elif getattr(self, "port", None):
|
|
return self.hostname + ":" + str(self.port)
|
|
return self.hostname
|
|
|
|
def mk_cname(self, require_frame = True, subreddit = None, port = None):
|
|
"""
|
|
Converts a ?cnameframe url into the corresponding cnamed
|
|
domain if applicable. Useful for frame-busting on redirect.
|
|
"""
|
|
|
|
# make sure the url is indeed in a frame
|
|
if require_frame and not self.query_dict.has_key(self.cname_get):
|
|
return self
|
|
|
|
# fetch the subreddit and make sure it
|
|
subreddit = subreddit or self.get_subreddit()
|
|
if subreddit and subreddit.domain:
|
|
|
|
# no guarantee there was a scheme
|
|
self.scheme = self.scheme or "http"
|
|
|
|
# update the domain (preserving the port)
|
|
self.hostname = subreddit.domain
|
|
self.port = self.port or port
|
|
|
|
# and remove any cnameframe GET parameters
|
|
if self.query_dict.has_key(self.cname_get):
|
|
del self._query_dict[self.cname_get]
|
|
|
|
# remove the subreddit reference
|
|
self.path = lstrips(self.path, subreddit.path)
|
|
if not self.path.startswith('/'):
|
|
self.path = '/' + self.path
|
|
|
|
return self
|
|
|
|
def is_in_frame(self):
|
|
"""
|
|
Checks if the url is in a frame by determining if
|
|
cls.cname_get is present.
|
|
"""
|
|
return self.query_dict.has_key(self.cname_get)
|
|
|
|
def put_in_frame(self):
|
|
"""
|
|
Adds the cls.cname_get get parameter to the query string.
|
|
"""
|
|
self.update_query(**{self.cname_get:random.random()})
|
|
|
|
def __repr__(self):
|
|
return "<URL %s>" % repr(self.unparse())
|
|
|
|
def domain_permutations(self, fragments=False, subdomains=True):
|
|
"""
|
|
Takes a domain like `www.reddit.com`, and returns a list of ways
|
|
that a user might search for it, like:
|
|
* www
|
|
* reddit
|
|
* com
|
|
* www.reddit.com
|
|
* reddit.com
|
|
* com
|
|
"""
|
|
ret = set()
|
|
if self.hostname:
|
|
r = self.hostname.split('.')
|
|
|
|
if subdomains:
|
|
for x in xrange(len(r)-1):
|
|
ret.add('.'.join(r[x:len(r)]))
|
|
|
|
if fragments:
|
|
for x in r:
|
|
ret.add(x)
|
|
|
|
return ret
|
|
|
|
@classmethod
|
|
def base_url(cls, url):
|
|
u = cls(url)
|
|
|
|
# strip off any www and lowercase the hostname:
|
|
netloc = strip_www(u.netloc.lower())
|
|
|
|
# http://code.google.com/web/ajaxcrawling/docs/specification.html
|
|
fragment = u.fragment if u.fragment.startswith("!") else ""
|
|
|
|
return urlunparse((u.scheme.lower(), netloc,
|
|
u.path, u.params, u.query, fragment))
|
|
|
|
|
|
def to_js(content, callback="document.write", escape=True):
|
|
before = after = ''
|
|
if callback:
|
|
before = callback + "("
|
|
after = ");"
|
|
if escape:
|
|
content = string2js(content)
|
|
return before + content + after
|
|
|
|
def pload(fname, default = None):
|
|
"Load a pickled object from a file"
|
|
try:
|
|
f = file(fname, 'r')
|
|
d = pickle.load(f)
|
|
except IOError:
|
|
d = default
|
|
else:
|
|
f.close()
|
|
return d
|
|
|
|
def psave(fname, d):
|
|
"Save a pickled object into a file"
|
|
f = file(fname, 'w')
|
|
pickle.dump(d, f)
|
|
f.close()
|
|
|
|
def unicode_safe(res):
|
|
try:
|
|
return str(res)
|
|
except UnicodeEncodeError:
|
|
try:
|
|
return unicode(res).encode('utf-8')
|
|
except UnicodeEncodeError:
|
|
return res.decode('utf-8').encode('utf-8')
|
|
|
|
def decompose_fullname(fullname):
|
|
"""
|
|
decompose_fullname("t3_e4fa") ->
|
|
(Thing, 3, 658918)
|
|
"""
|
|
from r2.lib.db.thing import Thing,Relation
|
|
if fullname[0] == 't':
|
|
type_class = Thing
|
|
elif fullname[0] == 'r':
|
|
type_class = Relation
|
|
|
|
type_id36, thing_id36 = fullname[1:].split('_')
|
|
|
|
type_id = int(type_id36,36)
|
|
id = int(thing_id36,36)
|
|
|
|
return (type_class, type_id, id)
|
|
|
|
def cols(lst, ncols):
|
|
"""divides a list into columns, and returns the
|
|
rows. e.g. cols('abcdef', 2) returns (('a', 'd'), ('b', 'e'), ('c',
|
|
'f'))"""
|
|
nrows = int(math.ceil(1.*len(lst) / ncols))
|
|
lst = lst + [None for i in range(len(lst), nrows*ncols)]
|
|
cols = [lst[i:i+nrows] for i in range(0, nrows*ncols, nrows)]
|
|
rows = zip(*cols)
|
|
rows = [filter(lambda x: x is not None, r) for r in rows]
|
|
return rows
|
|
|
|
def fetch_things(t_class,since,until,batch_fn=None,
|
|
*query_params, **extra_query_dict):
|
|
"""
|
|
Simple utility function to fetch all Things of class t_class
|
|
(spam or not, but not deleted) that were created from 'since'
|
|
to 'until'
|
|
"""
|
|
|
|
from r2.lib.db.operators import asc
|
|
|
|
if not batch_fn:
|
|
batch_fn = lambda x: x
|
|
|
|
query_params = ([t_class.c._date >= since,
|
|
t_class.c._date < until,
|
|
t_class.c._spam == (True,False)]
|
|
+ list(query_params))
|
|
query_dict = {'sort': asc('_date'),
|
|
'limit': 100,
|
|
'data': True}
|
|
query_dict.update(extra_query_dict)
|
|
|
|
q = t_class._query(*query_params,
|
|
**query_dict)
|
|
|
|
orig_rules = deepcopy(q._rules)
|
|
|
|
things = list(q)
|
|
while things:
|
|
things = batch_fn(things)
|
|
for t in things:
|
|
yield t
|
|
q._rules = deepcopy(orig_rules)
|
|
q._after(t)
|
|
things = list(q)
|
|
|
|
def fetch_things2(query, chunk_size = 100, batch_fn = None, chunks = False):
|
|
"""Incrementally run query with a limit of chunk_size until there are
|
|
no results left. batch_fn transforms the results for each chunk
|
|
before returning."""
|
|
orig_rules = deepcopy(query._rules)
|
|
query._limit = chunk_size
|
|
items = list(query)
|
|
done = False
|
|
while items and not done:
|
|
#don't need to query again at the bottom if we didn't get enough
|
|
if len(items) < chunk_size:
|
|
done = True
|
|
|
|
after = items[-1]
|
|
|
|
if batch_fn:
|
|
items = batch_fn(items)
|
|
|
|
if chunks:
|
|
yield items
|
|
else:
|
|
for i in items:
|
|
yield i
|
|
|
|
if not done:
|
|
query._rules = deepcopy(orig_rules)
|
|
query._after(after)
|
|
items = list(query)
|
|
|
|
def fix_if_broken(thing, delete = True, fudge_links = False):
|
|
from r2.models import Link, Comment, Subreddit, Message
|
|
|
|
# the minimum set of attributes that are required
|
|
attrs = dict((cls, cls._essentials)
|
|
for cls
|
|
in (Link, Comment, Subreddit, Message))
|
|
|
|
if thing.__class__ not in attrs:
|
|
raise TypeError
|
|
|
|
tried_loading = False
|
|
for attr in attrs[thing.__class__]:
|
|
try:
|
|
# try to retrieve the attribute
|
|
getattr(thing, attr)
|
|
except AttributeError:
|
|
# that failed; let's explicitly load it and try again
|
|
|
|
if not tried_loading:
|
|
tried_loading = True
|
|
thing._load()
|
|
|
|
try:
|
|
getattr(thing, attr)
|
|
except AttributeError:
|
|
if not delete:
|
|
raise
|
|
if isinstance(thing, Link) and fudge_links:
|
|
if attr == "sr_id":
|
|
thing.sr_id = 6
|
|
print "Fudging %s.sr_id to %d" % (thing._fullname,
|
|
thing.sr_id)
|
|
elif attr == "author_id":
|
|
thing.author_id = 8244672
|
|
print "Fudging %s.author_id to %d" % (thing._fullname,
|
|
thing.author_id)
|
|
else:
|
|
print "Got weird attr %s; can't fudge" % attr
|
|
|
|
if not thing._deleted:
|
|
print "%s is missing %r, deleting" % (thing._fullname, attr)
|
|
thing._deleted = True
|
|
|
|
thing._commit()
|
|
|
|
if not fudge_links:
|
|
break
|
|
|
|
|
|
def find_recent_broken_things(from_time = None, to_time = None,
|
|
delete = False):
|
|
"""
|
|
Occasionally (usually during app-server crashes), Things will
|
|
be partially written out to the database. Things missing data
|
|
attributes break the contract for these things, which often
|
|
breaks various pages. This function hunts for and destroys
|
|
them as appropriate.
|
|
"""
|
|
from r2.models import Link, Comment
|
|
from r2.lib.db.operators import desc
|
|
from pylons import g
|
|
|
|
from_time = from_time or timeago('1 hour')
|
|
to_time = to_time or datetime.now(g.tz)
|
|
|
|
for cls in (Link, Comment):
|
|
q = cls._query(cls.c._date > from_time,
|
|
cls.c._date < to_time,
|
|
data=True,
|
|
sort=desc('_date'))
|
|
for thing in fetch_things2(q):
|
|
fix_if_broken(thing, delete = delete)
|
|
|
|
|
|
def timeit(func):
|
|
"Run some function, and return (RunTimeInSeconds,Result)"
|
|
before=time.time()
|
|
res=func()
|
|
return (time.time()-before,res)
|
|
def lineno():
|
|
"Returns the current line number in our program."
|
|
import inspect
|
|
print "%s\t%s" % (datetime.now(),inspect.currentframe().f_back.f_lineno)
|
|
|
|
def IteratorFilter(iterator, fn):
|
|
for x in iterator:
|
|
if fn(x):
|
|
yield x
|
|
|
|
def UniqueIterator(iterator, key = lambda x: x):
|
|
"""
|
|
Takes an iterator and returns an iterator that returns only the
|
|
first occurence of each entry
|
|
"""
|
|
so_far = set()
|
|
def no_dups(x):
|
|
k = key(x)
|
|
if k in so_far:
|
|
return False
|
|
else:
|
|
so_far.add(k)
|
|
return True
|
|
|
|
return IteratorFilter(iterator, no_dups)
|
|
|
|
def modhash(user, rand = None, test = False):
|
|
return user.name
|
|
|
|
def valid_hash(user, hash):
|
|
return True
|
|
|
|
def check_cheating(loc):
|
|
pass
|
|
|
|
def vote_hash(user, thing, note='valid'):
|
|
return user.name
|
|
|
|
def valid_vote_hash(hash, user, thing):
|
|
return True
|
|
|
|
def safe_eval_str(unsafe_str):
|
|
return unsafe_str.replace('\\x3d', '=').replace('\\x26', '&')
|
|
|
|
rx_whitespace = re.compile('\s+', re.UNICODE)
|
|
rx_notsafe = re.compile('\W+', re.UNICODE)
|
|
rx_underscore = re.compile('_+', re.UNICODE)
|
|
def title_to_url(title, max_length = 50):
|
|
"""Takes a string and makes it suitable for use in URLs"""
|
|
title = _force_unicode(title) #make sure the title is unicode
|
|
title = rx_whitespace.sub('_', title) #remove whitespace
|
|
title = rx_notsafe.sub('', title) #remove non-printables
|
|
title = rx_underscore.sub('_', title) #remove double underscores
|
|
title = title.strip('_') #remove trailing underscores
|
|
title = title.lower() #lowercase the title
|
|
|
|
if len(title) > max_length:
|
|
#truncate to nearest word
|
|
title = title[:max_length]
|
|
last_word = title.rfind('_')
|
|
if (last_word > 0):
|
|
title = title[:last_word]
|
|
return title or "_"
|
|
|
|
def dbg(s):
|
|
import sys
|
|
sys.stderr.write('%s\n' % (s,))
|
|
|
|
def trace(fn):
|
|
def new_fn(*a,**kw):
|
|
ret = fn(*a,**kw)
|
|
dbg("Fn: %s; a=%s; kw=%s\nRet: %s"
|
|
% (fn,a,kw,ret))
|
|
return ret
|
|
return new_fn
|
|
|
|
def common_subdomain(domain1, domain2):
|
|
if not domain1 or not domain2:
|
|
return ""
|
|
domain1 = domain1.split(":")[0]
|
|
domain2 = domain2.split(":")[0]
|
|
if len(domain1) > len(domain2):
|
|
domain1, domain2 = domain2, domain1
|
|
|
|
if domain1 == domain2:
|
|
return domain1
|
|
else:
|
|
dom = domain1.split(".")
|
|
for i in range(len(dom), 1, -1):
|
|
d = '.'.join(dom[-i:])
|
|
if domain2.endswith(d):
|
|
return d
|
|
return ""
|
|
|
|
def interleave_lists(*args):
|
|
max_len = max(len(x) for x in args)
|
|
for i in xrange(max_len):
|
|
for a in args:
|
|
if i < len(a):
|
|
yield a[i]
|
|
|
|
def link_from_url(path, filter_spam = False, multiple = True):
|
|
from pylons import c
|
|
from r2.models import IDBuilder, Link, Subreddit, NotFound
|
|
|
|
if not path:
|
|
return
|
|
|
|
try:
|
|
links = Link._by_url(path, c.site)
|
|
except NotFound:
|
|
return [] if multiple else None
|
|
|
|
return filter_links(tup(links), filter_spam = filter_spam,
|
|
multiple = multiple)
|
|
|
|
def filter_links(links, filter_spam = False, multiple = True):
|
|
# run the list through a builder to remove any that the user
|
|
# isn't allowed to see
|
|
from pylons import c
|
|
from r2.models import IDBuilder, Link, Subreddit, NotFound
|
|
links = IDBuilder([link._fullname for link in links],
|
|
skip = False).get_items()[0]
|
|
if not links:
|
|
return
|
|
|
|
if filter_spam:
|
|
# first, try to remove any spam
|
|
links_nonspam = [ link for link in links
|
|
if not link._spam ]
|
|
if links_nonspam:
|
|
links = links_nonspam
|
|
|
|
# if it occurs in one or more of their subscriptions, show them
|
|
# that one first
|
|
subs = set(Subreddit.user_subreddits(c.user, limit = None))
|
|
def cmp_links(a, b):
|
|
if a.sr_id in subs and b.sr_id not in subs:
|
|
return -1
|
|
elif a.sr_id not in subs and b.sr_id in subs:
|
|
return 1
|
|
else:
|
|
return cmp(b._hot, a._hot)
|
|
links = sorted(links, cmp = cmp_links)
|
|
|
|
# among those, show them the hottest one
|
|
return links if multiple else links[0]
|
|
|
|
def link_duplicates(article):
|
|
# don't bother looking it up if the link doesn't have a URL anyway
|
|
if getattr(article, 'is_self', False):
|
|
return []
|
|
|
|
return url_links(article.url, exclude = article._fullname)
|
|
|
|
def url_links(url, exclude=None):
|
|
from r2.models import Link, NotFound
|
|
|
|
try:
|
|
links = tup(Link._by_url(url, None))
|
|
except NotFound:
|
|
links = []
|
|
|
|
links = [ link for link in links
|
|
if link._fullname != exclude ]
|
|
return links
|
|
|
|
class TimeoutFunctionException(Exception):
|
|
pass
|
|
|
|
class TimeoutFunction:
|
|
"""Force an operation to timeout after N seconds. Works with POSIX
|
|
signals, so it's not safe to use in a multi-treaded environment"""
|
|
def __init__(self, function, timeout):
|
|
self.timeout = timeout
|
|
self.function = function
|
|
|
|
def handle_timeout(self, signum, frame):
|
|
raise TimeoutFunctionException()
|
|
|
|
def __call__(self, *args):
|
|
# can only be called from the main thread
|
|
old = signal.signal(signal.SIGALRM, self.handle_timeout)
|
|
signal.alarm(self.timeout)
|
|
try:
|
|
result = self.function(*args)
|
|
finally:
|
|
signal.alarm(0)
|
|
signal.signal(signal.SIGALRM, old)
|
|
return result
|
|
|
|
def make_offset_date(start_date, interval, future = True,
|
|
business_days = False):
|
|
"""
|
|
Generates a date in the future or past "interval" days from start_date.
|
|
|
|
Can optionally give weekends no weight in the calculation if
|
|
"business_days" is set to true.
|
|
"""
|
|
if interval is not None:
|
|
interval = int(interval)
|
|
if business_days:
|
|
weeks = interval / 7
|
|
dow = start_date.weekday()
|
|
if future:
|
|
future_dow = (dow + interval) % 7
|
|
if dow > future_dow or future_dow > 4:
|
|
weeks += 1
|
|
else:
|
|
future_dow = (dow - interval) % 7
|
|
if dow < future_dow or future_dow > 4:
|
|
weeks += 1
|
|
interval += 2 * weeks;
|
|
if future:
|
|
return start_date + timedelta(interval)
|
|
return start_date - timedelta(interval)
|
|
return start_date
|
|
|
|
def in_chunks(it, size=25):
|
|
chunk = []
|
|
it = iter(it)
|
|
try:
|
|
while True:
|
|
chunk.append(it.next())
|
|
if len(chunk) >= size:
|
|
yield chunk
|
|
chunk = []
|
|
except StopIteration:
|
|
if chunk:
|
|
yield chunk
|
|
|
|
def spaceout(items, targetseconds,
|
|
minsleep = 0, die = False,
|
|
estimate = None):
|
|
"""Given a list of items and a function to apply to them, space
|
|
the execution out over the target number of seconds and
|
|
optionally stop when we're out of time"""
|
|
targetseconds = float(targetseconds)
|
|
state = [1.0]
|
|
|
|
if estimate is None:
|
|
try:
|
|
estimate = len(items)
|
|
except TypeError:
|
|
# if we can't come up with an estimate, the best we can do
|
|
# is just enforce the minimum sleep time (and the max
|
|
# targetseconds if die==True)
|
|
pass
|
|
|
|
mean = lambda lst: sum(float(x) for x in lst)/float(len(lst))
|
|
beginning = datetime.now()
|
|
|
|
for item in items:
|
|
start = datetime.now()
|
|
yield item
|
|
end = datetime.now()
|
|
|
|
took_delta = end - start
|
|
took = (took_delta.days * 60 * 24
|
|
+ took_delta.seconds
|
|
+ took_delta.microseconds/1000000.0)
|
|
state.append(took)
|
|
if len(state) > 10:
|
|
del state[0]
|
|
|
|
if die and end > beginning + timedelta(seconds=targetseconds):
|
|
# we ran out of time, ignore the rest of the iterator
|
|
break
|
|
|
|
if estimate is None:
|
|
if minsleep:
|
|
# we have no idea how many items we're going to get
|
|
sleep(minsleep)
|
|
else:
|
|
sleeptime = max((targetseconds / estimate) - mean(state),
|
|
minsleep)
|
|
if sleeptime > 0:
|
|
sleep(sleeptime)
|
|
|
|
def progress(it, verbosity=100, key=repr, estimate=None, persec=True):
|
|
"""An iterator that yields everything from `it', but prints progress
|
|
information along the way, including time-estimates if
|
|
possible"""
|
|
from itertools import islice
|
|
from datetime import datetime
|
|
import sys
|
|
|
|
now = start = datetime.now()
|
|
elapsed = start - start
|
|
|
|
# try to guess at the estimate if we can
|
|
if estimate is None:
|
|
try:
|
|
estimate = len(it)
|
|
except:
|
|
pass
|
|
|
|
def timedelta_to_seconds(td):
|
|
return td.days * (24*60*60) + td.seconds + (float(td.microseconds) / 1000000)
|
|
def format_timedelta(td, sep=''):
|
|
ret = []
|
|
s = timedelta_to_seconds(td)
|
|
if s < 0:
|
|
neg = True
|
|
s *= -1
|
|
else:
|
|
neg = False
|
|
|
|
if s >= (24*60*60):
|
|
days = int(s//(24*60*60))
|
|
ret.append('%dd' % days)
|
|
s -= days*(24*60*60)
|
|
if s >= 60*60:
|
|
hours = int(s//(60*60))
|
|
ret.append('%dh' % hours)
|
|
s -= hours*(60*60)
|
|
if s >= 60:
|
|
minutes = int(s//60)
|
|
ret.append('%dm' % minutes)
|
|
s -= minutes*60
|
|
if s >= 1:
|
|
seconds = int(s)
|
|
ret.append('%ds' % seconds)
|
|
s -= seconds
|
|
|
|
if not ret:
|
|
return '0s'
|
|
|
|
return ('-' if neg else '') + sep.join(ret)
|
|
def format_datetime(dt, show_date=False):
|
|
if show_date:
|
|
return dt.strftime('%Y-%m-%d %H:%M')
|
|
else:
|
|
return dt.strftime('%H:%M:%S')
|
|
def deq(dt1, dt2):
|
|
"Indicates whether the two datetimes' dates describe the same (day,month,year)"
|
|
d1, d2 = dt1.date(), dt2.date()
|
|
return ( d1.day == d2.day
|
|
and d1.month == d2.month
|
|
and d1.year == d2.year)
|
|
|
|
sys.stderr.write('Starting at %s\n' % (start,))
|
|
|
|
# we're going to islice it so we need to start an iterator
|
|
it = iter(it)
|
|
|
|
seen = 0
|
|
while True:
|
|
this_chunk = 0
|
|
thischunk_started = datetime.now()
|
|
|
|
# the simple bit: just iterate and yield
|
|
for item in islice(it, verbosity):
|
|
this_chunk += 1
|
|
seen += 1
|
|
yield item
|
|
|
|
if this_chunk < verbosity:
|
|
# we're done, the iterator is empty
|
|
break
|
|
|
|
now = datetime.now()
|
|
elapsed = now - start
|
|
thischunk_seconds = timedelta_to_seconds(now - thischunk_started)
|
|
|
|
if estimate:
|
|
# the estimate is based on the total number of items that
|
|
# we've processed in the total amount of time that's
|
|
# passed, so it should smooth over momentary spikes in
|
|
# speed (but will take a while to adjust to long-term
|
|
# changes in speed)
|
|
remaining = ((elapsed/seen)*estimate)-elapsed
|
|
completion = now + remaining
|
|
count_str = ('%d/%d %.2f%%'
|
|
% (seen, estimate, float(seen)/estimate*100))
|
|
completion_str = format_datetime(completion, not deq(completion,now))
|
|
estimate_str = (' (%s remaining; completion %s)'
|
|
% (format_timedelta(remaining),
|
|
completion_str))
|
|
else:
|
|
count_str = '%d' % seen
|
|
estimate_str = ''
|
|
|
|
if key:
|
|
key_str = ': %s' % key(item)
|
|
else:
|
|
key_str = ''
|
|
|
|
# unlike the estimate, the persec count is the number per
|
|
# second for *this* batch only, without smoothing
|
|
if persec and thischunk_seconds > 0:
|
|
persec_str = ' (%.1f/s)' % (float(this_chunk)/thischunk_seconds,)
|
|
else:
|
|
persec_str = ''
|
|
|
|
sys.stderr.write('%s%s, %s%s%s\n'
|
|
% (count_str, persec_str,
|
|
format_timedelta(elapsed), estimate_str, key_str))
|
|
|
|
now = datetime.now()
|
|
elapsed = now - start
|
|
elapsed_seconds = timedelta_to_seconds(elapsed)
|
|
if persec and seen > 0 and elapsed_seconds > 0:
|
|
persec_str = ' (@%.1f/sec)' % (float(seen)/elapsed_seconds)
|
|
else:
|
|
persec_str = ''
|
|
sys.stderr.write('Processed %d%s items in %s..%s (%s)\n'
|
|
% (seen,
|
|
persec_str,
|
|
format_datetime(start, not deq(start, now)),
|
|
format_datetime(now, not deq(start, now)),
|
|
format_timedelta(elapsed)))
|
|
|
|
class Hell(object):
|
|
def __str__(self):
|
|
return "boom!"
|
|
|
|
class Bomb(object):
|
|
@classmethod
|
|
def __getattr__(cls, key):
|
|
raise Hell()
|
|
|
|
@classmethod
|
|
def __setattr__(cls, key, val):
|
|
raise Hell()
|
|
|
|
@classmethod
|
|
def __repr__(cls):
|
|
raise Hell()
|
|
|
|
class SimpleSillyStub(object):
|
|
"""A simple stub object that does nothing when you call its methods."""
|
|
def __nonzero__(self):
|
|
return False
|
|
|
|
def __getattr__(self, name):
|
|
return self.stub
|
|
|
|
def stub(self, *args, **kwargs):
|
|
pass
|
|
|
|
def strordict_fullname(item, key='fullname'):
|
|
"""Sometimes we migrate AMQP queues from simple strings to pickled
|
|
dictionaries. During the migratory period there may be items in
|
|
the queue of both types, so this function tries to detect which
|
|
the item is. It shouldn't really be used on a given queue for more
|
|
than a few hours or days"""
|
|
try:
|
|
d = pickle.loads(item)
|
|
except:
|
|
d = {key: item}
|
|
|
|
if (not isinstance(d, dict)
|
|
or key not in d
|
|
or not isinstance(d[key], str)):
|
|
raise ValueError('Error trying to migrate %r (%r)'
|
|
% (item, d))
|
|
|
|
return d
|
|
|
|
def thread_dump(*a):
|
|
import sys, traceback
|
|
from datetime import datetime
|
|
|
|
sys.stderr.write('%(t)s Thread Dump @%(d)s %(t)s\n' % dict(t='*'*15,
|
|
d=datetime.now()))
|
|
|
|
for thread_id, stack in sys._current_frames().items():
|
|
sys.stderr.write('\t-- Thread ID: %s--\n' % (thread_id,))
|
|
|
|
for filename, lineno, fnname, line in traceback.extract_stack(stack):
|
|
sys.stderr.write('\t\t%(filename)s(%(lineno)d): %(fnname)s\n'
|
|
% dict(filename=filename, lineno=lineno, fnname=fnname))
|
|
sys.stderr.write('\t\t\t%(line)s\n' % dict(line=line))
|
|
|
|
|
|
def constant_time_compare(actual, expected):
|
|
"""
|
|
Returns True if the two strings are equal, False otherwise
|
|
|
|
The time taken is dependent on the number of characters provided
|
|
instead of the number of characters that match.
|
|
"""
|
|
actual_len = len(actual)
|
|
expected_len = len(expected)
|
|
result = actual_len ^ expected_len
|
|
if expected_len > 0:
|
|
for i in xrange(actual_len):
|
|
result |= ord(actual[i]) ^ ord(expected[i % expected_len])
|
|
return result == 0
|
|
|
|
def wraps_api(f):
|
|
# work around issue where wraps() requires attributes to exist
|
|
if not hasattr(f, '_api_doc'):
|
|
f._api_doc = {}
|
|
return wraps(f, assigned=WRAPPER_ASSIGNMENTS+('_api_doc',))
|
|
|
|
|
|
def extract_urls_from_markdown(md):
|
|
"Extract URLs that will be hot links from a piece of raw Markdown."
|
|
|
|
html = snudown.markdown(_force_utf8(md))
|
|
links = SoupStrainer("a")
|
|
|
|
for link in BeautifulSoup(html, parseOnlyThese=links):
|
|
url = link.get('href')
|
|
if url:
|
|
yield url
|
|
|
|
|
|
def summarize_markdown(md):
|
|
"""Get the first paragraph of some Markdown text, potentially truncated."""
|
|
|
|
first_graf, sep, rest = md.partition("\n\n")
|
|
return first_graf[:500]
|
|
|
|
|
|
def find_containing_network(ip_ranges, address):
|
|
"""Find an IP network that contains the given address."""
|
|
addr = ipaddress.ip_address(address)
|
|
for network in ip_ranges:
|
|
if addr in network:
|
|
return network
|
|
return None
|
|
|
|
|
|
def is_throttled(address):
|
|
"""Determine if an IP address is in a throttled range."""
|
|
return bool(find_containing_network(g.throttles, address))
|
|
|
|
|
|
def parse_http_basic(authorization_header):
|
|
"""Parse the username/credentials out of an HTTP Basic Auth header.
|
|
|
|
Raises RequirementException if anything is uncool.
|
|
"""
|
|
auth_scheme, auth_token = require_split(authorization_header, 2)
|
|
require(auth_scheme.lower() == "basic")
|
|
try:
|
|
auth_data = base64.b64decode(auth_token)
|
|
except TypeError:
|
|
raise RequirementException
|
|
return require_split(auth_data, 2, ":")
|
|
|
|
|
|
def simple_traceback():
|
|
"""Generate a pared-down traceback that's human readable but small."""
|
|
|
|
stack_trace = traceback.extract_stack(limit=7)[:-2]
|
|
return "\n".join(":".join((os.path.basename(filename),
|
|
function_name,
|
|
str(line_number),
|
|
))
|
|
for filename, line_number, function_name, text
|
|
in stack_trace)
|
|
|
|
|
|
class GoldPrice(object):
|
|
"""Simple price math / formatting type.
|
|
|
|
Prices are assumed to be USD at the moment.
|
|
|
|
"""
|
|
def __init__(self, decimal):
|
|
self.decimal = Decimal(decimal)
|
|
|
|
def __mul__(self, other):
|
|
return type(self)(self.decimal * other)
|
|
|
|
def __div__(self, other):
|
|
return type(self)(self.decimal / other)
|
|
|
|
def __str__(self):
|
|
return "$%s" % self.decimal.quantize(Decimal("1.00"))
|
|
|
|
def __repr__(self):
|
|
return "%s(%s)" % (type(self).__name__, self)
|
|
|
|
@property
|
|
def pennies(self):
|
|
return int(self.decimal * 100)
|
|
|
|
|
|
def config_gold_price(v, key=None, data=None):
|
|
return GoldPrice(v)
|
|
|