from hashlib import md5

from lib.request import Request


class DBRevision:
    def __init__(self, row):
        # need to stay in sync with the schema creation in db.py
        (
            self.dbid,
            self.project,
            self.package,
            self.rev,
            self.unexpanded_srcmd5,
            self.commit_time,
            self.userid,
            self.comment,
            self.broken,
            self.expanded_srcmd5,
            self.request_number,
            self.request_id,
            self.files_hash,
        ) = row
        self.rev = float(self.rev)
        self._files = None

    def short_string(self):
        return f"Rev {self.project}/{self.package}/{self.rev}"

    def __str__(self):
        return f"Rev {self.project}/{self.package}/{self.rev} Md5 {self.unexpanded_srcmd5} {self.commit_time} {self.userid} {self.request_number}"

    def __repr__(self):
        return f"[{self.__str__()}]"

    def __eq__(self, other):
        return self.dbid == other.dbid

    def __lt__(self, other):
        if self.project != other.project:
            return self.project < other.project
        if self.package != other.package:
            return self.package < other.package
        return self.rev < other.rev

    def as_dict(self, db):
        """Return a dict we can put into YAML for test cases"""
        ret = {
            "project": self.project,
            "package": self.package,
            "rev": self.rev,
            "unexpanded_srcmd5": self.unexpanded_srcmd5,
            "commit_time": self.commit_time,
            "userid": self.userid,
            "comment": self.comment,
            "broken": self.broken,
            "expanded_srcmd5": self.expanded_srcmd5,
            "files_hash": self.files_hash,
            "files": self.files_list(db),
        }
        if self.request_id:
            ret["request"] = Request.find(db, self.request_id).as_dict()
        return ret

    def links_to(self, db, project, package):
        with db.cursor() as cur:
            cur.execute(
                "INSERT INTO links (revision_id, project, package) VALUES (%s,%s,%s)",
                (self.dbid, project, package),
            )

    @classmethod
    def import_obs_rev(cls, db, revision):
        with db.cursor() as cur:
            cur.execute(
                """INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, commit_time, userid, comment, request_number)  
                        VALUES(%s, %s, %s, %s, %s, %s, %s, %s)""",
                (
                    revision.project,
                    revision.package,
                    revision.rev,
                    revision.unexpanded_srcmd5,
                    revision.time,
                    revision.userid,
                    revision.comment,
                    revision.request_number,
                ),
            )
        return cls.fetch_revision(db, revision.project, revision.package, revision.rev)

    @staticmethod
    def fetch_revision(db, project, package, rev):
        with db.cursor() as cur:
            cur.execute(
                "SELECT * FROM revisions where project=%s and package=%s and rev=%s",
                (project, package, str(rev)),
            )
            row = cur.fetchone()
        if row:
            return DBRevision(row)

    @staticmethod
    def latest_revision(db, project, package):
        with db.cursor() as cur:
            cur.execute(
                "SELECT MAX(rev) FROM revisions where project=%s and package=%s",
                (project, package),
            )
            max = cur.fetchone()[0]
        if max:
            return DBRevision.fetch_revision(db, project, package, max)
        return None

    @staticmethod
    def all_revisions(db, project, package):
        with db.cursor() as cur:
            cur.execute(
                "SELECT * FROM revisions where project=%s and package=%s",
                (project, package),
            )
            ret = []
            for row in cur.fetchall():
                ret.append(DBRevision(row))
        return ret

    def linked_rev(self, db):
        if self.broken:
            return None
        with db.cursor() as cur:
            cur.execute(
                "SELECT project,package FROM links where revision_id=%s", (self.dbid,)
            )
            row = cur.fetchone()
            if not row:
                return None
            project, package = row
            cur.execute(
                "SELECT * FROM revisions where project=%s and package=%s and commit_time <= %s ORDER BY commit_time DESC LIMIT 1",
                (project, package, self.commit_time),
            )
            revisions = [DBRevision(row) for row in cur.fetchall()]
        if revisions:
            return revisions[0]
        else:
            self.set_broken(db)
        return None

    def set_broken(self, db):
        with db.cursor() as cur:
            cur.execute("UPDATE revisions SET broken=TRUE where id=%s", (self.dbid,))

    def import_dir_list(self, db, xml):
        with db.cursor() as cur:
            cur.execute(
                "UPDATE revisions SET expanded_srcmd5=%s where id=%s",
                (xml.get("srcmd5"), self.dbid),
            )
            for entry in xml.findall("entry"):
                cur.execute(
                    """INSERT INTO files (name, md5, size, mtime, revision_id) 
                            VALUES (%s,%s,%s,%s,%s)""",
                    (
                        entry.get("name"),
                        entry.get("md5"),
                        entry.get("size"),
                        entry.get("mtime"),
                        self.dbid,
                    ),
                )

    def previous_commit(self, db):
        return self.fetch_revision(db, self.project, self.package, int(self.rev) - 1)

    def next_commit(self, db):
        return self.fetch_revision(db, self.project, self.package, int(self.rev) + 1)

    def calculate_files_hash(self, db):
        m = md5()
        for file_dict in self.files_list(db):
            m.update(
                (
                    file_dict["name"]
                    + "/"
                    + file_dict["md5"]
                    + "/"
                    + str(file_dict["size"])
                ).encode("utf-8")
            )
        return m.hexdigest()

    def files_list(self, db):
        if self._files:
            return self._files
        with db.cursor() as cur:
            cur.execute("SELECT * from files where revision_id=%s", (self.dbid,))
            self._files = []
            for row in cur.fetchall():
                (_, _, name, md5, size, mtime) = row
                self._files.append(
                    {"md5": md5, "size": size, "mtime": mtime, "name": name}
                )
        self._files.sort(key=lambda x: x["name"])
        return self._files

    @staticmethod
    def requests_to_fetch(db):
        with db.cursor() as cur:
            cur.execute(
                """SELECT request_number FROM revisions revs LEFT JOIN requests
                reqs ON reqs.number=revs.request_number WHERE reqs.id is null AND
                revs.request_number IS NOT NULL""",
            )
            return [row[0] for row in cur.fetchall()]

    @staticmethod
    def import_fixture_dict(db, rev_dict):
        """Used in test cases to read a revision from fixtures into the test database"""
        with db.cursor() as cur:
            cur.execute(
                """INSERT INTO revisions (project, package, rev, unexpanded_srcmd5, expanded_srcmd5, 
                                        commit_time, userid, comment, broken, files_hash)  
                        VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id""",
                (
                    rev_dict["project"],
                    rev_dict["package"],
                    rev_dict["rev"],
                    rev_dict["unexpanded_srcmd5"],
                    rev_dict["expanded_srcmd5"],
                    rev_dict["commit_time"],
                    rev_dict["userid"],
                    rev_dict["comment"],
                    rev_dict["broken"],
                    rev_dict["files_hash"],
                ),
            )
            rev_id = cur.fetchone()[0]
            for file_dict in rev_dict["files"]:
                cur.execute(
                    "INSERT INTO files (md5, mtime, name, size, revision_id) VALUES(%s, %s, %s, %s, %s)",
                    (
                        file_dict["md5"],
                        file_dict["mtime"],
                        file_dict["name"],
                        file_dict["size"],
                        rev_id,
                    ),
                )
            request = rev_dict.get("request")
            if request:
                cur.execute(
                    """INSERT INTO requests (creator, number, source_project, source_package,
                            source_rev, state, type) VALUES (%s, %s, %s, %s, %s, %s, %s) RETURNING id""",
                    (
                        request["creator"],
                        request["number"],
                        request.get("source_project"),
                        request.get("source_package"),
                        request.get("source_rev"),
                        request["state"],
                        request["type"],
                    ),
                )
                request_id = cur.fetchone()[0]
                cur.execute(
                    "UPDATE revisions SET request_id=%s, request_number=%s WHERE id=%s",
                    (request_id, request["number"], rev_id),
                )