#!/usr/bin/env python3

# Libervia plugin for Extended Channel Search (XEP-0433)
# Copyright (C) 2009-2025 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import difflib
from typing import Any, Final, Iterator, Self, cast
from pydantic import BaseModel, Field, ConfigDict, RootModel, model_validator
from twisted.internet import defer
from twisted.words.protocols.jabber import jid
from twisted.words.xish import domish
from wokkel import data_form
from libervia.backend import G
from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.models.types import JIDType
from libervia.backend.plugins import plugin_misc_jid_search
from libervia.backend.plugins.plugin_xep_0059 import RSMRequest

log = getLogger(__name__)

# Namespaces
NS_CHANNEL_SEARCH: Final[str] = "urn:xmpp:channel-search:0"
NS_SEARCH: Final[str] = f"{NS_CHANNEL_SEARCH}:search"
NS_SEARCH_PARAMS: Final[str] = f"{NS_CHANNEL_SEARCH}:search-params"
NS_ORDER: Final[str] = f"{NS_CHANNEL_SEARCH}:order"
NS_ERROR: Final[str] = f"{NS_CHANNEL_SEARCH}:error"

# Common sort keys
SORT_ADDRESS: Final[str] = f"{{{NS_ORDER}}}address"
SORT_NUSERS: Final[str] = f"{{{NS_ORDER}}}nusers"

PLUGIN_INFO = {
    C.PI_NAME: "Extended Channel Search",
    C.PI_IMPORT_NAME: "XEP-0433",
    C.PI_TYPE: "XEP",
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_DEPENDENCIES: ["XEP-0059", "JID_SEARCH"],
    C.PI_RECOMMENDATIONS: [],
    C.PI_MAIN: "XEP_0433",
    C.PI_HANDLER: "no",
    C.PI_DESCRIPTION: _("Cross-domain search for public group chats"),
}


class SearchRequest(BaseModel):
    """Parameters for channel search request."""

    model_config = ConfigDict(extra="forbid")

    query: str | None = Field(None, alias="q")
    all: bool = False
    sinname: bool = True
    sindescription: bool = True
    sinaddr: bool = True
    min_users: int | None = Field(default=None, ge=0)
    types: list[str] = []
    key: str = SORT_ADDRESS
    rsm: RSMRequest | None = None

    @model_validator(mode="after")
    def check_conflicts(self) -> Self:
        if self.all and self.query:
            raise ValueError('Cannot combine "all" with search query')
        return self

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        """Parse from XMPP data form element."""
        form = data_form.Form.fromElement(element)
        if form.formNamespace != NS_SEARCH_PARAMS:
            raise ValueError("Invalid FORM_TYPE")

        kwargs = {}

        if "q" in form:
            kwargs["query"] = form["q"]

        if "all" in form:
            kwargs["all"] = form["all"]

        if "min_users" in form:
            try:
                kwargs["min_users"] = int(form["min_users"])
            except ValueError:
                raise ValueError("Invalid min_users value")

        for field in ["sinname", "sindescription", "sinaddr", "types", "key"]:
            if field in form:
                kwargs[field] = form[field]

        return cls(**kwargs)

    def to_form(self) -> data_form.Form:
        """Convert to "submit" data form"""
        form = data_form.Form("submit", formNamespace=NS_SEARCH_PARAMS)

        # Add fields with original XML field names
        if self.query is not None:
            form.addField(data_form.Field(var="q", value=self.query))

        if self.all:
            form.addField(data_form.Field("boolean", "all", value=True))

        if not self.sinname:
            form.addField(data_form.Field("boolean", "sinname", value=False))

        if not self.sindescription:
            form.addField(data_form.Field("boolean", "sindescription", value=False))

        if not self.sinaddr:
            form.addField(data_form.Field("boolean", "sinaddr", value=False))

        if self.min_users is not None:
            form.addField(data_form.Field(var="min_users", value=str(self.min_users)))

        if self.types:
            form.addField(data_form.Field("list-multi", "types", values=self.types))

        if self.key != SORT_ADDRESS:
            form.addField(data_form.Field("list-single", "key", value=self.key))

        return form

    def to_element(self) -> domish.Element:
        """Convert to XMPP data form submission."""
        form = self.to_form()
        search_elt = domish.Element((NS_SEARCH, "search"))
        search_elt.addChild(form.toElement())
        if self.rsm is not None:
            search_elt.addChild(self.rsm.to_element())

        return search_elt


class SearchItem(BaseModel):
    """Represents a single channel search result."""

    address: JIDType
    name: str | None = None
    description: str | None = None
    language: str | None = None
    nusers: int | None = Field(default=None, ge=0)
    service_type: str | None = None
    is_open: bool | None = None
    anonymity_mode: str | None = None

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        """Parse from <item> element."""
        if not (element.name == "item" and element.uri == NS_SEARCH):
            raise ValueError("Invalid channel item element")

        address = element.getAttribute("address")
        if not address:
            raise ValueError("Missing required address attribute")

        data: dict[str, Any] = {"address": jid.JID(address)}

        for child in element.elements():
            if child.uri != NS_SEARCH:
                continue

            content = str(child)
            match (name := child.name.replace("-", "_")):
                case "nusers":
                    data[name] = int(content)
                case "is_open":
                    data[name] = content.lower() == "true"
                case "service_type" | "anonymity_mode" if content:
                    data[name] = content
                case _:
                    data[name] = content

        return cls(**data)

    def to_element(self) -> domish.Element:
        """Convert to <item> element."""
        item = domish.Element((NS_SEARCH, "item"))
        item["address"] = str(self.address)

        field_mappings = {
            "name": "name",
            "description": "description",
            "language": "language",
            "nusers": "nusers",
            "service_type": "service-type",
            "anonymity_mode": "anonymity-mode",
        }

        for field, element_name in field_mappings.items():
            value = getattr(self, field)
            if value is not None:
                elem = item.addElement((NS_SEARCH, element_name))
                elem.addContent(
                    str(value).lower() if isinstance(value, bool) else str(value)
                )

        if self.is_open is not None:
            item.addElement(
                (NS_SEARCH, "is-open"), content="true" if self.is_open else "false"
            )

        return item


class SearchItems(RootModel):
    root: list[SearchItem]

    def __iter__(self) -> Iterator[SearchItem]:  # type: ignore
        return iter(self.root)

    def __getitem__(self, item) -> str:
        return self.root[item]

    def __len__(self) -> int:
        return len(self.root)

    def append(self, item: SearchItem) -> None:
        self.root.append(item)

    def sort(self, key=None, reverse=False) -> None:
        self.root.sort(key=key, reverse=reverse)  # type: ignore

    @classmethod
    def from_element(cls, element: domish.Element) -> Self:
        if element.name == "result" and element.uri == NS_SEARCH:
            result_elt = element
        else:
            try:
                result_elt = next(element.elements(NS_SEARCH, "result"))
            except StopIteration:
                raise exceptions.NotFound("No <result> element found.")
        items = []
        for item_elt in result_elt.elements(NS_SEARCH, "item"):
            items.append(SearchItem.from_element(item_elt))
        return cls(items)

    def to_element(self) -> domish.Element:
        result_elt = domish.Element((NS_SEARCH, "result"))
        for search_item in self.root:
            result_elt.addChild(search_item.to_element())
        return result_elt


class XEP_0433:
    """Implementation of XEP-0433 Extended Channel Search."""

    namespace: Final[str] = NS_CHANNEL_SEARCH

    def __init__(self, host: Any):
        log.info(f"Plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization.")
        self.host = host
        host.trigger.add(
            "JID_SEARCH_perform_search", self.jid_search_perform_search_trigger
        )
        self.allow_external = G.config.common.allow_external_search
        self.group_chat_search_default_jid = jid.JID(
            G.config.common.group_chat_search_default_jid
        )

        host.bridge.add_method(
            "extended_search_request",
            ".plugin",
            in_sign="sss",
            out_sign="s",
            method=self._search,
            async_=True,
        )

    async def jid_search_perform_search_trigger(
        self,
        client: SatXMPPEntity,
        search_term: str,
        options: plugin_misc_jid_search.Options,
        sequence_matcher: difflib.SequenceMatcher,
        matches: plugin_misc_jid_search.SearchItems,
    ) -> bool:
        if options.groupchat and self.allow_external:
            log.debug(f"Search {search_term!r} at {self.group_chat_search_default_jid}.")
            try:
                external_items = await self.search(
                    client,
                    self.group_chat_search_default_jid,
                    SearchRequest(q=search_term),
                )
            except Exception as e:
                log.warning(f"Can't do external search: {e}.")
                return True
            for search_item in external_items:
                room_search_item = plugin_misc_jid_search.RoomSearchItem(
                    entity=search_item.address,
                    name=(
                        search_item.name
                        or search_item.address.user
                        or search_item.address.full()
                    ),
                    local=False,
                    service_type=search_item.service_type,
                    is_open=search_item.is_open,
                    anonymity_mode=search_item.anonymity_mode,
                    description=search_item.description,
                    language=search_item.language,
                    nusers=search_item.nusers,
                )
                matches.append(room_search_item)
        return True

    def _search(
        self, target: str, search_request: str, profile: str
    ) -> defer.Deferred[str]:
        client = self.host.get_client(profile)
        d = defer.ensureDeferred(
            self.search(
                client, jid.JID(target), SearchRequest.model_validate_json(search_request)
            )
        )
        d.addCallback(
            lambda search_items: search_items.model_dump_json(exclude_none=True)
        )
        d = cast(defer.Deferred[str], d)
        return d

    async def search(
        self, client: SatXMPPEntity, target: jid.JID, search_request: SearchRequest
    ) -> SearchItems:
        """Do a Search"""
        iq_elt = client.IQ("get")
        iq_elt.addChild(search_request.to_element())
        iq_result_elt = await iq_elt.send(target.full())
        return SearchItems.from_element(iq_result_elt)
