# postpone type checks until we upgrade to sqlalchemy v2
# - https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html
# - https://docs.sqlalchemy.org/en/20/orm/extensions/index.html
import logging
from datetime import datetime
from typing import TYPE_CHECKING, TypedDict

import pytz
from fpdf import FPDF, XPos, YPos
from sqlalchemy import (
    JSON,
    Column,
    DateTime,
    Index,
    Integer,
    String,
    UniqueConstraint,
    event,
)
from sqlalchemy.ext.declarative import declarative_base

from .document_types import DocumentTypes
from .errors import AccessNotAllowed, InvalidRequest, InvalidStatus
from .schemas import DOC_SOURCE
from .util import current_timestamp, utc_to_local

if TYPE_CHECKING:
    from typing import Any, Literal, NotRequired


Base = declarative_base()

log = logging.getLogger(__name__)


class Status(TypedDict):
    """Represents individual status item."""

    code: str
    description: str
    final: bool
    header: str
    remarks: str


class StatusMap(TypedDict):
    """Status tracking configuration."""

    access_roles: "list[str]"
    statuses: "list[Status]"
    transitions: "dict[str,dict[str,dict[str,list[str]]]]"


class UserInformation(TypedDict):
    """OCBC extended user information.

    Users can have multiple roles assigned and pick one when siging in.
    """

    role: "list[str]"
    email: "list[str]"
    givenName: "list[str]"  # noqa: N815
    surname: "list[str]"
    uid: "list[str]"
    lan_id: "list[str]"
    org_unit_id: "list[str]"
    rmcode: "list[str]"
    role_ocbc: "list[str]"


class User(TypedDict):
    """Represents logged in user (session)."""

    id: str
    email: str
    full_name: str
    user_information: UserInformation


class Updater(TypedDict):
    """User performing a status update."""

    id: str
    lan_id: str
    email: str
    role: "list[str]"
    name: str


class DocumentUploader(TypedDict):
    """Comes from the pipelet."""

    email: "list[str]"
    name: str
    uid: "list[str]"
    lan_id: "list[str]"
    role_ocbc: "list[str]"


class StatusUpdate(TypedDict):
    """Represents document status update record."""

    remarks: "NotRequired[str]"
    code: str
    description: str
    header: str
    user: Updater
    timestamp: str


# Base is a dynamic class and mypy cant do static analysis:
class StatusTrackingConfig(Base):
    """This class is used to define document status tracking process (state
    machine | automata).

    Whenever config gets changed, a new record must be created with a
    new version. This is to allow us later to reconsile possible invalid
    statuses.
    """

    __tablename__ = "status_tracking_config"

    id: int = Column("id", Integer, primary_key=True)
    project_id: str = Column("project_id", String(50), nullable=False)
    # on updates always create a new record with a new version
    version: int = Column("version", Integer, nullable=False)
    # json column to store the automata definition
    status_map: "StatusMap" = Column("status_map", JSON)
    updated: datetime = Column(
        "updated", DateTime, default=datetime.utcnow, nullable=False
    )

    __table_args__ = (
        # ensure a single version per project and create an index for efficient queries:
        UniqueConstraint(
            "project_id", "version", name="unq_version_per_project"
        ),  # comma is mandatory!
    )

    def _init(self) -> None:
        # a dict to support quick lookups by code
        self.codes: dict[str, Status] = {}  # pylint: disable=W0201
        for s in self.status_map["statuses"]:
            self.codes[s["code"]] = s
        # validate transitions
        for p in self.status_map["transitions"].keys():
            if p not in self.codes:
                raise InvalidStatus(f"invalid status {p}")
            for q in self.status_map["transitions"][p].keys():
                if q not in self.codes:
                    raise InvalidStatus(f"invalid status {q}")

    @staticmethod
    def after_load(
        target: "StatusTrackingConfig",
        context: "Any",  # pylint: disable=W0613
    ) -> None:
        # called by orm
        target._init()  # pylint: disable=W0212

    def __repr__(self) -> str:
        return (
            f"<StatusTrackingConfig(project={self.project_id}, version={self.version})>"
        )

    def to_dict(self) -> "dict[str, Any]":
        return {
            "id": self.id,
            "project_id": self.project_id,
            "version": self.version,
            "status_map": self.status_map,
            "updated": self.updated.isoformat() if self.updated else None,
        }

    def raise_not_allowed_to_access_status_tracking(
        self, user_info: "UserInformation"
    ) -> None:
        if not set(user_info["role_ocbc"]).intersection(
            set(self.status_map["access_roles"])
        ):
            raise AccessNotAllowed(f"{user_info['role_ocbc']} not allowed access.")

    def list_available_followup_statuses(
        self, user_info: "UserInformation", doc: "DocumentStatus"
    ) -> "list[Status]":
        doc_type = "BBCA" if doc.is_bbca() else "non-BBCA"
        roles = user_info["role_ocbc"]
        log.debug("follow up for %s | %s | %s", doc.status_code, roles, doc_type)
        t = self.status_map["transitions"]
        if doc.status_code not in t:
            return []
        allowed_next = set()
        for next_status in t[doc.status_code].keys():
            for role in roles:
                doc_types = (
                    t[doc.status_code][next_status][role]
                    if role in t[doc.status_code][next_status]
                    else []
                )
                if doc_type in doc_types:
                    allowed_next.add(next_status)
        log.debug("transitions: %s", allowed_next)

        # Dual role cannot update from 001 to 001Z

        prev_updater = doc.current_status()["user"]
        if (
            doc.status_code == "001"
            and len(roles) > 1
            and prev_updater["lan_id"] == user_info["lan_id"][0]
        ):
            log.debug("preventing user holding two roles from updating 001 to 001Z")

            allowed_next.discard("001Z")

        log.debug(
            "follow up for %s | %s | %s: %s",
            doc.status_code,
            roles,
            doc_type,
            allowed_next,
        )
        return [self.codes[code] for code in allowed_next]

    def is_status_update_allowed(
        self, doc: "DocumentStatus", new_status: str, user_info: "UserInformation"
    ) -> bool:
        doc_type = "BBCA" if doc.is_bbca() else "non-BBCA"
        roles = user_info["role_ocbc"]
        log.info(
            "%s->%s allowed for %s | %s ?",
            doc.status_code,
            new_status,
            roles,
            doc_type,
        )
        allowed_statuses = self.list_available_followup_statuses(user_info, doc)
        allowed_codes = [s["code"] for s in allowed_statuses]
        return new_status in allowed_codes


# run init after StatusTrackingConfig is populated from the database:

event.listen(
    StatusTrackingConfig, "load", StatusTrackingConfig.after_load, propagate=True
)


class DocumentStatus(Base):
    """This class is used to keep track of individual document status
    changes.
    """

    __tablename__ = "document_status"

    document_name: str = Column("document_name", String(4096))
    document_id: str = Column("document_id", String(50), primary_key=True)
    company_name: "str | None" = Column("company_name", String(4096))
    document_type: "str | None" = Column("document_type", String(250))
    # store as-is, no need to parse, enforce specific formats...
    document_date: "str | None" = Column("document_date", String(250))
    project_id: str = Column("project_id", String(50), nullable=False)
    status_trail: "list[StatusUpdate]" = Column("status_trail", JSON)
    # status_code used for process flow (excludes AT00x for example)
    status_code: str = Column("status_code", String(50))
    # Track the last non-ATxx status
    status_trail_report_idx: int = Column(
        "status_trail_report_idx", Integer, default=1, nullable=False
    )
    updated: datetime = Column(
        "updated", DateTime, default=datetime.utcnow, nullable=False
    )
    wfi_document_id = Column("wfi_document_id", String(50))

    __table_args__ = (
        # reporting: this allows to query for all documents
        # changed in a given time period,
        # changed in a given time period and optionally by status
        Index("idx_upd_code", "updated", "status_code"),  # comma is mandatory!
    )
    document_types = DocumentTypes.load()

    def __repr__(self) -> str:
        return f"<DocumentStatus(id={self.document_id},status={self.status_code})>"

    def is_bbca(self) -> bool:
        if not self.document_type:
            return False
        return DocumentStatus.document_types.is_bbca_type(self.document_type.upper())

    def to_dict(self) -> "dict[str, Any]":
        return {
            "document_id": self.document_id,
            "document_name": self.document_name,
            "company_name": self.company_name,
            "document_type": self.document_type,
            "document_date": self.document_date,
            "project_id": self.project_id,
            "status_trail": self.status_trail,
            "status_code": self.status_code,
            "status_trail_report_idx": self.status_trail_report_idx,
            "updated": self.updated.isoformat() if self.updated else None,
            "next_statuses": (
                self.next_statuses if hasattr(self, "next_statuses") else []
            ),
        }

    def is_classified(self) -> bool:
        return all(
            isinstance(var, str) and var
            for var in [self.company_name, self.document_date, self.document_type]
        )

    def current_status(self) -> "StatusUpdate":
        """Return status record corresponding to status_code."""
        return next(
            filter(lambda s: s["code"] == self.status_code, reversed(self.status_trail))
        )

    def update(
        self,
        config: StatusTrackingConfig,
        updater: "Updater",
        attributes: "dict[str,str]",
    ) -> None:
        company_name = attributes.get("company_name")
        document_type = attributes.get("document_type")
        document_date = attributes.get("document_date")
        if not company_name and not document_type and not document_date:
            raise InvalidRequest(
                "Nothing to update, must have one of [company,type,date]"
            )
        timenow = current_timestamp()
        audit_update: list[StatusUpdate] = []

        if document_date and self.document_date != document_date:
            from_date, to_date = (
                (
                    utc_to_local(
                        date,
                        pytz.timezone("Asia/Singapore"),
                        display_format="%d/%m/%Y",
                    )
                    if date
                    else date
                )
                for date in [self.document_date, document_date]
            )

            audit_update.append(
                {
                    "code": "AT003",
                    "header": config.codes["AT003"]["header"],
                    "description": f"From {from_date} to {to_date}",
                    "user": updater,
                    "timestamp": timenow,
                }
            )
            self.document_date = document_date

        if document_type and self.document_type != document_type:
            audit_update.append(
                {
                    "code": "AT002",
                    "header": config.codes["AT002"]["header"],
                    "description": f"From {self.document_type} to {document_type}",
                    "user": updater,
                    "timestamp": timenow,
                }
            )
            self.document_type = document_type

        if company_name and self.company_name != company_name:
            audit_update.append(
                {
                    "code": "AT001",
                    "header": config.codes["AT001"]["header"],
                    "description": f"From {self.company_name} to {company_name}",
                    "user": updater,
                    "timestamp": timenow,
                }
            )
            self.company_name = company_name

        if not audit_update:
            raise InvalidRequest(
                "Nothing to update, provided values are identical to the current ones"
            )

        self.status_trail = self.status_trail + audit_update
        # doc might have become classified after this update and must change status:
        if self.status_code.startswith("9") and self.is_classified():
            updated_status = "001"
            self.status_trail = self.status_trail + [
                {
                    "code": updated_status,
                    "header": config.codes[updated_status]["header"],
                    "description": config.codes[updated_status]["description"],
                    "timestamp": timenow,
                    "user": updater,
                }
            ]
            self.status_code = updated_status

        now = datetime.utcnow()
        self.updated = now

        # Update status trail report index
        if not self.status_trail[-1]["code"].startswith("AT"):
            self.status_trail_report_idx = len(self.status_trail) - 1

    @staticmethod
    def create(  # pylint: disable-msg=too-many-arguments,too-many-positional-arguments,too-many-locals
        config: StatusTrackingConfig,
        uploader: "Updater",
        source_type: str,
        document_id: str,
        document_name: str,
        project_id: str,
        company_name: "str | None" = None,
        document_type: "str | None" = None,
        document_date: "str | None" = None,
        wfi_document_id: "str | None" = None,
    ) -> "DocumentStatus":
        # fmt: off
        classified = {
            # name, type, date
            True: {
                True: {
                    True: "001",  # none missing, document is classified
                    False: "906"  # date missing
                },
                False: {
                    True: "905",  # doc_type missing
                    False: "903"  # doc_type and date missing
                }
            },
            False: {
                True: {
                    True: "904",  # name missing
                    False: "902"  # name and date missing
                },
                False: {
                    True: "901",  # name and doc_type missing
                    False: "907"  # all 3 missing
                },
            }
        }
        # fmt: on
        time_now = current_timestamp()
        status_trail: list[StatusUpdate]

        # If a dual role uploads, it'll always be as Workbench:CSSupportMaker
        if uploader["role"][0] == "Workbench:CSSupportMaker&Checker":
            uploader["role"][0] = "Workbench:CSSupportMaker"

        if source_type in ["WFI Migration", "End point Upload"]:
            # these two are completed (already processed)
            status_trail = [
                {
                    "code": DOC_SOURCE[source_type],
                    "user": uploader,
                    "header": "set later",
                    "description": "set later",
                    "timestamp": time_now,
                },
            ]
        else:
            initial_status = classified[company_name is not None][
                document_type is not None
            ][document_date is not None]

            status_trail = [
                {
                    "code": DOC_SOURCE[source_type],
                    "user": uploader,
                    "header": "set later",
                    "description": "set later",
                    "timestamp": time_now,
                },
                {
                    "code": initial_status,
                    "user": uploader,
                    "header": "set later",
                    "description": "set later",
                    "timestamp": time_now,
                },
            ]
        for s in status_trail:
            s["header"] = config.codes[s["code"]]["header"]
            s["description"] = config.codes[s["code"]]["description"]

        doc = DocumentStatus(
            document_id=document_id,
            document_name=document_name,
            project_id=project_id,
            status_trail=status_trail,
            company_name=company_name,
            document_type=document_type,
            document_date=document_date,
            wfi_document_id=wfi_document_id,
            status_code=status_trail[-1]["code"],
        )
        return doc


class StatusTrailPDF(FPDF):
    """Generates PDF for DocumentStatus."""

    def __init__(  # pylint: disable=too-many-arguments,too-many-positional-arguments
        self,
        doc: DocumentStatus,
        filename: str,
        company_cif: str,
        doc_format: str,
        sort_order: 'Literal["ASC", "DESC"]' = "DESC",
    ) -> None:
        super().__init__()
        self.doc = doc
        self.filename = filename
        self.cif = company_cif
        self.doc_format = doc_format
        self.sort_order = sort_order
        self._render_pdf()

    def save_to_file(self, filepath: str = "document-status.pdf") -> None:
        self.output(filepath)

    def get_binary(self) -> bytearray:
        return self.output()

    def _render_pdf(self) -> None:
        self.add_page()
        self.set_left_margin(20)
        self.set_right_margin(20)
        self.set_auto_page_break(auto=True, margin=25)

        self._set_intro()

        trail_history = (
            self.doc.status_trail
            if self.sort_order == "ASC"
            else reversed(self.doc.status_trail)
        )
        for i, status in enumerate(trail_history):
            self._status_title(
                i + 1,
                status["header"],
                utc_to_local(
                    status["timestamp"],
                    pytz.timezone("Asia/Singapore"),
                    display_format="%d/%m/%Y %H:%M:%S",
                ),
            )
            self._status_details(
                status["description"],
                f"{status['user']['name']} ({status['user']['role'][0]})",
                status["remarks"] if "remarks" in status else "-",
            )

    def _set_intro(
        self,
    ) -> None:
        document_type = self.doc.document_type or "-"
        company_name = self.doc.company_name or "-"
        document_date = (
            utc_to_local(
                self.doc.document_date,
                pytz.timezone("Asia/Singapore"),
                display_format="%d/%m/%Y",
            )
            if self.doc.document_date
            else "-"
        )

        self.set_text_color(128, 128, 128)
        self.set_font("Helvetica", "", 8)
        self.cell(
            0, 8, "Document History", new_x=XPos.LMARGIN, new_y=YPos.NEXT, align="L"
        )

        self.set_text_color(0, 0, 0)
        self.set_font("Helvetica", "B", 16)
        self.multi_cell(
            0, 6, self.filename, new_x=XPos.LMARGIN, new_y=YPos.NEXT, align="L"
        )
        self.ln()

        self.set_font("Helvetica", "", 10)
        self.set_text_color(64, 64, 64)
        self.cell(40, 5, "Document Type:", border=0)
        self.set_text_color(0, 0, 0)
        self.cell(40, 5, document_type, border=0)
        self.ln()

        self.set_text_color(64, 64, 64)
        self.cell(40, 5, "Document Date:", border=0)
        self.set_text_color(0, 0, 0)
        self.cell(40, 5, document_date, border=0)
        self.ln()

        self.set_text_color(64, 64, 64)
        self.cell(40, 5, "Company CIF:", border=0)
        self.set_text_color(0, 0, 0)
        self.cell(40, 5, self.cif, border=0)
        self.ln()

        self.set_text_color(64, 64, 64)
        self.cell(40, 5, "Company Name:", border=0)
        self.set_text_color(0, 0, 0)
        self.cell(40, 5, company_name, border=0)
        self.ln()

        self.set_text_color(64, 64, 64)
        self.cell(40, 5, "Document Format:", border=0)
        self.set_text_color(0, 0, 0)
        self.multi_cell(40, 5, self.doc_format, border=0)
        self.ln()

        self.set_draw_color(128, 128, 128)
        self.set_line_width(0.05)  # width is in millimeters
        self.line(20, self.get_y() + 3, 190, self.get_y() + 3)
        order = "Descending" if self.sort_order == "DESC" else "Ascending"
        self.cell(140, 15, f"Document Status (sort by Date - {order})", border=0)
        self.ln(12)

    def header(self) -> None:
        self.set_text_color(64, 64, 64)
        self.set_font("Helvetica", "", 8)
        self.multi_cell(
            0,
            4,
            f"{self.filename} status trail history",
            align="L",
        )
        self.ln()
        self.set_draw_color(0, 0, 0)
        self.set_line_width(0.1)
        self.line(10, self.get_y(), 200, self.get_y())
        self.ln()

    def footer(self) -> None:
        self.set_text_color(64, 64, 64)
        # Position at 1.5 cm from bottom

        self.set_y(-25)
        self.set_draw_color(0, 0, 0)
        self.set_line_width(0.1)
        self.line(10, self.get_y(), 200, self.get_y())
        self.ln(2)
        self.set_font("Helvetica", "", 8)
        # default alias is {nb} which then must be escaped so change it:
        self.alias_nb_pages("<%tp%>")
        self.cell(
            0,
            4,
            f"Page {self.page_no()!s} of <%tp%>",
            align="L",
        )
        self.ln()
        self.cell(
            0,
            4,
            f"PDF generated: {current_timestamp()!s}",
            align="L",
        )
        self.ln()
        self.cell(
            0,
            4,
            "For Internal Use Only",
            align="L",
        )
        self.ln()
        self.set_font("Helvetica", "", 12)
        self.cell(
            0,
            6,
            "Confidential",
            align="C",
        )

    def _status_title(self, num: int, header: str, timestamp: str) -> None:
        self.set_text_color(30, 144, 255)
        self.set_font("Helvetica", "", 11)
        self.cell(0, 10, f"{num!s}. {header}", align="L")
        self.set_font("Helvetica", "I", 10)
        self.cell(0, 10, timestamp, align="R")
        self.ln(8)

    def _status_details(
        self, description: str, updater: str, remarks: str = "-"
    ) -> None:
        self.set_text_color(0, 0, 0)
        self.set_font("Helvetica", "", 10)
        self.multi_cell(0, 5, description)
        self.ln(3)

        self.set_text_color(64, 64, 64)
        self.set_font("Helvetica", "", 10)
        self.cell(0, 6, f"Reporter: {updater}", align="L")
        if remarks is not None:
            self.ln(5)
            self.multi_cell(0, 6, f"Remarks: {remarks}", align="L")

        self.ln(8)
