"""Custom activity plugin."""

# mypy: ignore-errors
import collections
import hashlib
import json
import logging
import re
import socket
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
from dateutil import parser

from octopus.utils import load_config
from squirro.common.dependency import get_injected
from squirro.dataloader.data_source import DataSource
from squirro_client import SquirroClient

if TYPE_CHECKING:
    from collections.abc import Generator
    from typing import Any

log = logging.getLogger(__name__)

ACTIVITY_FILE_PATTERN = "activity*.jsonl"
ACTIVITY_FILE_DATE_REGEX = re.compile(r"activity.(\d+-\d+-\d+).jsonl")


class SquirroActivitySource(DataSource):  # pylint: disable=abstract-method
    """Data loader plugin to parse Squirro activity logs."""

    inc_column = "now"

    def __init__(self, *args: "dict[str, Any]", **kwargs: "dict[str, Any]") -> None:
        """Initialize the SquirroActivitySource."""
        super().__init__(*args, **kwargs)
        self.preview: bool = False
        self.hostname: str | None = None
        self.activity_path: str | None = None

        try:
            cfg = load_config()
            sq_token = cfg.get("squirro", "token")
            self.activity_project_id = cfg.get("activity", "project_id")
            cluster_url = cfg.get("squirro", "cluster")
        except Exception:
            log.exception("Exception occurred when reading main.ini file")
            raise

        try:
            self.sq_client = SquirroClient(None, None, cluster=cluster_url)
            self.sq_client.authenticate(refresh_token=sq_token)
        except Exception:
            log.exception("Exception occurred when instantiating squirro client")
            raise

        self.additional_fields: dict[str, str] = self.get_custom_fields()

    @property
    def inc_column_cache_key(self) -> str:
        return f"inc_column_{self.hostname}"

    @property
    def max_inc_value(self) -> "datetime | None":
        inc_value = self.key_value_store.get(self.inc_column_cache_key)
        return self._convert_to_datetime(inc_value) if inc_value else None

    @max_inc_value.setter
    def max_inc_value(self, value: str) -> None:
        self.key_value_store[self.inc_column_cache_key] = value

    @staticmethod
    def _convert_to_datetime(date_str: "str | None") -> datetime:
        return parser.parse(date_str, ignoretz=True)

    def get_custom_fields(self) -> "dict":
        project_config: dict = self.sq_client.get_project_configuration(
            project_id=self.activity_project_id
        )["config"]

        additional_fields: dict = project_config.get(
            "app.custom-activity-plugin-labels", {}
        ).get("value", "")

        additional_fields_flattened: dict = pd.json_normalize(
            additional_fields, sep="."
        ).to_dict(orient="records")[0]

        existing_facets: list[dict] = self.sq_client.get_facets(
            self.activity_project_id
        )
        existing_facets = [facet["name"] for facet in existing_facets]
        for key, value in additional_fields_flattened.items():
            if key not in existing_facets:
                self.sq_client.new_facet(
                    project_id=self.activity_project_id,
                    name=key,
                    display_name=value,
                    searchable=True,
                )
        return additional_fields_flattened

    def connect(self, _: str | None = None, __: str | None = None) -> None:
        if self.args.reset:
            log.info("Resetting key-value stores")
            self.key_value_cache.clear()
            self.key_value_store.clear()

        self.hostname = socket.getfqdn()
        log.debug("Incremental Column: %r", self.inc_column)
        log.debug("Incremental Last Value: %r", self.max_inc_value)

        config = get_injected("config")
        activity_path = self.args.source_folder or config.get("activity", "path")
        self.activity_path = activity_path.rstrip("/")
        log.debug("Activity path: %r", self.activity_path)

    def disconnect(self) -> None:
        """Disconnect from the source."""

    def getDataBatch(
        self, batch_size: int, *_: "dict[str, Any]", **kwargs: "dict[str, Any]"
    ) -> "Generator[list[dict[str, Any]], None, None]":
        """Get data from source on batches.

        Yields:
            List of dictionaries
        """
        self.preview = bool(self.preview_mode or kwargs.get("get_schema"))
        rows: list[dict[str, Any]] = []
        for row in self.getActivityReports():
            rows.append(row)
            if len(rows) >= batch_size:
                yield self._batch_it(rows)
                rows = []
        if rows:
            yield self._batch_it(rows)

    # pylint: disable-next=invalid-name
    def getActivityReports(
        self,
    ) -> "Generator[dict[str, Any], None, None]":
        files = self.gather_files_to_process()
        for file in files:
            log.debug("Analysing the %r file", file)
            with Path(file).open(encoding="utf-8") as f:
                for line in f:
                    record = self._load_item(line)
                    if not record or self._skip_item(record):
                        continue
                    yield self._process_item(record)

    # pylint: disable-next=invalid-name
    def getSchema(self) -> "list[str]":
        """Return the schema of the dataset.

        This is done by analysing the first 100 events, so we get a
        reasonably complete list of keys to use.

        :returns a List containing the names of the columns retrieved
        from the source
        """
        schema: set = set()
        batch = next(self.getDataBatch(100, get_schema=True), ())
        for entry in batch:
            schema |= set(entry.keys())

        if additional_fields := self.args.additional_fields:
            schema |= {field.strip() for field in additional_fields.split(",")}

        return sorted(schema)

    # pylint: disable-next=invalid-name
    def getJobId(self) -> str:
        """Return a unique string for each different select.

        :returns a string
        """
        # Generate a stable id that changes with the main parameters
        m = hashlib.sha256()
        m.update(repr(self.activity_path).encode("utf-8"))
        job_id = m.hexdigest()
        log.debug("Job ID: %s", job_id)
        return job_id

    # pylint: disable-next=invalid-name
    def getArguments(self) -> "list[dict[str, Any]]":
        """Get arguments required by the plugin.

        Returns:
            List of dictionaries
        """
        return [
            {
                "name": "source_folder",
                "display_label": "Activity path",
                "help": (
                    "Custom path the dataloader will search for activity files. "
                    "If not set, the default path provided in the .ini configuration "
                    "is used. The dataloader is searching for the activity files in "
                    "the root level of provided path and also inside the folder with "
                    "corresponding hostname."
                ),
                "default": "",
                "type": "str",
                "advanced": True,
            },
            {
                "name": "additional_fields",
                "display_label": "Additional fields",
                "default": "",
                "type": "str",
                "advanced": True,
            },
        ]

    # pylint: disable-next=invalid-name
    def getIncrementalColumns(self) -> None:
        return None

    def gather_files_to_process(self) -> list[str]:
        """Gather files which have not been loaded yet.

        The function list all the activity files in an ascending way
        e.g.: [activity.2022-02-01.jsonl, activity.2022-02-02.jsonl],
        then collect only a files which have the date equal or newer
        than saved incremental value.

        Returns:
            List of file paths.
        """
        files_to_process: list = []
        files: list[str] = self._gather_files()
        max_inc_value = self.max_inc_value
        if not max_inc_value:
            log.debug(
                "Incremental value not found in the store. Analysing all %d files.",
                len(files),
            )
            return files
        for file in files:
            file_datetime = self._get_datetime_from_file_name(file)
            if not file_datetime or file_datetime.date() < max_inc_value.date():
                continue
            files_to_process.append(file)
        log.debug("Found %d files to process", len(files_to_process))
        return files_to_process

    def _gather_files(self) -> "list[str]":
        """Gather files from the filesystem.

        Gathers the activity files from the root level and then check
        for files inside the folder with corresponding hostname.
        Gathering files from the root level provides backward
        compatibility with older Squirro versions, where activity files
        were saved without hostname knowledge. If a file with the same
        date is found both in the root level and in the hostname folder,
        let's assume the one with the hostname is newer.

        Returns:
            List of file paths.
        """
        root_level_files = Path(self.activity_path).glob(ACTIVITY_FILE_PATTERN)
        hostname_specific_files = Path(f"{self.activity_path}").glob(
            f"{self.hostname}/{ACTIVITY_FILE_PATTERN}"
        )
        return sorted(
            {*root_level_files, *hostname_specific_files},
            key=lambda path: (Path(path).name, len(Path(path).parents)),
        )

    def _get_datetime_from_file_name(self, file: str) -> "datetime | None":
        match = ACTIVITY_FILE_DATE_REGEX.search(file)
        warning_message = f"Not found a proper date in the file name: {file!r}"
        if not match:
            log.warning(warning_message)
            return None
        try:
            matched_date = match.group(1)
        except IndexError:
            log.warning(warning_message)
            return None
        return self._convert_to_datetime(matched_date)

    @staticmethod
    def _load_item(line: str) -> "dict[str, Any] | None":
        if not line.strip():
            return None
        try:
            return json.loads(line)
        except Exception:  # pylint: disable=broad-except
            log.warning("Could not parse line %r. Skipping the line", line)

        return None

    def _skip_item(self, record: "dict[str, Any]") -> bool:
        action_is_pageview = record["action"] == "pageview"
        has_invalid_tenant = (
            not self.source_tenant or record["tenant"] != self.source_tenant
        )
        return (
            action_is_pageview
            or has_invalid_tenant
            or self._already_fetched(record.get(self.inc_column))
        )

    def _already_fetched(self, inc_value: "str | None") -> bool:
        return not self.preview and bool(inc_value) and self._old_inc_value(inc_value)

    def _old_inc_value(self, inc_value: "str | None") -> bool:
        max_inc_value = self.max_inc_value
        return bool(
            max_inc_value and max_inc_value > self._convert_to_datetime(inc_value)
        )

    def _process_item(self, item: "dict[str, Any]") -> "dict[str, Any]":
        # Activity logs are stored in a nested dictionary
        # but facet mapping works only with flat keys
        item = _flatten(item)
        item["hostname"] = self.hostname
        item["type"] = "activity_log"

        # workaround for SB-617
        if "query" in item:
            item["query.query"] = item.pop("query")

        # Put microseconds in separate field and
        # remove from datetime
        if "." in item["now"]:
            now, microseconds = item["now"].split(".", 1)
            item["now"] = now
            item["microseconds"] = microseconds
        return item

    def _batch_it(self, batch: "list[dict[str, Any]]") -> "list[dict[str, Any]]":
        """Save state and return batch.

        The method saves the latest state to the store and thus keep
        track of max_inc_value.

        Args:
            batch: List of items to save.

        Returns:
            List of items to save.
        """
        if not self.preview:
            self._save_state(batch)
        return batch

    def _save_state(self, batch: "list[dict[str, Any]]") -> None:
        for item in batch:
            item_inc_value = item.get(self.inc_column)
            if not item_inc_value or self._old_inc_value(item_inc_value):
                continue
            self.max_inc_value = item_inc_value


def _flatten(input_dict: dict) -> dict:
    stack = collections.deque([("", input_dict)])
    output_dict: dict = {}
    while stack:
        key, value = stack.popleft()
        if isinstance(value, dict):
            prefix = f"{key}." if key else ""
            stack.extend((f"{prefix}{k}", v) for k, v in value.items())
        else:
            output_dict[key] = value
    return output_dict
