"""Synchronize company data changes to Squirro."""

import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
from organisation_structure import construct_mappings

from octopus.clients import init_squirro_client
from octopus.data import CompanyDataIndex
from octopus.utils import set_log_verbosity

if TYPE_CHECKING:
    from typing import Any

    Keywords = dict[str, list[str]]


set_log_verbosity(logging.INFO)


COMPANY_DATA_MAPPING = {
    "Unique_Identifier": "company_uid",
    "CustSIBSCIFKey": "company_cif",
    "CustSIBSName": "company_name",
    "RMLanID": "wfi_company_rm_code",
    "RMName": "rm_name",
    "Team_CD": "wfi_company_team_code",
    "Segment": "wfi_company_segment",
    "Team": "wfi_company_team_name",
}
SQUIRRO_FIELDS = list(COMPANY_DATA_MAPPING.values())


def main() -> None:
    """Entrypoint.

    Raises:
        FileNotFoundError: If the company data file or territory list file does
        not exist.
    """
    args = initialize_args()

    company_fp = Path(args.company_data_file)
    territory_fp = Path(args.territory_list_file)

    if not company_fp.exists():
        msg = f"File {company_fp} does not exist."
        logging.error(msg)
        raise FileNotFoundError(msg)

    if not territory_fp.exists():
        msg = f"File {territory_fp} does not exist."
        logging.error(msg)
        raise FileNotFoundError(msg)

    permission_code_mapping: dict[str, list[str]] = construct_mappings(territory_fp)
    path = Path(args.permission_code_mapping_file)
    with path.open("w", encoding="utf-8") as f:
        json.dump(permission_code_mapping, f)

    df = pd.read_csv(args.company_data_file, dtype=str, encoding="latin-1")
    df = df.fillna("")
    df = df.apply(lambda col: col.map(lambda x: [str(x)]))
    df = df.rename(columns=COMPANY_DATA_MAPPING)[SQUIRRO_FIELDS]
    df["permission_code"] = df["wfi_company_team_code"].apply(
        lambda x: permission_code_mapping.get(x[0], [])
    )
    company_data_new = df.to_dict(orient="records")
    company_data_old = CompanyDataIndex.load_index()._index_uid  # noqa: SLF001

    companies_updated = compare_company_data(
        company_data_old,  # type: ignore[arg-type]
        company_data_new,  # type: ignore[arg-type]
    )
    update_items(companies_updated)


def compare_company_data(
    old_data: "dict[str, Keywords]", new_data: "list[dict[str, Any]]"
) -> "defaultdict[str, Any]":
    """Compare old company data and new company data and return the difference.

    Args:
        old_data: Old company data read from company data index
        new_data: New company data read from company data file

    Returns:
        DefaultDict where the key is the UID and the value is the keywords changed.
    """
    diff: defaultdict[str, Any] = defaultdict(dict)
    for company in new_data:
        uid = company["company_uid"][0]
        if not (company_old := old_data.get(uid)):
            continue  # new company

        for label, value in company.items():
            # compare with company_old
            old_value = company_old.get(label, [])
            is_different = (
                set(value) != set(old_value)
                if label == "permission_code"
                else value != old_value
            )
            if is_different:
                diff[uid][label] = {"old": old_value, "new": value}
    return diff


def update_items(  # pylint:disable=too-many-locals
    companies: "dict[str, dict[str, Any]]",
) -> None:
    """Update items in the project based on changes in company data.

    Args:
        companies: UID to keywords mapping
    """
    sq_client, project_id = init_squirro_client()

    items: dict[str, Keywords] = {}
    for uid, updated_labels in companies.items():
        logging.info("Getting items for company %s", uid)

        results: dict[str, Any] = {}
        start: int = results.get("next_params", {}).get("start", 0)
        while not results.get("eof"):
            results = sq_client.query(
                project_id=project_id,
                query=f"company_uid:{uid}",
                fields=["id", "keywords"],
                start=start,
                count=10000,
            )
            for item in results["items"]:
                item_id = item["id"]
                keywords = {k: v["new"] for k, v in updated_labels.items()}

                # Can have multiple company names, only update the one that was changed
                if updated_labels.get("company_name"):
                    old_name = updated_labels["company_name"]["old"][0]
                    new_name = updated_labels["company_name"]["new"][0]

                    current_company_names = items.get(item_id, {}).get(
                        "company_name", item["keywords"].get("company_name", [])
                    )
                    keywords["company_name"] = [
                        new_name if company == old_name else company
                        for company in current_company_names
                    ]

                    if (
                        wfi_company_name := item["keywords"].get(
                            "wfi_company_name", [""]
                        )
                    ) and wfi_company_name[0] == old_name:
                        keywords["wfi_company_name"] = [new_name]

                # Update permissions
                if updated_labels.get("permission_code"):
                    uid_permission_code_old = items.get(item_id, {}).get(
                        "uid_permission_code",
                        item["keywords"].get("uid_permission_code", []),
                    )
                    uid_permission_code_new = [
                        code
                        for code in uid_permission_code_old
                        if not code.startswith(uid)
                    ]
                    uid_permission_code_new += [
                        f"{uid}___{code}"
                        for code in updated_labels["permission_code"]["new"]
                    ]

                    keywords["uid_permission_code"] = uid_permission_code_new
                    keywords["permission_code"] = list(
                        {code.split("___")[1] for code in uid_permission_code_new}
                    )

                # Prevent overriding items[item_id] if both above conditions are true
                items.setdefault(item_id, {}).update(keywords)

                if len(items) >= 1000:  # noqa: PLR2004
                    sq_client.modify_items(
                        project_id,
                        [{"id": key, "keywords": val} for key, val in items.items()],
                    )
                    items = {}

    if items:
        sq_client.modify_items(
            project_id,
            [{"id": key, "keywords": val} for key, val in items.items()],
        )


def initialize_args() -> "argparse.Namespace":
    """Initialize arguments.

    Returns:
        Parsed arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--territory-list-file",
        required=True,
        help="Path to the territory list",
    )
    parser.add_argument(
        "--company-data-file",
        required=True,
        help="Path to the updated company data file provided by client daily",
    )
    parser.add_argument(
        "--permission-code-mapping-file",
        default="/flash/octopus/cache/permission_code_mapping.json",
        help="Output path of permission code mapping file",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main()
