sprint-econtai/analysis/generators/wage_bill_by_occupation.py
Félix Dorn 43076bcbb1 old
2025-07-15 00:41:05 +02:00

150 lines
6.1 KiB
Python

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
from pathlib import Path
import tempfile
import logging
# Assuming data.py is in the same package and provides this function
from ..data import get_db_connection
# This mapping helps translate the O*NET 2-digit major group codes
# into human-readable labels for the plot's y-axis.
OCCUPATION_MAJOR_CODES = {
'11': 'Management',
'13': 'Business & Financial',
'15': 'Computer & Mathematical',
'17': 'Architecture & Engineering',
'19': 'Life, Physical, & Social Science',
'21': 'Community & Social Service',
'23': 'Legal',
'25': 'Education, Training, & Library',
'27': 'Arts, Design, & Media',
'29': 'Healthcare Practitioners',
'31': 'Healthcare Support',
'33': 'Protective Service',
'35': 'Food Preparation & Serving',
'37': 'Building & Grounds Maintenance',
'39': 'Personal Care & Service',
'41': 'Sales & Related',
'43': 'Office & Admin Support',
'45': 'Farming, Fishing, & Forestry',
'47': 'Construction & Extraction',
'49': 'Installation, Maintenance, & Repair',
'51': 'Production',
'53': 'Transportation & Material Moving',
'55': 'Military Specific',
}
def generate(processed_df: pd.DataFrame):
"""
Generates a bar plot of the total wage bill per major occupation group.
This corresponds to the first 'cell11' from the original analysis notebook.
It calculates the total wage bill (Total Employment * Annual Mean Wage) for
each occupation and aggregates it by major occupation group. This generator
loads its data directly from the O*NET database.
Args:
processed_df (pd.DataFrame): The preprocessed data (not used in this generator,
but required by the function signature).
Returns:
Path: The path to the generated temporary image file, or None on failure.
"""
logging.info("Generating plot of total wage bill by occupation...")
conn = None
try:
# --- Data Loading ---
# This generator needs specific data that is not in the main preprocessed_df.
# It loads occupational employment and wage data directly from the database.
conn = get_db_connection()
if conn is None:
raise ConnectionError("Could not get database connection.")
# This data is stored in a long format in the `occupation_level_metadata` table.
# We need to query this table and pivot it to get employment and wage columns.
query = "SELECT onetsoc_code, item, response FROM occupation_level_metadata WHERE item IN ('Employment', 'Annual Mean Wage')"
try:
df_meta = pd.read_sql_query(query, conn)
# Pivot the table to create 'Employment' and 'Annual Mean Wage' columns
df_oesm = df_meta.pivot(index='onetsoc_code', columns='item', values='response').reset_index()
logging.info("Pivoted occupation metadata. Columns are: %s", df_oesm.columns.tolist())
# Rename for consistency with the original notebook's code
df_oesm.rename(columns={
'onetsoc_code': 'OCC_CODE',
'Employment': 'TOT_EMP',
'Annual Mean Wage': 'A_MEAN'
}, inplace=True)
except (pd.io.sql.DatabaseError, KeyError) as e:
logging.error(f"Failed to query or pivot occupation metadata: {e}", exc_info=True)
return None
# --- Data Preparation ---
# Create a 'major group' code from the first two digits of the SOC code
df_oesm['onetsoc_major'] = df_oesm['OCC_CODE'].str[:2]
# Ensure wage and employment columns are numeric, coercing errors to NaN
df_oesm['TOT_EMP'] = pd.to_numeric(df_oesm['TOT_EMP'], errors='coerce')
df_oesm['A_MEAN'] = pd.to_numeric(df_oesm['A_MEAN'], errors='coerce')
# Drop rows with missing data in critical columns
df_oesm.dropna(subset=['TOT_EMP', 'A_MEAN', 'onetsoc_major'], inplace=True)
# Calculate the wage bill for each occupation
df_oesm['wage_bill'] = df_oesm['TOT_EMP'] * df_oesm['A_MEAN']
# Aggregate the wage bill by major occupation group
df_wage_bill_major = df_oesm.groupby('onetsoc_major')['wage_bill'].sum().reset_index()
# Map the major codes to readable titles for plotting
df_wage_bill_major['OCC_TITLE_MAJOR'] = df_wage_bill_major['onetsoc_major'].map(OCCUPATION_MAJOR_CODES)
df_wage_bill_major.dropna(subset=['OCC_TITLE_MAJOR'], inplace=True) # Drop military/unmapped codes
# Sort by wage bill for a more informative plot
df_wage_bill_major = df_wage_bill_major.sort_values('wage_bill', ascending=False)
if df_wage_bill_major.empty:
logging.warning("No data available to generate the wage bill plot.")
return None
# --- Plotting ---
plt.figure(figsize=(12, 10))
ax = sns.barplot(x='wage_bill', y='OCC_TITLE_MAJOR', data=df_wage_bill_major, palette="viridis", orient='h')
ax.set_title('Total Wage Bill per Major Occupation Group', fontsize=16, pad=15)
ax.set_xlabel('Total Wage Bill (in USD)', fontsize=12)
ax.set_ylabel('Major Occupation Group', fontsize=12)
ax.grid(axis='x', linestyle='--', alpha=0.7)
# Format the x-axis to be more readable (e.g., "$2.0T" for trillions)
def format_billions(x, pos):
if x >= 1e12:
return f'${x*1e-12:.1f}T'
if x >= 1e9:
return f'${x*1e-9:.0f}B'
return f'${x*1e-6:.0f}M'
ax.xaxis.set_major_formatter(mticker.FuncFormatter(format_billions))
plt.tight_layout()
# --- File Saving ---
temp_dir = tempfile.gettempdir()
temp_path = Path(temp_dir) / "wage_bill_by_occupation.png"
plt.savefig(temp_path, dpi=300)
logging.info(f"Successfully saved plot to temporary file: {temp_path}")
return temp_path
except Exception as e:
logging.error(f"An error occurred while generating the wage bill plot: {e}", exc_info=True)
return None
finally:
plt.close()
if conn:
conn.close()