AI-PDF-Tool / main.py
moazx's picture
update
443e99e
import os
import json
import signal
import sys
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Sequence, Set, Any
from multiprocessing import Pool, cpu_count
from functools import partial
import fitz # PyMuPDF (Still needed for drawing output PDF)
import pypdfium2 as pdfium
import torch
from doclayout_yolo import YOLOv10
from huggingface_hub import hf_hub_download
from loguru import logger
from PIL import Image
import numpy as np
try:
import pymupdf4llm # type: ignore
except ImportError: # pragma: no cover - optional dependency
pymupdf4llm = None # type: ignore
# ----------------------------------------------------------------------
# CONFIGURATION
# ----------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Model options
MODEL_SIZE = 1024
REPO_ID = "juliozhao/DocLayout-YOLO-DocStructBench"
WEIGHTS_FILE = f"doclayout_yolo_docstructbench_imgsz{MODEL_SIZE}.pt"
# Detection settings
CONF_THRESHOLD = 0.25
# Multiprocessing settings
NUM_WORKERS = None # None = auto (cpu_count - 1), or set to specific number like 4
USE_MULTIPROCESSING = True # Set to False to disable parallel processing entirely
# ----------------------------------------------------------------------
# Color map for the layout classes
# ----------------------------------------------------------------------
CLASS_COLORS = {
"text": (0, 128, 0), # Dark Green
"title": (192, 0, 0), # Dark Red
"figure": (0, 0, 192), # Dark Blue
"table": (218, 165, 32), # Goldenrod (Dark Yellow)
"list": (128, 0, 128), # Purple
"header": (0, 128, 128), # Teal
"footer": (100, 100, 100), # Dark Gray
"figure_caption": (0, 0, 128), # Navy
"table_caption": (139, 69, 19), # Saddle Brown
"table_footnote": (128, 0, 128), # Purple
}
# Global model instance (will be None in worker processes until loaded)
_model = None
_shutdown_requested = False
# ----------------------------------------------------------------------
# Signal handler for graceful shutdown
# ----------------------------------------------------------------------
def signal_handler(signum, frame):
"""Handle interrupt signals gracefully."""
global _shutdown_requested
if not _shutdown_requested:
_shutdown_requested = True
logger.warning("\n⚠️ Interrupt received! Finishing current page and shutting down gracefully...")
logger.warning("Press Ctrl+C again to force quit (may leave incomplete files)")
else:
logger.error("\n❌ Force quit requested. Exiting immediately.")
sys.exit(1)
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# ----------------------------------------------------------------------
# Model loader function
# ----------------------------------------------------------------------
def get_model():
"""Lazy load the model (only once per process)."""
global _model
if _model is None:
weights_path = hf_hub_download(repo_id=REPO_ID, filename=WEIGHTS_FILE)
_model = YOLOv10(weights_path)
logger.info(f"βœ“ Model loaded in worker process (PID: {os.getpid()})")
return _model
# ----------------------------------------------------------------------
# Worker initialization function
# ----------------------------------------------------------------------
def init_worker():
"""Initialize worker process - loads model once at startup."""
try:
get_model()
logger.success(f"Worker {os.getpid()} ready")
except Exception as e:
logger.error(f"Failed to initialize worker {os.getpid()}: {e}")
raise
# ----------------------------------------------------------------------
# Run layout detection on a single page image (YOLO)
# ----------------------------------------------------------------------
def detect_page(pil_img: Image.Image) -> List[dict]:
"""Detect layout elements using YOLO model."""
model = get_model() # Will return already-loaded model in worker
img_cv = np.array(pil_img)
results = model.predict(
img_cv,
imgsz=MODEL_SIZE,
conf=CONF_THRESHOLD,
device=DEVICE,
verbose=False
)
dets = []
for i, box in enumerate(results[0].boxes):
cls_id = int(box.cls.item())
name = results[0].names[cls_id]
conf = float(box.conf.item())
x0, y0, x1, y1 = box.xyxy[0].cpu().numpy().tolist()
dets.append({
"name": name,
"bbox": [x0, y0, x1, y1],
"conf": conf,
"source": "yolo",
"index": i
})
return dets
# ----------------------------------------------------------------------
# Crop & save figure/table regions (with captions)
# ----------------------------------------------------------------------
def get_union_box(box1: List[float], box2: List[float]) -> List[float]:
"""Get the bounding box enclosing two boxes."""
x0 = min(box1[0], box2[0])
y0 = min(box1[1], box2[1])
x1 = max(box1[2], box2[2])
y1 = max(box1[3], box2[3])
return [x0, y0, x1, y1]
def collect_caption_elements(
element: Dict,
all_dets: List[Dict],
target_name: str,
max_vertical_gap: float = 60.0,
min_overlap: float = 0.25,
) -> List[Dict]:
"""
Collect contiguous caption detections directly below a figure/table.
"""
base_box = element["bbox"]
base_bottom = base_box[3]
selected: List[Dict] = []
last_bottom = base_bottom
relevant = [
d for d in all_dets
if d["name"] == target_name and d["bbox"][1] >= base_bottom - 5
]
relevant.sort(key=lambda d: d["bbox"][1])
for cand in relevant:
cand_box = cand["bbox"]
top = cand_box[1]
if selected and top - last_bottom > max_vertical_gap:
break
if selected:
overlap = _horizontal_overlap_ratio(selected[-1]["bbox"], cand_box)
else:
overlap = _horizontal_overlap_ratio(base_box, cand_box)
if overlap < min_overlap:
continue
selected.append(cand)
last_bottom = cand_box[3]
return selected
def collect_title_and_text_segments(
element: Dict,
all_dets: List[Dict],
processed_indices: Set[int],
settings: Optional[Dict[str, float]] = None,
) -> Tuple[List[Dict], List[Dict]]:
"""
Locate a title below the element and any contiguous text blocks directly beneath it.
"""
if settings is None:
settings = TITLE_TEXT_ASSOCIATION
if not element.get("bbox"):
return [], []
figure_box = element["bbox"]
figure_bottom = figure_box[3]
candidates = [
d for d in all_dets
if d.get("bbox") and d["index"] not in processed_indices
]
candidates.sort(key=lambda d: d["bbox"][1])
titles: List[Dict] = []
texts: List[Dict] = []
for idx, det in enumerate(candidates):
if det["name"] != "title":
continue
title_box = det["bbox"]
if title_box[1] < figure_bottom - 5:
continue
vertical_gap = title_box[1] - figure_bottom
if vertical_gap > settings["max_title_gap"]:
break
overlap = _horizontal_overlap_ratio(figure_box, title_box)
if overlap < settings["min_overlap"]:
continue
titles.append(det)
last_bottom = title_box[3]
for follower in candidates[idx + 1 :]:
if follower["name"] == "title":
break
if follower["name"] != "text":
continue
text_box = follower["bbox"]
if text_box[1] < title_box[1]:
continue
gap = text_box[1] - last_bottom
if gap > settings["max_text_gap"]:
break
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
continue
texts.append(follower)
last_bottom = text_box[3]
break
return titles, texts
def save_layout_elements(pil_img: Image.Image, page_num: int,
dets: List[dict], out_dir: Path) -> List[dict]:
"""Save figure and table crops, merging captions."""
fig_dir = out_dir / "figures"
tab_dir = out_dir / "tables"
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(tab_dir, exist_ok=True)
infos = []
fig_count = 0
tab_count = 0
processed_indices = set()
for i, d in enumerate(dets):
if d["index"] in processed_indices:
continue
name = d["name"].lower()
final_box = d["bbox"]
caption_segments: List[Dict] = []
title_segments: List[Dict] = []
text_segments: List[Dict] = []
if name == "figure":
elem_type = "figure"
path_template = fig_dir / f"page_{page_num + 1}_fig_{fig_count}.png"
fig_count += 1
caption_segments = collect_caption_elements(d, dets, "figure_caption")
for cap in caption_segments:
final_box = get_union_box(final_box, cap["bbox"])
processed_indices.add(cap["index"])
title_segments, text_segments = collect_title_and_text_segments(
d, dets, processed_indices
)
for seg in title_segments + text_segments:
final_box = get_union_box(final_box, seg["bbox"])
processed_indices.add(seg["index"])
elif name == "table":
elem_type = "table"
path_template = tab_dir / f"page_{page_num + 1}_tab_{tab_count}.png"
tab_count += 1
caption_segments = collect_caption_elements(d, dets, "table_caption")
for cap in caption_segments:
final_box = get_union_box(final_box, cap["bbox"])
processed_indices.add(cap["index"])
else:
continue
x0, y0, x1, y1 = map(int, final_box)
crop = pil_img.crop((x0, y0, x1, y1))
if crop.mode == "CMYK":
crop = crop.convert("RGB")
crop.save(path_template)
info_data = {
"type": elem_type,
"page": page_num + 1,
"bbox_pixels": final_box,
"conf": d["conf"],
"source": d.get("source", "yolo"),
"image_path": str(path_template.relative_to(out_dir)),
"width": int(x1 - x0),
"height": int(y1 - y0),
"page_width": pil_img.width,
"page_height": pil_img.height,
}
if caption_segments:
info_data["captions"] = [
{
"bbox": cap["bbox"],
"conf": cap.get("conf"),
"index": cap["index"],
"source": cap.get("source"),
"page": page_num + 1,
}
for cap in caption_segments
]
if title_segments:
info_data["titles"] = [
{
"bbox": seg["bbox"],
"conf": seg.get("conf"),
"index": seg["index"],
"source": seg.get("source"),
"page": page_num + 1,
}
for seg in title_segments
]
if text_segments:
info_data["texts"] = [
{
"bbox": seg["bbox"],
"conf": seg.get("conf"),
"index": seg["index"],
"source": seg.get("source"),
"page": page_num + 1,
}
for seg in text_segments
]
infos.append(info_data)
return infos
TABLE_STITCH_TOLERANCES = {
"x_tol": 60,
"y_tol": 60,
"width_tol": 120,
"height_tol": 120,
}
CROSS_PAGE_CAPTION_THRESHOLDS = {
"max_top_ratio": 0.35,
"max_top_pixels": 220,
"x_tol": 120,
"width_tol": 200,
"min_overlap": 0.05,
}
TITLE_TEXT_ASSOCIATION = {
"max_title_gap": 220,
"max_text_gap": 160,
"min_overlap": 0.2,
}
def _horizontal_overlap_ratio(box1: List[float], box2: List[float]) -> float:
"""Compute horizontal overlap ratio between two bounding boxes."""
x_left = max(box1[0], box2[0])
x_right = min(box1[2], box2[2])
overlap = max(0.0, x_right - x_left)
if overlap <= 0:
return 0.0
width_union = max(box1[2], box2[2]) - min(box1[0], box2[0])
if width_union <= 0:
return 0.0
return overlap / width_union
def _bbox_to_rect(bbox: List[float]) -> Tuple[int, int, int, int]:
"""Convert [x0, y0, x1, y1] into (x, y, w, h)."""
x0, y0, x1, y1 = bbox
return int(x0), int(y0), int(x1 - x0), int(y1 - y0)
def _open_table_image(elem: Dict, out_dir: Path) -> Optional[Image.Image]:
"""Open a table image relative to the output directory."""
image_path = out_dir / elem["image_path"]
if not image_path.exists():
logger.warning(f"Missing table crop for stitching: {image_path}")
return None
img = Image.open(image_path)
if img.mode != "RGB":
img = img.convert("RGB")
return img
def _pad_width(img: Image.Image, target_width: int) -> Image.Image:
if img.width >= target_width:
return img
canvas = Image.new("RGB", (target_width, img.height), color=(255, 255, 255))
canvas.paste(img, (0, 0))
return canvas
def _pad_height(img: Image.Image, target_height: int) -> Image.Image:
if img.height >= target_height:
return img
canvas = Image.new("RGB", (img.width, target_height), color=(255, 255, 255))
canvas.paste(img, (0, 0))
return canvas
def _append_segment_image(
base_img: Image.Image,
segment_img: Image.Image,
resize_to_base: bool = False,
) -> Image.Image:
"""Append segment image below base image with optional width alignment."""
if base_img.mode != "RGB":
base_img = base_img.convert("RGB")
if segment_img.mode != "RGB":
segment_img = segment_img.convert("RGB")
if resize_to_base and segment_img.width > 0 and base_img.width > 0:
segment_img = segment_img.resize(
(
base_img.width,
max(1, int(segment_img.height * (base_img.width / segment_img.width))),
),
Image.Resampling.LANCZOS,
)
target_width = max(base_img.width, segment_img.width)
base_img = _pad_width(base_img, target_width)
segment_img = _pad_width(segment_img, target_width)
stitched = Image.new(
"RGB",
(target_width, base_img.height + segment_img.height),
color=(255, 255, 255),
)
stitched.paste(base_img, (0, 0))
stitched.paste(segment_img, (0, base_img.height))
return stitched
def _render_pdf_page(
pdf_doc: pdfium.PdfDocument,
page_index: int,
scale: float,
cache: Dict[int, Image.Image],
) -> Optional[Image.Image]:
"""Render a PDF page to a PIL image with caching."""
if page_index in cache:
return cache[page_index]
try:
page = pdf_doc[page_index]
bitmap = page.render(scale=scale)
pil_img = bitmap.to_pil()
page.close()
except Exception as exc:
logger.error(f"Failed to render page {page_index + 1} for caption stitching: {exc}")
return None
cache[page_index] = pil_img
return pil_img
def _crop_pdf_region(
page_img: Optional[Image.Image], bbox: List[float]
) -> Optional[Image.Image]:
"""Crop a region from a rendered PDF page."""
if page_img is None:
return None
x0, y0, x1, y1 = map(int, bbox)
x0 = max(0, x0)
y0 = max(0, y0)
x1 = min(page_img.width, max(x0 + 1, x1))
y1 = min(page_img.height, max(y0 + 1, y1))
if x0 >= x1 or y0 >= y1:
return None
crop = page_img.crop((x0, y0, x1, y1))
if crop.mode == "CMYK":
crop = crop.convert("RGB")
return crop
def write_markdown_document(pdf_path: Path, out_dir: Path) -> Optional[Path]:
"""
Extract markdown text from a PDF using PyMuPDF4LLM and write it to disk.
"""
if pymupdf4llm is None:
logger.warning(
"Skipping markdown extraction for %s because pymupdf4llm is not installed.",
pdf_path.name,
)
return None
try:
markdown_content = pymupdf4llm.to_markdown(str(pdf_path))
except Exception as exc:
logger.error(f" Failed to create markdown for {pdf_path.name}: {exc}")
return None
if isinstance(markdown_content, list):
markdown_content = "\n\n".join(
part for part in markdown_content if isinstance(part, str)
)
if not isinstance(markdown_content, str):
logger.error(
f" Unexpected markdown output type {type(markdown_content)} for {pdf_path.name}"
)
return None
markdown_content = markdown_content.strip()
if not markdown_content:
logger.warning(f" No textual content extracted from {pdf_path.name}")
return None
if not markdown_content.endswith("\n"):
markdown_content += "\n"
md_path = out_dir / f"{pdf_path.stem}.md"
md_path.write_text(markdown_content, encoding="utf-8")
logger.info(f" Saved markdown to {md_path.name}")
return md_path
def _collect_text_under_title_cross_page(
title_det: Dict,
sorted_dets: List[Dict],
start_idx: int,
page_idx: int,
used_indices: Set[Tuple[int, int]],
settings: Optional[Dict[str, float]] = None,
) -> List[Dict]:
"""Collect text elements directly below a title on the next page."""
if settings is None:
settings = TITLE_TEXT_ASSOCIATION
texts: List[Dict] = []
title_box = title_det["bbox"]
last_bottom = title_box[3]
for follower in sorted_dets[start_idx + 1 :]:
det_index = follower.get("index")
if det_index is None or (page_idx, det_index) in used_indices:
continue
if follower["name"] == "title":
break
if follower["name"] != "text":
continue
text_box = follower["bbox"]
if text_box[1] < title_box[1]:
continue
gap = text_box[1] - last_bottom
if gap > settings["max_text_gap"]:
break
if _horizontal_overlap_ratio(title_box, text_box) < settings["min_overlap"]:
continue
texts.append(follower)
last_bottom = text_box[3]
return texts
def attach_cross_page_figure_captions(
elements: List[Dict],
all_dets: Sequence[Optional[List[Dict[str, Any]]]],
pdf_bytes: bytes,
out_dir: Path,
scale: float,
) -> List[Dict]:
"""
If a figure caption appears on the next page, stitch it to the prior figure.
"""
figures = [elem for elem in elements if elem.get("type") == "figure"]
if not figures or not all_dets:
return elements
try:
pdf_doc = pdfium.PdfDocument(pdf_bytes)
except Exception as exc:
logger.error(f"Unable to reopen PDF for figure caption stitching: {exc}")
return elements
page_cache: Dict[int, Image.Image] = {}
used_following_ids: Set[Tuple[int, int]] = set()
# Mark existing caption/title/text detections as used
for elem in figures:
for key in ("captions", "titles", "texts"):
for seg in elem.get(key, []) or []:
idx = seg.get("index")
page_no = seg.get("page")
if idx is None or page_no is None:
continue
used_following_ids.add((page_no - 1, idx))
for elem in figures:
page_no = elem.get("page")
bbox = elem.get("bbox_pixels")
if page_no is None or bbox is None:
continue
current_idx = page_no - 1
next_idx = current_idx + 1
if next_idx >= len(all_dets):
continue
next_dets = all_dets[next_idx]
if not next_dets:
continue
fig_width = bbox[2] - bbox[0]
page_img = _render_pdf_page(pdf_doc, next_idx, scale, page_cache)
if page_img is None:
continue
next_page_height = page_img.height
max_top_allowed = min(
CROSS_PAGE_CAPTION_THRESHOLDS["max_top_pixels"],
int(next_page_height * CROSS_PAGE_CAPTION_THRESHOLDS["max_top_ratio"]),
)
sorted_next = sorted(
[det for det in next_dets if det.get("bbox")],
key=lambda det: det["bbox"][1],
)
caption_candidate: Optional[Tuple[Dict, int]] = None
caption_candidates = []
for det in sorted_next:
if det.get("name") != "figure_caption":
continue
det_index = det.get("index")
if det_index is None or (next_idx, det_index) in used_following_ids:
continue
det_bbox = det.get("bbox")
if not det_bbox or det_bbox[1] > max_top_allowed:
continue
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
x_diff = abs(bbox[0] - det_bbox[0])
width_diff = abs((bbox[2] - bbox[0]) - (det_bbox[2] - det_bbox[0]))
if overlap < CROSS_PAGE_CAPTION_THRESHOLDS["min_overlap"]:
if (
x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
or width_diff > CROSS_PAGE_CAPTION_THRESHOLDS["width_tol"]
):
continue
score = width_diff + 0.5 * x_diff
caption_candidates.append((score, det, det_index))
if caption_candidates:
caption_candidates.sort(key=lambda item: item[0])
_, best_det, best_index = caption_candidates[0]
caption_candidate = (best_det, best_index)
title_candidate: Optional[Tuple[Dict, int]] = None
title_texts: List[Dict] = []
for idx_sorted, det in enumerate(sorted_next):
if det.get("name") != "title":
continue
det_index = det.get("index")
if det_index is None or (next_idx, det_index) in used_following_ids:
continue
det_bbox = det.get("bbox")
if not det_bbox or det_bbox[1] > max_top_allowed:
continue
overlap = _horizontal_overlap_ratio(bbox, det_bbox)
x_diff = abs(bbox[0] - det_bbox[0])
if (
overlap < TITLE_TEXT_ASSOCIATION["min_overlap"]
and x_diff > CROSS_PAGE_CAPTION_THRESHOLDS["x_tol"]
):
continue
title_candidate = (det, det_index)
title_texts = _collect_text_under_title_cross_page(
det, sorted_next, idx_sorted, next_idx, used_following_ids
)
break
if not caption_candidate and not title_candidate and not title_texts:
continue
figure_path = out_dir / elem["image_path"]
if not figure_path.exists():
continue
figure_img = Image.open(figure_path)
if figure_img.mode == "CMYK":
figure_img = figure_img.convert("RGB")
segments_added = False
if caption_candidate:
cap_det, cap_index = caption_candidate
caption_crop = _crop_pdf_region(page_img, cap_det["bbox"])
if caption_crop is not None:
figure_img = _append_segment_image(
figure_img, caption_crop, resize_to_base=True
)
elem.setdefault("captions", [])
elem["captions"].append(
{
"bbox": cap_det["bbox"],
"conf": cap_det.get("conf"),
"index": cap_index,
"source": cap_det.get("source"),
"page": next_idx + 1,
}
)
used_following_ids.add((next_idx, cap_index))
segments_added = True
if title_candidate:
title_det, title_index = title_candidate
title_crop = _crop_pdf_region(page_img, title_det["bbox"])
if title_crop is not None:
figure_img = _append_segment_image(figure_img, title_crop)
elem.setdefault("titles", [])
elem["titles"].append(
{
"bbox": title_det["bbox"],
"conf": title_det.get("conf"),
"index": title_index,
"source": title_det.get("source"),
"page": next_idx + 1,
}
)
used_following_ids.add((next_idx, title_index))
segments_added = True
for text_det in title_texts:
text_index = text_det.get("index")
text_crop = _crop_pdf_region(page_img, text_det["bbox"])
if text_crop is None:
continue
figure_img = _append_segment_image(figure_img, text_crop)
elem.setdefault("texts", [])
elem["texts"].append(
{
"bbox": text_det["bbox"],
"conf": text_det.get("conf"),
"index": text_index,
"source": text_det.get("source"),
"page": next_idx + 1,
}
)
if text_index is not None:
used_following_ids.add((next_idx, text_index))
segments_added = True
if not segments_added:
continue
figure_img.save(figure_path)
elem["width"] = figure_img.width
elem["height"] = figure_img.height
span = elem.get("page_span")
if span:
if next_idx + 1 not in span:
span.append(next_idx + 1)
else:
base_page = elem.get("page")
new_span = [page for page in (base_page, next_idx + 1) if page is not None]
elem["page_span"] = new_span
pdf_doc.close()
return elements
def _stitch_table_pair(
base_elem: Dict,
candidate_elem: Dict,
out_dir: Path,
merge_index: int,
stitch_type: str,
) -> Optional[Dict]:
"""Stitch two table crops either vertically or horizontally."""
base_img = _open_table_image(base_elem, out_dir)
candidate_img = _open_table_image(candidate_elem, out_dir)
if base_img is None or candidate_img is None:
return None
tables_dir = out_dir / "tables"
tables_dir.mkdir(parents=True, exist_ok=True)
if stitch_type == "vertical":
target_width = max(base_img.width, candidate_img.width)
base_img = _pad_width(base_img, target_width)
candidate_img = _pad_width(candidate_img, target_width)
merged_height = base_img.height + candidate_img.height
stitched = Image.new("RGB", (target_width, merged_height), color=(255, 255, 255))
stitched.paste(base_img, (0, 0))
stitched.paste(candidate_img, (0, base_img.height))
else:
target_height = max(base_img.height, candidate_img.height)
base_img = _pad_height(base_img, target_height)
candidate_img = _pad_height(candidate_img, target_height)
merged_width = base_img.width + candidate_img.width
stitched = Image.new("RGB", (merged_width, target_height), color=(255, 255, 255))
stitched.paste(base_img, (0, 0))
stitched.paste(candidate_img, (base_img.width, 0))
merged_name = (
f"page_{base_elem['page']}_to_{candidate_elem['page']}_"
f"table_merged_{merge_index}.png"
)
merged_path = tables_dir / merged_name
stitched.save(merged_path)
# Remove original partial crops to avoid duplicates
(out_dir / base_elem["image_path"]).unlink(missing_ok=True)
(out_dir / candidate_elem["image_path"]).unlink(missing_ok=True)
new_bbox = [
min(base_elem["bbox_pixels"][0], candidate_elem["bbox_pixels"][0]),
min(base_elem["bbox_pixels"][1], candidate_elem["bbox_pixels"][1]),
max(base_elem["bbox_pixels"][2], candidate_elem["bbox_pixels"][2]),
max(base_elem["bbox_pixels"][3], candidate_elem["bbox_pixels"][3]),
]
merged_elem = base_elem.copy()
merged_elem["page_span"] = [base_elem["page"], candidate_elem["page"]]
merged_elem["box_refs"] = [
{"page": base_elem["page"], "image_path": base_elem["image_path"]},
{"page": candidate_elem["page"], "image_path": candidate_elem["image_path"]},
]
merged_elem["bbox_pixels"] = new_bbox
merged_elem["image_path"] = str(merged_path.relative_to(out_dir))
merged_elem["width"] = stitched.width
merged_elem["height"] = stitched.height
merged_elem["page_height"] = stitched.height
merged_elem["conf"] = min(
base_elem.get("conf", 1.0), candidate_elem.get("conf", 1.0)
)
return merged_elem
def merge_spanning_tables(elements: List[Dict], out_dir: Path) -> List[Dict]:
"""
Stitch table crops that continue across adjacent pages using the heuristic
from the legacy OpenCV-based extractor.
"""
if not elements:
return elements
tables_by_page: Dict[int, List[Dict]] = {}
non_tables: List[Dict] = []
for elem in elements:
if elem.get("type") != "table":
non_tables.append(elem)
continue
page = elem.get("page")
if not isinstance(page, int):
non_tables.append(elem)
continue
tables_by_page.setdefault(page, []).append(elem)
merged_results: List[Dict] = []
used_next: Dict[int, set[int]] = {}
merge_counter = 0
for page in sorted(tables_by_page.keys()):
current_tables = tables_by_page.get(page, [])
next_page_tables = tables_by_page.get(page + 1, [])
next_used_indices = used_next.get(page + 1, set())
current_used_indices = used_next.get(page, set())
for idx_current, table_elem in enumerate(current_tables):
if idx_current in current_used_indices:
continue
if not next_page_tables:
merged_results.append(table_elem)
continue
x, y, w, h = _bbox_to_rect(table_elem["bbox_pixels"])
matched = False
for idx, candidate in enumerate(next_page_tables):
if idx in next_used_indices:
continue
if candidate.get("type") != "table":
continue
cx, cy, cw, ch = _bbox_to_rect(candidate["bbox_pixels"])
vertical_match = (
abs(x - cx) <= TABLE_STITCH_TOLERANCES["x_tol"]
and abs((x + w) - (cx + cw)) <= TABLE_STITCH_TOLERANCES["width_tol"]
)
horizontal_match = (
abs(y - cy) <= TABLE_STITCH_TOLERANCES["y_tol"]
and abs((y + h) - (cy + ch))
<= TABLE_STITCH_TOLERANCES["height_tol"]
)
stitch_type = "vertical" if vertical_match else None
if not stitch_type and horizontal_match:
stitch_type = "horizontal"
if not stitch_type:
continue
merge_counter += 1
merged_elem = _stitch_table_pair(
table_elem, candidate, out_dir, merge_counter, stitch_type
)
if merged_elem is None:
continue
merged_results.append(merged_elem)
next_used_indices.add(idx)
matched = True
break
if not matched:
merged_results.append(table_elem)
used_next[page + 1] = next_used_indices
merged_results.extend(non_tables)
return merged_results
# ----------------------------------------------------------------------
# Draw layout boxes on the original PDF
# ----------------------------------------------------------------------
def draw_layout_pdf(pdf_bytes: bytes, all_dets: List[List[dict]],
scale: float, out_path: Path):
"""Annotate PDF with semi-transparent bounding boxes and labels."""
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page_no, dets in enumerate(all_dets):
page = doc[page_no]
for d in dets:
rgb = CLASS_COLORS.get(d["name"], (0, 0, 0))
rect = fitz.Rect([c / scale for c in d["bbox"]])
border_color = [c / 255 for c in rgb]
fill_color = [c / 255 for c in rgb]
fill_opacity = 0.15
border_width = 1.5
page.draw_rect(
rect,
color=border_color,
fill=fill_color,
width=border_width,
overlay=True,
fill_opacity=fill_opacity
)
label = f"{d['name']} {d['conf']:.2f}"
if d.get("source"):
label += f" [{d['source'][0].upper()}]"
text_bg = fitz.Rect(rect.x0, rect.y0 - 10, rect.x0 + 60, rect.y0)
page.draw_rect(text_bg, color=None, fill=(1, 1, 1, 0.6), overlay=True)
page.insert_text(
(rect.x0 + 2, rect.y0 - 8),
label,
fontsize=6.5,
color=border_color,
overlay=True
)
doc.save(str(out_path))
doc.close()
# ----------------------------------------------------------------------
# Process a single PDF Page (for parallel execution)
# ----------------------------------------------------------------------
def process_page(task_data: Tuple[int, bytes, float, Path, str]) -> Optional[Tuple[int, List[dict], List[dict]]]:
"""
Process a single page of a PDF in a worker process.
Returns: (page_number, detections, elements) or None on failure
"""
pno, pdf_bytes, scale, out_dir, pdf_name = task_data
if _shutdown_requested:
return None
pdf_pdfium = None
try:
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
page = pdf_pdfium[pno]
bitmap = page.render(scale=scale)
pil = bitmap.to_pil()
dets = detect_page(pil)
elements = save_layout_elements(pil, pno, dets, out_dir)
page_figures = len([d for d in dets if d['name'] == 'figure'])
page_tables = len([d for d in dets if d['name'] == 'table'])
logger.info(f" [{pdf_name}] Page {pno + 1}: {page_figures} figs, {page_tables} tables")
page.close()
pdf_pdfium.close()
return (pno, dets, elements)
except Exception as e:
logger.error(f"Failed to process page {pno + 1} of {pdf_name}: {e}")
if pdf_pdfium:
pdf_pdfium.close()
return None
# ----------------------------------------------------------------------
# Process a full PDF using the persistent worker pool
# ----------------------------------------------------------------------
def process_pdf_with_pool(
pdf_path: Path,
out_dir: Path,
pool: Optional[Pool] = None,
*,
extract_images: bool = True,
extract_markdown: bool = True,
):
"""
Main processing pipeline for a PDF file.
If pool is provided, uses it. Otherwise processes serially.
"""
if _shutdown_requested:
logger.warning(f"Skipping {pdf_path.name} due to shutdown request")
return
stem = pdf_path.stem
logger.info(f"Processing {pdf_path.name}")
pdf_bytes = pdf_path.read_bytes()
doc = None
try:
doc = pdfium.PdfDocument(pdf_bytes)
page_count = len(doc)
except Exception as e:
logger.error(f"Failed to open PDF {pdf_path.name}: {e}. Skipping.")
return
finally:
if doc is not None:
doc.close()
scale = 2.0
all_elements: List[Dict] = []
filtered_dets: List[List[dict]] = []
if extract_images:
all_dets: List[Optional[List[dict]]] = [None] * page_count
if pool is not None and USE_MULTIPROCESSING:
logger.info(f" Using worker pool for {page_count} pages...")
tasks = [
(pno, pdf_bytes, scale, out_dir, pdf_path.name)
for pno in range(page_count)
]
try:
results = pool.map(process_page, tasks)
for res in results:
if res:
pno, dets, elements = res
all_dets[pno] = dets
all_elements.extend(elements)
except KeyboardInterrupt:
logger.warning("Processing interrupted during parallel execution")
raise
else:
logger.info("Using serial processing...")
try:
pdf_pdfium = pdfium.PdfDocument(pdf_bytes)
for pno in range(page_count):
if _shutdown_requested:
logger.warning(
f"Stopping at page {pno + 1}/{page_count} due to shutdown request"
)
break
try:
logger.info(f" Processing page {pno + 1}/{page_count}")
page = pdf_pdfium[pno]
bitmap = page.render(scale=scale)
pil = bitmap.to_pil()
dets = detect_page(pil)
all_dets[pno] = dets
elements = save_layout_elements(pil, pno, dets, out_dir)
all_elements.extend(elements)
page_figures = len([d for d in dets if d["name"] == "figure"])
page_tables = len([d for d in dets if d["name"] == "table"])
logger.info(
f" Found {page_figures} figures and {page_tables} tables"
)
page.close()
except Exception as e:
logger.error(f"Failed to process page {pno + 1}: {e}. Skipping page.")
pdf_pdfium.close()
except Exception as e:
logger.error(f"Fatal error processing {pdf_path.name}: {e}")
if "pdf_pdfium" in locals() and pdf_pdfium:
pdf_pdfium.close()
return
dets_per_page: List[Optional[List[Dict[str, Any]]]] = [
det if det is not None else None for det in all_dets
]
filtered_dets = [d for d in all_dets if d is not None]
if all_elements:
all_elements = merge_spanning_tables(all_elements, out_dir)
all_elements = attach_cross_page_figure_captions(
all_elements, dets_per_page, pdf_bytes, out_dir, scale
)
if all_elements:
content_list_path = out_dir / f"{stem}_content_list.json"
with open(content_list_path, "w", encoding="utf-8") as f:
json.dump(all_elements, f, ensure_ascii=False, indent=4)
logger.info(f" Saved {len(all_elements)} elements to JSON")
if filtered_dets:
draw_layout_pdf(
pdf_bytes, filtered_dets, scale, out_dir / f"{stem}_layout.pdf"
)
logger.info(" Generated annotated PDF")
else:
logger.warning(f"No detections found for {stem}. Skipping layout PDF.")
else:
logger.info(" Image extraction skipped per configuration.")
markdown_path = None
if extract_markdown:
markdown_path = write_markdown_document(pdf_path, out_dir)
if markdown_path is None:
logger.warning(f" Markdown extraction yielded no content for {stem}.")
if _shutdown_requested:
logger.warning(f"⚠️ Partial results saved for {stem} β†’ {out_dir}")
else:
if extract_images:
logger.success(
f"βœ“ {stem} β†’ {out_dir} ({len(all_elements)} elements extracted)"
)
else:
logger.success(f"βœ“ {stem} β†’ {out_dir} (image extraction skipped)")
# ----------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------
if __name__ == "__main__":
# Important for multiprocessing on Windows/macOS
torch.multiprocessing.set_start_method('spawn', force=True)
# Setup signal handlers for graceful shutdown
setup_signal_handlers()
INPUT_DIR = Path("./pdfs")
OUTPUT_DIR = Path("./output")
os.makedirs(INPUT_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
pdf_files = list(INPUT_DIR.glob("*.pdf"))
if not pdf_files:
logger.warning("No PDF files found in ./pdfs")
logger.info("Please add PDF files to the ./pdfs directory")
logger.info("The script will exit gracefully. No errors occurred.")
sys.exit(0)
logger.info(f"Found {len(pdf_files)} PDF file(s) to process")
logger.info(f"Settings: MODEL_SIZE={MODEL_SIZE}, CONF={CONF_THRESHOLD}")
# Determine worker count
total_cpus = cpu_count()
if NUM_WORKERS is None:
num_workers = max(1, total_cpus - 1)
else:
num_workers = max(1, min(NUM_WORKERS, total_cpus))
# Decide whether to use multiprocessing
use_pool = USE_MULTIPROCESSING and DEVICE == "cpu" and total_cpus >= 4
if use_pool:
logger.info(f"πŸš€ Creating persistent worker pool with {num_workers} workers...")
else:
if not USE_MULTIPROCESSING:
logger.info("Multiprocessing disabled by configuration")
elif DEVICE != "cpu":
logger.info(f"Using serial GPU processing (device: {DEVICE})")
else:
logger.info(f"Using serial CPU processing (CPU count {total_cpus} too low)")
pool = None
try:
# Create persistent pool ONCE for all PDFs
if use_pool:
pool = Pool(processes=num_workers, initializer=init_worker)
logger.success(f"βœ“ Worker pool ready with {num_workers} workers\n")
else:
# Load model in main process for serial execution
logger.info("Initializing model in main process...")
get_model()
logger.success(f"βœ“ Model loaded (device: {DEVICE})\n")
# Process all PDFs using the same pool
for i, pdf_path in enumerate(pdf_files, 1):
if _shutdown_requested:
logger.warning(f"\nShutdown requested. Processed {i-1}/{len(pdf_files)} files.")
break
logger.info(f"\n{'='*60}")
logger.info(f"πŸ“„ File {i}/{len(pdf_files)}: {pdf_path.name}")
logger.info(f"{'='*60}")
sub_out = OUTPUT_DIR / pdf_path.stem
os.makedirs(sub_out, exist_ok=True)
try:
process_pdf_with_pool(pdf_path, sub_out, pool)
except KeyboardInterrupt:
logger.warning(f"\nInterrupted while processing {pdf_path.name}")
break
except Exception as e:
logger.error(f"Error processing {pdf_path.name}: {e}")
if _shutdown_requested:
break
logger.info("Continuing with next file...")
continue
if _shutdown_requested:
logger.warning(f"\n⚠️ Processing interrupted. Partial results saved in {OUTPUT_DIR}")
else:
logger.success(f"\n✨ All done! Results are in {OUTPUT_DIR}")
except KeyboardInterrupt:
logger.error("\n❌ Processing interrupted by user")
sys.exit(1)
except Exception as e:
logger.error(f"\n❌ Fatal error: {e}")
sys.exit(1)
finally:
# Clean up pool if it exists
if pool is not None:
logger.info("\n🧹 Shutting down worker pool...")
pool.close()
pool.join()
logger.success("βœ“ Worker pool closed cleanly")