wip
This commit is contained in:
parent
62296e1b69
commit
65dc648797
37 changed files with 1413 additions and 2433 deletions
222
pipeline/utils.py
Normal file
222
pipeline/utils.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
import subprocess
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib as mpl
|
||||
import seaborn as sns
|
||||
import tempfile
|
||||
import litellm
|
||||
import time
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from typing import Any, List, Dict
|
||||
from .logger import logger
|
||||
|
||||
OCCUPATION_MAJOR_CODES = {
|
||||
'11': 'Management',
|
||||
'13': 'Business & Financial',
|
||||
'15': 'Computer & Mathematical',
|
||||
'17': 'Architecture & Engineering',
|
||||
'19': 'Life, Physical, & Social Science',
|
||||
'21': 'Community & Social Service',
|
||||
'23': 'Legal',
|
||||
'25': 'Education, Training, & Library',
|
||||
'27': 'Arts, Design, & Media',
|
||||
'29': 'Healthcare Practitioners',
|
||||
'31': 'Healthcare Support',
|
||||
'33': 'Protective Service',
|
||||
'35': 'Food Preparation & Serving',
|
||||
'37': 'Building & Grounds Maintenance',
|
||||
'39': 'Personal Care & Service',
|
||||
'41': 'Sales & Related',
|
||||
'43': 'Office & Admin Support',
|
||||
'45': 'Farming, Fishing, & Forestry',
|
||||
'47': 'Construction & Extraction',
|
||||
'49': 'Installation, Maintenance, & Repair',
|
||||
'51': 'Production',
|
||||
'53': 'Transportation & Material Moving',
|
||||
'55': 'Military Specific',
|
||||
}
|
||||
|
||||
GRAY = {'50':'#f8fafc','100':'#f1f5f9','200':'#e2e8f0',
|
||||
'300':'#cbd5e1','400':'#94a3b8','500':'#64748b',
|
||||
'600':'#475569','700':'#334155','800':'#1e293b',
|
||||
'900':'#0f172a','950':'#020617'}
|
||||
|
||||
LIME = {'50': '#f7fee7','100': '#ecfcca','200': '#d8f999',
|
||||
'300': '#bbf451','400': '#9ae600','500': '#83cd00',
|
||||
'600': '#64a400','700': '#497d00','800': '#3c6300',
|
||||
'900': '#35530e','950': '#192e03'}
|
||||
|
||||
|
||||
def convert_to_minutes(qty, unit):
|
||||
"""Converts a quantity in a given unit to minutes."""
|
||||
return qty * {
|
||||
"minute": 1,
|
||||
"hour": 60,
|
||||
"day": 60 * 24,
|
||||
"week": 60 * 24 * 7,
|
||||
"month": 60 * 24 * 30,
|
||||
"trimester": 60 * 24 * 90,
|
||||
"semester": 60 * 24 * 180,
|
||||
"year": 60 * 24 * 365,
|
||||
}[unit]
|
||||
|
||||
|
||||
def pretty_display(df):
|
||||
print(df)
|
||||
return
|
||||
html_output = df.to_html(index=False)
|
||||
|
||||
# Create a temporary HTML file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix=".html", encoding="utf-8") as temp_file:
|
||||
temp_file.write(html_output)
|
||||
temp_file_path = temp_file.name
|
||||
subprocess.run(["/home/felix/.nix-profile/bin/firefox-devedition", "-p", "Work (YouthAI)", temp_file_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
input("Press Enter to continue after reviewing the HTML output...")
|
||||
|
||||
|
||||
def enrich(
|
||||
model: str,
|
||||
rpm: int, # Requests per minute
|
||||
messages_to_process: List[List[Dict[str, str]]],
|
||||
schema: Dict[str, Any],
|
||||
chunk_size: int = 100,
|
||||
):
|
||||
all_results = []
|
||||
num_messages = len(messages_to_process)
|
||||
if num_messages == 0:
|
||||
return all_results
|
||||
|
||||
num_chunks = math.ceil(num_messages / chunk_size)
|
||||
logger.info(f"Starting enrichment for {num_messages} messages, in {num_chunks} chunks of up to {chunk_size} each.")
|
||||
|
||||
# Calculate the time that should be allocated per request to respect the RPM limit.
|
||||
time_per_request = 60.0 / rpm if rpm > 0 else 0
|
||||
|
||||
for i in tqdm(range(num_chunks), desc="Enriching data in chunks"):
|
||||
chunk_start_time = time.time()
|
||||
|
||||
start_index = i * chunk_size
|
||||
end_index = start_index + chunk_size
|
||||
message_chunk = messages_to_process[start_index:end_index]
|
||||
|
||||
if not message_chunk:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Send requests for the entire chunk in a batch for better performance.
|
||||
responses = litellm.batch_completion(
|
||||
model=model,
|
||||
messages=message_chunk,
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": schema,
|
||||
},
|
||||
)
|
||||
|
||||
# batch_completion returns the response or an exception object for each message.
|
||||
# We'll replace exceptions with None as expected by the calling functions.
|
||||
for response in responses:
|
||||
if isinstance(response, Exception):
|
||||
logger.error(f"API call within batch failed: {response}")
|
||||
all_results.append(None)
|
||||
else:
|
||||
all_results.append(response)
|
||||
|
||||
except Exception as e:
|
||||
# This catches catastrophic failures in batch_completion itself (e.g., auth)
|
||||
logger.error(f"litellm.batch_completion call failed for chunk {i+1}/{num_chunks}: {e}")
|
||||
all_results.extend([None] * len(message_chunk))
|
||||
|
||||
chunk_end_time = time.time()
|
||||
elapsed_time = chunk_end_time - chunk_start_time
|
||||
|
||||
# To enforce the rate limit, we calculate how long the chunk *should* have taken
|
||||
# and sleep for the remainder of that time.
|
||||
if time_per_request > 0:
|
||||
expected_duration_for_chunk = len(message_chunk) * time_per_request
|
||||
if elapsed_time < expected_duration_for_chunk:
|
||||
sleep_duration = expected_duration_for_chunk - elapsed_time
|
||||
logger.debug(f"Chunk processed in {elapsed_time:.2f}s. Sleeping for {sleep_duration:.2f}s to respect RPM.")
|
||||
time.sleep(sleep_duration)
|
||||
|
||||
return all_results
|
||||
|
||||
def get_contrasting_text_color(bg_color_hex_or_rgba):
|
||||
if isinstance(bg_color_hex_or_rgba, str):
|
||||
rgba = mcolors.to_rgba(bg_color_hex_or_rgba)
|
||||
else:
|
||||
rgba = bg_color_hex_or_rgba
|
||||
r, g, b, _ = rgba
|
||||
luminance = 0.2126 * r + 0.7152 * g + 0.0722 * b
|
||||
return 'black' if luminance > 0.55 else 'white'
|
||||
|
||||
|
||||
def style_plot():
|
||||
"""
|
||||
Applies a consistent and professional style to all plots.
|
||||
This function sets matplotlib's rcParams for a global effect.
|
||||
"""
|
||||
mpl.rcParams.update({
|
||||
'figure.facecolor': GRAY['50'],
|
||||
'figure.edgecolor': 'none',
|
||||
'figure.figsize': (12, 8),
|
||||
'figure.dpi': 150,
|
||||
|
||||
'axes.facecolor': GRAY['50'],
|
||||
'axes.edgecolor': GRAY['300'],
|
||||
'axes.grid': True,
|
||||
'axes.labelcolor': GRAY['800'],
|
||||
'axes.titlecolor': GRAY['900'],
|
||||
'axes.titlesize': 18,
|
||||
'axes.titleweight': 'bold',
|
||||
'axes.titlepad': 20,
|
||||
'axes.labelsize': 14,
|
||||
'axes.labelweight': 'semibold',
|
||||
'axes.labelpad': 10,
|
||||
'axes.spines.top': False,
|
||||
'axes.spines.right': False,
|
||||
'axes.spines.left': True,
|
||||
'axes.spines.bottom': True,
|
||||
|
||||
'text.color': GRAY['700'],
|
||||
|
||||
'xtick.color': GRAY['600'],
|
||||
'ytick.color': GRAY['600'],
|
||||
'xtick.labelsize': 12,
|
||||
'ytick.labelsize': 12,
|
||||
'xtick.major.size': 0,
|
||||
'ytick.major.size': 0,
|
||||
'xtick.minor.size': 0,
|
||||
'ytick.minor.size': 0,
|
||||
'xtick.major.pad': 8,
|
||||
'ytick.major.pad': 8,
|
||||
|
||||
'grid.color': GRAY['200'],
|
||||
'grid.linestyle': '--',
|
||||
'grid.linewidth': 1,
|
||||
|
||||
'legend.frameon': False,
|
||||
'legend.fontsize': 12,
|
||||
'legend.title_fontsize': 14,
|
||||
'legend.facecolor': 'inherit',
|
||||
|
||||
'font.family': 'sans-serif',
|
||||
'font.sans-serif': ['Inter'],
|
||||
'font.weight': 'normal',
|
||||
|
||||
'lines.linewidth': 2,
|
||||
'lines.markersize': 6,
|
||||
})
|
||||
|
||||
# Seaborn specific styles
|
||||
# Use shades of LIME as the primary color palette.
|
||||
# Sorting by integer value of keys, and reversed to have darker shades first.
|
||||
# Excluding very light colors that won't be visible on a light background.
|
||||
lime_palette = [LIME[k] for k in sorted(LIME.keys(), key=int, reverse=True) if k not in ['50', '100', '700', '800', '900', '950',]]
|
||||
|
||||
sns.set_palette(lime_palette)
|
||||
sns.set_style("whitegrid", {
|
||||
'axes.edgecolor': GRAY['300'],
|
||||
'grid.color': GRAY['200'],
|
||||
'grid.linestyle': '--',
|
||||
})
|
Loading…
Add table
Add a link
Reference in a new issue