From ff059656bd5f1792775457c02ed8b566358d8c18 Mon Sep 17 00:00:00 2001 From: zevaryx Date: Mon, 7 Feb 2022 09:25:15 -0700 Subject: [PATCH] Update query wrapper to account for nested keys --- jarvis_core/db/__init__.py | 57 ++++++++++++++++++---- jarvis_core/util.py | 97 ++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 147 insertions(+), 9 deletions(-) diff --git a/jarvis_core/db/__init__.py b/jarvis_core/db/__init__.py index 6dfb282..2eef157 100644 --- a/jarvis_core/db/__init__.py +++ b/jarvis_core/db/__init__.py @@ -3,6 +3,8 @@ from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from umongo.frameworks import MotorAsyncIOInstance +from jarvis_core.util import find + CLIENT = None JARVISDB = None CTC2DB = None @@ -31,15 +33,54 @@ def connect( CTC2_INST.set_db(CTC2DB) +QUERY_OPS = ["ne", "lt", "lte", "gt", "gte", "not", "in", "nin", "mod", "all", "size"] +STRING_OPS = [ + "exact", + "iexact", + "contains", + "icontains", + "startswith", + "istartswith", + "endswith", + "iendswith", + "wholeword", + "iwholeword", + "regex", + "iregex" "match", +] +GEO_OPS = [ + "get_within", + "geo_within_box", + "geo_within_polygon", + "geo_within_center", + "geo_within_sphere", + "geo_intersects", + "near", + "within_distance", + "within_spherical_distance", + "near_sphere", + "within_box", + "within_polygon", + "max_distance", + "min_distance", +] + +ALL_OPS = QUERY_OPS + STRING_OPS + GEO_OPS + + def q(**kwargs: dict) -> dict: """uMongo query wrapper.""" # noqa: D403 query = {} - for k, v in kwargs.items(): - if k == "_id": - v = ObjectId(v) - elif "__" in k: - k, mod, *_ = k.split("__") - if mod: - v = {f"${mod}": v} - query[k] = v + for key, value in kwargs.items(): + if key == "_id": + value = ObjectId(value) + elif "__" in key: + args = key.split("__") + if not any(x in ALL_OPS for x in args): + key = ".".join(args) + else: + idx = args.index(find(lambda x: x in ALL_OPS, args)) + key = ".".join(args[:idx]) + value = {f"${args[idx]}": value} + query[key] = value return query diff --git a/jarvis_core/util.py b/jarvis_core/util.py index dc1df41..df1d542 100644 --- a/jarvis_core/util.py +++ b/jarvis_core/util.py @@ -1,4 +1,5 @@ """JARVIS quality of life utilities.""" +from typing import Any, Callable, Iterable, List, Optional class Singleton(object): @@ -31,3 +32,99 @@ class Singleton(object): for key, value in self.OPTIONAL.items(): if not getattr(self, key, None): setattr(self, key, value) + + +def find(predicate: Callable, sequence: Iterable) -> Optional[Any]: + """ + Find the first element in a sequence that matches the predicate. + + ??? Hint "Example Usage:" + ```python + member = find(lambda m: m.name == "UserName", guild.members) + ``` + Args: + predicate: A callable that returns a boolean value + sequence: A sequence to be searched + + Returns: + A match if found, otherwise None + + """ + for el in sequence: + if predicate(el): + return el + return None + + +def find_all(predicate: Callable, sequence: Iterable) -> List[Any]: + """ + Find all elements in a sequence that match the predicate. + + ??? Hint "Example Usage:" + ```python + members = find_all(lambda m: m.name == "UserName", guild.members) + ``` + Args: + predicate: A callable that returns a boolean value + sequence: A sequence to be searched + + Returns: + A list of matches + + """ + return [el for el in sequence if predicate(el)] + + +def get(sequence: Iterable, **kwargs: Any) -> Optional[Any]: + """ + Find the first element in a sequence that matches all attrs. + + ??? Hint "Example Usage:" + ```python + channel = get(guild.channels, nsfw=False, category="General") + ``` + + Args: + sequence: A sequence to be searched + kwargs: Keyword arguments to search the sequence for + + Returns: + A match if found, otherwise None + """ + if not kwargs: + return sequence[0] + + for el in sequence: + if any(not hasattr(el, attr) for attr in kwargs.keys()): + continue + if all(getattr(el, attr) == value for attr, value in kwargs.items()): + return el + return None + + +def get_all(sequence: Iterable, **kwargs: Any) -> List[Any]: + """ + Find all elements in a sequence that match all attrs. + + ??? Hint "Example Usage:" + ```python + channels = get_all(guild.channels, nsfw=False, category="General") + ``` + + Args: + sequence: A sequence to be searched + kwargs: Keyword arguments to search the sequence for + + Returns: + A list of matches + """ + if not kwargs: + return sequence + + matches = [] + for el in sequence: + if any(not hasattr(el, attr) for attr in kwargs.keys()): + continue + if all(getattr(el, attr) == value for attr, value in kwargs.items()): + matches.append(el) + return matches diff --git a/pyproject.toml b/pyproject.toml index 7a97d88..df8266d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jarvis-core" -version = "0.1.2" +version = "0.1.3" description = "" authors = ["Your Name "]