jarvis-tasks/jarvis_tasks/config.py
2023-05-10 18:06:44 -06:00

149 lines
3.6 KiB
Python

"""Task config."""
from enum import Enum
from os import environ
from pathlib import Path
from typing import Optional
import yaml
import orjson as json
from dotenv import load_dotenv
from jarvis_core.util import find_all
from pydantic import BaseModel
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
class Environment(Enum):
"""JARVIS running environment."""
production = "production"
develop = "develop"
class Mongo(BaseModel):
"""MongoDB config."""
host: list[str] | str = "localhost"
username: Optional[str] = None
password: Optional[str] = None
port: int = 27017
class Reddit(BaseModel):
"""Reddit config."""
user_agent: Optional[str] = None
client_secret: str
client_id: str
class Twitter(BaseModel):
"""Twitter config."""
consumer_key: str
consumer_secret: str
access_token: str
access_secret: str
bearer_token: str
class Config(BaseModel):
"""Tasks config model."""
token: str
mongo: Mongo
reddit: Optional[Reddit] = None
twitter: Optional[Twitter] = None
log_level: str = "INFO"
environment: Environment = Environment.develop
_config: Config = None
def _load_json() -> Config | None:
path = Path("config.json")
config = None
if path.exists():
with path.open() as f:
j = json.loads(f.read())
config = Config(**j)
return config
def _load_yaml() -> Config | None:
path = Path("config.yaml")
config = None
if path.exists():
with path.open() as f:
y = yaml.load(f.read(), Loader=Loader)
config = Config(**y)
return config
def _load_env() -> Config | None:
load_dotenv()
data = {}
mongo = {}
twitter = {}
reddit = {}
mongo_keys = find_all(lambda x: x.upper().startswith("MONGO"), environ.keys())
reddit_keys = find_all(lambda x: x.upper().startswith("REDDIT"), environ.keys())
twitter_keys = find_all(lambda x: x.upper().startswith("TWITTER"), environ.keys())
config_keys = (
mongo_keys + reddit_keys + twitter_keys + ["TOKEN", "LOG_LEVEL", "ENVIRONMENT"]
)
for item, value in environ.items():
if item not in config_keys:
continue
if item in mongo_keys:
key = "_".join(item.split("_")[1:]).lower()
mongo[key] = value
elif item in twitter_keys:
key = "_".join(item.split("_")[1:]).lower()
twitter[key] = value
elif item in reddit_keys:
key = "_".join(item.split("_")[1:]).lower()
reddit[key] = value
else:
data[item.lower()] = value
data["mongo"] = mongo
if all(x is not None for x in reddit.values()):
data["reddit"] = reddit
if all(x is not None for x in twitter.values()):
data["twitter"] = twitter
return Config(**data)
def load_config(method: Optional[str] = None) -> Config:
"""
Load the config using the specified method first
Args:
method: Method to use first
"""
global _config
if _config is not None:
return _config
methods = {"yaml": _load_yaml, "json": _load_json, "env": _load_env}
method_names = list(methods.keys())
if method and method in method_names:
method_names.remove(method)
method_names.insert(0, method)
for method in method_names:
if _config := methods[method]():
return _config
raise FileNotFoundError("Missing one of: config.yaml, config.json, .env")