"""AutoML Classifier."""

import pandas as pd
import requests
from autogluon.tabular import TabularDataset

from squirro.lib.nlp.steps.classifiers.base import Classifier


class AutoMLClassifier(Classifier):
    """The AutoML :class:`Classifier` uses the `autogluon tabular
    <auto.gluon.ai>`_ library for AutoML.

    **Input** - all input fields need to be of type :class:`str`.

    **Output** - prediction of the extract as a string

    Parameters:
        type (str): `automl_remote`
        model_id (str): model folder name
    **Example**

    .. code-block:: json

        {
            "step": "classifier",
            "type": "automl_remote",
            "input_fields": ["extract"],
            "output_field": "prediction",
            "label_field": "label",
            "hyperparameters": {"FASTTEXT": {"epoch": 25}}
            "fit_parameters": {"time_limit": 30, "num_bag_folds": 5}
        }
    """

    def __init__(self, config):
        super().__init__(config)

    def clean(self):
        super().clean()

    def load(self):
        pass

    def process_batch(self, batch):
        batch = self.filter_skipped(batch)
        tabular_batch = self._doc_to_tabular(batch)
        response = requests.post(
            "http://localhost:25565/inference",
            json={"data": tabular_batch},
            params={"model_id": self.model_id},
        )
        y_pred = pd.DataFrame(response.json())
        for doc, prediction, proba in zip(
            batch, y_pred.idxmax(axis=1), y_pred.max(axis=1)
        ):
            doc.fields[self.output_field] = {prediction: proba}
        return batch

    def _doc_to_tabular(self, docs, train: bool = False):
        data = []
        for doc in docs:
            extract = doc.fields.get(self.input_field)

            if train:
                label = doc.fields.get(self.label_field)
                if isinstance(label, list):
                    label = label[0]

                if label is not None:
                    # only convert docs with label and extract for training
                    data.append((extract, label, doc.id))
            else:
                # convert all docs for inferring/predicting
                data.append((extract, None, doc.id))

        return TabularDataset(data, columns=["extract", "label", "document_id"])

    def train(self, docs):
        pass
