Merge pull request 'Keep a reference to the database in DBRevision' (#12) from add_export into main

Reviewed-on: https://gitea.opensuse.org/importers/git-importer/pulls/12
This commit is contained in:
coolo 2022-11-03 08:16:30 +01:00
commit 74f5cd901e
3 changed files with 87 additions and 104 deletions

View File

@ -2,15 +2,16 @@ from __future__ import annotations
import logging import logging
from hashlib import md5 from hashlib import md5
from pathlib import PurePath from pathlib import Path
from typing import Optional from typing import Optional
from lib.db import DB from lib.db import DB
from lib.obs_revision import OBSRevision
from lib.request import Request from lib.request import Request
class DBRevision: class DBRevision:
def __init__(self, row): def __init__(self, db: DB, row: tuple):
# need to stay in sync with the schema creation in db.py # need to stay in sync with the schema creation in db.py
( (
self.dbid, self.dbid,
@ -29,6 +30,7 @@ class DBRevision:
) = row ) = row
self.rev = float(self.rev) self.rev = float(self.rev)
self._files = None self._files = None
self.db = db
def short_string(self): def short_string(self):
return f"{self.project}/{self.package}/{self.rev}" return f"{self.project}/{self.package}/{self.rev}"
@ -49,7 +51,7 @@ class DBRevision:
return self.package < other.package return self.package < other.package
return self.rev < other.rev return self.rev < other.rev
def as_dict(self, db): def as_dict(self):
"""Return a dict we can put into YAML for test cases""" """Return a dict we can put into YAML for test cases"""
ret = { ret = {
"project": self.project, "project": self.project,
@ -62,21 +64,21 @@ class DBRevision:
"broken": self.broken, "broken": self.broken,
"expanded_srcmd5": self.expanded_srcmd5, "expanded_srcmd5": self.expanded_srcmd5,
"files_hash": self.files_hash, "files_hash": self.files_hash,
"files": self.files_list(db), "files": self.files_list(),
} }
if self.request_id: if self.request_id:
ret["request"] = Request.find(db, self.request_id).as_dict() ret["request"] = Request.find(self.db, self.request_id).as_dict()
return ret return ret
def links_to(self, db, project, package): def links_to(self, project: str, package: str) -> None:
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)", "INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)",
(self.dbid, project, package), (self.dbid, project, package),
) )
@classmethod @staticmethod
def import_obs_rev(cls, db, revision): def import_obs_rev(db: DB, revision: OBSRevision):
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
"""INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number) """INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number)
@ -92,7 +94,9 @@ class DBRevision:
revision.request_number, revision.request_number,
), ),
) )
return cls.fetch_revision(db, revision.project, revision.package, revision.rev) return DBRevision.fetch_revision(
db, revision.project, revision.package, revision.rev
)
@staticmethod @staticmethod
def fetch_revision(db, project, package, rev): def fetch_revision(db, project, package, rev):
@ -103,7 +107,7 @@ class DBRevision:
) )
row = cur.fetchone() row = cur.fetchone()
if row: if row:
return DBRevision(row) return DBRevision(db, row)
@staticmethod @staticmethod
def latest_revision(db, project, package): def latest_revision(db, project, package):
@ -126,13 +130,13 @@ class DBRevision:
) )
ret = [] ret = []
for row in cur.fetchall(): for row in cur.fetchall():
ret.append(DBRevision(row)) ret.append(DBRevision(db, row))
return ret return ret
def linked_rev(self, db): def linked_rev(self):
if self.broken: if self.broken:
return None return None
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"SELECT project,package FROM links where revision_id=%s", (self.dbid,) "SELECT project,package FROM links where revision_id=%s", (self.dbid,)
) )
@ -144,19 +148,19 @@ class DBRevision:
"SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1", "SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1",
(project, package, self.commit_time), (project, package, self.commit_time),
) )
revisions = [DBRevision(row) for row in cur.fetchall()] revisions = [DBRevision(self.db, row) for row in cur.fetchall()]
if revisions: if revisions:
return revisions[0] return revisions[0]
else: else:
self.set_broken(db) self.set_broken()
return None return None
def set_broken(self, db): def set_broken(self):
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,)) cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,))
def import_dir_list(self, db, xml): def import_dir_list(self, xml):
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute( cur.execute(
"UPDATE revisions SET expanded_srcmd5=%s where id=%s", "UPDATE revisions SET expanded_srcmd5=%s where id=%s",
(xml.get("srcmd5"), self.dbid), (xml.get("srcmd5"), self.dbid),
@ -174,15 +178,19 @@ class DBRevision:
), ),
) )
def previous_commit(self, db): def previous_commit(self):
return self.fetch_revision(db, self.project, self.package, int(self.rev) - 1) return DBRevision.fetch_revision(
self.db, self.project, self.package, int(self.rev) - 1
)
def next_commit(self, db): def next_commit(self):
return self.fetch_revision(db, self.project, self.package, int(self.rev) + 1) return DBRevision.fetch_revision(
self.db, self.project, self.package, int(self.rev) + 1
)
def calculate_files_hash(self, db): def calculate_files_hash(self):
m = md5() m = md5()
for file_dict in self.files_list(db): for file_dict in self.files_list():
m.update( m.update(
( (
file_dict["name"] file_dict["name"]
@ -194,10 +202,10 @@ class DBRevision:
) )
return m.hexdigest() return m.hexdigest()
def files_list(self, db): def files_list(self):
if self._files: if self._files:
return self._files return self._files
with db.cursor() as cur: with self.db.cursor() as cur:
cur.execute("SELECT * from files where revision_id=%s", (self.dbid,)) cur.execute("SELECT * from files where revision_id=%s", (self.dbid,))
self._files = [] self._files = []
for row in cur.fetchall(): for row in cur.fetchall():
@ -208,7 +216,7 @@ class DBRevision:
self._files.sort(key=lambda x: x["name"]) self._files.sort(key=lambda x: x["name"])
return self._files return self._files
def calc_delta(self, db: DB, current_rev: Optional[DBRevision]): def calc_delta(self, current_rev: Optional[DBRevision]):
"""Calculate the list of files to download and to delete. """Calculate the list of files to download and to delete.
Param current_rev is the revision that's currently checked out. Param current_rev is the revision that's currently checked out.
If it's None, the repository is empty. If it's None, the repository is empty.
@ -217,18 +225,18 @@ class DBRevision:
to_delete = [] to_delete = []
if current_rev: if current_rev:
old_files = { old_files = {
e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list(db) e["name"]: f"{e['md5']}-{e['size']}" for e in current_rev.files_list()
} }
else: else:
old_files = dict() old_files = dict()
for entry in self.files_list(db): for entry in self.files_list():
if old_files.get(entry["name"]) != f"{entry['md5']}-{entry['size']}": if old_files.get(entry["name"]) != f"{entry['md5']}-{entry['size']}":
logging.debug(f"Download {entry['name']}") logging.debug(f"Download {entry['name']}")
to_download.append((PurePath(entry["name"]), entry["md5"])) to_download.append((Path(entry["name"]), entry["size"], entry["md5"]))
old_files.pop(entry["name"], None) old_files.pop(entry["name"], None)
for entry in old_files.keys(): for entry in old_files.keys():
logging.debug(f"Delete {entry}") logging.debug(f"Delete {entry}")
to_delete.append(PurePath(entry)) to_delete.append(Path(entry))
return to_download, to_delete return to_download, to_delete
@staticmethod @staticmethod

View File

@ -1,6 +1,5 @@
import logging import logging
import os import os
from pathlib import Path
import yaml import yaml
@ -42,44 +41,10 @@ class GitExporter:
# Download each file in OBS if it is not a binary (or large) # Download each file in OBS if it is not a binary (or large)
# file # file
for (name, size, file_md5) in obs_files: for (name, size, file_md5) in obs_files:
# this file creates easily 100k commits and is just useless data :( # Validate the MD5 of the downloaded file
# unfortunately it's stored in the same meta package as the project config if md5(self.git.path / name) != file_md5:
if revision.package == "_project" and name == "_staging_workflow": raise Exception(f"Download error in {name}")
continue self.git.add(name)
# have such files been detected as text mimetype before?
is_text = self.proxy_sha256.is_text(name)
if not is_text and is_binary_or_large(name, size):
file_sha256 = self.proxy_sha256.get_or_put(
revision.project,
revision.package,
name,
revision.srcmd5,
file_md5,
size,
)
self.git.add_lfs(name, file_sha256["sha256"], size)
else:
if (name, size, file_md5) not in git_files:
logging.debug(f"Download {name}")
self.obs.download(
revision.project,
revision.package,
name,
revision.srcmd5,
self.git.path,
file_md5=file_md5,
)
# Validate the MD5 of the downloaded file
if md5(self.git.path / name) != file_md5:
raise Exception(f"Download error in {name}")
self.git.add(name)
# Remove extra files
obs_names = {n for (n, _, _) in obs_files}
git_names = {n for (n, _, _) in git_files}
for name in git_names - obs_names:
logging.debug(f"Remove {name}")
self.git.remove(name)
def set_gc_interval(self, gc): def set_gc_interval(self, gc):
self.gc_interval = gc self.gc_interval = gc
@ -120,35 +85,45 @@ class GitExporter:
self.git.gc() self.git.gc()
gc_cnt = self.gc_interval gc_cnt = self.gc_interval
logging.debug(f"Committing {flat}") logging.debug(f"Committing {flat}")
self.commit_flat(db, flat, branch_state) self.commit_flat(flat, branch_state)
def limit_download(self, file: Path): def commit_flat(self, flat, branch_state):
return file.suffix in (".spec", ".changes")
def commit_flat(self, db, flat, branch_state):
parents = [] parents = []
self.git.checkout(flat.branch) self.git.checkout(flat.branch)
# Overwrite ".gitattributes" with the
self.git.add_default_lfs_gitattributes(force=True)
if flat.parent1: if flat.parent1:
parents.append(flat.parent1.git_commit) parents.append(flat.parent1.git_commit)
if flat.parent2: if flat.parent2:
parents.append(flat.parent2.git_commit) parents.append(flat.parent2.git_commit)
to_download, to_delete = flat.commit.calc_delta(db, branch_state[flat.branch]) to_download, to_delete = flat.commit.calc_delta(branch_state[flat.branch])
for file in to_delete: for file in to_delete:
if not self.limit_download(file):
continue
self.git.remove(file) self.git.remove(file)
for file, md5 in to_download: for file, size, md5 in to_download:
if not self.limit_download(file): # have such files been detected as text mimetype before?
continue is_text = self.proxy_sha256.is_text(file.name)
self.obs.download( if not is_text and is_binary_or_large(file.name, size):
flat.commit.project, file_sha256 = self.proxy_sha256.get_or_put(
flat.commit.package, flat.commit.project,
file.name, flat.commit.package,
flat.commit.expanded_srcmd5, file.name,
self.git.path, flat.commit.expanded_srcmd5,
file_md5=md5, md5,
) size,
self.git.add(file) )
self.git.add_lfs(file.name, file_sha256["sha256"], size)
else:
self.obs.download(
flat.commit.project,
flat.commit.package,
file.name,
flat.commit.expanded_srcmd5,
self.git.path,
file_md5=md5,
)
self.git.add(file)
commit = self.git.commit( commit = self.git.commit(
f"OBS User {flat.commit.userid}", f"OBS User {flat.commit.userid}",

View File

@ -31,12 +31,12 @@ class Importer:
try: try:
root = rev.read_link() root = rev.read_link()
except ET.ParseError: except ET.ParseError:
dbrev.set_broken(db) dbrev.set_broken()
continue continue
if root is not None: if root is not None:
tprj = root.get("project") or project tprj = root.get("project") or project
tpkg = root.get("package") or package tpkg = root.get("package") or package
dbrev.links_to(db, tprj, tpkg) dbrev.links_to(tprj, tpkg)
def find_linked_revs(self, db): def find_linked_revs(self, db):
with db.cursor() as cur: with db.cursor() as cur:
@ -46,8 +46,8 @@ class Importer:
WHERE lrevs.id IS NULL) and broken is FALSE;""" WHERE lrevs.id IS NULL) and broken is FALSE;"""
) )
for row in cur.fetchall(): for row in cur.fetchall():
rev = DBRevision(row) rev = DBRevision(db, row)
linked_rev = rev.linked_rev(db) linked_rev = rev.linked_rev()
if not linked_rev: if not linked_rev:
logging.debug(f"No link {rev}") logging.debug(f"No link {rev}")
continue continue
@ -75,10 +75,10 @@ class Importer:
"SELECT * from revisions WHERE id in (SELECT linked_id from linked_revs WHERE considered=FALSE)" "SELECT * from revisions WHERE id in (SELECT linked_id from linked_revs WHERE considered=FALSE)"
) )
for row in cur.fetchall(): for row in cur.fetchall():
self._find_fake_revision(db, DBRevision(row)) self._find_fake_revision(db, DBRevision(db, row))
def _find_fake_revision(self, db, rev): def _find_fake_revision(self, db, rev):
prev = rev.previous_commit(db) prev = rev.previous_commit()
if not prev: if not prev:
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
@ -95,8 +95,8 @@ class Importer:
) )
last_linked = None last_linked = None
for linked in cur.fetchall(): for linked in cur.fetchall():
linked = DBRevision(linked) linked = DBRevision(db, linked)
nextrev = linked.next_commit(db) nextrev = linked.next_commit()
if nextrev and nextrev.commit_time < rev.commit_time: if nextrev and nextrev.commit_time < rev.commit_time:
continue continue
last_linked = linked last_linked = linked
@ -149,7 +149,7 @@ class Importer:
cur.execute( cur.execute(
"SELECT * FROM revisions WHERE broken=FALSE AND expanded_srcmd5 IS NULL" "SELECT * FROM revisions WHERE broken=FALSE AND expanded_srcmd5 IS NULL"
) )
return [DBRevision(row) for row in cur.fetchall()] return [DBRevision(db, row) for row in cur.fetchall()]
def fill_file_lists(self, db): def fill_file_lists(self, db):
self.find_linked_revs(db) self.find_linked_revs(db)
@ -169,15 +169,15 @@ class Importer:
rev.project, rev.package, rev.unexpanded_srcmd5, linked_rev rev.project, rev.package, rev.unexpanded_srcmd5, linked_rev
) )
if list: if list:
rev.import_dir_list(db, list) rev.import_dir_list(list)
md5 = rev.calculate_files_hash(db) md5 = rev.calculate_files_hash()
with db.cursor() as cur: with db.cursor() as cur:
cur.execute( cur.execute(
"UPDATE revisions SET files_hash=%s WHERE id=%s", "UPDATE revisions SET files_hash=%s WHERE id=%s",
(md5, rev.dbid), (md5, rev.dbid),
) )
else: else:
rev.set_broken(db) rev.set_broken()
def refresh_package(self, db, project, package): def refresh_package(self, db, project, package):
key = f"{project}/{package}" key = f"{project}/{package}"