"""Compute accuracy of the model."""

import json

import numpy as np

from octopus.clients import init_redis_client, init_squirro_client


def main() -> None:
    """Entrypoint."""
    sq_client, project_id = init_squirro_client()
    redis_client = init_redis_client()

    labels, preds = [], []
    for item in sq_client.scan(
        project_id,
        query='-source_type:"WFI*" document_type_true:* -is_deleted:true',
        fields=["id", "keywords"],
        count=10000,
    ):
        labels.extend(item["keywords"]["document_type_true"])
        preds.extend(item["keywords"]["document_type_pred"])

    results = compute_accuracy(labels, preds)
    redis_client.set(f"ml_accuracy_{project_id}", json.dumps(results))


def compute_accuracy(
    ground_truths: "list[str]", predictions: "list[str]"
) -> "dict[str,dict[str,int]]":
    """Compute accuracy, number of correctly predicted and total count.

    Args:
        ground_truths: List of ground truths.
        predictions: List of predictions.

    Returns:
        Dict with accuracy, correct and total.
    """
    labels = np.asarray(ground_truths)
    preds = np.asarray(predictions)
    classes = sorted(np.unique(labels))

    accuracy = int(np.mean(preds == labels) * 100)
    correct = int(np.sum(preds == labels))
    total = len(labels)
    results = {
        "ALL": {
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
        }
    }

    for c in classes:
        mask = labels == c
        accuracy = int(np.mean(preds[mask] == labels[mask]) * 100)
        correct = int(np.sum(preds[mask] == labels[mask]))
        total = int(np.sum(mask))

        results[c] = {
            "accuracy": accuracy,
            "correct": correct,
            "total": total,
        }

    return results


if __name__ == "__main__":
    main()
