diff --git a/.github/workflows/DiscordBot.yml b/.github/workflows/DiscordBot.yml index ada92d0..d64e307 100644 --- a/.github/workflows/DiscordBot.yml +++ b/.github/workflows/DiscordBot.yml @@ -40,21 +40,35 @@ jobs: docker: runs-on: ubuntu-latest steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Set up QEMU uses: docker/setup-qemu-action@v1 - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v1 - + + - name: Docker meta + id: meta + uses: docker/metadata-action@v4 + with: + images: plop91/plop_discord + tags: | + type=ref,event=branch + type=ref,event=pr + type=sha,prefix={{branch}}- + - name: Login to DockerHub - uses: docker/login-action@v1 + uses: docker/login-action@v1 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_SECRET }} - + - name: Build and push id: docker_build uses: docker/build-push-action@v2 with: push: true - tags: plop91/plop_discord:nightly + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 8d9ebbe..24b50b8 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -1,71 +1,68 @@ -# For most projects, this workflow file will not need changing; you simply need -# to commit it to your repository. -# -# You may wish to alter this file to override the set of languages analyzed, -# or to provide custom queries or build logic. -# -# ******** NOTE ******** -# We have attempted to detect the languages in your repository. Please check -# the `language` matrix defined below to confirm you have the correct set of -# supported CodeQL languages. -# -name: "CodeQL" - -on: - push: - branches: [ master ] - pull_request: - # The branches below must be a subset of the branches above - branches: [ master ] - schedule: - - cron: '23 1 * * 1' - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: [ 'python' ] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] - # Learn more: - # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v1 - with: - languages: ${{ matrix.language }} - # If you wish to specify custom queries, you can do so here or in a config file. - # By default, queries listed here will override any specified in a config file. - # Prefix the list here with "+" to use these queries and those in the config file. - # queries: ./path/to/local/query, your-org/your-repo/queries@main - - # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). - # If this step fails, then you should remove it and run the build manually (see below) - - name: Autobuild - uses: github/codeql-action/autobuild@v1 - - # â„šī¸ Command-line programs to run using the OS shell. - # 📚 https://git.io/JvXDl - - # âœī¸ If the Autobuild fails above, remove it and uncomment the following three lines - # and modify them (or add more) to build your code if your project - # uses a compiled language - - #- run: | - # make bootstrap - # make release - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + pull_request: + schedule: + - cron: '23 1 * * 1' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] + # Learn more: + # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # â„šī¸ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # âœī¸ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/BotHead.py b/BotHead.py index be1cb47..b74467f 100644 --- a/BotHead.py +++ b/BotHead.py @@ -29,9 +29,31 @@ async def setup_hook(self) -> None: Setup hook for the bot, loads all cogs. :return: None """ - for f in os.listdir('./cogs'): + cogs_dir = './cogs' + if not os.path.exists(cogs_dir): + settings.logger.error(f"Cogs directory '{cogs_dir}' does not exist. Cannot load cogs.") + return + + if not os.path.isdir(cogs_dir): + settings.logger.error(f"'{cogs_dir}' exists but is not a directory. Cannot load cogs.") + return + + try: + cog_files = os.listdir(cogs_dir) + except PermissionError: + settings.logger.error(f"Permission denied reading cogs directory '{cogs_dir}'") + return + except OSError as e: + settings.logger.error(f"Error reading cogs directory '{cogs_dir}': {e}") + return + + for f in cog_files: if f.endswith('.py'): - await self.load_extension(f'cogs.{f[:-3]}') + try: + await self.load_extension(f'cogs.{f[:-3]}') + settings.logger.info(f"Loaded cog: {f[:-3]}") + except Exception as e: + settings.logger.error(f"Failed to load cog {f[:-3]}: {e}") # @commands.command(brief="Admin only command: Load a Cog.") # async def load(self, ctx, extension): diff --git a/Dockerfile b/Dockerfile index f59150e..6abf102 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ RUN git clone https://github.com/plop91/PlopBot.git # Set working dir to the git repo WORKDIR /usr/src/app/PlopBot/ # Set branch to pull from -ARG branch=devel +ARG branch=ian-nightly # fetch branchs RUN git fetch # Checkout branch diff --git a/cogs/adminCog.py b/cogs/adminCog.py index 0c4f371..ca4bb91 100644 --- a/cogs/adminCog.py +++ b/cogs/adminCog.py @@ -3,6 +3,7 @@ """ from discord.ext import commands import settings +import subprocess class Admin(commands.Cog): @@ -32,7 +33,34 @@ async def hash(self, ctx): :arg ctx: context of the command :return: None """ - pass + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + # Fallback to old string-based check if admin_ids not configured + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + + if is_admin: + try: + # Get the current git commit hash + result = subprocess.run( + ['git', 'rev-parse', '--short', 'HEAD'], + capture_output=True, + text=True, + timeout=5 + ) + if result.returncode == 0: + commit_hash = result.stdout.strip() + await ctx.send(f"Current commit hash: `{commit_hash}`") + else: + await ctx.send("Failed to get git commit hash") + except subprocess.TimeoutExpired: + await ctx.send("Git command timed out") + except Exception as e: + settings.logger.error(f"Error getting git hash: {e}") + await ctx.send("Error retrieving git commit hash") + else: + await ctx.send("You do not have permission to run this command") + settings.logger.warning(f"Unauthorized hash command attempt by {ctx.author}") @commands.command(brief="Admin only command: provide current version") async def version(self, ctx): @@ -41,7 +69,33 @@ async def version(self, ctx): :arg ctx: context of the command :return: None """ - pass + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + + if is_admin: + try: + # Try to get version from git tag + result = subprocess.run( + ['git', 'describe', '--tags', '--always'], + capture_output=True, + text=True, + timeout=5 + ) + if result.returncode == 0: + version = result.stdout.strip() + await ctx.send(f"Bot version: `{version}`") + else: + await ctx.send("Version information not available") + except subprocess.TimeoutExpired: + await ctx.send("Git command timed out") + except Exception as e: + settings.logger.error(f"Error getting version: {e}") + await ctx.send("Error retrieving version information") + else: + await ctx.send("You do not have permission to run this command") + settings.logger.warning(f"Unauthorized version command attempt by {ctx.author}") @commands.command(brief="Admin only command: Turn the bot off.") async def kill(self, ctx): @@ -50,18 +104,22 @@ async def kill(self, ctx): :arg ctx: context of the command :return: None """ + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + # try to gracefully shut down the bot # noinspection PyBroadException try: - if str(ctx.author) in settings.info_json["admins"]: + if is_admin: settings.logger.info(f"kill from {ctx.author}!") if str(ctx.message.channel) in settings.info_json["command_channels"]: - await self.client.logout() + await self.client.close() else: - await ctx.channel.send(ctx.author) - await ctx.channel.send(settings.info_json["admins"]) - await ctx.channel.send("you are not an admin") - # if the bot fails to log out kill it + await ctx.channel.send("You do not have permission to run this command") + settings.logger.warning(f"Unauthorized kill attempt by {ctx.author}") + # if the bot fails to close kill it except Exception: exit(1) @@ -69,22 +127,28 @@ async def kill(self, ctx): async def restart(self, ctx): """ Preforms a restart of the bot + Note: This command closes the bot. The bot should be managed by a process manager + (like systemd or docker) that will automatically restart it. :arg ctx: context of the command :return: None """ + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + # try to gracefully shut down the bot # noinspection PyBroadException try: - if str(ctx.author) in settings.info_json["admins"]: + if is_admin: settings.logger.info(f"restart from {ctx.author}!") if str(ctx.message.channel) in settings.info_json["command_channels"]: - await self.client.logout() - await self.client.start(settings.info_json["token"]) + await ctx.send("Restarting bot... (requires process manager)") + await self.client.close() else: - await ctx.channel.send(ctx.author) - await ctx.channel.send(settings.info_json["admins"]) - await ctx.channel.send("you are not an admin") - # if the bot fails to log out kill it + await ctx.channel.send("You do not have permission to run this command") + settings.logger.warning(f"Unauthorized restart attempt by {ctx.author}") + # if the bot fails to close kill it except Exception: exit(1) diff --git a/cogs/audioCog.py b/cogs/audioCog.py index ed6c2d7..09ec00d 100644 --- a/cogs/audioCog.py +++ b/cogs/audioCog.py @@ -2,6 +2,8 @@ This cog is used to play audio from YouTube and the soundboard. It also has the ability to download mp3's from YouTube and add them to the soundboard. It also has the ability to play TTS audio. """ +from asyncio import sleep + from discord.ext import commands, tasks from discord.errors import ClientException from discord.utils import get @@ -15,6 +17,7 @@ import shutil import settings import traceback +import re ytdl_format_options = { 'format': 'bestaudio/best', @@ -95,6 +98,42 @@ def __init__(self, client): self.ghost_message = {} + # URL validation pattern for YouTube and common video sites + self.url_pattern = re.compile( + r'^https?://' # http:// or https:// + r'(?:(?:www|m)\.)?' # optional www. or m. + r'(?:youtube\.com|youtu\.be|twitch\.tv|soundcloud\.com|vimeo\.com|dailymotion\.com)' # allowed domains + r'[^\s]*$', # rest of URL + re.IGNORECASE + ) + + def validate_url(self, url): + """ + Validates that a URL is from an allowed domain + :param url: URL to validate + :return: True if valid, False otherwise + """ + if not url or not isinstance(url, str): + return False + return bool(self.url_pattern.match(url.strip())) + + def sanitize_filename(self, filename): + """ + Sanitizes a filename to prevent path traversal and other attacks + :param filename: Filename to sanitize + :return: Sanitized filename or None if invalid + """ + if not filename or not isinstance(filename, str): + return None + # Remove any path components + filename = os.path.basename(filename) + # Remove any dangerous characters + filename = re.sub(r'[^\w\s\-.]', '', filename) + # Prevent empty or hidden files + if not filename or filename.startswith('.'): + return None + return filename.strip().lower() + @staticmethod def clean_youtube(): """ @@ -102,8 +141,35 @@ def clean_youtube(): :return: """ settings.logger.info(f"cleaning youtube folder!") - for f in os.listdir("youtube"): - os.remove(os.path.join("youtube", f)) + youtube_dir = "youtube" + + if not os.path.exists(youtube_dir): + settings.logger.warning(f"YouTube directory '{youtube_dir}' does not exist, skipping cleanup") + return + + if not os.path.isdir(youtube_dir): + settings.logger.warning(f"'{youtube_dir}' exists but is not a directory, skipping cleanup") + return + + try: + files = os.listdir(youtube_dir) + except PermissionError: + settings.logger.error(f"Permission denied reading YouTube directory '{youtube_dir}'") + return + except OSError as e: + settings.logger.error(f"Error reading YouTube directory '{youtube_dir}': {e}") + return + + for f in files: + file_path = os.path.join(youtube_dir, f) + try: + if os.path.isfile(file_path): + os.remove(file_path) + settings.logger.debug(f"Removed YouTube file: {f}") + except PermissionError: + settings.logger.error(f"Permission denied removing file '{file_path}'") + except OSError as e: + settings.logger.error(f"Error removing file '{file_path}': {e}") @commands.Cog.listener() async def on_ready(self): @@ -153,13 +219,29 @@ async def on_message(self, message): "for admin approval. Notify an admin to resolve.") # If this is a new filename else: - filename = attachment.filename.lower().replace(' ', '').replace('_', '') + # Sanitize filename to prevent path traversal attacks + sanitized = self.sanitize_filename(attachment.filename) + if not sanitized or not sanitized.endswith('.mp3'): + await message.channel.send("Invalid filename. Only MP3 files with safe names are allowed.") + continue + + filename = sanitized + + # Validate the full path to prevent directory traversal + raw_path = os.path.join("./soundboard/raw", filename) + if not os.path.abspath(raw_path).startswith(os.path.abspath("./soundboard/raw")): + await message.channel.send("Invalid filename - path traversal detected") + settings.logger.warning(f"Path traversal attempt by {message.author}: {attachment.filename}") + continue + settings.logger.info(f"{message.author} added a mp3 file: {attachment}") await message.channel.send( f"The audio is being downloaded and should be ready shortly the name of the clip will " f"be: {filename.replace('.mp3', '')}") - await attachment.save(f"./soundboard/raw/{filename}") - audio_json = ffmpeg.probe(f"./soundboard/raw/{filename}") + await attachment.save(raw_path) + # Run ffmpeg.probe in executor to avoid blocking event loop + loop = asyncio.get_event_loop() + audio_json = await loop.run_in_executor(None, ffmpeg.probe, raw_path) # If the clip is too long it needs to be reviewed if float(audio_json['streams'][0]['duration']) >= 60: @@ -167,10 +249,17 @@ async def on_message(self, message): "reviewed before it can be played.") else: try: - shutil.copy(f"./soundboard/raw/{filename}", f"./soundboard/{filename}") + # Validate destination path as well + dest_path = os.path.join("./soundboard", filename) + if not os.path.abspath(dest_path).startswith(os.path.abspath("./soundboard")): + await message.channel.send("Invalid filename - path traversal detected") + settings.logger.warning(f"Path traversal attempt in destination by {message.author}") + continue + + shutil.copy(raw_path, dest_path) settings.soundboard_db.add_db_entry(filename.lower(), filename.replace(".mp3", "").lower()) - self.sounds[filename.replace(".mp3", "").lower()] = f"./soundboard/{filename}" + self.sounds[filename.replace(".mp3", "").lower()] = dest_path except ValueError: await message.channel.send("A file with that name already existed in the database, " "contact an admin!") @@ -180,8 +269,8 @@ async def on_message(self, message): else: # divide message as though it was a webhook command data = message.content.split(':') - # check if it has a valid source - if data[0] == "www.sodersjerna.com": + # check if it has a valid source AND sufficient parts to prevent IndexError + if len(data) >= 4 and data[0] == "www.sodersjerna.com": member = discord.utils.get(message.guild.members, name=data[1]) if member is not None and member.voice is not None: for client in self.client.voice_clients: @@ -238,14 +327,16 @@ async def play_clip(self, text_channel, voice_channel, filename): f = filename.replace("soundboard/", "").replace(".mp3", "") embed_var = discord.Embed(title="Play Command", - description=f"{text_channel.author} played a random clip: {f}", + description=f"Playing random clip: {f}", color=0xffff00) else: embed_var = discord.Embed(title="Play Command", - description=f"{text_channel.author} played: {fn}", + description=f"Playing: {fn}", color=0xffff00) - self.ghost_message[text_channel.guild.id] = await text_channel.channel.send(embed=embed_var) + # text_channel could be a Context object or a Channel object + channel = text_channel.channel if hasattr(text_channel, 'channel') else text_channel + self.ghost_message[text_channel.guild.id] = await channel.send(embed=embed_var) except AttributeError: settings.logger.info(f"Attribute Error: {traceback.format_exc()}") @@ -273,14 +364,35 @@ async def play(self, ctx, filename=None): """ settings.logger.info(f"play from {ctx.author} :{filename}") if ctx.author not in settings.info_json["blacklist"]: + # Sanitize filename if provided + if filename is not None: + sanitized = self.sanitize_filename(filename) + if not sanitized: + await ctx.send("Invalid filename provided.") + await ctx.message.delete() + return + filename = sanitized + if filename is None: embed_var = discord.Embed(title="Soundboard files", - description="type '.play ' followed by a name to play " + description="type '.play ' or '.p' followed by a name to play " "file", color=0x00ff00) s = "" + field_index = 0 for file in self.sounds.keys(): - if len(s) + len(file) >= 1024: + settings.logger.info(f"DEBUG: field_index: {field_index}") + if len(s) + len(file) >= 1024 and field_index > 3: + settings.logger.info(f"DEBUG: SENDING MESSAGE") + await ctx.channel.send(embed=embed_var) + field_index = 0 + embed_var = discord.Embed(title="Soundboard files", + description="type '.play ' or '.p' followed by a name to play " + "file", color=0x00ff00) + s = "" + + elif len(s) + len(file) >= 1024: embed_var.add_field(name="play from a filename:", value=s, inline=False) + field_index += 1 s = "" s += file + ", " @@ -288,9 +400,10 @@ async def play(self, ctx, filename=None): embed_var.add_field(name="play a random file:", value="random", inline=False) + settings.logger.info(f"DEBUG: SENDING REAL MESSAGE") await ctx.channel.send(embed=embed_var) - await ctx.message.delete() + await ctx.message.delete() return await self.play_clip(ctx, ctx.voice_client, filename) @@ -311,9 +424,20 @@ async def youtube(self, ctx, *, url): :return: None """ settings.logger.info(f"youtube from {ctx.author} :{url}") + + # Validate URL before processing + if not self.validate_url(url): + await ctx.send("Invalid URL. Only YouTube, Twitch, SoundCloud, Vimeo, and Dailymotion URLs are allowed.") + await ctx.message.delete() + return + async with ctx.typing(): - player = await YTDLSource.from_url(url, loop=self.client.loop, volume=self.volume) - ctx.voice_client.play(player) + try: + player = await YTDLSource.from_url(url, loop=self.client.loop, volume=self.volume) + ctx.voice_client.play(player) + except Exception as e: + settings.logger.error(f"Error playing YouTube URL: {e}") + await ctx.send("Failed to play the requested URL.") await ctx.message.delete() @commands.command(pass_context=True, @@ -327,9 +451,19 @@ async def stream(self, ctx, *, url): :arg url: url of the YouTube video to play :return: None """ + # Validate URL before processing + if not self.validate_url(url): + await ctx.send("Invalid URL. Only YouTube, Twitch, SoundCloud, Vimeo, and Dailymotion URLs are allowed.") + await ctx.message.delete() + return + async with ctx.typing(): - player = await YTDLSource.from_url(url, loop=self.client.loop, stream=True, volume=self.volume) - ctx.voice_client.play(player) + try: + player = await YTDLSource.from_url(url, loop=self.client.loop, stream=True, volume=self.volume) + ctx.voice_client.play(player) + except Exception as e: + settings.logger.error(f"Error streaming URL: {e}") + await ctx.send("Failed to stream the requested URL.") await ctx.message.delete() @commands.command(pass_context=True, @@ -458,8 +592,19 @@ async def get(self, ctx, sound: str): :arg sound: sound to return :return: None """ - if os.path.isfile(os.path.join("soundboard", sound)): - await ctx.channel.send(sound, file=discord.File(sound + ".mp3", os.path.join("soundboard", sound))) + # Validate filename to prevent path traversal + if '..' in sound or sound.startswith('/') or sound.startswith('\\'): + await ctx.channel.send("Invalid filename") + return + + filepath = os.path.join("soundboard", sound) + # Ensure the resolved path is within the soundboard directory + if not os.path.abspath(filepath).startswith(os.path.abspath("soundboard")): + await ctx.channel.send("Invalid filename") + return + + if os.path.isfile(filepath): + await ctx.channel.send(sound, file=discord.File(filepath)) @commands.command(aliases=['SAY'], brief="", @@ -474,8 +619,32 @@ async def say(self, ctx, text, *, tts_file='say'): """ settings.logger.info(f"say from {ctx.author} text:{text}") text = text.strip().lower() - gTTS(text).save(os.path.join("soundboard", tts_file + '.mp3')) - await self.play_clip(ctx, ctx.voice_client, tts_file) + + # Limit TTS text length to prevent abuse + MAX_TTS_LENGTH = 500 + if len(text) > MAX_TTS_LENGTH: + await ctx.send(f"Text too long. Max {MAX_TTS_LENGTH} characters allowed") + return + + # Sanitize tts_file parameter to prevent path traversal + sanitized_tts_file = self.sanitize_filename(tts_file) + if not sanitized_tts_file: + await ctx.send("Invalid filename") + return + + # Remove .mp3 extension if user provided it (we'll add it) + if sanitized_tts_file.endswith('.mp3'): + sanitized_tts_file = sanitized_tts_file[:-4] + + # Validate the full path + filepath = os.path.join("soundboard", sanitized_tts_file + '.mp3') + if not os.path.abspath(filepath).startswith(os.path.abspath("soundboard")): + await ctx.send("Invalid filename - path traversal detected") + settings.logger.warning(f"Path traversal attempt in TTS by {ctx.author}: {tts_file}") + return + + gTTS(text).save(filepath) + await self.play_clip(ctx, ctx.voice_client, sanitized_tts_file) await ctx.message.delete() @play.before_invoke diff --git a/cogs/gameCog.py b/cogs/gameCog.py index 87876fb..4790d90 100644 --- a/cogs/gameCog.py +++ b/cogs/gameCog.py @@ -75,7 +75,11 @@ async def teams(self, ctx, teams="2"): :return: None """ - iteams = int(teams) + try: + iteams = int(teams) + except ValueError: + await ctx.send("Invalid number of teams") + return # get current voice channel of author voice = ctx.author.voice.channel @@ -83,21 +87,40 @@ async def teams(self, ctx, teams="2"): if voice is not None: # filter bots from list of members in channel people = list(filter(lambda x: (not x.bot), voice.members)) - settings.logger.info(f"{iteams} teams with members: ".join(m.name for m in people)) + members_str = ", ".join(m.name for m in people) + settings.logger.info(f"{iteams} teams with members: {members_str}") if iteams < 2: iteams = 2 + if len(people) < iteams: settings.logger.info(f"Not enough players for {iteams} teams.") await ctx.send(f"Not enough players for {iteams} teams.") else: + # Calculate players per team and extra players + players_per_team = len(people) // iteams + extra_players = len(people) % iteams + + if players_per_team == 0: + await ctx.send(f"Not enough players for {iteams} teams. Need at least {iteams} players.") + return + for x in range(1, iteams + 1): - players = random.sample(people, int(len(people) / iteams)) - for p in players: - people.remove(p) - await ctx.send(f"Team {x}: " + ", ".join(m.name for m in players)) - iteams -= 1 + # Give extra players to first teams + team_size = players_per_team + (1 if x <= extra_players else 0) + + if len(people) >= team_size: + players = random.sample(people, team_size) + for p in players: + people.remove(p) + await ctx.send(f"Team {x}: " + ", ".join(m.name for m in players)) + else: + # Assign remaining players to last team + if people: + await ctx.send(f"Team {x}: " + ", ".join(m.name for m in people)) + people.clear() + break else: settings.logger.info(f"Could not find voice channel of member.") await ctx.send("Don't think you're in a voice channel") @@ -113,8 +136,13 @@ async def roll(self, ctx, sides, times="1"): :arg times: number of times to roll :return: None """ - isides = int(sides) - itimes = int(times) + try: + isides = int(sides) + itimes = int(times) + except ValueError: + await ctx.send("Invalid number for sides or times") + return + settings.logger.info(f"roll from {ctx.author}: {isides} sides") if isides > 1 and itimes > 0: await ctx.message.channel.send( diff --git a/cogs/generalCog.py b/cogs/generalCog.py index df2bc13..c7fda18 100644 --- a/cogs/generalCog.py +++ b/cogs/generalCog.py @@ -33,13 +33,17 @@ async def on_ready(self): @commands.Cog.listener() async def on_message(self, message): """ - logs any incoming messages and responds to 'hey' with 'hi' to verify bot is functional. + Responds to 'hey' with 'hi' to verify bot is functional. + Only logs command messages to respect user privacy. :arg message: message object :return: None """ _id = message.guild - message.content = message.content.strip().lower() - settings.logger.info(f"Message from {message.author}: {message.content}") + # Only log if it's a command (starts with prefix) or from the bot itself + # This reduces privacy concerns and log file size + if message.content.startswith(tuple(settings.info_json.get("command_prefixes", ["!"]))): + settings.logger.info(f"Command from {message.author}: {message.content}") + if message.author != self.client.user: if message.content.strip().lower() == "hey": await message.channel.send("Hi") @@ -47,11 +51,15 @@ async def on_message(self, message): @commands.Cog.listener() async def on_message_delete(self, message): """ - logs any deleted messages + Logs metadata of deleted messages without content to respect user privacy :arg message: message object :return: None """ - settings.logger.info(f"deleted message- {message.author} : {message.content}") + # Only log metadata, not content, to respect user privacy and GDPR + settings.logger.info( + f"deleted message- {message.author} in #{message.channel} " + f"at {message.created_at} (length: {len(message.content)} chars)" + ) @commands.Cog.listener() async def on_member_join(self, member): @@ -88,6 +96,10 @@ async def repeat(self, ctx, times: int, content='repeating...'): :arg content: content to repeat :return: None """ + MAX_REPEATS = 10 + if times > MAX_REPEATS: + await ctx.send(f"Max {MAX_REPEATS} repeats allowed") + return for i in range(times): await ctx.send(content) @@ -112,15 +124,21 @@ async def status(self, ctx): await ctx.channel.send(embed=embed_var) await ctx.message.delete() - @tasks.loop(seconds=0, minutes=30, hours=1) + @tasks.loop(hours=1) async def change_status(self): """ changes the bot to a randomly provided status. :return: None """ + statuses = settings.info_json.get("status", []) + if not statuses: + settings.logger.warning("No status messages configured, skipping status change") + return + settings.logger.info(f"status changed automatically") - await self.client.change_presence(status=discord.Status.online, activity=discord.Game( - settings.info_json["status"][random.randint(0, len(settings.info_json["status"]) - 1)])) + # Use random.choice instead of randint for cleaner code + status = random.choice(statuses) + await self.client.change_presence(status=discord.Status.online, activity=discord.Game(status)) async def setup(client): diff --git a/cogs/openAiCog.py b/cogs/openAiCog.py index 9740f46..1dce936 100644 --- a/cogs/openAiCog.py +++ b/cogs/openAiCog.py @@ -6,7 +6,7 @@ import discord import settings from discord.ext import commands -from openai import OpenAI as Oai +from openai import OpenAI as Oai, BadRequestError import wget import os import textwrap @@ -14,55 +14,35 @@ import json import asyncio -import mysql.connector -from mysql.connector import errorcode +from db.openai_database_manager import OpenAIDatabaseManager global logger +BLACKLIST_FILE = "openai_blacklist.json" -class OpenAIDatabaseManager: - """ - This class is for managing the openai database - """ - def __init__(self, db_host, db_username, db_password, database_name): - """ - Constructor for the openai database manager - :param db_host: Database host - :param db_username: Database username - :param db_password: Database password - :param database_name: Database name - """ - self.db = None - self.my_cursor = None - - self.db_host = db_host - self.db_username = db_username - self.db_password = db_password - self.database_name = database_name - - self.connect() - - def connect(self): - """Connects to the database""" - try: - self.db = mysql.connector.connect( - host=self.db_host, - user=self.db_username, - password=self.db_password, - database=self.database_name - ) - self.my_cursor = self.db.cursor() - except mysql.connector.Error as e: - if e.errno == errorcode.ER_ACCESS_DENIED_ERROR: - settings.logger.warning("Soundboard user name or password is Bad") - elif e.errno == errorcode.ER_BAD_DB_ERROR: - settings.logger.warning("Database does not exist") - else: - settings.logger.warning(e) +def load_blacklist(): + """Load blacklist from file""" + try: + if os.path.exists(BLACKLIST_FILE): + with open(BLACKLIST_FILE, 'r') as f: + return json.load(f) + return [] + except Exception as e: + settings.logger.error(f"Error loading blacklist: {e}") + return [] -blacklist = [] +def save_blacklist(blacklist_data): + """Save blacklist to file""" + try: + with open(BLACKLIST_FILE, 'w') as f: + json.dump(blacklist_data, f, indent=4) + except Exception as e: + settings.logger.error(f"Error saving blacklist: {e}") + + +blacklist = load_blacklist() def blacklisted(user): @@ -85,9 +65,23 @@ def __init__(self, client): :param client: Client object """ self.client = client - self.api_key = settings.info_json["openai"]["apikey"] - # openai.api_key = self.api_key - self.openai_client = Oai(api_key=self.api_key) + # Use environment variable for OpenAI API key if available, otherwise fallback to JSON + self.api_key = os.environ.get('OPENAI_API_KEY', settings.info_json.get("openai", {}).get("apikey")) + if not self.api_key: + settings.logger.warning("OpenAI API key not found in environment variables or config file") + self.openai_client = Oai(api_key=self.api_key) if self.api_key else None + + # self.db_manager = OpenAIDatabaseManager( + # settings.info_json["openai"]["db_host"], + # settings.info_json["openai"]["db_username"], + # settings.info_json["openai"]["db_password"], + # settings.info_json["openai"]["database_name"] + # ) + # + # try: + # self.db_manager.connect() + # except Exception as e: + # settings.logger.warning(f"Error connecting to openai database: {e}") self.active_assistants = {} self.active_threads = {} @@ -102,6 +96,7 @@ async def on_ready(self): @commands.command(pass_context=True, aliases=["genimg", "genimage", "gen_image"], brief="generate an image from a prompt using openai") + @commands.cooldown(1, 60, commands.BucketType.user) async def gen_img(self, ctx, *args): """ Generate an image from a prompt using openai @@ -113,24 +108,35 @@ async def gen_img(self, ctx, *args): if not blacklisted(ctx.author): prompt = ' '.join(args) settings.logger.info(f"generating image") - response = self.openai_client.images.generate( - model="dall-e-3", - prompt=prompt, - size="1024x1024", - quality="standard", - n=1 - ) - # image_url = response['data'][0]['url'] - image_url = response.data[0].url - image_filename = wget.download(image_url) - await ctx.send(file=discord.File(image_filename)) - os.remove(image_filename) + try: + response = self.openai_client.images.generate( + model="dall-e-3", + prompt=prompt, + size="1024x1024", + quality="standard", + n=1 + ) + image_url = response.data[0].url + image_filename = wget.download(image_url) + await ctx.send(file=discord.File(image_filename)) + os.remove(image_filename) + # TODO: add to database + except BadRequestError as e: + """ + openai.BadRequestError: Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Your prompt may contain text that is not allowed by our safety system.', 'param': None, 'type': 'invalid_request_error'}} + """ + if e.code == 400: + if e.code == "content_policy_violation": + await ctx.send("Your prompt was rejected by OpenAI's safety system due to content policy violation") + return + raise e else: settings.logger.info(f"User {ctx.author} is blacklisted from AI cog!") @commands.command(pass_context=True, aliases=["editimg", "editimage", "edit_image"], brief="edit an image from a prompt using openai") - async def edit_img(self, ctx): + @commands.cooldown(1, 60, commands.BucketType.user) + async def edit_img(self, ctx, *args): """ Edit an image from a prompt using openai :arg ctx: Context @@ -138,6 +144,9 @@ async def edit_img(self, ctx): """ if not blacklisted(ctx.author): + if not ctx.message.attachments: + await ctx.send("No image attached") + return if ctx.message.attachments[0] is None: await ctx.send("No image attached") return @@ -149,16 +158,19 @@ async def edit_img(self, ctx): png = png.resize((1024, 1024)) png.save("temp.png", 'png', quality=100) settings.logger.info(f"editing image") - response = self.openai_client.images.create_variation( - image=open("temp.png", "rb"), - n=1, - size="1024x1024" - ) + with open("temp.png", "rb") as image_file: + response = self.openai_client.images.create_variation( + image=image_file, + n=1, + size="1024x1024" + ) os.remove("temp.png") image_url = response['data'][0]['url'] image_filename = wget.download(image_url) await ctx.send(file=discord.File(image_filename)) os.remove(image_filename) + + # todo: add to database else: settings.logger.info(f"User {ctx.author} is blacklisted from AI cog!") @@ -204,6 +216,7 @@ async def list_assistants(self, ctx): @commands.command(pass_context=True, aliases=["cra", "createassistant"], brief="Create an assistant from a prompt using openai") + @commands.cooldown(1, 60, commands.BucketType.user) async def create_assistant(self, ctx, name, *args): """ Create an assistant from a prompt using openai @@ -368,6 +381,7 @@ async def handle_tool_call(self, ctx, run, thread_id): @commands.command(pass_context=True, aliases=["ca", "chatassistant"], brief="chat with an assistant using openai") + @commands.cooldown(1, 60, commands.BucketType.user) async def chat_assistant(self, ctx, name, *args): """ Chat with an assistant using openai @@ -401,26 +415,35 @@ async def chat_assistant(self, ctx, name, *args): self.active_threads[guild] = {name: thread} thread_id = self.active_threads[guild][name].id - self.openai_client.beta.threads.messages.create( + # Wrap blocking message create call in asyncio.to_thread + await asyncio.to_thread( + self.openai_client.beta.threads.messages.create, thread_id=thread_id, role="user", content=prompt ) - run = self.openai_client.beta.threads.runs.create( + # Wrap blocking run create call in asyncio.to_thread + run = await asyncio.to_thread( + self.openai_client.beta.threads.runs.create, thread_id=thread_id, assistant_id=assistant.id ) start_time = run.created_at while True: - run = self.openai_client.beta.threads.runs.retrieve( + # Wrap blocking OpenAI API call in asyncio.to_thread to prevent event loop blocking + run = await asyncio.to_thread( + self.openai_client.beta.threads.runs.retrieve, thread_id=thread_id, run_id=run.id ) current_time = time.time() if current_time - start_time > 60: + # TODO: if the assistant times out, deduct from the user's usage, then cancel the run await ctx.send("Assistant time out - cancelling") - run = self.openai_client.beta.threads.runs.cancel( + # Wrap blocking cancel call in asyncio.to_thread + run = await asyncio.to_thread( + self.openai_client.beta.threads.runs.cancel, thread_id=thread_id, run_id=run.id ) @@ -430,7 +453,7 @@ async def chat_assistant(self, ctx, name, *args): await ctx.send("Error cancelling assistant run") return - if "completed" in run.status: + if run.status == "completed": break elif run.status == "queued": pass @@ -455,7 +478,9 @@ async def chat_assistant(self, ctx, name, *args): return await asyncio.sleep(2) - messages = self.openai_client.beta.threads.messages.list( + # Wrap blocking messages list call in asyncio.to_thread + messages = await asyncio.to_thread( + self.openai_client.beta.threads.messages.list, thread_id=thread_id ) @@ -466,7 +491,11 @@ async def chat_assistant(self, ctx, name, *args): await ctx.send(line) # retrieve image file else: - image_data = self.openai_client.files.content(content.image_file.file_id) + # Wrap blocking file content call in asyncio.to_thread + image_data = await asyncio.to_thread( + self.openai_client.files.content, + content.image_file.file_id + ) image_data_bytes = image_data.read() image_filename = "./my-image.png" @@ -490,11 +519,22 @@ async def openai_ban(self, ctx, *user): :param user: User to ban :return: None """ - if ctx.author in settings.info_json["admins"]: - blacklist.append(str(user).strip().lower()) - await ctx.send(f"{user} has been banned from using the openai cog") + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + + if is_admin: + user_str = str(user).strip().lower() + if user_str not in blacklist: + blacklist.append(user_str) + save_blacklist(blacklist) + await ctx.send(f"{user} has been banned from using the openai cog") + else: + await ctx.send(f"{user} is already banned") else: - await ctx.send(f"{ctx.author} is not an admin and cannot ban someone from using the openai cog") + await ctx.send(f"You do not have permission to run this command") + settings.logger.warning(f"Unauthorized openai_ban attempt by {ctx.author}") @commands.command(pass_context=True, aliases=["openai_unban_user", "openai_unbanuser"], brief="Unban a user from using the openai cog") @@ -505,11 +545,22 @@ async def openai_unban(self, ctx, *user): :param user: User to unban :return: None """ - if ctx.author in settings.info_json["admins"]: - blacklist.remove(str(user).strip().lower()) - await ctx.send(f"{user} has been unbanned from using the openai cog") + # Use Discord user IDs instead of string names to prevent spoofing + admin_ids = settings.info_json.get("admin_ids", []) + admin_strings = settings.info_json.get("admins", []) + is_admin = ctx.author.id in admin_ids or str(ctx.author) in admin_strings + + if is_admin: + user_str = str(user).strip().lower() + if user_str in blacklist: + blacklist.remove(user_str) + save_blacklist(blacklist) + await ctx.send(f"{user} has been unbanned from using the openai cog") + else: + await ctx.send(f"{user} is not in the blacklist") else: - await ctx.send(f"{user} is not an admin and cannot be unbanned from using the openai cog") + await ctx.send(f"You do not have permission to run this command") + settings.logger.warning(f"Unauthorized openai_unban attempt by {ctx.author}") async def setup(client): diff --git a/cogs/twitterCog.py b/cogs/twitterCog.py index 67dcad5..1c23b31 100644 --- a/cogs/twitterCog.py +++ b/cogs/twitterCog.py @@ -23,10 +23,19 @@ def __init__(self, client): """ self.client = client - self.auth = OAuthHandler(settings.info_json["twitter"]["apikey"], settings.info_json["twitter"]["apisecret"]) - self.auth.set_access_token(settings.info_json["twitter"]["accesstoken"], - settings.info_json["twitter"]["accesstokensecret"]) - self.auth_api = API(self.auth) + # Use environment variables for Twitter credentials with JSON fallback + api_key = os.environ.get('TWITTER_API_KEY', settings.info_json.get("twitter", {}).get("apikey")) + api_secret = os.environ.get('TWITTER_API_SECRET', settings.info_json.get("twitter", {}).get("apisecret")) + access_token = os.environ.get('TWITTER_ACCESS_TOKEN', settings.info_json.get("twitter", {}).get("accesstoken")) + access_secret = os.environ.get('TWITTER_ACCESS_SECRET', settings.info_json.get("twitter", {}).get("accesstokensecret")) + + if not all([api_key, api_secret, access_token, access_secret]): + settings.logger.warning("Twitter credentials not fully configured in environment variables or config file") + self.auth_api = None + else: + self.auth = OAuthHandler(api_key, api_secret) + self.auth.set_access_token(access_token, access_secret) + self.auth_api = API(self.auth) @commands.Cog.listener() async def on_ready(self): @@ -43,6 +52,10 @@ async def factbot(self, ctx): :param ctx: Context of the command :return: None """ + if not self.auth_api: + await ctx.send("Twitter API not configured") + return + filename = "factbot.jpg" settings.logger.info(f"factbot : {ctx.author}") await self.get_last_tweet_image("@factbot1", save_as=filename) @@ -64,24 +77,30 @@ async def get_last_tweet_image(self, username, save_as="image.jpg"): :param save_as: filename to save the image as :return: None """ - tweets = self.auth_api.user_timeline(screen_name=username, count=1, include_rts=False, - exclude_replies=True) - tmp = [] - tweets_for_csv = [tweet.text for tweet in tweets] # CSV file created - for j in tweets_for_csv: - # Appending tweets to the empty array tmp - tmp.append(j) - print(tmp) - media_files = set() - for status in tweets: - media = status.entities.get('media', []) - if len(media) > 0: - media_files.add(media[0]['media_url']) - for media_file in media_files: - if save_as.endswith(".jpg") or save_as.endswith(".png"): - wget.download(media_file, save_as) - else: - wget.download(media_file, "image.jpg") + try: + tweets = self.auth_api.user_timeline(screen_name=username, count=1, include_rts=False, + exclude_replies=True) + tmp = [] + tweets_for_csv = [tweet.text for tweet in tweets] # CSV file created + for j in tweets_for_csv: + # Appending tweets to the empty array tmp + tmp.append(j) + settings.logger.debug(f"Tweet data: {tmp}") + media_files = set() + for status in tweets: + media = status.entities.get('media', []) + if len(media) > 0: + media_files.add(media[0]['media_url']) + for media_file in media_files: + try: + if save_as.endswith(".jpg") or save_as.endswith(".png"): + wget.download(media_file, save_as) + else: + wget.download(media_file, "image.jpg") + except Exception as e: + settings.logger.error(f"Download failed: {e}") + except Exception as e: + settings.logger.error(f"Twitter API error: {e}") async def setup(client): diff --git a/cogs/voiceCog.py b/cogs/voiceCog.py new file mode 100644 index 0000000..f019ba1 --- /dev/null +++ b/cogs/voiceCog.py @@ -0,0 +1,253 @@ +import discord +import settings +from discord.ext import commands +import requests +import os +import time +import json +import asyncio +import re +from urllib.parse import quote, urlparse + + +class Voices(commands.Cog): + """ + Functions: + add_voice: this function adds a voice to the database (NOTE: any clips attached to the message will be included) + arg: voice_name: str + add_clip: + arg: voice_name: str + make clip: this function creates new clips with the given text and voice, this function will also create a new message, + the message will inform the user that the clip is being processed, the message will be updated when the clip is ready. + arg: voice_name: str + arg: text: str + """ + + def __init__(self, client): + """ + Constructor for the voices cog + :param client: Client object + """ + self.client = client + # Get voice API URL from config, fallback to localhost + raw_url = settings.info_json.get("voice_api", {}).get("url", "http://localhost:8000") + + # Validate and sanitize the API URL + if not self._validate_api_url(raw_url): + settings.logger.error(f"Invalid voice API URL configured: {raw_url}. Using localhost fallback.") + self.voice_api_url = "http://localhost:8000" + else: + self.voice_api_url = raw_url.rstrip('/') + + # Request timeout in seconds to prevent DoS + self.request_timeout = 30 + + def _validate_api_url(self, url: str) -> bool: + """ + Validates the API URL to prevent SSRF attacks + :param url: URL to validate + :return: True if valid, False otherwise + """ + try: + parsed = urlparse(url) + # Only allow http and https schemes + if parsed.scheme not in ['http', 'https']: + settings.logger.warning(f"Invalid URL scheme: {parsed.scheme}") + return False + + # Ensure hostname is present + if not parsed.netloc: + settings.logger.warning("URL missing hostname") + return False + + # Block localhost variations, internal IPs (for production security) + # Uncomment these checks in production: + # blocked_hosts = ['127.', '0.0.0.0', 'localhost', '10.', '172.16.', '192.168.', '169.254.'] + # if any(parsed.netloc.startswith(blocked) for blocked in blocked_hosts): + # settings.logger.warning(f"Blocked internal/localhost URL: {parsed.netloc}") + # return False + + return True + except Exception as e: + settings.logger.error(f"Error validating URL: {e}") + return False + + def _sanitize_voice_name(self, voice_name: str) -> str: + """ + Sanitizes voice name to prevent injection attacks + :param voice_name: Voice name to sanitize + :return: Sanitized voice name or None if invalid + """ + # Only allow alphanumeric, underscore, and hyphen + if not re.match(r'^[a-zA-Z0-9_-]+$', voice_name): + return None + # Limit length + if len(voice_name) > 50: + return None + return voice_name + + def _sanitize_uuid(self, uuid: str) -> str: + """ + Sanitizes UUID to prevent injection attacks + :param uuid: UUID to sanitize + :return: Sanitized UUID or None if invalid + """ + # UUID format validation + if not re.match(r'^[a-f0-9-]+$', uuid, re.IGNORECASE): + return None + # Limit length + if len(uuid) > 36: + return None + return uuid + + @commands.command(pass_context=True, aliases=['av'], brief='Adds a voice', help='Adds a voice') + async def add_voice(self, ctx, voice_name: str): + """ + Adds a voice to the database + :param ctx: context + :param voice_name: voice name + """ + # Sanitize voice name to prevent injection + sanitized_name = self._sanitize_voice_name(voice_name) + if not sanitized_name: + await ctx.send(f'Invalid voice name. Use only alphanumeric characters, underscores, and hyphens.') + settings.logger.warning(f"Invalid voice name attempted by {ctx.author}: {voice_name}") + return + + # try to make a voice + # Use URL encoding for safety + r = requests.put(f'{self.voice_api_url}/new_voice?name={quote(sanitized_name)}', timeout=self.request_timeout) + if r.status_code == 200: + await ctx.send(f'Voice {sanitized_name} added') + else: + await ctx.send(f'Failed to add voice {sanitized_name}') + + if ctx.message.attachments: + # TODO: add the clips to the database + pass + + @commands.command(pass_context=True, aliases=['ac'], brief='Adds a clip', help='Adds a clip') + async def add_clip(self, ctx, voice_name: str): + """ + Adds a clip to the database + :param ctx: context + :param voice_name: voice name + """ + # Sanitize voice name to prevent injection + sanitized_name = self._sanitize_voice_name(voice_name) + if not sanitized_name: + await ctx.send(f'Invalid voice name. Use only alphanumeric characters, underscores, and hyphens.') + settings.logger.warning(f"Invalid voice name attempted by {ctx.author}: {voice_name}") + return + + if not ctx.message.attachments: + await ctx.send('No clip attached') + return + + # mk temp dir + if not os.path.exists('temp'): + os.makedirs('temp') + # download clip + for f in ctx.message.attachments: + await f.save(f'temp/{f.filename}') + + # upload clip to server + with open(f'temp/{f.filename}', 'rb') as file: + files = {'file': file} + # try to make a clip - use URL encoding for safety + r = requests.put(f'{self.voice_api_url}/new_clip?voice_name={quote(sanitized_name)}', files=files, timeout=self.request_timeout) + + if r.status_code == 200: + await ctx.send(f'Clip {f.filename} added to voice {sanitized_name}') + else: + await ctx.send(f'Failed to add clip {f.filename} to voice {sanitized_name}') + + @commands.command(pass_context=True, aliases=['mc'], brief='Makes a clip', help='Makes a clip') + async def make_clip(self, ctx, voice_name: str, *text: str): + """ + Makes a clip with the given text and voice + :param ctx: context + :param voice_name: voice name + :param text: text + """ + # Sanitize voice name to prevent injection + sanitized_name = self._sanitize_voice_name(voice_name) + if not sanitized_name: + await ctx.send(f'Invalid voice name. Use only alphanumeric characters, underscores, and hyphens.') + settings.logger.warning(f"Invalid voice name attempted by {ctx.author}: {voice_name}") + return + + # create data - use sanitized name + data = {'model': sanitized_name, 'text': ''.join(text), 'preset': "standard", "candidates": 1} + json_data = json.dumps(data) + + # make request + r = requests.put(f'{self.voice_api_url}/gen_voice', data=json_data, timeout=self.request_timeout) + + # check if request was successful + if r.status_code == 200: + await ctx.send(f'Clip is being processed') + else: + await ctx.send(f'Failed to make clip: {r.status_code} {r.text}') + return + + # get uuid + try: + uuid = r.json()["uuid"] + except (KeyError, json.JSONDecodeError) as e: + await ctx.send(f'Invalid response from voice API') + settings.logger.error(f"Invalid response from voice API: {e}") + return + + # Sanitize UUID to prevent injection + sanitized_uuid = self._sanitize_uuid(uuid) + if not sanitized_uuid: + await ctx.send(f'Invalid UUID received from API') + settings.logger.error(f"Invalid UUID from API: {uuid}") + return + + # get start time + start_time = time.time() + + while True: + r = requests.get(f'{self.voice_api_url}/get_clip?uid={quote(sanitized_uuid)}&clip=0', timeout=self.request_timeout) + # TODO: schedule a task to check every few seconds so the bot can do other things + if r.status_code == 200: + # download clip + if not os.path.exists("voices"): + os.mkdir("voices") + with open(f'voices/{sanitized_uuid}.wav', 'wb') as f: + f.write(r.content) + await ctx.send(f'Clip ready', file=discord.File(f'voices/{sanitized_uuid}.wav')) + # TODO: fix this + # await ctx.author.voice.channel.connect() + # source = discord.PCMVolumeTransformer( + # discord.FFmpegPCMAudio(source=f"{f'voices/{sanitized_uuid}.wav'}"), volume=1.0) + # ctx.voice_client.play(source) + break + if time.time() - start_time > 180: + await ctx.send(f'Clip timed out') + break + await asyncio.sleep(5) + + @commands.command(pass_context=True, aliases=['lv'], brief='Lists voices', help='Lists voices') + async def list_voices(self, ctx): + """ + Lists voices + :param ctx: context + """ + r = requests.get(f'{self.voice_api_url}/get_voices', timeout=self.request_timeout) + if r.status_code == 200: + voices = r.json()["voices"] + await ctx.send('Voices:\n' + "\n".join(voices)) + else: + await ctx.send(f'Failed to list voices') + + +async def setup(client): + """ + Sets up the cog + :param client: Client object + :return: None + """ + await client.add_cog(Voices(client)) diff --git a/db/openai_database_manager.py b/db/openai_database_manager.py new file mode 100644 index 0000000..63861da --- /dev/null +++ b/db/openai_database_manager.py @@ -0,0 +1,369 @@ +import mysql.connector +from mysql.connector import errorcode +from datetime import datetime + +import logging + + +class OpenAIDatabaseManager: + """ + This class is for managing the openai database + """ + + def __init__(self, db_host: str, db_username: str, db_password: str, database_name: str): + """ + Constructor for the openai database manager + :param db_host: Database host + :param db_username: Database username + :param db_password: Database password + :param database_name: Database name + """ + self.db = None + self.my_cursor = None + + self.db_host = db_host + self.db_username = db_username + self.db_password = db_password + self.database_name = database_name + + def connect(self): + """Connects to the database""" + try: + self.db = mysql.connector.connect( + host=self.db_host, + user=self.db_username, + password=self.db_password, + database=self.database_name + ) + self.my_cursor = self.db.cursor() + except mysql.connector.Error as e: + if e.errno == errorcode.ER_ACCESS_DENIED_ERROR: + # settings.logger.warning("Soundboard username or password is Bad") + print("Soundboard user name or password is Bad") + elif e.errno == errorcode.ER_BAD_DB_ERROR: + # settings.logger.warning("Database does not exist") + print("Database does not exist") + else: + # settings.logger.warning(e) + print(e) + + def disconnect(self): + """Disconnects from the database""" + self.my_cursor.close() + self.db.close() + + # user + def add_user(self, user): + """ + Adds a new user to the database + :param user: User to add + :return: None + :raises: ValueError if the user already exists + :raises: ConnectionError if the user cannot be added because the database is not connected + :raises: Exception if the user cannot be added for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "INSERT INTO usernames (username, date_created) VALUES (%s, %s)" + val = (user, datetime.now()) + self.my_cursor.execute(sql, val) + self.db.commit() + except mysql.connector.errors.IntegrityError: + raise ValueError(f"User {user} already exists") + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql, val) + self.db.commit() + else: + raise ConnectionError(f"Database error: {e.errno}") + + def get_users(self): + """ + Gets the users from the database + :return: List of users + :raises: ConnectionError if the user cannot be added because the database is not connected + :raises: Exception if the list of users cannot be retrieved for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "SELECT username FROM usernames" + self.my_cursor.execute(sql) + result = self.my_cursor.fetchall() + return [row[0] for row in result] + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql) + result = self.my_cursor.fetchall() + return [row[0] for row in result] + else: + raise ConnectionError(f"Database error: {e.errno}") + + def remove_user(self, user): + """ + Removes a user from the database + :param user: User to remove + :return: None + :raises: ValueError if the user does not exist + :raises: ConnectionError if the user cannot be added because the database is not connected + :raises: Exception if the user cannot be removed for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "DELETE FROM usernames WHERE username = %s" + adr = (user,) + self.my_cursor.execute(sql, adr) + if self.my_cursor.rowcount == 0: + raise ValueError(f"User {user} does not exist") + self.db.commit() + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql, adr) + if self.my_cursor.rowcount == 0: + raise ValueError(f"User {user} does not exist") + self.db.commit() + else: + raise ConnectionError(f"Database error: {e.errno}") + + + # blacklist + def add_blacklist(self, blacklist): + """ + Adds a new blacklist to the database + :param blacklist: Blacklist to add + :return: None + :raises: ValueError if the blacklist already exists + :raises: ConnectionError if the blacklist cannot be added because the database is not connected + :raises: Exception if the blacklist cannot be added for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "INSERT INTO blacklists (blacklist, date_created) VALUES (%s, %s)" + val = (blacklist, datetime.now()) + self.my_cursor.execute(sql, val) + self.db.commit() + except mysql.connector.errors.IntegrityError: + raise ValueError(f"Blacklist entry {blacklist} already exists") + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql, val) + self.db.commit() + else: + raise ConnectionError(f"Database error: {e.errno}") + + def blacklisted(self, user): + """ + Checks if a user is blacklisted + :param user: User to check + :return: True if the user is blacklisted, False otherwise + :raises: ConnectionError if the blacklist cannot be added because the database is not connected + :raises: Exception if the blacklist cannot be added for any other reason + """ + # TODO: add error handling + try: + blacklist = self.get_blacklists() + if user in blacklist: + return True + return False + except Exception as e: + raise e + + def get_blacklists(self): + """ + Gets the blacklists from the database + :return: List of blacklisted users + :raises: ConnectionError if the blacklist cannot be added because the database is not connected + :raises: Exception if the list of blacklists cannot be retrieved for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "SELECT blacklist FROM blacklists" + self.my_cursor.execute(sql) + result = self.my_cursor.fetchall() + return [row[0] for row in result] + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql) + result = self.my_cursor.fetchall() + return [row[0] for row in result] + else: + raise ConnectionError(f"Database error: {e.errno}") + + def remove_blacklist(self, blacklist): + """ + Removes a blacklist from the database + :param blacklist: Blacklist to remove + :return: None + :raises: ValueError if the blacklist does not exist + :raises: ConnectionError if the blacklist cannot be added because the database is not connected + :raises: Exception if the blacklist cannot be removed for any other reason + """ + if not self.db or not self.my_cursor: + raise ConnectionError("Database not connected") + + try: + sql = "DELETE FROM blacklists WHERE blacklist = %s" + adr = (blacklist,) + self.my_cursor.execute(sql, adr) + if self.my_cursor.rowcount == 0: + raise ValueError(f"Blacklist entry {blacklist} does not exist") + self.db.commit() + except mysql.connector.Error as e: + if e.errno == errorcode.CR_SERVER_GONE_ERROR: + self.connect() + # Retry once + self.my_cursor.execute(sql, adr) + if self.my_cursor.rowcount == 0: + raise ValueError(f"Blacklist entry {blacklist} does not exist") + self.db.commit() + else: + raise ConnectionError(f"Database error: {e.errno}") + + # assistants + def add_assistant(self, assistant): + """ + Adds an assistant to the database + :param assistant: Assistant to add + :return: None + :raises: ValueError if the assistant already exists + :raises: ConnectionError if the assistant cannot be added because the database is not connected + :raises: Exception if the assistant cannot be added for any other reason + """ + # TODO: add error handling + pass + + def get_assistants(self): + """ + Gets the assistants from the database + :return: List of assistants + :raises: ConnectionError if the assistant cannot be added because the database is not connected + :raises: Exception if the list of assistants cannot be retrieved for any other reason + """ + # TODO: add error handling + pass + + def remove_assistant(self, assistant): + """ + Removes an assistant from the database + :param assistant: Assistant to remove + :return: None + :raises: ValueError if the assistant does not exist + :raises: ConnectionError if the assistant cannot be added because the database is not connected + :raises: Exception if the assistant cannot be removed for any other reason + """ + # TODO: add error handling + pass + + # threads + def add_thread(self, thread): + """ + Adds a thread to the database + :param thread: Thread to add + :return: None + :raises: ValueError if the thread already exists + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be added for any other reason + """ + # TODO: add error handling + pass + + def get_threads(self): + """ + Gets the threads from the database + :return: List of threads + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the list of threads cannot be retrieved for any other reason + """ + # TODO: add error handling + pass + + def remove_thread(self, thread): + """ + Removes a thread from the database + :param thread: Thread to remove + :return: None + :raises: ValueError if the thread does not exist + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be removed for any other reason + """ + # TODO: add error handling + pass + + def update_usage(self, username, function, count=1, month=None, year=None): + """ + Updates the usage of a user for a function, adding 1 to the count for the current month and year or the given + year, if the user has not used the function that month and year before it will be added to the database. + :param: username: Username to update + :param: function: Function to update + :param: count: Count to add + :param: month: Month to update + :param: year: Year to update + :return: None + :raises: ValueError if the assistant or thread does not exist + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be removed for any other reason + """ + # TODO: add error handling + pass + + def get_usage(self, username, function, month=None, year=None): + """ + Gets the usage of a user for a function for the current month and year or the given month and year + :param: username: Username to get + :param: function: Function to get + :param: month: Month to get Default: Current month + :param: year: Year to get Default: Current year + :return: Int number of times the user has used the function + :raises: ValueError if the assistant or thread does not exist + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be removed for any other reason + """ + # TODO: add error handling + pass + + def get_user_functions(self, username, month=None, year=None): + """ + Gets the functions a user has used for the current month and year or the given month and year + :param: username: Username to get + :param: month: Month to get Default: Current month + :param: year: Year to get Default: Current year + :return: None + :raises: ValueError if the assistant or thread does not exist + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be removed for any other reason + """ + # TODO: add error handling + pass + + def get_function_users(self, function, month=None, year=None): + """ + Gets the users that have used a function for the current month and year or the given month and year + :param: function: Function to get + :param: month: Month to get Default: Current month + :param: year: Year to get Default: Current year + :return: None + :raises: ValueError if the assistant or thread does not exist + :raises: ConnectionError if the thread cannot be added because the database is not connected + :raises: Exception if the thread cannot be removed for any other reason + """ + # TODO: add error handling + pass diff --git a/info/.gitignore b/info/.gitignore index c96a04f..a3a0c8b 100644 --- a/info/.gitignore +++ b/info/.gitignore @@ -1,2 +1,2 @@ -* +* !.gitignore \ No newline at end of file diff --git a/notes/OpenAIUsageTrackerPlan.md b/notes/OpenAIUsageTrackerPlan.md new file mode 100644 index 0000000..7564783 --- /dev/null +++ b/notes/OpenAIUsageTrackerPlan.md @@ -0,0 +1,80 @@ +# plan for OpenAI Usage Tracker + +## Database schema +database name: openai + +### usernames_table +a table to store the usernames of the users +- id: int (Primary key) the id of the record +- username: string the username of the user +- date_created: date the date the user was created + +### cost_table +a table to store the cost of each endpoint on openai +- endpoint: int (Primary key)(foreign key) the id of the endpoint +- date: date the last date the price was updated +- cost: float the cost of the function per request in dollars on openai + +### endpoint_table +a table to store the name of the endpoints +- id: int (primary key) the id of the record +- endpoint_name: string (unique) the name of the function + + initial data: + openai_images_generate + openai_images_variation + openai_assistant_create + openai_assistant_submit_tool_run + openai_assistant_thread_create + openai_assistant_run + +### assistant_table +list of assistants created by users, including the creator username, assistant id, and the date created +- id: int (primary key) the id of the record +- creator_id: int (foreign key) the username of the creator +- assistant_id: string (unique) the id of the assistant + +### thread_table +list of threads created by users, including the creator username, thread id, and the date created +- id: int (primary key) the id of the record +- creator_username: string (foreign key) the username of the creator +- thread_id: string (unique) the id of the thread + +### endpoint_usage_table +a table to store the usage of each endpoint on openai +- id: int (primary key) the id of the record +- endpoint: string (foreign key) the name of the endpoint +- date: date the date of the last usage +- usage: int the number of requests made on that date +- cost: float the cost of the requests made on that date + +### endpoint_user_usage_table +a table to store the usage of each endpoint by each user +- id: int (primary key) the id of the record +- endpoint: string (foreign key) the name of the endpoint +- username: string (foreign key) the username of the user +- date: date the date of the last usage +- usage: int the number of requests made on that date +- cost: float the cost of the requests made on that date + +### blacklisted_user_table +a table to store the blacklisted users +- id: int (primary key) the id of the record +- user_id: int (foreign key) the user_id of the user +- date: date the date the user was blacklisted +- reason: string the reason the user was blacklisted +- blacklisted_by: int (foreign key) the user_id of the user who blacklisted the user +- blacklisted_until: date the date the user will be blacklisted until + +### rate_limited_user_table +a table to store the rate limited users +- id: int (primary key) the id of the record +- user_id: int (foreign key) the user_id of the user +- date: date the user was rate limited +- reason: string the reason the user was rate limited +- rate_limited_by: int (foreign key) the user_id of the user who rate limited the user +- rate_limited_until: date the date the user will be rate limited until +- rate_limit: int the number of requests the user is limited to + + + diff --git a/settings.py b/settings.py index f1b9d33..94ae1be 100644 --- a/settings.py +++ b/settings.py @@ -47,11 +47,30 @@ def init(args): # global info_json - with open(args.json, 'r') as f: - info_json = json.load(f) - f.close() + try: + with open(args.json, 'r') as f: + info_json = json.load(f) + f.close() + except FileNotFoundError: + raise FileNotFoundError( + f"Configuration file not found at '{args.json}'. " + f"Please ensure the info.json file exists at the specified path." + ) + except json.JSONDecodeError as e: + raise ValueError( + f"Configuration file '{args.json}' contains invalid JSON: {e}. " + f"Please check the file for syntax errors." + ) + except PermissionError: + raise PermissionError( + f"Cannot read configuration file '{args.json}': Permission denied. " + f"Please check file permissions." + ) global token - token = info_json["token"] + # Use environment variable for token if available, otherwise fallback to JSON + token = os.environ.get('DISCORD_BOT_TOKEN', info_json.get("token")) + if not token: + raise ValueError("Discord bot token not found in environment variables or config file") # # @@ -65,17 +84,18 @@ def init(args): if args.db_username is None: db_username = info_json['soundboard_database']['username'] else: - db_username = args.db_host + db_username = args.db_username if args.db_password is None: - db_password = info_json['soundboard_database']['password'] + # Use environment variable for DB password if available, otherwise fallback to JSON + db_password = os.environ.get('DB_PASSWORD', info_json['soundboard_database']['password']) else: - db_password = args.db_host + db_password = args.db_password if args.db_name is None: db_name = info_json['soundboard_database']['database'] else: - db_name = args.db_host + db_name = args.db_name soundboard_db = SoundboardDBManager(db_host=host, db_username=db_username, db_password=db_password, database_name=db_name) @@ -136,13 +156,12 @@ def add_db_entry(self, filename: str, name: str): self.my_cursor.execute(sql, val) self.db.commit() logger.info(f"adding sound to db filename:{filename} name:{name}") - except Exception as e: - logger.warning(f"unknown exception while adding to db!") - logger.warning(e) - - except Exception as e: - logger.warning("unknown exception while adding to db!") - logger.warning(e) + except mysql.connector.Error as retry_error: + logger.error(f"Database error during retry: {retry_error}") + raise + else: + logger.error(f"Database error: {e}") + raise def remove_db_entry(self, filename: str): """Removes database entry for the given filename""" @@ -166,16 +185,12 @@ def remove_db_entry(self, filename: str): self.db.commit() logger.info(f"removed sound from db filename:{filename}") - except Exception as e: - logger.warning(f"unknown exception while adding to db!") - logger.warning(e) + except mysql.connector.Error as retry_error: + logger.error(f"Database error during retry while removing: {retry_error}") + raise else: - logger.warning(f"unknown exception while adding to db!") - logger.warning(e) - - except Exception as e: - logger.warning("unknown exception while removing from db!") - logger.warning(e) + logger.error(f"Database error while removing: {e}") + raise def list_db_files(self): """Returns a list of database entries""" @@ -197,17 +212,12 @@ def list_db_files(self): my_result = self.my_cursor.fetchall() return my_result - except Exception as e: - logger.warning(f"List_db_file inner unknown exception while listing db!") - logger.warning(e) - + except mysql.connector.Error as retry_error: + logger.error(f"Database error during retry while listing: {retry_error}") + raise else: - logger.warning(f"List_db_file unknown my sql exception while listing db!") - logger.warning(e) - - except Exception as e: - logger.warning("List_db_file unknown exception while listing db!") - logger.warning(e) + logger.error(f"Database error while listing: {e}") + raise def verify_db(self): """Checks database against files on server and manages database accordingly @@ -225,24 +235,25 @@ def verify_db(self): self.connect() db_files = self.list_db_files() - except Exception as e: - logger.warning(f"unknown exception while adding to db!") - logger.warning(e) + except mysql.connector.Error as retry_error: + logger.error(f"Database error during retry while verifying: {retry_error}") + return else: - logger.warning(f"unknown exception while verifying db!") - logger.warning(e) - - except Exception as e: - logger.warning(f"unknown exception while verifying db!") - logger.warning(e) + logger.error(f"Database error while verifying: {e}") + return try: + # Find files in soundboard directory + soundboard_files = set() for file in os.listdir("./soundboard"): if file.endswith(".mp3"): - for temp in db_files: - if temp[0] == file: - db_files.remove(temp) + soundboard_files.add(file) - for file in db_files: + # Build list of database entries to remove (not in filesystem) + # Use list comprehension instead of modifying list during iteration + files_to_remove = [db_file for db_file in db_files if db_file[0] not in soundboard_files] + + # Remove orphaned database entries + for file in files_to_remove: self.remove_db_entry(file[1]) for file in os.listdir("./soundboard"): @@ -254,9 +265,10 @@ def verify_db(self): self.add_db_entry(file.lower(), file.replace(".mp3", "").lower()) except ValueError: continue - except Exception as e: - logger.warning(f"unknown exception while verifying db!") - logger.warning(e) + except OSError as e: + logger.error(f"File system error while verifying db: {e}") + except mysql.connector.Error as e: + logger.error(f"Database error while verifying db: {e}") def add_to_json(filename, json_data, tag, data):