132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
import sqlite3
|
|
import pandas as pd
|
|
import requests
|
|
import io
|
|
import zipfile
|
|
import yaml
|
|
from pathlib import Path
|
|
from .logger import logger
|
|
from typing import Tuple, Dict
|
|
|
|
ONET_VERSION = "29_1"
|
|
ONET_URL = f"https://www.onetcenter.org/dl_files/database/db_{ONET_VERSION}_mysql.zip"
|
|
|
|
def fetch_onet_database(cache_dir: Path) -> sqlite3.Connection:
|
|
DB_PATH = cache_dir / f"onet_{ONET_VERSION}.db"
|
|
|
|
if DB_PATH.exists():
|
|
logger.info(f"Using cached O*NET database: {DB_PATH}")
|
|
return sqlite3.connect(DB_PATH)
|
|
|
|
logger.info(f"Downloading O*NET database from {ONET_URL}")
|
|
response = requests.get(ONET_URL, stream=True, headers={
|
|
"User-Agent": "econ-agent/1.0"
|
|
})
|
|
response.raise_for_status()
|
|
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.executescript("""
|
|
PRAGMA journal_mode = OFF;
|
|
PRAGMA synchronous = 0;
|
|
PRAGMA cache_size = 1000000;
|
|
PRAGMA locking_mode = EXCLUSIVE;
|
|
PRAGMA temp_store = MEMORY;
|
|
PRAGMA foreign_keys = ON;
|
|
""")
|
|
|
|
zip_content = response.content
|
|
with zipfile.ZipFile(io.BytesIO(zip_content)) as z:
|
|
sql_scripts = []
|
|
for filename in sorted(z.namelist()):
|
|
if filename.endswith(".sql"):
|
|
sql_scripts.append(z.read(filename).decode('utf-8'))
|
|
|
|
if not sql_scripts:
|
|
raise RuntimeError("No SQL files found in the O*NET zip archive.")
|
|
|
|
logger.info("Executing SQL files in alphabetical order (single transaction mode)")
|
|
full_script = "BEGIN TRANSACTION;\n" + "\n".join(sql_scripts) + "\nCOMMIT;"
|
|
conn.executescript(full_script)
|
|
|
|
conn.executescript("""
|
|
PRAGMA journal_mode = WAL;
|
|
PRAGMA synchronous = NORMAL;
|
|
PRAGMA locking_mode = NORMAL;
|
|
PRAGMA temp_store = DEFAULT;
|
|
PRAGMA foreign_keys = ON;
|
|
PRAGMA optimize;
|
|
""")
|
|
conn.execute("VACUUM;")
|
|
conn.commit()
|
|
|
|
return conn
|
|
|
|
def fetch_oesm_data(cache_dir: Path) -> pd.DataFrame:
|
|
VERSION = "23"
|
|
URL = f"https://www.bls.gov/oes/special-requests/oesm{VERSION}nat.zip"
|
|
DATA_PATH = cache_dir / "oesm.parquet"
|
|
|
|
if DATA_PATH.exists():
|
|
logger.info(f"Using cached OESM data: {DATA_PATH}")
|
|
return pd.read_parquet(DATA_PATH)
|
|
|
|
logger.info(f"Downloading OESM data from {URL}")
|
|
headers = {'User-Agent': 'econ-agent/1.0'}
|
|
response = requests.get(URL, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
zip_content = response.content
|
|
|
|
logger.info(f"Creating new OESM data cache: {DATA_PATH}")
|
|
with zipfile.ZipFile(io.BytesIO(zip_content)) as z:
|
|
with z.open(f"oesm{VERSION}national.xlsx") as f:
|
|
df = pd.read_excel(f, engine='openpyxl', na_values=['*', '#'])
|
|
|
|
df.to_parquet(DATA_PATH)
|
|
logger.info(f"Saved OESM data to cache: {DATA_PATH}")
|
|
return df
|
|
|
|
def fetch_epoch_remote_data(cache_dir: Path) -> pd.DataFrame:
|
|
URL = "https://drive.google.com/uc?export=download&id=1GrHhuYIgaCCgo99dZ_40BWraz-fzo76r"
|
|
DATA_PATH = cache_dir / f"epoch_remote_latest.parquet"
|
|
|
|
if DATA_PATH.exists():
|
|
logger.info(f"Using cached EPOCH remote data: {DATA_PATH}")
|
|
return pd.read_parquet(DATA_PATH)
|
|
|
|
logger.info(f"Downloading EPOCH remote data from Google Drive: {URL}")
|
|
|
|
session = requests.Session()
|
|
session.headers.update({"User-Agent": "econ-agent/1.0"})
|
|
response = session.get(URL, stream=True)
|
|
response.raise_for_status()
|
|
|
|
csv_content = response.content
|
|
|
|
logger.info(f"Creating new EPOCH remote data cache: {DATA_PATH}")
|
|
df = pd.read_csv(io.BytesIO(csv_content))
|
|
df.to_parquet(DATA_PATH)
|
|
|
|
return df
|
|
|
|
def fetch_metr_data(cache_dir: Path) -> Dict:
|
|
URL = "https://metr.org/assets/benchmark_results.yaml"
|
|
DATA_PATH = cache_dir / "metr_benchmark_results.yaml"
|
|
|
|
if DATA_PATH.exists():
|
|
logger.info(f"Using cached METR data: {DATA_PATH}")
|
|
with open(DATA_PATH, "r") as f:
|
|
return yaml.safe_load(f)
|
|
|
|
logger.info(f"Downloading METR data from {URL}")
|
|
headers = {"User-Agent": "econ-agent/1.0"}
|
|
response = requests.get(URL, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
yaml_content = response.content
|
|
|
|
logger.info(f"Creating new METR data cache: {DATA_PATH}")
|
|
with open(DATA_PATH, "wb") as f:
|
|
f.write(yaml_content)
|
|
|
|
return yaml.safe_load(yaml_content)
|