"""Ensure that Event Chains are properly queued and handled between frontend and backend."""

from __future__ import annotations

from typing import Generator

import pytest
from selenium.webdriver.common.by import By

from reflex.testing import AppHarness, WebDriver

MANY_EVENTS = 50


def EventChain():
    """App with chained event handlers."""
    import asyncio
    import time
    from typing import List

    import reflex as rx

    # repeated here since the outer global isn't exported into the App module
    MANY_EVENTS = 50

    class State(rx.State):
        event_order: List[str] = []
        interim_value: str = ""

        def event_no_args(self):
            self.event_order.append("event_no_args")

        def event_arg(self, arg):
            self.event_order.append(f"event_arg:{arg}")

        def event_arg_repr_type(self, arg):
            self.event_order.append(f"event_arg_repr:{arg!r}_{type(arg).__name__}")

        def event_nested_1(self):
            self.event_order.append("event_nested_1")
            yield State.event_nested_2
            yield State.event_arg("nested_1")  # type: ignore

        def event_nested_2(self):
            self.event_order.append("event_nested_2")
            yield State.event_nested_3
            yield rx.console_log("event_nested_2")
            yield State.event_arg("nested_2")  # type: ignore

        def event_nested_3(self):
            self.event_order.append("event_nested_3")
            yield State.event_no_args
            yield State.event_arg("nested_3")  # type: ignore

        def on_load_return_chain(self):
            self.event_order.append("on_load_return_chain")
            return [State.event_arg(1), State.event_arg(2), State.event_arg(3)]  # type: ignore

        def on_load_yield_chain(self):
            self.event_order.append("on_load_yield_chain")
            yield State.event_arg(4)  # type: ignore
            yield State.event_arg(5)  # type: ignore
            yield State.event_arg(6)  # type: ignore

        def click_return_event(self):
            self.event_order.append("click_return_event")
            return State.event_no_args

        def click_return_events(self):
            self.event_order.append("click_return_events")
            return [
                State.event_arg(7),  # type: ignore
                rx.console_log("click_return_events"),
                State.event_arg(8),  # type: ignore
                State.event_arg(9),  # type: ignore
            ]

        def click_yield_chain(self):
            self.event_order.append("click_yield_chain:0")
            yield State.event_arg(10)  # type: ignore
            self.event_order.append("click_yield_chain:1")
            yield rx.console_log("click_yield_chain")
            yield State.event_arg(11)  # type: ignore
            self.event_order.append("click_yield_chain:2")
            yield State.event_arg(12)  # type: ignore
            self.event_order.append("click_yield_chain:3")

        def click_yield_many_events(self):
            self.event_order.append("click_yield_many_events")
            for ix in range(MANY_EVENTS):
                yield State.event_arg(ix)  # type: ignore
                yield rx.console_log(f"many_events_{ix}")
            self.event_order.append("click_yield_many_events_done")

        def click_yield_nested(self):
            self.event_order.append("click_yield_nested")
            yield State.event_nested_1
            yield State.event_arg("yield_nested")  # type: ignore

        def redirect_return_chain(self):
            self.event_order.append("redirect_return_chain")
            yield rx.redirect("/on-load-return-chain")

        def redirect_yield_chain(self):
            self.event_order.append("redirect_yield_chain")
            yield rx.redirect("/on-load-yield-chain")

        def click_return_int_type(self):
            self.event_order.append("click_return_int_type")
            return State.event_arg_repr_type(1)  # type: ignore

        def click_return_dict_type(self):
            self.event_order.append("click_return_dict_type")
            return State.event_arg_repr_type({"a": 1})  # type: ignore

        async def click_yield_interim_value_async(self):
            self.interim_value = "interim"
            yield
            await asyncio.sleep(0.5)
            self.interim_value = "final"

        def click_yield_interim_value(self):
            self.interim_value = "interim"
            yield
            time.sleep(0.5)
            self.interim_value = "final"

    app = rx.App(state=rx.State)

    token_input = rx.input(
        value=State.router.session.client_token, is_read_only=True, id="token"
    )

    @app.add_page
    def index():
        return rx.fragment(
            token_input,
            rx.input(value=State.interim_value, is_read_only=True, id="interim_value"),
            rx.button(
                "Return Event",
                id="return_event",
                on_click=State.click_return_event,
            ),
            rx.button(
                "Return Events",
                id="return_events",
                on_click=State.click_return_events,
            ),
            rx.button(
                "Yield Chain",
                id="yield_chain",
                on_click=State.click_yield_chain,
            ),
            rx.button(
                "Yield Many events",
                id="yield_many_events",
                on_click=State.click_yield_many_events,
            ),
            rx.button(
                "Yield Nested",
                id="yield_nested",
                on_click=State.click_yield_nested,
            ),
            rx.button(
                "Redirect Yield Chain",
                id="redirect_yield_chain",
                on_click=State.redirect_yield_chain,
            ),
            rx.button(
                "Redirect Return Chain",
                id="redirect_return_chain",
                on_click=State.redirect_return_chain,
            ),
            rx.button(
                "Click Int Type",
                id="click_int_type",
                on_click=lambda: State.event_arg_repr_type(1),  # type: ignore
            ),
            rx.button(
                "Click Dict Type",
                id="click_dict_type",
                on_click=lambda: State.event_arg_repr_type({"a": 1}),  # type: ignore
            ),
            rx.button(
                "Return Chain Int Type",
                id="return_int_type",
                on_click=State.click_return_int_type,
            ),
            rx.button(
                "Return Chain Dict Type",
                id="return_dict_type",
                on_click=State.click_return_dict_type,
            ),
            rx.button(
                "Click Yield Interim Value (Async)",
                id="click_yield_interim_value_async",
                on_click=State.click_yield_interim_value_async,
            ),
            rx.button(
                "Click Yield Interim Value",
                id="click_yield_interim_value",
                on_click=State.click_yield_interim_value,
            ),
        )

    def on_load_return_chain():
        return rx.fragment(
            rx.text("return"),
            token_input,
        )

    def on_load_yield_chain():
        return rx.fragment(
            rx.text("yield"),
            token_input,
        )

    def on_mount_return_chain():
        return rx.fragment(
            rx.text(
                "return",
                on_mount=State.on_load_return_chain,
                on_unmount=lambda: State.event_arg("unmount"),  # type: ignore
            ),
            token_input,
            rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
        )

    def on_mount_yield_chain():
        return rx.fragment(
            rx.text(
                "yield",
                on_mount=[
                    State.on_load_yield_chain,
                    lambda: State.event_arg("mount"),  # type: ignore
                ],
                on_unmount=State.event_no_args,
            ),
            token_input,
            rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"),
        )

    app.add_page(on_load_return_chain, on_load=State.on_load_return_chain)  # type: ignore
    app.add_page(on_load_yield_chain, on_load=State.on_load_yield_chain)  # type: ignore
    app.add_page(on_mount_return_chain)
    app.add_page(on_mount_yield_chain)


@pytest.fixture(scope="module")
def event_chain(tmp_path_factory) -> Generator[AppHarness, None, None]:
    """Start EventChain app at tmp_path via AppHarness.

    Args:
        tmp_path_factory: pytest tmp_path_factory fixture

    Yields:
        running AppHarness instance
    """
    with AppHarness.create(
        root=tmp_path_factory.mktemp("event_chain"),
        app_source=EventChain,  # type: ignore
    ) as harness:
        yield harness


@pytest.fixture
def driver(event_chain: AppHarness) -> Generator[WebDriver, None, None]:
    """Get an instance of the browser open to the event_chain app.

    Args:
        event_chain: harness for EventChain app

    Yields:
        WebDriver instance.
    """
    assert event_chain.app_instance is not None, "app is not running"
    driver = event_chain.frontend()
    try:
        yield driver
    finally:
        driver.quit()


def assert_token(event_chain: AppHarness, driver: WebDriver) -> str:
    """Get the token associated with backend state.

    Args:
        event_chain: harness for EventChain app.
        driver: WebDriver instance.

    Returns:
        The token visible in the driver browser.
    """
    assert event_chain.app_instance is not None
    token_input = driver.find_element(By.ID, "token")
    assert token_input

    # wait for the backend connection to send the token
    token = event_chain.poll_for_value(token_input)
    assert token is not None

    state_name = event_chain.get_full_state_name(["_state"])
    return f"{token}_{state_name}"


@pytest.mark.parametrize(
    ("button_id", "exp_event_order"),
    [
        ("return_event", ["click_return_event", "event_no_args"]),
        (
            "return_events",
            ["click_return_events", "event_arg:7", "event_arg:8", "event_arg:9"],
        ),
        (
            "yield_chain",
            [
                "click_yield_chain:0",
                "click_yield_chain:1",
                "click_yield_chain:2",
                "click_yield_chain:3",
                "event_arg:10",
                "event_arg:11",
                "event_arg:12",
            ],
        ),
        (
            "yield_many_events",
            [
                "click_yield_many_events",
                "click_yield_many_events_done",
                *[f"event_arg:{ix}" for ix in range(MANY_EVENTS)],
            ],
        ),
        (
            "yield_nested",
            [
                "click_yield_nested",
                "event_nested_1",
                "event_arg:yield_nested",
                "event_nested_2",
                "event_arg:nested_1",
                "event_nested_3",
                "event_arg:nested_2",
                "event_no_args",
                "event_arg:nested_3",
            ],
        ),
        (
            "redirect_return_chain",
            [
                "redirect_return_chain",
                "on_load_return_chain",
                "event_arg:1",
                "event_arg:2",
                "event_arg:3",
            ],
        ),
        (
            "redirect_yield_chain",
            [
                "redirect_yield_chain",
                "on_load_yield_chain",
                "event_arg:4",
                "event_arg:5",
                "event_arg:6",
            ],
        ),
        (
            "click_int_type",
            ["event_arg_repr:1_int"],
        ),
        (
            "click_dict_type",
            ["event_arg_repr:{'a': 1}_dict"],
        ),
        (
            "return_int_type",
            ["click_return_int_type", "event_arg_repr:1_int"],
        ),
        (
            "return_dict_type",
            ["click_return_dict_type", "event_arg_repr:{'a': 1}_dict"],
        ),
    ],
)
@pytest.mark.asyncio
async def test_event_chain_click(
    event_chain: AppHarness,
    driver: WebDriver,
    button_id: str,
    exp_event_order: list[str],
):
    """Click the button, assert that the events are handled in the correct order.

    Args:
        event_chain: AppHarness for the event_chain app
        driver: selenium WebDriver open to the app
        button_id: the ID of the button to click
        exp_event_order: the expected events recorded in the State
    """
    token = assert_token(event_chain, driver)
    state_name = event_chain.get_state_name("_state")
    btn = driver.find_element(By.ID, button_id)
    btn.click()

    async def _has_all_events():
        return len(
            (await event_chain.get_state(token)).substates[state_name].event_order
        ) == len(exp_event_order)

    await AppHarness._poll_for_async(_has_all_events)
    event_order = (await event_chain.get_state(token)).substates[state_name].event_order
    assert event_order == exp_event_order


@pytest.mark.parametrize(
    ("uri", "exp_event_order"),
    [
        (
            "/on-load-return-chain",
            [
                "on_load_return_chain",
                "event_arg:1",
                "event_arg:2",
                "event_arg:3",
            ],
        ),
        (
            "/on-load-yield-chain",
            [
                "on_load_yield_chain",
                "event_arg:4",
                "event_arg:5",
                "event_arg:6",
            ],
        ),
    ],
)
@pytest.mark.asyncio
async def test_event_chain_on_load(
    event_chain: AppHarness,
    driver: WebDriver,
    uri: str,
    exp_event_order: list[str],
):
    """Load the URI, assert that the events are handled in the correct order.

    Args:
        event_chain: AppHarness for the event_chain app
        driver: selenium WebDriver open to the app
        uri: the page to load
        exp_event_order: the expected events recorded in the State
    """
    assert event_chain.frontend_url is not None
    driver.get(event_chain.frontend_url + uri)
    token = assert_token(event_chain, driver)
    state_name = event_chain.get_state_name("_state")

    async def _has_all_events():
        return len(
            (await event_chain.get_state(token)).substates[state_name].event_order
        ) == len(exp_event_order)

    await AppHarness._poll_for_async(_has_all_events)
    backend_state = (await event_chain.get_state(token)).substates[state_name]
    assert backend_state.event_order == exp_event_order
    assert backend_state.is_hydrated is True


@pytest.mark.parametrize(
    ("uri", "exp_event_order"),
    [
        (
            "/on-mount-return-chain",
            [
                "on_load_return_chain",
                "event_arg:unmount",
                "on_load_return_chain",
                "event_arg:1",
                "event_arg:2",
                "event_arg:3",
                "event_arg:1",
                "event_arg:2",
                "event_arg:3",
                "event_arg:unmount",
            ],
        ),
        (
            "/on-mount-yield-chain",
            [
                "on_load_yield_chain",
                "event_arg:mount",
                "event_no_args",
                "on_load_yield_chain",
                "event_arg:mount",
                "event_arg:4",
                "event_arg:5",
                "event_arg:6",
                "event_arg:4",
                "event_arg:5",
                "event_arg:6",
                "event_no_args",
            ],
        ),
    ],
)
@pytest.mark.asyncio
async def test_event_chain_on_mount(
    event_chain: AppHarness,
    driver: WebDriver,
    uri: str,
    exp_event_order: list[str],
):
    """Load the URI, assert that the events are handled in the correct order.

    These pages use `on_mount` and `on_unmount`, which get fired twice in dev mode
    due to react StrictMode being used.

    In prod mode, these events are only fired once.

    Args:
        event_chain: AppHarness for the event_chain app
        driver: selenium WebDriver open to the app
        uri: the page to load
        exp_event_order: the expected events recorded in the State
    """
    assert event_chain.frontend_url is not None
    driver.get(event_chain.frontend_url + uri)
    token = assert_token(event_chain, driver)
    state_name = event_chain.get_state_name("_state")

    unmount_button = driver.find_element(By.ID, "unmount")
    assert unmount_button
    unmount_button.click()

    async def _has_all_events():
        return len(
            (await event_chain.get_state(token)).substates[state_name].event_order
        ) == len(exp_event_order)

    await AppHarness._poll_for_async(_has_all_events)
    event_order = (await event_chain.get_state(token)).substates[state_name].event_order
    assert event_order == exp_event_order


@pytest.mark.parametrize(
    ("button_id",),
    [
        ("click_yield_interim_value_async",),
        ("click_yield_interim_value",),
    ],
)
def test_yield_state_update(event_chain: AppHarness, driver: WebDriver, button_id: str):
    """Click the button, assert that the interim value is set, then final value is set.

    Args:
        event_chain: AppHarness for the event_chain app
        driver: selenium WebDriver open to the app
        button_id: the ID of the button to click
    """
    interim_value_input = driver.find_element(By.ID, "interim_value")
    assert_token(event_chain, driver)

    btn = driver.find_element(By.ID, button_id)
    btn.click()
    assert (
        event_chain.poll_for_value(interim_value_input, exp_not_equal="") == "interim"
    )
    assert (
        event_chain.poll_for_value(interim_value_input, exp_not_equal="interim")
        == "final"
    )