diff --git a/r2/r2/models/traffic.py b/r2/r2/models/traffic.py index 4fd94646c..c8ff2f7cc 100644 --- a/r2/r2/models/traffic.py +++ b/r2/r2/models/traffic.py @@ -39,8 +39,14 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.schema import Column from sqlalchemy.sql.expression import desc, distinct -from sqlalchemy.sql.functions import sum -from sqlalchemy.types import DateTime, Integer, String, BigInteger +from sqlalchemy.sql.functions import sum as sa_sum +from sqlalchemy.types import ( + BigInteger, + DateTime, + Integer, + String, + TypeDecorator, +) from r2.lib.memoize import memoize from r2.lib.utils import timedelta_by_name, tup @@ -252,6 +258,32 @@ def top_last_month(cls, key): for r in q.all()] +class CoerceToLong(TypeDecorator): + # source: + # https://groups.google.com/forum/?fromgroups=#!topic/sqlalchemy/3fipkThttQA + + impl = BigInteger + + def process_result_value(self, value, dialect): + if value is not None: + value = long(value) + return value + + +def sum(column): + """Wrapper around sqlalchemy.sql.functions.sum to handle BigInteger. + + sqlalchemy returns a Decimal for sum over BigInteger values. Detect the + column type and coerce to long if it's a BigInteger. + + """ + + if isinstance(column.property.columns[0].type, BigInteger): + return sa_sum(column, type_=CoerceToLong) + else: + return sa_sum(column) + + def totals(cls, interval): """Aggregate sitewide totals for self-serve promotion traffic.