"""Retry plugin."""

# mypy: ignore-errors
import hashlib
import logging
from typing import TYPE_CHECKING

from octopus.clients.redis_client import init_redis_client
from squirro.dataloader.data_source import DataSource

if TYPE_CHECKING:
    from collections.abc import Generator
    from typing import Any

    from redis import Redis

log = logging.getLogger(__name__)

_CONFIG_PATH = "/opt/squirro/octopus/config/main.ini"
REDIS_FAILED_ITEMS_HASH = "failed_items_hash"


class RedisRetrySource(DataSource):
    """A Data Loader Plugin to retry uploading items stored in redis."""

    redis_client: "Redis[bytes]"

    def connect(self, _: str | None = None, __: str | None = None) -> None:
        """Connect to the source."""
        self.redis_client = init_redis_client()

    def disconnect(self) -> None:
        """Disconnect from the source."""

    def getDataBatch(
        self, batch_size: int
    ) -> "Generator[list[dict[str, str]], Any, None]":
        """Generator - Get data from source on batches.

        Args:
            batch_size: Size of the batch.

        Yields:
            Lists of dictionaries.
        """
        try:
            file_paths_encoded: list[bytes] = self.redis_client.hkeys(
                REDIS_FAILED_ITEMS_HASH
            )
            file_paths_to_retry: list[dict[str, str]] = [
                {"content_url": item.decode()} for item in file_paths_encoded
            ]
        except Exception:
            log.exception("Failed to get files to retry from redis")
            raise

        for i in range(0, len(file_paths_to_retry), batch_size):
            yield file_paths_to_retry[i : i + batch_size]

    def getSchema(self) -> "list[str]":
        """Return the schema of the dataset.

        Returns:
            List of column names.
        """
        return ["content_url"]

    def getJobId(self) -> str:
        """Return a unique string for each different select."""
        # Generate a stable id that changes with the main parameters
        m = hashlib.sha256()

        redis_connection = (
            self.redis_client.connection_pool.connection_kwargs["host"]
            + ":"
            + str(self.redis_client.connection_pool.connection_kwargs["port"])
        )

        m.update(redis_connection.encode("utf-8"))
        job_id = m.hexdigest()
        log.debug("Job ID: %s", job_id)
        return job_id
