Fix subreddit name checking with official regex, closes #141

This commit is contained in:
Zeva Rose 2022-05-29 19:31:32 -06:00
parent 638ae08bdd
commit 9188b13695

View file

@ -1,6 +1,7 @@
"""JARVIS Reddit cog."""
import asyncio
import logging
import re
from typing import List, Optional
from asyncpraw import Reddit
@ -28,6 +29,7 @@ from jarvis.utils import build_embed
from jarvis.utils.permissions import admin_or_permissions
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z")
class RedditCog(Extension):
@ -135,8 +137,7 @@ class RedditCog(Extension):
)
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
@ -248,12 +249,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer()
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
@ -300,12 +300,11 @@ class RedditCog(Extension):
],
)
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
await ctx.defer()
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
@ -340,12 +339,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer()
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
@ -380,12 +378,11 @@ class RedditCog(Extension):
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
await ctx.defer()
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e: