progress
This commit is contained in:
parent
b7c94590f9
commit
f9f9825abb
9 changed files with 941 additions and 42 deletions
0
pipeline/__init__.py
Normal file
0
pipeline/__init__.py
Normal file
|
@ -7,8 +7,22 @@ from .run import Run
|
|||
import pandas as pd
|
||||
|
||||
def enrich_with_task_estimateability(run: Run) -> pd.DataFrame:
|
||||
run.metadata.
|
||||
"""
|
||||
TODO: check run.cache_dir / computed_task_estimateability.parquet, if it exists, load it, return it, and don't compute this
|
||||
|
||||
call enrich with the right parameters, save the output to cache dir,
|
||||
return it
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enrich_with_task_estimates(run: Run) -> pd.DataFrame:
|
||||
"""
|
||||
TODO: check run.cache_dir / computed_task_estimates.parquet, if it exists, load it, return it, and don't compute this
|
||||
|
||||
call enrich with the right parameters, save the output to cache dir,
|
||||
return it
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enrich(model: str, system_prompt: str, schema: Any, rpm: int, chunk_size: int = 100, messages: Any):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -6,36 +6,40 @@ import sqlite3
|
|||
from typing import Tuple
|
||||
import pandas as pd
|
||||
import requests
|
||||
import hashlib
|
||||
import io
|
||||
import zipfile
|
||||
from .run import Run
|
||||
from .logger import logger
|
||||
from pipeline.run import Run
|
||||
from pipeline.logger import logger
|
||||
|
||||
def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]:
|
||||
"""
|
||||
Downloads the O*NET database, creates a local SQLite file from it, and returns a connection.
|
||||
The version is the sha256 of the downloaded zip file.
|
||||
"""
|
||||
url = "https://www.onetcenter.org/dl_files/database/db_29_1_mysql.zip"
|
||||
logger.info(f"Downloading O*NET database from {url}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
# Read content into memory
|
||||
zip_content = response.content
|
||||
version = hashlib.sha256(zip_content).hexdigest()
|
||||
logger.info(f"O*NET database version (sha256): {version}")
|
||||
|
||||
version = "29_1"
|
||||
url = f"https://www.onetcenter.org/dl_files/database/db_{version}_mysql.zip"
|
||||
db_path = run.cache_dir / f"onet_{version}.db"
|
||||
run.meta.fetchers['onet'] = {
|
||||
'url': url,
|
||||
'version': version,
|
||||
'db_path': str(db_path),
|
||||
}
|
||||
|
||||
if db_path.exists():
|
||||
logger.info(f"Using cached O*NET database: {db_path}")
|
||||
conn = sqlite3.connect(db_path)
|
||||
# Set PRAGMA for foreign keys on every connection
|
||||
conn.execute("PRAGMA foreign_keys = ON;")
|
||||
return conn, version
|
||||
|
||||
logger.info(f"Downloading O*NET database from {url}")
|
||||
response = requests.get(url, stream=True, headers={
|
||||
"User-Agent": "econ-agent/1.0"
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
# Read content into memory
|
||||
zip_content = response.content
|
||||
|
||||
db_path = run.cache_dir / f"onet_{version}.db"
|
||||
|
||||
logger.info(f"Creating new O*NET database: {db_path}")
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
|
@ -84,22 +88,28 @@ def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]:
|
|||
def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
|
||||
"""
|
||||
Downloads the OESM national data from the BLS website.
|
||||
The version is the sha256 of the downloaded zip file.
|
||||
"""
|
||||
url = "https://www.bls.gov/oes/special-requests/oesm23nat.zip"
|
||||
logger.info(f"Downloading OESM data from {url}")
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
version = "23"
|
||||
url = f"https://www.bls.gov/oes/special-requests/oesm{version}nat.zip"
|
||||
parquet_path = run.cache_dir / "oesm.parquet"
|
||||
run.meta.fetchers['oesm'] = {
|
||||
'url': url,
|
||||
'version': version,
|
||||
'parquet_path': str(parquet_path),
|
||||
}
|
||||
|
||||
zip_content = response.content
|
||||
version = hashlib.sha256(zip_content).hexdigest()
|
||||
logger.info(f"OESM data version (sha256): {version}")
|
||||
|
||||
parquet_path = run.cache_dir / f"oesm_{version}.parquet"
|
||||
if parquet_path.exists():
|
||||
logger.info(f"Using cached OESM data: {parquet_path}")
|
||||
return pd.read_parquet(parquet_path), version
|
||||
|
||||
logger.info(f"Downloading OESM data from {url}")
|
||||
headers = {'User-Agent': 'econ-agent/1.0'}
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
zip_content = response.content
|
||||
logger.info(f"OESM data version: {version}")
|
||||
|
||||
logger.info(f"Creating new OESM data cache: {parquet_path}")
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as z:
|
||||
# Find the excel file in the zip
|
||||
|
@ -115,7 +125,7 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
|
|||
|
||||
logger.info(f"Reading {excel_filename} from zip archive.")
|
||||
with z.open(excel_filename) as f:
|
||||
df = pd.read_excel(f, engine='openpyxl')
|
||||
df = pd.read_excel(f, engine='openpyxl', na_values=['*', '#'])
|
||||
|
||||
df.to_parquet(parquet_path)
|
||||
logger.info(f"Saved OESM data to cache: {parquet_path}")
|
||||
|
@ -124,25 +134,30 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]:
|
|||
def fetch_epoch_remote_data(run: Run) -> Tuple[pd.DataFrame, str]:
|
||||
"""
|
||||
Downloads the EPOCH AI remote work task data.
|
||||
The version is the sha256 of the downloaded CSV file.
|
||||
"""
|
||||
# This is the direct download link constructed from the Google Drive share link
|
||||
version = "latest"
|
||||
url = "https://drive.google.com/uc?export=download&id=1GrHhuYIgaCCgo99dZ_40BWraz-fzo76r"
|
||||
parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet"
|
||||
run.meta.fetchers['epoch_remote'] = {
|
||||
'url': url,
|
||||
'version': version,
|
||||
'parquet_path': str(parquet_path),
|
||||
}
|
||||
|
||||
if parquet_path.exists():
|
||||
logger.info(f"Using cached EPOCH remote data: {parquet_path}")
|
||||
return pd.read_parquet(parquet_path), version
|
||||
|
||||
logger.info(f"Downloading EPOCH remote data from Google Drive: {url}")
|
||||
|
||||
# Need to handle potential cookies/redirects from Google Drive
|
||||
session = requests.Session()
|
||||
session.headers.update({"User-Agent": "econ-agent/1.0"})
|
||||
response = session.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
csv_content = response.content
|
||||
version = hashlib.sha256(csv_content).hexdigest()
|
||||
logger.info(f"EPOCH remote data version (sha256): {version}")
|
||||
|
||||
parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet"
|
||||
if parquet_path.exists():
|
||||
logger.info(f"Using cached EPOCH remote data: {parquet_path}")
|
||||
return pd.read_parquet(parquet_path), version
|
||||
|
||||
logger.info(f"Creating new EPOCH remote data cache: {parquet_path}")
|
||||
df = pd.read_csv(io.BytesIO(csv_content))
|
||||
|
|
|
@ -26,4 +26,16 @@ def _get_current_commit() -> str:
|
|||
"""
|
||||
Returns the current git commit hash, "unknown", or "errored" depending on why the commit could not be retrieved.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
import subprocess
|
||||
try:
|
||||
# Get the current commit hash
|
||||
commit_hash = subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], stderr=subprocess.PIPE, text=True
|
||||
).strip()
|
||||
return commit_hash
|
||||
except subprocess.CalledProcessError:
|
||||
# If git command fails (e.g., not a git repository)
|
||||
return "errored"
|
||||
except FileNotFoundError:
|
||||
# If git is not installed
|
||||
return "unknown"
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Optional
|
|||
from .metadata import Metadata
|
||||
|
||||
class Run(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
# === FETCHERS ===
|
||||
onet_conn: Optional[sqlite3.Connection] = None
|
||||
onet_version: Optional[str] = None
|
||||
|
|
|
@ -5,22 +5,29 @@ from .postprocessors import check_for_insanity, create_df_tasks
|
|||
from .generators import GENERATORS
|
||||
from .run import Run
|
||||
from .constants import GRAY
|
||||
import argparse
|
||||
import platformdirs
|
||||
import seaborn as sns
|
||||
import matplotlib as mpl
|
||||
from pathlib import Path
|
||||
from typings import Optional
|
||||
from typing import Optional
|
||||
|
||||
CACHE_DIR = platformdirs.user_cache_dir("econtai")
|
||||
|
||||
def run(output_dir: Optional[str] = None):
|
||||
if output_dir is None:
|
||||
output_dir = Path(".")
|
||||
|
||||
load_dotenv()
|
||||
_setup_graph_rendering()
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = Path("dist/")
|
||||
else:
|
||||
output_dir = Path(output_dir).resolve()
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
current_run = Run(output_dir=output_dir, cache_dir=CACHE_DIR)
|
||||
current_run.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetchers (fetchers.py)
|
||||
current_run.onet_conn, current_run.onet_version = fetch_onet_database(current_run)
|
||||
|
@ -54,3 +61,14 @@ def _setup_graph_rendering():
|
|||
|
||||
|
||||
sns.set_style("white")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run the econtai pipeline.")
|
||||
parser.add_argument("--output-dir", type=str, help="The directory to write output files to.")
|
||||
args = parser.parse_args()
|
||||
run(output_dir=args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue