Files
reddit/r2/r2/lib/utils/utils.py

1364 lines
42 KiB
Python

# The contents of this file are subject to the Common Public Attribution
# License Version 1.0. (the "License"); you may not use this file except in
# compliance with the License. You may obtain a copy of the License at
# http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
# License Version 1.1, but Sections 14 and 15 have been added to cover use of
# software over a computer network and provide for limited attribution for the
# Original Developer. In addition, Exhibit A has been modified to be consistent
# with Exhibit B.
#
# Software distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
# the specific language governing rights and limitations under the License.
#
# The Original Code is Reddit.
#
# The Original Developer is the Initial Developer. The Initial Developer of the
# Original Code is CondeNet, Inc.
#
# All portions of the code written by CondeNet are Copyright (c) 2006-2010
# CondeNet, Inc. All Rights Reserved.
################################################################################
from urllib import unquote_plus
from urllib2 import urlopen
from urlparse import urlparse, urlunparse
from threading import local
import signal
from copy import deepcopy
import cPickle as pickle
import re, math, random
from BeautifulSoup import BeautifulSoup
from time import sleep
from datetime import datetime, timedelta
from pylons.i18n import ungettext, _
from r2.lib.filters import _force_unicode
from mako.filters import url_escape
from r2.lib.utils._utils import *
iters = (list, tuple, set)
def randstr(len, reallyrandom = False):
"""If reallyrandom = False, generates a random alphanumeric string
(base-36 compatible) of length len. If reallyrandom, add
uppercase and punctuation (which we'll call 'base-93' for the sake
of argument) and suitable for use as salt."""
alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789'
if reallyrandom:
alphabet += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&\()*+,-./:;<=>?@[\\]^_{|}~'
return ''.join(random.choice(alphabet)
for i in range(len))
def is_authorized_cname(domain, cnames):
return any((domain == cname or domain.endswith('.' + cname))
for cname in cnames)
class Storage(dict):
"""
A Storage object is like a dictionary except `obj.foo` can be used
in addition to `obj['foo']`.
>>> o = storage(a=1)
>>> o.a
1
>>> o['a']
1
>>> o.a = 2
>>> o['a']
2
>>> del o.a
>>> o.a
Traceback (most recent call last):
...
AttributeError: 'a'
"""
def __getattr__(self, key):
try:
return self[key]
except KeyError, k:
raise AttributeError, k
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
try:
del self[key]
except KeyError, k:
raise AttributeError, k
def __repr__(self):
return '<Storage ' + dict.__repr__(self) + '>'
storage = Storage
def storify(mapping, *requireds, **defaults):
"""
Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
d doesn't have all of the keys in `requireds` and using the default
values for keys found in `defaults`.
For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
`storage({'a':1, 'b':2, 'c':3})`.
If a `storify` value is a list (e.g. multiple values in a form submission),
`storify` returns the last element of the list, unless the key appears in
`defaults` as a list. Thus:
>>> storify({'a':[1, 2]}).a
2
>>> storify({'a':[1, 2]}, a=[]).a
[1, 2]
>>> storify({'a':1}, a=[]).a
[1]
>>> storify({}, a=[]).a
[]
Similarly, if the value has a `value` attribute, `storify will return _its_
value, unless the key appears in `defaults` as a dictionary.
>>> storify({'a':storage(value=1)}).a
1
>>> storify({'a':storage(value=1)}, a={}).a
<Storage {'value': 1}>
>>> storify({}, a={}).a
{}
"""
def getvalue(x):
if hasattr(x, 'value'):
return x.value
else:
return x
stor = Storage()
for key in requireds + tuple(mapping.keys()):
value = mapping[key]
if isinstance(value, list):
if isinstance(defaults.get(key), list):
value = [getvalue(x) for x in value]
else:
value = value[-1]
if not isinstance(defaults.get(key), dict):
value = getvalue(value)
if isinstance(defaults.get(key), list) and not isinstance(value, list):
value = [value]
setattr(stor, key, value)
for (key, value) in defaults.iteritems():
result = value
if hasattr(stor, key):
result = stor[key]
if value == () and not isinstance(result, tuple):
result = (result,)
setattr(stor, key, result)
return stor
class Enum(Storage):
def __init__(self, *a):
self.name = tuple(a)
Storage.__init__(self, ((e, i) for i, e in enumerate(a)))
def __contains__(self, item):
if isinstance(item, int):
return item in self.values()
else:
return Storage.__contains__(self, item)
class Results():
def __init__(self, sa_ResultProxy, build_fn, do_batch=False):
self.rp = sa_ResultProxy
self.fn = build_fn
self.do_batch = do_batch
@property
def rowcount(self):
return self.rp.rowcount
def _fetch(self, res):
if self.do_batch:
return self.fn(res)
else:
return [self.fn(row) for row in res]
def fetchall(self):
return self._fetch(self.rp.fetchall())
def fetchmany(self, n):
rows = self._fetch(self.rp.fetchmany(n))
if rows:
return rows
else:
raise StopIteration
def fetchone(self):
row = self.rp.fetchone()
if row:
if self.do_batch:
row = tup(row)
return self.fn(row)[0]
else:
return self.fn(row)
else:
raise StopIteration
r_base_url = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^#]*[^#/])/?")
def base_url(url):
res = r_base_url.findall(url)
return (res and res[0]) or url
r_domain = re.compile("(?i)(?:.+?://)?(?:www[\d]*\.)?([^/:#?]*)")
def domain(s):
"""
Takes a URL and returns the domain part, minus www., if
present
"""
res = r_domain.findall(s)
domain = (res and res[0]) or s
return domain.lower()
r_path_component = re.compile(".*?/(.*)")
def path_component(s):
"""
takes a url http://www.foo.com/i/like/cheese and returns
i/like/cheese
"""
res = r_path_component.findall(base_url(s))
return (res and res[0]) or s
def get_title(url):
"""Fetches the contents of url and extracts (and utf-8 encodes)
the contents of <title>"""
if not url or not url.startswith('http://'):
return None
try:
# if we don't find it in the first kb of the resource, we
# probably won't find it
opener = urlopen(url, timeout=15)
text = opener.read(1024)
opener.close()
bs = BeautifulSoup(text)
if not bs:
return
title_bs = bs.html.head.title
if not title_bs or not title_bs.string:
return
return title_bs.string.encode('utf-8')
except:
return None
valid_schemes = ('http', 'https', 'ftp', 'mailto')
valid_dns = re.compile('\A[-a-zA-Z0-9]+\Z')
def sanitize_url(url, require_scheme = False):
"""Validates that the url is of the form
scheme://domain/path/to/content#anchor?cruft
using the python built-in urlparse. If the url fails to validate,
returns None. If no scheme is provided and 'require_scheme =
False' is set, the url is returned with scheme 'http', provided it
otherwise validates"""
if not url:
return
url = url.strip()
if url.lower() == 'self':
return url
u = urlparse(url)
# first pass: make sure a scheme has been specified
if not require_scheme and not u.scheme:
url = 'http://' + url
u = urlparse(url)
if u.scheme and u.scheme in valid_schemes:
# if there is a scheme and no hostname, it is a bad url.
if not u.hostname:
return
if u.username is not None or u.password is not None:
return
labels = u.hostname.split('.')
for label in labels:
try:
#if this succeeds, this portion of the dns is almost
#valid and converted to ascii
label = label.encode('idna')
except TypeError:
print "label sucks: [%r]" % label
raise
except UnicodeError:
return
else:
#then if this success, this portion of the dns is really valid
if not re.match(valid_dns, label):
return
return url
def trunc_string(text, length):
return text[0:length]+'...' if len(text)>length else text
# Truncate a time to a certain number of minutes
# e.g, trunc_time(5:52, 30) == 5:30
def trunc_time(time, mins, hours=None):
if hours is not None:
if hours < 1 or hours > 60:
raise ValueError("Hours %d is weird" % mins)
time = time.replace(hour = hours * (time.hour / hours))
if mins < 1 or mins > 60:
raise ValueError("Mins %d is weird" % mins)
return time.replace(minute = mins * (time.minute / mins),
second = 0,
microsecond = 0)
def long_datetime(datetime):
return datetime.ctime() + " GMT"
def median(l):
if l:
s = sorted(l)
i = len(s) / 2
return s[i]
def query_string(dict):
pairs = []
for k,v in dict.iteritems():
if v is not None:
try:
k = url_escape(unicode(k).encode('utf-8'))
v = url_escape(unicode(v).encode('utf-8'))
pairs.append(k + '=' + v)
except UnicodeDecodeError:
continue
if pairs:
return '?' + '&'.join(pairs)
else:
return ''
class UrlParser(object):
"""
Wrapper for urlparse and urlunparse for making changes to urls.
All attributes present on the tuple-like object returned by
urlparse are present on this class, and are setable, with the
exception of netloc, which is instead treated via a getter method
as a concatenation of hostname and port.
Unlike urlparse, this class allows the query parameters to be
converted to a dictionary via the query_dict method (and
correspondingly updated vi update_query). The extension of the
path can also be set and queried.
The class also contains reddit-specific functions for setting,
checking, and getting a path's subreddit. It also can convert
paths between in-frame and out of frame cname'd forms.
"""
__slots__ = ['scheme', 'path', 'params', 'query',
'fragment', 'username', 'password', 'hostname',
'port', '_url_updates', '_orig_url', '_query_dict']
valid_schemes = ('http', 'https', 'ftp', 'mailto')
cname_get = "cnameframe"
def __init__(self, url):
u = urlparse(url)
for s in self.__slots__:
if hasattr(u, s):
setattr(self, s, getattr(u, s))
self._url_updates = {}
self._orig_url = url
self._query_dict = None
def update_query(self, **updates):
"""
Can be used instead of self.query_dict.update() to add/change
query params in situations where the original contents are not
required.
"""
self._url_updates.update(updates)
@property
def query_dict(self):
"""
Parses the `params' attribute of the original urlparse and
generates a dictionary where both the keys and values have
been url_unescape'd. Any updates or changes to the resulting
dict will be reflected in the updated query params
"""
if self._query_dict is None:
def _split(param):
p = param.split('=')
return (unquote_plus(p[0]),
unquote_plus('='.join(p[1:])))
self._query_dict = dict(_split(p) for p in self.query.split('&')
if p)
return self._query_dict
def path_extension(self):
"""
Fetches the current extension of the path.
"""
return self.path.split('/')[-1].split('.')[-1]
def set_extension(self, extension):
"""
Changes the extension of the path to the provided value (the
"." should not be included in the extension as a "." is
provided)
"""
pieces = self.path.split('/')
dirs = pieces[:-1]
base = pieces[-1].split('.')
base = '.'.join(base[:-1] if len(base) > 1 else base)
if extension:
base += '.' + extension
dirs.append(base)
self.path = '/'.join(dirs)
return self
def unparse(self):
"""
Converts the url back to a string, applying all updates made
to the feilds thereof.
Note: if a host name has been added and none was present
before, will enforce scheme -> "http" unless otherwise
specified. Double-slashes are removed from the resultant
path, and the query string is reconstructed only if the
query_dict has been modified/updated.
"""
# only parse the query params if there is an update dict
q = self.query
if self._url_updates or self._query_dict is not None:
q = self._query_dict or self.query_dict
q.update(self._url_updates)
q = query_string(q).lstrip('?')
# make sure the port is not doubly specified
if self.port and ":" in self.hostname:
self.hostname = self.hostname.split(':')[0]
# if there is a netloc, there had better be a scheme
if self.netloc and not self.scheme:
self.scheme = "http"
return urlunparse((self.scheme, self.netloc,
self.path.replace('//', '/'),
self.params, q, self.fragment))
def path_has_subreddit(self):
"""
utility method for checking if the path starts with a
subreddit specifier (namely /r/ or /reddits/).
"""
return (self.path.startswith('/r/') or
self.path.startswith('/reddits/'))
def get_subreddit(self):
"""checks if the current url refers to a subreddit and returns
that subreddit object. The cases here are:
* the hostname is unset or is g.domain, in which case it
looks for /r/XXXX or /reddits. The default in this case
is Default.
* the hostname is a cname to a known subreddit.
On failure to find a subreddit, returns None.
"""
from pylons import g
from r2.models import Subreddit, Sub, NotFound, DefaultSR
try:
if not self.hostname or self.hostname.startswith(g.domain):
if self.path.startswith('/r/'):
return Subreddit._by_name(self.path.split('/')[2])
elif self.path.startswith('/reddits/'):
return Sub
else:
return DefaultSR()
elif self.hostname:
return Subreddit._by_domain(self.hostname)
except NotFound:
pass
return None
def is_reddit_url(self, subreddit = None):
"""utility method for seeing if the url is associated with
reddit as we don't necessarily want to mangle non-reddit
domains
returns true only if hostname is nonexistant, a subdomain of
g.domain, or a subdomain of the provided subreddit's cname.
"""
from pylons import g
return (not self.hostname or
self.hostname.endswith(g.domain) or
is_authorized_cname(self.hostname, g.authorized_cnames) or
(subreddit and subreddit.domain and
self.hostname.endswith(subreddit.domain)))
def path_add_subreddit(self, subreddit):
"""
Adds the subreddit's path to the path if another subreddit's
prefix is not already present.
"""
if not self.path_has_subreddit():
self.path = (subreddit.path + self.path)
return self
@property
def netloc(self):
"""
Getter method which returns the hostname:port, or empty string
if no hostname is present.
"""
if not self.hostname:
return ""
elif getattr(self, "port", None):
return self.hostname + ":" + str(self.port)
return self.hostname
def mk_cname(self, require_frame = True, subreddit = None, port = None):
"""
Converts a ?cnameframe url into the corresponding cnamed
domain if applicable. Useful for frame-busting on redirect.
"""
# make sure the url is indeed in a frame
if require_frame and not self.query_dict.has_key(self.cname_get):
return self
# fetch the subreddit and make sure it
subreddit = subreddit or self.get_subreddit()
if subreddit and subreddit.domain:
# no guarantee there was a scheme
self.scheme = self.scheme or "http"
# update the domain (preserving the port)
self.hostname = subreddit.domain
self.port = self.port or port
# and remove any cnameframe GET parameters
if self.query_dict.has_key(self.cname_get):
del self._query_dict[self.cname_get]
# remove the subreddit reference
self.path = lstrips(self.path, subreddit.path)
if not self.path.startswith('/'):
self.path = '/' + self.path
return self
def is_in_frame(self):
"""
Checks if the url is in a frame by determining if
cls.cname_get is present.
"""
return self.query_dict.has_key(self.cname_get)
def put_in_frame(self):
"""
Adds the cls.cname_get get parameter to the query string.
"""
self.update_query(**{self.cname_get:random.random()})
def __repr__(self):
return "<URL %s>" % repr(self.unparse())
def domain_permutations(self, fragments=False, subdomains=True):
"""
Takes a domain like `www.reddit.com`, and returns a list of ways
that a user might search for it, like:
* www
* reddit
* com
* www.reddit.com
* reddit.com
* com
"""
ret = set()
if self.hostname:
r = self.hostname.split('.')
if subdomains:
for x in xrange(len(r)-1):
ret.add('.'.join(r[x:len(r)]))
if fragments:
for x in r:
ret.add(x)
return ret
@classmethod
def base_url(cls, url):
u = cls(url)
# strip off any www and lowercase the hostname:
netloc = u.netloc.lower()
if len(netloc.split('.')) > 2 and netloc.startswith("www."):
netloc = netloc[4:]
# http://code.google.com/web/ajaxcrawling/docs/specification.html
fragment = u.fragment if u.fragment.startswith("!") else ""
return urlunparse((u.scheme.lower(), netloc,
u.path, u.params, u.query, fragment))
def to_js(content, callback="document.write", escape=True):
before = after = ''
if callback:
before = callback + "("
after = ");"
if escape:
content = string2js(content)
return before + content + after
class TransSet(local):
def __init__(self, items = ()):
self.set = set(items)
self.trans = False
def begin(self):
self.trans = True
def add_engine(self, engine):
if self.trans:
return self.set.add(engine.begin())
def clear(self):
return self.set.clear()
def __iter__(self):
return self.set.__iter__()
def commit(self):
for t in self:
t.commit()
self.clear()
def rollback(self):
for t in self:
t.rollback()
self.clear()
def __del__(self):
self.commit()
def pload(fname, default = None):
"Load a pickled object from a file"
try:
f = file(fname, 'r')
d = pickle.load(f)
except IOError:
d = default
else:
f.close()
return d
def psave(fname, d):
"Save a pickled object into a file"
f = file(fname, 'w')
pickle.dump(d, f)
f.close()
def unicode_safe(res):
try:
return str(res)
except UnicodeEncodeError:
try:
return unicode(res).encode('utf-8')
except UnicodeEncodeError:
return res.decode('utf-8').encode('utf-8')
def decompose_fullname(fullname):
"""
decompose_fullname("t3_e4fa") ->
(Thing, 3, 658918)
"""
from r2.lib.db.thing import Thing,Relation
if fullname[0] == 't':
type_class = Thing
elif fullname[0] == 'r':
type_class = Relation
type_id36, thing_id36 = fullname[1:].split('_')
type_id = int(type_id36,36)
id = int(thing_id36,36)
return (type_class, type_id, id)
def cols(lst, ncols):
"""divides a list into columns, and returns the
rows. e.g. cols('abcdef', 2) returns (('a', 'd'), ('b', 'e'), ('c',
'f'))"""
nrows = int(math.ceil(1.*len(lst) / ncols))
lst = lst + [None for i in range(len(lst), nrows*ncols)]
cols = [lst[i:i+nrows] for i in range(0, nrows*ncols, nrows)]
rows = zip(*cols)
rows = [filter(lambda x: x is not None, r) for r in rows]
return rows
def fetch_things(t_class,since,until,batch_fn=None,
*query_params, **extra_query_dict):
"""
Simple utility function to fetch all Things of class t_class
(spam or not, but not deleted) that were created from 'since'
to 'until'
"""
from r2.lib.db.operators import asc
if not batch_fn:
batch_fn = lambda x: x
query_params = ([t_class.c._date >= since,
t_class.c._date < until,
t_class.c._spam == (True,False)]
+ list(query_params))
query_dict = {'sort': asc('_date'),
'limit': 100,
'data': True}
query_dict.update(extra_query_dict)
q = t_class._query(*query_params,
**query_dict)
orig_rules = deepcopy(q._rules)
things = list(q)
while things:
things = batch_fn(things)
for t in things:
yield t
q._rules = deepcopy(orig_rules)
q._after(t)
things = list(q)
def fetch_things2(query, chunk_size = 100, batch_fn = None, chunks = False):
"""Incrementally run query with a limit of chunk_size until there are
no results left. batch_fn transforms the results for each chunk
before returning."""
orig_rules = deepcopy(query._rules)
query._limit = chunk_size
items = list(query)
done = False
while items and not done:
#don't need to query again at the bottom if we didn't get enough
if len(items) < chunk_size:
done = True
after = items[-1]
if batch_fn:
items = batch_fn(items)
if chunks:
yield items
else:
for i in items:
yield i
if not done:
query._rules = deepcopy(orig_rules)
query._after(after)
items = list(query)
def fix_if_broken(thing, delete = True, fudge_links = False):
from r2.models import Link, Comment, Subreddit, Message
# the minimum set of attributes that are required
attrs = dict((cls, cls._essentials)
for cls
in (Link, Comment, Subreddit, Message))
if thing.__class__ not in attrs:
raise TypeError
tried_loading = False
for attr in attrs[thing.__class__]:
try:
# try to retrieve the attribute
getattr(thing, attr)
except AttributeError:
# that failed; let's explicitly load it and try again
if not tried_loading:
tried_loading = True
thing._load(check_essentials=False)
try:
getattr(thing, attr)
except AttributeError:
if not delete:
raise
if isinstance(thing, Link) and fudge_links:
if attr == "sr_id":
thing.sr_id = 6
print "Fudging %s.sr_id to %d" % (thing._fullname,
thing.sr_id)
elif attr == "author_id":
thing.author_id = 8244672
print "Fudging %s.author_id to %d" % (thing._fullname,
thing.author_id)
else:
print "Got weird attr %s; can't fudge" % attr
if not thing._deleted:
print "%s is missing %r, deleting" % (thing._fullname, attr)
thing._deleted = True
thing._commit()
if not fudge_links:
break
def find_recent_broken_things(from_time = None, to_time = None,
delete = False):
"""
Occasionally (usually during app-server crashes), Things will
be partially written out to the database. Things missing data
attributes break the contract for these things, which often
breaks various pages. This function hunts for and destroys
them as appropriate.
"""
from r2.models import Link, Comment
from r2.lib.db.operators import desc
from pylons import g
from_time = from_time or timeago('1 hour')
to_time = to_time or datetime.now(g.tz)
for cls in (Link, Comment):
q = cls._query(cls.c._date > from_time,
cls.c._date < to_time,
data=True,
sort=desc('_date'))
for thing in fetch_things2(q):
fix_if_broken(thing, delete = delete)
def timeit(func):
"Run some function, and return (RunTimeInSeconds,Result)"
before=time.time()
res=func()
return (time.time()-before,res)
def lineno():
"Returns the current line number in our program."
import inspect
print "%s\t%s" % (datetime.now(),inspect.currentframe().f_back.f_lineno)
def IteratorFilter(iterator, fn):
for x in iterator:
if fn(x):
yield x
def UniqueIterator(iterator, key = lambda x: x):
"""
Takes an iterator and returns an iterator that returns only the
first occurence of each entry
"""
so_far = set()
def no_dups(x):
k = key(x)
if k in so_far:
return False
else:
so_far.add(k)
return True
return IteratorFilter(iterator, no_dups)
def modhash(user, rand = None, test = False):
return user.name
def valid_hash(user, hash):
return True
def check_cheating(loc):
pass
def vote_hash(user, thing, note='valid'):
return user.name
def valid_vote_hash(hash, user, thing):
return True
def safe_eval_str(unsafe_str):
return unsafe_str.replace('\\x3d', '=').replace('\\x26', '&')
rx_whitespace = re.compile('\s+', re.UNICODE)
rx_notsafe = re.compile('\W+', re.UNICODE)
rx_underscore = re.compile('_+', re.UNICODE)
def title_to_url(title, max_length = 50):
"""Takes a string and makes it suitable for use in URLs"""
title = _force_unicode(title) #make sure the title is unicode
title = rx_whitespace.sub('_', title) #remove whitespace
title = rx_notsafe.sub('', title) #remove non-printables
title = rx_underscore.sub('_', title) #remove double underscores
title = title.strip('_') #remove trailing underscores
title = title.lower() #lowercase the title
if len(title) > max_length:
#truncate to nearest word
title = title[:max_length]
last_word = title.rfind('_')
if (last_word > 0):
title = title[:last_word]
return title or "_"
def dbg(s):
import sys
sys.stderr.write('%s\n' % (s,))
def trace(fn):
def new_fn(*a,**kw):
ret = fn(*a,**kw)
dbg("Fn: %s; a=%s; kw=%s\nRet: %s"
% (fn,a,kw,ret))
return ret
return new_fn
def common_subdomain(domain1, domain2):
if not domain1 or not domain2:
return ""
domain1 = domain1.split(":")[0]
domain2 = domain2.split(":")[0]
if len(domain1) > len(domain2):
domain1, domain2 = domain2, domain1
if domain1 == domain2:
return domain1
else:
dom = domain1.split(".")
for i in range(len(dom), 1, -1):
d = '.'.join(dom[-i:])
if domain2.endswith(d):
return d
return ""
def interleave_lists(*args):
max_len = max(len(x) for x in args)
for i in xrange(max_len):
for a in args:
if i < len(a):
yield a[i]
def link_from_url(path, filter_spam = False, multiple = True):
from pylons import c
from r2.models import IDBuilder, Link, Subreddit, NotFound
if not path:
return
try:
links = Link._by_url(path, c.site)
except NotFound:
return [] if multiple else None
return filter_links(tup(links), filter_spam = filter_spam,
multiple = multiple)
def filter_links(links, filter_spam = False, multiple = True):
# run the list through a builder to remove any that the user
# isn't allowed to see
from pylons import c
from r2.models import IDBuilder, Link, Subreddit, NotFound
links = IDBuilder([link._fullname for link in links],
skip = False).get_items()[0]
if not links:
return
if filter_spam:
# first, try to remove any spam
links_nonspam = [ link for link in links
if not link._spam ]
if links_nonspam:
links = links_nonspam
# if it occurs in one or more of their subscriptions, show them
# that one first
subs = set(Subreddit.user_subreddits(c.user, limit = None))
def cmp_links(a, b):
if a.sr_id in subs and b.sr_id not in subs:
return -1
elif a.sr_id not in subs and b.sr_id in subs:
return 1
else:
return cmp(b._hot, a._hot)
links = sorted(links, cmp = cmp_links)
# among those, show them the hottest one
return links if multiple else links[0]
def link_duplicates(article):
from r2.models import Link, NotFound
# don't bother looking it up if the link doesn't have a URL anyway
if getattr(article, 'is_self', False):
return []
try:
links = tup(Link._by_url(article.url, None))
except NotFound:
links = []
duplicates = [ link for link in links
if link._fullname != article._fullname ]
return duplicates
class TimeoutFunctionException(Exception):
pass
class TimeoutFunction:
"""Force an operation to timeout after N seconds. Works with POSIX
signals, so it's not safe to use in a multi-treaded environment"""
def __init__(self, function, timeout):
self.timeout = timeout
self.function = function
def handle_timeout(self, signum, frame):
raise TimeoutFunctionException()
def __call__(self, *args):
# can only be called from the main thread
old = signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.timeout)
try:
result = self.function(*args)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old)
return result
def make_offset_date(start_date, interval, future = True,
business_days = False):
"""
Generates a date in the future or past "interval" days from start_date.
Can optionally give weekends no weight in the calculation if
"business_days" is set to true.
"""
if interval is not None:
interval = int(interval)
if business_days:
weeks = interval / 7
dow = start_date.weekday()
if future:
future_dow = (dow + interval) % 7
if dow > future_dow or future_dow > 4:
weeks += 1
else:
future_dow = (dow - interval) % 7
if dow < future_dow or future_dow > 4:
weeks += 1
interval += 2 * weeks;
if future:
return start_date + timedelta(interval)
return start_date - timedelta(interval)
return start_date
def to_csv(table):
# commas and linebreaks must result in a quoted string
def quote_commas(x):
if ',' in x or '\n' in x:
return u'"%s"' % x.replace('"', '""')
return x
return u"\n".join(u','.join(quote_commas(y) for y in x)
for x in table)
def in_chunks(it, size=25):
chunk = []
it = iter(it)
try:
while True:
chunk.append(it.next())
if len(chunk) >= size:
yield chunk
chunk = []
except StopIteration:
if chunk:
yield chunk
r_subnet = re.compile("\A(\d+\.\d+)\.\d+\.\d+\Z")
def ip_and_slash16(req):
ip = req.ip
if ip is None:
raise ValueError("request.ip is None")
m = r_subnet.match(ip)
if m is None:
raise ValueError("Couldn't parse IP %s" % ip)
slash16 = m.group(1) + '.x.x'
return (ip, slash16)
def spaceout(items, targetseconds,
minsleep = 0, die = False,
estimate = None):
"""Given a list of items and a function to apply to them, space
the execution out over the target number of seconds and
optionally stop when we're out of time"""
targetseconds = float(targetseconds)
state = [1.0]
if estimate is None:
try:
estimate = len(items)
except TypeError:
# if we can't come up with an estimate, the best we can do
# is just enforce the minimum sleep time (and the max
# targetseconds if die==True)
pass
mean = lambda lst: sum(float(x) for x in lst)/float(len(lst))
beginning = datetime.now()
for item in items:
start = datetime.now()
yield item
end = datetime.now()
took_delta = end - start
took = (took_delta.days * 60 * 24
+ took_delta.seconds
+ took_delta.microseconds/1000000.0)
state.append(took)
if len(state) > 10:
del state[0]
if die and end > beginning + timedelta(seconds=targetseconds):
# we ran out of time, ignore the rest of the iterator
break
if estimate is None:
if minsleep:
# we have no idea how many items we're going to get
sleep(minsleep)
else:
sleeptime = max((targetseconds / estimate) - mean(state),
minsleep)
if sleeptime > 0:
sleep(sleeptime)
def progress(it, verbosity=100, key=repr, estimate=None, persec=True):
"""An iterator that yields everything from `it', but prints progress
information along the way, including time-estimates if
possible"""
from itertools import islice
from datetime import datetime
import sys
now = start = datetime.now()
elapsed = start - start
# try to guess at the estimate if we can
if estimate is None:
try:
estimate = len(it)
except:
pass
def timedelta_to_seconds(td):
return td.days * (24*60*60) + td.seconds + (float(td.microseconds) / 1000000)
def format_timedelta(td, sep=''):
ret = []
s = timedelta_to_seconds(td)
if s < 0:
neg = True
s *= -1
else:
neg = False
if s >= (24*60*60):
days = int(s//(24*60*60))
ret.append('%dd' % days)
s -= days*(24*60*60)
if s >= 60*60:
hours = int(s//(60*60))
ret.append('%dh' % hours)
s -= hours*(60*60)
if s >= 60:
minutes = int(s//60)
ret.append('%dm' % minutes)
s -= minutes*60
if s >= 1:
seconds = int(s)
ret.append('%ds' % seconds)
s -= seconds
if not ret:
return '0s'
return ('-' if neg else '') + sep.join(ret)
def format_datetime(dt, show_date=False):
if show_date:
return dt.strftime('%Y-%m-%d %H:%M')
else:
return dt.strftime('%H:%M:%S')
def deq(dt1, dt2):
"Indicates whether the two datetimes' dates describe the same (day,month,year)"
d1, d2 = dt1.date(), dt2.date()
return ( d1.day == d2.day
and d1.month == d2.month
and d1.year == d2.year)
sys.stderr.write('Starting at %s\n' % (start,))
# we're going to islice it so we need to start an iterator
it = iter(it)
seen = 0
while True:
this_chunk = 0
thischunk_started = datetime.now()
# the simple bit: just iterate and yield
for item in islice(it, verbosity):
this_chunk += 1
seen += 1
yield item
if this_chunk < verbosity:
# we're done, the iterator is empty
break
now = datetime.now()
elapsed = now - start
thischunk_seconds = timedelta_to_seconds(now - thischunk_started)
if estimate:
# the estimate is based on the total number of items that
# we've processed in the total amount of time that's
# passed, so it should smooth over momentary spikes in
# speed (but will take a while to adjust to long-term
# changes in speed)
remaining = ((elapsed/seen)*estimate)-elapsed
completion = now + remaining
count_str = ('%d/%d %.2f%%'
% (seen, estimate, float(seen)/estimate*100))
completion_str = format_datetime(completion, not deq(completion,now))
estimate_str = (' (%s remaining; completion %s)'
% (format_timedelta(remaining),
completion_str))
else:
count_str = '%d' % seen
estimate_str = ''
if key:
key_str = ': %s' % key(item)
else:
key_str = ''
# unlike the estimate, the persec count is the number per
# second for *this* batch only, without smoothing
if persec and thischunk_seconds > 0:
persec_str = ' (%.1f/s)' % (float(this_chunk)/thischunk_seconds,)
else:
persec_str = ''
sys.stderr.write('%s%s, %s%s%s\n'
% (count_str, persec_str,
format_timedelta(elapsed), estimate_str, key_str))
now = datetime.now()
elapsed = now - start
elapsed_seconds = timedelta_to_seconds(elapsed)
if persec and seen > 0 and elapsed_seconds > 0:
persec_str = ' (@%.1f/sec)' % (float(seen)/elapsed_seconds)
else:
persec_str = ''
sys.stderr.write('Processed %d%s items in %s..%s (%s)\n'
% (seen,
persec_str,
format_datetime(start, not deq(start, now)),
format_datetime(now, not deq(start, now)),
format_timedelta(elapsed)))
class Hell(object):
def __str__(self):
return "boom!"
class Bomb(object):
@classmethod
def __getattr__(cls, key):
raise Hell()
@classmethod
def __setattr__(cls, key, val):
raise Hell()
@classmethod
def __repr__(cls):
raise Hell()
def strordict_fullname(item, key='fullname'):
"""Sometimes we migrate AMQP queues from simple strings to pickled
dictionaries. During the migratory period there may be items in
the queue of both types, so this function tries to detect which
the item is. It shouldn't really be used on a given queue for more
than a few hours or days"""
try:
d = pickle.loads(item)
except:
d = {key: item}
if (not isinstance(d, dict)
or key not in d
or not isinstance(d[key], str)):
raise ValueError('Error trying to migrate %r (%r)'
% (item, d))
return d
def thread_dump(*a):
import sys, traceback
from datetime import datetime
sys.stderr.write('%(t)s Thread Dump @%(d)s %(t)s\n' % dict(t='*'*15,
d=datetime.now()))
for thread_id, stack in sys._current_frames().items():
sys.stderr.write('\t-- Thread ID: %s--\n' % (thread_id,))
for filename, lineno, fnname, line in traceback.extract_stack(stack):
sys.stderr.write('\t\t%(filename)s(%(lineno)d): %(fnname)s\n'
% dict(filename=filename, lineno=lineno, fnname=fnname))
sys.stderr.write('\t\t\t%(line)s\n' % dict(line=line))
def constant_time_compare(actual, expected):
"""
Returns True if the two strings are equal, False otherwise
The time taken is dependent on the number of charaters provided
instead of the number of characters that match.
"""
actual_len = len(actual)
expected_len = len(expected)
result = actual_len ^ expected_len
if expected_len > 0:
for i in xrange(actual_len):
result |= ord(actual[i]) ^ ord(expected[i % expected_len])
return result == 0