kosmos-2.5-demo / app.py
nielsr's picture
nielsr HF Staff
Fix ZeroGPU and model loading issues
d463280
raw
history blame
9.62 kB
import spaces
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration
import re
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Initialize models and processors lazily
base_model = None
base_processor = None
chat_model = None
chat_processor = None
def load_base_model():
global base_model, base_processor
if base_model is None:
base_repo = "microsoft/kosmos-2.5"
base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
base_repo,
device_map=device,
dtype=dtype,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
base_processor = AutoProcessor.from_pretrained(base_repo)
return base_model, base_processor
def load_chat_model():
global chat_model, chat_processor
if chat_model is None:
chat_repo = "microsoft/kosmos-2.5-chat"
chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
chat_repo,
device_map=device,
dtype=dtype,
attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
)
chat_processor = AutoProcessor.from_pretrained(chat_repo)
return chat_model, chat_processor
def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
y = y.replace(prompt, "")
if "<md>" in prompt:
return y
pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
bboxs_raw = re.findall(pattern, y)
lines = re.split(pattern, y)[1:]
bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
bboxs = [[int(j) for j in i] for i in bboxs]
info = ""
for i in range(len(lines)):
if i < len(bboxs):
box = bboxs[i]
x0, y0, x1, y1 = box
if not (x0 >= x1 or y0 >= y1):
x0 = int(x0 * scale_width)
y0 = int(y0 * scale_height)
x1 = int(x1 * scale_width)
y1 = int(y1 * scale_height)
info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
return info.strip()
@spaces.GPU
def generate_markdown(image):
if image is None:
return "Please upload an image."
model, processor = load_base_model()
prompt = "<md>"
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
result = generated_text[0].replace(prompt, "").strip()
return result
@spaces.GPU
def generate_ocr(image):
if image is None:
return "Please upload an image.", None
model, processor = load_base_model()
prompt = "<ocr>"
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Post-process OCR output
output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
# Create visualization
from PIL import ImageDraw
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)
lines = output_text.split("\n")
for line in lines:
if not line.strip():
continue
parts = line.split(",")
if len(parts) >= 8:
try:
coords = list(map(int, parts[:8]))
draw.polygon(coords, outline="red", width=2)
except:
continue
return output_text, vis_image
@spaces.GPU
def generate_chat_response(image, question):
if image is None:
return "Please upload an image."
if not question.strip():
return "Please ask a question."
model, processor = load_chat_model()
template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
prompt = template.format(question)
inputs = processor(text=prompt, images=image, return_tensors="pt")
height, width = inputs.pop("height"), inputs.pop("width")
raw_width, raw_height = image.size
scale_height = raw_height / height
scale_width = raw_width / width
inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=1024,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
# Extract only the assistant's response
result = generated_text[0]
if "ASSISTANT:" in result:
result = result.split("ASSISTANT:")[-1].strip()
return result
# Create Gradio interface
with gr.Blocks(title="KOSMOS-2.5 Document AI Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# KOSMOS-2.5 Document AI Demo
Explore Microsoft's KOSMOS-2.5, a multimodal model for reading text-intensive images!
This demo showcases three capabilities:
1. **Markdown Generation**: Convert document images to markdown format
2. **OCR with Bounding Boxes**: Extract text with spatial coordinates
3. **Document Q&A**: Ask questions about document content using KOSMOS-2.5 Chat
Upload a document image (receipt, form, article, etc.) and try different tasks!
""")
with gr.Tabs():
# Markdown Generation Tab
with gr.TabItem("πŸ“ Markdown Generation"):
with gr.Row():
with gr.Column():
md_image = gr.Image(type="pil", label="Upload Document Image")
md_button = gr.Button("Generate Markdown", variant="primary")
with gr.Column():
md_output = gr.Textbox(
label="Generated Markdown",
lines=15,
max_lines=20,
show_copy_button=True
)
# OCR Tab
with gr.TabItem("πŸ” OCR with Bounding Boxes"):
with gr.Row():
with gr.Column():
ocr_image = gr.Image(type="pil", label="Upload Document Image")
ocr_button = gr.Button("Extract Text with Coordinates", variant="primary")
with gr.Column():
with gr.Row():
ocr_text = gr.Textbox(
label="Extracted Text with Coordinates",
lines=10,
show_copy_button=True
)
ocr_vis = gr.Image(label="Visualization (Red boxes show detected text)")
# Chat Tab
with gr.TabItem("πŸ’¬ Document Q&A (Chat)"):
with gr.Row():
with gr.Column():
chat_image = gr.Image(type="pil", label="Upload Document Image")
chat_question = gr.Textbox(
label="Ask a question about the document",
placeholder="e.g., What is the total amount on this receipt?",
lines=2
)
chat_button = gr.Button("Get Answer", variant="primary")
with gr.Column():
chat_output = gr.Textbox(
label="Answer",
lines=8,
show_copy_button=True
)
# Event handlers
md_button.click(
fn=generate_markdown,
inputs=[md_image],
outputs=[md_output]
)
ocr_button.click(
fn=generate_ocr,
inputs=[ocr_image],
outputs=[ocr_text, ocr_vis]
)
chat_button.click(
fn=generate_chat_response,
inputs=[chat_image, chat_question],
outputs=[chat_output]
)
# Examples section
gr.Markdown("""
## Example Use Cases:
- **Receipts**: Extract itemized information or ask about totals
- **Forms**: Convert to structured format or answer specific questions
- **Articles**: Get markdown format or ask about content
- **Screenshots**: Extract text or get information about specific elements
## Note:
This is a generative model and may occasionally hallucinate. Results should be verified for accuracy.
""")
if __name__ == "__main__":
demo.launch()