From ba7436f10c0698bef084537dd3b141fd321024c5 Mon Sep 17 00:00:00 2001 From: Stephan Kulow Date: Wed, 2 Nov 2022 18:07:24 +0100 Subject: [PATCH] Keep a reference to the database in DBRevision To avoid passing the db to all actions --- lib/db_revision.py | 68 +++++++++++++++++++++++++-------------------- lib/git_exporter.py | 6 ++-- lib/importer.py | 24 ++++++++-------- 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/lib/db_revision.py b/lib/db_revision.py index da37166..9f751f8 100644 --- a/lib/db_revision.py +++ b/lib/db_revision.py @@ -6,11 +6,12 @@ from pathlib import PurePath from typing import Optional from lib.db import DB +from lib.obs_revision import OBSRevision from lib.request import Request class DBRevision: - def __init__(self, row): + def __init__(self, db: DB, row: tuple): # need to stay in sync with the schema creation in db.py ( self.dbid, @@ -29,6 +30,7 @@ class DBRevision: ) = row self.rev = float(self.rev) self._files = None + self.db = db def short_string(self): return f"{self.project}/{self.package}/{self.rev}" @@ -49,7 +51,7 @@ class DBRevision: return self.package < other.package return self.rev < other.rev - def as_dict(self, db): + def as_dict(self): """Return a dict we can put into YAML for test cases""" ret = { "project": self.project, @@ -62,21 +64,21 @@ class DBRevision: "broken": self.broken, "expanded_srcmd5": self.expanded_srcmd5, "files_hash": self.files_hash, - "files": self.files_list(db), + "files": self.files_list(), } if self.request_id: - ret["request"] = Request.find(db, self.request_id).as_dict() + ret["request"] = Request.find(self.db, self.request_id).as_dict() return ret - def links_to(self, db, project, package): - with db.cursor() as cur: + def links_to(self, project: str, package: str) -> None: + with self.db.cursor() as cur: cur.execute( "INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)", (self.dbid, project, package), ) - @classmethod - def import_obs_rev(cls, db, revision): + @staticmethod + def import_obs_rev(db: DB, revision: OBSRevision): with db.cursor() as cur: cur.execute( """INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number) @@ -92,7 +94,9 @@ class DBRevision: revision.request_number, ), ) - return cls.fetch_revision(db, revision.project, revision.package, revision.rev) + return DBRevision.fetch_revision( + db, revision.project, revision.package, revision.rev + ) @staticmethod def fetch_revision(db, project, package, rev): @@ -103,7 +107,7 @@ class DBRevision: ) row = cur.fetchone() if row: - return DBRevision(row) + return DBRevision(db, row) @staticmethod def latest_revision(db, project, package): @@ -126,13 +130,13 @@ class DBRevision: ) ret = [] for row in cur.fetchall(): - ret.append(DBRevision(row)) + ret.append(DBRevision(db, row)) return ret - def linked_rev(self, db): + def linked_rev(self): if self.broken: return None - with db.cursor() as cur: + with self.db.cursor() as cur: cur.execute( "SELECT project,package FROM links where revision_id=%s", (self.dbid,) ) @@ -144,19 +148,19 @@ class DBRevision: "SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1", (project, package, self.commit_time), ) - revisions = [DBRevision(row) for row in cur.fetchall()] + revisions = [DBRevision(self.db, row) for row in cur.fetchall()] if revisions: return revisions[0] else: - self.set_broken(db) + self.set_broken() return None - def set_broken(self, db): - with db.cursor() as cur: + def set_broken(self): + with self.db.cursor() as cur: cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,)) - def import_dir_list(self, db, xml): - with db.cursor() as cur: + def import_dir_list(self, xml): + with self.db.cursor() as cur: cur.execute( "UPDATE revisions SET expanded_srcmd5=%s where id=%s", (xml.get("srcmd5"), self.dbid), @@ -174,15 +178,19 @@ class DBRevision: ), ) - def previous_commit(self, db): - return self.fetch_revision(db, self.project, self.package, int(self.rev) - 1) + def previous_commit(self): + return DBRevision.fetch_revision( + self.db, self.project, self.package, int(self.rev) - 1 + ) - def next_commit(self, db): - return self.fetch_revision(db, self.project, self.package, int(self.rev) + 1) + def next_commit(self): + return DBRevision.fetch_revision( + self.db, self.project, self.package, int(self.rev) + 1 + ) - def calculate_files_hash(self, db): + def calculate_files_hash(self): m = md5() - for file_dict in self.files_list(db): + for file_dict in self.files_list(): m.update( ( file_dict["name"] @@ -194,10 +202,10 @@ class DBRevision: ) return m.hexdigest() - def files_list(self, db): + def files_list(self): if self._files: return self._files - with db.cursor() as cur: + with self.db.cursor() as cur: cur.execute("SELECT * from files where revision_id=%s", (self.dbid,)) self._files = [] for row in cur.fetchall(): @@ -208,7 +216,7 @@ class DBRevision: self._files.sort(key=lambda x: x["name"]) return self._files - def calc_delta(self, db: DB, current_rev: Optional[DBRevision]): + def calc_delta(self, current_rev: Optional[DBRevision]): """Calculate the list of files to download and to delete. Param current_rev is the revision that's currently checked out. If it's None, the repository is empty. @@ -217,11 +225,11 @@ class DBRevision: to_delete = [] if current_rev: old_files = { - e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list(db) + e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list() } else: old_files = dict() - for entry in self.files_list(db): + for entry in self.files_list(): if old_files.get(entry["name"]) != f"{entry['md5']}-{entry['size']}": logging.debug(f"Download {entry['name']}") to_download.append((PurePath(entry["name"]), entry["md5"])) diff --git a/lib/git_exporter.py b/lib/git_exporter.py index 5f89f94..73b4b80 100644 --- a/lib/git_exporter.py +++ b/lib/git_exporter.py @@ -120,19 +120,19 @@ class GitExporter: self.git.gc() gc_cnt = self.gc_interval logging.debug(f"Committing {flat}") - self.commit_flat(db, flat, branch_state) + self.commit_flat(flat, branch_state) def limit_download(self, file: Path): return file.suffix in (".spec", ".changes") - def commit_flat(self, db, flat, branch_state): + def commit_flat(self, flat, branch_state): parents = [] self.git.checkout(flat.branch) if flat.parent1: parents.append(flat.parent1.git_commit) if flat.parent2: parents.append(flat.parent2.git_commit) - to_download, to_delete = flat.commit.calc_delta(db, branch_state[flat.branch]) + to_download, to_delete = flat.commit.calc_delta(branch_state[flat.branch]) for file in to_delete: if not self.limit_download(file): continue diff --git a/lib/importer.py b/lib/importer.py index 6445b6a..b5431b4 100644 --- a/lib/importer.py +++ b/lib/importer.py @@ -31,12 +31,12 @@ class Importer: try: root = rev.read_link() except ET.ParseError: - dbrev.set_broken(db) + dbrev.set_broken() continue if root is not None: tprj = root.get("project") or project tpkg = root.get("package") or package - dbrev.links_to(db, tprj, tpkg) + dbrev.links_to(tprj, tpkg) def find_linked_revs(self, db): with db.cursor() as cur: @@ -46,8 +46,8 @@ class Importer: WHERE lrevs.id IS NULL) and broken is FALSE;""" ) for row in cur.fetchall(): - rev = DBRevision(row) - linked_rev = rev.linked_rev(db) + rev = DBRevision(db, row) + linked_rev = rev.linked_rev() if not linked_rev: logging.debug(f"No link {rev}") continue @@ -75,10 +75,10 @@ class Importer: "SELECT * from revisions WHERE id in (SELECT linked_id from linked_revs WHERE considered=FALSE)" ) for row in cur.fetchall(): - self._find_fake_revision(db, DBRevision(row)) + self._find_fake_revision(db, DBRevision(db, row)) def _find_fake_revision(self, db, rev): - prev = rev.previous_commit(db) + prev = rev.previous_commit() if not prev: with db.cursor() as cur: cur.execute( @@ -95,8 +95,8 @@ class Importer: ) last_linked = None for linked in cur.fetchall(): - linked = DBRevision(linked) - nextrev = linked.next_commit(db) + linked = DBRevision(db, linked) + nextrev = linked.next_commit() if nextrev and nextrev.commit_time < rev.commit_time: continue last_linked = linked @@ -149,7 +149,7 @@ class Importer: cur.execute( "SELECT * FROM revisions WHERE broken=FALSE AND expanded_srcmd5 IS NULL" ) - return [DBRevision(row) for row in cur.fetchall()] + return [DBRevision(db, row) for row in cur.fetchall()] def fill_file_lists(self, db): self.find_linked_revs(db) @@ -169,15 +169,15 @@ class Importer: rev.project, rev.package, rev.unexpanded_srcmd5, linked_rev ) if list: - rev.import_dir_list(db, list) - md5 = rev.calculate_files_hash(db) + rev.import_dir_list(list) + md5 = rev.calculate_files_hash() with db.cursor() as cur: cur.execute( "UPDATE revisions SET files_hash=%s WHERE id=%s", (md5, rev.dbid), ) else: - rev.set_broken(db) + rev.set_broken() def refresh_package(self, db, project, package): key = f"{project}/{package}"