Fix subreddit name checking with official regex, closes #141
This commit is contained in:
parent
638ae08bdd
commit
9188b13695
1 changed files with 11 additions and 14 deletions
|
@ -1,6 +1,7 @@
|
|||
"""JARVIS Reddit cog."""
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from asyncpraw import Reddit
|
||||
|
@ -28,6 +29,7 @@ from jarvis.utils import build_embed
|
|||
from jarvis.utils.permissions import admin_or_permissions
|
||||
|
||||
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
|
||||
sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z")
|
||||
|
||||
|
||||
class RedditCog(Extension):
|
||||
|
@ -135,8 +137,7 @@ class RedditCog(Extension):
|
|||
)
|
||||
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
||||
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
|
||||
name = name.replace("r/", "")
|
||||
if len(name) > 20 or len(name) < 3:
|
||||
if not sub_name.match(name):
|
||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||
return
|
||||
|
||||
|
@ -248,12 +249,11 @@ class RedditCog(Extension):
|
|||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||
)
|
||||
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
|
||||
await ctx.defer()
|
||||
name = name.replace("r/", "")
|
||||
if len(name) > 20 or len(name) < 3:
|
||||
if not sub_name.match(name):
|
||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||
return
|
||||
try:
|
||||
await ctx.defer()
|
||||
subreddit = await self.api.subreddit(name)
|
||||
await subreddit.load()
|
||||
except (NotFound, Forbidden, Redirect) as e:
|
||||
|
@ -300,12 +300,11 @@ class RedditCog(Extension):
|
|||
],
|
||||
)
|
||||
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
|
||||
await ctx.defer()
|
||||
name = name.replace("r/", "")
|
||||
if len(name) > 20 or len(name) < 3:
|
||||
if not sub_name.match(name):
|
||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||
return
|
||||
try:
|
||||
await ctx.defer()
|
||||
subreddit = await self.api.subreddit(name)
|
||||
await subreddit.load()
|
||||
except (NotFound, Forbidden, Redirect) as e:
|
||||
|
@ -340,12 +339,11 @@ class RedditCog(Extension):
|
|||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||
)
|
||||
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
|
||||
await ctx.defer()
|
||||
name = name.replace("r/", "")
|
||||
if len(name) > 20 or len(name) < 3:
|
||||
if not sub_name.match(name):
|
||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||
return
|
||||
try:
|
||||
await ctx.defer()
|
||||
subreddit = await self.api.subreddit(name)
|
||||
await subreddit.load()
|
||||
except (NotFound, Forbidden, Redirect) as e:
|
||||
|
@ -380,12 +378,11 @@ class RedditCog(Extension):
|
|||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||
)
|
||||
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
|
||||
await ctx.defer()
|
||||
name = name.replace("r/", "")
|
||||
if len(name) > 20 or len(name) < 3:
|
||||
if not sub_name.match(name):
|
||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||
return
|
||||
try:
|
||||
await ctx.defer()
|
||||
subreddit = await self.api.subreddit(name)
|
||||
await subreddit.load()
|
||||
except (NotFound, Forbidden, Redirect) as e:
|
||||
|
|
Loading…
Add table
Reference in a new issue