Update query wrapper to account for nested keys
This commit is contained in:
parent
a40c85065c
commit
ff059656bd
3 changed files with 147 additions and 9 deletions
|
@ -3,6 +3,8 @@ from bson import ObjectId
|
|||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from umongo.frameworks import MotorAsyncIOInstance
|
||||
|
||||
from jarvis_core.util import find
|
||||
|
||||
CLIENT = None
|
||||
JARVISDB = None
|
||||
CTC2DB = None
|
||||
|
@ -31,15 +33,54 @@ def connect(
|
|||
CTC2_INST.set_db(CTC2DB)
|
||||
|
||||
|
||||
QUERY_OPS = ["ne", "lt", "lte", "gt", "gte", "not", "in", "nin", "mod", "all", "size"]
|
||||
STRING_OPS = [
|
||||
"exact",
|
||||
"iexact",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"wholeword",
|
||||
"iwholeword",
|
||||
"regex",
|
||||
"iregex" "match",
|
||||
]
|
||||
GEO_OPS = [
|
||||
"get_within",
|
||||
"geo_within_box",
|
||||
"geo_within_polygon",
|
||||
"geo_within_center",
|
||||
"geo_within_sphere",
|
||||
"geo_intersects",
|
||||
"near",
|
||||
"within_distance",
|
||||
"within_spherical_distance",
|
||||
"near_sphere",
|
||||
"within_box",
|
||||
"within_polygon",
|
||||
"max_distance",
|
||||
"min_distance",
|
||||
]
|
||||
|
||||
ALL_OPS = QUERY_OPS + STRING_OPS + GEO_OPS
|
||||
|
||||
|
||||
def q(**kwargs: dict) -> dict:
|
||||
"""uMongo query wrapper.""" # noqa: D403
|
||||
query = {}
|
||||
for k, v in kwargs.items():
|
||||
if k == "_id":
|
||||
v = ObjectId(v)
|
||||
elif "__" in k:
|
||||
k, mod, *_ = k.split("__")
|
||||
if mod:
|
||||
v = {f"${mod}": v}
|
||||
query[k] = v
|
||||
for key, value in kwargs.items():
|
||||
if key == "_id":
|
||||
value = ObjectId(value)
|
||||
elif "__" in key:
|
||||
args = key.split("__")
|
||||
if not any(x in ALL_OPS for x in args):
|
||||
key = ".".join(args)
|
||||
else:
|
||||
idx = args.index(find(lambda x: x in ALL_OPS, args))
|
||||
key = ".".join(args[:idx])
|
||||
value = {f"${args[idx]}": value}
|
||||
query[key] = value
|
||||
return query
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""JARVIS quality of life utilities."""
|
||||
from typing import Any, Callable, Iterable, List, Optional
|
||||
|
||||
|
||||
class Singleton(object):
|
||||
|
@ -31,3 +32,99 @@ class Singleton(object):
|
|||
for key, value in self.OPTIONAL.items():
|
||||
if not getattr(self, key, None):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def find(predicate: Callable, sequence: Iterable) -> Optional[Any]:
|
||||
"""
|
||||
Find the first element in a sequence that matches the predicate.
|
||||
|
||||
??? Hint "Example Usage:"
|
||||
```python
|
||||
member = find(lambda m: m.name == "UserName", guild.members)
|
||||
```
|
||||
Args:
|
||||
predicate: A callable that returns a boolean value
|
||||
sequence: A sequence to be searched
|
||||
|
||||
Returns:
|
||||
A match if found, otherwise None
|
||||
|
||||
"""
|
||||
for el in sequence:
|
||||
if predicate(el):
|
||||
return el
|
||||
return None
|
||||
|
||||
|
||||
def find_all(predicate: Callable, sequence: Iterable) -> List[Any]:
|
||||
"""
|
||||
Find all elements in a sequence that match the predicate.
|
||||
|
||||
??? Hint "Example Usage:"
|
||||
```python
|
||||
members = find_all(lambda m: m.name == "UserName", guild.members)
|
||||
```
|
||||
Args:
|
||||
predicate: A callable that returns a boolean value
|
||||
sequence: A sequence to be searched
|
||||
|
||||
Returns:
|
||||
A list of matches
|
||||
|
||||
"""
|
||||
return [el for el in sequence if predicate(el)]
|
||||
|
||||
|
||||
def get(sequence: Iterable, **kwargs: Any) -> Optional[Any]:
|
||||
"""
|
||||
Find the first element in a sequence that matches all attrs.
|
||||
|
||||
??? Hint "Example Usage:"
|
||||
```python
|
||||
channel = get(guild.channels, nsfw=False, category="General")
|
||||
```
|
||||
|
||||
Args:
|
||||
sequence: A sequence to be searched
|
||||
kwargs: Keyword arguments to search the sequence for
|
||||
|
||||
Returns:
|
||||
A match if found, otherwise None
|
||||
"""
|
||||
if not kwargs:
|
||||
return sequence[0]
|
||||
|
||||
for el in sequence:
|
||||
if any(not hasattr(el, attr) for attr in kwargs.keys()):
|
||||
continue
|
||||
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
|
||||
return el
|
||||
return None
|
||||
|
||||
|
||||
def get_all(sequence: Iterable, **kwargs: Any) -> List[Any]:
|
||||
"""
|
||||
Find all elements in a sequence that match all attrs.
|
||||
|
||||
??? Hint "Example Usage:"
|
||||
```python
|
||||
channels = get_all(guild.channels, nsfw=False, category="General")
|
||||
```
|
||||
|
||||
Args:
|
||||
sequence: A sequence to be searched
|
||||
kwargs: Keyword arguments to search the sequence for
|
||||
|
||||
Returns:
|
||||
A list of matches
|
||||
"""
|
||||
if not kwargs:
|
||||
return sequence
|
||||
|
||||
matches = []
|
||||
for el in sequence:
|
||||
if any(not hasattr(el, attr) for attr in kwargs.keys()):
|
||||
continue
|
||||
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
|
||||
matches.append(el)
|
||||
return matches
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "jarvis-core"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = ""
|
||||
authors = ["Your Name <you@example.com>"]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue