149 lines
3.6 KiB
Python
149 lines
3.6 KiB
Python
"""Task config."""
|
|
from enum import Enum
|
|
from os import environ
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
import orjson as json
|
|
from dotenv import load_dotenv
|
|
from jarvis_core.util import find_all
|
|
from pydantic import BaseModel
|
|
|
|
try:
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Loader
|
|
|
|
|
|
class Environment(Enum):
|
|
"""JARVIS running environment."""
|
|
|
|
production = "production"
|
|
develop = "develop"
|
|
|
|
|
|
class Mongo(BaseModel):
|
|
"""MongoDB config."""
|
|
|
|
host: list[str] | str = "localhost"
|
|
username: Optional[str] = None
|
|
password: Optional[str] = None
|
|
port: int = 27017
|
|
|
|
|
|
class Reddit(BaseModel):
|
|
"""Reddit config."""
|
|
|
|
user_agent: Optional[str] = None
|
|
client_secret: str
|
|
client_id: str
|
|
|
|
|
|
class Twitter(BaseModel):
|
|
"""Twitter config."""
|
|
|
|
consumer_key: str
|
|
consumer_secret: str
|
|
access_token: str
|
|
access_secret: str
|
|
bearer_token: str
|
|
|
|
|
|
class Config(BaseModel):
|
|
"""Tasks config model."""
|
|
|
|
token: str
|
|
mongo: Mongo
|
|
reddit: Optional[Reddit] = None
|
|
twitter: Optional[Twitter] = None
|
|
log_level: str = "INFO"
|
|
environment: Environment = Environment.develop
|
|
|
|
|
|
_config: Config = None
|
|
|
|
|
|
def _load_json() -> Config | None:
|
|
path = Path("config.json")
|
|
config = None
|
|
if path.exists():
|
|
with path.open() as f:
|
|
j = json.loads(f.read())
|
|
config = Config(**j)
|
|
|
|
return config
|
|
|
|
|
|
def _load_yaml() -> Config | None:
|
|
path = Path("config.yaml")
|
|
config = None
|
|
if path.exists():
|
|
with path.open() as f:
|
|
y = yaml.load(f.read(), Loader=Loader)
|
|
config = Config(**y)
|
|
|
|
return config
|
|
|
|
|
|
def _load_env() -> Config | None:
|
|
load_dotenv()
|
|
data = {}
|
|
mongo = {}
|
|
twitter = {}
|
|
reddit = {}
|
|
mongo_keys = find_all(lambda x: x.upper().startswith("MONGO"), environ.keys())
|
|
reddit_keys = find_all(lambda x: x.upper().startswith("REDDIT"), environ.keys())
|
|
twitter_keys = find_all(lambda x: x.upper().startswith("TWITTER"), environ.keys())
|
|
|
|
config_keys = (
|
|
mongo_keys + reddit_keys + twitter_keys + ["TOKEN", "LOG_LEVEL", "ENVIRONMENT"]
|
|
)
|
|
|
|
for item, value in environ.items():
|
|
if item not in config_keys:
|
|
continue
|
|
|
|
if item in mongo_keys:
|
|
key = "_".join(item.split("_")[1:]).lower()
|
|
mongo[key] = value
|
|
elif item in twitter_keys:
|
|
key = "_".join(item.split("_")[1:]).lower()
|
|
twitter[key] = value
|
|
elif item in reddit_keys:
|
|
key = "_".join(item.split("_")[1:]).lower()
|
|
reddit[key] = value
|
|
else:
|
|
data[item.lower()] = value
|
|
|
|
data["mongo"] = mongo
|
|
if all(x is not None for x in reddit.values()):
|
|
data["reddit"] = reddit
|
|
if all(x is not None for x in twitter.values()):
|
|
data["twitter"] = twitter
|
|
|
|
return Config(**data)
|
|
|
|
|
|
def load_config(method: Optional[str] = None) -> Config:
|
|
"""
|
|
Load the config using the specified method first
|
|
|
|
Args:
|
|
method: Method to use first
|
|
"""
|
|
global _config
|
|
if _config is not None:
|
|
return _config
|
|
|
|
methods = {"yaml": _load_yaml, "json": _load_json, "env": _load_env}
|
|
method_names = list(methods.keys())
|
|
if method and method in method_names:
|
|
method_names.remove(method)
|
|
method_names.insert(0, method)
|
|
|
|
for method in method_names:
|
|
if _config := methods[method]():
|
|
return _config
|
|
|
|
raise FileNotFoundError("Missing one of: config.yaml, config.json, .env")
|