This commit is contained in:
Félix Dorn 2025-07-03 19:40:35 +02:00
parent b7c94590f9
commit f9f9825abb
9 changed files with 941 additions and 42 deletions

0
pipeline/__init__.py Normal file
View file

View file

@ -7,8 +7,22 @@ from .run import Run
import pandas as pd
def enrich_with_task_estimateability(run: Run) -> pd.DataFrame:
run.metadata.
"""
TODO: check run.cache_dir / computed_task_estimateability.parquet, if it exists, load it, return it, and don't compute this
call enrich with the right parameters, save the output to cache dir,
return it
"""
raise NotImplementedError
def enrich_with_task_estimates(run: Run) -> pd.DataFrame:
"""
TODO: check run.cache_dir / computed_task_estimates.parquet, if it exists, load it, return it, and don't compute this
call enrich with the right parameters, save the output to cache dir,
return it
"""
raise NotImplementedError
def enrich(model: str, system_prompt: str, schema: Any, rpm: int, chunk_size: int = 100, messages: Any):
raise NotImplementedError

View file

@ -6,36 +6,40 @@ import sqlite3
from typing import Tuple
import pandas as pd
import requests
import hashlib
import io
import zipfile
from .run import Run
from .logger import logger
from pipeline.run import Run
from pipeline.logger import logger
def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]:
"""
Downloads the O*NET database, creates a local SQLite file from it, and returns a connection.
The version is the sha256 of the downloaded zip file.
"""
url = "https://www.onetcenter.org/dl_files/database/db_29_1_mysql.zip"
logger.info(f"Downloading O*NET database from {url}")
response = requests.get(url, stream=True)
response.raise_for_status()
# Read content into memory
zip_content = response.content
version = hashlib.sha256(zip_content).hexdigest()
logger.info(f"O*NET database version (sha256): {version}")
version = "29_1"
url = f"https://www.onetcenter.org/dl_files/database/db_{version}_mysql.zip"
db_path = run.cache_dir / f"onet_{version}.db"
run.meta.fetchers['onet'] = {
'url': url,
'version': version,
'db_path': str(db_path),
}
if db_path.exists():
logger.info(f"Using cached O*NET database: {db_path}")
conn = sqlite3.connect(db_path)
# Set PRAGMA for foreign keys on every connection
conn.execute("PRAGMA foreign_keys = ON;")
return conn, version
logger.info(f"Downloading O*NET database from {url}")
response = requests.get(url, stream=True, headers={
"User-Agent": "econ-agent/1.0"
})
response.raise_for_status()
# Read content into memory
zip_content = response.content
db_path = run.cache_dir / f"onet_{version}.db"
logger.info(f"Creating new O*NET database: {db_path}")
conn = sqlite3.connect(db_path)
@ -84,22 +88,28 @@ def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]:
def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
"""
Downloads the OESM national data from the BLS website.
The version is the sha256 of the downloaded zip file.
"""
url = "https://www.bls.gov/oes/special-requests/oesm23nat.zip"
logger.info(f"Downloading OESM data from {url}")
response = requests.get(url)
response.raise_for_status()
version = "23"
url = f"https://www.bls.gov/oes/special-requests/oesm{version}nat.zip"
parquet_path = run.cache_dir / "oesm.parquet"
run.meta.fetchers['oesm'] = {
'url': url,
'version': version,
'parquet_path': str(parquet_path),
}
zip_content = response.content
version = hashlib.sha256(zip_content).hexdigest()
logger.info(f"OESM data version (sha256): {version}")
parquet_path = run.cache_dir / f"oesm_{version}.parquet"
if parquet_path.exists():
logger.info(f"Using cached OESM data: {parquet_path}")
return pd.read_parquet(parquet_path), version
logger.info(f"Downloading OESM data from {url}")
headers = {'User-Agent': 'econ-agent/1.0'}
response = requests.get(url, headers=headers)
response.raise_for_status()
zip_content = response.content
logger.info(f"OESM data version: {version}")
logger.info(f"Creating new OESM data cache: {parquet_path}")
with zipfile.ZipFile(io.BytesIO(zip_content)) as z:
# Find the excel file in the zip
@ -115,7 +125,7 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
logger.info(f"Reading {excel_filename} from zip archive.")
with z.open(excel_filename) as f:
df = pd.read_excel(f, engine='openpyxl')
df = pd.read_excel(f, engine='openpyxl', na_values=['*', '#'])
df.to_parquet(parquet_path)
logger.info(f"Saved OESM data to cache: {parquet_path}")
@ -124,25 +134,30 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
def fetch_epoch_remote_data(run: Run) -> Tuple[pd.DataFrame, str]:
"""
Downloads the EPOCH AI remote work task data.
The version is the sha256 of the downloaded CSV file.
"""
# This is the direct download link constructed from the Google Drive share link
version = "latest"
url = "https://drive.google.com/uc?export=download&id=1GrHhuYIgaCCgo99dZ_40BWraz-fzo76r"
parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet"
run.meta.fetchers['epoch_remote'] = {
'url': url,
'version': version,
'parquet_path': str(parquet_path),
}
if parquet_path.exists():
logger.info(f"Using cached EPOCH remote data: {parquet_path}")
return pd.read_parquet(parquet_path), version
logger.info(f"Downloading EPOCH remote data from Google Drive: {url}")
# Need to handle potential cookies/redirects from Google Drive
session = requests.Session()
session.headers.update({"User-Agent": "econ-agent/1.0"})
response = session.get(url, stream=True)
response.raise_for_status()
csv_content = response.content
version = hashlib.sha256(csv_content).hexdigest()
logger.info(f"EPOCH remote data version (sha256): {version}")
parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet"
if parquet_path.exists():
logger.info(f"Using cached EPOCH remote data: {parquet_path}")
return pd.read_parquet(parquet_path), version
logger.info(f"Creating new EPOCH remote data cache: {parquet_path}")
df = pd.read_csv(io.BytesIO(csv_content))

View file

@ -26,4 +26,16 @@ def _get_current_commit() -> str:
"""
Returns the current git commit hash, "unknown", or "errored" depending on why the commit could not be retrieved.
"""
raise NotImplementedError
import subprocess
try:
# Get the current commit hash
commit_hash = subprocess.check_output(
["git", "rev-parse", "HEAD"], stderr=subprocess.PIPE, text=True
).strip()
return commit_hash
except subprocess.CalledProcessError:
# If git command fails (e.g., not a git repository)
return "errored"
except FileNotFoundError:
# If git is not installed
return "unknown"

View file

@ -6,6 +6,7 @@ from typing import Optional
from .metadata import Metadata
class Run(BaseModel):
model_config = {"arbitrary_types_allowed": True}
# === FETCHERS ===
onet_conn: Optional[sqlite3.Connection] = None
onet_version: Optional[str] = None

View file

@ -5,22 +5,29 @@ from .postprocessors import check_for_insanity, create_df_tasks
from .generators import GENERATORS
from .run import Run
from .constants import GRAY
import argparse
import platformdirs
import seaborn as sns
import matplotlib as mpl
from pathlib import Path
from typings import Optional
from typing import Optional
CACHE_DIR = platformdirs.user_cache_dir("econtai")
def run(output_dir: Optional[str] = None):
if output_dir is None:
output_dir = Path(".")
load_dotenv()
_setup_graph_rendering()
if output_dir is None:
output_dir = Path("dist/")
else:
output_dir = Path(output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
current_run = Run(output_dir=output_dir, cache_dir=CACHE_DIR)
current_run.cache_dir.mkdir(parents=True, exist_ok=True)
# Fetchers (fetchers.py)
current_run.onet_conn, current_run.onet_version = fetch_onet_database(current_run)
@ -54,3 +61,14 @@ def _setup_graph_rendering():
sns.set_style("white")
def main():
parser = argparse.ArgumentParser(description="Run the econtai pipeline.")
parser.add_argument("--output-dir", type=str, help="The directory to write output files to.")
args = parser.parse_args()
run(output_dir=args.output_dir)
if __name__ == "__main__":
main()