import hashlib
import logging

try:
    import magic
except:
    print("Install python3-python-magic, not python3-magic")
    raise

from lib.db import DB
from lib.lfs_oid import LFSOid
from lib.obs import OBS


class ProxySHA256:
    def __init__(self, obs: OBS, db: DB):
        self.obs = obs
        self.db = db
        self.hashes = None
        self.texts = None
        self.mime = None

    def get(self, package, name, file_md5):
        if self.hashes is None:
            self.load_hashes(package)
        key = f"{file_md5}-{name}"
        ret = self.hashes.get(key)
        return ret

    def load_hashes(self, package):
        with self.db.cursor() as cur:
            cur.execute(
                """SELECT lfs_oids.file_md5,lop.filename,lfs_oids.sha256,lfs_oids.size
               FROM lfs_oid_in_package lop
               JOIN lfs_oids ON lfs_oids.id=lop.lfs_oid_id
               WHERE lop.package=%s""",
                (package,),
            )
            self.hashes = {
                f"{row[0]}-{row[1]}": (row[2], row[3]) for row in cur.fetchall()
            }

    def put(self, project, package, name, revision, file_md5, size):
        if not self.mime:
            self.mime = magic.Magic(mime=True)

        mimetype = None
        logging.debug(f"Add LFS for {project}/{package}/{name}")
        fin = self.obs._download(project, package, name, revision)
        sha = hashlib.sha256()
        while True:
            buffer = fin.read(10000)
            if not buffer:
                break
            sha.update(buffer)
            # only guess from the first 10K
            if not mimetype:
                mimetype = self.mime.from_buffer(buffer)
        fin.close()
        LFSOid(self.db).add(
            project, package, name, revision, sha.hexdigest(), size, mimetype, file_md5
        )

        # reset
        self.hashes = None
        self.texts = None
        return self.get(package, name, file_md5)

    def is_text(self, package, filename):
        if self.texts is None:
            self.load_texts(package)
        return filename in self.texts

    def load_texts(self, package):
        self.texts = set()
        with self.db.cursor() as cur:
            cur.execute("SELECT filename from text_files where package=%s", (package,))
            for row in cur.fetchall():
                self.texts.add(row[0])

    def get_or_put(self, project, package, name, revision, file_md5, size):
        result = self.get(package, name, file_md5)
        if not result:
            result = self.put(project, package, name, revision, file_md5, size)

        sha256, db_size = result
        assert db_size == size

        return sha256