File size: 11,109 Bytes
7c00453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b783c4a
7c00453
 
 
 
b783c4a
7c00453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb127d2
7c00453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59977d7
 
 
 
 
 
 
 
 
 
 
 
 
54ee8aa
 
7c00453
54ee8aa
59977d7
 
7c00453
59977d7
 
 
 
54ee8aa
7c00453
59977d7
 
c9aba3a
59977d7
 
 
 
7c00453
59977d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# # Step 2: Import necessary libraries
# import gradio as gr
# from PIL import Image
# from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModelForCausalLM
# from peft import PeftConfig, PeftModel
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from transformers.cache_utils import DynamicCache
# import json
# import os
# from peft import PeftConfig

# # Step 3: Set device and default dtype
# DEVICE = torch.device("cpu")  # Explicitly set to CPU
# torch.set_default_dtype(torch.float32)  # Use float32 for CPU compatibility (float16 is less reliable on CPU)

# # Step 4: Load CLIP model and processor
# clip_model = CLIPModel.from_pretrained(
#     "openai/clip-vit-base-patch32",
#     torch_dtype=torch.float32  # Use float32 instead of float16
# ).to(DEVICE)
# clip_processor = CLIPProcessor.from_pretrained(
#     "openai/clip-vit-base-patch32",
#     use_fast=True
# )

# # Step 5: Define the MultiModalModel class
# class MultiModalModel(nn.Module):
#     def __init__(self, phi_model_name="microsoft/phi-3-mini-4k-instruct", clip_model_name="openai/clip-vit-base-patch32"):
#         super().__init__()
#         self.phi = None  # Will be set after loading the PEFT model
#         self.tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
#         self.tokenizer.add_special_tokens({"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"})
#         self.clip = CLIPModel.from_pretrained(
#             clip_model_name,
#             torch_dtype=torch.float32  # Use float32 for CPU
#         ).eval().to(DEVICE)
#         image_embedding_dim = self.clip.config.projection_dim
#         phi_hidden_size = 3072  # Hardcoded for Phi-3 mini
#         self.image_projection = nn.Sequential(
#             nn.Linear(image_embedding_dim, phi_hidden_size, dtype=torch.float32),  # Use float32
#             nn.LayerNorm(phi_hidden_size, dtype=torch.float32),
#             nn.Dropout(0.1)
#         ).to(DEVICE)
#         nn.init.xavier_uniform_(self.image_projection[0].weight, gain=1.0)
#         nn.init.zeros_(self.image_projection[0].bias)

#     def forward(self, text_input_ids, attention_mask=None, image_embedding=None):
#         image_embedding = torch.clamp(image_embedding, min=-1e4, max=1e4)
#         image_embedding = F.normalize(image_embedding, dim=-1, eps=1e-5).to(torch.float32)  # Use float32
#         with torch.no_grad():
#             self.image_projection[0].weight.clamp_(-1.0, 1.0)
#             self.image_projection[0].bias.clamp_(-1.0, 1.0)
#         projected_image = 1.0 * self.image_projection(image_embedding)
#         projected_image = torch.clamp(projected_image, min=-1e4, max=1e4)
#         if torch.isnan(projected_image).any() or torch.isinf(projected_image).any():
#             print("Warning: Projected image contains NaN or Inf values after clamping, replacing with zeros")
#             projected_image = torch.where(
#                 torch.logical_or(torch.isnan(projected_image), torch.isinf(projected_image)),
#                 torch.zeros_like(projected_image),
#                 projected_image
#             )
#         if projected_image.dim() == 2:
#             projected_image = projected_image.unsqueeze(1)
#         text_embeddings = self.phi.get_input_embeddings()(text_input_ids)
#         fused_embeddings = text_embeddings.clone()
#         img_token_id = self.tokenizer.convert_tokens_to_ids("[IMG]")
#         img_token_mask = (text_input_ids == img_token_id)
#         for i in range(fused_embeddings.shape[0]):
#             img_positions = img_token_mask[i].nonzero(as_tuple=True)[0]
#             if img_positions.numel() > 0:
#                 fused_embeddings[i, img_positions[0], :] = projected_image[i, 0, :]
#         if torch.isnan(fused_embeddings).any() or torch.isinf(fused_embeddings).any():
#             print("Warning: Fused embeddings contain NaN or Inf values, replacing with zeros")
#             fused_embeddings = torch.where(
#                 torch.logical_or(torch.isnan(fused_embeddings), torch.isinf(fused_embeddings)),
#                 torch.zeros_like(fused_embeddings),
#                 fused_embeddings
#             )
#         return fused_embeddings

# # Step 6: Load the fine-tuned model weights from Epoch_0
# def load_model():
#     # 1. Load PEFT Config
#     peft_model_id = "finalmodel_v2"  # Path to the saved PEFT directory
#     # Load the config.json file
#     config_path = os.path.join(peft_model_id, "config.json")
#     with open(config_path, "r") as f:
#         peft_config_dict = json.load(f)
    
#     # Check if 'eva_config' exists and remove it if it's not needed
#     if "eva_config" in peft_config_dict:
#         print("Found 'eva_config' in the config. Removing it...")
#         del peft_config_dict["eva_config"]
    
#     # Save the modified config back
#     with open(config_path, "w") as f:
#         json.dump(peft_config_dict, f, indent=2)

#     # Now load the config with PeftConfig
#     config = PeftConfig.from_pretrained(peft_model_id)  # Use the config to determine the base model

#     attn_implementation = "eager"
#     cache = DynamicCache()

#     # Load base model without quantization (CPU-compatible)
#     base_model = AutoModelForCausalLM.from_pretrained(
#         config.base_model_name_or_path,
#         return_dict=True,
#         device_map="cpu",  # Explicitly set to CPU
#         trust_remote_code=False,
#         torch_dtype=torch.float32,  # Use float32 for CPU
#         attn_implementation="eager"
#     )
#     base_model.gradient_checkpointing_enable()
#     peft_model = PeftModel.from_pretrained(base_model, peft_model_id)
#     tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
#     special_tokens = {"additional_special_tokens": ["[IMG]"], "pad_token": "<pad>"}
#     tokenizer.add_special_tokens(special_tokens)
#     peft_model.resize_token_embeddings(len(tokenizer))
#     tokenizer.pad_token = tokenizer.eos_token

#     model = MultiModalModel(phi_model_name=config.base_model_name_or_path)
#     model.phi = peft_model
#     model.to(DEVICE)
#     model.eval()
#     return model, tokenizer

# # Load the model
# model, tokenizer = load_model()

# # Step 7: Simple captioning function (for demonstration, assuming you have one)
# def generate_caption(image, model, tokenizer):
#     try:
#         if not isinstance(image, Image.Image):
#             return "Error: Input must be a valid image."
#         if image.mode != "RGB":
#             image = image.convert("RGB")

#         # Process image with CLIP
#         image_inputs = clip_processor(images=image, return_tensors="pt").to(DEVICE)
#         with torch.no_grad():
#             image_embedding = clip_model.get_image_features(**image_inputs).to(torch.float32)

#         # Prepare prompt
#         prompt = "Caption this image: [IMG]"
#         inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
#         input_ids = inputs["input_ids"].to(DEVICE)
#         attention_mask = inputs["attention_mask"].to(DEVICE)

#         # Generate fused embeddings
#         with torch.no_grad():
#             fused_embedding = model(input_ids, attention_mask, image_embedding)

#         # Generate caption
#         with torch.no_grad():
#             generated_ids = model.phi.generate(
#                 inputs_embeds=fused_embedding,
#                 attention_mask=attention_mask,
#                 max_new_tokens=50,
#                 min_length=10,
#                 num_beams=3,  # Reduced for CPU speed
#                 repetition_penalty=1.2,
#                 do_sample=False
#             )
#         caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
#         return caption.strip()

#     except Exception as e:
#         return f"Error generating caption: {str(e)}"

# # Step 8: Gradio interface (same as your previous request)
# with gr.Blocks(title="CPU-Based Image Captioning") as interface:
#     with gr.Row():
#         with gr.Column():
#             image_input = gr.Image(type="pil", label="Upload an Image", sources=["upload"])
#         with gr.Column():
#             gr.Markdown("**The image depicts** a bustling cityscape at dusk, with towering skyscrapers reflecting the orange and pink hues of the setting sun. Streetlights are lined with a variety of vehicles, including cars, buses, and bicycles. Pedestrians can be seen walking along the sidewalks, some carrying shopping bags, while others are engrossed in their smartphones. The urban environment, casting a warm glow on the scene.")
#             caption_output = gr.Textbox(label="Caption:", placeholder="A vibrant cityscape at dusk, with skyscrapers reflecting the sunset", lines=2)
    
#     with gr.Row():
#         clear_button = gr.Button("Clear")
#         submit_button = gr.Button("Submit", variant="primary")

#     def update_caption(image):
#         if image is None:
#             return "Please upload an image."
#         caption = generate_caption(image, model, tokenizer)
#         return caption

#     submit_button.click(
#         fn=update_caption,
#         inputs=image_input,
#         outputs=caption_output
#     )

#     clear_button.click(
#         fn=lambda: "",
#         outputs=caption_output
#     )

# interface.launch(debug=True)
#!/usr/bin/env python

# #!/usr/bin/env python

# import os
# import torch
# from multimodal_app import create_interface

# # Optional: Set torch threads to limit CPU usage
# torch.set_num_threads(4)

# # Create and launch the interface
# demo = create_interface()
# demo.queue()  # Enable queuing for better handling of multiple requests
# demo.launch()
#!/usr/bin/env python

import os
import torch
import sys
import argparse

# Set environment variables for better compatibility
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Set torch threads to limit CPU usage
torch.set_num_threads(4)

# Parse command line arguments
parser = argparse.ArgumentParser(description="Multimodal Image Description App")
parser.add_argument("--peft-model", type=str, default="model_V1",
                   help="Path to PEFT model")
parser.add_argument("--port", type=int, default=7860,
                   help="Port to run the server on")
args = parser.parse_args()

try:
    from multimodal_app import create_interface, load_model
    
    # Preload the model with PEFT path
    print(f"Preloading model with PEFT path: {args.peft_model}")
    load_model(args.peft_model)
    
    # Create and launch the interface
    demo = create_interface()
    
    # Launch with proper settings for stability
    demo.launch(
        share=False,              # Set to True if you want a public link
        debug=True,               # Enable debug for better error messages
        server_name="0.0.0.0",    # Listen on all interfaces  
        server_port=args.port,    # Port from arguments
        show_api=False            # Hide API docs for simplicity
    )
except Exception as e:
    print(f"Error starting application: {e}")
    import traceback
    traceback.print_exc()
    sys.exit(1)