Fix subreddit name checking with official regex, closes #141

This commit is contained in:
Zeva Rose 2022-05-29 19:31:32 -06:00
parent 638ae08bdd
commit 9188b13695

View file

@ -1,6 +1,7 @@
"""JARVIS Reddit cog.""" """JARVIS Reddit cog."""
import asyncio import asyncio
import logging import logging
import re
from typing import List, Optional from typing import List, Optional
from asyncpraw import Reddit from asyncpraw import Reddit
@ -28,6 +29,7 @@ from jarvis.utils import build_embed
from jarvis.utils.permissions import admin_or_permissions from jarvis.utils.permissions import admin_or_permissions
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)" DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z")
class RedditCog(Extension): class RedditCog(Extension):
@ -135,8 +137,7 @@ class RedditCog(Extension):
) )
@check(admin_or_permissions(Permissions.MANAGE_GUILD)) @check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None: async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
name = name.replace("r/", "") if not sub_name.match(name):
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True) await ctx.send("Invalid Subreddit name", ephemeral=True)
return return
@ -248,12 +249,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
) )
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None: async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer() if not sub_name.match(name):
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True) await ctx.send("Invalid Subreddit name", ephemeral=True)
return return
try: try:
await ctx.defer()
subreddit = await self.api.subreddit(name) subreddit = await self.api.subreddit(name)
await subreddit.load() await subreddit.load()
except (NotFound, Forbidden, Redirect) as e: except (NotFound, Forbidden, Redirect) as e:
@ -300,12 +300,11 @@ class RedditCog(Extension):
], ],
) )
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None: async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
await ctx.defer() if not sub_name.match(name):
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True) await ctx.send("Invalid Subreddit name", ephemeral=True)
return return
try: try:
await ctx.defer()
subreddit = await self.api.subreddit(name) subreddit = await self.api.subreddit(name)
await subreddit.load() await subreddit.load()
except (NotFound, Forbidden, Redirect) as e: except (NotFound, Forbidden, Redirect) as e:
@ -340,12 +339,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
) )
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None: async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer() if not sub_name.match(name):
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True) await ctx.send("Invalid Subreddit name", ephemeral=True)
return return
try: try:
await ctx.defer()
subreddit = await self.api.subreddit(name) subreddit = await self.api.subreddit(name)
await subreddit.load() await subreddit.load()
except (NotFound, Forbidden, Redirect) as e: except (NotFound, Forbidden, Redirect) as e:
@ -380,12 +378,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
) )
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None: async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer() if not sub_name.match(name):
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True) await ctx.send("Invalid Subreddit name", ephemeral=True)
return return
try: try:
await ctx.defer()
subreddit = await self.api.subreddit(name) subreddit = await self.api.subreddit(name)
await subreddit.load() await subreddit.load()
except (NotFound, Forbidden, Redirect) as e: except (NotFound, Forbidden, Redirect) as e: