import pytest

from octopus.text import (
    normalize_whitespace,
    remove_html_tags,
    remove_special_characters,
    remove_stopwords,
    sanitize_text,
)


@pytest.fixture(
    params=[
        ("Hello World", "Hello World"),
        ("Hello     World", "Hello World"),
        ("Hello \n\t World", "Hello World"),
        ("", ""),
    ]
)
def normalize_whitespace_cases(request):
    return request.param


def test_normalize_whitespace(normalize_whitespace_cases):
    txt, expected = normalize_whitespace_cases
    assert normalize_whitespace(txt) == expected


@pytest.fixture(
    params=[
        ("<p>Hello World</p>", "Hello World"),
        ("<p>Hello World</p><p>How are you?</p>", "Hello WorldHow are you?"),
        ("", ""),
    ]
)
def remove_html_tags_cases(request):
    return request.param


def test_remove_html_tags(remove_html_tags_cases):
    txt, expected = remove_html_tags_cases
    assert remove_html_tags(txt) == expected


@pytest.fixture(
    params=[
        ("HelloWorld", "HelloWorld", None),
        ("Hello@World", "HelloWorld", None),
        ("@!$%^&*", "", None),
        ("", "", None),
        ("Hello@World", "Hello@World", ["@"]),
    ]
)
def remove_special_characters_cases(request):
    return request.param


def test_remove_special_characters(remove_special_characters_cases):
    txt, expected, keep = remove_special_characters_cases
    assert remove_special_characters(txt, keep) == expected


@pytest.fixture(
    params=[
        ("Hello World", "Hello World"),
        ("The world is beautiful", "world beautiful"),
        ("The is in", ""),
        ("", ""),
    ]
)
def remove_stopwords_cases(request):
    return request.param


def test_remove_stopwords(remove_stopwords_cases):
    txt, expected = remove_stopwords_cases
    assert remove_stopwords(txt) == expected


@pytest.fixture(
    params=[
        ("<p>Hello, World!</p>", "hello world"),
        ("<p>Hello     World</p><p> How are you?!!!</p>", "hello world you"),
        ("", ""),
    ]
)
def sanitize_text_cases(request):
    return request.param


def test_sanitize_text(sanitize_text_cases):
    txt, expected = sanitize_text_cases
    assert sanitize_text(txt) == expected


def test_sanitize_text_keep_stopwords():
    txt = "The Hello World"
    assert sanitize_text(txt, keep_stopwords=True) == "the hello world"


def test_sanitize_text_keep_characters():
    txt = "Hello, World!"
    assert sanitize_text(txt, keep_chars=[" ", "!"]) == "hello world!"


def test_sanitize_text_keep_all_characters():
    txt = "Hello, World!"
    assert sanitize_text(txt, keep_chars="all") == "hello, world!"
