#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Plot distributions stored in xarray.DataArray across coordinates.
Used to plot the prior and posterior of the various MMM components.
See the :func:`plot_curve` function for more information.
"""
import warnings
from collections.abc import Callable, Generator, Iterable, MutableMapping, Sequence
from itertools import product, repeat
from typing import Any, Concatenate, ParamSpec, cast
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]
[docs]
def get_plot_coords(coords: Coords, non_grid_names: set[str]) -> Coords:
    """Get the plot coordinates.
    Parameters
    ----------
    coords : Coords
        The coordinates to get the plot coordinates from.
    non_grid_names : set[str]
        The names to exclude from the grid.
    Returns
    -------
    Coords
        The plot coordinates.
    """
    plot_coord_names = list(key for key in coords.keys() if key not in non_grid_names)
    return {name: np.array(coords[name]) for name in plot_coord_names} 
[docs]
def drop_scalar_coords(curve: xr.DataArray) -> xr.DataArray:
    """Remove scalar coordinates from an xarray DataArray.
    This function identifies and removes scalar coordinates from the given
    DataArray. Scalar coordinates are those with a single value that are
    not part of the DataArray's indexes. The function returns a new DataArray
    with the scalar coordinates removed.
    Parameters
    ----------
    curve : xr.DataArray
        The input DataArray from which scalar coordinates will be removed.
    Returns
    -------
    xr.DataArray
        A new DataArray with the identified scalar coordinates removed.
    """
    scalar_coords_to_drop = []
    for coord, values in curve.coords.items():
        if values.size == 1 and coord not in curve.indexes:
            scalar_coords_to_drop.append(coord)
    return curve.reset_coords(scalar_coords_to_drop, drop=True) 
[docs]
def get_total_coord_size(coords: Coords) -> int:
    """Get the total size of the coordinates.
    Parameters
    ----------
    coords : Coords
        The coordinates to get the total size of.
    Returns
    -------
    int
        The total size of the coordinates.
    """
    total_size: int = (
        1 if coords == {} else np.prod([len(values) for values in coords.values()])  # type: ignore
    )
    if total_size >= 12:
        warnings.warn("Large number of coordinates!", stacklevel=2)
    return total_size 
[docs]
def create_legend_handles(
    colors: Iterable[str],
    alpha: float = 0.5,
    line: bool = True,
    patch: bool = True,
) -> list[Line2D | Patch | tuple[Line2D, Patch]]:
    """Create the legend handles for the given colors.
    Parameters
    ----------
    colors : Iterable[str]
        The colors to create the legend handles.
    alpha : float, optional
        The alpha value for the patches, by default 0.5.
    line : bool, optional
        Whether to include the line, by default True.
    patch : bool, optional
        Whether to include the patch, by default True.
    Returns
    -------
    list[Line2D | Patch | tuple[Line2D, Patch]]
        The legend handles.
    """
    if not line and not patch:
        raise ValueError("At least one of line or patch must be True")
    def create_handle(
        color: str, alpha: float
    ) -> Line2D | Patch | tuple[Line2D, Patch]:
        if line and patch:
            return Line2D([0], [0], color=color), Patch(color=color, alpha=alpha)
        if line:
            return Line2D([0], [0], color=color)
        return Patch(color=color, alpha=alpha)
    return [create_handle(color, alpha) for color in colors] 
[docs]
def set_subplot_kwargs_defaults(
    subplot_kwargs: MutableMapping[str, Any],
    total_size: int,
) -> None:
    """Set the defaults for the subplot kwargs.
    Parameters
    ----------
    subplot_kwargs : MutableMapping[str, Any]
        The subplot kwargs to set the defaults for.
    total_size : int
        The total size of the coordinates.
    Raises
    ------
    ValueError
        If both `ncols` and `nrows` are specified.
    """
    if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs:
        raise ValueError("Only specify one")
    if "ncols" not in subplot_kwargs and "nrows" not in subplot_kwargs:
        subplot_kwargs["ncols"] = total_size
    if "ncols" in subplot_kwargs:
        subplot_kwargs["nrows"] = total_size // subplot_kwargs["ncols"]
    elif "nrows" in subplot_kwargs:
        subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"] 
Selection = dict[str, Any]
[docs]
def selections(
    coords: Coords,
) -> Generator[Selection, None, None]:
    """Create generator of selections.
    Parameters
    ----------
    coords : Coords
        The coordinates to create the selections from.
    Yields
    ------
    dict[str, Any]
        The selections.
    """
    coord_names = coords.keys()
    for values in product(*coords.values()):
        yield {name: value for name, value in zip(coord_names, values, strict=True)} 
P = ParamSpec("P")
GetPlotData = Callable[[xr.DataArray], xr.DataArray]
MakeSelection = Callable[[xr.DataArray, Selection], pd.DataFrame]
PlotSelection = Callable[Concatenate[pd.DataFrame, Axes, str, P], Axes]
def _get_sample_plot_data(data):
    return data
def _create_make_sample_selection(
    rng,
    n: int,
    n_chains: int,
    n_draws: int,
) -> MakeSelection:
    rng = rng or np.random.default_rng()
    idx = random_samples(
        rng,
        n=n,
        n_chains=n_chains,
        n_draws=n_draws,
    )
    def make_sample_selection(data, sel):
        return data.sel(sel).to_series().unstack().loc[idx, :].T
    return make_sample_selection
def _plot_sample_selection(df, ax: Axes, color: str, **plot_kwargs) -> Axes:
    return df.plot(ax=ax, color=color, **plot_kwargs)
def _create_get_hdi_plot_data(hdi_kwargs) -> GetPlotData:
    def get_plot_data(data: xr.DataArray) -> xr.DataArray:
        hdi: xr.Dataset = az.hdi(data, **hdi_kwargs)
        return hdi[data.name]
    return get_plot_data
def _make_hdi_selection(data: xr.DataArray, sel: dict[str, Any]) -> pd.DataFrame:
    return data.sel(sel).to_series().unstack()
def _plot_hdi_selection(
    df: pd.DataFrame,
    ax: Axes,
    color: str,
    **plot_kwargs,
) -> Axes:
    ax.fill_between(
        x=df.index,
        y1=df["lower"],
        y2=df["higher"],
        color=color,
        **plot_kwargs,
    )
    return ax
SelToString = Callable[[Selection], str]
[docs]
def random_samples(
    rng: np.random.Generator,
    n: int,
    n_chains: int,
    n_draws: int,
) -> list[tuple[int, int]]:
    """Generate random samples from the chains and draws.
    Parameters
    ----------
    rng : np.random.Generator
        Random number generator
    n : int
        Number of samples to generate
    n_chains : int
        Number of chains
    n_draws : int
        Number of draws
    Returns
    -------
    list[tuple[int, int]]
        The random samples
    """
    combinations = list(product(range(n_chains), range(n_draws)))
    return [
        tuple(pair) for pair in list(rng.choice(combinations, size=n, replace=False))
    ] 
[docs]
def generate_colors(n: int, start: int = 0) -> list[str]:
    """Generate list of colors.
    Parameters
    ----------
    n : int
        Number of colors to generate
    start : int, optional
        Starting index, by default 0
    Returns
    -------
    list[str]
        List of colors
    Examples
    --------
    Generate 5 colors starting from index 1
    .. code-block:: python
        colors = generate_colors(5, start=1)
        print(colors)
        # ['C1', 'C2', 'C3', 'C4', 'C5']
    """
    return [f"C{i}" for i in range(start, start + n)] 
def _plot_across_coord(
    curve: xr.DataArray,
    non_grid_names: set[str],
    get_plot_data: GetPlotData,
    make_selection: MakeSelection,
    plot_selection: PlotSelection,
    subplot_kwargs: dict | None = None,
    axes: npt.NDArray[Axes] | None = None,
    same_axes: bool = False,
    colors: Iterable[str] | None = None,
    legend: bool = False,
    plot_kwargs: dict[str, Any] | None = None,
    patch: bool = True,
    line: bool = True,
    sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
    """Plot data array across coords.
    Commonality used for the `plot_samples` and `plot_hdi` functions.
    Differences depending on the `get_plot_data`, `make_selection` and
    `plot_selection` functions passed.
    Allows for plotting each coordinate combination on a separate axis
    or on the same axis.
    """
    if sel_to_string is None:
        def sel_to_string(sel):
            return ", ".join(f"{key}={value}" for key, value in sel.items())
    curve = drop_scalar_coords(curve)
    data = get_plot_data(curve)
    plot_coords = get_plot_coords(
        data.coords,
        non_grid_names=non_grid_names.union({"chain", "draw", "hdi"}),
    )
    total_size = get_total_coord_size(plot_coords)
    if axes is None and not same_axes:
        subplot_kwargs = subplot_kwargs or {}
        subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs}
        set_subplot_kwargs_defaults(subplot_kwargs, total_size)
        fig, axes = plt.subplots(**subplot_kwargs)
        axes_iter = np.ravel(axes)
        return_axes = axes
        create_title = sel_to_string
        create_legend_label = None
    elif axes is not None and same_axes:
        fig = plt.gcf()
        axes_iter = repeat(axes[0], total_size)  # type: ignore
        return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes
        def create_title(sel):
            return ""
        create_legend_label = sel_to_string
    elif axes is None and same_axes:
        fig, ax = plt.subplots(ncols=1, nrows=1)
        axes_iter = repeat(ax, total_size)  # type: ignore
        return_axes = np.array([ax])
        def create_title(sel):
            return ""
        create_legend_label = sel_to_string
    else:
        fig = plt.gcf()
        axes_iter = np.ravel(axes)  # type: ignore
        return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes
        create_title = sel_to_string  # type: ignore
        create_legend_label = None
    colors = cast(Iterable[str], colors or generate_colors(n=total_size, start=0))
    for color, ax, sel in zip(colors, axes_iter, selections(plot_coords), strict=False):
        ax = data.pipe(make_selection, sel=sel).pipe(
            plot_selection,
            ax=ax,
            color=color,
            **plot_kwargs,
        )
        title = create_title(sel)
        ax.set_title(title)
    if same_axes and legend and create_legend_label is not None:
        handles = create_legend_handles(colors, patch=patch, line=line)
        labels = [create_legend_label(sel) for sel in selections(plot_coords)]
        ax.legend(handles=handles, labels=labels)
    return fig, return_axes
[docs]
def plot_hdi(
    curve: xr.DataArray,
    non_grid_names: str | set[str],
    hdi_prob: float | None = None,
    hdi_kwargs: dict | None = None,
    subplot_kwargs: dict[str, Any] | None = None,
    plot_kwargs: dict[str, Any] | None = None,
    axes: npt.NDArray[Axes] | None = None,
    same_axes: bool = False,
    colors: Iterable[str] | None = None,
    legend: bool = False,
    sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
    """Plot hdi of the curve across coords.
    Parameters
    ----------
    curve : xr.DataArray
        Curve to plot
    non_grid_names : str | set[str]
        The names to exclude from the grid. chain and draw are
        excluded automatically
    n : int, optional
        Number of samples to plot
    rng : np.random.Generator, optional
        Random number generator
    axes : npt.NDArray[plt.Axes], optional
        Axes to plot on
    subplot_kwargs : dict, optional
        Additional kwargs to while creating the fig and axes
    plot_kwargs : dict, optional
        Kwargs for the plot function
    Returns
    -------
    tuple[plt.Figure, npt.NDArray[plt.Axes]]
        Figure and the axes
    """
    hdi_kwargs = hdi_kwargs or {}
    hdi_kwargs = {**dict(hdi_prob=hdi_prob), **hdi_kwargs}
    get_plot_data = _create_get_hdi_plot_data(hdi_kwargs)
    make_selection = _make_hdi_selection
    plot_selection = _plot_hdi_selection
    if isinstance(non_grid_names, str):
        non_grid_names = {non_grid_names}
    plot_kwargs = plot_kwargs or {}
    plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs}
    return _plot_across_coord(
        curve=curve,
        non_grid_names=non_grid_names,
        get_plot_data=get_plot_data,
        make_selection=make_selection,
        plot_selection=plot_selection,
        subplot_kwargs=subplot_kwargs,
        same_axes=same_axes,
        axes=axes,
        colors=colors,
        legend=legend,
        plot_kwargs=plot_kwargs,
        patch=True,
        line=False,
        sel_to_string=sel_to_string,
    ) 
[docs]
def plot_samples(
    curve: xr.DataArray,
    non_grid_names: str | set[str],
    n: int = 10,
    rng: np.random.Generator | None = None,
    axes: npt.NDArray[Axes] | None = None,
    subplot_kwargs: dict[str, Any] | None = None,
    plot_kwargs: dict[str, Any] | None = None,
    same_axes: bool = False,
    colors: Iterable[str] | None = None,
    legend: bool = False,
    sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
    """Plot n samples of the curve across coords.
    Parameters
    ----------
    curve : xr.DataArray
        Curve to plot
    non_grid_names : str | set[str]
        The names to exclude from the grid. chain and draw are
        excluded automatically
    n : int, optional
        Number of samples to plot
    rng : np.random.Generator, optional
        Random number generator
    axes : npt.NDArray[plt.Axes], optional
        Axes to plot on
    subplot_kwargs : dict, optional
        Additional kwargs to while creating the fig and axes
    plot_kwargs : dict, optional
        Kwargs for the plot function
    same_axes : bool
        All of the plots in the same axis
    Returns
    -------
    tuple[plt.Figure, npt.NDArray[plt.Axes]]
        Figure and the axes
    """
    get_plot_data = _get_sample_plot_data
    if isinstance(non_grid_names, str):
        non_grid_names = {non_grid_names}
    n_chains = curve.sizes["chain"]
    n_draws = curve.sizes["draw"]
    make_selection = _create_make_sample_selection(
        rng=rng,
        n=n,
        n_chains=n_chains,
        n_draws=n_draws,
    )
    plot_selection = _plot_sample_selection
    plot_kwargs = plot_kwargs or {}
    plot_kwargs = {
        **{"alpha": 0.3, "legend": False},
        **plot_kwargs,
    }
    return _plot_across_coord(
        curve=curve,
        non_grid_names=non_grid_names,
        get_plot_data=get_plot_data,
        make_selection=make_selection,
        plot_selection=plot_selection,
        subplot_kwargs=subplot_kwargs,
        plot_kwargs=plot_kwargs,
        same_axes=same_axes,
        axes=axes,
        colors=colors,
        legend=legend,
        patch=False,
        line=True,
        sel_to_string=sel_to_string,
    ) 
[docs]
def plot_curve(
    curve: xr.DataArray,
    non_grid_names: str | set[str],
    n_samples: int = 10,
    hdi_probs: float | list[float] | None = None,
    random_seed: np.random.Generator | None = None,
    subplot_kwargs: dict | None = None,
    sample_kwargs: dict | None = None,
    hdi_kwargs: dict | None = None,
    axes: npt.NDArray[Axes] | None = None,
    same_axes: bool = False,
    colors: Iterable[str] | None = None,
    legend: bool | None = None,
    sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
    """Plot HDI with samples of the curve across coords.
    Parameters
    ----------
    curve : xr.DataArray
        Curve to plot
    non_grid_names : str | set[str]
        The names to exclude from the grid. HDI and samples both
        have defaults of hdi and chain, draw, respectively
    n_samples : int, optional
        Number of samples
    hdi_probs : float | list[float], optional
        HDI probabilities. Defaults to None which uses arviz default for
        stats.ci_prob which is 94%
    random_seed : np.random.Generator, optional
        Random number generator. Defaults to None which uses
        np.random.default_rng()
    subplot_kwargs : dict, optional
        Additional kwargs to while creating the fig and axes
    sample_kwargs : dict, optional
        Kwargs for the :func:`plot_samples` function
    hdi_kwargs : dict, optional
        Kwargs for the :func:`plot_hdi` function
    same_axes : bool
        If all of the plots are on the same axis
    colors : Iterable[str], optional
        Colors for the plots
    legend : bool, optional
        If to include a legend. Defaults to True if same_axes
    sel_to_string : Callable[[Selection], str], optional
        Function to convert selection to a string. Defaults to
        ", ".join(f"{key}={value}" for key, value in sel.items())
    Returns
    -------
    tuple[plt.Figure, npt.NDArray[plt.Axes]]
        Figure and the axes
    Examples
    --------
    Plot prior for arbitrary Deterministic in PyMC model
    .. plot::
        :include-source: True
        :context: reset
        import numpy as np
        import pandas as pd
        import pymc as pm
        import matplotlib.pyplot as plt
        from pymc_marketing.plot import plot_curve
        seed = sum(map(ord, "Arbitrary curve"))
        rng = np.random.default_rng(seed)
        dates = pd.date_range("2024-01-01", periods=52, freq="W")
        coords = {"date": dates, "product": ["A", "B"]}
        with pm.Model(coords=coords) as model:
            data = pm.Normal(
                "data",
                mu=[-0.5, 0.5],
                sigma=1,
                dims=("date", "product"),
            )
            cumsum = pm.Deterministic(
                "cumsum",
                data.cumsum(axis=0),
                dims=("date", "product"),
            )
            idata = pm.sample_prior_predictive(random_seed=rng)
        curve = idata.prior["cumsum"]
        fig, axes = plot_curve(
            curve,
            "date",
            subplot_kwargs={"figsize": (15, 5)},
            random_seed=rng,
        )
        plt.show()
    Choose the HDI intervals and number of samples
    .. plot::
        :include-source: True
        :context: reset
        fig, axes = plot_curve(
            curve,
            "date",
            n_samples=3,
            hdi_probs=[0.5, 0.95],
            random_seed=rng,
        )
        fig.suptitle("Same data but fewer lines and more HDIs")
        plt.show()
    Plot same curve on same axes with custom colors
    .. plot::
        :include-source: True
        :context: close-figs
        colors = ["red", "blue"]
        fig, axes = plot_curve(
            curve,
            "date",
            same_axes=True,
            colors=colors,
            random_seed=rng,
        )
        axes[0].set(title="Same data but on same axes and custom colors")
        plt.show()
    """
    curve = drop_scalar_coords(curve)
    hdi_probs = hdi_probs or None
    if not isinstance(hdi_probs, list):
        hdi_probs = [hdi_probs]  # type: ignore
    hdi_kwargs = hdi_kwargs or {}
    sample_kwargs = sample_kwargs or {}
    sample_kwargs = {**dict(n=n_samples, rng=random_seed), **sample_kwargs}
    if "subplot_kwargs" not in sample_kwargs:
        sample_kwargs["subplot_kwargs"] = subplot_kwargs
    if "axes" not in sample_kwargs:
        sample_kwargs["axes"] = axes
    if same_axes:
        sample_kwargs["same_axes"] = True
        sample_kwargs["legend"] = False
        hdi_kwargs["same_axes"] = True
        hdi_kwargs["legend"] = legend if isinstance(legend, bool) else True
    if colors is not None:
        sample_kwargs["colors"] = colors
        hdi_kwargs["colors"] = colors
    if sel_to_string is not None:
        sample_kwargs["sel_to_string"] = sel_to_string
        hdi_kwargs["sel_to_string"] = sel_to_string
    fig, axes = plot_samples(
        curve,
        non_grid_names=non_grid_names,
        **sample_kwargs,
    )
    for hdi_prob in hdi_probs:
        fig, axes = plot_hdi(
            curve,
            hdi_prob=hdi_prob,
            non_grid_names=non_grid_names,
            axes=axes,
            **hdi_kwargs,
        )
    return fig, axes