"""Redis snapshot."""

import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING

from octopus.clients import init_redis_client
from octopus.utils import set_log_verbosity

if TYPE_CHECKING:
    from argparse import Namespace

    from redis import Redis


set_log_verbosity(logging.INFO)


def create(client: "Redis[bytes]", db: int, backup_dir: str) -> None:
    """Create a snapshot of the redis database.

    Args:
        client: Redis client
        db: Redis database
        backup_dir: Path to the backup directory
    """
    # Fetch all keys
    keys = client.keys()
    if not keys:
        logging.warning("No keys found.")
        return

    data = {}
    for key_b in keys:
        key = key_b.decode()
        logging.info("Dumping %s", key)

        # Dump does not work on redis client 4.3.4
        value = client.execute_command("dump", key)
        if not value:
            logging.exception("\t> No value found. Skipping...")
            continue
        data[key] = str(value, "latin-1")

    # Write to a json file
    today = datetime.now(UTC).strftime("%Y%m%d")
    path = Path(backup_dir) / f"{today}" / f"db_{db}.json"
    path.parent.mkdir(parents=True, exist_ok=True)
    logging.warning("Writing to %s", path)
    with path.open("w") as f:
        json.dump(data, f)


def restore(
    client: "Redis[bytes]",
    db: int,
    backup_dir: str,
    *,
    exclude_extauth_saml: bool = True,
) -> None:
    """Restore a snapshot of the redis database.

    Args:
        client: Redis client
        db: Redis database
        backup_dir: Path to the backup directory
        exclude_extauth_saml: Exclude extauth_saml keys
    """
    today = datetime.now(UTC).strftime("%Y%m%d")
    path = Path(backup_dir) / f"{today}" / f"db_{db}.json"
    if not path.exists():
        logging.exception("Snapshot %s does not exist.", path)
        return

    with Path(path).open(encoding="utf-8") as f:
        data = json.load(f)

    for key, value in data.items():
        if exclude_extauth_saml and "extauth_saml" in key:
            continue
        logging.warning("Restoring %s...", key)

        # Restore does not work on redis client 4.3.4
        client.execute_command("restore", key, 0, value.encode("latin-1"), "REPLACE")


def main(args: "Namespace") -> None:
    """Entrypoint."""
    if args.mode == "create":
        logging.info("Creating snapshot")
        snapshot = create
    elif args.mode == "restore":
        logging.info("Restoring snapshot")
        snapshot = restore

    db = 0
    while db < 16:  # noqa: PLR2004
        logging.warning("Connecting to db %d", db)
        client = init_redis_client(db=db)
        snapshot(client, db, args.backup_dir)
        db += 1

    logging.info("Done")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--backup-dir", type=str, default="/flash/redis-snapshots")
    parser.add_argument(
        "--mode",
        type=str,
        choices=["create", "restore"],
        required=True,
        help="Create/Restore Redis snapshot",
    )

    main(parser.parse_args())
