Update query wrapper to account for nested keys

This commit is contained in:
Zeva Rose 2022-02-07 09:25:15 -07:00
parent a40c85065c
commit ff059656bd
3 changed files with 147 additions and 9 deletions

View file

@ -3,6 +3,8 @@ from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from umongo.frameworks import MotorAsyncIOInstance from umongo.frameworks import MotorAsyncIOInstance
from jarvis_core.util import find
CLIENT = None CLIENT = None
JARVISDB = None JARVISDB = None
CTC2DB = None CTC2DB = None
@ -31,15 +33,54 @@ def connect(
CTC2_INST.set_db(CTC2DB) CTC2_INST.set_db(CTC2DB)
QUERY_OPS = ["ne", "lt", "lte", "gt", "gte", "not", "in", "nin", "mod", "all", "size"]
STRING_OPS = [
"exact",
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"wholeword",
"iwholeword",
"regex",
"iregex" "match",
]
GEO_OPS = [
"get_within",
"geo_within_box",
"geo_within_polygon",
"geo_within_center",
"geo_within_sphere",
"geo_intersects",
"near",
"within_distance",
"within_spherical_distance",
"near_sphere",
"within_box",
"within_polygon",
"max_distance",
"min_distance",
]
ALL_OPS = QUERY_OPS + STRING_OPS + GEO_OPS
def q(**kwargs: dict) -> dict: def q(**kwargs: dict) -> dict:
"""uMongo query wrapper.""" # noqa: D403 """uMongo query wrapper.""" # noqa: D403
query = {} query = {}
for k, v in kwargs.items(): for key, value in kwargs.items():
if k == "_id": if key == "_id":
v = ObjectId(v) value = ObjectId(value)
elif "__" in k: elif "__" in key:
k, mod, *_ = k.split("__") args = key.split("__")
if mod: if not any(x in ALL_OPS for x in args):
v = {f"${mod}": v} key = ".".join(args)
query[k] = v else:
idx = args.index(find(lambda x: x in ALL_OPS, args))
key = ".".join(args[:idx])
value = {f"${args[idx]}": value}
query[key] = value
return query return query

View file

@ -1,4 +1,5 @@
"""JARVIS quality of life utilities.""" """JARVIS quality of life utilities."""
from typing import Any, Callable, Iterable, List, Optional
class Singleton(object): class Singleton(object):
@ -31,3 +32,99 @@ class Singleton(object):
for key, value in self.OPTIONAL.items(): for key, value in self.OPTIONAL.items():
if not getattr(self, key, None): if not getattr(self, key, None):
setattr(self, key, value) setattr(self, key, value)
def find(predicate: Callable, sequence: Iterable) -> Optional[Any]:
"""
Find the first element in a sequence that matches the predicate.
??? Hint "Example Usage:"
```python
member = find(lambda m: m.name == "UserName", guild.members)
```
Args:
predicate: A callable that returns a boolean value
sequence: A sequence to be searched
Returns:
A match if found, otherwise None
"""
for el in sequence:
if predicate(el):
return el
return None
def find_all(predicate: Callable, sequence: Iterable) -> List[Any]:
"""
Find all elements in a sequence that match the predicate.
??? Hint "Example Usage:"
```python
members = find_all(lambda m: m.name == "UserName", guild.members)
```
Args:
predicate: A callable that returns a boolean value
sequence: A sequence to be searched
Returns:
A list of matches
"""
return [el for el in sequence if predicate(el)]
def get(sequence: Iterable, **kwargs: Any) -> Optional[Any]:
"""
Find the first element in a sequence that matches all attrs.
??? Hint "Example Usage:"
```python
channel = get(guild.channels, nsfw=False, category="General")
```
Args:
sequence: A sequence to be searched
kwargs: Keyword arguments to search the sequence for
Returns:
A match if found, otherwise None
"""
if not kwargs:
return sequence[0]
for el in sequence:
if any(not hasattr(el, attr) for attr in kwargs.keys()):
continue
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
return el
return None
def get_all(sequence: Iterable, **kwargs: Any) -> List[Any]:
"""
Find all elements in a sequence that match all attrs.
??? Hint "Example Usage:"
```python
channels = get_all(guild.channels, nsfw=False, category="General")
```
Args:
sequence: A sequence to be searched
kwargs: Keyword arguments to search the sequence for
Returns:
A list of matches
"""
if not kwargs:
return sequence
matches = []
for el in sequence:
if any(not hasattr(el, attr) for attr in kwargs.keys()):
continue
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
matches.append(el)
return matches

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "jarvis-core" name = "jarvis-core"
version = "0.1.2" version = "0.1.3"
description = "" description = ""
authors = ["Your Name <you@example.com>"] authors = ["Your Name <you@example.com>"]