File size: 5,237 Bytes
4a30650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6f0ec
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM

# ---- CPU-only config ----
MID = "apple/FastVLM-0.5B"
IMAGE_TOKEN_INDEX = -200  # special image token id used by FastVLM

tok = None
model = None

def load_model():
    global tok, model
    if tok is None or model is None:
        print("Loading model (CPU)…")
        tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True)
        # Force CPU + float32 (fp16 is unsafe on CPU)
        model = AutoModelForCausalLM.from_pretrained(
            MID,
            torch_dtype=torch.float32,
            device_map="cpu",
            trust_remote_code=True,
        )
        print("Model loaded successfully on CPU!")
    return tok, model

def caption_image(image, custom_prompt=None):
    """
    Generate a caption for the input image (CPU-only).
    """
    if image is None:
        return "Please upload an image first."

    try:
        tok, model = load_model()

        # Convert image to RGB if needed
        if image.mode != "RGB":
            image = image.convert("RGB")

        prompt = custom_prompt if custom_prompt else "Describe this image in detail."

        # Single-turn chat with an <image> placeholder
        messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
        rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

        # Split around the literal "<image>"
        pre, post = rendered.split("<image>", 1)

        # Tokenize text around the image token
        pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
        post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids

        # Derive device/dtype from the loaded model (CPU here, but future-proof)
        model_device = next(model.parameters()).device
        model_dtype = next(model.parameters()).dtype

        # Insert IMAGE token id at placeholder position
        img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype, device=model_device)
        input_ids = torch.cat(
            [pre_ids.to(model_device), img_tok, post_ids.to(model_device)],
            dim=1
        )
        attention_mask = torch.ones_like(input_ids, device=model_device)

        # Preprocess image using model's vision tower
        px = model.get_vision_tower().image_processor(
            images=image, return_tensors="pt"
        )["pixel_values"].to(device=model_device, dtype=model_dtype)

        # Generate caption (deterministic)
        with torch.no_grad():
            out = model.generate(
                inputs=input_ids,
                attention_mask=attention_mask,
                images=px,
                max_new_tokens=128,
                do_sample=False,  # temperature is ignored when sampling is off
            )

        # Decode and slice to the assistant part if present
        generated_text = tok.decode(out[0], skip_special_tokens=True)
        if "Assistant:" in generated_text:
            response = generated_text.split("Assistant:", 1)[-1].strip()
        elif "assistant" in generated_text:
            response = generated_text.split("assistant", 1)[-1].strip()
        else:
            response = generated_text.strip()

        return response

    except Exception as e:
        return f"Error generating caption: {str(e)}"

# ---- Gradio UI (CPU) ----
with gr.Blocks(title="FastVLM Image Captioning (CPU)") as demo:
    gr.Markdown(
        """
        # 🖼️ FastVLM Image Captioning (CPU)
        Upload an image to generate a detailed caption using Apple's FastVLM-0.5B.
        This build runs on **CPU only**. Expect slower generation than GPU.
        """
    )

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image", elem_id="image-upload")
            custom_prompt = gr.Textbox(
                label="Custom Prompt (Optional)",
                placeholder="Leave empty for default: 'Describe this image in detail.'",
                lines=2
            )
            with gr.Row():
                clear_btn = gr.ClearButton([image_input, custom_prompt])
                generate_btn = gr.Button("Generate Caption", variant="primary")

        with gr.Column():
            output = gr.Textbox(
                label="Generated Caption",
                lines=8,
                max_lines=15,
                show_copy_button=True
            )

    generate_btn.click(fn=caption_image, inputs=[image_input, custom_prompt], outputs=output)

    # Also generate on image upload if no custom prompt
    def _auto_caption(img, prompt):
        return caption_image(img, prompt) if (img is not None and not prompt) else None

    image_input.change(fn=_auto_caption, inputs=[image_input, custom_prompt], outputs=output)

    gr.Markdown(
        """
        ---
        **Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B)  
        **Note:** CPU-only run. For speed, switch to a CUDA GPU build or a GPU Space.
        """
    )

if __name__ == "__main__":
    demo.launch(
        share=False,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860
    )