diff --git a/lib/importer.py b/lib/importer.py index 3dc6da2..5618206 100644 --- a/lib/importer.py +++ b/lib/importer.py @@ -2,8 +2,6 @@ import functools import logging import xml.etree.ElementTree as ET -import psycopg2 - from lib.binary import is_binary_or_large from lib.db import DB from lib.db_revision import DBRevision @@ -12,7 +10,7 @@ from lib.history import History from lib.obs import OBS from lib.obs_revision import OBSRevision from lib.proxy_sha256 import ProxySHA256, md5, sha256 -from lib.tree_builder import TreeBuilder +from lib.tree_builder import AbstractWalker, TreeBuilder from lib.user import User @@ -261,7 +259,13 @@ class Importer: def export_as_git(self): db = DB() - TreeBuilder(db).build(self.package).print() + tree = TreeBuilder(db).build(self.package) + + class ExportWalker(AbstractWalker): + def call(self, node, is_source): + pass + + tree.walk(ExportWalker()) def import_into_db(self): db = DB() diff --git a/lib/request.py b/lib/request.py index fcdba9f..23429c7 100644 --- a/lib/request.py +++ b/lib/request.py @@ -61,21 +61,22 @@ class Request: with db.cursor() as cur: cur.execute("""SELECT * from requests WHERE id=%s""", (request_id,)) row = cur.fetchone() - ret = Request() - ret._from_db(row) - return ret + return Request.from_db(row) - def _from_db(self, row): + @staticmethod + def from_db(row): + ret = Request() ( - self.dbid, - self.number, - self.creator, - self.type_, - self.state, - self.source_package, - self.source_project, - self.source_rev, + ret.dbid, + ret.number, + ret.creator, + ret.type_, + ret.state, + ret.source_package, + ret.source_project, + ret.source_rev, ) = row + return ret def as_dict(self): return { diff --git a/lib/tree_builder.py b/lib/tree_builder.py index 41baa32..15ff847 100644 --- a/lib/tree_builder.py +++ b/lib/tree_builder.py @@ -1,3 +1,6 @@ +from typing import Dict +from xmlrpc.client import Boolean + from lib.db_revision import DBRevision from lib.request import Request @@ -103,24 +106,57 @@ class TreeBuilder: def add_merge_points(self, factory_revisions): """For all target revisions that accepted a request, look up the merge points in the source chains (ignoring the actual revision submitted for now)""" - source_revisions = dict() - factory_node = factory_revisions - while factory_node: - if factory_node.revision.request_id: - req = Request.find(self.db, factory_node.revision.request_id) + + class FindRequestsWalker(AbstractWalker): + def __init__(self) -> None: + super().__init__() + self.requests = set() + + def call(self, node: TreeNode, _: Boolean) -> None: + if not node.revision.request_id: + return + self.requests.add(node.revision.request_id) + + class FindMergeWalker(AbstractWalker): + def __init__(self, builder: TreeBuilder, requests: Dict) -> None: + super().__init__() + self.source_revisions = dict() + self.builder = builder + self.requests = requests + + def call(self, node, is_source) -> None: + # not going to happen, but better safe + if is_source: + return + if not node.revision.request_id: + return + req = self.requests.get(node.revision.request_id) key = f"{req.source_project}/{req.source_package}" - if key not in source_revisions: - source_revisions[key] = self.revisions_chain( + if key not in self.source_revisions: + self.source_revisions[key] = self.builder.revisions_chain( req.source_project, req.source_package ) - factory_node.merged = self.find_merge( - factory_node.revision, source_revisions[key] + node.merged = self.builder.find_merge( + node.revision, self.source_revisions[key] ) # add a reverse lookup - if factory_node.merged: - factory_node.merged.merged_into = factory_node + if node.merged: + node.merged.merged_into = node - factory_node = factory_node.parent + # walk the tree twice. First we collect all requests to be looked up + # to avoid going into the DB a thousand times + frqs = FindRequestsWalker() + factory_revisions.walk(frqs) + requests = dict() + with self.db.cursor() as cur: + cur.execute( + "SELECT * from requests WHERE id = ANY(%s)", (list(frqs.requests),) + ) + for row in cur.fetchall(): + req = Request.from_db(row) + requests[req.dbid] = req + sw = FindMergeWalker(self, requests) + factory_revisions.walk(sw) def prune_loose_end(self, factory_node): """Look for source revisions that end in a new root and prune them"""