import seaborn as sns import matplotlib.pyplot as plt import matplotlib.ticker as mticker import pandas as pd from pathlib import Path import tempfile import logging # Assuming data.py is in the same package and provides this function from ..data import get_db_connection # This mapping helps translate the O*NET 2-digit major group codes # into human-readable labels for the plot's y-axis. OCCUPATION_MAJOR_CODES = { '11': 'Management', '13': 'Business & Financial', '15': 'Computer & Mathematical', '17': 'Architecture & Engineering', '19': 'Life, Physical, & Social Science', '21': 'Community & Social Service', '23': 'Legal', '25': 'Education, Training, & Library', '27': 'Arts, Design, & Media', '29': 'Healthcare Practitioners', '31': 'Healthcare Support', '33': 'Protective Service', '35': 'Food Preparation & Serving', '37': 'Building & Grounds Maintenance', '39': 'Personal Care & Service', '41': 'Sales & Related', '43': 'Office & Admin Support', '45': 'Farming, Fishing, & Forestry', '47': 'Construction & Extraction', '49': 'Installation, Maintenance, & Repair', '51': 'Production', '53': 'Transportation & Material Moving', '55': 'Military Specific', } def generate(processed_df: pd.DataFrame): """ Generates a bar plot of the total wage bill per major occupation group. This corresponds to the first 'cell11' from the original analysis notebook. It calculates the total wage bill (Total Employment * Annual Mean Wage) for each occupation and aggregates it by major occupation group. This generator loads its data directly from the O*NET database. Args: processed_df (pd.DataFrame): The preprocessed data (not used in this generator, but required by the function signature). Returns: Path: The path to the generated temporary image file, or None on failure. """ logging.info("Generating plot of total wage bill by occupation...") conn = None try: # --- Data Loading --- # This generator needs specific data that is not in the main preprocessed_df. # It loads occupational employment and wage data directly from the database. conn = get_db_connection() if conn is None: raise ConnectionError("Could not get database connection.") # This data is stored in a long format in the `occupation_level_metadata` table. # We need to query this table and pivot it to get employment and wage columns. query = "SELECT onetsoc_code, item, response FROM occupation_level_metadata WHERE item IN ('Employment', 'Annual Mean Wage')" try: df_meta = pd.read_sql_query(query, conn) # Pivot the table to create 'Employment' and 'Annual Mean Wage' columns df_oesm = df_meta.pivot(index='onetsoc_code', columns='item', values='response').reset_index() logging.info("Pivoted occupation metadata. Columns are: %s", df_oesm.columns.tolist()) # Rename for consistency with the original notebook's code df_oesm.rename(columns={ 'onetsoc_code': 'OCC_CODE', 'Employment': 'TOT_EMP', 'Annual Mean Wage': 'A_MEAN' }, inplace=True) except (pd.io.sql.DatabaseError, KeyError) as e: logging.error(f"Failed to query or pivot occupation metadata: {e}", exc_info=True) return None # --- Data Preparation --- # Create a 'major group' code from the first two digits of the SOC code df_oesm['onetsoc_major'] = df_oesm['OCC_CODE'].str[:2] # Ensure wage and employment columns are numeric, coercing errors to NaN df_oesm['TOT_EMP'] = pd.to_numeric(df_oesm['TOT_EMP'], errors='coerce') df_oesm['A_MEAN'] = pd.to_numeric(df_oesm['A_MEAN'], errors='coerce') # Drop rows with missing data in critical columns df_oesm.dropna(subset=['TOT_EMP', 'A_MEAN', 'onetsoc_major'], inplace=True) # Calculate the wage bill for each occupation df_oesm['wage_bill'] = df_oesm['TOT_EMP'] * df_oesm['A_MEAN'] # Aggregate the wage bill by major occupation group df_wage_bill_major = df_oesm.groupby('onetsoc_major')['wage_bill'].sum().reset_index() # Map the major codes to readable titles for plotting df_wage_bill_major['OCC_TITLE_MAJOR'] = df_wage_bill_major['onetsoc_major'].map(OCCUPATION_MAJOR_CODES) df_wage_bill_major.dropna(subset=['OCC_TITLE_MAJOR'], inplace=True) # Drop military/unmapped codes # Sort by wage bill for a more informative plot df_wage_bill_major = df_wage_bill_major.sort_values('wage_bill', ascending=False) if df_wage_bill_major.empty: logging.warning("No data available to generate the wage bill plot.") return None # --- Plotting --- plt.figure(figsize=(12, 10)) ax = sns.barplot(x='wage_bill', y='OCC_TITLE_MAJOR', data=df_wage_bill_major, palette="viridis", orient='h') ax.set_title('Total Wage Bill per Major Occupation Group', fontsize=16, pad=15) ax.set_xlabel('Total Wage Bill (in USD)', fontsize=12) ax.set_ylabel('Major Occupation Group', fontsize=12) ax.grid(axis='x', linestyle='--', alpha=0.7) # Format the x-axis to be more readable (e.g., "$2.0T" for trillions) def format_billions(x, pos): if x >= 1e12: return f'${x*1e-12:.1f}T' if x >= 1e9: return f'${x*1e-9:.0f}B' return f'${x*1e-6:.0f}M' ax.xaxis.set_major_formatter(mticker.FuncFormatter(format_billions)) plt.tight_layout() # --- File Saving --- temp_dir = tempfile.gettempdir() temp_path = Path(temp_dir) / "wage_bill_by_occupation.png" plt.savefig(temp_path, dpi=300) logging.info(f"Successfully saved plot to temporary file: {temp_path}") return temp_path except Exception as e: logging.error(f"An error occurred while generating the wage bill plot: {e}", exc_info=True) return None finally: plt.close() if conn: conn.close()