import io
import json
import zipfile
from unittest.mock import patch

import pytest

from .plugin import (
    get_all_labels_to_update,
    get_wfi_metadata,
    plugin,
    update_document_status,
)


class TestGetSquirroLabelsAndWfiMetadata:
    def test_update_document_type(
        self, expected_labels_values, mock_redis_client, document_types_mapping
    ):
        labels_changed = {"document_type": ["Document Type"]}
        mock_redis_client.hget.return_value = json.dumps(document_types_mapping)

        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == combine_two_dicts(
            expected_labels_values["document_type"]["squirro"],
            expected_labels_values["document_type"]["wfi"],
        )
        assert get_wfi_metadata(sq_labels_to_update) == dict_first_val(
            expected_labels_values["document_type"]["wfi"]
        )

    def test_update_company_name_no_wfi_update_required(self, expected_labels_values):
        labels_changed = {"company_name": ["Company A", "Company B"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)

        # Handle permission code differently because
        # order changes after turning into list
        permission_code_actual = sq_labels_to_update.pop("permission_code")
        permission_code_expected = expected_labels_values["company_name"][
            "squirro"
        ].pop("permission_code")

        assert set(permission_code_actual) == set(permission_code_expected)
        assert sq_labels_to_update == expected_labels_values["company_name"]["squirro"]
        assert get_wfi_metadata(sq_labels_to_update) == {}

    def test_update_company_name_wfi_update_required(self, expected_labels_values):
        labels_changed = {
            "company_name": ["Company A", "Company B"],
            "wfi_company_name": ["Company A"],
        }
        sq_labels_to_update = get_all_labels_to_update(labels_changed)

        # Handle permission code differently because
        # order changes after turning into list
        permission_code_actual = sq_labels_to_update.pop("permission_code")
        permission_code_expected = expected_labels_values["company_name"][
            "squirro"
        ].pop("permission_code")

        assert set(permission_code_actual) == set(permission_code_expected)
        assert sq_labels_to_update == combine_two_dicts(
            expected_labels_values["company_name"]["squirro"],
            expected_labels_values["company_name"]["wfi"],
        )
        assert get_wfi_metadata(sq_labels_to_update) == dict_first_val(
            expected_labels_values["company_name"]["wfi"]
        )

    def test_update_document_date(self, expected_labels_values):
        labels_changed = {"document_date": ["2022-12-12"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == combine_two_dicts(
            expected_labels_values["document_date"]["squirro"],
            expected_labels_values["document_date"]["wfi"],
        )
        assert get_wfi_metadata(sq_labels_to_update) == dict_first_val(
            expected_labels_values["document_date"]["wfi"]
        )

    def test_update_references(self, expected_labels_values):
        labels_changed = {"references": ["Ref 1", "Ref 2"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == combine_two_dicts(
            expected_labels_values["references"]["squirro"],
            expected_labels_values["references"]["wfi"],
        )
        assert get_wfi_metadata(sq_labels_to_update) == dict_first_val(
            expected_labels_values["references"]["wfi"]
        )

    def test_update_notes(self, expected_labels_values):
        labels_changed = {"notes": ["Note 1", "Note 2"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == expected_labels_values["notes"]["squirro"]
        assert get_wfi_metadata(sq_labels_to_update) == {}

    def test_update_cso_name(self, expected_labels_values):
        labels_changed = {"cso_name": ["John Doe"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == expected_labels_values["cso_name"]["squirro"]
        assert get_wfi_metadata(sq_labels_to_update) == {}

    def test_update_rm_name(self, expected_labels_values):
        labels_changed = {"rm_name": ["Jane Doe"]}
        sq_labels_to_update = get_all_labels_to_update(labels_changed)
        assert sq_labels_to_update == expected_labels_values["rm_name"]["squirro"]
        assert get_wfi_metadata(sq_labels_to_update) == {}


class TestUpdateDocumentStatus:
    def test_no_update_document_status(self, mock_requests_put):
        labels_changed = {"references": ["Ref 1"]}
        update_document_status("project-id", "document-id", labels_changed)
        mock_requests_put.assert_not_called()

    def test_update_document_status(self, monkeypatch, mock_requests_put):
        class MockFlaskRequest:
            headers = None

        labels_changed = {
            "document_type": ["Document Type"],
            "wfi_company_name": ["Company A"],
            "document_date": ["2022-12-12"],
        }
        monkeypatch.setattr("documents.plugin.request", MockFlaskRequest)

        update_document_status("project-id", "document-id", labels_changed)
        mock_requests_put.assert_called_with(
            "https://example.com/studio/document_status_tracking"
            "/projects/project-id/documents/document-id",
            json={
                "document_type": "Document Type",
                "company_name": "Company A",
                "document_date": "2022-12-12",
            },
            headers=None,
            timeout=10,
        )


@patch("documents.plugin.get_injected")
class TestRequestsUploadDocument:
    def test_upload_no_document(self, _, test_client) -> None:
        response = test_client.post("/")
        assert response.status_code == 400

    def test_upload_single_pdf_with_labels(
        self,
        _,
        test_client,
    ) -> None:
        with open("tests/octopus/pdf/test-data/force-ocr.pdf", "rb") as f:
            pdf_bytes_io = io.BytesIO(f.read())
        test_file = (pdf_bytes_io, "test-file.pdf")
        labels = json.dumps(
            {
                "document_type_true": ["ANNUAL REPORT"],
                "company_name_true": ["Frasers"],
            }
        )
        response = test_client.post(
            "/",
            data={
                "project_id": "project-id",
                "labels": labels,
                "documents": test_file,
            },
        )

        assert response.status_code == 201

    def test_upload_multiple_files(
        self,
        _,
        test_client,
    ):
        num_files = 5
        test_files = [
            (io.BytesIO(b"test file"), f"test-file-{i}.pdf") for i in range(num_files)
        ]
        response = test_client.post(
            "/",
            data={
                "project_id": "project-id",
                "documents": test_files,
            },
        )

        assert response.status_code == 201

    def test_upload_zip(
        self,
        _,
        test_client,
    ) -> None:
        zip_bytes_io = io.BytesIO()
        with zipfile.ZipFile(zip_bytes_io, "w") as zipf:
            zipf.writestr("test-file.txt", b"test file")
        zip_bytes_io.seek(0)
        response = test_client.post(
            "/",
            data={
                "project_id": "project-id",
                "documents": (zip_bytes_io, "test-file.zip"),
            },
        )

        assert response.status_code == 201


class TestRequestsModifySquirroLabels:
    def test_no_body(self, test_client):
        res = test_client.patch("/item-id")
        assert res.status_code == 400

    def test_non_json_body(self, test_client):
        res = test_client.patch("/item-id", data="123")
        assert res.status_code == 400

    def test_invalid_body(self, test_client):
        res = test_client.patch("/item-id", json={"test_item": "123"})
        assert res.status_code == 400

    @patch("documents.plugin.get_injected")
    def test_valid_request(self, _, test_client):
        res = test_client.patch(
            "/item-id",
            json={"project_id": "project-id", "labels": {"document_type": ["Report"]}},
        )
        assert res.status_code == 204


class TestRequestsUpdateDocument:
    def test_no_body(self, test_client):
        res = test_client.put("/item-id")
        assert res.status_code == 400

    def test_non_json_body(self, test_client):
        res = test_client.put("/item-id", data="123")
        assert res.status_code == 400

    def test_invalid_body(self, test_client):
        res = test_client.put("/item-id", json={"test_item": "123"})
        assert res.status_code == 400

    @patch("documents.plugin.update_squirro_item")
    @patch("documents.plugin.add_wfi_payload_to_redis")
    @patch("documents.plugin.update_document_status")
    def test_valid_request(self, _, __, ___, test_client):
        res = test_client.put(
            "/item-id",
            json={
                "project_id": "project-id",
                "wfi_document_id": "wfi-id",
                "labels": {"references": ["Ref 1"]},
            },
        )
        assert res.status_code == 200


class TestRequestsBulkAssign:
    def test_invalid_body(self, test_client):
        res = test_client.patch("/item-id", json={"test_item": "123"})
        assert res.status_code == 400

    @patch("documents.plugin.get_injected")
    def test_valid_request(self, mock_get_injected, test_client):
        mock_client = mock_get_injected()

        res = test_client.patch(
            "/bulk",
            json={
                "project_id": "project-id",
                "labels": {"cso_name": ["John Doe"]},
                "ids": ["item-0", "item-1", "item-2"],
            },
        )

        mock_client.modify_items.assert_called_once_with(
            "project-id",
            items=[
                {"id": f"item-{i}", "keywords": {"cso_name": ["John Doe"]}}
                for i in range(3)
            ],
        )
        assert res.status_code == 204


@pytest.fixture
def test_client():
    with plugin.test_client() as client:
        yield client


@pytest.fixture
def expected_labels_values():
    return {
        "document_type": {
            "squirro": {
                "document_type": ["Document Type"],
                "document_type_true": ["Document Type"],
                "document_category": ["Document Category"],
            },
            "wfi": {
                "wfi_document_name": ["WFI Document Name"],
                "wfi_document_type": ["WFI Document Type"],
                "wfi_document_category": ["WFI Document Category"],
            },
        },
        "company_name": {
            "squirro": {
                "company_cif": ["22222", "11111"],
                "company_name": ["Company A", "Company B"],
                "company_name_true": ["Company A", "Company B"],
                "company_uid": ["SG_22222", "SG_11111"],
                "permission_code": ["B1", "C1", "A1"],
                "rm_name": ["Company A RM"],
                "uid_permission_code": [
                    "SG_22222___B1",
                    "SG_22222___C1",
                    "SG_11111___A1",
                    "SG_11111___B1",
                ],
            },
            "wfi": {
                "wfi_company_name": ["Company A"],
                "wfi_company_cif": ["22222"],
                "wfi_company_rm_code": ["R011"],
                "wfi_company_team_code": ["B1"],
                "wfi_company_team_name": ["Team A"],
                "wfi_company_segment": ["R"],
            },
        },
        "document_date": {
            "squirro": {
                "document_date": ["2022-12-12"],
                "document_date_true": ["2022-12-12"],
            },
            "wfi": {
                "wfi_document_date": ["2022-12-12"],
            },
        },
        "references": {
            "squirro": {
                "references": ["Ref 1", "Ref 2"],
            },
            "wfi": {
                "wfi_references": ["Ref 1;Ref 2"],
            },
        },
        "cso_name": {
            "squirro": {
                "cso_name": ["John Doe"],
            }
        },
        "rm_name": {
            "squirro": {
                "rm_name": ["Jane Doe"],
            }
        },
        "notes": {
            "squirro": {
                "notes": ["Note 1", "Note 2"],
            }
        },
    }


@pytest.fixture
def document_types_mapping():
    return {
        "document_category": "Document Category",
        "wfi_document_name": "WFI Document Name",
        "wfi_document_type": "WFI Document Type",
        "wfi_document_category": "WFI Document Category",
    }


def combine_two_dicts(dict1, dict2):
    return {**dict1, **dict2}


def dict_first_val(dictionary):
    return {key: val[0] for key, val in dictionary.items()}
