sprint-econtai/pipeline/utils.py
Félix Dorn 65dc648797 wip
2025-07-15 00:34:54 +02:00

222 lines
7.5 KiB
Python

import subprocess
import matplotlib.colors as mcolors
import matplotlib as mpl
import seaborn as sns
import tempfile
import litellm
import time
import math
from tqdm import tqdm
from typing import Any, List, Dict
from .logger import logger
OCCUPATION_MAJOR_CODES = {
'11': 'Management',
'13': 'Business & Financial',
'15': 'Computer & Mathematical',
'17': 'Architecture & Engineering',
'19': 'Life, Physical, & Social Science',
'21': 'Community & Social Service',
'23': 'Legal',
'25': 'Education, Training, & Library',
'27': 'Arts, Design, & Media',
'29': 'Healthcare Practitioners',
'31': 'Healthcare Support',
'33': 'Protective Service',
'35': 'Food Preparation & Serving',
'37': 'Building & Grounds Maintenance',
'39': 'Personal Care & Service',
'41': 'Sales & Related',
'43': 'Office & Admin Support',
'45': 'Farming, Fishing, & Forestry',
'47': 'Construction & Extraction',
'49': 'Installation, Maintenance, & Repair',
'51': 'Production',
'53': 'Transportation & Material Moving',
'55': 'Military Specific',
}
GRAY = {'50':'#f8fafc','100':'#f1f5f9','200':'#e2e8f0',
'300':'#cbd5e1','400':'#94a3b8','500':'#64748b',
'600':'#475569','700':'#334155','800':'#1e293b',
'900':'#0f172a','950':'#020617'}
LIME = {'50': '#f7fee7','100': '#ecfcca','200': '#d8f999',
'300': '#bbf451','400': '#9ae600','500': '#83cd00',
'600': '#64a400','700': '#497d00','800': '#3c6300',
'900': '#35530e','950': '#192e03'}
def convert_to_minutes(qty, unit):
"""Converts a quantity in a given unit to minutes."""
return qty * {
"minute": 1,
"hour": 60,
"day": 60 * 24,
"week": 60 * 24 * 7,
"month": 60 * 24 * 30,
"trimester": 60 * 24 * 90,
"semester": 60 * 24 * 180,
"year": 60 * 24 * 365,
}[unit]
def pretty_display(df):
print(df)
return
html_output = df.to_html(index=False)
# Create a temporary HTML file
with tempfile.NamedTemporaryFile(mode='w', suffix=".html", encoding="utf-8") as temp_file:
temp_file.write(html_output)
temp_file_path = temp_file.name
subprocess.run(["/home/felix/.nix-profile/bin/firefox-devedition", "-p", "Work (YouthAI)", temp_file_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
input("Press Enter to continue after reviewing the HTML output...")
def enrich(
model: str,
rpm: int, # Requests per minute
messages_to_process: List[List[Dict[str, str]]],
schema: Dict[str, Any],
chunk_size: int = 100,
):
all_results = []
num_messages = len(messages_to_process)
if num_messages == 0:
return all_results
num_chunks = math.ceil(num_messages / chunk_size)
logger.info(f"Starting enrichment for {num_messages} messages, in {num_chunks} chunks of up to {chunk_size} each.")
# Calculate the time that should be allocated per request to respect the RPM limit.
time_per_request = 60.0 / rpm if rpm > 0 else 0
for i in tqdm(range(num_chunks), desc="Enriching data in chunks"):
chunk_start_time = time.time()
start_index = i * chunk_size
end_index = start_index + chunk_size
message_chunk = messages_to_process[start_index:end_index]
if not message_chunk:
continue
try:
# Send requests for the entire chunk in a batch for better performance.
responses = litellm.batch_completion(
model=model,
messages=message_chunk,
response_format={
"type": "json_schema",
"json_schema": schema,
},
)
# batch_completion returns the response or an exception object for each message.
# We'll replace exceptions with None as expected by the calling functions.
for response in responses:
if isinstance(response, Exception):
logger.error(f"API call within batch failed: {response}")
all_results.append(None)
else:
all_results.append(response)
except Exception as e:
# This catches catastrophic failures in batch_completion itself (e.g., auth)
logger.error(f"litellm.batch_completion call failed for chunk {i+1}/{num_chunks}: {e}")
all_results.extend([None] * len(message_chunk))
chunk_end_time = time.time()
elapsed_time = chunk_end_time - chunk_start_time
# To enforce the rate limit, we calculate how long the chunk *should* have taken
# and sleep for the remainder of that time.
if time_per_request > 0:
expected_duration_for_chunk = len(message_chunk) * time_per_request
if elapsed_time < expected_duration_for_chunk:
sleep_duration = expected_duration_for_chunk - elapsed_time
logger.debug(f"Chunk processed in {elapsed_time:.2f}s. Sleeping for {sleep_duration:.2f}s to respect RPM.")
time.sleep(sleep_duration)
return all_results
def get_contrasting_text_color(bg_color_hex_or_rgba):
if isinstance(bg_color_hex_or_rgba, str):
rgba = mcolors.to_rgba(bg_color_hex_or_rgba)
else:
rgba = bg_color_hex_or_rgba
r, g, b, _ = rgba
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
return 'black' if luminance > 0.55 else 'white'
def style_plot():
"""
Applies a consistent and professional style to all plots.
This function sets matplotlib's rcParams for a global effect.
"""
mpl.rcParams.update({
'figure.facecolor': GRAY['50'],
'figure.edgecolor': 'none',
'figure.figsize': (12, 8),
'figure.dpi': 150,
'axes.facecolor': GRAY['50'],
'axes.edgecolor': GRAY['300'],
'axes.grid': True,
'axes.labelcolor': GRAY['800'],
'axes.titlecolor': GRAY['900'],
'axes.titlesize': 18,
'axes.titleweight': 'bold',
'axes.titlepad': 20,
'axes.labelsize': 14,
'axes.labelweight': 'semibold',
'axes.labelpad': 10,
'axes.spines.top': False,
'axes.spines.right': False,
'axes.spines.left': True,
'axes.spines.bottom': True,
'text.color': GRAY['700'],
'xtick.color': GRAY['600'],
'ytick.color': GRAY['600'],
'xtick.labelsize': 12,
'ytick.labelsize': 12,
'xtick.major.size': 0,
'ytick.major.size': 0,
'xtick.minor.size': 0,
'ytick.minor.size': 0,
'xtick.major.pad': 8,
'ytick.major.pad': 8,
'grid.color': GRAY['200'],
'grid.linestyle': '--',
'grid.linewidth': 1,
'legend.frameon': False,
'legend.fontsize': 12,
'legend.title_fontsize': 14,
'legend.facecolor': 'inherit',
'font.family': 'sans-serif',
'font.sans-serif': ['Inter'],
'font.weight': 'normal',
'lines.linewidth': 2,
'lines.markersize': 6,
})
# Seaborn specific styles
# Use shades of LIME as the primary color palette.
# Sorting by integer value of keys, and reversed to have darker shades first.
# Excluding very light colors that won't be visible on a light background.
lime_palette = [LIME[k] for k in sorted(LIME.keys(), key=int, reverse=True) if k not in ['50', '100', '700', '800', '900', '950',]]
sns.set_palette(lime_palette)
sns.set_style("whitegrid", {
'axes.edgecolor': GRAY['300'],
'grid.color': GRAY['200'],
'grid.linestyle': '--',
})