"""Upload job."""

import hashlib
import logging
import shutil
import time
from datetime import datetime
from pathlib import Path
from typing import Any, NoReturn

import pytz
from pdf2image import convert_from_path
from PIL import Image, ImageDraw, ImageEnhance
from pypdf import PdfReader, PdfWriter
from pyzbar.pyzbar import decode

from octopus.activity_tracking import RejectionReason
from octopus.clients import init_squirro_client
from octopus.email import STREAM_NAME as EMAIL_STREAM_NAME
from octopus.email import EmailPayload
from octopus.stream import Streamer, add_to_stream
from octopus.uploader import (
    GROUP_NAME,
    STREAM_NAME,
    UploadPayload,
    generate_failed_file_payload,
)
from octopus.utils import (
    check_is_file_valid,
    compute_hash,
    extract_zip,
    get_mime_type,
    is_zip_file,
    load_config,
    set_log_verbosity,
)
from squirro_client import DocumentUploader

set_log_verbosity(logging.INFO)

MAX_BATCH_SIZE = 10 * 1024 * 1024
MAX_FILE_SIZE = 50 * 1024 * 1024
MAX_RETRY = 3


class FailToUploadError(Exception):
    """Failed to upload to Squirro."""

    def __init__(self, fp: "Path") -> None:
        """Initialize the exception with a message."""
        super().__init__(f"Failed to upload {fp.name} to Squirro.")


class UploadStreamer(Streamer):
    """Upload streamer."""

    batch_size: int = 0
    current_batch: list[Path]
    uploaded_files: set[Path]  # Track uploaded files to avoid duplicates
    upload_payload: UploadPayload
    uploader: DocumentUploader

    def process(self, payload: "dict[str, Any]") -> None:
        """Process message.

        Args:
            payload: Payload to process.
        """
        logging.info("Payload: %s", payload)
        self.upload_payload = UploadPayload.from_dict(payload)
        self.uploaded_files = set()  # Initialize the set for each process call

        # Get Octopus Separator Barcodes from the reference PDF
        reference_pdf_path = (
            Path(__file__).parent.parent.parent
            / "assets"
            / "Separator_barcode"
            / "OCTOPUS_SEPARATOR_Barcode.pdf"
        )
        barcodes = self._extract_barcodes_from_reference(str(reference_pdf_path))

        while self.upload_payload.file_paths:
            fp = self.upload_payload.file_paths.pop()
            # Validate the file
            if not check_is_file_valid(fp):
                self.upload_payload.invalids.append(
                    generate_failed_file_payload(fp.name, RejectionReason.INVALID_FILE)
                )
                continue

            # Check file size
            if fp.stat().st_size > MAX_FILE_SIZE:
                self.upload_payload.invalids.append(
                    generate_failed_file_payload(
                        fp.name,
                        RejectionReason.FILE_TOO_LARGE,
                    )
                )
                continue

            # Process single PDF
            if fp.suffix.lower() == ".pdf":
                self._split_and_upload_pdf(
                    fp,
                    barcodes,
                    self.upload_payload.labels,
                )
                continue

            # Process ZIP file
            if is_zip_file(str(fp)):
                self._process_zip(fp, barcodes)
                continue

            # Process other types
            self._upload_to_squirro(fp, fp.stat().st_size, self.upload_payload.labels)

    @staticmethod
    def _preprocess_image(image: Image.Image) -> Image.Image:
        """Enhance the image for better barcode detection.

        Converts the image to grayscale, enhances its contrast, and resizes it
        for improved barcode recognition.

        Args:
            image (Image.Image): The input image to process.

        Returns:
            Image.Image: The enhanced and resized image.
        """
        grayscale_image = image.convert("L")
        enhancer = ImageEnhance.Contrast(grayscale_image)
        enhanced_image = enhancer.enhance(2.0)
        return enhanced_image.resize(
            (enhanced_image.width * 2, enhanced_image.height * 2)
        )

    @staticmethod
    def _extract_barcodes_from_reference(reference_pdf: str) -> set[str]:
        """Extract barcodes from the reference PDF.

        Converts each page of the PDF into an image, decodes barcodes from the image,
        and returns a set of unique barcodes.

        Args:
            reference_pdf (str): Path to the reference PDF.

        Returns:
            Set[str]: A set of extracted barcodes as strings.

        Raises:
            ValueError: If no barcodes could be extracted
            or the PDF cannot be processed.
        """

        def _raise_error(message: str) -> NoReturn:
            """Helper function to raise a ValueError with logging.

            Args:
                message (str): The error message to log and raise.

            Raises:
                ValueError: Raised with the provided message.
            """
            logging.error(message)
            raise ValueError(message)

        barcodes: set[str] = set()

        try:
            pages = convert_from_path(reference_pdf)

            for page in pages:
                try:
                    decoded_objects = decode(page)
                    barcodes.update(obj.data.decode("utf-8") for obj in decoded_objects)
                except Exception as decode_error:
                    logging.warning(
                        "Failed to decode barcode on page: %s", decode_error
                    )
                    continue

            if not barcodes:
                _raise_error("No barcodes could be extracted from the PDF.")

            logging.info("Extracted barcodes: %s", barcodes)
        except Exception as e:
            error_message = f"Failed to process PDF: {e}"
            logging.exception(error_message)
            raise ValueError(error_message) from e

        return barcodes

    @staticmethod
    def _remove_barcodes(
        image: Image.Image, matching_barcodes: list[Any]
    ) -> Image.Image:
        """Remove matching barcode regions by filling them with white.

        Args:
            image (Image.Image): The input image from which barcodes will be removed.
            matching_barcodes (List): A list of barcode objects with
                                      rectangle attributes.

        Returns:
            Image.Image: The image with barcode regions filled with white.
        """
        draw = ImageDraw.Draw(image)
        for obj in matching_barcodes:
            rect = obj.rect
            draw.rectangle(
                [
                    (rect.left, rect.top),
                    (rect.left + rect.width, rect.top + rect.height),
                ],
                fill="white",
            )
        return image

    def _split_pdf_by_barcode(
        self, pdf_path: Path, barcodes: set[str], output_dir: Path
    ) -> None:
        """Split PDF based on barcodes and remove matching reference barcodes.

        Args:
            pdf_path (Path): Path to the input PDF file.
            barcodes (Set[str]): Set of barcode strings to match against.
            output_dir (Path): Directory where split files will be saved.
        """
        try:
            output_dir.mkdir(exist_ok=True)
            pages = convert_from_path(pdf_path)
            count = 1
            current_writer = PdfWriter()
            barcode_detected = False  # Track if any barcode is detected
            split_files = []  # Track the split files

            for i, page in enumerate(pages):
                preprocessed_image = self._preprocess_image(page)
                decoded_objects = decode(preprocessed_image)

                # Find matching barcodes
                matching_barcodes = [
                    obj
                    for obj in decoded_objects
                    if obj.data.decode("utf-8") in barcodes
                ]
                # If Seperator Barcode is detected
                if matching_barcodes:
                    barcode_detected = True
                    if current_writer.pages:
                        split_path = output_dir / f"{pdf_path.stem}_{count}.pdf"
                        with split_path.open("wb") as f:
                            current_writer.write(f)
                        logging.info("Split saved: %s", split_path)
                        split_files.append(split_path)
                        count += 1
                        current_writer = PdfWriter()
                    continue

                temp_pdf_path = output_dir / f"temp_page_{i}.pdf"
                image_with_no_barcodes = self._remove_barcodes(page, matching_barcodes)
                image_with_no_barcodes.save(temp_pdf_path)

                temp_reader = PdfReader(temp_pdf_path)
                current_writer.add_page(temp_reader.pages[0])

            # Save any remaining pages in the writer as the final split
            if current_writer.pages:
                split_path = output_dir / f"{pdf_path.stem}_{count}.pdf"
                with split_path.open("wb") as f:  # Use Path.open
                    current_writer.write(f)
                logging.info("Final split saved: %s", split_path)
                split_files.append(split_path)

            # If no barcode is detected, copy the original file
            if not barcode_detected and not split_files:
                final_path = output_dir / pdf_path.name
                shutil.copy(pdf_path, final_path)
                logging.info(
                    "No barcode detected. Original file saved as: %s", final_path
                )

            # If barcode at beginning or end or both, rename it to the original name
            if len(split_files) == 1 and split_files[0].name != pdf_path.name:
                final_path = output_dir / pdf_path.name
                split_files[0].rename(final_path)
                logging.info(
                    "Single split renamed to original file name: %s", final_path
                )

        except Exception:
            logging.exception("Error splitting PDF %s", pdf_path)

    def _split_and_upload_pdf(
        self, pdf_path: Path, barcodes: set[str], labels: dict[str, Any] | None = None
    ) -> None:
        """Split PDF based on barcodes and upload each split.

        Args:
            pdf_path (Path): Path to the input PDF file.
            barcodes (Set[str]): Set of barcode strings to match against.
            labels (set, None): Labels to associate with the uploaded files.
        """
        try:
            output_dir = pdf_path.parent / "splits"
            output_dir.mkdir(exist_ok=True)

            self._split_pdf_by_barcode(pdf_path, barcodes, output_dir)

            for split_file in output_dir.iterdir():
                if (
                    split_file.suffix.lower() == ".pdf"
                    and not split_file.name.startswith("temp_page_")
                    and split_file not in self.uploaded_files
                ):
                    self._upload_to_squirro(
                        split_file, split_file.stat().st_size, labels or {}
                    )
                    self.uploaded_files.add(split_file)

        except Exception:
            logging.exception(
                "Error splitting or uploading PDF %s",
                pdf_path,
            )
            raise
        finally:
            if output_dir.exists():
                try:
                    shutil.rmtree(output_dir)
                    logging.info(
                        "Deleted output directory after processing: %s", output_dir
                    )
                except Exception:
                    logging.exception(
                        "Failed to delete output directory %s", output_dir
                    )

    def _process_zip(self, zip_path: Path, barcodes: set[str] | None = None) -> None:
        """Process ZIP files by extracting and splitting PDFs.

        Args:
            zip_path (Path): Path to the ZIP file to process.
            barcodes (Set[str]): Set of barcode strings to match against.
        """
        extracted, invalid_zips = extract_zip(zip_path)
        for invalid in invalid_zips:
            self.upload_payload.invalids.append(
                generate_failed_file_payload(invalid, RejectionReason.INVALID_FILE)
            )

        if not extracted:
            return

        # hash for zip reference
        dt = (
            datetime.now()
            .astimezone(pytz.timezone("Asia/Singapore"))
            .strftime("%d/%m/%Y-%H:%M:%S")
        )
        hash_obj = hashlib.blake2b(digest_size=16)
        hash_obj.update(f"{zip_path.name}|{dt}".encode())
        zip_labels = self.upload_payload.labels | {
            "zip_reference": [f"{zip_path.name}|{dt}|{hash_obj.hexdigest()}"]
        }

        for f in extracted:
            file_path = Path(f)

            if not check_is_file_valid(f):
                self.upload_payload.invalids.append(
                    generate_failed_file_payload(f, RejectionReason.INVALID_FILE)
                )
                continue

            if file_path.suffix.lower() == ".pdf":
                self._split_and_upload_pdf(file_path, barcodes or set(), zip_labels)
            else:
                self._upload_to_squirro(file_path, file_path.stat().st_size, zip_labels)

    def preprocess(self) -> None:
        """Preprocess."""
        self.current_batch: list[Path] = []

        sq_client, project_id = init_squirro_client()
        project_cfg: dict[str, Any] = sq_client.get_project_configuration(project_id)[
            "config"
        ]

        cfg = load_config()
        self.uploader = DocumentUploader(
            project_id=project_id,
            token=cfg["squirro"]["token"],
            cluster=cfg["squirro"]["cluster"],
            source_name=project_cfg.get("app.user-upload-source-name", {}).get(
                "value", "User Upload"
            ),
            pipeline_workflow_name=project_cfg.get(
                "app.user-upload-pipeline-workflow-name", {}
            ).get("value", "Upload"),
        )

    def postprocess(self) -> None:
        """Postprocess."""
        # Upload the remaining files
        self._batch_upload()

        # Send email if there are invalid files
        if self.upload_payload.invalids:
            data: dict[str, Any] = {
                "data": {
                    "user_info": {
                        "name": self.upload_payload.name,
                        "failure": self.upload_payload.invalids,
                    }
                },
                "recipients": self.upload_payload.email,
                "type": "ingestion_report",
            }
            add_to_stream(
                EmailPayload.create_payload(**data).to_dict(),
                EMAIL_STREAM_NAME,
                redis_client=self.redis_client,
            )

        # Set an empty upload payload
        self.upload_payload = UploadPayload()

    def _batch_upload(self) -> None:
        """Upload the current batch if it is full."""
        self.uploader.flush()
        while self.current_batch:
            f = self.current_batch.pop()
            logging.info("Uploaded %s", f)
            f.unlink(missing_ok=True)
        self.batch_size = 0

    def _upload_to_squirro(
        self, fp: "Path", fsize: int, labels: "dict[str, Any]"
    ) -> None:
        """Upload a file to Squirro.

        Args:
            fp: File path.
            fsize: File size.
            labels: Labels to add to the file.

        Raises:
            FailToUploadError: If failed to upload to Squirro.
        """
        try:
            with fp.open("rb") as b:
                binary_hash = compute_hash(b)
        except Exception:
            logging.exception("Error computing hash for %s", fp.name)
            fp.unlink(missing_ok=True)
            raise
        labels["binary_hash"] = [binary_hash]
        labels["skip_binary_hash_compute"] = ["true"]

        retry = 0
        while retry < MAX_RETRY:
            try:
                # Create a squirro item for upload, items are buffered internally
                self.uploader.upload(
                    str(fp),
                    mime_type=get_mime_type(str(fp)),
                    title=fp.name,
                    doc_id=f"{binary_hash}{int(time.time())}",
                    keywords=labels,
                )
                self.current_batch.append(fp)
                self.batch_size += fsize
                break
            except Exception:
                logging.exception(
                    "Attempt #%d failed. Error uploading document %s",
                    retry + 1,
                    fp.name,
                )
                retry += 1
                continue
        else:
            fp.unlink(missing_ok=True)
            self.upload_payload.invalids.append(
                generate_failed_file_payload(
                    fp.name,
                    RejectionReason.INTERNAL_ERRORS,
                )
            )
            raise FailToUploadError(fp)

        if self.batch_size >= MAX_BATCH_SIZE:
            self._batch_upload()


if __name__ == "__main__":
    UploadStreamer(STREAM_NAME, GROUP_NAME).run()
