mirror of
https://github.com/stake-house/poap-reddit-bot.git
synced 2026-01-10 06:27:56 -05:00
major refactoring
This commit is contained in:
16
app.py
16
app.py
@@ -11,16 +11,16 @@ import pandas as pd
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from models.settings import RedditSettings, DBSettings, FastAPISettings
|
||||
from poapbot.models.settings import RedditSettings, DBSettings, FastAPISettings
|
||||
|
||||
SETTINGS = yaml.safe_load(open('settings.yaml', 'r'))
|
||||
REDDIT_SETTINGS = RedditSettings.parse_obj(SETTINGS['reddit'])
|
||||
DB_SETTINGS = DBSettings.parse_obj(SETTINGS['db'])
|
||||
API_SETTINGS = FastAPISettings.parse_obj(SETTINGS['fastapi'])
|
||||
|
||||
from scraper import RedditScraper
|
||||
from bot import RedditBot
|
||||
from models import metadata, database, Event, Attendee, Claim
|
||||
from poapbot.scraper import RedditScraper
|
||||
from poapbot.bot import RedditBot
|
||||
from poapbot.models import metadata, database, Event, Attendee, Claim
|
||||
|
||||
engine = sqlalchemy.create_engine(DB_SETTINGS.url)
|
||||
metadata.create_all(engine)
|
||||
@@ -67,12 +67,12 @@ async def shutdown():
|
||||
tags=['admin'],
|
||||
response_model=Event
|
||||
)
|
||||
async def create_event(request: Request, id: str, expiry_date: datetime, description: Optional[str] = ""):
|
||||
async def create_event(request: Request, id: str, name: str, code: str, expiry_date: datetime, description: Optional[str] = ""):
|
||||
existing_event = await Event.objects.get_or_none(pk=id)
|
||||
if existing_event:
|
||||
raise HTTPException(status_code=409, detail=f'Event with id "{id}" already exists')
|
||||
else:
|
||||
event = Event(id=id, description=description, expiry_date=expiry_date)
|
||||
event = Event(id=id, name=name, code=code, description=description, expiry_date=expiry_date)
|
||||
await event.save()
|
||||
return event
|
||||
|
||||
@@ -122,8 +122,8 @@ async def upload_claims(request: Request, event_id: str, file: UploadFile = File
|
||||
claim_map = df.set_index('username')['link'].to_dict()
|
||||
|
||||
existing_attendees = await Attendee.objects.filter(id__in=usernames).all()
|
||||
existing_claims = await Claim.objects.filter(ormar.and_(event__id__exact=event_id, attendee__id__in=usernames)).all()
|
||||
existing_claim_usernames = [c.attendee.id for c in existing_claims]
|
||||
existing_claims = await Claim.objects.filter(ormar.and_(event__id__exact=event_id, attendee__username__in=usernames)).all()
|
||||
existing_claim_usernames = [c.attendee.username for c in existing_claims]
|
||||
|
||||
new_attendees = [Attendee(id=username) for username in usernames if username not in [p.id for p in existing_attendees]]
|
||||
await Attendee.objects.bulk_create(new_attendees)
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
from asyncpraw import Reddit
|
||||
from asyncpraw.models import Message
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from models import Claim, RequestMessage
|
||||
|
||||
class RedditBot:
|
||||
|
||||
def __init__(self, client: Reddit):
|
||||
self.client = client
|
||||
|
||||
async def message_handler(self, message: Message):
|
||||
request_message = await RequestMessage.objects.get_or_none(id=message.id)
|
||||
if request_message:
|
||||
return
|
||||
else:
|
||||
request_message = RequestMessage(
|
||||
id=message.id,
|
||||
username=message.author.name,
|
||||
created=message.created_utc,
|
||||
subject=message.subject,
|
||||
body=message.body
|
||||
)
|
||||
await request_message.save()
|
||||
|
||||
if message.body and 'ping' in message.body.lower():
|
||||
await message.reply('pong')
|
||||
elif message.body:
|
||||
code = message.body.split(' ')[0].lower()
|
||||
claim = await Claim.objects.filter(event__id__exact=code, attendee__id__exact=message.author.name.lower()).get_or_none()
|
||||
if claim:
|
||||
if claim.event.expiry_date < datetime.utcnow():
|
||||
await message.reply(f'Sorry, your claim for {claim.event.id} has expired')
|
||||
else:
|
||||
await message.reply(f'Your claim link for {claim.event.id} is {claim.link}')
|
||||
claim.notified = True
|
||||
await claim.update()
|
||||
else:
|
||||
await message.reply(f'Sorry, you do not have a claim for {code}')
|
||||
await message.mark_read()
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
try:
|
||||
async for item in self.client.inbox.stream():
|
||||
if isinstance(item, Message):
|
||||
asyncio.create_task(self.message_handler(item))
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except:
|
||||
logging.error('Encountered error in run loop', exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
77
poapbot/bot/__init__.py
Normal file
77
poapbot/bot/__init__.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from asyncpraw import Reddit
|
||||
from asyncpraw.models import Message
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import ormar
|
||||
|
||||
from ..models import database, Event, Claim, Attendee, RequestMessage
|
||||
|
||||
logger = logging.getLogger('redditbot')
|
||||
|
||||
class RedditBot:
|
||||
|
||||
def __init__(self, client: Reddit):
|
||||
self.client = client
|
||||
|
||||
async def message_handler(self, message: Message):
|
||||
username = message.author.name.lower() if message.author else None
|
||||
code = message.body.split(' ')[0].lower()
|
||||
|
||||
if code == 'ping':
|
||||
await message.reply('pong')
|
||||
await message.mark_read()
|
||||
logger.info('Received ping, sending pong')
|
||||
return
|
||||
|
||||
request_message = await RequestMessage.objects.get_or_none(id=message.id)
|
||||
if request_message:
|
||||
logger.debug(f'Request message {request_message.id} has already been processed, skipping')
|
||||
await message.mark_read()
|
||||
return
|
||||
else:
|
||||
request_message = RequestMessage(
|
||||
id=message.id,
|
||||
username=message.author.name,
|
||||
created=message.created_utc,
|
||||
subject=message.subject,
|
||||
body=message.body
|
||||
)
|
||||
await request_message.save()
|
||||
|
||||
event = await Event.objects.get_or_none(code=code)
|
||||
if event:
|
||||
expired = event.expired()
|
||||
claim = await Claim.objects.filter(ormar.and_(attendee__username__exact=username, event__id__exact=event.id)).get_or_none()
|
||||
if claim:
|
||||
comment = await message.reply(f'Your claim link for {claim.event.name} is {claim.link}')
|
||||
elif not expired:
|
||||
claim = await Claim.objects.filter(ormar.and_(attendee__id__isnull=True, reserved__exact=False, event__id__exact=event.id)).get_or_none()
|
||||
if claim:
|
||||
async with database.transaction():
|
||||
attendee = await Attendee.objects.get_or_create(username=username)
|
||||
await event.attendees.add(attendee)
|
||||
claim.attendee = attendee
|
||||
claim.reserved = True
|
||||
await claim.update()
|
||||
comment = await message.reply(f'Your claim link for {claim.event.name} is {claim.link}')
|
||||
else:
|
||||
comment = await message.reply(f'Sorry, there are no more claims available for {event.name}')
|
||||
elif expired:
|
||||
comment = await message.reply(f'Sorry, event {event.name} has expired')
|
||||
else:
|
||||
comment = await message.reply(f'Invalid event code: {code}')
|
||||
|
||||
await message.mark_read()
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
try:
|
||||
async for item in self.client.inbox.stream():
|
||||
if isinstance(item, Message):
|
||||
await self.message_handler(item)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except:
|
||||
logging.error('Encountered error in run loop', exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
@@ -9,8 +9,8 @@ class Attendee(ormar.Model):
|
||||
|
||||
class Meta(BaseMeta):
|
||||
tablename = "attendees"
|
||||
constraints = [ormar.UniqueColumns('username')]
|
||||
|
||||
id: str = ormar.String(primary_key=True, max_length=100)
|
||||
id: str = ormar.Integer(primary_key=True)
|
||||
username: str = ormar.String(max_length=100)
|
||||
channel: str = ormar.String(max_length=100, choices=['DISCORD','REDDIT'])
|
||||
events: Optional[List[Event]] = ormar.ManyToMany(Event)
|
||||
@@ -5,7 +5,6 @@ from typing import List, Optional
|
||||
from . import BaseMeta
|
||||
from .event import Event
|
||||
from .attendee import Attendee
|
||||
from .message import RequestMessage
|
||||
|
||||
class Claim(ormar.Model):
|
||||
|
||||
@@ -14,7 +13,7 @@ class Claim(ormar.Model):
|
||||
constraints = [ormar.UniqueColumns('attendee','event')]
|
||||
|
||||
id: int = ormar.Integer(primary_key=True)
|
||||
attendee: Attendee = ormar.ForeignKey(Attendee)
|
||||
attendee: Attendee = ormar.ForeignKey(Attendee, nullable=True)
|
||||
event: Event = ormar.ForeignKey(Event)
|
||||
link: str = ormar.String(max_length=256)
|
||||
notified: Optional[bool] = ormar.Boolean(default=False)
|
||||
reserved: Optional[bool] = ormar.Boolean(default=False)
|
||||
@@ -7,7 +7,13 @@ from . import BaseMeta
|
||||
class Event(ormar.Model):
|
||||
class Meta(BaseMeta):
|
||||
tablename = "events"
|
||||
constraints = [ormar.UniqueColumns('code')]
|
||||
|
||||
id: str = ormar.String(primary_key=True, max_length=100)
|
||||
name: str = ormar.String(max_length=256)
|
||||
description: str = ormar.String(max_length=256)
|
||||
expiry_date: datetime = ormar.DateTime()
|
||||
code: str = ormar.String(max_length=256)
|
||||
expiry_date: datetime = ormar.DateTime()
|
||||
|
||||
def expired(self):
|
||||
return self.expiry_date < datetime.utcnow()
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from . import BaseMeta
|
||||
from .claim import Claim
|
||||
|
||||
class RequestMessage(ormar.Model):
|
||||
|
||||
@@ -23,4 +24,5 @@ class ResponseMessage(ormar.Model):
|
||||
id: int = ormar.String(primary_key=True, max_length=100)
|
||||
username: str = ormar.String(max_length=100)
|
||||
created: datetime = ormar.DateTime()
|
||||
body: str = ormar.String(max_length=1024)
|
||||
body: str = ormar.String(max_length=1024)
|
||||
claim: Claim = ormar.ForeignKey(Claim, nullable=True)
|
||||
Reference in New Issue
Block a user