Source code for geostep.visualizer

import matplotlib.pyplot as plt
import pandas as pd
from typing import Any, Dict, Optional

import seaborn as sns

from .logger import logger


[docs] def plot_power_analysis( power_df: pd.DataFrame, save_path: Optional[str] = None ) -> None: """ Plots the results of a power analysis as a heatmap. Parameters ---------- power_df : pd.DataFrame The DataFrame returned by the run_power_analysis function. It should contain 'effect_size', 'test_weeks', and 'power' columns. save_path : str, optional If provided, saves the plot to this path. Otherwise, displays the plot. """ if not all( col in power_df.columns for col in ["effect_size", "test_weeks", "power"] ): raise ValueError( "Input DataFrame must contain 'effect_size', 'test_weeks', and 'power' " "columns." ) power_pivot = power_df.pivot( index="test_weeks", columns="effect_size", values="power" ) plt.figure(figsize=(10, 7)) sns.heatmap( power_pivot, annot=True, fmt=".2f", cmap="viridis", cbar_kws={"label": "Statistical Power"}, ) plt.title("Power Analysis Results", fontsize=16) plt.xlabel("Assumed Effect Size (Lift)") plt.ylabel("Test Duration (Weeks)") if save_path: plt.savefig(save_path) logger.info(f"Plot saved to {save_path}") else: plt.show() plt.close()
[docs] def plot_lift_distribution( analysis_df: pd.DataFrame, assignment_col: str, lift_col: str = "lift_index", save_path: Optional[str] = None, ) -> None: """ Plots the distribution of the lift index by group. Parameters ---------- analysis_df : pd.DataFrame The DataFrame containing the per-geo lift calculations. This is returned by the `run_primary_analysis` function. assignment_col : str The name of the column with group assignments. lift_col : str, optional The name of the lift index column, by default 'lift_index'. save_path : str, optional If provided, saves the plot to this path. Otherwise, displays the plot. """ plt.figure(figsize=(10, 7)) sns.boxplot( data=analysis_df, x=assignment_col, y=lift_col, order=["Control", "Treatment"] ) plt.title("Distribution of Lift Index by Group", fontsize=16) plt.xlabel("Group") plt.ylabel("Lift Index") plt.grid(True, which="both", linestyle="--", linewidth=0.5) # Format y-axis as percentage from matplotlib.ticker import PercentFormatter plt.gca().yaxis.set_major_formatter(PercentFormatter(1)) plt.tight_layout() if save_path: plt.savefig(save_path) logger.info(f"Plot saved to {save_path}") else: plt.show() plt.close()
def plot_did_results( results: Dict[str, Any], df: pd.DataFrame, date_col: str, kpi_col: str, assignment_col: str, test_period_start: str, save_path: Optional[str] = None, ) -> None: """ Plots the time series trends for Treatment and Control groups. Parameters ---------- results : dict The results of the DID analysis. df : pd.DataFrame The DataFrame containing the full experimental data. geo_col, assignment_col, date_col, kpi_col : str Names of the relevant columns. test_period_start : str The start date of the test period (YYYY-MM-DD), used to draw a vertical line. save_path : str, optional If provided, saves the plot to this path. Otherwise, displays the plot. """ df[date_col] = pd.to_datetime(df[date_col]) start_date = pd.to_datetime(test_period_start) # Aggregate data by date and assignment group agg_df = df.groupby([date_col, assignment_col])[kpi_col].mean().reset_index() plt.figure(figsize=(12, 8)) sns.lineplot(data=agg_df, x=date_col, y=kpi_col, hue=assignment_col) plt.axvline( x=start_date, color="red", linestyle="--", label="Test Start" ) plt.title("KPI Trend by Group", fontsize=16) plt.xlabel("Date") plt.ylabel(f"Average {kpi_col}") plt.legend() plt.grid(True, which="both", linestyle="--", linewidth=0.5) plt.tight_layout() if save_path: plt.savefig(save_path) logger.info(f"Plot saved to {save_path}") else: plt.show() plt.close()