Page 4

# pyre-ignore-all-errors
"""
Oaxaca-Blinder Decomposition Analysis

This module provides a generalized framework for conducting Oaxaca-Blinder decomposition
analysis across different categorical breakdowns (e.g., product, vertical, product-vertical, etc.).

The core analysis decomposes performance gaps between regions into two components:
1. Construct Gap (Mix Effect): Differences in the composition/mix of categories
2. Performance Gap (Rate Effect): Differences in the efficiency/rates within categories

Input DataFrame Requirements:
============================
The input DataFrame should contain the following columns:
- region_column: Region identifier (string)
- numerator_column: Numerator for rate calculation (numeric)
- denominator_column: Denominator for rate calculation (numeric)
- category_columns: One or more categorical columns defining the breakdown levels
  (e.g., 'product', 'vertical', ['product', 'vertical'])

Example DataFrame structure for product analysis:
    territory_l4_name | product    | cli_won_cnt | cli_cnt
    AM-EMEA          | Product_A  | 25          | 100
    AM-EMEA          | Product_B  | 45          | 150
    AM-APAC          | Product_A  | 16          | 80
    ...

Example DataFrame structure for product-vertical analysis:
    territory_l4_name | product    | vertical | cli_won_cnt | cli_cnt
    AM-EMEA          | Product_A  | Tech     | 12          | 50
    AM-EMEA          | Product_A  | Finance  | 15          | 50
    AM-EMEA          | Product_B  | Tech     | 26          | 75
    ...

The analysis will:
1. Calculate rates on-the-fly from numerator/denominator columns
2. Calculate mix percentages for each region vs rest-of-world (using denominator as volume)
3. Calculate performance rates for each region vs rest-of-world
4. Decompose the total gap into construct and performance components
5. Provide detailed attribution at the category level
"""

import logging
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


# Centralized Analysis Thresholds and Constants
class AnalysisThresholds:
    """Centralized constants for all analysis thresholds to eliminate magic numbers."""

    # Gap significance thresholds
    BUSINESS_SIGNIFICANT_GAP = 0.01      # 1pp - business significant gap
    MINOR_DETECTABLE_GAP = 0.005         # 0.5pp - minor but detectable gap
    MATHEMATICAL_TOLERANCE = 0.001       # 0.1pp - mathematical tolerance for equality

    # Mix and business presence thresholds
    MEANINGFUL_MIX_THRESHOLD = 0.05      # 5% - meaningful business presence

    # Contradiction and paradox thresholds
    SUBSTANTIAL_CONTRADICTION_RATIO = 0.3  # 30% - substantial magnitude ratio
    MAJORITY_BUSINESS_WEIGHT = 0.5         # 50% - majority of business weight

    # Gap magnitude assessment thresholds (in percentage points)
    HIGH_IMPACT_GAP = 2.0               # 2pp - high business impact
    MEDIUM_IMPACT_GAP = 1.0             # 1pp - medium business impact
    SIGNIFICANT_RATE_DIFFERENCE = 5.0    # 5pp - significant rate difference

    # Business conclusion thresholds
    CONSTRUCT_DOMINANCE_RATIO = 1.2     # 20% more to be considered dominant

    # Business narrative generation thresholds
    UNDER_ALLOCATION_THRESHOLD = 0.8    # 80% of benchmark - significantly under-allocated
    OVER_ALLOCATION_THRESHOLD = 1.2     # 120% of benchmark - significantly over-allocated
    HIGH_PERFORMANCE_THRESHOLD = 0.35   # 35% rate threshold - high-performing category

    # Mathematical cancellation detection thresholds
    HIGH_MIX_VARIATION_THRESHOLD = 0.3  # 30pp difference - high impact mix variation
    MEDIUM_MIX_VARIATION_THRESHOLD = 0.15  # 15pp difference - medium impact mix variation
    HIGH_PERFORMING_SEGMENT_THRESHOLD = 0.4   # 40% rate - high-performing segment
    LOW_PERFORMING_SEGMENT_THRESHOLD = 0.2    # 20% rate - low-performing segment
    SIGNIFICANT_MIX_GAP_THRESHOLD = 20.0      # 20pp - significant mix gap
    SIGNIFICANT_ALLOCATION_IMPLICATION_THRESHOLD = 0.4  # 40% rate - high-performing allocation threshold
    SIGNIFICANT_MIX_DIFFERENCE_THRESHOLD = 20.0  # 20pp - significant mix difference threshold
    LOW_PERFORMING_ALLOCATION_THRESHOLD = 0.2  # 20% rate - low-performing allocation threshold


# Centralized Business Logic Functions
def assess_gap_significance(gap_magnitude_pp: float) -> str:
    """
    Single source of truth for gap significance assessment.

    Args:
        gap_magnitude_pp: Gap magnitude in percentage points (already converted from decimal)

    Returns:
        Significance level: 'high', 'medium', or 'low'
    """
    if gap_magnitude_pp > AnalysisThresholds.HIGH_IMPACT_GAP:
        return "high"
    elif gap_magnitude_pp > AnalysisThresholds.MEDIUM_IMPACT_GAP:
        return "medium"
    else:
        return "low"


def determine_business_focus(construct_gap: float, performance_gap: float) -> Dict:
    """
    Single source of truth for business conclusions and primary drivers.

    Args:
        construct_gap: Construct (composition) gap contribution
        performance_gap: Performance (execution) gap contribution

    Returns:
        Dictionary with business_conclusion and primary_driver
    """
    abs_construct = abs(construct_gap)
    abs_performance = abs(performance_gap)

    # Decisive classification: focus on the larger absolute contributor
    if abs_construct > abs_performance:
        business_conclusion = "composition_driven"
        primary_driver = "allocation"
    else:
        business_conclusion = "performance_driven"
        primary_driver = "execution"

    return {
        "business_conclusion": business_conclusion,
        "primary_driver": primary_driver,
        "construct_dominance": abs_construct > abs_performance * AnalysisThresholds.CONSTRUCT_DOMINANCE_RATIO,
        "performance_dominance": abs_performance > abs_construct * AnalysisThresholds.CONSTRUCT_DOMINANCE_RATIO
    }


def oaxaca_blinder_decomposition(
    df: pd.DataFrame,
    region_column: str = "territory_l4_name",
    numerator_column: str = "cli_won_cnt",
    denominator_column: str = "cli_cnt",
    category_columns: Union[str, List[str]] = "product",
    decomposition_method: str = "two_part",
    baseline_region: str = "rest_of_world",
) -> pd.DataFrame:
    """
    Perform Oaxaca-Blinder decomposition analysis for any categorical breakdown.

    This function decomposes performance gaps between regions into:
    1. Construct Gap: (region_mix - rest_mix) * rest_rate
    2. Performance Gap: region_mix * (region_rate - rest_rate)
    3. Net Gap: region_rate - rest_rate

    Args:
        df: Input DataFrame with regional performance data
        region_column: Column name containing region identifiers
        numerator_column: Column name containing numerator for rate calculation
        denominator_column: Column name containing denominator for rate calculation (also used for mix)
        category_columns: Column name(s) defining the categorical breakdown.
                         Can be a single string or list of strings for multi-level analysis.

    Returns:
        DataFrame with decomposition results containing:
        - region: Region identifier
        - category_*: Category identifiers (one column per category level)
        - region_mix_pct: Region's mix percentage for this category
        - rest_mix_pct: Rest-of-world mix percentage for this category
        - region_rate: Region's performance rate for this category
        - rest_rate: Rest-of-world performance rate for this category
        - construct_gap_contribution: Category's contribution to construct gap
        - performance_gap_contribution: Category's contribution to performance gap
        - net_gap: Net gap (region_rate - rest_rate)
        - performance_index: Region rate / Rest rate (relative performance)
        - weighted_impact: region_mix_pct * (region_rate - rest_rate)

    Example:
        # Single-level analysis (product only)
        result = oaxaca_blinder_decomposition(
            df=df,
            numerator_column="cli_won_cnt",
            denominator_column="cli_cnt",
            category_columns="product"
        )

        # Multi-level analysis (product-vertical)
        result = oaxaca_blinder_decomposition(
            df=df,
            numerator_column="cli_won_cnt",
            denominator_column="cli_cnt",
            category_columns=["product", "vertical"]
        )
    """

    # Ensure category_columns is a list
    if isinstance(category_columns, str):
        category_columns = [category_columns]

    # Validate required columns exist
    required_columns = [
        region_column,
        numerator_column,
        denominator_column,
    ] + category_columns
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")

    # Initialize list to store decomposition results
    decomposition_results = []

    # Get unique regions
    regions = df[region_column].unique()

    # Process each region
    for region in regions:
        # Get region data using utility function
        region_data = _filter_region_data(df, region_column, region)

        # Calculate region mix percentages using utility function
        region_mix_pct = _calculate_region_mix(
            region_data, category_columns, denominator_column
        )

        # Calculate region rates (numerator/denominator aggregated by category)
        region_rates = _calculate_rates_from_components(
            region_data, category_columns, numerator_column, denominator_column
        )

        # Get baseline data based on baseline_region parameter
        try:
            rest_mix_pct, rest_rates = _get_baseline_data(
                df,
                region_column,
                baseline_region,
                region,
                numerator_column,
                denominator_column,
                category_columns,
            )

        except (ValueError, NotImplementedError) as e:
            logger.warning(f"Error getting baseline data for region {region}: {e}")
            continue

        if len(rest_mix_pct) > 0:

            # Process each category combination
            for category_combo in region_mix_pct.index:
                # Store original key for lookups
                lookup_key = category_combo

                # Ensure category_combo is a tuple for consistent handling in result record
                if not isinstance(category_combo, tuple):
                    category_combo = (category_combo,)

                # Get values for this category using the original key
                region_mix_val = region_mix_pct.get(lookup_key, 0)
                rest_mix_val = rest_mix_pct.get(lookup_key, 0)
                region_rate_val = region_rates.get(lookup_key, 0)
                rest_rate_val = rest_rates.get(lookup_key, 0)

                # Calculate gap components based on decomposition method
                if decomposition_method == "two_part":
                    # B-baseline (ROW baseline)
                    construct_gap = (region_mix_val - rest_mix_val) * rest_rate_val
                    performance_gap = region_mix_val * (region_rate_val - rest_rate_val)
                elif decomposition_method == "reverse":
                    # A-baseline (region baseline)
                    construct_gap = (region_mix_val - rest_mix_val) * region_rate_val
                    performance_gap = rest_mix_val * (region_rate_val - rest_rate_val)
                elif decomposition_method == "symmetric":
                    # Symmetric (path-independent using midpoints)
                    avg_rate = (region_rate_val + rest_rate_val) / 2
                    avg_mix = (region_mix_val + rest_mix_val) / 2
                    construct_gap = (region_mix_val - rest_mix_val) * avg_rate
                    performance_gap = avg_mix * (region_rate_val - rest_rate_val)
                else:
                    raise ValueError(
                        f"Invalid decomposition_method: {decomposition_method}. "
                        f"Must be 'two_part', 'reverse', or 'symmetric'"
                    )

                # Calculate net gap as the weighted contribution to overall gap
                # This is the fundamental Oaxaca-Blinder identity:
                # gap_i = region_mix_i * region_rate_i - rest_mix_i * rest_rate_i
                net_gap = (
                    region_mix_val * region_rate_val - rest_mix_val * rest_rate_val
                )

                # Calculate additional metrics
                performance_index = (
                    region_rate_val / rest_rate_val if rest_rate_val != 0 else np.nan
                )
                weighted_impact = region_mix_val * (region_rate_val - rest_rate_val)

                # Create result record
                result_record = {
                    "region": region,
                    "region_mix_pct": region_mix_val,
                    "rest_mix_pct": rest_mix_val,
                    "region_rate": region_rate_val,
                    "rest_rate": rest_rate_val,
                    "construct_gap_contribution": construct_gap,
                    "performance_gap_contribution": performance_gap,
                    "net_gap": net_gap,
                    "performance_index": performance_index,
                    "weighted_impact": weighted_impact,
                }

                # Add category columns
                for i, col_name in enumerate(category_columns):
                    result_record[f"category_{col_name}"] = category_combo[i]

                decomposition_results.append(result_record)

    # Convert to DataFrame
    result_df = pd.DataFrame(decomposition_results)

    return result_df


def _get_baseline_data(
    df: pd.DataFrame,
    region_column: str,
    baseline_region: str,
    current_region: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: List[str],
) -> Tuple[pd.Series, pd.Series]:
    """
    Get baseline data based on baseline_region specification.

    Args:
        df: Input DataFrame
        region_column: Column name containing region identifiers
        baseline_region: Type of baseline to use
        current_region: Current region being analyzed
        numerator_column: Column name for numerator
        denominator_column: Column name for denominator
        category_columns: List of category column names

    Returns:
        Tuple of (baseline_mix_pct, baseline_rates)
    """

    if baseline_region == "rest_of_world":
        # Default behavior - exclude current region
        baseline_data = _filter_region_data(df, region_column, current_region, exclude=True)

    elif baseline_region == "global_average":
        # Use all data as baseline (including current region)
        baseline_data = df

    elif baseline_region == "top_performer":
        # Create synthetic "top performer" baseline using best rate for each category
        # This avoids the issue of one region being best overall but not best in each category

        # Exclude current region from consideration
        other_regions_data = df[df[region_column] != current_region]

        if len(other_regions_data) == 0:
            raise ValueError(
                f"No other regions available for top_performer baseline for region {current_region}"
            )

        # For each category, find the best performing region and use that rate
        # Calculate rates by category for all other regions
        category_rates_by_region = {}
        for region in other_regions_data[region_column].unique():
            region_data = other_regions_data[
                other_regions_data[region_column] == region
            ]
            region_rates = _calculate_rates_from_components(
                region_data, category_columns, numerator_column, denominator_column
            )
            category_rates_by_region[region] = region_rates

        # Find best rate for each category across all regions
        all_categories = set()
        for rates in category_rates_by_region.values():
            all_categories.update(rates.index)

        best_rates = {}

        for category in all_categories:
            category_rates = []
            for region, rates in category_rates_by_region.items():
                if category in rates and not pd.isna(rates[category]):
                    category_rates.append(rates[category])

            if category_rates:
                best_rates[category] = max(category_rates)

        # Calculate mix using all other regions (not just the best performer)
        baseline_mix_pct = _calculate_region_mix(
            other_regions_data, category_columns, denominator_column
        )
        baseline_rates = pd.Series(best_rates)

        logger.info(
            f"Top performer baseline created using best rates across categories for {current_region}"
        )
        return baseline_mix_pct, baseline_rates

    elif baseline_region == "previous_period":
        # This would require time-series data with a period column
        # For now, raise an error with guidance
        raise NotImplementedError(
            "previous_period baseline requires time-series data. "
            "Use oaxaca_blinder_time_series function instead."
        )

    elif baseline_region in df[region_column].unique():
        # Specific region baseline
        baseline_data = df[df[region_column] == baseline_region]

    else:
        raise ValueError(
            f"Invalid baseline_region: {baseline_region}. "
            f"Must be 'rest_of_world', 'global_average', 'top_performer', "
            f"'previous_period', or a valid region name from: {df[region_column].unique()}"
        )

    if len(baseline_data) == 0:
        raise ValueError(f"No data available for baseline_region: {baseline_region}")

    # Calculate baseline mix and rates using utility function
    baseline_mix_pct = _calculate_region_mix(
        baseline_data, category_columns, denominator_column
    )
    baseline_rates = _calculate_rates_from_components(
        baseline_data, category_columns, numerator_column, denominator_column
    )

    return baseline_mix_pct, baseline_rates


def _calculate_rates_from_components(
    data: pd.DataFrame,
    category_columns: List[str],
    numerator_column: str,
    denominator_column: str,
) -> pd.Series:
    """
    Calculate rates for each category combination from numerator and denominator columns.

    Args:
        data: DataFrame containing the data
        category_columns: List of category column names
        numerator_column: Column name for numerator
        denominator_column: Column name for denominator

    Returns:
        Series with category combinations as index and rates as values
    """
    # Handle case where no category columns are provided (overall rate calculation)
    if not category_columns:
        total_numerator = data[numerator_column].sum()
        total_denominator = data[denominator_column].sum()
        overall_rate = total_numerator / total_denominator if total_denominator > 0 else 0
        # Return as Series with single value
        return pd.Series([overall_rate], index=[0])

    # Group by categories and sum numerator and denominator
    grouped = data.groupby(category_columns).agg(
        {numerator_column: "sum", denominator_column: "sum"}
    )

    # Calculate rates
    rates = grouped[numerator_column] / grouped[denominator_column]

    # Replace any NaN or inf values with 0
    rates = rates.fillna(0).replace([np.inf, -np.inf], 0)

    return rates


def _calculate_region_mix(
    data: pd.DataFrame,
    category_columns: List[str],
    denominator_column: str,
) -> pd.Series:
    """
    Calculate mix percentages for each category combination.

    Args:
        data: DataFrame containing the data
        category_columns: List of category column names
        denominator_column: Column name for denominator (volume)

    Returns:
        Series with category combinations as index and mix percentages as values
    """
    # Handle case where no category columns are provided (overall calculation)
    if not category_columns:
        total_volume = data[denominator_column].sum()
        # Return as Series with single value
        return pd.Series([1.0], index=[0])

    region_mix = data.groupby(category_columns)[denominator_column].sum()
    return region_mix / region_mix.sum()


def _filter_region_data(
    df: pd.DataFrame,
    region_column: str,
    target_region: str,
    exclude: bool = False,
) -> pd.DataFrame:
    """
    Filter DataFrame for specific region data.

    Args:
        df: Input DataFrame
        region_column: Column name containing region identifiers
        target_region: Region to filter for/against
        exclude: If True, exclude the target_region; if False, include only target_region

    Returns:
        Filtered DataFrame
    """
    if exclude:
        return df[df[region_column] != target_region]
    else:
        return df[df[region_column] == target_region]






def calculate_regional_gaps(
    decomposition_df: pd.DataFrame, region_column: str = "region"
) -> pd.DataFrame:
    """
    Calculate aggregate gaps at the regional level from decomposition results.

    Args:
        decomposition_df: Output from oaxaca_blinder_decomposition
        region_column: Column name containing region identifiers

    Returns:
        DataFrame with regional-level gap analysis:
        - region: Region identifier
        - total_construct_gap: Sum of construct gap contributions
        - total_performance_gap: Sum of performance gap contributions
        - total_net_gap: Sum of net gaps (should equal total_construct_gap + total_performance_gap)
        - expected_rate: What the rate would be with region mix but rest-of-world rates
        - actual_rate: Actual regional rate
    """

    # Filter out rows with missing gap data
    analysis_df = decomposition_df[
        decomposition_df["construct_gap_contribution"].notna()
    ].copy()

    # Calculate regional aggregates
    regional_gaps = (
        analysis_df.groupby(region_column)
        .agg(
            {
                "construct_gap_contribution": "sum",
                "performance_gap_contribution": "sum",
                "net_gap": "sum",
                "region_mix_pct": "sum",  # Should sum to 1.0 for validation
                "rest_mix_pct": "sum",  # Should sum to 1.0 for validation
            }
        )
        .reset_index()
    )

    # Rename columns for clarity
    regional_gaps = regional_gaps.rename(
        columns={
            "construct_gap_contribution": "total_construct_gap",
            "performance_gap_contribution": "total_performance_gap",
            "net_gap": "total_net_gap",
        }
    )

    # Calculate expected and actual rates
    # Expected rate = sum(region_mix * rest_rate)
    # Actual rate = sum(region_mix * region_rate)
    expected_rates = []
    actual_rates = []

    for region in regional_gaps[region_column]:
        region_data = analysis_df[analysis_df[region_column] == region]

        expected_rate = (region_data["region_mix_pct"] * region_data["rest_rate"]).sum()
        actual_rate = (region_data["region_mix_pct"] * region_data["region_rate"]).sum()

        expected_rates.append(expected_rate)
        actual_rates.append(actual_rate)

    regional_gaps["expected_rate"] = expected_rates
    regional_gaps["actual_rate"] = actual_rates

    # Add validation columns
    regional_gaps["gap_validation"] = (
        regional_gaps["total_construct_gap"] + regional_gaps["total_performance_gap"]
    ) - regional_gaps["total_net_gap"]

    regional_gaps["rate_validation"] = (
        regional_gaps["actual_rate"] - regional_gaps["expected_rate"]
    ) - regional_gaps["total_net_gap"]

    return regional_gaps


def find_top_drivers(
    decomposition_df: pd.DataFrame,
    region: str,
    metric_type: str = "weighted_impact",
    top_n: int = 5,
    category_columns: Optional[List[str]] = None,
) -> pd.DataFrame:
    """
    Find the top positive and negative drivers for a specific region.

    Args:
        decomposition_df: Output from oaxaca_blinder_decomposition
        region: Region to analyze
        metric_type: Metric to use for ranking ("weighted_impact", "construct_gap_contribution",
                    "performance_gap_contribution", or "net_gap")
        top_n: Number of top drivers to return
        category_columns: List of category column names to include in output

    Returns:
        DataFrame with top drivers, sorted by absolute impact
    """

    # Filter for the specified region and remove missing values
    region_data = decomposition_df[
        (decomposition_df["region"] == region) & (decomposition_df[metric_type].notna())
    ].copy()

    if len(region_data) == 0:
        logger.warning(f"No data found for region {region}")
        return pd.DataFrame()

    # Sort by absolute value of the metric
    region_data["abs_impact"] = region_data[metric_type].abs()
    top_drivers = region_data.nlargest(top_n, "abs_impact")

    # Select relevant columns for output
    output_columns = [
        "region",
        metric_type,
        "region_mix_pct",
        "rest_mix_pct",
        "region_rate",
        "rest_rate",
        "performance_index",
    ]

    # Add category columns if specified
    if category_columns:
        category_cols = [f"category_{col}" for col in category_columns]
        output_columns = category_cols + output_columns
    else:
        # Include all category columns found in the data
        category_cols = [
            col for col in top_drivers.columns if col.startswith("category_")
        ]
        output_columns = category_cols + output_columns

    # Filter to existing columns
    output_columns = [col for col in output_columns if col in top_drivers.columns]

    return top_drivers[output_columns].reset_index(drop=True)


def summarize_decomposition(
    decomposition_df: pd.DataFrame, regional_gaps_df: Optional[pd.DataFrame] = None
) -> Dict:
    """
    Generate a summary of the Oaxaca-Blinder decomposition analysis.

    Args:
        decomposition_df: Output from oaxaca_blinder_decomposition
        regional_gaps_df: Output from calculate_regional_gaps (optional)

    Returns:
        Dictionary containing summary statistics and insights
    """

    if regional_gaps_df is None:
        regional_gaps_df = calculate_regional_gaps(decomposition_df)

    # Calculate summary statistics
    summary = {
        "total_regions": len(regional_gaps_df),
        "total_categories": len(
            decomposition_df[decomposition_df["construct_gap_contribution"].notna()]
        ),
        "avg_construct_gap": regional_gaps_df["total_construct_gap"].mean(),
        "avg_performance_gap": regional_gaps_df["total_performance_gap"].mean(),
        "construct_gap_dominance": (
            regional_gaps_df["total_construct_gap"].abs()
            > regional_gaps_df["total_performance_gap"].abs()
        ).mean(),
        "largest_positive_gap_region": (
            regional_gaps_df.loc[regional_gaps_df["total_net_gap"].idxmax(), "region"]
            if len(regional_gaps_df) > 0
            else None
        ),
        "largest_negative_gap_region": (
            regional_gaps_df.loc[regional_gaps_df["total_net_gap"].idxmin(), "region"]
            if len(regional_gaps_df) > 0
            else None
        ),
    }

    return summary


def detect_aggregation_bias(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: Optional[List[str]] = None,
    subcategory_columns: Optional[List[str]] = None,
    top_n_focus: int = 3,
    **kwargs,
) -> pd.DataFrame:
    """
    Detect Simpson's paradox that would lead to wrong business recommendations.

    This function focuses on detecting when we'd make incorrect recommendations about
    which categories (products/verticals) to fix. It only checks the top contributors
    that would actually drive business decisions.

    Args:
        df: Input DataFrame with regional performance data
        region_column: Column name containing region identifiers
        numerator_column: Column name containing numerator for rate calculation
        denominator_column: Column name containing denominator for rate calculation
        category_columns: List of category column names for aggregate analysis (None for region-level)
        subcategory_columns: List of subcategory column names for detailed analysis
        top_n_focus: Number of top categories to focus paradox detection on (default: 3)
        **kwargs: Additional arguments passed to oaxaca_blinder_decomposition

    Returns:
        DataFrame with detected Simpson's paradox cases that would change business recommendations

    Example:
        # Detect if we'd wrongly blame top products when the issue is actually vertical mix
        paradox_cases = detect_aggregation_bias(
            df=df, category_columns=["product"], subcategory_columns=["vertical"], top_n_focus=2
        )
    """

    # Handle region-level Simpson's paradox detection
    if category_columns is None:
        return _detect_region_level_simpson(
            df,
            region_column,
            numerator_column,
            denominator_column,
            subcategory_columns,
            **kwargs,
        )

    # Handle category-level Simpson's paradox detection
    if subcategory_columns is None:
        raise ValueError(
            "subcategory_columns must be provided for category-level paradox detection"
        )

    return _detect_category_level_simpson(
        df,
        region_column,
        numerator_column,
        denominator_column,
        category_columns,
        subcategory_columns,
        top_n_focus,
        **kwargs,
    )


def _detect_category_level_simpson(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: List[str],
    subcategory_columns: List[str],
    top_n_focus: int,
    **kwargs,
) -> pd.DataFrame:
    """
    Detect category-level Simpson's paradox focusing on top business drivers.

    This checks if our top-N worst performing categories would actually be
    recommended for different actions at the subcategory level.
    """

    # Get detailed decomposition first (this is the source of truth)
    detailed_results = oaxaca_blinder_decomposition(
        df,
        region_column=region_column,
        numerator_column=numerator_column,
        denominator_column=denominator_column,
        category_columns=category_columns + subcategory_columns,
        **kwargs,
    )

    # Use the existing decomposition function with correct category level
    # This avoids reinventing the wheel and ensures consistency
    aggregate_results = oaxaca_blinder_decomposition(
        df,
        region_column=region_column,
        numerator_column=numerator_column,
        denominator_column=denominator_column,
        category_columns=category_columns,
        **kwargs,
    )

    paradox_cases = []

    for region in aggregate_results["region"].unique():
        region_agg = aggregate_results[aggregate_results["region"] == region]

        # Focus only on top-N worst performers (most negative net_gap)
        # These are the categories we'd actually recommend fixing
        top_problems = region_agg.nsmallest(top_n_focus, "net_gap")

        for _, agg_row in top_problems.iterrows():
            category = agg_row[f"category_{category_columns[0]}"]

            # Get subcategory breakdown for this problem category
            detailed_subset = detailed_results[
                (detailed_results["region"] == region)
                & (detailed_results[f"category_{category_columns[0]}"] == category)
            ]

            if len(detailed_subset) <= 1:
                continue

            # Check if this "problem category" actually performs well in subcategories
            # Prepare data for unified paradox detection
            agg_data = {
                "construct_gap": agg_row["construct_gap_contribution"],
                "performance_gap": agg_row["performance_gap_contribution"],
                "net_gap": agg_row["net_gap"]
            }
            detailed_data = {
                "construct_gap": detailed_subset["construct_gap_contribution"].sum(),
                "performance_gap": detailed_subset["performance_gap_contribution"].sum(),
                "net_gap": detailed_subset["net_gap"].sum(),
                "weighted_gap": detailed_subset["weighted_impact"].sum()
            }

            paradox_detected = _unified_paradox_detection(agg_data, detailed_data, "category")

            if paradox_detected:
                # Prepare data for unified analysis function
                agg_data = {
                    "construct_gap_contribution": agg_row["construct_gap_contribution"],
                    "performance_gap_contribution": agg_row["performance_gap_contribution"],
                    "net_gap": agg_row["net_gap"]
                }
                detailed_data = {
                    "construct_gap_contribution": detailed_subset["construct_gap_contribution"].sum(),
                    "performance_gap_contribution": detailed_subset["performance_gap_contribution"].sum(),
                    "net_gap": detailed_subset["net_gap"].sum()
                }

                paradox_info = analyze_paradox_impact(
                    agg_data=agg_data,
                    detailed_data=detailed_data,
                    analysis_type="category",
                    category_name=category_columns[0]
                )
                paradox_info.update(
                    {
                        "region": region,
                        "category": category,
                        "aggregate_net_gap": agg_row["net_gap"],
                    }
                )
                paradox_cases.append(paradox_info)

    return pd.DataFrame(paradox_cases)




def _unified_paradox_detection(agg_data: Dict, detailed_data: Dict, analysis_type: str) -> bool:
    """
    Unified paradox detection logic for both category-level and region-level analysis.

    Args:
        agg_data: Dictionary with aggregate-level data
        detailed_data: Dictionary with detailed-level data
        analysis_type: Type of analysis ("category" or "region")

    Returns:
        True if paradox is detected, False otherwise
    """

    # Extract data based on analysis type
    if analysis_type == "category":
        agg_construct_gap = agg_data["construct_gap"]
        agg_performance_gap = agg_data["performance_gap"]
        agg_net_gap = agg_data["net_gap"]

        sub_construct_gap = detailed_data["construct_gap"]
        sub_performance_gap = detailed_data["performance_gap"]
        sub_net_gap = detailed_data["net_gap"]
        weighted_gap = detailed_data["weighted_gap"]

    elif analysis_type == "region":
        # For region-level, we get overall gap vs product-level gaps
        agg_construct_gap = 0  # Overall regional analysis doesn't decompose
        agg_performance_gap = agg_data["overall_gap"]  # Treat as performance gap
        agg_net_gap = agg_data["overall_gap"]

        sub_construct_gap = detailed_data["product_construct_gap"]
        sub_performance_gap = detailed_data["product_performance_gap"]
        sub_net_gap = detailed_data["product_net_gap"]
        weighted_gap = detailed_data.get("weighted_gap", sub_net_gap)

    else:
        raise ValueError(f"Invalid analysis_type: {analysis_type}. Must be 'category' or 'region'")

    # Use centralized business focus determination
    agg_business_focus = determine_business_focus(agg_construct_gap, agg_performance_gap)
    sub_business_focus = determine_business_focus(sub_construct_gap, sub_performance_gap)

    agg_recommendation = agg_business_focus["business_conclusion"]
    sub_recommendation = sub_business_focus["business_conclusion"]

    # Only flag as paradox if:
    # 1. Recommendations are different AND
    # 2. The gap magnitude is significant AND
    # 3. The direction of the overall performance contradicts subcategory patterns

    recommendation_conflict = agg_recommendation != sub_recommendation

    # Use centralized significance assessment
    agg_gap_significance = assess_gap_significance(abs(agg_net_gap) * 100)
    sub_gap_significance = assess_gap_significance(abs(sub_net_gap) * 100)

    agg_significant = agg_gap_significance in ["medium", "high"]
    sub_significant = sub_gap_significance in ["medium", "high"]

    if not (agg_significant and sub_significant):
        return False  # Gaps too small to be meaningful

    # Check for direction contradiction
    if analysis_type == "region":
        # For region-level Simpson's paradox:
        # - Overall gap shows region underperforming (negative)
        # - But product-level analysis shows region should be outperforming (positive net gap)
        # This indicates the region outperforms in individual products but loses due to mix

        # The key insight: if overall gap is negative but product analysis shows positive net gap,
        # then the region is being hurt by having bad product mix (Simpson's paradox)
        opposite_directions = (agg_net_gap < 0) and (sub_net_gap > 0)

        # Ensure both gaps are significant enough to matter
        agg_significant = abs(agg_net_gap) > AnalysisThresholds.BUSINESS_SIGNIFICANT_GAP
        sub_significant = abs(sub_net_gap) > AnalysisThresholds.MINOR_DETECTABLE_GAP

        direction_paradox = opposite_directions and agg_significant and sub_significant
    else:
        # Original category-level logic
        opposite_directions = (agg_net_gap > 0) != (weighted_gap > 0)

        # Ensure the contradiction is substantial (not just marginal)
        magnitude_ratio = (
            abs(weighted_gap) / abs(agg_net_gap) if abs(agg_net_gap) > 0 else 0
        )
        substantial_contradiction = magnitude_ratio > AnalysisThresholds.SUBSTANTIAL_CONTRADICTION_RATIO

        direction_paradox = opposite_directions and substantial_contradiction

    # For region-level analysis, we care more about direction paradox than recommendation conflict
    # because region-level analysis doesn't have the same construct/performance decomposition
    if analysis_type == "region":
        return direction_paradox
    else:
        return recommendation_conflict and agg_significant and direction_paradox


def analyze_paradox_impact(
    agg_data: Dict,
    detailed_data: Dict,
    analysis_type: str = "category",
    category_name: str = "Unknown"
) -> Dict:
    """
    Unified function for analyzing business impact of detected Simpson's paradox.

    Works for both category-level and region-level paradox analysis by accepting
    standardized data dictionaries instead of specific DataFrame structures.

    Args:
        agg_data: Dictionary with aggregate-level analysis data
        detailed_data: Dictionary with detailed-level analysis data
        analysis_type: Type of analysis ("category" or "region")
        category_name: Name of the category being analyzed

    Returns:
        Dictionary with standardized paradox impact analysis
    """

    # Extract data based on analysis type
    if analysis_type == "category":
        agg_construct = agg_data["construct_gap_contribution"]
        agg_performance = agg_data["performance_gap_contribution"]
        agg_net = agg_data["net_gap"]

        sub_construct = detailed_data["construct_gap_contribution"]
        sub_performance = detailed_data["performance_gap_contribution"]
        sub_net = detailed_data["net_gap"]

    elif analysis_type == "region":
        # For region-level, we get overall gap vs product-level gaps
        agg_construct = 0  # Overall regional analysis doesn't decompose
        agg_performance = agg_data["overall_gap"]  # Treat as performance gap
        agg_net = agg_data["overall_gap"]

        sub_construct = detailed_data["product_construct_gap"]
        sub_performance = detailed_data["product_performance_gap"]
        sub_net = detailed_data["product_net_gap"]
    else:
        raise ValueError(f"Invalid analysis_type: {analysis_type}. Must be 'category' or 'region'")

    # Use centralized business focus determination
    agg_business_focus = determine_business_focus(agg_construct, agg_performance)
    sub_business_focus = determine_business_focus(sub_construct, sub_performance)

    agg_recommendation = agg_business_focus["business_conclusion"]
    sub_recommendation = sub_business_focus["business_conclusion"]

    # Calculate gap magnitudes using centralized logic
    agg_gap_magnitude = (abs(agg_construct) + abs(agg_performance)) * 100
    sub_gap_magnitude = (abs(sub_construct) + abs(sub_performance)) * 100

    # Use centralized gap significance assessment
    gap_significance = assess_gap_significance(agg_gap_magnitude)

    # Determine conflicts
    recommendation_conflict = agg_recommendation != sub_recommendation
    driver_conflict = agg_business_focus["primary_driver"] != sub_business_focus["primary_driver"]

    # Generate analysis-type specific insights
    if analysis_type == "category":
        aggregate_description = agg_recommendation
        subcategory_description = sub_recommendation
    else:  # region
        overall_direction = "outperforming" if agg_net > 0 else "underperforming"
        product_direction = "outperforming" if sub_net > 0 else "underperforming"
        aggregate_description = f"Region appears to be {overall_direction} overall"
        subcategory_description = f"Product analysis shows {product_direction} with {sub_recommendation} focus"

    return {
        "aggregate_recommendation": aggregate_description,
        "subcategory_recommendation": subcategory_description,
        "aggregate_gap_magnitude": agg_gap_magnitude,
        "subcategory_gap_magnitude": sub_gap_magnitude,
        "recommendation_conflict": recommendation_conflict,
        "driver_conflict": driver_conflict,
        "gap_significance": gap_significance,
        "aggregate_primary_driver": agg_business_focus["primary_driver"],
        "subcategory_primary_driver": sub_business_focus["primary_driver"],
        "aggregate_construct_gap": agg_construct,
        "aggregate_performance_gap": agg_performance,
        "subcategory_construct_gap": sub_construct,
        "subcategory_performance_gap": sub_performance,
        "analysis_type": analysis_type,
        "category_name": category_name
    }




def _detect_region_level_simpson(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    subcategory_columns: List[str],
    **kwargs,
) -> pd.DataFrame:
    """
    Detect region-level Simpson's paradox using business recommendation conflicts.

    Uses the same philosophy as category-level detection: focuses on cases where
    overall regional analysis would lead to different business recommendations
    than product-level analysis. Scalable to any number of regions.

    Args:
        df: Input DataFrame
        region_column: Column name containing region identifiers
        numerator_column: Column name for numerator
        denominator_column: Column name for denominator
        subcategory_columns: List of subcategory columns (e.g., ["product"])
        **kwargs: Additional arguments

    Returns:
        DataFrame with region-level Simpson's paradox cases
    """

    paradox_cases = []

    # Get product-level decomposition using existing function
    product_level_results = oaxaca_blinder_decomposition(
        df,
        region_column=region_column,
        numerator_column=numerator_column,
        denominator_column=denominator_column,
        category_columns=subcategory_columns,
        **kwargs,
    )

    # Check each region for recommendation conflicts
    for region in df[region_column].unique():
        # Calculate overall gap using existing utility functions
        region_data = _filter_region_data(df, region_column, region)
        rest_data = _filter_region_data(df, region_column, region, exclude=True)

        if len(rest_data) == 0:
            continue

        region_rates = _calculate_rates_from_components(
            region_data, [], numerator_column, denominator_column
        )
        rest_rates = _calculate_rates_from_components(
            rest_data, [], numerator_column, denominator_column
        )

        region_rate = region_rates.iloc[0] if len(region_rates) > 0 else 0
        rest_rate = rest_rates.iloc[0] if len(rest_rates) > 0 else 0
        overall_gap = region_rate - rest_rate

        # Use centralized significance assessment
        gap_significance = assess_gap_significance(abs(overall_gap) * 100)
        if gap_significance == "low":
            continue

        # Get product-level analysis for this region
        region_products = product_level_results[product_level_results["region"] == region]
        if len(region_products) <= 1:
            continue

        # Calculate product-level metrics using existing functions
        product_construct_gap = region_products["construct_gap_contribution"].sum()
        product_performance_gap = region_products["performance_gap_contribution"].sum()
        product_business_focus = determine_business_focus(product_construct_gap, product_performance_gap)

        # Calculate product-level advantage
        product_wins = 0
        total_products = 0
        product_level_advantage = 0

        for _, product_row in region_products.iterrows():
            if product_row["region_rate"] > product_row["rest_rate"]:
                product_wins += 1
                product_level_advantage += (product_row["region_rate"] - product_row["rest_rate"]) * product_row["region_mix_pct"]
            total_products += 1

        # If region wins in majority of products, treat as positive product net gap
        if product_wins > total_products / 2:
            product_net_gap = abs(product_level_advantage)  # Positive when region wins in products
        else:
            product_net_gap = -abs(product_level_advantage)  # Negative when region loses in products

        # Check for paradox using centralized detection
        overall_gap_significance = assess_gap_significance(abs(overall_gap) * 100)
        product_gap_significance = assess_gap_significance(abs(product_net_gap) * 100)

        region_paradox_conditions = [
            (overall_gap < 0),  # Region appears worse overall
            (product_net_gap > 0),  # But wins in products
            (overall_gap_significance in ["medium", "high"]),  # Significant overall gap
            (product_gap_significance in ["medium", "high"])  # Significant product advantage
        ]

        paradox_detected = all(region_paradox_conditions)

        if paradox_detected:
            # Generate paradox info using unified function
            agg_data = {
                "overall_gap": overall_gap
            }
            detailed_data = {
                "product_construct_gap": product_construct_gap,
                "product_performance_gap": product_performance_gap,
                "product_net_gap": product_net_gap
            }

            paradox_info = analyze_paradox_impact(
                agg_data=agg_data,
                detailed_data=detailed_data,
                analysis_type="region",
                category_name=subcategory_columns[0]
            )

            paradox_info.update({
                "region": region,
                "category": "OVERALL",  # Region-level paradox
                "aggregate_gap": overall_gap,
                "product_recommendation": product_business_focus["business_conclusion"],
                "paradox_type": "region_level",
            })

            paradox_cases.append(paradox_info)

    return pd.DataFrame(paradox_cases)






def _assess_region_paradox_severity(gap_magnitude: float, num_contradicting_products: int) -> str:
    """Assess severity of region-level Simpson's paradox based on business impact."""

    if gap_magnitude > AnalysisThresholds.HIGH_IMPACT_GAP * 1.5 and num_contradicting_products >= 2:
        return "High"
    elif gap_magnitude > AnalysisThresholds.MEDIUM_IMPACT_GAP * 1.5 or num_contradicting_products >= 2:
        return "Medium"
    else:
        return "Low"


def enhanced_rca_analysis(
    df: pd.DataFrame,
    region_column: str = "territory_l4_name",
    numerator_column: str = "cli_won_cnt",
    denominator_column: str = "cli_cnt",
    category_columns: Union[str, List[str]] = "product",
    subcategory_columns: Optional[List[str]] = None,
    **kwargs,
) -> Dict:
    """
    Run comprehensive RCA with automatic detection of mathematical edge cases.

    This function provides a complete root cause analysis by:
    1. Running decomposition at the specified category level
    2. Automatically detecting mathematical cancellation scenarios (identical rates/mix)
    3. Automatically checking for Simpson's paradox per region if subcategories are provided
    4. Providing region-specific narratives with implicit edge case handling
    5. Providing actionable insights for business decision-making

    Args:
        df: Input DataFrame with regional performance data
        region_column: Column name containing region identifiers
        numerator_column: Column name containing numerator for rate calculation
        denominator_column: Column name containing denominator for rate calculation
        category_columns: Column name(s) defining the categorical breakdown
        subcategory_columns: List of subcategory column names for automatic paradox detection
        **kwargs: Additional arguments passed to oaxaca_blinder_decomposition

    Returns:
        Dictionary containing:
        - primary_analysis: Main decomposition results (automatically uses subcategory if paradox detected)
        - regional_analysis: Region-specific analysis with implicit narratives
        - regional_summary: Regional-level gap summary
        - top_drivers: Top drivers by region
        - actual_vs_decomposed: Validation of decomposition accuracy
        - mathematical_edge_cases: Detection results for identical rates/mix scenarios
        - simpson_paradox_cases: Detection results for Simpson's paradox

    Example:
        # Basic analysis
        results = enhanced_rca_analysis(df=df, category_columns="product")

        # Analysis with automatic edge case detection
        results = enhanced_rca_analysis(
            df=df, category_columns=["product"], subcategory_columns=["vertical"]
        )
    """

    # Ensure category_columns is a list
    if isinstance(category_columns, str):
        category_columns = [category_columns]

    # Step 1: Calculate actual regional rates for validation
    actual_regional_rates = {}
    for region in df[region_column].unique():
        region_data = _filter_region_data(df, region_column, region)
        rest_data = _filter_region_data(df, region_column, region, exclude=True)

        if len(rest_data) == 0:
            continue

        region_rates = _calculate_rates_from_components(
            region_data, [], numerator_column, denominator_column
        )
        rest_rates = _calculate_rates_from_components(
            rest_data, [], numerator_column, denominator_column
        )

        region_rate = region_rates.iloc[0] if len(region_rates) > 0 else 0
        rest_rate = rest_rates.iloc[0] if len(rest_rates) > 0 else 0

        actual_regional_rates[region] = {
            "region_rate": region_rate,
            "rest_rate": rest_rate,
            "actual_gap": region_rate - rest_rate,
        }

    # Step 2: Detect mathematical cancellation scenarios FIRST
    mathematical_edge_cases = detect_mathematical_cancellation_scenarios(
        df=df,
        region_column=region_column,
        numerator_column=numerator_column,
        denominator_column=denominator_column,
        category_columns=category_columns,
        tolerance=0.001,
    )

    # Step 3: Category-level decomposition
    category_results = oaxaca_blinder_decomposition(
        df,
        region_column=region_column,
        numerator_column=numerator_column,
        denominator_column=denominator_column,
        category_columns=category_columns,
        **kwargs,
    )

    # Step 4: Subcategory-level decomposition (if provided)
    subcategory_results = None
    if subcategory_columns is not None:
        subcategory_results = oaxaca_blinder_decomposition(
            df,
            region_column=region_column,
            numerator_column=numerator_column,
            denominator_column=denominator_column,
            category_columns=category_columns + subcategory_columns,
            **kwargs,
        )

    # Step 5: Simpson's paradox detection
    simpson_check = pd.DataFrame()
    if subcategory_columns is not None:
        simpson_check = detect_aggregation_bias(
            df,
            region_column=region_column,
            numerator_column=numerator_column,
            denominator_column=denominator_column,
            category_columns=category_columns,
            subcategory_columns=subcategory_columns,
            **kwargs,
        )

    # Step 6: Region-specific analysis with comprehensive edge case handling
    regional_analysis = {}

    for region in category_results["region"].unique():
        # Check if this specific region has Simpson's paradox
        region_paradox = (
            simpson_check[simpson_check["region"] == region]
            if len(simpson_check) > 0
            else pd.DataFrame()
        )

        has_simpson_paradox = len(region_paradox) > 0

        # Check if this region is affected by mathematical cancellation
        region_affected_by_cancellation = _check_region_affected_by_cancellation(
            region, mathematical_edge_cases
        )

        # Choose analysis level based on edge case detection
        if has_simpson_paradox and subcategory_results is not None:
            # Use subcategory analysis for the ENTIRE region when Simpson's paradox is detected
            region_data_for_analysis = subcategory_results[
                subcategory_results["region"] == region
            ]
            analysis_note = "detailed subcategory analysis used due to Simpson's paradox"
        elif region_affected_by_cancellation and subcategory_results is not None:
            # Use subcategory analysis when mathematical cancellation is detected
            region_data_for_analysis = subcategory_results[
                subcategory_results["region"] == region
            ]
            analysis_note = "detailed subcategory analysis used due to mathematical cancellation"
        else:
            # Use category analysis when no edge cases
            region_data_for_analysis = category_results[
                category_results["region"] == region
            ]
            analysis_note = "standard category analysis"

        # Get actual regional gap for reconciliation
        actual_gap = actual_regional_rates[region]["actual_gap"]

        # Generate enhanced business narrative with edge case awareness
        business_narrative = _generate_enhanced_business_narrative(
            region_data_for_analysis,
            region,
            has_simpson_paradox,
            region_affected_by_cancellation,
            mathematical_edge_cases,
            actual_gap
        )

        # Determine business conclusion based on gaps
        total_construct = region_data_for_analysis["construct_gap_contribution"].sum()
        total_performance = region_data_for_analysis[
            "performance_gap_contribution"
        ].sum()
        business_focus = determine_business_focus(total_construct, total_performance)
        business_conclusion = business_focus["business_conclusion"]

        regional_analysis[region] = {
            "region": region,
            "business_narrative": business_narrative,
            "business_conclusion": business_conclusion,
            "construct_gap": total_construct,
            "performance_gap": total_performance,
            "total_gap": region_data_for_analysis["net_gap"].sum(),
            "actual_gap": actual_gap,
            "analysis_data": region_data_for_analysis,
            "analysis_note": analysis_note,
            "has_simpson_paradox": has_simpson_paradox,
            "affected_by_cancellation": region_affected_by_cancellation,
        }

    # Step 7: Determine overall primary analysis
    # Use the most detailed level available
    if subcategory_results is not None:
        primary_analysis = subcategory_results
    else:
        primary_analysis = category_results

    # Step 8: Validate decomposition accuracy
    regional_summary = calculate_regional_gaps(primary_analysis)
    actual_vs_decomposed = []

    for region in regional_summary["region"]:
        actual_data = actual_regional_rates[region]
        decomposed_gap = regional_summary[regional_summary["region"] == region][
            "total_net_gap"
        ].iloc[0]

        actual_vs_decomposed.append(
            {
                "region": region,
                "actual_gap": actual_data["actual_gap"],
                "decomposed_gap": decomposed_gap,
                "validation_error": abs(actual_data["actual_gap"] - decomposed_gap),
                "region_rate": actual_data["region_rate"],
                "rest_rate": actual_data["rest_rate"],
            }
        )

    # Step 9: Add top drivers
    top_drivers = {}
    for region in primary_analysis["region"].unique():
        top_drivers[region] = find_top_drivers(
            primary_analysis, region=region, metric_type="weighted_impact", top_n=5
        )

    results = {
        "primary_analysis": primary_analysis,
        "regional_analysis": regional_analysis,
        "regional_summary": regional_summary,
        "top_drivers": top_drivers,
        "actual_vs_decomposed": pd.DataFrame(actual_vs_decomposed),
        "mathematical_edge_cases": mathematical_edge_cases,
        "simpson_paradox_cases": simpson_check,
    }

    return results


def _check_region_affected_by_cancellation(region: str, edge_cases: Dict) -> bool:
    """Check if a region is affected by mathematical cancellation scenarios."""

    # Check identical rates scenarios
    for scenario in edge_cases.get("identical_rates_detected", []):
        if region in scenario.get("regions_affected", []):
            return True

    # Check identical mix scenarios
    for scenario in edge_cases.get("identical_mix_detected", []):
        if region in scenario.get("regions_affected", []):
            return True

    return False


def _generate_enhanced_business_narrative(
    region_data: pd.DataFrame,
    region: str,
    has_simpson_paradox: bool = False,
    affected_by_cancellation: bool = False,
    edge_cases: Dict = None,
    actual_gap: float = None,
) -> str:
    """Generate enhanced business narrative with edge case awareness."""

    if len(region_data) == 0:
        return f"No data available for {region}"

    # Handle mathematical cancellation scenarios first
    if affected_by_cancellation and edge_cases:
        cancellation_insights = _extract_cancellation_insights_for_region(region, edge_cases)
        if cancellation_insights:
            headline_gap = actual_gap if actual_gap is not None else region_data["net_gap"].sum()
            headline_magnitude = abs(headline_gap) * 100
            direction = "outperforms" if headline_gap > 0 else "underperforms"
            return f"{region} {direction} by {headline_magnitude:.1f}pp. {cancellation_insights}"

    # Generate standard narrative using decomposed functions
    headline_narrative = _generate_headline_narrative(region_data, region, actual_gap)
    root_cause_narrative = _generate_root_cause_narrative(region_data, actual_gap)
    category_insights = _generate_category_insights(region_data, actual_gap)

    return headline_narrative + root_cause_narrative + category_insights


def _generate_headline_narrative(region_data: pd.DataFrame, region: str, actual_gap: float = None) -> str:
    """Generate the headline performance narrative for a region."""
    total_gap = region_data["net_gap"].sum()
    headline_gap = actual_gap if actual_gap is not None else total_gap
    headline_magnitude = abs(headline_gap) * 100
    direction = "outperforms" if headline_gap > 0 else "underperforms"

    return f"{region} {direction} by {headline_magnitude:.1f}pp"


def _generate_root_cause_narrative(region_data: pd.DataFrame, actual_gap: float = None) -> str:
    """Generate the root cause explanation narrative."""
    total_construct = region_data["construct_gap_contribution"].sum()
    total_performance = region_data["performance_gap_contribution"].sum()
    total_gap = region_data["net_gap"].sum()
    headline_gap = actual_gap if actual_gap is not None else total_gap

    business_focus = determine_business_focus(total_construct, total_performance)
    business_conclusion = business_focus["business_conclusion"]

    if business_conclusion == "composition_driven":
        if headline_gap < 0:
            primary_issue = "portfolio mix needs rebalancing"
        else:
            primary_issue = "strong portfolio positioning"
    elif business_conclusion == "performance_driven":
        if headline_gap < 0:
            primary_issue = "execution gaps within current focus areas"
        else:
            primary_issue = "strong execution across focus areas"
    else:
        if headline_gap < 0:
            primary_issue = "both allocation and execution opportunities"
        else:
            primary_issue = "balanced strength in allocation and execution"

    if headline_gap < 0:
        return f". Root cause: {primary_issue}"
    else:
        return f". Success driver: {primary_issue}"


def _generate_category_insights(region_data: pd.DataFrame, actual_gap: float = None) -> str:
    """Generate specific category-level insights."""
    total_gap = region_data["net_gap"].sum()
    headline_gap = actual_gap if actual_gap is not None else total_gap

    # Find significant categories
    significant_categories = region_data[
        abs(region_data["net_gap"]) > AnalysisThresholds.MINOR_DETECTABLE_GAP
    ].copy()
    significant_categories = significant_categories.sort_values("net_gap", ascending=True)

    # Get top 2 most impactful categories
    top_impact_categories = (
        significant_categories.head(2) if headline_gap < 0 else significant_categories.tail(2)
    )

    key_insights = []
    for _, row in top_impact_categories.iterrows():
        if abs(row["net_gap"]) > AnalysisThresholds.MINOR_DETECTABLE_GAP:
            insight = _generate_single_category_insight(row)
            key_insights.append(insight)

    # Build insights text
    if key_insights:
        if headline_gap < 0:
            return f". Key areas: {'; '.join(key_insights[:2])}"
        else:
            return f". Key strengths: {'; '.join(key_insights[:2])}"
    else:
        return ""


def _generate_single_category_insight(row: pd.Series) -> str:
    """Generate insight for a single category."""
    category_name = CategoryDataExtractor.extract_name(row)
    construct_impact = row["construct_gap_contribution"]
    performance_impact = row["performance_gap_contribution"]

    # Use existing business focus determination
    category_business_focus = determine_business_focus(construct_impact, performance_impact)

    if category_business_focus["construct_dominance"]:
        # Allocation-focused insight using centralized thresholds
        region_mix = row["region_mix_pct"]
        rest_mix = row["rest_mix_pct"]
        rest_rate = row["rest_rate"]

        if region_mix < rest_mix * AnalysisThresholds.UNDER_ALLOCATION_THRESHOLD:
            if rest_rate > AnalysisThresholds.HIGH_PERFORMANCE_THRESHOLD:
                return f"under-allocated to high-performing {category_name} ({region_mix:.0%} vs {rest_mix:.0%} benchmark) - missing growth opportunity"
            else:
                return f"under-allocated to {category_name} ({region_mix:.0%} vs {rest_mix:.0%} benchmark) - but this may be strategic"
        elif region_mix > rest_mix * AnalysisThresholds.OVER_ALLOCATION_THRESHOLD:
            if rest_rate < AnalysisThresholds.HIGH_PERFORMANCE_THRESHOLD:
                return f"over-allocated to low-performing {category_name} ({region_mix:.0%} vs {rest_mix:.0%} benchmark) - dragging down overall results"
            else:
                return f"over-allocated to high-performing {category_name} ({region_mix:.0%} vs {rest_mix:.0%} benchmark) - good strategic focus"
        else:
            return f"allocation gap in {category_name} ({region_mix:.0%} vs {rest_mix:.0%} benchmark)"
    else:
        # Performance-focused insight
        region_rate = row["region_rate"]
        rest_rate = row["rest_rate"]
        net_gap = row["net_gap"]

        if net_gap < 0:
            return f"underperforming in {category_name} ({region_rate:.0%} vs {rest_rate:.0%} benchmark)"
        else:
            return f"strong execution in {category_name} ({region_rate:.0%} vs {rest_rate:.0%} benchmark)"


def _extract_cancellation_insights_for_region(region: str, edge_cases: Dict) -> str:
    """Extract specific insights for a region affected by mathematical cancellation."""

    insights = []

    # Check for relevant alternative insights
    for insight in edge_cases.get("alternative_insights", []):
        if region in insight.get("specific_insight", ""):
            if insight["type"] == "identical_rates_strategic_insight":
                insights.append(f"Strategic insight: {insight['specific_insight']} - {insight['business_implication']}")
            elif insight["type"] == "identical_mix_performance_insight":
                insights.append(f"Execution insight: {insight['specific_insight']} - {insight['business_implication']}")

    if insights:
        return " ".join(insights)

    # Fallback for detected scenarios without specific insights
    for scenario in edge_cases.get("identical_rates_detected", []):
        if region in scenario.get("regions_affected", []):
            return f"Note: Standard decomposition shows mathematical cancellation due to identical {scenario['identical_rate']:.1%} rates across segments. Consider volume-weighted strategic analysis."

    for scenario in edge_cases.get("identical_mix_detected", []):
        if region in scenario.get("regions_affected", []):
            return f"Note: Standard decomposition shows mathematical cancellation due to identical {scenario['identical_mix']:.1%} allocation across segments. Focus on execution differences."

    return ""




class CategoryDataExtractor:
    """Unified class for standardized category data extraction and processing."""

    @staticmethod
    def extract_name(row: pd.Series) -> str:
        """Extract clean category name from decomposition row."""
        category_cols = [col for col in row.index if col.startswith("category_")]
        if len(category_cols) == 1:
            return str(row[category_cols[0]])
        elif len(category_cols) > 1:
            parts = [str(row[col]) for col in sorted(category_cols)]
            return "-".join(parts)
        else:
            return "Unknown"

    @staticmethod
    def filter_data(df: pd.DataFrame, category_columns: List[str], target_category) -> pd.DataFrame:
        """Unified utility function for filtering DataFrame by category values."""
        filtered_data = df.copy()

        for i, col in enumerate(category_columns):
            if isinstance(target_category, tuple):
                filtered_data = filtered_data[filtered_data[col] == target_category[i]]
            else:
                filtered_data = filtered_data[filtered_data[col] == target_category]

        return filtered_data

    @staticmethod
    def calculate_metrics(
        df: pd.DataFrame,
        region_column: str,
        numerator_column: str,
        denominator_column: str,
        category_columns: List[str],
        target_category,
        region: str,
    ) -> Dict:
        """Unified utility function for calculating category-specific metrics for a region."""
        # Get region data
        region_data = _filter_region_data(df, region_column, region)

        # Filter for specific category
        category_data = CategoryDataExtractor.filter_data(region_data, category_columns, target_category)

        # Calculate metrics
        total_volume = region_data[denominator_column].sum()
        category_volume = category_data[denominator_column].sum()
        category_numerator = category_data[numerator_column].sum()

        # Calculate rate and mix
        category_rate = category_numerator / category_volume if category_volume > 0 else 0
        category_mix = category_volume / total_volume if total_volume > 0 else 0

        return {
            "rate": category_rate,
            "volume": category_volume,
            "numerator": category_numerator,
            "mix_percentage": category_mix,
            "total_volume": total_volume
        }




def detect_mathematical_cancellation_scenarios(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: List[str],
    tolerance: float = 0.001,
) -> Dict:
    """
    Detect scenarios where mathematical cancellation masks business insights.

    This function identifies two critical edge cases:
    1. Identical rates scenario: When subcategories have identical performance rates
       across regions, construct gaps cancel out mathematically
    2. Identical mix scenario: When regions have identical percentage shares for
       certain categories, similar cancellation can occur

    Args:
        df: Input DataFrame with regional performance data
        region_column: Column name containing region identifiers
        numerator_column: Column name containing numerator for rate calculation
        denominator_column: Column name containing denominator for rate calculation
        category_columns: List of category column names
        tolerance: Tolerance for considering rates/mixes as "identical" (default: 0.001)

    Returns:
        Dictionary containing:
        - identical_rates_detected: List of scenarios where rates are identical
        - identical_mix_detected: List of scenarios where mix is identical
        - alternative_insights: Business insights when mathematical cancellation occurs
        - recommendations: Specific recommendations for handling these scenarios
    """

    results = {
        "identical_rates_detected": [],
        "identical_mix_detected": [],
        "alternative_insights": [],
        "recommendations": []
    }

    # Get unique regions
    regions = df[region_column].unique()
    if len(regions) < 2:
        return results

    # Check for identical rates scenario
    identical_rates_scenarios = _detect_identical_rates_scenario(
        df, region_column, numerator_column, denominator_column,
        category_columns, tolerance
    )
    results["identical_rates_detected"] = identical_rates_scenarios

    # Check for identical mix scenario
    identical_mix_scenarios = _detect_identical_mix_scenario(
        df, region_column, denominator_column, category_columns, tolerance
    )
    results["identical_mix_detected"] = identical_mix_scenarios

    # Generate alternative insights for detected scenarios
    if identical_rates_scenarios or identical_mix_scenarios:
        alternative_insights = _generate_alternative_insights(
            df, region_column, numerator_column, denominator_column,
            category_columns, identical_rates_scenarios, identical_mix_scenarios
        )
        results["alternative_insights"] = alternative_insights

        # Generate recommendations
        recommendations = _generate_cancellation_recommendations(
            identical_rates_scenarios, identical_mix_scenarios
        )
        results["recommendations"] = recommendations

    return results


def _detect_identical_rates_scenario(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: List[str],
    tolerance: float,
) -> List[Dict]:
    """Detect when subcategories have identical rates across regions."""

    identical_rates_scenarios = []

    # Calculate rates by region and category using existing utility function
    rates_by_region_category = {}

    for region in df[region_column].unique():
        region_data = _filter_region_data(df, region_column, region)
        region_rates = _calculate_rates_from_components(
            region_data, category_columns, numerator_column, denominator_column
        )
        rates_by_region_category[region] = region_rates

    # Get all unique categories
    all_categories = set()
    for rates in rates_by_region_category.values():
        all_categories.update(rates.index)

    # Check each category for identical rates across regions
    for category in all_categories:
        category_rates = []
        regions_with_data = []

        for region, rates in rates_by_region_category.items():
            if category in rates and not pd.isna(rates[category]):
                category_rates.append(rates[category])
                regions_with_data.append(region)

        if len(category_rates) >= 2:
            # Check if all rates are identical (within tolerance)
            rate_differences = [abs(rate - category_rates[0]) for rate in category_rates]
            max_difference = max(rate_differences)

            if max_difference <= tolerance:
                # Calculate mix differences to see if there's still business insight
                mix_data = _calculate_mix_differences_for_category(
                    df, region_column, denominator_column, category_columns,
                    category, regions_with_data
                )

                identical_rates_scenarios.append({
                    "category": category,
                    "regions_affected": regions_with_data,
                    "identical_rate": category_rates[0],
                    "max_rate_difference": max_difference,
                    "mix_data": mix_data,
                    "business_impact": _assess_identical_rates_business_impact(mix_data)
                })

    return identical_rates_scenarios


def _detect_identical_mix_scenario(
    df: pd.DataFrame,
    region_column: str,
    denominator_column: str,
    category_columns: List[str],
    tolerance: float,
) -> List[Dict]:
    """Detect when regions have identical mix percentages for certain categories."""

    identical_mix_scenarios = []

    # Calculate mix percentages by region using existing utility function
    mix_by_region = {}

    for region in df[region_column].unique():
        region_data = _filter_region_data(df, region_column, region)
        region_mix_pct = _calculate_region_mix(region_data, category_columns, denominator_column)
        mix_by_region[region] = region_mix_pct

    # Get all unique categories
    all_categories = set()
    for mix_pct in mix_by_region.values():
        all_categories.update(mix_pct.index)

    # Check each category for identical mix across regions
    for category in all_categories:
        category_mixes = []
        regions_with_data = []

        for region, mix_pct in mix_by_region.items():
            if category in mix_pct and not pd.isna(mix_pct[category]):
                category_mixes.append(mix_pct[category])
                regions_with_data.append(region)

        if len(category_mixes) >= 2:
            # Check if all mixes are identical (within tolerance)
            mix_differences = [abs(mix - category_mixes[0]) for mix in category_mixes]
            max_difference = max(mix_differences)

            if max_difference <= tolerance:
                identical_mix_scenarios.append({
                    "category": category,
                    "regions_affected": regions_with_data,
                    "identical_mix": category_mixes[0],
                    "max_mix_difference": max_difference,
                })

    return identical_mix_scenarios


def _calculate_mix_differences_for_category(
    df: pd.DataFrame,
    region_column: str,
    denominator_column: str,
    category_columns: List[str],
    target_category,
    regions: List[str],
) -> Dict:
    """Calculate mix allocation differences for a specific category across regions."""

    mix_data = {}

    for region in regions:
        # Use consolidated utility function
        metrics = CategoryDataExtractor.calculate_metrics(
            df, region_column, "dummy_numerator", denominator_column,
            category_columns, target_category, region
        )

        mix_data[region] = {
            "mix_percentage": metrics["mix_percentage"],
            "absolute_volume": metrics["volume"],
            "total_volume": metrics["total_volume"]
        }

    return mix_data


def _assess_identical_rates_business_impact(mix_data: Dict) -> str:
    """Assess the business impact when rates are identical but mix differs."""

    if not mix_data or len(mix_data) < 2:
        return "low"

    # Calculate mix variation
    mix_percentages = [data["mix_percentage"] for data in mix_data.values()]
    min_mix = min(mix_percentages)
    max_mix = max(mix_percentages)
    mix_range = max_mix - min_mix

    # Assess impact based on mix variation using centralized thresholds
    if mix_range > AnalysisThresholds.HIGH_MIX_VARIATION_THRESHOLD:
        return "high"
    elif mix_range > AnalysisThresholds.MEDIUM_MIX_VARIATION_THRESHOLD:
        return "medium"
    else:
        return "low"


def _generate_alternative_insights(
    df: pd.DataFrame,
    region_column: str,
    numerator_column: str,
    denominator_column: str,
    category_columns: List[str],
    identical_rates_scenarios: List[Dict],
    identical_mix_scenarios: List[Dict],
) -> List[Dict]:
    """Generate alternative business insights when mathematical cancellation occurs."""

    insights = []

    # Generate insights for identical rates scenarios
    for scenario in identical_rates_scenarios:
        if scenario["business_impact"] in ["medium", "high"]:
            category = scenario["category"]
            mix_data = scenario["mix_data"]

            # Find regions with highest and lowest allocation
            region_allocations = [(region, data["mix_percentage"])
                                for region, data in mix_data.items()]
            region_allocations.sort(key=lambda x: x[1])

            lowest_region, lowest_mix = region_allocations[0]
            highest_region, highest_mix = region_allocations[-1]

            mix_gap = (highest_mix - lowest_mix) * 100  # Convert to percentage points

            insight = {
                "type": "identical_rates_strategic_insight",
                "category": category,
                "identical_rate": scenario["identical_rate"],
                "key_finding": f"Despite identical {scenario['identical_rate']:.1%} performance rates, "
                             f"allocation varies significantly across regions",
                "specific_insight": f"{highest_region} allocates {mix_gap:.1f}pp more to {category} "
                                 f"than {lowest_region} ({highest_mix:.1%} vs {lowest_mix:.1%})",
                "business_implication": _determine_allocation_implication(
                    scenario["identical_rate"], mix_gap
                ),
                "recommended_analysis": "Consider segment value weighting or strategic importance analysis"
            }
            insights.append(insight)

    # Generate insights for identical mix scenarios
    for scenario in identical_mix_scenarios:
        category = scenario["category"]
        regions = scenario["regions_affected"]

        # Calculate rate differences for this category across regions using existing utility
        rate_differences = []
        for region in regions:
            # Use consolidated utility function
            metrics = CategoryDataExtractor.calculate_metrics(
                df, region_column, numerator_column, denominator_column,
                category_columns, category, region
            )
            rate_differences.append((region, metrics["rate"]))

        if len(rate_differences) >= 2:
            rate_differences.sort(key=lambda x: x[1])
            lowest_region, lowest_rate = rate_differences[0]
            highest_region, highest_rate = rate_differences[-1]

            rate_gap = (highest_rate - lowest_rate) * 100

            if rate_gap > AnalysisThresholds.SIGNIFICANT_RATE_DIFFERENCE:  # Use centralized threshold
                insight = {
                    "type": "identical_mix_performance_insight",
                    "category": category,
                    "identical_mix": scenario["identical_mix"],
                    "key_finding": f"Despite identical {scenario['identical_mix']:.1%} allocation, "
                                 f"performance varies significantly across regions",
                    "specific_insight": f"{highest_region} outperforms {lowest_region} by "
                                      f"{rate_gap:.1f}pp in {category} ({highest_rate:.1%} vs {lowest_rate:.1%})",
                    "business_implication": "Execution gap - opportunity to replicate best practices",
                    "recommended_analysis": "Deep-dive into operational differences and best practice sharing"
                }
                insights.append(insight)

    return insights


def _determine_allocation_implication(rate: float, mix_gap: float) -> str:
    """Determine business implication of allocation differences given performance rate."""

    if rate > AnalysisThresholds.SIGNIFICANT_ALLOCATION_IMPLICATION_THRESHOLD:  # High-performing segment
        if mix_gap > AnalysisThresholds.SIGNIFICANT_MIX_DIFFERENCE_THRESHOLD:
            return "Significant missed opportunity - under-allocated region missing high-value growth"
        else:
            return "Moderate opportunity - consider increasing allocation to high-performing segment"
    elif rate < AnalysisThresholds.LOW_PERFORMING_ALLOCATION_THRESHOLD:  # Low-performing segment
        if mix_gap > AnalysisThresholds.SIGNIFICANT_MIX_DIFFERENCE_THRESHOLD:
            return "Strategic concern - over-allocated region may be dragging down results"
        else:
            return "Minor strategic difference - monitor allocation efficiency"
    else:  # Medium-performing segment
        return "Strategic choice difference - evaluate based on broader portfolio strategy"


def _generate_cancellation_recommendations(
    identical_rates_scenarios: List[Dict],
    identical_mix_scenarios: List[Dict],
) -> List[str]:
    """Generate specific recommendations for handling mathematical cancellation scenarios."""

    recommendations = []

    if identical_rates_scenarios:
        recommendations.append(
            "Identical rates detected: Supplement standard decomposition with volume-weighted "
            "strategic analysis to capture allocation impact"
        )

        high_impact_scenarios = [s for s in identical_rates_scenarios
                               if s["business_impact"] == "high"]
        if high_impact_scenarios:
            recommendations.append(
                "High-impact allocation differences found: Consider segment value weighting "
                "or strategic importance scoring to quantify true business impact"
            )

    if identical_mix_scenarios:
        recommendations.append(
            "Identical mix detected: Focus analysis on execution differences and "
            "best practice identification across regions"
        )

    if identical_rates_scenarios and identical_mix_scenarios:
        recommendations.append(
            "Multiple cancellation scenarios detected: Consider multi-dimensional analysis "
            "combining volume, value, and strategic importance metrics"
        )

    return recommendations

Last updated