diff --git a/wallabag_kindle_consumer/consumer.py b/wallabag_kindle_consumer/consumer.py index 3324ada..b105aea 100644 --- a/wallabag_kindle_consumer/consumer.py +++ b/wallabag_kindle_consumer/consumer.py @@ -4,7 +4,7 @@ import asyncio from logbook import Logger from sqlalchemy.orm import joinedload -from wallabag_kindle_consumer.models import User, Job, session_maker +from wallabag_kindle_consumer.models import User, Job, context_session logger = Logger(__name__) @@ -12,7 +12,7 @@ logger = Logger(__name__) class Consumer: def __init__(self, wallabag, cfg, sender): self.wallabag = wallabag - self.session = session_maker(cfg.db_uri)() + self.sessionmaker = context_session(cfg) self.interval = cfg.consume_interval self.sender = sender self.running = True @@ -25,22 +25,23 @@ class Consumer: user.jobs.append(job) await self.wallabag.remove_tag(user, entry) - async def process_job(self, job): + async def process_job(self, job, session): logger.info("Process export for job {id} ({format})", id=job.article, format=job.format) data = await self.wallabag.export_article(job.user, job.article, job.format) await self.sender.send_mail(job, data) - self.session.delete(job) + session.delete(job) async def consume(self): while self.running: - logger.info("Start consume run") - fetches = [self.fetch_jobs(user) for user in self.session.query(User).all()] - await asyncio.gather(*fetches) - self.session.commit() + with self.sessionmaker as session: + logger.info("Start consume run") + fetches = [self.fetch_jobs(user) for user in session.query(User).all()] + await asyncio.gather(*fetches) + session.commit() - jobs = [self.process_job(job) for job in self.session.query(Job).options(joinedload('user'))] - await asyncio.gather(*jobs) - self.session.commit() + jobs = [self.process_job(job, session) for job in session.query(Job).options(joinedload('user'))] + await asyncio.gather(*jobs) + session.commit() await asyncio.sleep(self.interval) diff --git a/wallabag_kindle_consumer/refresher.py b/wallabag_kindle_consumer/refresher.py index 3296151..811c7db 100644 --- a/wallabag_kindle_consumer/refresher.py +++ b/wallabag_kindle_consumer/refresher.py @@ -3,19 +3,20 @@ from datetime import datetime, timedelta from logbook import Logger from sqlalchemy import func -from .models import User, session_maker + +from .models import User, context_session logger = Logger(__name__) class Refresher: def __init__(self, config, wallabag): - self.session = session_maker(config.db_uri)() + self.sessionmaker = context_session(config) self.wallabag = wallabag self.grace = config.refresh_grace - def _wait_time(self): - next = self.session.query(func.min(User.token_valid).label("min")).first() + def _wait_time(self, session): + next = session.query(func.min(User.token_valid).label("min")).first() if next is None or next.min is None: return 3 delta = next.min - datetime.utcnow() @@ -27,15 +28,16 @@ class Refresher: async def refresh(self): while True: - await asyncio.sleep(self._wait_time()) + with self.sessionmaker as session: + await asyncio.sleep(self._wait_time(session)) - ts = datetime.utcnow() + timedelta(seconds=self.grace) - refreshes = [self._refresh_user(user) for user - in self.session.query(User).filter(User.token_valid < ts).all()] - await asyncio.gather(*refreshes) + ts = datetime.utcnow() + timedelta(seconds=self.grace) + refreshes = [self._refresh_user(user) for user + in session.query(User).filter(User.token_valid < ts).all()] + await asyncio.gather(*refreshes) - self.session.commit() - self.session.remove() + session.commit() + session.remove() async def _refresh_user(self, user): logger.info("Refresh token for {}", user.name)