from xmlrpc.client import Boolean

from lib.db_revision import DBRevision
from lib.flat_walker import FlatTreeWalker
from lib.request import Request


class AbstractWalker:
    def call(self, node, is_source):
        pass


class PrintWalker(AbstractWalker):
    def call(self, node, is_source):
        if is_source:
            print("   ", node.revision.short_string(), node.revision.files_hash)
        else:
            merge_str = ""
            if node.merged:
                merge_str = f"merged:{node.merged.revision.short_string()}"
            print(node.revision.short_string(), node.revision.files_hash, merge_str)


class TreeNode:
    """
    Nodes in this "tree" have either no parent (root), one parent (in a chain)
    or two parents (in this case the merged revision wins in conflicts).
    """

    def __init__(self, rev):
        self.parent = None
        self.merged = None
        self.revision = rev
        self.merged_into = None
        self.git_commit = None

    def walk(self, walker: AbstractWalker):
        node = self
        while node:
            walker.call(node, False)
            if node.merged:
                source_node = node.merged
                while source_node:
                    walker.call(source_node, True)
                    source_node = source_node.parent
                    if source_node and source_node.merged_into:
                        break
            node = node.parent

    def print(self):
        self.walk(PrintWalker())

    def as_flat_list(self):
        """Return the tree as git commits to do"""
        ftw = FlatTreeWalker()
        self.walk(ftw)
        return ftw.flats

    def as_list(self):
        """Return a list for test cases"""
        node = self
        ret = []
        while node:
            repr = {"commit": node.revision.short_string()}
            if node.merged:
                source_node = node.merged
                repr["merged"] = []
                while source_node:
                    repr["merged"].append(source_node.revision.short_string())
                    source_node = source_node.parent
                    if source_node and source_node.merged_into:
                        break
            node = node.parent
            ret.append(repr)
        return ret


class TreeBuilder:
    def __init__(self, db):
        self.db = db

    def revisions_chain(self, project, package):
        """Build a tree without branches (chain) from a project's
        history ignoring empty and broken revisions"""
        revisions = DBRevision.all_revisions(self.db, project, package)
        revisions.sort()
        prev = None
        tree = None
        for rev in revisions:
            if rev.broken:
                continue
            if prev and prev.files_hash == rev.files_hash:
                continue
            prev = rev
            new_tree = TreeNode(rev)
            if tree:
                new_tree.parent = tree
            tree = new_tree

        return tree

    def find_merge(self, revision, source_chain):
        """For a given revision in the target, find the node in the source chain
        that matches the files"""
        node = source_chain
        candidates = []
        while node:
            # exclude reverts happening after the merge
            if (
                node.revision.commit_time <= revision.commit_time
                and node.revision.files_hash == revision.files_hash
            ):
                candidates.append(node)
            if node.merged_into:
                # we can't have candidates that are crossing previous merges
                # see https://src.opensuse.org/importers/git-importer/issues/14
                candidates = []
            node = node.parent
        if candidates:
            # the first candidate is the youngest one that matches the check. That's
            # good enough. See FastCGI test case for rev 36 and 38: 37 reverted 36 and
            # then 38 reverting the revert before it was submitted.
            return candidates[0]

    def add_merge_points(self, factory_revisions):
        """For all target revisions that accepted a request, look up the merge
        points in the source chains (ignoring the actual revision submitted for now)"""

        class FindRequestsWalker(AbstractWalker):
            def __init__(self) -> None:
                super().__init__()
                self.requests = set()

            def call(self, node: TreeNode, _: Boolean) -> None:
                if not node.revision.request_id:
                    return
                self.requests.add(node.revision.request_id)

        class FindMergeWalker(AbstractWalker):
            def __init__(self, builder: TreeBuilder, requests: dict) -> None:
                super().__init__()
                self.source_revisions = dict()
                self.builder = builder
                self.requests = requests

            def call(self, node, is_source) -> None:
                # not going to happen, but better safe
                if is_source:
                    return
                if not node.revision.request_id:
                    return
                req = self.requests.get(node.revision.request_id)
                key = f"{req.source_project}/{req.source_package}"
                if key not in self.source_revisions:
                    self.source_revisions[key] = self.builder.revisions_chain(
                        req.source_project, req.source_package
                    )
                node.merged = self.builder.find_merge(
                    node.revision, self.source_revisions[key]
                )
                # add a reverse lookup
                if node.merged:
                    node.merged.merged_into = node

        # walk the tree twice. First we collect all requests to be looked up
        # to avoid going into the DB a thousand times
        frqs = FindRequestsWalker()
        factory_revisions.walk(frqs)
        requests = dict()
        with self.db.cursor() as cur:
            cur.execute(
                "SELECT * from requests WHERE id = ANY(%s)", (list(frqs.requests),)
            )
            for row in cur.fetchall():
                req = Request.from_db(row)
                requests[req.dbid] = req
        sw = FindMergeWalker(self, requests)
        factory_revisions.walk(sw)

    def prune_loose_end(self, factory_node):
        """Look for source revisions that end in a new root and prune them"""
        merge_before_last = None
        last_merge = None
        while factory_node:
            if factory_node.merged:
                source_node = factory_node.merged
                while source_node:
                    source_node = source_node.parent
                    if source_node and source_node.merged_into:
                        break
                merge_before_last = last_merge
                last_merge = factory_node
            factory_node = factory_node.parent

        # a package without requests
        if not last_merge:
            return
        if merge_before_last:
            # we need to find the last merged_into that didn't end nowhere
            # and cut the rope there
            node = merge_before_last.merged
            last_node = None
            while node:
                last_node = node
                node = node.parent
                if node and node.merged_into:
                    break
            if last_node:
                last_node.parent = None

        if not last_merge.parent:
            last_merge.parent = last_merge.merged
        last_merge.merged.merged_into = None
        last_merge.merged = None

    def build(self, project, package):
        """Create a Factory tree (returning the top)"""
        factory_revisions = self.revisions_chain(project, package)
        self.add_merge_points(factory_revisions)
        # factory_revisions.print()
        self.prune_loose_end(factory_revisions)
        # factory_revisions.print()
        return factory_revisions