# mypy: ignore-errors
__plugin_name__ = "exchange_plugin"
__version__ = "0.1"

import base64
import collections
import enum
import hashlib
import json
import logging
import mimetypes
import pathlib
import pickle  # nosec B403
import sys
from datetime import datetime, timedelta
from typing import Any, Dict, Generator, List, Optional, Set, Tuple
from urllib.parse import quote, urljoin, urlparse

import requests
from dateutil import parser
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry  # pylint: disable=import-error

from squirro.common.dependency import get_injected
from squirro.dataloader.data_source import DataSource

log = logging.getLogger(__name__)

ITEM_ENTRY = Dict[str, Any]  # pylint: disable=invalid-name
ITEMS_GENERATOR = Generator[ITEM_ENTRY, None, None]  # pylint: disable=invalid-name

CACHE_PREVIEW_PREFIX = "preview_items_"
CACHE_PREVIEW_EXPIRES_SECONDS = 5 * 60


class ExchangeSource(DataSource):  # pylint: disable=abstract-method
    """Exchange data source."""

    @enum.unique
    class StatKeys(enum.Enum):
        """Statistics keys."""

        ITEMS_BATCHED = "Items batched"
        MESSAGES = "Messages downloaded"
        MESSAGES_SKIP = "Messages skipped"
        ATTACHMENTS = "Attachments downloaded"
        ATTACHMENTS_SKIP = "Attachments skipped"

    stats: Dict[StatKeys, int]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.client: Optional[ExchangeClient] = None
        self.preview: bool = False
        self.next_links: Dict[str, Any] = {}
        self._stats_reset()

    @property
    def arg_access_id(self) -> str:
        return self.args.token.get("refresh_token") or self.args.token.get(
            "access_token"
        )

    @property
    def arg_file_size_limit(self) -> int:
        return self._get_size_limit(self.args.file_size_limit, 50)

    @property
    def arg_batch_size_limit(self) -> int:
        return self._get_size_limit(self.args.batch_size_limit, 50)

    @staticmethod
    def _get_size_limit(size: Optional[int], default: int) -> int:
        return size * 1024 * 1024 if size else default * 1024 * 1024

    @property
    def arg_download_spam_and_trash(self) -> bool:
        return self.args.download_spam_and_trash or False

    @property
    def arg_download_media_files(self) -> bool:
        return self.args.download_media_files or False

    def connect(self, inc_column: Optional[str] = None, max_inc_value=None):
        self._stats_reset()
        if self.args.reset:
            log.debug("Reset cache and store")
            self.key_value_cache.clear()
            self.key_value_store.clear()

        token = self.args.token
        log.debug("Token keys: %r", token.keys())

        access_token, access_token_expires_at = self._load_access_token()
        if access_token:
            log.debug("Using access_token from previous run")
        else:
            log.debug("Using access_token supplied in args")
            access_token = token.get("access_token")
            access_token_expires_at = token.get("expires_at")

        config = get_injected("config")
        self.client = ExchangeClient(
            client_id=config.get("dataloader", "exchange_client_id"),
            client_secret=config.get("dataloader", "exchange_client_secret"),
            access_token=access_token,
            access_token_expiration=(
                datetime.utcfromtimestamp(int(access_token_expires_at))
                if access_token_expires_at
                else None
            ),
            refresh_token=token.get("refresh_token"),
            scope=token.get("scope"),
        )

        log.info("Parameters:")
        log.info("File size limit: %d", self.arg_file_size_limit)
        log.info("Batch size limit: %d", self.arg_batch_size_limit)
        log.info("Download spam and trash: %r", self.arg_download_spam_and_trash)
        log.info("Download media files: %r", self.arg_download_media_files)

    def disconnect(self):
        log.info("Disconnecting; cleaning up session & dumping state")
        if self.client:
            token_expiration = self.client.access_token_expiration
            self._save_access_token(
                self.client.access_token,
                token_expiration.timestamp() if token_expiration else None,
            )
            self.client.close()
            self.client = None
        self._stats_log()

    # pylint: disable-next=invalid-name
    def getDataBatch(  # noqa: N802
        self, batch_size: int, *_, **kwargs
    ) -> Generator[List[ITEM_ENTRY], None, None]:
        self.preview = bool(self.preview_mode or kwargs.get("get_schema"))

        if self.preview:
            items = self._get_cached_items()
            if items:
                yield items

        me_folders = self.client.list_me_folders(
            self.arg_download_spam_and_trash, not self.preview
        )

        next_links: Dict = {}
        if not self.preview:
            next_links = self.key_value_store.get("next_links") or next_links
        self.next_links = {  # filter out old folders & populate with None for new ones
            folder_id: next_links.get(folder_id) for folder_id in me_folders
        }
        log.info(
            "Using %d cached next_links for iterative update",
            sum(bool(v) for v in self.next_links.values()),
        )

        batch: List = []
        for message in self._message_entries():
            if message.get("@removed"):
                self.stats[self.StatKeys.MESSAGES_SKIP] += 1
                continue
            for item in self._extract_message(message):
                if self._release_batch_earlier(batch):
                    yield self._batch_it(batch)
                    batch = []
                batch.append(item)
                if len(batch) >= batch_size:
                    yield self._batch_it(batch)
                    batch = []
        if batch:
            yield self._batch_it(batch)
        if self.preview:
            self._stats_reset()

    # pylint: disable-next=invalid-name
    def getJobId(self) -> str:  # noqa: N802
        m = hashlib.sha256()
        for v in (
            __plugin_name__,
            __version__,
            self.arg_access_id,
            self.arg_file_size_limit,
            self.arg_batch_size_limit,
            self.arg_download_spam_and_trash,
            self.arg_download_media_files,
        ):
            m.update(repr(v).encode())
        job_id = m.hexdigest()
        return job_id

    # pylint: disable-next=invalid-name
    def getSchema(self) -> List[str]:  # noqa: N802
        log.info("Getting Schema")
        fields = load_mapped_fields()
        batch = next(self.getDataBatch(10, get_schema=True), ())
        for entry in batch:
            fields |= set(entry.keys())
        self._stats_reset()
        return sorted(fields)

    # pylint: disable-next=invalid-name
    def getArguments(self) -> list:  # noqa: N802
        return [
            {
                "name": "file_size_limit",
                "display_label": "File Size Limit",
                "default": 50,
                "help": (
                    "File size limit in megabytes. "
                    "If a file is bigger than this limit, "
                    "then the file will not be indexed."
                ),
                "type": "int",
                "advanced": True,
            },
            {
                "name": "batch_size_limit",
                "display_label": "Batch Size Limit",
                "default": 1,
                "help": (
                    "Size limit in megabytes for a batch of items "
                    "(triggering early batch release if necessary)."
                ),
                "type": "int",
                "advanced": True,
            },
            {
                "name": "download_spam_and_trash",
                "display_label": "Download spam and trash",
                "help": (
                    "If set, messages from Junk Email and Deleted Items folders "
                    "will be downloaded."
                ),
                "type": "bool",
                "default": False,
                "action": "store_true",
                "advanced": True,
            },
            {
                "name": "download_media_files",
                "display_label": "Download media files",
                "help": "If set, media files will be downloaded.",
                "type": "bool",
                "default": True,
                "action": "store_true",
                "advanced": True,
            },
        ]

    # pylint: disable-next=invalid-name
    def getIncrementalColumns(self) -> None:  # noqa: N802
        """This plugin uses API specific incremental loading."""
        return None

    def _stats_reset(self) -> None:
        self.stats = {stat_key: 0 for stat_key in self.StatKeys.__members__.values()}

    def _load_access_token(self) -> Tuple[Optional[str], Optional[float]]:
        access_token_expires_at = self.key_value_store.get("access_token_expires_at")
        return (
            self.key_value_store.get("access_token"),
            float(access_token_expires_at) if access_token_expires_at else None,
        )

    def _save_access_token(
        self, access_token: Optional[str], expires_at: Optional[float]
    ) -> None:
        log.debug(
            "Saving access_token %r with expire_at %s", bool(access_token), expires_at
        )
        self.key_value_store["access_token"] = access_token
        self.key_value_store["access_token_expires_at"] = expires_at

    def _stats_log(self) -> None:
        log.info("Statistics:")
        stats = {stat_key.value: count for stat_key, count in self.stats.items()}
        max_name_length = max(len(stat_name) for stat_name in stats)
        for stat_name, count in stats.items():
            log.info("%s: %d", stat_name.ljust(max_name_length), count)

    def _get_cached_items(self) -> Optional[List[ITEM_ENTRY]]:
        cached_items = self.key_value_cache.get(
            f"{CACHE_PREVIEW_PREFIX}{self.getJobId()}"
        )
        if not cached_items:
            return None
        decoded_items = base64.b64decode(cached_items)
        loaded_items = pickle.loads(decoded_items)  # nosec B301
        log.debug("Loaded %d items from cache", len(loaded_items))
        return loaded_items

    def _message_entries(self) -> ITEMS_GENERATOR:
        """Yield messages from folder and update next_links."""
        for folder_id, next_link in self.next_links.copy().items():
            paginate = True
            while paginate:
                message_entries, next_link, has_more = self.client.delta(
                    folder_id, next_link
                )
                yield from message_entries
                self.next_links[folder_id] = next_link
                paginate = not self.preview and has_more

    def _extract_message(self, message: ITEM_ENTRY) -> ITEMS_GENERATOR:
        processed_message = flatten(message)
        for date_field in ExchangeAPI.Dates.fields:
            processed_message[date_field] = (
                self._convert_to_datetime(processed_message[date_field])
                if processed_message.get(date_field)
                else None
            )
        processed_message["toRecipients"] = self._extract_recipients_emails(
            processed_message.get("toRecipients", [])
        )
        processed_message["ccRecipients"] = self._extract_recipients_emails(
            processed_message.get("ccRecipients", [])
        )
        processed_message["bccRecipients"] = self._extract_recipients_emails(
            processed_message.get("bccRecipients", [])
        )
        processed_message["replyTo"] = self._extract_recipients_emails(
            processed_message.get("replyTo", [])
        )
        if processed_message.get("hasAttachments"):
            message_info = self._extract_message_info(processed_message)
            for attachment in self._process_attachments(processed_message["id"]):
                attachment.update(message_info)
                yield attachment

        # Skip the ingestion of email content
        # self.stats[self.StatKeys.MESSAGES] += 1
        # yield processed_message

    @staticmethod
    def _convert_to_datetime(date_str: str) -> datetime:
        return parser.parse(date_str, ignoretz=True)

    @staticmethod
    def _extract_recipients_emails(
        recipients: List[Dict[str, Dict[str, str]]]
    ) -> List[str]:
        recipient_addresses = [
            recipient.get("emailAddress", {}).get("address") for recipient in recipients
        ]
        return list(filter(None, recipient_addresses))

    @staticmethod
    def _extract_message_info(message: ITEM_ENTRY) -> ITEM_ENTRY:
        return {
            "messageId": message["id"],
            "createdDateTime": message["createdDateTime"],
            "from.emailAddress.address": message["from.emailAddress.address"],
            "from.emailAddress.name": message["from.emailAddress.name"],
            "toRecipients": message["toRecipients"],
            "ccRecipients": message["ccRecipients"],
            "bccRecipients": message["bccRecipients"],
            "replyTo": message["replyTo"],
            "emailSubject": message["subject"],
            "emailBodyPreview": message["bodyPreview"],
            "emailBodyContent": message["body.content"],
            "emailBodyContentType": message["body.contentType"],
        }

    def _process_attachments(self, message_id: str) -> ITEMS_GENERATOR:
        for attachment in self.client.get_attachments(message_id):
            if self._skip_attachment(attachment):
                continue
            self.stats[self.StatKeys.ATTACHMENTS] += 1
            yield self._process_attachment(attachment)

    def _skip_attachment(self, attachment: ITEM_ENTRY) -> bool:
        size = attachment["size"]
        if self._skip_too_big_file(size):
            log.debug("Skipped file with size (%d) above the limit", size)
            self.stats[self.StatKeys.ATTACHMENTS_SKIP] += 1
            return True
        mime_type = self._guess_type_if_unknown(
            attachment["contentType"], attachment["name"]
        )
        if self._skip_media_file(mime_type):
            log.debug("Skipped media file: %s", mime_type)
            self.stats[self.StatKeys.ATTACHMENTS_SKIP] += 1
            return True
        return False

    def _skip_too_big_file(self, file_size: int) -> bool:
        return bool(self.arg_file_size_limit and file_size > self.arg_file_size_limit)

    @staticmethod
    def _guess_type_if_unknown(mime_type: str, file_name: str) -> str:
        if mime_type == ExchangeAPI.Files.unknown_mime_type:
            if file_name.lower().endswith(".msg"):
                mime_type = "application/vnd.ms-outlook"
            else:
                new_type, _ = mimetypes.guess_type(file_name, strict=False)
                if new_type:
                    mime_type = new_type
        return mime_type

    def _skip_media_file(self, mime_type: str) -> bool:
        return not self.arg_download_media_files and is_media_file(mime_type)

    def _process_attachment(self, attachment: ITEM_ENTRY) -> ITEM_ENTRY:
        processed_attachment = attachment.copy()
        file_name = processed_attachment.pop("name")
        processed_attachment["subject"] = file_name
        processed_attachment["filename"] = file_name
        processed_attachment["lastModifiedDateTime"] = self._convert_to_datetime(
            processed_attachment["lastModifiedDateTime"]
        )
        processed_attachment["contentType"] = self._guess_type_if_unknown(
            processed_attachment["contentType"], file_name
        )
        processed_attachment["contentBytes"] = (
            "BINARY CONTENT [disabled in Data Preview]"
            if self.preview
            else self._get_content(processed_attachment.get("contentBytes"))
        )

        return processed_attachment

    def _get_content(self, content: Optional[str]) -> Optional[bytes]:
        if not content:
            log.debug("No content in the attachment")
            return None
        return self._decode_content(content)

    @staticmethod
    def _decode_content(encoded_content: str) -> bytes:
        return base64.urlsafe_b64decode(encoded_content)

    def _release_batch_earlier(self, batch: List[ITEM_ENTRY]) -> bool:
        current_batch_size = sum(get_item_size(item) for item in batch)
        if (
            self.arg_batch_size_limit
            and batch
            and self.arg_batch_size_limit < current_batch_size
        ):
            log.debug(
                "Early batch release. Current batch size %dB, limit %dB",
                current_batch_size,
                self.arg_batch_size_limit,
            )
            return True
        return False

    def _batch_it(self, batch: List[ITEM_ENTRY]) -> List[ITEM_ENTRY]:
        self.stats[self.StatKeys.ITEMS_BATCHED] += len(batch)
        if self.preview:
            self._cache_items(batch)
        else:
            self._save_state()
        return batch

    def _cache_items(self, batch: List[ITEM_ENTRY]) -> None:
        cache_key = f"{CACHE_PREVIEW_PREFIX}{self.getJobId()}"
        dumped_items = pickle.dumps(batch)
        self.key_value_cache[cache_key] = base64.b64encode(dumped_items).decode("utf-8")
        self.key_value_cache.expires(cache_key, CACHE_PREVIEW_EXPIRES_SECONDS)
        log.debug("Cached %d items", len(batch))

    def _save_state(self) -> None:
        self.key_value_store["next_links"] = self.next_links
        log.debug("Saved next_links %r", self.next_links)


def load_mapped_fields(file_name: str = "mappings.json") -> Set[str]:
    with (pathlib.Path(__file__).parent / file_name).open() as json_file:
        fields = json.load(json_file)
        return {value for key, value in fields.items() if key.startswith("map_")}


def is_media_file(mimetype: Optional[str]) -> bool:
    """Determine if mime type belong to media file.

    Move to common utils in STX-28.
    """
    media_registries = ["audio", "image", "video"]
    return bool(
        mimetype
        and any(mimetype.startswith(f"{registry}/") for registry in media_registries)
    )


def flatten(input_dict: dict) -> dict:
    stack = collections.deque([("", input_dict)])
    output_dict: dict = {}
    while stack:
        key, value = stack.popleft()
        if isinstance(value, dict):
            prefix = f"{key}." if key else ""
            stack.extend((f"{prefix}{k}", v) for k, v in value.items())
        else:
            output_dict[key] = value
    return output_dict


def get_item_size(item: ITEM_ENTRY) -> int:
    return sum(sys.getsizeof(v) for v in item.values())


config = get_injected("config")
tenant_id = config.get("dataloader", "exchange_tenant_id", fallback=None)


# pylint: disable-next=too-few-public-methods
class ExchangeAPI:
    """Exchange API constants."""

    # pylint: disable-next=too-few-public-methods
    class Oauth2:
        """Oauth2 constants."""

        refresh_token_url = (
            f"https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token"
            if tenant_id
            else "https://login.microsoftonline.com/common/oauth2/v2.0/token"
        )

    # pylint: disable-next=too-few-public-methods
    class Links:
        """Links constants."""

        base_url = "https://graph.microsoft.com/v1.0/"

    # pylint: disable-next=too-few-public-methods
    class Folders:
        """Folder constants."""

        spam = "Junk Email"
        trash = "Deleted Items"

    # pylint: disable-next=too-few-public-methods
    class Files:
        """File constants."""

        unknown_mime_type = "application/octet-stream"

    # pylint: disable-next=too-few-public-methods
    class Dates:
        """Date constants."""

        fields = [
            "createdDateTime",
            "lastModifiedDateTime",
            "receivedDateTime",
            "sentDateTime",
        ]


class OAuthSession(requests.Session):
    """OAuth2 session with automatic token refresh.

    FIXME to be moved to common plugin library as part of STX-11
    """

    LOG = logging.getLogger(f"{__name__}.OAuthSession")
    TOKEN_EXPIRATION_BUFFER = timedelta(seconds=300)

    token_endpoint_url: str

    def __init__(  # pylint: disable=too-many-arguments
        self,
        access_token: Optional[str] = None,
        access_token_expiration: Optional[datetime] = None,
        refresh_token: Optional[str] = None,
        client_id: Optional[str] = None,
        client_secret: Optional[str] = None,
        scope: Optional[List[str]] = None,
    ):
        super().__init__()

        self.access_token = access_token
        self.access_token_expiration = access_token_expiration

        self.client_id = client_id
        self.client_secret = client_secret
        self.refresh_token = refresh_token
        self.scope = scope

    @property
    def access_token(self) -> Optional[str]:
        return self._access_token

    @access_token.setter
    def access_token(self, value: Optional[str]):
        self._access_token = value
        self.access_token_expiration = None
        self.headers["Authorization"] = f"Bearer {self._access_token}"

    @property
    def can_refresh(self) -> bool:
        return bool(self.refresh_token and self.client_id and self.client_secret)

    def request(self, *args, oauth2_refresh: bool = True, **kwargs):
        if oauth2_refresh:
            self.check_and_refresh_access_token()

        res = super().request(*args, **kwargs)
        if oauth2_refresh and self.is_expired_access_token_response(res):
            self.LOG.warning("Request failed with expired access token")
            if self.can_refresh:
                self.refresh_access_token()
                res = super().request(*args, **kwargs)
            else:
                self.LOG.error("Unable to refresh access token")

        return res

    @staticmethod
    def is_expired_access_token_response(response: requests.Response) -> bool:
        return response.status_code == 401

    def check_and_refresh_access_token(self):
        needs_refresh = (
            self.access_token_expiration
            and (datetime.utcnow() + self.TOKEN_EXPIRATION_BUFFER)
            >= self.access_token_expiration
        )
        needs_token = not self.access_token
        if (needs_refresh or needs_token) and self.can_refresh:
            self.refresh_access_token()

    def refresh_access_token(self):
        self.LOG.info("Refreshing access token")
        body = {
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "grant_type": "refresh_token",
            "refresh_token": self.refresh_token,
        }
        if self.scope is not None:
            body["scope"] = " ".join(self.scope)
        res = self.post(
            self.token_endpoint_url,
            data=body,
            oauth2_refresh=False,
            headers={"Authorization": None},
        )
        res.raise_for_status()

        token = res.json()
        self.access_token = token["access_token"]
        self.access_token_expiration = datetime.utcnow() + timedelta(
            seconds=int(token["expires_in"])
        )


class ExchangeClient(OAuthSession):
    """Exchange API client."""

    token_endpoint_url = ExchangeAPI.Oauth2.refresh_token_url
    base_url = ExchangeAPI.Links.base_url

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.headers["User-Agent"] = f"Squirro-{__plugin_name__}-{__version__}"
        retry = Retry(
            total=9,
            connect=9,
            backoff_factor=0.5,
            status_forcelist=[429, 500, 502, 503, 504],
        )
        adapter = HTTPAdapter(max_retries=retry)
        self.mount("https://", adapter)

    def list_me_folders(
        self, download_spam_and_trash: bool, load_nested: bool
    ) -> Generator[str, None, None]:
        stack = collections.deque(["./me/mailFolders"])
        while stack:
            url = stack.popleft()
            res = self.api_request(url).json()
            for folder in res["value"]:
                if not download_spam_and_trash and folder["displayName"] in {
                    ExchangeAPI.Folders.spam,
                    ExchangeAPI.Folders.trash,
                }:
                    continue
                if load_nested and folder["childFolderCount"]:
                    stack.append(f"./me/mailFolders/{folder['id']}/childFolders")
                yield folder["id"]
            next_url = res.get("@odata.nextLink")
            if next_url:
                stack.append(next_url)

    def delta(
        self, folder_id: str, last_url: Optional[str]
    ) -> Tuple[List[dict], str, bool]:
        """Fetch new messages from the folder.

        https://docs.microsoft.com/en-us/graph/api/message-delta?view=graph-rest-1.0&tabs=http
        """
        url = last_url or f"./me/mailfolders/{quote(folder_id)}/messages/delta"
        res = self.api_request(url).json()
        files = res["value"]
        next_link = res.get("@odata.nextLink")
        has_more = bool(next_link)
        if not next_link:
            next_link = res["@odata.deltaLink"]
        return files, next_link, has_more

    def get_attachments(self, message_id: str) -> List[Dict[str, Any]]:
        res = self.api_request(f"./me/messages/{message_id}/attachments")
        return res.json()["value"]

    def api_request(self, url: str, method: str = "GET", **kwargs) -> Any:
        parsed = urlparse(url)
        if not parsed.scheme:
            url = urljoin(self.base_url, url)
        res = self.request(method, url, **kwargs)
        try:
            res.raise_for_status()
        except requests.HTTPError as e:
            if e.response is not None:
                log.warning("Error occurred %r", e.response.json())
            raise
        return res
