import matplotlib.pyplot as plt
import pandas as pd
from typing import Any, Dict, Optional
import seaborn as sns
from .logger import logger
[docs]
def plot_power_analysis(
power_df: pd.DataFrame, save_path: Optional[str] = None
) -> None:
"""
Plots the results of a power analysis as a heatmap.
Parameters
----------
power_df : pd.DataFrame
The DataFrame returned by the run_power_analysis function.
It should contain 'effect_size', 'test_weeks', and 'power' columns.
save_path : str, optional
If provided, saves the plot to this path. Otherwise, displays the plot.
"""
if not all(
col in power_df.columns for col in ["effect_size", "test_weeks", "power"]
):
raise ValueError(
"Input DataFrame must contain 'effect_size', 'test_weeks', and 'power' "
"columns."
)
power_pivot = power_df.pivot(
index="test_weeks", columns="effect_size", values="power"
)
plt.figure(figsize=(10, 7))
sns.heatmap(
power_pivot,
annot=True,
fmt=".2f",
cmap="viridis",
cbar_kws={"label": "Statistical Power"},
)
plt.title("Power Analysis Results", fontsize=16)
plt.xlabel("Assumed Effect Size (Lift)")
plt.ylabel("Test Duration (Weeks)")
if save_path:
plt.savefig(save_path)
logger.info(f"Plot saved to {save_path}")
else:
plt.show()
plt.close()
[docs]
def plot_lift_distribution(
analysis_df: pd.DataFrame,
assignment_col: str,
lift_col: str = "lift_index",
save_path: Optional[str] = None,
) -> None:
"""
Plots the distribution of the lift index by group.
Parameters
----------
analysis_df : pd.DataFrame
The DataFrame containing the per-geo lift calculations.
This is returned by the `run_primary_analysis` function.
assignment_col : str
The name of the column with group assignments.
lift_col : str, optional
The name of the lift index column, by default 'lift_index'.
save_path : str, optional
If provided, saves the plot to this path. Otherwise, displays the plot.
"""
plt.figure(figsize=(10, 7))
sns.boxplot(
data=analysis_df, x=assignment_col, y=lift_col, order=["Control", "Treatment"]
)
plt.title("Distribution of Lift Index by Group", fontsize=16)
plt.xlabel("Group")
plt.ylabel("Lift Index")
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
# Format y-axis as percentage
from matplotlib.ticker import PercentFormatter
plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.tight_layout()
if save_path:
plt.savefig(save_path)
logger.info(f"Plot saved to {save_path}")
else:
plt.show()
plt.close()
def plot_did_results(
results: Dict[str, Any],
df: pd.DataFrame,
date_col: str,
kpi_col: str,
assignment_col: str,
test_period_start: str,
save_path: Optional[str] = None,
) -> None:
"""
Plots the time series trends for Treatment and Control groups.
Parameters
----------
results : dict
The results of the DID analysis.
df : pd.DataFrame
The DataFrame containing the full experimental data.
geo_col, assignment_col, date_col, kpi_col : str
Names of the relevant columns.
test_period_start : str
The start date of the test period (YYYY-MM-DD), used to draw a vertical line.
save_path : str, optional
If provided, saves the plot to this path. Otherwise, displays the plot.
"""
df[date_col] = pd.to_datetime(df[date_col])
start_date = pd.to_datetime(test_period_start)
# Aggregate data by date and assignment group
agg_df = df.groupby([date_col, assignment_col])[kpi_col].mean().reset_index()
plt.figure(figsize=(12, 8))
sns.lineplot(data=agg_df, x=date_col, y=kpi_col, hue=assignment_col)
plt.axvline(
x=start_date, color="red", linestyle="--", label="Test Start"
)
plt.title("KPI Trend by Group", fontsize=16)
plt.xlabel("Date")
plt.ylabel(f"Average {kpi_col}")
plt.legend()
plt.grid(True, which="both", linestyle="--", linewidth=0.5)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
logger.info(f"Plot saved to {save_path}")
else:
plt.show()
plt.close()