diff --git a/jarvis/cogs/reddit.py b/jarvis/cogs/reddit.py index b69b4c1..9ecdf87 100644 --- a/jarvis/cogs/reddit.py +++ b/jarvis/cogs/reddit.py @@ -1,6 +1,7 @@ """JARVIS Reddit cog.""" import asyncio import logging +import re from typing import List, Optional from asyncpraw import Reddit @@ -28,6 +29,7 @@ from jarvis.utils import build_embed from jarvis.utils.permissions import admin_or_permissions DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)" +sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z") class RedditCog(Extension): @@ -135,8 +137,7 @@ class RedditCog(Extension): ) @check(admin_or_permissions(Permissions.MANAGE_GUILD)) async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None: - name = name.replace("r/", "") - if len(name) > 20 or len(name) < 3: + if not sub_name.match(name): await ctx.send("Invalid Subreddit name", ephemeral=True) return @@ -248,12 +249,11 @@ class RedditCog(Extension): name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True ) async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None: - await ctx.defer() - name = name.replace("r/", "") - if len(name) > 20 or len(name) < 3: + if not sub_name.match(name): await ctx.send("Invalid Subreddit name", ephemeral=True) return try: + await ctx.defer() subreddit = await self.api.subreddit(name) await subreddit.load() except (NotFound, Forbidden, Redirect) as e: @@ -300,12 +300,11 @@ class RedditCog(Extension): ], ) async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None: - await ctx.defer() - name = name.replace("r/", "") - if len(name) > 20 or len(name) < 3: + if not sub_name.match(name): await ctx.send("Invalid Subreddit name", ephemeral=True) return try: + await ctx.defer() subreddit = await self.api.subreddit(name) await subreddit.load() except (NotFound, Forbidden, Redirect) as e: @@ -340,12 +339,11 @@ class RedditCog(Extension): name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True ) async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None: - await ctx.defer() - name = name.replace("r/", "") - if len(name) > 20 or len(name) < 3: + if not sub_name.match(name): await ctx.send("Invalid Subreddit name", ephemeral=True) return try: + await ctx.defer() subreddit = await self.api.subreddit(name) await subreddit.load() except (NotFound, Forbidden, Redirect) as e: @@ -380,12 +378,11 @@ class RedditCog(Extension): name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True ) async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None: - await ctx.defer() - name = name.replace("r/", "") - if len(name) > 20 or len(name) < 3: + if not sub_name.match(name): await ctx.send("Invalid Subreddit name", ephemeral=True) return try: + await ctx.defer() subreddit = await self.api.subreddit(name) await subreddit.load() except (NotFound, Forbidden, Redirect) as e: