diff --git a/wallabag_kindle_consumer/config.py b/wallabag_kindle_consumer/config.py index a2087ce..542a6be 100644 --- a/wallabag_kindle_consumer/config.py +++ b/wallabag_kindle_consumer/config.py @@ -1,36 +1,29 @@ -import os +import dataclasses import configparser +import os from logbook import Logger logger = Logger(__name__) +@dataclasses.dataclass class Config: - known_values = ["wallabag_host", "db_uri", "client_id", "client_secret", "domain", "smtp_from", "smtp_host", - "smtp_port", "smtp_user", "smtp_passwd", "tag", "refresh_grace", "consume_interval", - "interface_host", "interface_port"] - required_values = ["wallabag_host", "db_uri", "client_id", "client_secret", "domain", "smtp_from", "smtp_host", - "smtp_port", "smtp_user", "smtp_passwd"] - - def __init__(self, wallabag_host, db_uri, client_id, client_secret, domain, smtp_from, smtp_host, smtp_port, - smtp_user, smtp_passwd, tag='kindle', refresh_grace=120, consume_interval=30, - interface_host="127.0.0.1", interface_port=8080): - self.wallabag_host = wallabag_host - self.db_uri = db_uri - self.client_id = client_id - self.client_secret = client_secret - self.domain = domain - self.smtp_from = smtp_from - self.smtp_host = smtp_host - self.smtp_port = smtp_port - self.smtp_user = smtp_user - self.smtp_passwd = smtp_passwd - self.tag = tag - self.refresh_grace = refresh_grace - self.consume_interval = consume_interval - self.interface_host = interface_host - self.interface_port = interface_port + wallabag_host: str + db_uri: str + client_id: str + client_secret: str + domain: str + smtp_from: str + smtp_host: str + smtp_port: int + smtp_user: str + smtp_passwd: str + tag: str = "kindle" + refresh_grace: int = 120 + consume_interval: int = 30 + interface_host: str = "127.0.0.1" + interface_port: int = 8080 @staticmethod def from_file(filename): @@ -51,12 +44,12 @@ class Config: tmp = {} missing = [] - for key in Config.known_values: - if key in dflt: - tmp[key] = dflt[key] + for field in dataclasses.fields(Config): + if field.name in dflt: + tmp[field.name] = field.type(dflt[field.name]) else: - if key in Config.required_values: - missing.append(key) + if field.default is dataclasses.MISSING: + missing.append(field.name) if 0 != len(missing): logger.warn("Config file {filename} does not contain configs for: {lst}", filename=filename, @@ -70,12 +63,12 @@ class Config: logger.info("Read config from environment") tmp = {} missing = [] - for key in Config.known_values: - if key.upper() in os.environ: - tmp[key] = os.environ[key.upper()] + for field in dataclasses.fields(Config): + if field.name.upper() in os.environ: + tmp[field.name] = field.type(os.environ[field.name.upper()]) else: - if key in Config.required_values: - missing.append(key.upper()) + if field.default is dataclasses.MISSING: + missing.append(field.name.upper()) if 0 != len(missing): logger.warn("Environment config does not contain configs for: {lst}", lst=", ".join(missing))