"""Generate reports."""

import logging
from datetime import datetime
from io import BytesIO
from typing import TYPE_CHECKING

import pandas as pd
import pytz
import requests

from octopus.clients import init_squirro_client
from octopus.utils import build_query_string, load_config, set_log_verbosity

from .columns import REPORT_COLUMNS

if TYPE_CHECKING:
    from typing import Any

set_log_verbosity(logging.INFO)


LOCAL_TIMEZONE = pytz.timezone("Asia/Singapore")
TODAY = datetime.now(LOCAL_TIMEZONE).replace(hour=0, minute=0, second=0, microsecond=0)
SEGMENT_MAPPING = {"R": "RE", "W": "WCM"}


def generate_batch_report(
    statuses: "list[str]", query_options: "dict[str, str | None]", fp: str
) -> None:
    """Generate a batch report.

    Reports are generated for R and W segments and saved as excel files on disk.

    Args:
        statuses: Query documents with the given statuses.
        query_options: Arguments to the squirro client query method.
        fp: Output filepath of the report.
    """
    sq_client, project_id = init_squirro_client()

    query_options["query"] = (
        build_query_string({"current_doc_status": statuses}, "-is_deleted:true")
        if statuses
        else "-is_deleted:true"
    )

    es_res = sq_client.query(
        project_id,
        count=10000,
        **query_options,
    )
    es_data = _process_es_results(es_res["items"])
    db_res = _fetch_db_records(
        statuses=statuses,
        created_after=query_options.get("created_after"),
        created_before=query_options.get("created_before"),
    )
    db_data = _process_db_results(db_res)

    es_df = pd.DataFrame(
        es_data,
        columns=[field.name for field in REPORT_COLUMNS if field.source == "ES"],
    )
    db_df = pd.DataFrame(
        db_data,
        columns=[
            "Item ID",  # must be available in both dfs to join later
            *[field.name for field in REPORT_COLUMNS if field.source == "DB"],
        ],
    )
    df = _merge_and_sanitize_dfs(es_df, db_df)
    df = df.sort_values(by="Ageing Days", ascending=False)

    for segment in ["R", "W"]:
        df_seg = df[df["Segment"] == segment].copy()
        df_seg["Segment"] = df_seg["Segment"].map(SEGMENT_MAPPING)
        df_seg.insert(0, "S/N", range(1, len(df_seg) + 1))
        export_to_excel(df_seg, fp.format(segment=segment))


def generate_online_report(items: "list[dict[str, Any]]") -> "BytesIO":
    """Generate an online report and save it to a BytesIO object.

    Args:
        items: Items returned by a squirro query.

    Returns:
        BytesIO object that contains the excel file generated.
    """
    es_data = _process_es_results(items)
    db_res = _fetch_db_records_by_ids([item["id"] for item in items])
    db_data = _process_db_results(db_res)

    es_df = pd.DataFrame(
        es_data,
        columns=[field.name for field in REPORT_COLUMNS if field.source == "ES"],
    )
    db_df = pd.DataFrame(
        db_data,
        columns=[
            "Item ID",  # must be available in both dfs to join later
            *[field.name for field in REPORT_COLUMNS if field.source == "DB"],
        ],
    )

    df = _merge_and_sanitize_dfs(es_df, db_df)

    df = df.drop(columns="Ageing Days")
    df["Segment"] = df["Segment"].map(SEGMENT_MAPPING)
    df.insert(0, "S/N", range(1, len(df) + 1))

    stream = BytesIO()
    export_to_excel(df, stream)
    stream.seek(0)

    return stream


def _fetch_db_records(
    statuses: "list[str]",
    created_before: "str | None" = None,
    created_after: "str | None" = None,
) -> "list[dict[str, Any]]":
    """Fetch data from the database by applying the relevant filters.

    Args:
        statuses: The list of status codes to filter by.
        created_before: Upper limit of the records' creation date.
        created_after: Lower limit of the records' creation date.

    Returns:
        Data fetched from the database.

    Raises:
        HTTPError: If there is something wrong with the request.
    """
    cfg = load_config()

    params: dict[str, Any] = {
        "statuses": statuses,
        "token": cfg["squirro"]["token"],
    }
    if created_before:
        params["created_before"] = created_before.replace("T", " ")
    if created_after:
        params["created_after"] = created_after.replace("T", " ")

    url = (
        f"http://localhost/studio/document_status_tracking/projects/"
        f"{cfg['squirro']['project_id']}/documents"
    )
    try:
        res = requests.get(url, params=params, timeout=30)
        res.raise_for_status()
    except requests.HTTPError:
        logging.exception("Something went wrong.")
        raise

    data: list[dict[str, Any]] = res.json()
    return data


def _fetch_db_records_by_ids(ids: "list[str]") -> "list[dict[str, Any]]":
    """Make a request to fetch data from the database using the ids.

    Args:
        ids: IDs of the documents to filter by.

    Returns:
        Data fetched from the database.

    Raises:
        HTTPError: If there is something wrong with the request.
    """
    cfg = load_config()

    params = {"token": cfg["squirro"]["token"]}

    url = (
        f"http://localhost/studio/document_status_tracking/projects/"
        f"{cfg['squirro']['project_id']}/documents/ids"
    )
    try:
        res = requests.post(url, params=params, json=ids, timeout=30)
        res.raise_for_status()
    except requests.HTTPError:
        logging.exception("Something went wrong.")
        raise

    data: list[dict[str, Any]] = res.json()
    return data


def _process_es_results(data: "list[dict[str, Any]]") -> "list[dict[str, str]]":
    """Process the Elasticsearch results.

    Get the values that are to be included in the reports from Squirro query
    results.

    Args:
        data: A list of Squirro Items.

    Returns:
        Processed data.
    """
    items: list[dict[str, str]] = []
    for item in data:
        report_item = {}

        for field in REPORT_COLUMNS:
            if field.source != "ES":
                continue
            column = field.name
            try:
                value = (
                    item["keywords"].get(field.squirro_field, [""])[0]
                    if field.is_keyword
                    else item.get(field.squirro_field, "")  # type: ignore[arg-type]
                )

                if field.is_datetime:
                    value = (
                        datetime.fromisoformat(value)
                        .replace(tzinfo=pytz.utc)
                        .astimezone(LOCAL_TIMEZONE)
                        .strftime(field.datetime_format)  # type: ignore[arg-type]
                    )
            except Exception:
                value = ""
            report_item[column] = value

        report_item["Ageing Days"] = (
            TODAY.replace(tzinfo=pytz.utc)
            - datetime.fromisoformat(item["created_at"]).replace(tzinfo=pytz.utc)
        ).days

        items.append(report_item)

    return items


def _process_db_results(data: "list[dict[str, Any]]") -> "list[dict[str, str]]":
    """Get the values that are to be included in the reports from the database.

    Args:
        data: A list of DB records.

    Returns:
        Processed data.
    """
    ret: list[dict[str, str]] = []

    for item in data:
        latest_status = item["status_trail"][
            min(item["status_trail_report_idx"], len(item["status_trail"]) - 1)
        ]
        ret.append(
            {
                "Item ID": item["document_id"],
                "Latest Status Code": latest_status["code"],
                "Latest Status Header": latest_status["header"],
                "Creator of Latest Status": latest_status["user"]["name"],
                "Date of Latest Status": datetime.strptime(
                    latest_status["timestamp"], "%Y-%m-%d %H:%M:%S"
                )
                .replace(tzinfo=pytz.utc)
                .astimezone(LOCAL_TIMEZONE)
                .strftime("%d/%m/%Y %H:%M:%S"),
            }
        )

    return ret


def _merge_and_sanitize_dfs(es_df: pd.DataFrame, db_df: pd.DataFrame) -> pd.DataFrame:
    """Merge two dataframes together. Log any inconsistency between ES and DB.

    Args:
        es_df: Dataframe containing data from ES.
        db_df: Dataframe containing data from DB.

    Returns:
        The merged dataframe.
    """
    # Log IDs that are not present in both
    if diff := (set(es_df["Item ID"]).symmetric_difference(set(db_df["Item ID"]))):
        logging.warning("%i IDs present only in one of ES and DB: %s", len(diff), diff)

    # Perform left join to include all docs from ES but not DB.
    # Some docs might be soft-deleted from ES that should not appear in the reports.
    df = es_df.merge(db_df, on="Item ID", how="left")

    # Find status codes discrepancies between ES and DB. Include them in the reports
    # if there is any
    df["Status Update Failed"] = df.apply(
        lambda x: "Yes" if x["ES Status Code"] != x["Latest Status Code"] else "",
        axis=1,
    )
    for _, row in df[df["Status Update Failed"] == "Yes"].iterrows():
        logging.warning(
            "Status inconsistent - ID: %s, ES: %s, DB: %s",
            row["Item ID"],
            row["ES Status Code"],
            row["Latest Status Code"],
        )

    return df.drop(columns=["Item ID", "ES Status Code"])


def export_to_excel(df: "pd.DataFrame", fp: "str | BytesIO") -> None:
    """Export a dataframe to excel and resize the columns.

    Args:
        df: Pandas dataframe to export.
        fp: A filepath str or BytesIO object to write to.
    """
    # Auto resize column width
    with pd.ExcelWriter(fp) as writer:  # pylint: disable=abstract-class-instantiated
        df.to_excel(writer, index=False)
        worksheet = writer.sheets["Sheet1"]
        for idx, col in enumerate(df):
            series = df[col]
            try:
                max_len = (
                    max(int(series.astype(str).map(len).max()), len(str(series.name)))
                    + 1
                )
            except Exception:
                max_len = len(str(series.name)) + 1
            worksheet.set_column(idx, idx, max_len)
