Files
vath/analysis/gpt4o/analysis_batch.py
2026-05-06 13:29:59 -04:00

557 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
analysis_batch.py — OpenAI Batch API pipeline
Commands (run manually in order):
submit <input_jsonl> [--model gpt-4o] [--limit N]
— build request file, upload, create batch
status [run_id] — check batch status, update manifest
download [run_id] — download + normalize output, update manifest
run_id defaults to the most recent run in runs/ when omitted.
File layout (all under analysis/gpt4o/):
requests/<run_id>.jsonl — batch input sent to OpenAI
raw/<run_id>.jsonl — raw batch output from OpenAI
runs/<run_id>.json — run manifest
<run_id>_<model>.jsonl — normalized output (same schema as realtime)
"""
import argparse
import hashlib
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from dotenv import load_dotenv
try:
import openai
except ImportError:
sys.exit("openai package not installed. Run: pip install openai")
# ---------------------------------------------------------------------------
# Model limits and token estimation
# Max enqueued tokens across ALL concurrent batches for this model
# (docs/openai.md pricing table, updated 2026-05-05).
# NOTE: your org tier may be lower — if a submit fails, use --limit to reduce chunk size.
MODEL_LIMITS: dict[str, int] = {
"gpt-5.5": 900_000,
"gpt-5.4": 900_000,
"gpt-5.4-mini": 2_000_000,
"gpt-5.4-nano": 200_000,
"gpt-4o": 900_000,
"gpt-4o-mini": 2_000_000,
"gpt-o4-mini": 2_000_000,
}
_DEFAULT_TOKEN_LIMIT = 900_000
# tiktoken encoding per model family; unknown models fall back to o200k_base
_MODEL_ENCODING: dict[str, str] = {
"gpt-5.5": "o200k_base",
"gpt-5.4": "o200k_base",
"gpt-5.4-mini": "o200k_base",
"gpt-5.4-nano": "o200k_base",
"gpt-4o": "o200k_base",
"gpt-4o-mini": "o200k_base",
"gpt-o4-mini": "o200k_base",
}
# Leave 10% headroom below the published limit
_LIMIT_BUFFER = 0.90
def estimate_tokens(messages: list[dict], model: str) -> int:
"""Estimate token count for a messages list.
Uses tiktoken when available (exact for OpenAI models); falls back to
chars/3 + 4-token overhead per message for unknown/Anthropic models.
"""
try:
import tiktoken
enc = tiktoken.get_encoding(_MODEL_ENCODING.get(model, "o200k_base"))
return sum(4 + len(enc.encode(m["content"])) for m in messages)
except ImportError:
return sum(4 + len(m["content"]) // 3 for m in messages)
def chunk_comments_by_tokens(
comments: list[dict], forum: dict | None, model: str
) -> list[list[dict]]:
"""Split comments into chunks where each chunk fits under the model token limit."""
raw_limit = MODEL_LIMITS.get(model, _DEFAULT_TOKEN_LIMIT)
token_limit = int(raw_limit * _LIMIT_BUFFER)
chunks: list[list[dict]] = []
current: list[dict] = []
current_tokens = 0
for comment in comments:
messages, _ = build_messages(comment, forum)
tokens = estimate_tokens(messages, model)
if current and current_tokens + tokens > token_limit:
chunks.append(current)
current = [comment]
current_tokens = tokens
else:
current.append(comment)
current_tokens += tokens
if current:
chunks.append(current)
return chunks
# ---------------------------------------------------------------------------
# Prompt
_DEFAULT_PROMPT_FILE = Path(__file__).parent.parent / "prompt-1.txt"
SYSTEM_PROMPT = _DEFAULT_PROMPT_FILE.read_text(encoding="utf-8").strip()
PROMPT_VERSION = hashlib.sha256(SYSTEM_PROMPT.encode("utf-8")).hexdigest()[:7]
def _load_prompt(path: Path) -> None:
"""Re-read a prompt file, updating module-level SYSTEM_PROMPT and PROMPT_VERSION."""
global SYSTEM_PROMPT, PROMPT_VERSION
SYSTEM_PROMPT = path.read_text(encoding="utf-8").strip()
PROMPT_VERSION = hashlib.sha256(SYSTEM_PROMPT.encode("utf-8")).hexdigest()[:7]
USER_TEMPLATE = """\
## Proposed Regulation
Title: {reg_title}
Description: {reg_desc}
---
## Public Comment
Comment ID: {comment_id}
Title: {comment_title}
Body:
{comment_text}
---
Classify this comment per the instructions. Return only JSON.\
"""
MAX_COMMENT_CHARS = 6000
# ---------------------------------------------------------------------------
# Directories
_SCRIPT_DIR = Path(__file__).parent
REQUESTS_DIR = _SCRIPT_DIR / "requests"
RAW_DIR = _SCRIPT_DIR / "raw"
RUNS_DIR = _SCRIPT_DIR / "runs"
# ---------------------------------------------------------------------------
# Core functions (importable for tests)
def load_items(path: Path) -> tuple[dict | None, list[dict]]:
"""Read a scraped JSONL file. Returns (forum_item_or_None, [comment_items])."""
forum = None
comments = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
if "comment_id" in item:
comments.append(item)
elif "reg_title" in item:
forum = item
return forum, comments
def custom_id_from(comment_id: str) -> str:
return f"comment_{comment_id}"
def parse_custom_id(custom_id: str) -> str:
"""Return comment_id from a custom_id string."""
return custom_id.removeprefix("comment_")
def build_messages(comment: dict, forum: dict | None) -> tuple[list, bool]:
"""Build OpenAI messages for one comment. Returns (messages, truncated)."""
reg_title = (forum or {}).get("reg_title", "[unknown]")
reg_desc = (forum or {}).get("reg_desc", "[unknown]")
body = (comment.get("text") or "").strip()
truncated = False
if not body:
body = "[No body text provided]"
elif len(body) > MAX_COMMENT_CHARS:
body = body[:MAX_COMMENT_CHARS] + "... [truncated]"
truncated = True
user_text = USER_TEMPLATE.format(
reg_title=reg_title,
reg_desc=reg_desc,
comment_id=comment.get("comment_id", ""),
comment_title=comment.get("title", ""),
comment_text=body,
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_text},
], truncated
def build_batch_request_line(comment: dict, forum: dict | None, model: str) -> dict:
"""Build one line of the batch input JSONL."""
messages, _ = build_messages(comment, forum)
return {
"custom_id": custom_id_from(comment["comment_id"]),
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": model,
"messages": messages,
"response_format": {"type": "json_object"},
"temperature": 0.0,
},
}
def normalize_output_line(
raw_line: dict,
comment_lookup: dict,
run_id: str,
analyzed_at: str,
model: str,
prompt_version: str,
) -> dict:
"""Convert one raw batch output line into a normalized analysis record.
comment_lookup: {comment_id: CommentItem dict}
prompt_version: taken from the run manifest so it reflects what was submitted.
"""
comment_id = parse_custom_id(raw_line.get("custom_id", ""))
comment = comment_lookup.get(comment_id, {})
base = {
"run_id": run_id,
"forum_id": comment.get("forum_id", ""),
"comment_id": comment_id,
"analyzed_at": analyzed_at,
"model": model,
"prompt_version": prompt_version,
"input_title": comment.get("title", ""),
"truncated": len(comment.get("text") or "") > MAX_COMMENT_CHARS,
}
# Check for outer-level batch error (e.g. batch_expired)
if raw_line.get("error"):
err = raw_line["error"]
err_msg = err.get("message", str(err)) if isinstance(err, dict) else str(err)
return {**base, "stance": None, "stance_confidence": None,
"stance_rationale": None, "tone": None, "tags": None, "error": err_msg}
response = raw_line.get("response") or {}
if response.get("status_code") != 200:
return {**base, "stance": None, "stance_confidence": None,
"stance_rationale": None, "tone": None, "tags": None,
"error": f"status {response.get('status_code')}"}
try:
content = response["body"]["choices"][0]["message"]["content"]
data = json.loads(content)
keys = ("stance", "stance_confidence", "stance_rationale", "tone", "tags")
parsed = {k: data.get(k) for k in keys}
return {**base, **parsed, "error": None}
except Exception as exc:
return {**base, "stance": None, "stance_confidence": None,
"stance_rationale": None, "tone": None, "tags": None, "error": str(exc)}
def make_manifest(
run_id: str,
input_filename: str,
input_sha256: str,
model: str,
batch_id: str,
records_submitted: int,
request_filename: str,
) -> dict:
return {
"run_id": run_id,
"input_filename": input_filename,
"input_sha256": input_sha256,
"prompt_hash": PROMPT_VERSION,
"model": model,
"batch_id": batch_id,
"records_submitted": records_submitted,
"records_completed": None,
"records_failed": None,
"request_filename": request_filename,
"raw_output_filename": None,
"normalized_output_filename": None,
"created_at": datetime.now(timezone.utc).isoformat(),
"completed_at": None,
}
def _latest_run_id() -> str:
"""Return the run_id of the most recently saved manifest, or exit if none found."""
runs = list(RUNS_DIR.glob("*.json")) if RUNS_DIR.exists() else []
if not runs:
sys.exit(f"No runs found in {RUNS_DIR}. Submit a batch first.")
latest = max(runs, key=lambda p: p.stat().st_mtime)
return latest.stem
def load_manifest(run_id: str) -> dict:
path = RUNS_DIR / f"{run_id}.json"
return json.loads(path.read_text(encoding="utf-8"))
def save_manifest(manifest: dict) -> None:
RUNS_DIR.mkdir(parents=True, exist_ok=True)
path = RUNS_DIR / f"{manifest['run_id']}.json"
path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8")
# ---------------------------------------------------------------------------
# Subcommand: submit
def _submit_chunk(
chunk: list[dict],
forum: dict | None,
input_path: Path,
input_sha256: str,
model: str,
client,
chunk_index: int,
total_chunks: int,
) -> str:
"""Upload and submit one chunk of comments. Returns the run_id."""
import uuid
run_id = str(uuid.uuid4())
label = f"chunk {chunk_index + 1}/{total_chunks}" if total_chunks > 1 else "single batch"
REQUESTS_DIR.mkdir(parents=True, exist_ok=True)
request_path = REQUESTS_DIR / f"{run_id}.jsonl"
with open(request_path, "w", encoding="utf-8") as f:
for comment in chunk:
line = build_batch_request_line(comment, forum, model)
f.write(json.dumps(line, ensure_ascii=False) + "\n")
print(f"[{label}] Wrote {len(chunk)} requests → {request_path}", file=sys.stderr)
with open(request_path, "rb") as f:
uploaded = client.files.create(file=f, purpose="batch")
print(f"[{label}] Uploaded: {uploaded.id}", file=sys.stderr)
batch = client.batches.create(
input_file_id=uploaded.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"run_id": run_id, "input_filename": str(input_path)},
)
print(f"[{label}] Batch created: {batch.id} status={batch.status}", file=sys.stderr)
manifest = make_manifest(
run_id=run_id,
input_filename=str(input_path),
input_sha256=input_sha256,
model=model,
batch_id=batch.id,
records_submitted=len(chunk),
request_filename=str(request_path),
)
save_manifest(manifest)
return run_id
def cmd_submit(args, client) -> None:
_load_prompt(Path(args.prompt))
print(f"Prompt: {args.prompt} (version {PROMPT_VERSION})", file=sys.stderr)
input_path = Path(args.input)
if not input_path.exists():
sys.exit(f"File not found: {input_path}")
print(f"Reading {input_path} ...", file=sys.stderr)
forum, comments = load_items(input_path)
if not comments:
sys.exit("No comment items found in input file.")
if forum is None:
print("Warning: no ForumItem found — regulation context will be [unknown].", file=sys.stderr)
if args.limit:
comments = comments[:args.limit]
print(f"Limiting to {len(comments)} comments (--limit {args.limit}).", file=sys.stderr)
token_limit = int(MODEL_LIMITS.get(args.model, _DEFAULT_TOKEN_LIMIT) * _LIMIT_BUFFER)
chunks = chunk_comments_by_tokens(comments, forum, args.model)
total = len(chunks)
print(
f"Model: {args.model} token limit: {token_limit:,} "
f"{len(comments)} comments split into {total} chunk(s).",
file=sys.stderr,
)
input_sha256 = hashlib.sha256(input_path.read_bytes()).hexdigest()
# Submit only the first chunk — the enqueued token limit is a TOTAL across all
# concurrent batches, so stacking multiple submissions will exceed the quota.
# Wait for each batch to complete before submitting the next.
run_id = _submit_chunk(chunks[0], forum, input_path, input_sha256, args.model, client, 0, total)
print(f"\nBatch 1/{total} submitted.", file=sys.stderr)
print(f" status: python analysis/gpt4o/analysis_batch.py status {run_id}", file=sys.stderr)
print(f" download: python analysis/gpt4o/analysis_batch.py download {run_id}", file=sys.stderr)
if total > 1:
remaining = sum(len(c) for c in chunks[1:])
print(f"\n{total - 1} more chunk(s) remaining ({remaining} comments).", file=sys.stderr)
print("After this batch completes and is downloaded, rerun submit with --limit to get the next chunk:", file=sys.stderr)
offset = len(chunks[0])
for idx, chunk in enumerate(chunks[1:], start=2):
print(f" chunk {idx}/{total}: comments {offset}{offset + len(chunk) - 1}", file=sys.stderr)
offset += len(chunk)
print(run_id) # stdout for scripting
# ---------------------------------------------------------------------------
# Subcommand: status
def cmd_status(args, client) -> None:
run_id = args.run_id or _latest_run_id()
if not args.run_id:
print(f"(using latest run: {run_id})", file=sys.stderr)
manifest = load_manifest(run_id)
batch = client.batches.retrieve(manifest["batch_id"])
counts = batch.request_counts
print(f"status: {batch.status}")
print(f"completed: {counts.completed}/{counts.total}")
print(f"failed: {counts.failed}")
manifest["records_completed"] = counts.completed
manifest["records_failed"] = counts.failed
save_manifest(manifest)
if batch.status == "completed":
print(f"\nReady to download. Run:")
print(f" python analysis/gpt4o/analysis_batch.py download {run_id}")
# ---------------------------------------------------------------------------
# Subcommand: download
def cmd_download(args, client) -> None:
run_id = args.run_id or _latest_run_id()
if not args.run_id:
print(f"(using latest run: {run_id})", file=sys.stderr)
manifest = load_manifest(run_id)
batch = client.batches.retrieve(manifest["batch_id"])
if batch.status != "completed":
sys.exit(f"Batch not complete yet (status={batch.status}). Run 'status' to check.")
run_id = manifest["run_id"]
model = manifest["model"]
model_slug = model.replace("/", "-")
# Download raw output
RAW_DIR.mkdir(parents=True, exist_ok=True)
raw_path = RAW_DIR / f"{run_id}.jsonl"
raw_text = client.files.content(batch.output_file_id).text
raw_path.write_text(raw_text, encoding="utf-8")
print(f"Raw output → {raw_path}", file=sys.stderr)
# Build comment lookup from original input for reconciliation
input_path = Path(manifest["input_filename"])
_, comments = load_items(input_path)
comment_lookup = {c["comment_id"]: c for c in comments}
# Normalize
completed_at = datetime.now(timezone.utc).isoformat()
if batch.completed_at:
completed_at = datetime.fromtimestamp(batch.completed_at, tz=timezone.utc).isoformat()
normalized_path = _SCRIPT_DIR / f"{run_id}_{model_slug}.jsonl"
n_ok = n_err = 0
with open(normalized_path, "w", encoding="utf-8") as out:
for line in raw_text.splitlines():
if not line.strip():
continue
raw_line = json.loads(line)
record = normalize_output_line(raw_line, comment_lookup, run_id, completed_at, model, manifest["prompt_hash"])
out.write(json.dumps(record, ensure_ascii=False) + "\n")
if record["error"]:
n_err += 1
else:
n_ok += 1
print(f"Normalized → {normalized_path} ({n_ok} ok, {n_err} errors)", file=sys.stderr)
manifest["records_completed"] = n_ok
manifest["records_failed"] = n_err
manifest["raw_output_filename"] = str(raw_path)
manifest["normalized_output_filename"] = str(normalized_path)
manifest["completed_at"] = completed_at
save_manifest(manifest)
print(f"Manifest updated → {RUNS_DIR / run_id}.json", file=sys.stderr)
# ---------------------------------------------------------------------------
# CLI
def main() -> None:
load_dotenv()
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
sys.exit("OPENAI_API_KEY not set. Create a .env file or export the variable.")
parser = argparse.ArgumentParser(
description="Public comment batch analysis pipeline.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
sub = parser.add_subparsers(dest="command", required=True)
p_submit = sub.add_parser("submit", help="Build and submit a batch job")
p_submit.add_argument("input", help="Path to scraped JSONL file")
p_submit.add_argument("--model", default="gpt-4o", help="OpenAI model (default: gpt-4o)")
p_submit.add_argument(
"--prompt",
default=str(_DEFAULT_PROMPT_FILE),
help="Path to system prompt file (default: analysis/prompt-1.txt)",
)
p_submit.add_argument(
"--limit", type=int, default=None, metavar="N",
help="Submit only the first N comments (useful for staying under token quota)",
)
p_status = sub.add_parser("status", help="Check batch status")
p_status.add_argument("run_id", nargs="?", default=None,
help="run_id from submit (default: most recent run)")
p_download = sub.add_parser("download", help="Download and normalize completed batch")
p_download.add_argument("run_id", nargs="?", default=None,
help="run_id from submit (default: most recent run)")
args = parser.parse_args()
client = openai.OpenAI(api_key=api_key)
if args.command == "submit":
cmd_submit(args, client)
elif args.command == "status":
cmd_status(args, client)
elif args.command == "download":
cmd_download(args, client)
if __name__ == "__main__":
main()