Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import torch | |
| torch.jit.script = lambda f: f | |
| import shlex | |
| import spaces | |
| import gradio as gr | |
| from threading import Thread | |
| from transformers import TextIteratorStreamer | |
| import hashlib | |
| import os | |
| from transformers import AutoModel, AutoProcessor | |
| import sys | |
| import subprocess | |
| from PIL import Image | |
| import time | |
| # install packages for mamba | |
| def install(): | |
| print("Install personal packages", flush=True) | |
| subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
| install() | |
| from cobra import load | |
| vlm = load("cobra+3b") | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| DTYPE = torch.bfloat16 | |
| else: | |
| DEVICE = "cpu" | |
| DTYPE = torch.float32 | |
| vlm.to(DEVICE, dtype=DTYPE) | |
| prompt_builder = vlm.get_prompt_builder() | |
| def bot_streaming(message, history, temperature, top_k, max_new_tokens): | |
| if len(history) == 0: | |
| prompt_builder.prompt, prompt_builder.turn_count = "", 0 | |
| image = None | |
| if message["files"]: | |
| image = message["files"][-1]["path"] | |
| else: | |
| # if there's no image uploaded for this turn, look for images in the past turns | |
| # kept inside tuples, take the last one | |
| for hist in history: | |
| if type(hist[0])==tuple: | |
| image = hist[0][0] | |
| if image is not None: | |
| image = Image.open(image).convert("RGB") | |
| prompt_builder.add_turn(role="human", message=message['text']) | |
| prompt_text = prompt_builder.get_prompt() | |
| # Generate from the VLM | |
| with torch.no_grad(): | |
| generated_text = vlm.generate( | |
| image, | |
| prompt_text, | |
| use_cache=True, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_k=top_k, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| prompt_builder.add_turn(role="gpt", message=generated_text) | |
| time.sleep(0.04) | |
| yield generated_text | |
| demo = gr.ChatInterface(fn=bot_streaming, | |
| additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), | |
| gr.Slider(1, 3, value=1, step=1, label="Top k"), | |
| gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], | |
| title="Cobra", | |
| description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.", | |
| stop_btn="Stop Generation", multimodal=True) | |
| demo.launch(debug=True) |