from pathlib import Path from typing import Generator, Dict, Tuple, Optional import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.ticker as mticker from scipy.stats import linregress from datetime import datetime from ..utils import style_plot, LIME def _generate_wage_projection_data( metr_results: Dict, df_with_wages: pd.DataFrame, percentile_key: str, doubling_time_modifier: float, ) -> Optional[Tuple[pd.DataFrame, pd.DataFrame, float]]: """ Generates wage projection data for different AI progress scenarios. Args: metr_results: The METR benchmark data. df_with_wages: DataFrame containing tasks with their estimated wage value. percentile_key: The percentile to use from METR data (e.g., 'p50_horizon_length'). doubling_time_modifier: Multiplier for the doubling time (e.g., 1.0 for baseline, 0.5 for optimistic, 2.0 for pessimistic). Returns: A tuple of (metr_df, projection_df, doubling_time_days), or None if data is insufficient. """ all_model_data = [] for model_name, data in metr_results.get("results", {}).items(): for agent_name, agent_data in data.get("agents", {}).items(): release_date_str = data.get("release_date") horizon = agent_data.get(percentile_key, {}).get("estimate") if release_date_str and horizon is not None: all_model_data.append({ "release_date": release_date_str, "horizon_minutes": horizon, }) if not all_model_data: return None metr_df = pd.DataFrame(all_model_data).sort_values("release_date").reset_index(drop=True) metr_df['release_date'] = pd.to_datetime(metr_df['release_date']) metr_df = metr_df[metr_df['horizon_minutes'] > 0].copy() if len(metr_df) < 2: return None metr_df['days_since_start'] = (metr_df['release_date'] - metr_df['release_date'].min()).dt.days log_y = np.log(metr_df['horizon_minutes']) slope, intercept, r_value, _, _ = linregress(metr_df['days_since_start'], log_y) # Apply the scenario modifier to the doubling time base_doubling_time_days = np.log(2) / slope modified_doubling_time_days = base_doubling_time_days * doubling_time_modifier modified_slope = np.log(2) / modified_doubling_time_days start_date = metr_df['release_date'].min() future_dates = pd.to_datetime(pd.date_range(start=start_date, end="2035-01-01", freq="ME")) future_days = (future_dates - start_date).days.to_numpy() projected_log_horizon = intercept + modified_slope * future_days projected_horizon_minutes = np.exp(projected_log_horizon) projection_df = pd.DataFrame({ "date": future_dates, "projected_coherence_minutes": projected_horizon_minutes, }) # Calculate the total wage bill of tasks automated over time for bound in ["lb", "mid", "ub"]: col_name = 'estimate_midpoint' if bound == 'mid' else f'{bound}_estimate_in_minutes' projection_df[f"automatable_wage_bill_{bound}"] = projection_df["projected_coherence_minutes"].apply( lambda h: df_with_wages.loc[df_with_wages[col_name] <= h, 'wage_per_task'].sum() ) # Also calculate for the actual METR data points for plotting metr_df["automatable_wage_bill_mid"] = metr_df["horizon_minutes"].apply( lambda h: df_with_wages.loc[df_with_wages['estimate_midpoint'] <= h, 'wage_per_task'].sum() ) return metr_df, projection_df, modified_doubling_time_days def _plot_scenario(ax, projection_df, metr_df, label, color, line_style='-'): """Helper function to draw a single projection scenario on a given axis.""" # Plot the projected wage bill ax.plot( projection_df["date"], projection_df["automatable_wage_bill_mid"], label=label, color=color, linewidth=2.5, linestyle=line_style, zorder=3 ) # Plot the shaded range for lower/upper bounds ax.fill_between( projection_df["date"], projection_df["automatable_wage_bill_lb"], projection_df["automatable_wage_bill_ub"], color=color, alpha=0.15, zorder=2 ) # Plot the actual METR data points against the wage bill ax.scatter( metr_df['release_date'], metr_df['automatable_wage_bill_mid'], color=color, edgecolor='black', s=60, zorder=4, label=f"Model Capabilities (P50)" ) def generate_projected_automatable_wage_bill( output_dir: Path, df: pd.DataFrame, task_summary_by_occupation_df: pd.DataFrame, metr_results: Dict, **kwargs, ) -> Generator[Path, None, None]: """ Generates a plot projecting the automatable wage bill under different AI progress scenarios (optimistic, baseline, pessimistic). """ style_plot() OUTPUT_PATH = output_dir / "projected_automatable_wage_bill_sensitivity.png" # 1. Calculate wage_per_task for each occupation wage_bill_info = task_summary_by_occupation_df[['onetsoc_code', 'wage_bill', 'total_tasks']].copy() wage_bill_info['wage_per_task'] = wage_bill_info['wage_bill'] / wage_bill_info['total_tasks'] wage_bill_info.replace([np.inf, -np.inf], 0, inplace=True) # Avoid division by zero issues wage_bill_info.drop(columns=['wage_bill', 'total_tasks'], inplace=True) # 2. Merge wage_per_task into the main task dataframe df_with_wages = pd.merge(df, wage_bill_info, on='onetsoc_code', how='left') df_with_wages['wage_per_task'].fillna(0, inplace=True) # 3. Generate data for all three scenarios scenarios = { "Optimistic": {"modifier": 0.5, "color": "tab:green", "style": "--"}, "Baseline": {"modifier": 1.0, "color": LIME['600'], "style": "-"}, "Pessimistic": {"modifier": 2.0, "color": "tab:red", "style": ":"}, } projection_results = {} for name, config in scenarios.items(): result = _generate_wage_projection_data(metr_results, df_with_wages, 'p50_horizon_length', config['modifier']) if result: projection_results[name] = result if not projection_results: print("Warning: Could not generate any projection data. Skipping wage bill plot.") return # 4. Create the plot fig, ax = plt.subplots(figsize=(14, 9)) # We only need to plot the scatter points once, let's use the baseline ones. if "Baseline" in projection_results: metr_df, _, _ = projection_results["Baseline"] ax.scatter( metr_df['release_date'], metr_df['automatable_wage_bill_mid'], color='black', s=80, zorder=5, label=f"Model Capabilities (P50)" ) legend_lines = [] for name, (metr_df, proj_df, doubling_time) in projection_results.items(): config = scenarios[name] ax.plot( proj_df["date"], proj_df["automatable_wage_bill_mid"], color=config['color'], linestyle=config['style'], linewidth=2.5, zorder=3 ) ax.fill_between( proj_df["date"], proj_df["automatable_wage_bill_lb"], proj_df["automatable_wage_bill_ub"], color=config['color'], alpha=0.15, zorder=2 ) # Create a custom line for the legend line = plt.Line2D([0], [0], color=config['color'], linestyle=config['style'], lw=2.5, label=f'{name} (Doubling Time: {doubling_time:.0f} days)') legend_lines.append(line) # 5. Styling and annotations ax.set_title("Projected Automatable Wage Bill (P50 Coherence)", fontsize=18, pad=20) ax.set_xlabel("Year", fontsize=12) ax.set_ylabel("Automatable Annual Wage Bill (Trillions of USD)", fontsize=12) # Format Y-axis to show trillions def trillions_formatter(x, pos): return f'${x / 1e12:.1f}T' ax.yaxis.set_major_formatter(mticker.FuncFormatter(trillions_formatter)) total_wage_bill = df_with_wages['wage_per_task'].sum() ax.set_ylim(0, total_wage_bill * 1.05) if "Baseline" in projection_results: _, proj_df, _ = projection_results["Baseline"] ax.set_xlim(datetime(2022, 1, 1), proj_df["date"].max()) # Create the legend from the custom lines and the scatter plot scatter_legend = ax.get_legend_handles_labels()[0] ax.legend(handles=legend_lines + scatter_legend, loc="upper left", fontsize=11) ax.grid(True, which="both", linestyle="--", linewidth=0.5) plt.tight_layout() plt.savefig(OUTPUT_PATH) plt.close(fig) print(f"Generated sensitivity analysis plot: {OUTPUT_PATH}") yield OUTPUT_PATH