#   Copyright 2022 - 2025 The PyMC Labs Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Multivariate Interrupted Time Series Analysis for Product Incrementality."""
import json
from typing import Any, Self, cast
import arviz as az
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import pymc as pm
from matplotlib.axes import Axes
from pymc_extras.prior import Prior
from xarray import DataArray
from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor
from pymc_marketing.model_config import parse_model_config
HDI_ALPHA = 0.5
[docs]
class MVITS(ModelBuilder):
    """Multivariate Interrupted Time Series class.
    Class to perform a multivariate interrupted time series analysis with the
    specific intent of determining where the sales of a new product came from.
    Parameters
    ----------
    existing_sales : list of str
        The names of the existing products.
    saturated_market : bool, optional
        Whether the market is saturated or not. If True, the sum of the beta's will be
        1. Else, the sum of the beta's will be less than 1 with the remaining sales
        attributed to the new product.
    model_config : dict, optional
        The model configuration. If None, the default model configuration will be used.
    sampler_config : dict, optional
        The sampler configuration. If None, the default sampler configuration will be used.
    """
    _model_type = "Multivariate Interrupted Time Series"
    version = "0.1.0"
[docs]
    def __init__(
        self,
        existing_sales: list[str],
        saturated_market: bool = True,
        model_config: dict | None = None,
        sampler_config: dict | None = None,
    ):
        self.existing_sales = existing_sales
        self.saturated_market = saturated_market
        model_config = model_config or {}
        model_config = parse_model_config(model_config)
        super().__init__(model_config=model_config, sampler_config=sampler_config)
        self._distribution_checks() 
    def _distribution_checks(self):
        if self.model_config["market_distribution"].distribution != "Dirichlet":
            raise ValueError("market_distribution must be a Dirichlet distribution")  #
        dims = "existing_product" if self.saturated_market else "all_sources"
        if dims not in self.model_config["market_distribution"].dims:
            raise ValueError(
                f"market_distribution must have dims='{dims}', not {self.model_config['market_distribution'].dims}"
            )
[docs]
    def create_idata_attrs(self) -> dict[str, str]:
        """Create the attributes for the InferenceData object.
        Returns
        -------
        dict[str, str]
            The attributes for the InferenceData object.
        """
        attrs = super().create_idata_attrs()
        attrs["existing_sales"] = json.dumps(self.existing_sales)
        attrs["saturated_market"] = json.dumps(self.saturated_market)
        return attrs 
[docs]
    @classmethod
    def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
        """Convert the attributes of the InferenceData object to the __init__ kwargs.
        Parameters
        ----------
        attrs : dict
            The attributes of the InferenceData object.
        Returns
        -------
        dict
            The __init__ kwargs for the class.
        """
        return {
            "model_config": json.loads(attrs["model_config"]),
            "sampler_config": json.loads(attrs["sampler_config"]),
            "existing_sales": json.loads(attrs["existing_sales"]),
            "saturated_market": json.loads(attrs["saturated_market"]),
        } 
    @property
    def default_model_config(self) -> dict:
        """Default model configuration.
        This is TruncatedNormal likelihood with a HalfNormal sigma, Normal intercept,
        and a Dirichlet market distribution
        Returns
        -------
        dict
            The default model configuration.
        """
        if self.saturated_market:
            a = np.full(len(self.existing_sales), 0.5)
            dims = "existing_product"
        else:
            a = np.full(len(self.existing_sales) + 1, 0.5)
            dims = "all_sources"
        market_distribution = Prior("Dirichlet", a=a, dims=dims)
        return {
            "intercept": Prior("Normal", dims="existing_product"),
            "likelihood": Prior(
                "TruncatedNormal",
                lower=0,
                sigma=Prior("HalfNormal", dims="existing_product"),
                dims=("time", "existing_product"),
            ),
            "market_distribution": market_distribution,
        }
    @property
    def default_sampler_config(self) -> dict:
        """Default sampler configuration."""
        return {}
    @property
    def output_var(self) -> str:
        """The output variable of the model."""
        return "y"
    @property
    def _serializable_model_config(self) -> dict[str, int | float | dict]:  # type: ignore
        result: dict[str, int | float | dict] = {
            "intercept": self.model_config["intercept"].to_dict(),
            "likelihood": self.model_config["likelihood"].to_dict(),
            "market_distribution": self.model_config["market_distribution"].to_dict(),
        }
        return result
    def _generate_and_preprocess_model_data(
        self,
        X: pd.DataFrame | pd.Series,
        y: np.ndarray,
    ) -> None:
        if isinstance(X, pd.Series):
            raise ValueError("X must be a DataFrame, not a Series")  # pragma: no cover
        self.X = X[self.existing_sales]
        self.y = pd.Series(y, index=X.index, name=self.output_var)
        # note: type hints for coords required for mypy to not get confused
        self.coords: dict[str, list[str]] = {
            "existing_product": list(self.existing_sales),
            "time": list(X.index.values),
            "all_sources": [
                *list(self.existing_sales),
                "new",
            ],
        }
[docs]
    def build_model(
        self,
        X: pd.DataFrame,
        y: pd.Series | np.ndarray,
        **kwargs,
    ) -> None:
        """Build a PyMC model for a multivariate interrupted time series analysis.
        Parameters
        ----------
        X : pd.DataFrame
            The data for the existing products.
        y : np.ndarray | pd.Series
            The data for the new product.
        """
        self._generate_and_preprocess_model_data(X, y)  # type: ignore
        with pm.Model(coords=self.coords) as model:
            # data
            _existing_sales = pm.Data(
                "existing_sales",
                X.values,
                dims=("time", "existing_product"),
            )
            y = pm.Data(
                "treatment_sales",
                y if not isinstance(y, pd.Series) else y.values,
                dims="time",
            )
            # priors
            intercept = self.model_config["intercept"].create_variable(name="intercept")
            if self.saturated_market:
                """We assume the market is saturated. The sum of the beta's will be 1.
                This means that the reduction in sales of existing products will equal
                the increase in sales of the new product, such that the total sales
                remain constant."""
                beta = self.model_config["market_distribution"].create_variable("beta")
            else:
                """We assume the market is not saturated. The sum of the beta's will be
                less than 1. This means that the reduction in sales of existing products
                will be less than the increase in sales of the new product."""
                beta_all = self.model_config["market_distribution"].create_variable(
                    "beta_all",
                )
                beta = pm.Deterministic(
                    "beta",
                    beta_all[:-1],
                    dims="existing_product",
                )
                pm.Deterministic("new sales", beta_all[-1])
            # expectation
            mu = pm.Deterministic(
                "mu",
                intercept[None, :] - y[:, None] * beta[None, :],
                dims=("time", "existing_product"),
            )
            # likelihood
            self.model_config["likelihood"].create_likelihood_variable(
                name=self.output_var,
                mu=mu,
                observed=_existing_sales,
            )
        self.model = model 
    def _data_setter(
        self,
        X: np.ndarray | pd.DataFrame,
        y: np.ndarray | pd.Series | None = None,
    ) -> None:
        """Set the data.
        Required from the parent class
        """
[docs]
    def calculate_counterfactual(
        self,
        random_seed: np.random.Generator | int | None = None,
    ) -> None:
        """Calculate the counterfactual scenario of never releasing the new product.
        Extends the InferenceData object
        Parameters
        ----------
        random_seed : np.random.Generator | int, optional
            The random seed for the sampling.
        """
        if not hasattr(self, "model"):
            raise RuntimeError("Call the 'fit' method first.")
        zero_sales = np.zeros_like(self.y, dtype=np.int32)
        self.counterfactual_model = pm.do(self.model, {"treatment_sales": zero_sales})
        with self.counterfactual_model:
            self.idata.extend(  # type: ignore
                pm.sample_posterior_predictive(
                    self.posterior,
                    var_names=["mu", self.output_var],
                    random_seed=random_seed,
                    predictions=True,
                )
            ) 
[docs]
    def sample(
        self,
        X,
        y,
        random_seed: np.random.Generator | int | None = None,
        sample_prior_predictive_kwargs: dict | None = None,
        fit_kwargs: dict | None = None,
        sample_posterior_predictive_kwargs: dict | None = None,
    ) -> Self:
        """Sample all the things.
        Run all of the sample methods in the sequence:
        - :meth:`sample_prior_predictive`
        - :meth:`fit`
        - :meth:`sample_posterior_predictive`
        - :meth:`calculate_counterfactual`
        Parameters
        ----------
        X : pd.DataFrame
            The data for the existing products.
        y : np.ndarray | pd.Series
            The data for the new product.
        random_seed : np.random.Generator | int, optional
            The random seed for each stage of sampling.
        sample_prior_predictive_kwargs : dict, optional
            The keyword arguments for the sample_prior_predictive method.
        fit_kwargs : dict, optional
            The keyword arguments for the fit method.
        sample_posterior_predictive_kwargs : dict, optional
            The keyword arguments for the sample_posterior_predictive method.
        Returns
        -------
        Self
            The model instance.
        """
        sample_prior_predictive_kwargs = sample_prior_predictive_kwargs or {}
        fit_kwargs = fit_kwargs or {}
        sample_posterior_predictive_kwargs = sample_posterior_predictive_kwargs or {}
        self.sample_prior_predictive(
            X,
            y,
            random_seed=random_seed,
            **sample_prior_predictive_kwargs,
        )
        self.fit(X, y, random_seed=random_seed, **fit_kwargs)
        self.sample_posterior_predictive(
            X,
            random_seed=random_seed,
            var_names=[self.output_var, "mu"],
            **sample_posterior_predictive_kwargs,
        )
        self.calculate_counterfactual(random_seed=random_seed)
        return self 
[docs]
    def causal_impact(self, variable: str = "mu") -> DataArray:
        """Calculate the causal impact of the new product on the existing products.
        Note: if we compare "mu" then we are comparing the expected sales, if we compare
        "y" then we are comparing the actual sales
        Parameters
        ----------
        variable : str, optional
            The variable to compare. Either "mu" or "y".
        Returns
        -------
        xr.DataArray
            The causal impact of the new product on the existing products.
        """
        if variable not in ["mu", "y"]:
            raise ValueError(
                f"variable must be either 'mu' or 'y', not {variable}"
            )  # pragma: no cover
        return self.posterior_predictive[variable] - self.predictions[variable] 
[docs]
    def plot_fit(
        self,
        variable: str = "mu",
        plot_total_sales: bool = True,
        ax: Axes | None = None,
    ):
        """Plot the model fit (posterior predictive) of the existing products.
        Parameters
        ----------
        variable : str, optional
            The variable to compare. Either "mu" or "y".
        plot_total_sales : bool, optional
            Whether to plot the total sales or not.
        ax : plt.Axes, optional
            The matplotlib axes.
        Returns
        -------
        plt.Axes
            The new or modified matplotlib axes.
        """
        if ax is None:
            _, ax = plt.subplots()
        ax = cast(Axes, ax)
        if variable not in ["mu", "y"]:
            raise ValueError(
                f"variable must be either 'mu' or 'y', not {variable}"
            )  # pragma: no cover
        # plot data
        self.plot_data(ax=ax, plot_total_sales=plot_total_sales)
        # plot posterior predictive distribution of sales for each of the existing products
        x = self.X.index.values  # type: ignore
        existing_products = self.coords["existing_product"]
        for i, existing_product in enumerate(existing_products):
            az.plot_hdi(
                x,
                self.posterior_predictive[variable]  # type: ignore
                .transpose(..., "time")
                .sel(existing_product=existing_product),
                fill_kwargs={
                    "alpha": HDI_ALPHA,
                    "color": f"C{i}",
                    "label": "Posterior predictive (HDI)",
                },
                smooth=False,
            )
        # formatting
        ax.legend()
        ax.set(title="Model fit of sales of existing products", ylabel="Sales")
        return ax 
[docs]
    def plot_counterfactual(
        self,
        variable: str = "mu",
        plot_total_sales: bool = True,
        ax: Axes | None = None,
    ):
        """Plot counterfactual scenario.
        Plot the predicted sales of the existing products under the counterfactual
        scenario of never releasing the new product.
        Parameters
        ----------
        variable : str, optional
            The variable to compare. Either "mu" or "y".
        plot_total_sales : bool, optional
            Whether to plot the total sales or not.
        axes : plt.Axes, optional
            The matplotlib axes.
        Returns
        -------
        plt.Axes
            The new or modified matplotlib axes.
        """
        if ax is None:
            _, ax = plt.subplots()
        ax = cast(Axes, ax)
        if variable not in ["mu", "y"]:
            raise ValueError(
                f"variable must be either 'mu' or 'y', not {variable}"
            )  # pragma: no cover
        # plot data
        self.plot_data(ax=ax, plot_total_sales=plot_total_sales)
        # plot posterior predictive distribution of sales for each of the existing products
        x = cast(pd.DataFrame, self.X).index.values
        existing_products = self.coords["existing_product"]
        for i, existing_product in enumerate(existing_products):
            az.plot_hdi(
                x,
                self.predictions[variable]  # type: ignore
                .transpose(..., "time")
                .sel(existing_product=existing_product),
                fill_kwargs={
                    "alpha": HDI_ALPHA,
                    "color": f"C{i}",
                    "label": "Posterior predictive (HDI)",
                },
                smooth=False,
            )
        # formatting
        ax.legend()
        ax.set(
            title="Model predictions under the counterfactual scenario", ylabel="Sales"
        )
        return ax 
[docs]
    def plot_causal_impact_sales(self, variable: str = "mu", ax: Axes | None = None):
        """Plot causal impact of sales.
        Plot the inferred causal impact of the new product on the sales of the
        existing products.
        Note: if we compare "mu" then we are comparing the expected sales, if we compare
        "y" then we are comparing the actual sales
        Parameters
        ----------
        variable : str, optional
            The variable to compare. Either "mu" or "y".
        ax : plt.Axes, optional
            The matplotlib axes.
        Returns
        -------
        plt.Axes
            The new or modified matplotlib axes.
        """
        if ax is None:
            _, ax = plt.subplots()
        ax = cast(Axes, ax)
        # plot posterior predictive distribution of sales for each of the existing products
        x = self.X.index.values  # type: ignore
        existing_products = self.coords["existing_product"]
        for i, existing_product in enumerate(existing_products):
            az.plot_hdi(
                x,
                self.causal_impact(variable=variable)
                .transpose(..., "time")
                .sel(existing_product=existing_product),
                fill_kwargs={
                    "alpha": HDI_ALPHA,
                    "color": f"C{i}",
                    "label": "Posterior predictive (HDI)",
                },
                smooth=False,
            )
        ax.set(ylabel="Change in sales caused by new product")
        # formatting
        ax.legend()
        ax.set(title="Estimated causal impact of new product upon existing products")
        return ax 
[docs]
    def plot_causal_impact_market_share(
        self, variable: str = "mu", ax: Axes | None = None
    ):
        """Plot the inferred causal impact of the new product on the existing products.
        Note: if we compare "mu" then we are comparing the expected sales, if we compare
        "y" then we are comparing the actual sales
        Parameters
        ----------
        variable : str, optional
            The variable to compare. Either "mu" or "y".
        ax : plt.Axes, optional
            The matplotlib axes.
        Returns
        -------
        plt.Axes
            The new or modified matplotlib axes.
        """
        if ax is None:
            _, ax = plt.subplots()
        ax = cast(Axes, ax)
        # plot posterior predictive distribution of sales for each of the existing products
        x = self.X.index.values  # type: ignore
        existing_products = list(self.idata.observed_data.existing_product.data)  # type: ignore
        # divide the causal impact change in sales by the counterfactual predicted sales
        variable = "mu"
        for i, existing_product in enumerate(existing_products):
            causal_impact = (
                self.causal_impact(variable=variable)
                .transpose(..., "time")
                .sel(existing_product=existing_product)
            )
            total_sales = (
                self.predictions[variable]  # type: ignore
                .transpose(..., "time")
                .sum(dim="existing_product")
            )
            causal_impact_market_share = (causal_impact / total_sales) * 100
            az.plot_hdi(
                x,
                causal_impact_market_share,
                fill_kwargs={
                    "alpha": HDI_ALPHA,
                    "color": f"C{i}",
                    "label": f"{existing_product} - Posterior predictive (HDI)",
                },
                smooth=False,
            )
        ax.set(ylabel="Change in market share caused by new product")
        ax.yaxis.set_major_formatter(mtick.PercentFormatter())
        # formatting
        ax.legend()
        ax.set(title="Estimated causal impact of new product upon existing products")
        return ax 
[docs]
    def plot_data(self, plot_total_sales: bool = True, ax: Axes | None = None):
        """Plot the observed data.
        Wrapper around the plot_product function.
        Parameters
        ----------
        plot_total_sales : bool, optional
            Whether to plot the total sales or not.
        ax : plt.Axes, optional
            The matplotlib axes.
        Returns
        -------
        plt.Axes
            The new or modified matplotlib axes.
        """
        data = pd.concat([self.X, self.y], axis=1)  # type: ignore
        return plot_product(data=data, ax=ax, plot_total_sales=plot_total_sales) 
    predictions = create_idata_accessor(
        "predictions",
        "Call the 'calculate_counterfactual' method first.",
    ) 
[docs]
def plot_product(
    data: pd.DataFrame,
    plot_total_sales: bool = True,
    ax: Axes | None = None,
) -> Axes:
    """Plot the sales of a single product.
    Parameters
    ----------
    data : pd.DataFrame
        The sales data.
    plot_total_sales : bool, optional
        Whether to plot the total sales or not.
    ax : plt.Axes, optional
        The matplotlib axes.
    Returns
    -------
    plt.Axes
        The new or modified matplotlib axes.
    """
    if ax is None:
        _, ax = plt.subplots()
    ax = cast(Axes, ax)
    data.plot(ax=ax)
    if plot_total_sales:
        data.sum(axis=1).plot(label="total sales", color="black", ax=ax)
    ax.set_ylim(bottom=0)
    ax.set(ylabel="Sales")
    return ax