"""Create or restore elasticsearch snapshot."""

import logging
import subprocess  # noqa: S404
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING

from octopus.clients import init_es_client
from octopus.utils import set_log_verbosity

if TYPE_CHECKING:
    from argparse import Namespace

    from elasticsearch import Elasticsearch


set_log_verbosity(logging.INFO)


# Timeout for ES operations
_TIMEOUT = 500


def _sq_services(mode: str) -> None:
    """Start or stop squirro services that might block modify the index.

    Args:
        mode: start or stop

    Raises:
        RuntimeError: If the service fails to start
        ValueError: If the service is not found
    """
    sq_services = [
        "sqfrontendd",
        "sqingesterd",
        "sqplumberd",
        "sqtopicd",
    ]

    for service in sq_services:
        logging.warning("%s %s", mode.capitalize(), service)
        if not mode or not service:
            msg = "Mode and service must be provided."
            logging.error(msg)
            raise ValueError(msg)

        command = ["sudo", "systemctl", mode, service]

        if (
            subprocess.run(  # noqa: S603
                command, shell=False, check=True
            ).returncode
            != 0
        ):
            command = ["sudo", "systemctl", "start", service]
            subprocess.run(  # noqa: S603
                command, shell=False, check=True
            )
            msg = f"Error {mode.lower()}ing {service}"
            logging.error(msg)
            raise RuntimeError(msg)


def create_snapshot(es: "Elasticsearch") -> None:
    """Create a snapshot of all indices.

    Args:
        es: Elasticsearch client
    """
    logging.info("Creating ES snapshot")

    today = datetime.now(UTC)
    try:
        es.snapshot.create(
            repository="es-snapshot-repo",
            snapshot=f"{today.strftime('%Y%m%d')}",
            body={"indices": "_all"},
            wait_for_completion=True,
            request_timeout=_TIMEOUT,
        )
    except Exception:
        logging.exception("Error creating snapshots.")
        raise
    logging.info("Snapshot complete")


def restore_snapshot(  # noqa: C901 - TODO: refactor
    es: "Elasticsearch",
    start_date: "str | None" = None,
    end_date: "str | None" = None,
) -> None:
    """Restore a snapshot of all indices.

    Args:
        es: Elasticsearch client
        start_date: Start date of the snapshot
        end_date: End date of the snapshot

    Raises:
        ValueError: If the date range is invalid
        SystemError: If no snapshot is present
    """
    logging.warning("Restoring ES snapshot")

    # Get yesterday's date
    yesterday = datetime.now(UTC) - timedelta(days=1)

    if not start_date:
        start_date = yesterday.strftime("%Y%m%d")
    start = int(
        datetime.strptime(start_date, "%Y%m%d").astimezone(UTC).strftime("%Y%m%d")
    )

    if not end_date:
        end_date = yesterday.strftime("%Y%m%d")
    end = int(datetime.strptime(end_date, "%Y%m%d").astimezone(UTC).strftime("%Y%m%d"))

    # Get all snapshots
    try:
        snapshot_info = es.snapshot.get(repository="es-snapshot-repo", snapshot="_all")
    except Exception:
        logging.exception("Error getting snapshots.")
        raise

    if not (n_snapshots := len(snapshot_info["snapshots"])):
        msg = "There seems to be no snapshot present."
        logging.error(msg)
        raise SystemError(msg)
    logging.info("There are a total of %d snapshots.", n_snapshots)

    # Get the first and last snapshot dates
    first = int(snapshot_info["snapshots"][0]["snapshot"])
    last = int(snapshot_info["snapshots"][-1]["snapshot"])
    if start < first or end > last:
        msg = f"Invalid date range. Valid range is {first}-{last}"
        logging.error(msg)
        raise ValueError(msg)
    logging.warning("Restoring snapshots from %s to %s", start, end)

    for snapshot in snapshot_info["snapshots"]:
        if not snapshot["snapshot"]:
            continue

        # Break if the snapshot is after the end date
        if int(snapshot["snapshot"]) > end:
            break

        # Skip if the snapshot is before the start date
        if int(snapshot["snapshot"]) < start:
            continue

        # Restore all indices from the snapshot
        for index in snapshot["indices"]:
            # Some system indices can only be restored as part of feature state
            if index.startswith("."):
                continue

            logging.warning("Restoring index: %s", index)
            if es.indices.exists(index=index):
                es.indices.close(index=index)

            es.snapshot.restore(
                repository="es-snapshot-repo",
                snapshot=snapshot["snapshot"],
                body={"indices": index},
                wait_for_completion=True,
                request_timeout=_TIMEOUT,
            )
    logging.info("Restore complete")


def main(args: "Namespace") -> None:
    """Create or restore elasticsearch snapshot."""
    logging.info("Connecting to ES Client")
    es = init_es_client()

    if args.mode == "create":
        create_snapshot(es)
    else:
        restore_snapshot(es, args.start_date, args.end_date)
        _sq_services("restart")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode",
        type=str,
        choices=["create", "restore"],
        required=True,
        help="Create/Restore elasticsearch snapshot",
    )
    parser.add_argument(
        "--end-date",
        type=str,
        required=False,
        help="End date of snapshot in the form of YYYYMMDD (e.g 20230717).",
    )
    parser.add_argument(
        "--start-date",
        type=str,
        required=False,
        help="Start date of snapshot in the form of YYYYMMDD (e.g 20230717).",
    )
    main(parser.parse_args())
