"""Track items using Redis."""

import json
from typing import TYPE_CHECKING

from octopus.clients import init_redis_client
from squirro.sdk import PipeletV1, require

if TYPE_CHECKING:
    from logging import Logger
    from typing import Any


UNKNOWN = "UNKNOWN"


# pylint: disable=too-few-public-methods
@require("log")
class TrackItems(PipeletV1):  # type: ignore[misc]
    """Track items in Redis.

    Add the ids of the items being ingested to a redis hash. The ids
    will be later used for tracking the ingestion status of the items.

    If items are uploaded (dashboard, email) by users, use item_hash
    which is later used for sending out email notifications.

    If items are from WFI, use item_wfi_hash which is separately used to
    keep track of which items from WFI are successful.

    The item_wfi_hash will be very large as it will contain thousands of
    items from WFI during migration/loading of documents. Hence it is
    kept independent from item_hash.
    """

    log: "Logger"

    def __init__(self, _: "dict[str, Any]") -> None:
        """Initialize the pipelet."""
        self.redis_client = init_redis_client()

    def consume(self, item: "dict[str, Any]") -> "dict[str, Any]":
        """Consume an item.

        Args:
            item: The item to consume

        Returns:
            The consumed item
        """
        item_id = item["id"]
        keywords: dict[str, list[str]] = item["keywords"]
        source_type = keywords.get("source_type", [UNKNOWN])[0]

        metadata = {
            "id": item_id,
            "title": item.get("title", ""),
            "created_at": item["created_at"],
            **keywords,
        }

        hash_name = (
            "item_wfi_status_hash"
            if source_type.startswith("WFI")
            else "item_status_hash"
        )
        msg = f"Adding item `{item_id}` to Redis hash `{hash_name}`"
        self.log.info(msg)

        try:
            self.redis_client.hset(hash_name, item_id, json.dumps(metadata))
        except Exception:
            self.log.exception("%s failed", msg)
            raise

        return item
