progress
This commit is contained in:
		
							parent
							
								
									b7c94590f9
								
							
						
					
					
						commit
						f9f9825abb
					
				
					 9 changed files with 941 additions and 42 deletions
				
			
		
							
								
								
									
										0
									
								
								pipeline/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								pipeline/__init__.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -7,8 +7,22 @@ from  .run  import Run | |||
| import pandas as pd | ||||
| 
 | ||||
| def enrich_with_task_estimateability(run: Run) -> pd.DataFrame: | ||||
|     run.metadata. | ||||
|     """ | ||||
|     TODO: check run.cache_dir / computed_task_estimateability.parquet, if it exists, load it, return it, and don't compute this | ||||
| 
 | ||||
|     call enrich with the right parameters, save the output to cache dir, | ||||
|      return it | ||||
|     """ | ||||
|     raise NotImplementedError | ||||
| 
 | ||||
| def enrich_with_task_estimates(run: Run) -> pd.DataFrame: | ||||
|     """ | ||||
|     TODO: check run.cache_dir / computed_task_estimates.parquet, if it exists, load it, return it, and don't compute this | ||||
| 
 | ||||
|     call enrich with the right parameters, save the output to cache dir, | ||||
|      return it | ||||
|     """ | ||||
|     raise NotImplementedError | ||||
| 
 | ||||
| def enrich(model: str, system_prompt: str, schema: Any, rpm: int, chunk_size: int = 100, messages: Any): | ||||
|     raise NotImplementedError | ||||
|  |  | |||
|  | @ -6,36 +6,40 @@ import sqlite3 | |||
| from typing import Tuple | ||||
| import pandas as pd | ||||
| import requests | ||||
| import hashlib | ||||
| import io | ||||
| import zipfile | ||||
| from .run import Run | ||||
| from .logger import logger | ||||
| from pipeline.run import Run | ||||
| from pipeline.logger import logger | ||||
| 
 | ||||
| def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]: | ||||
|     """ | ||||
|     Downloads the O*NET database, creates a local SQLite file from it, and returns a connection. | ||||
|     The version is the sha256 of the downloaded zip file. | ||||
|     """ | ||||
|     url = "https://www.onetcenter.org/dl_files/database/db_29_1_mysql.zip" | ||||
|     logger.info(f"Downloading O*NET database from {url}") | ||||
|     response = requests.get(url, stream=True) | ||||
|     response.raise_for_status() | ||||
| 
 | ||||
|     # Read content into memory | ||||
|     zip_content = response.content | ||||
|     version = hashlib.sha256(zip_content).hexdigest() | ||||
|     logger.info(f"O*NET database version (sha256): {version}") | ||||
| 
 | ||||
|     version  = "29_1" | ||||
|     url = f"https://www.onetcenter.org/dl_files/database/db_{version}_mysql.zip" | ||||
|     db_path = run.cache_dir / f"onet_{version}.db" | ||||
|     run.meta.fetchers['onet'] = { | ||||
|         'url': url, | ||||
|         'version': version, | ||||
|         'db_path': str(db_path), | ||||
|     } | ||||
| 
 | ||||
|     if db_path.exists(): | ||||
|         logger.info(f"Using cached O*NET database: {db_path}") | ||||
|         conn = sqlite3.connect(db_path) | ||||
|         # Set PRAGMA for foreign keys on every connection | ||||
|         conn.execute("PRAGMA foreign_keys = ON;") | ||||
|         return conn, version | ||||
| 
 | ||||
|     logger.info(f"Downloading O*NET database from {url}") | ||||
|     response = requests.get(url, stream=True, headers={ | ||||
|         "User-Agent": "econ-agent/1.0" | ||||
|     }) | ||||
|     response.raise_for_status() | ||||
| 
 | ||||
|     # Read content into memory | ||||
|     zip_content = response.content | ||||
| 
 | ||||
|     db_path = run.cache_dir / f"onet_{version}.db" | ||||
| 
 | ||||
|     logger.info(f"Creating new O*NET database: {db_path}") | ||||
|     conn = sqlite3.connect(db_path) | ||||
| 
 | ||||
|  | @ -84,22 +88,28 @@ def fetch_onet_database(run: Run) -> Tuple[sqlite3.Connection, str]: | |||
| def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]: | ||||
|     """ | ||||
|     Downloads the OESM national data from the BLS website. | ||||
|     The version is the sha256 of the downloaded zip file. | ||||
|     """ | ||||
|     url = "https://www.bls.gov/oes/special-requests/oesm23nat.zip" | ||||
|     logger.info(f"Downloading OESM data from {url}") | ||||
|     response = requests.get(url) | ||||
|     response.raise_for_status() | ||||
|     version = "23" | ||||
|     url = f"https://www.bls.gov/oes/special-requests/oesm{version}nat.zip" | ||||
|     parquet_path = run.cache_dir / "oesm.parquet" | ||||
|     run.meta.fetchers['oesm'] = { | ||||
|         'url': url, | ||||
|         'version': version, | ||||
|         'parquet_path': str(parquet_path), | ||||
|     } | ||||
| 
 | ||||
|     zip_content = response.content | ||||
|     version = hashlib.sha256(zip_content).hexdigest() | ||||
|     logger.info(f"OESM data version (sha256): {version}") | ||||
| 
 | ||||
|     parquet_path = run.cache_dir / f"oesm_{version}.parquet" | ||||
|     if parquet_path.exists(): | ||||
|         logger.info(f"Using cached OESM data: {parquet_path}") | ||||
|         return pd.read_parquet(parquet_path), version | ||||
| 
 | ||||
|     logger.info(f"Downloading OESM data from {url}") | ||||
|     headers = {'User-Agent': 'econ-agent/1.0'} | ||||
|     response = requests.get(url, headers=headers) | ||||
|     response.raise_for_status() | ||||
| 
 | ||||
|     zip_content = response.content | ||||
|     logger.info(f"OESM data version: {version}") | ||||
| 
 | ||||
|     logger.info(f"Creating new OESM data cache: {parquet_path}") | ||||
|     with zipfile.ZipFile(io.BytesIO(zip_content)) as z: | ||||
|         # Find the excel file in the zip | ||||
|  | @ -115,7 +125,7 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]: | |||
| 
 | ||||
|         logger.info(f"Reading {excel_filename} from zip archive.") | ||||
|         with z.open(excel_filename) as f: | ||||
|             df = pd.read_excel(f, engine='openpyxl') | ||||
|             df = pd.read_excel(f, engine='openpyxl', na_values=['*', '#']) | ||||
| 
 | ||||
|     df.to_parquet(parquet_path) | ||||
|     logger.info(f"Saved OESM data to cache: {parquet_path}") | ||||
|  | @ -124,25 +134,30 @@ def fetch_oesm_data(run: Run) -> Tuple[pd.DataFrame, str]: | |||
| def fetch_epoch_remote_data(run: Run) -> Tuple[pd.DataFrame, str]: | ||||
|     """ | ||||
|     Downloads the EPOCH AI remote work task data. | ||||
|     The version is the sha256 of the downloaded CSV file. | ||||
|     """ | ||||
|     # This is the direct download link constructed from the Google Drive share link | ||||
|     version = "latest" | ||||
|     url = "https://drive.google.com/uc?export=download&id=1GrHhuYIgaCCgo99dZ_40BWraz-fzo76r" | ||||
|     parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet" | ||||
|     run.meta.fetchers['epoch_remote'] = { | ||||
|         'url': url, | ||||
|         'version': version, | ||||
|         'parquet_path': str(parquet_path), | ||||
|     } | ||||
| 
 | ||||
|     if parquet_path.exists(): | ||||
|         logger.info(f"Using cached EPOCH remote data: {parquet_path}") | ||||
|         return pd.read_parquet(parquet_path), version | ||||
| 
 | ||||
|     logger.info(f"Downloading EPOCH remote data from Google Drive: {url}") | ||||
| 
 | ||||
|     # Need to handle potential cookies/redirects from Google Drive | ||||
|     session = requests.Session() | ||||
|     session.headers.update({"User-Agent": "econ-agent/1.0"}) | ||||
|     response = session.get(url, stream=True) | ||||
|     response.raise_for_status() | ||||
| 
 | ||||
|     csv_content = response.content | ||||
|     version = hashlib.sha256(csv_content).hexdigest() | ||||
|     logger.info(f"EPOCH remote data version (sha256): {version}") | ||||
| 
 | ||||
|     parquet_path = run.cache_dir / f"epoch_remote_{version}.parquet" | ||||
|     if parquet_path.exists(): | ||||
|         logger.info(f"Using cached EPOCH remote data: {parquet_path}") | ||||
|         return pd.read_parquet(parquet_path), version | ||||
| 
 | ||||
|     logger.info(f"Creating new EPOCH remote data cache: {parquet_path}") | ||||
|     df = pd.read_csv(io.BytesIO(csv_content)) | ||||
|  |  | |||
|  | @ -26,4 +26,16 @@ def _get_current_commit() -> str: | |||
|     """ | ||||
|     Returns the current git commit hash, "unknown", or "errored" depending on why the commit could not be retrieved. | ||||
|     """ | ||||
|     raise NotImplementedError | ||||
|     import subprocess | ||||
|     try: | ||||
|         # Get the current commit hash | ||||
|         commit_hash = subprocess.check_output( | ||||
|             ["git", "rev-parse", "HEAD"], stderr=subprocess.PIPE, text=True | ||||
|         ).strip() | ||||
|         return commit_hash | ||||
|     except subprocess.CalledProcessError: | ||||
|         # If git command fails (e.g., not a git repository) | ||||
|         return "errored" | ||||
|     except FileNotFoundError: | ||||
|         # If git is not installed | ||||
|         return "unknown" | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ from typing import Optional | |||
| from .metadata import Metadata | ||||
| 
 | ||||
| class Run(BaseModel): | ||||
|     model_config = {"arbitrary_types_allowed": True} | ||||
|     # === FETCHERS === | ||||
|     onet_conn: Optional[sqlite3.Connection] = None | ||||
|     onet_version: Optional[str] = None | ||||
|  |  | |||
|  | @ -5,22 +5,29 @@ from .postprocessors import check_for_insanity, create_df_tasks | |||
| from .generators import GENERATORS | ||||
| from .run import Run | ||||
| from .constants import GRAY | ||||
| import argparse | ||||
| import platformdirs | ||||
| import seaborn as sns | ||||
| import matplotlib as mpl | ||||
| from pathlib import Path | ||||
| from typings import Optional | ||||
| from typing import Optional | ||||
| 
 | ||||
| CACHE_DIR = platformdirs.user_cache_dir("econtai") | ||||
| 
 | ||||
| def run(output_dir: Optional[str] = None): | ||||
|     if output_dir is None: | ||||
|         output_dir = Path(".") | ||||
| 
 | ||||
|     load_dotenv() | ||||
|     _setup_graph_rendering() | ||||
| 
 | ||||
|     if output_dir is None: | ||||
|         output_dir = Path("dist/") | ||||
|     else: | ||||
|         output_dir = Path(output_dir).resolve() | ||||
| 
 | ||||
|     output_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
| 
 | ||||
|     current_run = Run(output_dir=output_dir, cache_dir=CACHE_DIR) | ||||
|     current_run.cache_dir.mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
|     # Fetchers (fetchers.py) | ||||
|     current_run.onet_conn, current_run.onet_version = fetch_onet_database(current_run) | ||||
|  | @ -54,3 +61,14 @@ def _setup_graph_rendering(): | |||
| 
 | ||||
| 
 | ||||
|     sns.set_style("white") | ||||
| 
 | ||||
| 
 | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser(description="Run the econtai pipeline.") | ||||
|     parser.add_argument("--output-dir", type=str, help="The directory to write output files to.") | ||||
|     args = parser.parse_args() | ||||
|     run(output_dir=args.output_dir) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Félix Dorn
						Félix Dorn