diff --git a/poapbot/db/database.py b/poapbot/db/database.py index 98821c2..5fff1ca 100644 --- a/poapbot/db/database.py +++ b/poapbot/db/database.py @@ -106,20 +106,12 @@ class POAPDatabase: except NoMatch: return [] - async def get_claims_by_event_id(self, event_id: str, select_related: List[str] = None, offset: int = 0, limit: int = 0) -> List[Claim]: + async def get_claims_by_event_id(self, event_id: str, reserved: bool = None, select_related: List[str] = None, offset: int = 0, limit: int = 0) -> List[Claim]: + filter_kwargs = dict(event__id__exact=event_id) + if reserved is not None: + filter_kwargs['reserved__exact'] = reserved try: - q = Claim.objects.filter(event__id__exact=event_id).offset(offset) - if select_related: - q = q.select_related(select_related) - if limit > 0: - q = q.limit(limit) - return await q.all() - except NoMatch: - return [] - - async def get_available_claims_by_event_id(self, event_id: str, select_related: List[str] = None, offset: int = 0, limit: int = 0) -> List[Claim]: - try: - q = Claim.objects.filter(event__id__exact=event_id, reserved__exact=False).offset(offset) + q = Claim.objects.filter(**filter_kwargs).offset(offset) if select_related: q = q.select_related(select_related) if limit > 0: @@ -262,9 +254,12 @@ class POAPDatabase: ## Bulk - async def set_claims_bulk(self, event_id: str, usernames: List[str]) -> List[Claim]: + async def set_claims_bulk(self, event_id: str, usernames: List[str]): async with self.db.transaction(): - available_claims = await self.get_available_claims_by_event_id(event_id, limit=len(usernames)) + reserved_claims = await self.get_claims_by_event_id(event_id, reserved=True, select_related=['attendee']) + existing_attendees = set([claim.attendee.username for claim in reserved_claims]) + usernames = [u for u in usernames if u not in existing_attendees] + available_claims = await self.get_claims_by_event_id(event_id, reserved=False, limit=len(usernames)) if len(available_claims) < len(usernames): raise DoesNotExist(f'Insufficient available claims: {len(available_claims)} available, {len(usernames)} requested') for n, username in enumerate(usernames):