229 lines
8.6 KiB
Python
229 lines
8.6 KiB
Python
from pathlib import Path
|
|
from typing import Generator, Dict, Tuple, Optional
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mticker
|
|
from scipy.stats import linregress
|
|
from datetime import datetime
|
|
from ..utils import style_plot, LIME
|
|
|
|
def _generate_wage_projection_data(
|
|
metr_results: Dict,
|
|
df_with_wages: pd.DataFrame,
|
|
percentile_key: str,
|
|
doubling_time_modifier: float,
|
|
) -> Optional[Tuple[pd.DataFrame, pd.DataFrame, float]]:
|
|
"""
|
|
Generates wage projection data for different AI progress scenarios.
|
|
|
|
Args:
|
|
metr_results: The METR benchmark data.
|
|
df_with_wages: DataFrame containing tasks with their estimated wage value.
|
|
percentile_key: The percentile to use from METR data (e.g., 'p50_horizon_length').
|
|
doubling_time_modifier: Multiplier for the doubling time (e.g., 1.0 for baseline,
|
|
0.5 for optimistic, 2.0 for pessimistic).
|
|
|
|
Returns:
|
|
A tuple of (metr_df, projection_df, doubling_time_days), or None if data is insufficient.
|
|
"""
|
|
all_model_data = []
|
|
for model_name, data in metr_results.get("results", {}).items():
|
|
for agent_name, agent_data in data.get("agents", {}).items():
|
|
release_date_str = data.get("release_date")
|
|
horizon = agent_data.get(percentile_key, {}).get("estimate")
|
|
if release_date_str and horizon is not None:
|
|
all_model_data.append({
|
|
"release_date": release_date_str,
|
|
"horizon_minutes": horizon,
|
|
})
|
|
|
|
if not all_model_data:
|
|
return None
|
|
|
|
metr_df = pd.DataFrame(all_model_data).sort_values("release_date").reset_index(drop=True)
|
|
metr_df['release_date'] = pd.to_datetime(metr_df['release_date'])
|
|
metr_df = metr_df[metr_df['horizon_minutes'] > 0].copy()
|
|
|
|
if len(metr_df) < 2:
|
|
return None
|
|
|
|
metr_df['days_since_start'] = (metr_df['release_date'] - metr_df['release_date'].min()).dt.days
|
|
log_y = np.log(metr_df['horizon_minutes'])
|
|
slope, intercept, r_value, _, _ = linregress(metr_df['days_since_start'], log_y)
|
|
|
|
# Apply the scenario modifier to the doubling time
|
|
base_doubling_time_days = np.log(2) / slope
|
|
modified_doubling_time_days = base_doubling_time_days * doubling_time_modifier
|
|
modified_slope = np.log(2) / modified_doubling_time_days
|
|
|
|
start_date = metr_df['release_date'].min()
|
|
future_dates = pd.to_datetime(pd.date_range(start=start_date, end="2035-01-01", freq="ME"))
|
|
future_days = (future_dates - start_date).days.to_numpy()
|
|
|
|
projected_log_horizon = intercept + modified_slope * future_days
|
|
projected_horizon_minutes = np.exp(projected_log_horizon)
|
|
|
|
projection_df = pd.DataFrame({
|
|
"date": future_dates,
|
|
"projected_coherence_minutes": projected_horizon_minutes,
|
|
})
|
|
|
|
# Calculate the total wage bill of tasks automated over time
|
|
for bound in ["lb", "mid", "ub"]:
|
|
col_name = 'estimate_midpoint' if bound == 'mid' else f'{bound}_estimate_in_minutes'
|
|
projection_df[f"automatable_wage_bill_{bound}"] = projection_df["projected_coherence_minutes"].apply(
|
|
lambda h: df_with_wages.loc[df_with_wages[col_name] <= h, 'wage_per_task'].sum()
|
|
)
|
|
|
|
# Also calculate for the actual METR data points for plotting
|
|
metr_df["automatable_wage_bill_mid"] = metr_df["horizon_minutes"].apply(
|
|
lambda h: df_with_wages.loc[df_with_wages['estimate_midpoint'] <= h, 'wage_per_task'].sum()
|
|
)
|
|
|
|
return metr_df, projection_df, modified_doubling_time_days
|
|
|
|
|
|
def _plot_scenario(ax, projection_df, metr_df, label, color, line_style='-'):
|
|
"""Helper function to draw a single projection scenario on a given axis."""
|
|
# Plot the projected wage bill
|
|
ax.plot(
|
|
projection_df["date"],
|
|
projection_df["automatable_wage_bill_mid"],
|
|
label=label,
|
|
color=color,
|
|
linewidth=2.5,
|
|
linestyle=line_style,
|
|
zorder=3
|
|
)
|
|
# Plot the shaded range for lower/upper bounds
|
|
ax.fill_between(
|
|
projection_df["date"],
|
|
projection_df["automatable_wage_bill_lb"],
|
|
projection_df["automatable_wage_bill_ub"],
|
|
color=color,
|
|
alpha=0.15,
|
|
zorder=2
|
|
)
|
|
# Plot the actual METR data points against the wage bill
|
|
ax.scatter(
|
|
metr_df['release_date'],
|
|
metr_df['automatable_wage_bill_mid'],
|
|
color=color,
|
|
edgecolor='black',
|
|
s=60,
|
|
zorder=4,
|
|
label=f"Model Capabilities (P50)"
|
|
)
|
|
|
|
|
|
def generate_projected_automatable_wage_bill(
|
|
output_dir: Path,
|
|
df: pd.DataFrame,
|
|
task_summary_by_occupation_df: pd.DataFrame,
|
|
metr_results: Dict,
|
|
**kwargs,
|
|
) -> Generator[Path, None, None]:
|
|
"""
|
|
Generates a plot projecting the automatable wage bill under different
|
|
AI progress scenarios (optimistic, baseline, pessimistic).
|
|
"""
|
|
style_plot()
|
|
OUTPUT_PATH = output_dir / "projected_automatable_wage_bill_sensitivity.png"
|
|
|
|
# 1. Calculate wage_per_task for each occupation
|
|
wage_bill_info = task_summary_by_occupation_df[['onetsoc_code', 'wage_bill', 'total_tasks']].copy()
|
|
wage_bill_info['wage_per_task'] = wage_bill_info['wage_bill'] / wage_bill_info['total_tasks']
|
|
wage_bill_info.replace([np.inf, -np.inf], 0, inplace=True) # Avoid division by zero issues
|
|
wage_bill_info.drop(columns=['wage_bill', 'total_tasks'], inplace=True)
|
|
|
|
# 2. Merge wage_per_task into the main task dataframe
|
|
df_with_wages = pd.merge(df, wage_bill_info, on='onetsoc_code', how='left')
|
|
df_with_wages['wage_per_task'].fillna(0, inplace=True)
|
|
|
|
# 3. Generate data for all three scenarios
|
|
scenarios = {
|
|
"Optimistic": {"modifier": 0.5, "color": "tab:green", "style": "--"},
|
|
"Baseline": {"modifier": 1.0, "color": LIME['600'], "style": "-"},
|
|
"Pessimistic": {"modifier": 2.0, "color": "tab:red", "style": ":"},
|
|
}
|
|
|
|
projection_results = {}
|
|
for name, config in scenarios.items():
|
|
result = _generate_wage_projection_data(metr_results, df_with_wages, 'p50_horizon_length', config['modifier'])
|
|
if result:
|
|
projection_results[name] = result
|
|
|
|
if not projection_results:
|
|
print("Warning: Could not generate any projection data. Skipping wage bill plot.")
|
|
return
|
|
|
|
# 4. Create the plot
|
|
fig, ax = plt.subplots(figsize=(14, 9))
|
|
|
|
# We only need to plot the scatter points once, let's use the baseline ones.
|
|
if "Baseline" in projection_results:
|
|
metr_df, _, _ = projection_results["Baseline"]
|
|
ax.scatter(
|
|
metr_df['release_date'],
|
|
metr_df['automatable_wage_bill_mid'],
|
|
color='black',
|
|
s=80,
|
|
zorder=5,
|
|
label=f"Model Capabilities (P50)"
|
|
)
|
|
|
|
|
|
legend_lines = []
|
|
for name, (metr_df, proj_df, doubling_time) in projection_results.items():
|
|
config = scenarios[name]
|
|
ax.plot(
|
|
proj_df["date"],
|
|
proj_df["automatable_wage_bill_mid"],
|
|
color=config['color'],
|
|
linestyle=config['style'],
|
|
linewidth=2.5,
|
|
zorder=3
|
|
)
|
|
ax.fill_between(
|
|
proj_df["date"],
|
|
proj_df["automatable_wage_bill_lb"],
|
|
proj_df["automatable_wage_bill_ub"],
|
|
color=config['color'],
|
|
alpha=0.15,
|
|
zorder=2
|
|
)
|
|
# Create a custom line for the legend
|
|
line = plt.Line2D([0], [0], color=config['color'], linestyle=config['style'], lw=2.5,
|
|
label=f'{name} (Doubling Time: {doubling_time:.0f} days)')
|
|
legend_lines.append(line)
|
|
|
|
|
|
# 5. Styling and annotations
|
|
ax.set_title("Projected Automatable Wage Bill (P50 Coherence)", fontsize=18, pad=20)
|
|
ax.set_xlabel("Year", fontsize=12)
|
|
ax.set_ylabel("Automatable Annual Wage Bill (Trillions of USD)", fontsize=12)
|
|
|
|
# Format Y-axis to show trillions
|
|
def trillions_formatter(x, pos):
|
|
return f'${x / 1e12:.1f}T'
|
|
ax.yaxis.set_major_formatter(mticker.FuncFormatter(trillions_formatter))
|
|
|
|
total_wage_bill = df_with_wages['wage_per_task'].sum()
|
|
ax.set_ylim(0, total_wage_bill * 1.05)
|
|
|
|
if "Baseline" in projection_results:
|
|
_, proj_df, _ = projection_results["Baseline"]
|
|
ax.set_xlim(datetime(2022, 1, 1), proj_df["date"].max())
|
|
|
|
# Create the legend from the custom lines and the scatter plot
|
|
scatter_legend = ax.get_legend_handles_labels()[0]
|
|
ax.legend(handles=legend_lines + scatter_legend, loc="upper left", fontsize=11)
|
|
|
|
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
|
|
plt.tight_layout()
|
|
plt.savefig(OUTPUT_PATH)
|
|
plt.close(fig)
|
|
|
|
print(f"Generated sensitivity analysis plot: {OUTPUT_PATH}")
|
|
yield OUTPUT_PATH
|