"""Text preprocessing functions."""

import re

from .stopwords import STOPWORDS

HTML_TAGS_PATTERN = re.compile(r"<[^>]*>")
WHITESPACE_PATTERN = re.compile(r"\s+")


def normalize_whitespace(txt: str) -> str:
    """Normalize whitespace in text.

    Args:
        txt: text to normalize whitespace from

    Returns:
        text with normalized whitespace
    """
    return WHITESPACE_PATTERN.sub(" ", txt).strip()


def remove_html_tags(txt: str) -> str:
    """Remove HTML tags from text.

    Args:
        txt: text to remove HTML tags from

    Returns:
        text with HTML tags removed
    """
    return HTML_TAGS_PATTERN.sub("", txt)


def remove_special_characters(
    txt: str, keep: "list[str] | None" = None, *, replace_with_space: bool = False
) -> str:
    """Remove special characters from text.

    Args:
        txt: text to remove special characters from
        keep: list of characters to keep (default: None)
        replace_with_space: whether to replace with space (default: False)

    Returns:
        text with special characters removed
    """
    if not keep:
        keep = []
    pattern = re.compile(r"[^a-zA-Z0-9" + re.escape("".join(keep)) + "]+")
    return pattern.sub(" " if replace_with_space else "", txt)


def remove_stopwords(txt: str) -> str:
    """Remove stopwords from text.

    Args:
        txt: text to remove stopwords from

    Returns:
        text with stopwords removed
    """
    return " ".join([w for w in txt.split() if w.lower() not in STOPWORDS])


def sanitize_text(
    txt: str,
    *,
    keep_stopwords: bool = False,
    keep_chars: "str | list[str] | None" = None,
) -> str:
    """Sanitize text by removing HTML tags, stopwords, and special characters.

    Args:
        txt: text to sanitize
        keep_stopwords: whether to keep stopwords. Defaults to False.
        keep_chars: list of special symbols to keep. `all` to keep all special
            symbols. Defaults to [" "].

    Returns:
        sanitized text
    """
    txt = txt.lower()
    txt = remove_html_tags(txt)
    if not keep_stopwords:
        txt = remove_stopwords(txt)
    if not keep_chars:
        keep_chars = [" "]
    if keep_chars != "all":
        if not isinstance(keep_chars, list):
            keep_chars = [keep_chars]
        txt = remove_special_characters(txt, keep=keep_chars)
    return normalize_whitespace(txt)
