Keep a reference to the database in DBRevision

To avoid passing the db to all actions
This commit is contained in:
Stephan Kulow 2022-11-02 18:07:24 +01:00
parent 75f9f56a57
commit ba7436f10c
3 changed files with 53 additions and 45 deletions

View File

@ -6,11 +6,12 @@ from pathlib import PurePath
from typing import Optional from typing import Optional
from lib.db import DB from lib.db import DB
from lib.obs_revision import OBSRevision
from lib.request import Request from lib.request import Request
class DBRevision: class DBRevision:
def __init__(self, row): def __init__(self, db: DB, row: tuple):
# need to stay in sync with the schema creation in db.py # need to stay in sync with the schema creation in db.py
( (
self.dbid, self.dbid,
@ -29,6 +30,7 @@ class DBRevision:
) = row ) = row
self.rev = float(self.rev) self.rev = float(self.rev)
self._files = None self._files = None
self.db = db
def short_string(self): def short_string(self):
return f"{self.project}/{self.package}/{self.rev}" return f"{self.project}/{self.package}/{self.rev}"
@ -49,7 +51,7 @@ class DBRevision:
return self.package < other.package return self.package < other.package
return self.rev < other.rev return self.rev < other.rev
def as_dict(self, db): def as_dict(self):
"""Return a dict we can put into YAML for test cases""" """Return a dict we can put into YAML for test cases"""
ret = { ret = {
"project": self.project, "project": self.project,
@ -62,21 +64,21 @@ class DBRevision:
"broken": self.broken, "broken": self.broken,
"expanded_srcmd5": self.expanded_srcmd5, "expanded_srcmd5": self.expanded_srcmd5,
"files_hash": self.files_hash, "files_hash": self.files_hash,
"files": self.files_list(db), "files": self.files_list(),
} }
if self.request_id: if self.request_id:
ret["request"] = Request.find(db, self.request_id).as_dict() ret["request"] = Request.find(self.db, self.request_id).as_dict()
return ret return ret
def links_to(self, db, project, package): def links_to(self, project: str, package: str) -> None:
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)", "INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)",
(self.dbid, project, package), (self.dbid, project, package),
) )
@classmethod @staticmethod
def import_obs_rev(cls, db, revision): def import_obs_rev(db: DB, revision: OBSRevision):
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
"""INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number) """INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number)
@ -92,7 +94,9 @@ class DBRevision:
revision.request_number, revision.request_number,
), ),
) )
return cls.fetch_revision(db, revision.project, revision.package, revision.rev) return DBRevision.fetch_revision(
db, revision.project, revision.package, revision.rev
)
@staticmethod @staticmethod
def fetch_revision(db, project, package, rev): def fetch_revision(db, project, package, rev):
@ -103,7 +107,7 @@ class DBRevision:
) )
row = cur.fetchone() row = cur.fetchone()
if row: if row:
return DBRevision(row) return DBRevision(db, row)
@staticmethod @staticmethod
def latest_revision(db, project, package): def latest_revision(db, project, package):
@ -126,13 +130,13 @@ class DBRevision:
) )
ret = [] ret = []
for row in cur.fetchall(): for row in cur.fetchall():
ret.append(DBRevision(row)) ret.append(DBRevision(db, row))
return ret return ret
def linked_rev(self, db): def linked_rev(self):
if self.broken: if self.broken:
return None return None
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"SELECT project,package FROM links where revision_id=%s", (self.dbid,) "SELECT project,package FROM links where revision_id=%s", (self.dbid,)
) )
@ -144,19 +148,19 @@ class DBRevision:
"SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1", "SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1",
(project, package, self.commit_time), (project, package, self.commit_time),
) )
revisions = [DBRevision(row) for row in cur.fetchall()] revisions = [DBRevision(self.db, row) for row in cur.fetchall()]
if revisions: if revisions:
return revisions[0] return revisions[0]
else: else:
self.set_broken(db) self.set_broken()
return None return None
def set_broken(self, db): def set_broken(self):
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,)) cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,))
def import_dir_list(self, db, xml): def import_dir_list(self, xml):
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"UPDATE revisions SET expanded_srcmd5=%s where id=%s", "UPDATE revisions SET expanded_srcmd5=%s where id=%s",
(xml.get("srcmd5"), self.dbid), (xml.get("srcmd5"), self.dbid),
@ -174,15 +178,19 @@ class DBRevision:
), ),
) )
def previous_commit(self, db): def previous_commit(self):
return self.fetch_revision(db, self.project, self.package, int(self.rev) - 1) return DBRevision.fetch_revision(
self.db, self.project, self.package, int(self.rev) - 1
)
def next_commit(self, db): def next_commit(self):
return self.fetch_revision(db, self.project, self.package, int(self.rev) + 1) return DBRevision.fetch_revision(
self.db, self.project, self.package, int(self.rev) + 1
)
def calculate_files_hash(self, db): def calculate_files_hash(self):
m = md5() m = md5()
for file_dict in self.files_list(db): for file_dict in self.files_list():
m.update( m.update(
( (
file_dict["name"] file_dict["name"]
@ -194,10 +202,10 @@ class DBRevision:
) )
return m.hexdigest() return m.hexdigest()
def files_list(self, db): def files_list(self):
if self._files: if self._files:
return self._files return self._files
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute("SELECT * from files where revision_id=%s", (self.dbid,)) cur.execute("SELECT * from files where revision_id=%s", (self.dbid,))
self._files = [] self._files = []
for row in cur.fetchall(): for row in cur.fetchall():
@ -208,7 +216,7 @@ class DBRevision:
self._files.sort(key=lambda x: x["name"]) self._files.sort(key=lambda x: x["name"])
return self._files return self._files
def calc_delta(self, db: DB, current_rev: Optional[DBRevision]): def calc_delta(self, current_rev: Optional[DBRevision]):
"""Calculate the list of files to download and to delete. """Calculate the list of files to download and to delete.
Param current_rev is the revision that's currently checked out. Param current_rev is the revision that's currently checked out.
If it's None, the repository is empty. If it's None, the repository is empty.
@ -217,11 +225,11 @@ class DBRevision:
to_delete = [] to_delete = []
if current_rev: if current_rev:
old_files = { old_files = {
e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list(db) e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list()
} }
else: else:
old_files = dict() old_files = dict()
for entry in self.files_list(db): for entry in self.files_list():
if old_files.get(entry["name"]) != f"{entry['md5']}-{entry['size']}": if old_files.get(entry["name"]) != f"{entry['md5']}-{entry['size']}":
logging.debug(f"Download {entry['name']}") logging.debug(f"Download {entry['name']}")
to_download.append((PurePath(entry["name"]), entry["md5"])) to_download.append((PurePath(entry["name"]), entry["md5"]))

View File

@ -120,19 +120,19 @@ class GitExporter:
self.git.gc() self.git.gc()
gc_cnt = self.gc_interval gc_cnt = self.gc_interval
logging.debug(f"Committing {flat}") logging.debug(f"Committing {flat}")
self.commit_flat(db, flat, branch_state) self.commit_flat(flat, branch_state)
def limit_download(self, file: Path): def limit_download(self, file: Path):
return file.suffix in (".spec", ".changes") return file.suffix in (".spec", ".changes")
def commit_flat(self, db, flat, branch_state): def commit_flat(self, flat, branch_state):
parents = [] parents = []
self.git.checkout(flat.branch) self.git.checkout(flat.branch)
if flat.parent1: if flat.parent1:
parents.append(flat.parent1.git_commit) parents.append(flat.parent1.git_commit)
if flat.parent2: if flat.parent2:
parents.append(flat.parent2.git_commit) parents.append(flat.parent2.git_commit)
to_download, to_delete = flat.commit.calc_delta(db, branch_state[flat.branch]) to_download, to_delete = flat.commit.calc_delta(branch_state[flat.branch])
for file in to_delete: for file in to_delete:
if not self.limit_download(file): if not self.limit_download(file):
continue continue

View File

@ -31,12 +31,12 @@ class Importer:
try: try:
root = rev.read_link() root = rev.read_link()
except ET.ParseError: except ET.ParseError:
dbrev.set_broken(db) dbrev.set_broken()
continue continue
if root is not None: if root is not None:
tprj = root.get("project") or project tprj = root.get("project") or project
tpkg = root.get("package") or package tpkg = root.get("package") or package
dbrev.links_to(db, tprj, tpkg) dbrev.links_to(tprj, tpkg)
def find_linked_revs(self, db): def find_linked_revs(self, db):
with db.cursor() as cur: with db.cursor() as cur:
@ -46,8 +46,8 @@ class Importer:
WHERE lrevs.id IS NULL) and broken is FALSE;""" WHERE lrevs.id IS NULL) and broken is FALSE;"""
) )
for row in cur.fetchall(): for row in cur.fetchall():
rev = DBRevision(row) rev = DBRevision(db, row)
linked_rev = rev.linked_rev(db) linked_rev = rev.linked_rev()
if not linked_rev: if not linked_rev:
logging.debug(f"No link {rev}") logging.debug(f"No link {rev}")
continue continue
@ -75,10 +75,10 @@ class Importer:
"SELECT * from revisions WHERE id in (SELECT linked_id from linked_revs WHERE considered=FALSE)" "SELECT * from revisions WHERE id in (SELECT linked_id from linked_revs WHERE considered=FALSE)"
) )
for row in cur.fetchall(): for row in cur.fetchall():
self._find_fake_revision(db, DBRevision(row)) self._find_fake_revision(db, DBRevision(db, row))
def _find_fake_revision(self, db, rev): def _find_fake_revision(self, db, rev):
prev = rev.previous_commit(db) prev = rev.previous_commit()
if not prev: if not prev:
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
@ -95,8 +95,8 @@ class Importer:
) )
last_linked = None last_linked = None
for linked in cur.fetchall(): for linked in cur.fetchall():
linked = DBRevision(linked) linked = DBRevision(db, linked)
nextrev = linked.next_commit(db) nextrev = linked.next_commit()
if nextrev and nextrev.commit_time < rev.commit_time: if nextrev and nextrev.commit_time < rev.commit_time:
continue continue
last_linked = linked last_linked = linked
@ -149,7 +149,7 @@ class Importer:
cur.execute( cur.execute(
"SELECT * FROM revisions WHERE broken=FALSE AND expanded_srcmd5 IS NULL" "SELECT * FROM revisions WHERE broken=FALSE AND expanded_srcmd5 IS NULL"
) )
return [DBRevision(row) for row in cur.fetchall()] return [DBRevision(db, row) for row in cur.fetchall()]
def fill_file_lists(self, db): def fill_file_lists(self, db):
self.find_linked_revs(db) self.find_linked_revs(db)
@ -169,15 +169,15 @@ class Importer:
rev.project, rev.package, rev.unexpanded_srcmd5, linked_rev rev.project, rev.package, rev.unexpanded_srcmd5, linked_rev
) )
if list: if list:
rev.import_dir_list(db, list) rev.import_dir_list(list)
md5 = rev.calculate_files_hash(db) md5 = rev.calculate_files_hash()
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
"UPDATE revisions SET files_hash=%s WHERE id=%s", "UPDATE revisions SET files_hash=%s WHERE id=%s",
(md5, rev.dbid), (md5, rev.dbid),
) )
else: else:
rev.set_broken(db) rev.set_broken()
def refresh_package(self, db, project, package): def refresh_package(self, db, project, package):
key = f"{project}/{package}" key = f"{project}/{package}"