"""Sub Items Enrichment."""

from typing import TYPE_CHECKING

from pypdf import PdfReader

from octopus.classification import doc_type_pred
from octopus.clients import init_redis_client
from octopus.data.company_data import (
    CompanyDataIndex,
    company_data_augmentation,
    define_display_labels,
    tag_company_data,
    tag_document_type_data,
)
from octopus.text import date_extraction
from octopus.utils import open_pdf
from squirro.sdk import PipeletV1, require

if TYPE_CHECKING:
    from logging import Logger
    from typing import Any

    from redis import Redis


@require("log")
class SubItemsEnrichment(PipeletV1):  # type: ignore[misc]
    """Pipelet to enrich document based on sub items."""

    log: "Logger"
    max_pages_to_process: int
    redis_client: "Redis[bytes]"

    def __init__(self, config: "dict[str, Any]") -> None:
        """Initialize the pipelet."""
        self.max_pages_to_process = config.get("max_pages_to_process", 3)
        self.company_index = CompanyDataIndex.load_index()
        self.redis_client = init_redis_client()

    # pylint: disable=too-many-locals, too-many-branches
    def consume(self, item: "dict[str, Any]") -> "dict[str, Any]":
        """Consume an item.

        Args:
            item: The item to consume

        Returns:
            The consumed item
        """
        self.log.info("Running sub items enrichment on `%s`", item["id"])

        # Get pages from item
        pages, num_pages = self._extract_compute_num_pages(item)
        if not pages:
            pages = item.get("body", "")

        self.log.info("Number of pages: %d", num_pages)
        item["keywords"]["num_pages"] = [num_pages]

        # Document type prediction
        self.log.info("Predicting document type")
        doc_type_pred(item, pages)

        # Date extraction
        item["keywords"]["document_date_pred"] = [
            date_extraction(f"{item['title']}\n{pages}")
        ]

        # Extract companies from first n pages
        self._company_extraction(item, pages)

        # Define display labels
        skip_company_uid = self._define_display_labels(item)

        names: list[str] = item.get("keywords", {}).get("company_name", [])
        n_companies = len(names)
        if not n_companies:
            return item

        company_data_by_names = self.company_index.search_by_names(names)
        if len(company_data_by_names) != n_companies:
            self.log.error(
                "Not all companies found in the company index, "
                "found %d companies, expected %d",
                len(company_data_by_names),
                n_companies,
            )
            return item

        # Augment item with company data
        self.log.info("Augmenting item with company data")
        company_data_augmentation(
            item,
            company_data_by_names,
            skip_company_uid=skip_company_uid,
            redis_client=self.redis_client,
        )

        # Tag WFI labels
        initial_checkin = True
        if item.get("keywords", {}).get("source_type", [""])[0].startswith("WFI"):
            initial_checkin = item.get("wfi:initial_checkin", False)

        if initial_checkin:
            tag_company_data(
                item, company_data_by_names[0] if company_data_by_names else {}
            )
            tag_document_type_data(item, self.redis_client)
            if document_date := item["keywords"].get("document_date", []):
                item["keywords"]["wfi_document_date"] = document_date

        return item

    def _company_extraction(self, item: "dict[str, Any]", pages: str) -> None:
        """Company extraction."""
        self.log.info("Extracting companies")
        if pages:
            companies = self.company_index.extract_companies(pages)
            if companies:
                item["keywords"]["company_name_pred"] = companies
                self.log.info("Found company names: %s", companies)
            else:
                self.log.info("No companies found")
        else:
            self.log.warning("No pages to extract companies from")

    def _define_display_labels(self, item: "dict[str, Any]") -> bool:
        """Define display labels.

        Args:
            item: The item to define display labels for

        Returns:
            Whether to skip the company uid
        """
        self.log.info("Defining display labels")
        company_data_by_uids = []

        skip_company_uid = False
        if uids := item.get("keywords", {}).get("company_uid", []):
            self.log.info("Company uid provided by user for item: %s", uids)
            company_data_by_uids = self.company_index.search_by_uids(uids)
            skip_company_uid = True

        define_display_labels(item, company_data_by_uids)

        return skip_company_uid

    def _extract_compute_num_pages(self, item: "dict[str, Any]") -> "tuple[str, int]":
        """Extract sub pages and compute number of pages.

        Args:
            item (Dict[str, Any]): The item to tag

        Returns:
            The pages and number of pages.
        """
        pages = ""

        # If OCR-ed sub_items will be present
        if num_pages := len(item.get("sub_items", [])):
            sub_pages = []
            for page_dict in item.get("sub_items", []):
                sub_pages.append(page_dict.get("body"))
                if len(sub_pages) >= self.max_pages_to_process:
                    break
            if sub_pages:
                pages = "\n".join(sub_pages)
        else:  # Otherwise, manually compute
            try:
                if pdf := open_pdf(item):
                    pdf_reader = PdfReader(pdf)
                    num_pages = int(pdf_reader.get_num_pages())
                else:
                    self.log.warning("No PDF file found")
            except Exception:
                self.log.exception("Failed to read PDF file")
        return pages, num_pages

    @staticmethod
    # pylint: disable-next=invalid-name
    def getArguments() -> "list[dict[str, Any]]":  # noqa: N802 type: ignore[type-arg]
        """Return the arguments for the pipelet."""
        return [
            {
                "name": "max_pages_to_process",
                "display_label": "Maximum number of pages to process",
                "type": "int",
                "default": 3,
            }
        ]
