222 lines
7.5 KiB
Python
222 lines
7.5 KiB
Python
import subprocess
|
|
import matplotlib.colors as mcolors
|
|
import matplotlib as mpl
|
|
import seaborn as sns
|
|
import tempfile
|
|
import litellm
|
|
import time
|
|
import math
|
|
from tqdm import tqdm
|
|
from typing import Any, List, Dict
|
|
from .logger import logger
|
|
|
|
OCCUPATION_MAJOR_CODES = {
|
|
'11': 'Management',
|
|
'13': 'Business & Financial',
|
|
'15': 'Computer & Mathematical',
|
|
'17': 'Architecture & Engineering',
|
|
'19': 'Life, Physical, & Social Science',
|
|
'21': 'Community & Social Service',
|
|
'23': 'Legal',
|
|
'25': 'Education, Training, & Library',
|
|
'27': 'Arts, Design, & Media',
|
|
'29': 'Healthcare Practitioners',
|
|
'31': 'Healthcare Support',
|
|
'33': 'Protective Service',
|
|
'35': 'Food Preparation & Serving',
|
|
'37': 'Building & Grounds Maintenance',
|
|
'39': 'Personal Care & Service',
|
|
'41': 'Sales & Related',
|
|
'43': 'Office & Admin Support',
|
|
'45': 'Farming, Fishing, & Forestry',
|
|
'47': 'Construction & Extraction',
|
|
'49': 'Installation, Maintenance, & Repair',
|
|
'51': 'Production',
|
|
'53': 'Transportation & Material Moving',
|
|
'55': 'Military Specific',
|
|
}
|
|
|
|
GRAY = {'50':'#f8fafc','100':'#f1f5f9','200':'#e2e8f0',
|
|
'300':'#cbd5e1','400':'#94a3b8','500':'#64748b',
|
|
'600':'#475569','700':'#334155','800':'#1e293b',
|
|
'900':'#0f172a','950':'#020617'}
|
|
|
|
LIME = {'50': '#f7fee7','100': '#ecfcca','200': '#d8f999',
|
|
'300': '#bbf451','400': '#9ae600','500': '#83cd00',
|
|
'600': '#64a400','700': '#497d00','800': '#3c6300',
|
|
'900': '#35530e','950': '#192e03'}
|
|
|
|
|
|
def convert_to_minutes(qty, unit):
|
|
"""Converts a quantity in a given unit to minutes."""
|
|
return qty * {
|
|
"minute": 1,
|
|
"hour": 60,
|
|
"day": 60 * 24,
|
|
"week": 60 * 24 * 7,
|
|
"month": 60 * 24 * 30,
|
|
"trimester": 60 * 24 * 90,
|
|
"semester": 60 * 24 * 180,
|
|
"year": 60 * 24 * 365,
|
|
}[unit]
|
|
|
|
|
|
def pretty_display(df):
|
|
print(df)
|
|
return
|
|
html_output = df.to_html(index=False)
|
|
|
|
# Create a temporary HTML file
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix=".html", encoding="utf-8") as temp_file:
|
|
temp_file.write(html_output)
|
|
temp_file_path = temp_file.name
|
|
subprocess.run(["/home/felix/.nix-profile/bin/firefox-devedition", "-p", "Work (YouthAI)", temp_file_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
input("Press Enter to continue after reviewing the HTML output...")
|
|
|
|
|
|
def enrich(
|
|
model: str,
|
|
rpm: int, # Requests per minute
|
|
messages_to_process: List[List[Dict[str, str]]],
|
|
schema: Dict[str, Any],
|
|
chunk_size: int = 100,
|
|
):
|
|
all_results = []
|
|
num_messages = len(messages_to_process)
|
|
if num_messages == 0:
|
|
return all_results
|
|
|
|
num_chunks = math.ceil(num_messages / chunk_size)
|
|
logger.info(f"Starting enrichment for {num_messages} messages, in {num_chunks} chunks of up to {chunk_size} each.")
|
|
|
|
# Calculate the time that should be allocated per request to respect the RPM limit.
|
|
time_per_request = 60.0 / rpm if rpm > 0 else 0
|
|
|
|
for i in tqdm(range(num_chunks), desc="Enriching data in chunks"):
|
|
chunk_start_time = time.time()
|
|
|
|
start_index = i * chunk_size
|
|
end_index = start_index + chunk_size
|
|
message_chunk = messages_to_process[start_index:end_index]
|
|
|
|
if not message_chunk:
|
|
continue
|
|
|
|
try:
|
|
# Send requests for the entire chunk in a batch for better performance.
|
|
responses = litellm.batch_completion(
|
|
model=model,
|
|
messages=message_chunk,
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": schema,
|
|
},
|
|
)
|
|
|
|
# batch_completion returns the response or an exception object for each message.
|
|
# We'll replace exceptions with None as expected by the calling functions.
|
|
for response in responses:
|
|
if isinstance(response, Exception):
|
|
logger.error(f"API call within batch failed: {response}")
|
|
all_results.append(None)
|
|
else:
|
|
all_results.append(response)
|
|
|
|
except Exception as e:
|
|
# This catches catastrophic failures in batch_completion itself (e.g., auth)
|
|
logger.error(f"litellm.batch_completion call failed for chunk {i+1}/{num_chunks}: {e}")
|
|
all_results.extend([None] * len(message_chunk))
|
|
|
|
chunk_end_time = time.time()
|
|
elapsed_time = chunk_end_time - chunk_start_time
|
|
|
|
# To enforce the rate limit, we calculate how long the chunk *should* have taken
|
|
# and sleep for the remainder of that time.
|
|
if time_per_request > 0:
|
|
expected_duration_for_chunk = len(message_chunk) * time_per_request
|
|
if elapsed_time < expected_duration_for_chunk:
|
|
sleep_duration = expected_duration_for_chunk - elapsed_time
|
|
logger.debug(f"Chunk processed in {elapsed_time:.2f}s. Sleeping for {sleep_duration:.2f}s to respect RPM.")
|
|
time.sleep(sleep_duration)
|
|
|
|
return all_results
|
|
|
|
def get_contrasting_text_color(bg_color_hex_or_rgba):
|
|
if isinstance(bg_color_hex_or_rgba, str):
|
|
rgba = mcolors.to_rgba(bg_color_hex_or_rgba)
|
|
else:
|
|
rgba = bg_color_hex_or_rgba
|
|
r, g, b, _ = rgba
|
|
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
|
|
return 'black' if luminance > 0.55 else 'white'
|
|
|
|
|
|
def style_plot():
|
|
"""
|
|
Applies a consistent and professional style to all plots.
|
|
This function sets matplotlib's rcParams for a global effect.
|
|
"""
|
|
mpl.rcParams.update({
|
|
'figure.facecolor': GRAY['50'],
|
|
'figure.edgecolor': 'none',
|
|
'figure.figsize': (12, 8),
|
|
'figure.dpi': 150,
|
|
|
|
'axes.facecolor': GRAY['50'],
|
|
'axes.edgecolor': GRAY['300'],
|
|
'axes.grid': True,
|
|
'axes.labelcolor': GRAY['800'],
|
|
'axes.titlecolor': GRAY['900'],
|
|
'axes.titlesize': 18,
|
|
'axes.titleweight': 'bold',
|
|
'axes.titlepad': 20,
|
|
'axes.labelsize': 14,
|
|
'axes.labelweight': 'semibold',
|
|
'axes.labelpad': 10,
|
|
'axes.spines.top': False,
|
|
'axes.spines.right': False,
|
|
'axes.spines.left': True,
|
|
'axes.spines.bottom': True,
|
|
|
|
'text.color': GRAY['700'],
|
|
|
|
'xtick.color': GRAY['600'],
|
|
'ytick.color': GRAY['600'],
|
|
'xtick.labelsize': 12,
|
|
'ytick.labelsize': 12,
|
|
'xtick.major.size': 0,
|
|
'ytick.major.size': 0,
|
|
'xtick.minor.size': 0,
|
|
'ytick.minor.size': 0,
|
|
'xtick.major.pad': 8,
|
|
'ytick.major.pad': 8,
|
|
|
|
'grid.color': GRAY['200'],
|
|
'grid.linestyle': '--',
|
|
'grid.linewidth': 1,
|
|
|
|
'legend.frameon': False,
|
|
'legend.fontsize': 12,
|
|
'legend.title_fontsize': 14,
|
|
'legend.facecolor': 'inherit',
|
|
|
|
'font.family': 'sans-serif',
|
|
'font.sans-serif': ['Inter'],
|
|
'font.weight': 'normal',
|
|
|
|
'lines.linewidth': 2,
|
|
'lines.markersize': 6,
|
|
})
|
|
|
|
# Seaborn specific styles
|
|
# Use shades of LIME as the primary color palette.
|
|
# Sorting by integer value of keys, and reversed to have darker shades first.
|
|
# Excluding very light colors that won't be visible on a light background.
|
|
lime_palette = [LIME[k] for k in sorted(LIME.keys(), key=int, reverse=True) if k not in ['50', '100', '700', '800', '900', '950',]]
|
|
|
|
sns.set_palette(lime_palette)
|
|
sns.set_style("whitegrid", {
|
|
'axes.edgecolor': GRAY['300'],
|
|
'grid.color': GRAY['200'],
|
|
'grid.linestyle': '--',
|
|
})
|