Fix subreddit name checking with official regex, closes #141
This commit is contained in:
parent
638ae08bdd
commit
9188b13695
1 changed files with 11 additions and 14 deletions
|
@ -1,6 +1,7 @@
|
||||||
"""JARVIS Reddit cog."""
|
"""JARVIS Reddit cog."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from asyncpraw import Reddit
|
from asyncpraw import Reddit
|
||||||
|
@ -28,6 +29,7 @@ from jarvis.utils import build_embed
|
||||||
from jarvis.utils.permissions import admin_or_permissions
|
from jarvis.utils.permissions import admin_or_permissions
|
||||||
|
|
||||||
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
|
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
|
||||||
|
sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z")
|
||||||
|
|
||||||
|
|
||||||
class RedditCog(Extension):
|
class RedditCog(Extension):
|
||||||
|
@ -135,8 +137,7 @@ class RedditCog(Extension):
|
||||||
)
|
)
|
||||||
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
||||||
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
|
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
|
||||||
name = name.replace("r/", "")
|
if not sub_name.match(name):
|
||||||
if len(name) > 20 or len(name) < 3:
|
|
||||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -248,12 +249,11 @@ class RedditCog(Extension):
|
||||||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||||
)
|
)
|
||||||
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
|
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
|
||||||
await ctx.defer()
|
if not sub_name.match(name):
|
||||||
name = name.replace("r/", "")
|
|
||||||
if len(name) > 20 or len(name) < 3:
|
|
||||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
await ctx.defer()
|
||||||
subreddit = await self.api.subreddit(name)
|
subreddit = await self.api.subreddit(name)
|
||||||
await subreddit.load()
|
await subreddit.load()
|
||||||
except (NotFound, Forbidden, Redirect) as e:
|
except (NotFound, Forbidden, Redirect) as e:
|
||||||
|
@ -300,12 +300,11 @@ class RedditCog(Extension):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
|
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
|
||||||
await ctx.defer()
|
if not sub_name.match(name):
|
||||||
name = name.replace("r/", "")
|
|
||||||
if len(name) > 20 or len(name) < 3:
|
|
||||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
await ctx.defer()
|
||||||
subreddit = await self.api.subreddit(name)
|
subreddit = await self.api.subreddit(name)
|
||||||
await subreddit.load()
|
await subreddit.load()
|
||||||
except (NotFound, Forbidden, Redirect) as e:
|
except (NotFound, Forbidden, Redirect) as e:
|
||||||
|
@ -340,12 +339,11 @@ class RedditCog(Extension):
|
||||||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||||
)
|
)
|
||||||
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
|
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
|
||||||
await ctx.defer()
|
if not sub_name.match(name):
|
||||||
name = name.replace("r/", "")
|
|
||||||
if len(name) > 20 or len(name) < 3:
|
|
||||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
await ctx.defer()
|
||||||
subreddit = await self.api.subreddit(name)
|
subreddit = await self.api.subreddit(name)
|
||||||
await subreddit.load()
|
await subreddit.load()
|
||||||
except (NotFound, Forbidden, Redirect) as e:
|
except (NotFound, Forbidden, Redirect) as e:
|
||||||
|
@ -380,12 +378,11 @@ class RedditCog(Extension):
|
||||||
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
||||||
)
|
)
|
||||||
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
|
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
|
||||||
await ctx.defer()
|
if not sub_name.match(name):
|
||||||
name = name.replace("r/", "")
|
|
||||||
if len(name) > 20 or len(name) < 3:
|
|
||||||
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
await ctx.defer()
|
||||||
subreddit = await self.api.subreddit(name)
|
subreddit = await self.api.subreddit(name)
|
||||||
await subreddit.load()
|
await subreddit.load()
|
||||||
except (NotFound, Forbidden, Redirect) as e:
|
except (NotFound, Forbidden, Redirect) as e:
|
||||||
|
|
Loading…
Add table
Reference in a new issue