"""Documents plugin."""

import io
import json
import logging
import sys
import tempfile
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
import pytz
import requests
from flask import jsonify, make_response, render_template, request, send_file
from flask import session as flask_session
from werkzeug.exceptions import BadRequest
from werkzeug.utils import secure_filename

from octopus.activity_tracking import RejectionReason
from octopus.clients import init_redis_client, init_wfi_client
from octopus.data import CompanyDataIndex, export_to_excel, generate_online_report
from octopus.stream import add_to_stream
from octopus.uploader import STREAM_NAME, UploadPayload, generate_failed_file_payload
from octopus.utils import load_config
from squirro.common.dependency import get_injected
from squirro.integration.frontend.context import execute_in_studioaware_context
from squirro.sdk.studio import StudioPlugin

if TYPE_CHECKING:
    from typing import Any

    from flask import Response
    from werkzeug.datastructures import FileStorage

    from squirro_client import SquirroClient


ITEM_ID_WFI_PAYLOAD_HASH: "str" = "item_id_wfi_payload_hash"
DOCUMENT_TYPE_MAPPING_HASH: "str" = "document_type_mapping_hash"
MAX_NUM_FILES = 10
MAX_REPORT_ITEM = 1000
RETRY_COUNT = 3

log = logging.getLogger(__name__)
plugin = StudioPlugin(__name__)
cfg = load_config()
redis_client = init_redis_client()

# pytest only works with skip_authentication=True
# use skip_authentication=True if using pytest
# otherwise set allow_project_readers=True
plugin_options: "dict[str, bool]" = (
    {"skip_authentication": True}
    if "pytest" in sys.modules
    else {"allow_project_readers": True}
)

config = load_config()


@execute_in_studioaware_context  # type: ignore[misc]
def _get_user_info() -> "dict[str, str]":
    """Get the signed user from the frontend context.

    Returns:
        The user information.

    Raises:
        BadRequest: If the user information is incomplete.
    """
    sq_client: SquirroClient = get_injected("squirro_client")
    try:
        user = sq_client.get_user_data(flask_session.get("user_id"))["user_information"]
    except Exception:  # pylint: disable=broad-except
        logging.exception("Failed to get user data.")
        raise
    fields = ["email", "name"]
    user_info = {}
    for k in fields:
        if not (v := user.get(k, [""])[0]):
            msg = f"Field `{k}` is missing from user information."
            log.error(msg)
            raise BadRequest(msg)
        user_info[k] = v
    return user_info


@plugin.route("/", methods=["POST"], **plugin_options)  # type: ignore[misc]
def upload_document() -> "Response":
    """Upload a document.

    Returns:
        A response with a status code of 201 if the document is uploaded successfully.

    Raises:
        BadRequest: If the user information is incomplete or the document is not
        uploaded successfully.
    """
    payload: dict[str, Any] = _get_user_info()

    files: list[FileStorage] = request.files.getlist("documents")
    if not files or len(files) > MAX_NUM_FILES:
        msg = f"Invalid number of files {len(files)}"
        log.error(msg)
        raise BadRequest(msg)

    str_labels: str = request.form.get("labels", "{}")
    labels: dict[str, list[str]] = json.loads(str_labels)
    labels["source_type"] = ["User Upload"]

    # Save files to temporary directory to ensure they are not lost
    file_paths: list[Path] = []
    invalids = []

    # Use /flash as common storage on if available
    try:
        tmp_dir = "/flash/upload-tmp"
    except Exception:
        tmp_dir = tempfile.mkdtemp()
    for file in files:
        file_name: str = secure_filename(file.filename)  # type: ignore[arg-type]
        file_path = Path(tmp_dir) / file_name

        try:
            file.save(file_path)
        except Exception:
            log.exception("Failed to save file %s", file_path)

            invalids.append(
                generate_failed_file_payload(
                    file_name,
                    RejectionReason.INVALID_FILE,
                ),
            )
            continue
        else:
            file_paths.append(file_path)
    payload["file_paths"] = file_paths
    payload["labels"] = labels
    payload["invalids"] = invalids
    try:
        add_to_stream(
            UploadPayload.create_payload(**payload).to_dict(),
            STREAM_NAME,
            redis_client=redis_client,
        )
    except Exception:
        log.exception("Failed to add payload to stream")
        for file_path in file_paths:
            file_path.unlink()
        raise

    # Upload is considered successful if the files are saved to the temporary directory
    return make_response({"message": "File uploaded successfully"}, 201)


@plugin.route("/bulk", methods=["PATCH"], **plugin_options)  # type: ignore[misc]
def bulk_assign() -> "Response":
    """Bulk assign items to a project.

    Args:
        ids: The item IDs.
        labels: The labels to update.
        project_id: The project ID.

    Returns:
        A response with a status code of 204 if the update is successful.

    Raises:
        BadRequest: If the update is not successful.
    """
    json_data = request.json
    if not json_data:
        msg = "Invalid JSON data."
        log.error(msg)
        raise BadRequest(msg)

    ids: list[str] | None = json_data.get("ids")
    labels: dict[str, list[str]] | None = json_data.get("labels")
    project_id: str | None = json_data.get("project_id")

    if not (ids and labels and project_id):
        msg = "Incomplete data to bulk assign items."
        log.error(msg)
        raise BadRequest(msg)

    bulk_assign_limit = 30
    if (num_items := len(ids)) > bulk_assign_limit:
        msg = "Exceeded limit for bulk assignment."
        log.error(msg)
        raise BadRequest(msg)

    sq_client = get_injected("squirro_client")
    sq_client.modify_items(
        project_id,
        items=[
            {
                "id": item_id,
                "keywords": labels,
            }
            for item_id in ids
        ],
    )

    return make_response(
        {"message": f"{num_items} items updated successfully."},
        204,
    )


@plugin.route("/<document_id>", methods=["PATCH"], **plugin_options)  # type: ignore[misc]
def modify_squirro_labels(document_id: str) -> "Response":
    """Modify ONLY the labels of a Squirro item.

    Args:
        document_id: The document ID.
        labels: The labels to update.

    Returns:
        A response with a status code of 204 if the update is successful.

    Raises:
        BadRequest: If the update is not successful.
    """
    json_data = request.json
    if not json_data:
        msg = "Invalid JSON data."
        log.error(msg)
        raise BadRequest(msg)

    labels: dict[str, list[str]] | None = json_data.get("labels")
    project_id: str | None = json_data.get("project_id")

    if not (labels and project_id):
        msg = f"Incomplete data to update item {document_id}."
        log.error(msg)
        raise BadRequest(msg)

    wfi_document_id = labels.pop("wfi_document_id", [""])[0]
    update_squirro_item(project_id, document_id, labels)

    if (
        is_deleted := labels.get("is_deleted", [None])[0]
    ) is not None and wfi_document_id:
        wfi_client = init_wfi_client()
        if is_deleted == "true":
            wfi_client.soft_delete(wfi_document_id)
        else:
            wfi_client.restore_soft_delete(wfi_document_id)

    return make_response({}, 204)


def update_squirro_item(
    project_id: str,
    document_id: str,
    labels: dict[str, list[str]],
) -> None:
    """Update the Squirro item.

    Args:
        project_id: The project ID.
        document_id: The document ID.
        labels: The labels to update.
    """
    log.info("Updating squirro item %s", document_id)
    log.debug("Labels: %s", labels)
    try:
        sq_client = get_injected("squirro_client")
        sq_client.modify_item(project_id, document_id, keywords=labels)
    except Exception:
        log.exception("Failed to modify item %s", document_id)
        raise


@plugin.route("/<document_id>", methods=["PUT"], **plugin_options)  # type: ignore[misc]
def update_document(document_id: str) -> "Response":
    """Update the document.

    Updates Squirro labels, WFI metadata, and document status.

    Returns:
        A response with a status code of 200 if the update is successful.

    Raises:
        BadRequest: If the update is not successful.
    """
    json_data = request.json
    if not json_data:
        msg = "Invalid JSON data."
        log.error(msg)
        raise BadRequest(msg)

    labels: dict[str, list[str]] | None = json_data.get("labels")
    project_id: str | None = json_data.get("project_id")
    wfi_document_id: str | None = json_data.get("wfi_document_id")

    if not (labels and project_id and wfi_document_id):
        msg = f"Incomplete data to update item {document_id}."
        log.error(msg)
        raise BadRequest(msg)

    all_labels = get_all_labels_to_update(labels)
    wfi_metadata = get_wfi_metadata(all_labels)

    update_squirro_item(project_id, document_id, all_labels)
    update_document_status(project_id, document_id, labels)
    if wfi_metadata:
        add_wfi_payload_to_redis(wfi_document_id, wfi_metadata)

    res = {}
    if rm_name := all_labels.get("rm_name"):
        res["rm_name"] = rm_name
    return make_response(jsonify(res), 200)


def get_all_labels_to_update(labels: dict[str, list[str]]) -> dict[str, list[str]]:
    """Get all labels to update.

    Args:
        labels: The labels to update.

    Returns:
        The labels to update.
    """
    labels_true: dict[str, list[str]] = {
        f"{label_key}_true": label_val
        for label_key, label_val in labels.items()
        if label_key in {"document_type", "company_name", "document_date"}
    }
    squirro_payload: dict[str, list[str]] = {**labels, **labels_true}

    if company_name := labels.get("company_name"):
        wfi_company_name = labels.get("wfi_company_name")
        company_data = get_company_data(company_name, wfi_company_name)
        squirro_payload.update(company_data)

    if document_type := labels.get("document_type"):
        document_type_data = get_document_type_data(document_type[0])
        squirro_payload.update(document_type_data)

    if document_date := labels.get("document_date"):
        squirro_payload.update({"wfi_document_date": document_date})

    if references := labels.get("references"):
        squirro_payload.update({"wfi_references": [";".join(references)]})

    return squirro_payload


def get_company_data(
    company_names: "list[str]",
    wfi_company_name: "list[str] | None" = None,
) -> dict[str, list[str]]:
    """Get the company data.

    Args:
        company_names: The company names.
        wfi_company_name: The WFI company name.

    Returns:
        The company data.
    """
    company_data_index = CompanyDataIndex.load_index()
    company_data = company_data_index.search_by_names(company_names)

    ret = defaultdict(list)
    labels_to_update = {
        "company_name",
        "wfi_company_name",
        "company_name_true",
        "company_cif",
        "permission_code",
        "uid_permission_code",
        "rm_name",
        "company_uid",
        "wfi_company_cif",
        "wfi_company_rm_code",
        "wfi_company_segment",
        "wfi_company_team_name",
        "wfi_company_team_code",
    }

    for company in company_data:
        company_uid = company.get("company_uid", [""])[0]
        if not company_uid:
            continue

        for label, values in company.items():
            if (
                not label.startswith("wfi_")  # type:ignore[attr-defined]
                and label in labels_to_update
            ):
                # skip WFI labels, will be added later if update is required in WFI
                ret[label].extend(values)

            if label == "permission_code":
                uid_permission_code = [
                    f"{company_uid}___{permission_code}" for permission_code in values
                ]
                ret["uid_permission_code"].extend(uid_permission_code)

    ret["permission_code"] = list(set(ret["permission_code"]))
    ret["rm_name"] = ret["rm_name"][:1]  # Use only the first rm_name

    if wfi_company_name:
        # Needs to update company data in WFI
        wfi_company_data = next(
            company
            for company in company_data
            if company["company_name"][0] == wfi_company_name[0]
        )
        for label, values in wfi_company_data.items():
            if label == "company_cif":
                ret["wfi_company_cif"] = values
            elif label.startswith("wfi_"):  # type:ignore[attr-defined]
                ret[label] = values

    return dict(ret)  # type: ignore[arg-type]


def get_document_type_data(document_type: str) -> dict[str, list[str]]:
    """Get the document type data.

    Args:
        document_type: The document type.

    Returns:
        The document type data.

    Raises:
        KeyError: If the document type is not found.
    """
    mapping_bytes = redis_client.hget(DOCUMENT_TYPE_MAPPING_HASH, document_type)
    if not mapping_bytes:
        msg = f"Document type {document_type} not found."
        log.error(msg)
        raise KeyError(msg)

    mapping = json.loads(mapping_bytes)

    document_type_related_labels = [
        "document_category",
        "wfi_document_type",
        "wfi_document_name",
        "wfi_document_category",
    ]

    return {
        label: [value]
        for label, value in mapping.items()
        if label in document_type_related_labels
    }


def get_wfi_metadata(labels: dict[str, list[str]]) -> "dict[str, str]":
    """Get the WFI metadata.

    Args:
        labels: The labels to get the WFI metadata from.

    Returns:
        The WFI metadata.
    """
    return {key: value[0] for key, value in labels.items() if key.startswith("wfi_")}


def add_wfi_payload_to_redis(
    wfi_document_id: str,
    metadata: "dict[str, str]",
) -> None:
    """Add WFI metadata to Redis.

    Args:
        wfi_document_id: The WFI document ID.
        metadata: The metadata to add.

    Raises:
        BadRequest: If the WFI document ID or metadata is incomplete.
    """
    if not (wfi_document_id and metadata):
        msg = "Incomplete data to update."
        log.error(msg)
        raise BadRequest(msg)

    log.info("Adding WFI metadata to redis for document %s", wfi_document_id)
    log.debug("WFI Metadata: %s", metadata)

    try:
        existing_payload = redis_client.hget(
            ITEM_ID_WFI_PAYLOAD_HASH,
            wfi_document_id,
        )
        if existing_payload is None:
            redis_client.hset(
                ITEM_ID_WFI_PAYLOAD_HASH,
                wfi_document_id,
                json.dumps(metadata),
            )
        else:
            existing_payload_dict: dict[str, str] = json.loads(existing_payload)
            new_payload: dict[str, str] = {**existing_payload_dict, **metadata}
            redis_client.hset(
                ITEM_ID_WFI_PAYLOAD_HASH,
                wfi_document_id,
                json.dumps(new_payload),
            )
    except Exception:
        log.exception("Error while trying to update payload in redis")
        raise


def update_document_status(
    project_id: str,
    document_id: str,
    labels: dict[str, list[str]],
) -> None:
    """Update the document status.

    Args:
        project_id: The project ID.
        document_id: The document ID.
        labels: The labels to update.
    """
    status_update_labels = {}

    if document_type := labels.get("document_type"):
        status_update_labels["document_type"] = document_type[0]

    if company_name := labels.get("wfi_company_name"):
        status_update_labels["company_name"] = company_name[0]

    if document_date := labels.get("document_date"):
        status_update_labels["document_date"] = document_date[0]

    status_tracking_url = (
        f"{cfg['squirro']['cluster']}/studio/document_status_tracking/projects"
        f"/{project_id}/documents/{document_id}"
    )

    if status_update_labels:
        log.info("Updating status for document %s", document_id)
        log.debug("Labels changed: %s", status_update_labels)
        res = requests.put(
            status_tracking_url,
            json=status_update_labels,
            headers=request.headers,  # Required to forward the session information
            timeout=10,
        )
        res.raise_for_status()


@plugin.route("/generate_report", **plugin_options)  # type: ignore[misc]
def generate_report() -> "Response":
    """Generate a report.

    Returns:
        A response with a status code of 200 if the report is generated successfully.
        A response with a status code of 400 if the report is not generated
            successfully.
    """
    project_id = request.args.get("project_id")

    sq_client = get_injected("squirro_client")
    res = sq_client.query(
        project_id,
        query=request.args.get("query"),
        created_after=request.args.get("created_after"),
        created_before=request.args.get("created_before"),
        count=1000,
    )

    if not res["total"]:
        return make_response(
            render_template(
                "no-results.html",
                error_msg="No records found, please refine the filter criteria.",
            ),
            400,
        )

    if res["total"] > MAX_REPORT_ITEM:
        return make_response(
            render_template(
                "no-results.html",
                error_msg="More than 1000 records found, "
                "please refine the filter criteria.",
            ),
            400,
        )

    report_bytes = generate_online_report(res["items"])

    return send_file(  # type: ignore[no-any-return]
        report_bytes,
        as_attachment=True,
        mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
        download_name="OCtopus-Report-"
        f"{datetime.now(tz=pytz.timezone('Asia/Singapore')).strftime('%Y%m%d-%H%M%S')}.xlsx",
    )


@plugin.route("/save_to_excel", **plugin_options)  # type: ignore[misc]
def save_to_excel() -> "Response":
    """Return items in an Excel sheet."""
    query_args = request.args.to_dict()
    dashboard = query_args.pop("dashboard", "Squirro")
    keywords_to_export = json.loads(
        query_args.pop("keywords", {}),  # type: ignore[arg-type]
    )

    sq_client = get_injected("squirro_client")
    res = sq_client.query(
        query_args.get("project_id"),
        query=query_args.get("query"),
        fields=["title", "created_at", "keywords"],
        count=10000,
    )

    items_dict_list = res.get("items", [])
    for item in items_dict_list:
        item.pop("id")
        item.pop("sources")
        for export_keyword in keywords_to_export:
            if export_keyword in {"created_at", "title"}:
                continue
            # handles casting present values to string datatype
            item[export_keyword] = ", ".join(
                map(str, item["keywords"].get(export_keyword, [])),
            )
        item.pop("keywords")

    keywords_to_export = sorted(
        keywords_to_export.items(),
        key=lambda x: x[1]["index"],
    )

    items_df = pd.DataFrame(
        items_dict_list,
        columns=[k for (k, _) in keywords_to_export],
    )
    items_df.insert(0, "S/N", range(1, len(items_df) + 1))

    # handles UTC to SGT conversion.
    # FUTURE: Migrate 'utc_to_local' into octopus util.py
    items_df["created_at"] = items_df["created_at"].apply(
        lambda x: datetime.fromisoformat(x)
        .replace(tzinfo=pytz.utc)
        .astimezone(pytz.timezone("Asia/Singapore"))
        .strftime("%d/%m/%Y %H:%M:%S"),
    )
    items_df["document_date"] = items_df["document_date"].apply(
        lambda x: datetime.strptime(x, "%Y-%m-%d")
        .astimezone(pytz.timezone("Asia/Singapore"))
        .strftime("%d/%m/%Y")
        if x
        else "",
    )
    items_df["document_date"] = items_df["document_date"]
    # Rename columns to column_title specified in frontend config
    items_df = items_df.rename(
        columns={k: v["column_title"] for (k, v) in keywords_to_export},
    )

    stream = io.BytesIO()
    export_to_excel(items_df, stream)
    stream.seek(0)

    resd: Response = send_file(
        stream,
        as_attachment=True,
        mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
        download_name=f"{datetime.now(pytz.timezone('Asia/Singapore')).strftime('%Y-%m-%d')}-{dashboard}-Export.xlsx",
    )
    return resd
