diff --git a/ptbcontrib/postgres_persistence/postgrespersistence.py b/ptbcontrib/postgres_persistence/postgrespersistence.py index 085e3f4..c5eacdd 100644 --- a/ptbcontrib/postgres_persistence/postgrespersistence.py +++ b/ptbcontrib/postgres_persistence/postgrespersistence.py @@ -47,6 +47,8 @@ class PostgresPersistence(DictPersistence): persistence instance. """ + PERSISTENCE_ID = 1 + def __init__( self, url: str = None, @@ -57,7 +59,7 @@ def __init__( if url: if not url.startswith("postgresql://"): raise TypeError(f"{url} isn't a valid PostgreSQL database URL.") - engine = create_engine(url, client_encoding="utf8") + engine = create_engine(url, client_encoding="utf8", pool_pre_ping=True) self._session = scoped_session(sessionmaker(bind=engine, autoflush=False)) elif session: @@ -90,9 +92,11 @@ def __init__( # `UPDATE` operations if column have some data already present inside it. if not data: upsert_qry = """ - INSERT INTO persistence (data) VALUES (:jsondata) + INSERT INTO persistence (id, data) VALUES (:id, :jsondata) ON CONFLICT (id) DO UPDATE SET data = :jsondata""" - self._session.execute(text(upsert_qry), {"jsondata": "{}"}) + self._session.execute( + text(upsert_qry), {"id": self.PERSISTENCE_ID, "jsondata": "{}"} + ) self._session.commit() super().__init__( @@ -104,7 +108,7 @@ def __init__( conversations_json=conversations_json, ) finally: - self._session.close() + self._session.remove() def __init_database(self) -> None: """ @@ -113,11 +117,11 @@ def __init_database(self) -> None: runs schema migration if necessary. """ try: - create_table_qry = """ + create_table_qry = f""" CREATE TABLE IF NOT EXISTS persistence( - id INT PRIMARY KEY DEFAULT 1, + id INT PRIMARY KEY DEFAULT {self.PERSISTENCE_ID}, data json NOT NULL, - CONSTRAINT single_row CHECK (id = 1));""" + CONSTRAINT single_row CHECK (id = {self.PERSISTENCE_ID}));""" self._session.execute(text(create_table_qry)) # Check if id column exists, is an integer type, and is a primary key @@ -141,8 +145,9 @@ def __init_database(self) -> None: data_valid = False if column_valid: check_data_qry = """ - SELECT 1 FROM persistence WHERE id = 1;""" - data_valid = self._session.execute(text(check_data_qry)).first() is not None + SELECT 1 FROM persistence WHERE id = :id;""" + result = self._session.execute(text(check_data_qry), {"id": self.PERSISTENCE_ID}) + data_valid = result.first() is not None needs_migration = not (column_valid and data_valid) @@ -150,17 +155,18 @@ def __init_database(self) -> None: self.logger.info("Old database schema detected. Running migration...") migration_commands = [ "ALTER TABLE persistence ADD COLUMN id INT;", - """ - UPDATE persistence SET id = 1 WHERE ctid = ( - SELECT ctid FROM persistence LIMIT 1 - );""", + """UPDATE persistence SET id = :id WHERE ctid = (" + "SELECT ctid FROM persistence LIMIT 1);""", "DELETE FROM persistence WHERE id IS NULL;", "ALTER TABLE persistence ALTER COLUMN id SET NOT NULL;", "ALTER TABLE persistence ADD PRIMARY KEY (id);", - "ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = 1);", + "ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = :id);", ] for command in migration_commands: - self._session.execute(text(command)) + if ":id" in command: + self._session.execute(text(command), {"id": self.PERSISTENCE_ID}) + else: + self._session.execute(text(command)) self.logger.info("Database migration successful!") self._session.commit() @@ -187,9 +193,9 @@ def _update_database(self) -> None: self.logger.debug("Updating database...") try: upsert_qry = """ - INSERT INTO persistence (data) VALUES (:jsondata) + INSERT INTO persistence (id, data) VALUES (:id, :jsondata) ON CONFLICT (id) DO UPDATE SET data = :jsondata""" - params = {"jsondata": self._dump_into_json()} + params = {"id": self.PERSISTENCE_ID, "jsondata": self._dump_into_json()} self._session.execute(text(upsert_qry), params) self._session.commit() except Exception as excp: # pylint: disable=W0703 @@ -198,6 +204,8 @@ def _update_database(self) -> None: exc_info=excp, ) self._session.rollback() + finally: + self._session.remove() async def update_conversation( self, name: str, key: Tuple[int, ...], new_state: Optional[object] diff --git a/tests/test_postgres_persistence.py b/tests/test_postgres_persistence.py index 045c139..d4ff8c0 100644 --- a/tests/test_postgres_persistence.py +++ b/tests/test_postgres_persistence.py @@ -81,7 +81,7 @@ async def test_with_handler(self, bot, update, monkeypatch): session = scoped_session("a") monkeypatch.setattr(session, "execute", self.mocked_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) app = ( Application.builder() @@ -128,7 +128,7 @@ async def test_on_flush(self, bot, update, monkeypatch, on_flush, expected): session = scoped_session("a") monkeypatch.setattr(session, "execute", self.mocked_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) persistence = PostgresPersistence(session=session, on_flush=on_flush) @@ -169,13 +169,13 @@ def test_load_on_boot(self, monkeypatch): session = scoped_session("a") monkeypatch.setattr(session, "execute", self.mocked_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) # Check that either SELECT or UPSERT query was executed (upsert for fresh db) executed_text = self.executed.text.strip() assert "SELECT data FROM persistence" in executed_text or ( - "INSERT INTO persistence (data) VALUES (:jsondata)" in executed_text + "INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in executed_text and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in executed_text ) assert self.commited == 555 @@ -185,7 +185,7 @@ async def test_flush(self, monkeypatch): session = scoped_session("a") monkeypatch.setattr(session, "execute", self.mocked_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) await PostgresPersistence(session=session).flush() assert self.executed != "" @@ -202,7 +202,7 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) @@ -252,7 +252,7 @@ def mock_execute(query, params=None): return FakeExecResultValidPK() # Check for data validation query (id=1 exists) - if "WHERE id = 1" in query.text and "information_schema" not in query.text: + if "WHERE id = :id" in query.text and "information_schema" not in query.text: return FakeExecResultValidData() return FakeExecResult() @@ -260,14 +260,14 @@ def mock_execute(query, params=None): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) # Verify no migration commands were run migration_commands = [ "ALTER TABLE persistence ADD COLUMN id INT", - "UPDATE persistence SET id = 1", + "UPDATE persistence SET id = :id", "DELETE FROM persistence WHERE id IS NULL", ] for migration_cmd in migration_commands: @@ -310,7 +310,7 @@ def mock_execute(query, params=None): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) @@ -340,18 +340,18 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) # Verify migration commands were run in correct order expected_migration_steps = [ "ALTER TABLE persistence ADD COLUMN id INT", - "UPDATE persistence SET id = 1", + "UPDATE persistence SET id = :id", "DELETE FROM persistence WHERE id IS NULL", "ALTER TABLE persistence ALTER COLUMN id SET NOT NULL", "ALTER TABLE persistence ADD PRIMARY KEY (id)", - "ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = 1)", + "ALTER TABLE persistence ADD CONSTRAINT single_row CHECK (id = :id)", ] for expected_step in expected_migration_steps: @@ -385,7 +385,7 @@ def mock_execute(query, params=None): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) persistence = PostgresPersistence(session=session) @@ -403,13 +403,13 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) # Check that upsert query was used for initialization upsert_found = any( - "INSERT INTO persistence (data) VALUES (:jsondata)" in query + "INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in query and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in query for query in executed_queries ) @@ -429,7 +429,7 @@ def mock_execute(query, params=None): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) monkeypatch.setattr(session, "rollback", lambda: None) persistence = PostgresPersistence(session=session) @@ -443,7 +443,7 @@ def mock_execute(query, params=None): # Verify upsert query was used upsert_found = any( - "INSERT INTO persistence (data) VALUES (:jsondata)" in query + "INSERT INTO persistence (id, data) VALUES (:id, :jsondata)" in query and "ON CONFLICT (id) DO UPDATE SET data = :jsondata" in query for query in executed_queries ) @@ -452,6 +452,7 @@ def mock_execute(query, params=None): # Verify parameters were passed assert len(executed_params) > 0 assert "jsondata" in executed_params[0] + assert "id" in executed_params[0] def test_single_row_constraint_in_schema(self, monkeypatch): """Test that single_row constraint is present in schema""" @@ -464,7 +465,7 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) @@ -493,13 +494,13 @@ def mock_execute(query, params=None): # pylint: disable=unused-argument session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session) # Verify that migration includes step to update first row to id=1 update_first_row = any( - "UPDATE persistence SET id = 1" in query and "LIMIT 1" in query + "UPDATE persistence SET id = :id" in query and "LIMIT 1" in query for query in executed_queries ) assert update_first_row @@ -515,7 +516,7 @@ async def test_data_persistence_with_upsert(self, bot, update, monkeypatch): session = scoped_session("a") monkeypatch.setattr(session, "execute", self.mocked_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) app = ( Application.builder() @@ -563,7 +564,7 @@ def mock_rollback(): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute_with_error) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) monkeypatch.setattr(session, "rollback", mock_rollback) # Should not raise exception, but handle it gracefully @@ -589,7 +590,7 @@ def mock_execute(query, params=None): session = scoped_session("a") monkeypatch.setattr(session, "execute", mock_execute) monkeypatch.setattr(session, "commit", self.mock_commit) - monkeypatch.setattr(session, "close", self.mock_ses_close) + monkeypatch.setattr(session, "remove", self.mock_ses_close) PostgresPersistence(session=session)