diff --git a/alembic/env.py b/alembic/env.py index 40f4df6..6651868 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -34,6 +34,7 @@ from models.groupphaseuser import Groupphaseuser from models.speechlist import Speechlistmodel from models.game import Game +from models.tournament import Tournament, TournamentTeam, TournamentTeamMember # set metadata target_metadata = Base.metadata diff --git a/alembic/versions/370497031613_tournaments.py b/alembic/versions/370497031613_tournaments.py new file mode 100644 index 0000000..23dc3db --- /dev/null +++ b/alembic/versions/370497031613_tournaments.py @@ -0,0 +1,56 @@ +"""Tournaments + +Revision ID: 370497031613 +Revises: f1a5ab97b1db +Create Date: 2020-11-02 09:32:20.371475 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '370497031613' +down_revision = 'f1a5ab97b1db' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tournaments', + sa.Column('message_id', sa.BigInteger(), nullable=False), + sa.Column('game_role', sa.BigInteger(), nullable=False), + sa.Column('team_size', sa.Integer(), nullable=False), + sa.Column('team_count', sa.Integer(), nullable=False), + sa.Column('registration_expires', sa.DateTime(), nullable=False), + sa.Column('voice_channel_id', sa.BigInteger(), nullable=False), + sa.Column('text_channel_id', sa.BigInteger(), nullable=False), + sa.Column('role_id', sa.BigInteger(), nullable=False), + sa.PrimaryKeyConstraint('message_id') + ) + op.create_table('tournament_teams', + sa.Column('reaction', sa.Text(), nullable=False), + sa.Column('tournament_message_id', sa.BigInteger(), nullable=False), + sa.Column('voice_channel_id', sa.BigInteger(), nullable=False), + sa.Column('text_channel_id', sa.BigInteger(), nullable=False), + sa.Column('role_id', sa.BigInteger(), nullable=False), + sa.ForeignKeyConstraint(['tournament_message_id'], ['tournaments.message_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('reaction', 'tournament_message_id') + ) + op.create_table('tournament_team_members', + sa.Column('member_id', sa.BigInteger(), nullable=False), + sa.Column('team_reaction', sa.Text(), nullable=False), + sa.Column('tournament_message_id', sa.BigInteger(), nullable=False), + sa.ForeignKeyConstraint(['team_reaction', 'tournament_message_id'], ['tournament_teams.reaction', 'tournament_teams.tournament_message_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('member_id', 'team_reaction', 'tournament_message_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tournament_team_members') + op.drop_table('tournament_teams') + op.drop_table('tournaments') + # ### end Alembic commands ### diff --git a/extensions/tournament.py b/extensions/tournament.py new file mode 100644 index 0000000..36f305b --- /dev/null +++ b/extensions/tournament.py @@ -0,0 +1,113 @@ +from datetime import datetime, timedelta + +import discord +from discord.ext import commands + +from db import db_session +from extensions.util import remove_reaction, create_role_and_channels +from models.tournament import Tournament + + +class Tournaments(commands.Cog, name="Tournaments"): + + def __init__(self, bot): + self.bot: discord.Client = bot + + @commands.command() + async def tournament(self, ctx, role: discord.Role, team_size: int, team_count: int, period: int = 15): + expires = datetime.now() + timedelta(minutes=period) + embed = discord.Embed( + color=discord.Color.blue(), + title=f"A new {role.name} Tournament was started", + ) + embed.add_field(name="Teams", + value=f"This tournament will have {team_size} member(s) per team " + f"and a maximum of {team_count} teams.") + embed.add_field(name="Registration", + value=f"If you want to enter a new team react with a new Reaction.\n" + f"If you want to enter an existing team click in its Reaction.\n" + f"To exit a team remove your Reaction.") + embed.add_field(name="Deadline", + value=f"Registration will be closed when {team_count} **full** teams are formed\n" + f"OR\n" + f"at {expires:%H:%M}.") + msg = await ctx.send(embed=embed) + role, voice, text = await create_role_and_channels(ctx.guild, f"{role.name} Tournament Participant", + f" {role.name} Tournament") + Tournament(message_id=msg.id, game_role_id=role.id, size=team_size, count=team_count, expires=expires, + voice_id=voice.id, text_id=text.id, role_id=role.id) + db_session.commit() + + @commands.Cog.listener() + async def on_raw_reaction_add(self, event: discord.RawReactionActionEvent): + tournament = Tournament.get(event.message_id) + if tournament is None: + return + guild = self.guild(event.guild_id) + if tournament.is_player_in_tournament(event.user_id): + await remove_reaction(guild, event) + return + tournament_role = guild.get_role(tournament.role_id) + reaction = event.emoji.name + team = tournament.get_team(reaction) + team_role: discord.Role = None + if team is None: + team_role, voice, text = await create_role_and_channels(guild, f"Team {reaction} Member", + f"Team {reaction}") + team = tournament.add_team(reaction=reaction, voice_id=voice.id, text_id=text.id, role_id=team_role.id) + if team_role is None: + team_role = guild.get_role(team.role_id) + if len(team.members) >= tournament.team_size: + await remove_reaction(guild, event) + return + await event.member.add_roles(tournament_role, team_role) + team.add_member(event.user_id) + db_session.commit() + + @commands.Cog.listener() + async def on_raw_reaction_remove(self, event: discord.RawReactionActionEvent): + tournament = Tournament.get(event.message_id) + if tournament is None: + return + team = tournament.get_team(event.emoji.name) + if team is None: + return + if not team.has_member(event.user_id): + return + guild = self.guild(event.guild_id) + member: discord.Member = guild.get_member(event.user_id) + tournament_role: discord.Role = guild.get_role(tournament.role_id) + team_role = guild.get_role(team.role_id) + await member.remove_roles(tournament_role, team_role) + team.remove_member(event.user_id) + if len(team.members) == 0: + await guild.get_channel(team.voice_channel_id).delete() + await guild.get_channel(team.text_channel_id).delete() + await team_role.delete() + tournament.remove_team(team.reaction) + db_session.commit() + + @commands.Cog.listener() + async def on_raw_message_edit(self, event: discord.RawMessageUpdateEvent): + tournament = Tournament.get(event.message_id) + if tournament is None: + return + if len(event.data["embeds"]) > 0: + return + guild = self.guild(int(event.data["guild_id"])) + await guild.get_channel(tournament.voice_channel_id).delete() + await guild.get_channel(tournament.text_channel_id).delete() + await guild.get_role(tournament.role_id).delete() + for team in tournament.teams: + await guild.get_channel(team.voice_channel_id).delete() + await guild.get_channel(team.text_channel_id).delete() + await guild.get_role(team.role_id).delete() + Tournament.delete(event.message_id) + db_session.commit() + + def guild(self, guild_id: int) -> discord.Guild: + return self.bot.get_guild(guild_id) + + +def setup(bot): + bot.add_cog(Tournaments(bot)) diff --git a/extensions/util.py b/extensions/util.py new file mode 100644 index 0000000..1d946fe --- /dev/null +++ b/extensions/util.py @@ -0,0 +1,20 @@ +import discord + + +async def remove_reaction(guild: discord.Guild, payload: discord.RawReactionActionEvent): + channel: discord.TextChannel = guild.get_channel(payload.channel_id) + msg = await channel.fetch_message(payload.message_id) + await msg.remove_reaction(payload.emoji, payload.member) + return + + +async def create_role_and_channels(guild: discord.Guild, role_name: str, channel_name: str) -> \ + (discord.Role, discord.VoiceChannel, discord.TextChannel): + role = await guild.create_role(name=role_name) + overwrites = { + role: discord.PermissionOverwrite(view_channel=True, read_messages=True, connect=True), + guild.default_role: discord.PermissionOverwrite(view_channel=False, read_messages=False, connect=False) + } + voice = await guild.create_voice_channel(name=channel_name, overwrites=overwrites) + text = await guild.create_text_channel(name=channel_name, overwrites=overwrites) + return role, voice, text diff --git a/models/tournament.py b/models/tournament.py new file mode 100644 index 0000000..109e691 --- /dev/null +++ b/models/tournament.py @@ -0,0 +1,125 @@ +from datetime import datetime + +from sqlalchemy import Column, Integer, BigInteger, DateTime, ForeignKey, Text, PrimaryKeyConstraint, \ + ForeignKeyConstraint +from sqlalchemy.orm import relationship + +from db import db_session +from models.base import Base + + +class Tournament(Base): + __tablename__ = "tournaments" + + message_id = Column(BigInteger, primary_key=True) + game_role = Column(BigInteger, nullable=False) + team_size = Column(Integer, nullable=False) + team_count = Column(Integer, nullable=False) + registration_expires = Column(DateTime, nullable=False) + voice_channel_id = Column(BigInteger, nullable=False) + text_channel_id = Column(BigInteger, nullable=False) + role_id = Column(BigInteger, nullable=False) + + teams = relationship("TournamentTeam", back_populates="tournament") + + def __init__(self, message_id: int, game_role_id: int, size: int, count: int, expires: datetime, + voice_id: int, text_id: int, role_id: int): + self.message_id = message_id + self.game_role = game_role_id + self.team_size = size + self.team_count = count + self.registration_expires = expires + self.voice_channel_id = voice_id + self.text_channel_id = text_id + self.role_id = role_id + db_session.add(self) + + @classmethod + def get(cls, message_id: int) -> "Tournament": + return db_session.query(Tournament).filter(Tournament.message_id == message_id).first() + + @classmethod + def delete(cls, message_id: int): + db_session.query(Tournament).filter(Tournament.message_id == message_id).delete() + + def get_team(self, reaction: str) -> "TournamentTeam": + return db_session.query(TournamentTeam) \ + .filter(TournamentTeam.tournament_message_id == self.message_id) \ + .filter(TournamentTeam.reaction == reaction) \ + .first() + + def add_team(self, reaction: str, voice_id: int, text_id: int, role_id: int) -> "TournamentTeam": + return TournamentTeam(reaction, self.message_id, voice_id, text_id, role_id) + + def remove_team(self, reaction: str): + db_session.query(TournamentTeam) \ + .filter(TournamentTeam.reaction == reaction and TournamentTeam.tournament_message_id == self.message_id) \ + .delete() + + def is_player_in_tournament(self, member_id: int) -> bool: + member = db_session.query(TournamentTeamMember) \ + .join(TournamentTeam, TournamentTeamMember.team_reaction == TournamentTeam.reaction) \ + .join(Tournament, TournamentTeam.tournament_message_id == Tournament.message_id) \ + .filter(TournamentTeamMember.member_id == member_id) \ + .filter(Tournament.message_id == self.message_id).first() + return member is not None + + +class TournamentTeam(Base): + __tablename__ = "tournament_teams" + + reaction = Column(Text, nullable=False) + tournament_message_id = Column(BigInteger, ForeignKey("tournaments.message_id", ondelete="CASCADE")) + voice_channel_id = Column(BigInteger, nullable=False) + text_channel_id = Column(BigInteger, nullable=False) + role_id = Column(BigInteger, nullable=False) + + PrimaryKeyConstraint(reaction, tournament_message_id) + + members = relationship("TournamentTeamMember", back_populates="team") + tournament = relationship("Tournament", back_populates="teams") + + def __init__(self, reaction: str, tournament_id: int, voice_id: int, text_id: int, role_id: int): + self.tournament_message_id = tournament_id + self.voice_channel_id = voice_id + self.text_channel_id = text_id + self.reaction = reaction + self.role_id = role_id + db_session.add(self) + + def add_member(self, member_id: int) -> "TournamentTeamMember": + return TournamentTeamMember(member_id, self.reaction, self.tournament_message_id) + + def remove_member(self, member_id): + db_session.query(TournamentTeamMember) \ + .filter(TournamentTeamMember.member_id == member_id) \ + .filter(TournamentTeamMember.team_reaction == self.reaction) \ + .delete() + + def has_member(self, member_id) -> bool: + member = db_session.query(TournamentTeamMember) \ + .filter(TournamentTeamMember.member_id == member_id) \ + .filter(TournamentTeamMember.team_reaction == self.reaction) \ + .first() + return member is not None + + +class TournamentTeamMember(Base): + __tablename__ = "tournament_team_members" + + member_id = Column(BigInteger, nullable=False) + team_reaction = Column(Text, nullable=False) + tournament_message_id = Column(BigInteger, nullable=False) + + PrimaryKeyConstraint(member_id, team_reaction, tournament_message_id) + ForeignKeyConstraint((team_reaction, tournament_message_id), + ("tournament_teams.reaction", "tournament_teams.tournament_message_id"), + ondelete="CASCADE") + + team = relationship("TournamentTeam", back_populates="members") + + def __init__(self, member_id: int, team_reaction: str, tournament_message_id: int): + self.member_id = member_id + self.team_reaction = team_reaction + self.tournament_message_id = tournament_message_id + db_session.add(self)