import subprocess import matplotlib.colors as mcolors import matplotlib as mpl import seaborn as sns import tempfile import litellm import time import math from tqdm import tqdm from typing import Any, List, Dict from .logger import logger OCCUPATION_MAJOR_CODES = { '11': 'Management', '13': 'Business & Financial', '15': 'Computer & Mathematical', '17': 'Architecture & Engineering', '19': 'Life, Physical, & Social Science', '21': 'Community & Social Service', '23': 'Legal', '25': 'Education, Training, & Library', '27': 'Arts, Design, & Media', '29': 'Healthcare Practitioners', '31': 'Healthcare Support', '33': 'Protective Service', '35': 'Food Preparation & Serving', '37': 'Building & Grounds Maintenance', '39': 'Personal Care & Service', '41': 'Sales & Related', '43': 'Office & Admin Support', '45': 'Farming, Fishing, & Forestry', '47': 'Construction & Extraction', '49': 'Installation, Maintenance, & Repair', '51': 'Production', '53': 'Transportation & Material Moving', '55': 'Military Specific', } GRAY = {'50':'#f8fafc','100':'#f1f5f9','200':'#e2e8f0', '300':'#cbd5e1','400':'#94a3b8','500':'#64748b', '600':'#475569','700':'#334155','800':'#1e293b', '900':'#0f172a','950':'#020617'} LIME = {'50': '#f7fee7','100': '#ecfcca','200': '#d8f999', '300': '#bbf451','400': '#9ae600','500': '#83cd00', '600': '#64a400','700': '#497d00','800': '#3c6300', '900': '#35530e','950': '#192e03'} def convert_to_minutes(qty, unit): """Converts a quantity in a given unit to minutes.""" return qty * { "minute": 1, "hour": 60, "day": 60 * 24, "week": 60 * 24 * 7, "month": 60 * 24 * 30, "trimester": 60 * 24 * 90, "semester": 60 * 24 * 180, "year": 60 * 24 * 365, }[unit] def pretty_display(df): print(df) return html_output = df.to_html(index=False) # Create a temporary HTML file with tempfile.NamedTemporaryFile(mode='w', suffix=".html", encoding="utf-8") as temp_file: temp_file.write(html_output) temp_file_path = temp_file.name subprocess.run(["/home/felix/.nix-profile/bin/firefox-devedition", "-p", "Work (YouthAI)", temp_file_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) input("Press Enter to continue after reviewing the HTML output...") def enrich( model: str, rpm: int, # Requests per minute messages_to_process: List[List[Dict[str, str]]], schema: Dict[str, Any], chunk_size: int = 100, ): all_results = [] num_messages = len(messages_to_process) if num_messages == 0: return all_results num_chunks = math.ceil(num_messages / chunk_size) logger.info(f"Starting enrichment for {num_messages} messages, in {num_chunks} chunks of up to {chunk_size} each.") # Calculate the time that should be allocated per request to respect the RPM limit. time_per_request = 60.0 / rpm if rpm > 0 else 0 for i in tqdm(range(num_chunks), desc="Enriching data in chunks"): chunk_start_time = time.time() start_index = i * chunk_size end_index = start_index + chunk_size message_chunk = messages_to_process[start_index:end_index] if not message_chunk: continue try: # Send requests for the entire chunk in a batch for better performance. responses = litellm.batch_completion( model=model, messages=message_chunk, response_format={ "type": "json_schema", "json_schema": schema, }, ) # batch_completion returns the response or an exception object for each message. # We'll replace exceptions with None as expected by the calling functions. for response in responses: if isinstance(response, Exception): logger.error(f"API call within batch failed: {response}") all_results.append(None) else: all_results.append(response) except Exception as e: # This catches catastrophic failures in batch_completion itself (e.g., auth) logger.error(f"litellm.batch_completion call failed for chunk {i+1}/{num_chunks}: {e}") all_results.extend([None] * len(message_chunk)) chunk_end_time = time.time() elapsed_time = chunk_end_time - chunk_start_time # To enforce the rate limit, we calculate how long the chunk *should* have taken # and sleep for the remainder of that time. if time_per_request > 0: expected_duration_for_chunk = len(message_chunk) * time_per_request if elapsed_time < expected_duration_for_chunk: sleep_duration = expected_duration_for_chunk - elapsed_time logger.debug(f"Chunk processed in {elapsed_time:.2f}s. Sleeping for {sleep_duration:.2f}s to respect RPM.") time.sleep(sleep_duration) return all_results def get_contrasting_text_color(bg_color_hex_or_rgba): if isinstance(bg_color_hex_or_rgba, str): rgba = mcolors.to_rgba(bg_color_hex_or_rgba) else: rgba = bg_color_hex_or_rgba r, g, b, _ = rgba luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b return 'black' if luminance > 0.55 else 'white' def style_plot(): """ Applies a consistent and professional style to all plots. This function sets matplotlib's rcParams for a global effect. """ mpl.rcParams.update({ 'figure.facecolor': GRAY['50'], 'figure.edgecolor': 'none', 'figure.figsize': (12, 8), 'figure.dpi': 150, 'axes.facecolor': GRAY['50'], 'axes.edgecolor': GRAY['300'], 'axes.grid': True, 'axes.labelcolor': GRAY['800'], 'axes.titlecolor': GRAY['900'], 'axes.titlesize': 18, 'axes.titleweight': 'bold', 'axes.titlepad': 20, 'axes.labelsize': 14, 'axes.labelweight': 'semibold', 'axes.labelpad': 10, 'axes.spines.top': False, 'axes.spines.right': False, 'axes.spines.left': True, 'axes.spines.bottom': True, 'text.color': GRAY['700'], 'xtick.color': GRAY['600'], 'ytick.color': GRAY['600'], 'xtick.labelsize': 12, 'ytick.labelsize': 12, 'xtick.major.size': 0, 'ytick.major.size': 0, 'xtick.minor.size': 0, 'ytick.minor.size': 0, 'xtick.major.pad': 8, 'ytick.major.pad': 8, 'grid.color': GRAY['200'], 'grid.linestyle': '--', 'grid.linewidth': 1, 'legend.frameon': False, 'legend.fontsize': 12, 'legend.title_fontsize': 14, 'legend.facecolor': 'inherit', 'font.family': 'sans-serif', 'font.sans-serif': ['Inter'], 'font.weight': 'normal', 'lines.linewidth': 2, 'lines.markersize': 6, }) # Seaborn specific styles # Use shades of LIME as the primary color palette. # Sorting by integer value of keys, and reversed to have darker shades first. # Excluding very light colors that won't be visible on a light background. lime_palette = [LIME[k] for k in sorted(LIME.keys(), key=int, reverse=True) if k not in ['50', '100', '700', '800', '900', '950',]] sns.set_palette(lime_palette) sns.set_style("whitegrid", { 'axes.edgecolor': GRAY['300'], 'grid.color': GRAY['200'], 'grid.linestyle': '--', })