|
|
|
|
|
""" |
|
|
SAM 3 目标检测与分割演示的 Hugging Face Spaces 版本。 |
|
|
|
|
|
适配 Hugging Face Spaces 部署环境: |
|
|
1. 直接从 Hugging Face Hub 下载模型和资源 |
|
|
2. 支持 ZeroGPU 或 CPU 推理 |
|
|
3. 无需本地上传额外文件 |
|
|
|
|
|
支持功能: |
|
|
1. 文本提示分割 |
|
|
2. 单框/多框提示分割 |
|
|
3. 正框/负框交互式标注(Multi Box 模式下可切换绘制正框或负框) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import matplotlib.pyplot as plt |
|
|
import io |
|
|
import random |
|
|
from typing import List, Dict, Any, Tuple |
|
|
|
|
|
|
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
|
|
|
IS_HF_SPACES = os.environ.get("SPACE_ID") is not None |
|
|
|
|
|
|
|
|
try: |
|
|
from gradio_image_prompter import ImagePrompter |
|
|
IMAGE_PROMPTER_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"ImagePrompter 不可用: {e}") |
|
|
IMAGE_PROMPTER_AVAILABLE = False |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
SPACES_GPU_AVAILABLE = True |
|
|
except ImportError: |
|
|
SPACES_GPU_AVAILABLE = False |
|
|
print("Hugging Face Spaces GPU 模块不可用,将使用标准推理") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAM3_HF_REPO_ID = os.environ.get("SAM3_HF_REPO_ID", "facebook/sam3") |
|
|
|
|
|
|
|
|
|
|
|
SAM3_INSTALLED = False |
|
|
sam3 = None |
|
|
build_sam3_image_model = None |
|
|
box_xywh_to_cxcywh = None |
|
|
Sam3Processor = None |
|
|
normalize_bbox = None |
|
|
draw_box_on_image = None |
|
|
plot_mask = None |
|
|
plot_bbox = None |
|
|
COLORS = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] |
|
|
plot_results = None |
|
|
|
|
|
try: |
|
|
import sam3 |
|
|
from sam3 import build_sam3_image_model |
|
|
SAM3_INSTALLED = True |
|
|
print("✅ sam3 库已安装") |
|
|
|
|
|
|
|
|
try: |
|
|
from sam3.model.box_ops import box_xywh_to_cxcywh |
|
|
except ImportError as e: |
|
|
print(f"⚠️ box_ops 导入失败: {e}") |
|
|
|
|
|
def box_xywh_to_cxcywh(boxes): |
|
|
"""将 XYWH 格式转换为 CXCYWH 格式""" |
|
|
x, y, w, h = boxes.unbind(-1) |
|
|
cx = x + w / 2 |
|
|
cy = y + h / 2 |
|
|
return torch.stack([cx, cy, w, h], dim=-1) |
|
|
|
|
|
try: |
|
|
from sam3.model.sam3_image_processor import Sam3Processor |
|
|
except ImportError as e: |
|
|
print(f"⚠️ Sam3Processor 导入失败: {e}") |
|
|
Sam3Processor = None |
|
|
|
|
|
try: |
|
|
from sam3.visualization_utils import normalize_bbox, draw_box_on_image, plot_mask, plot_bbox, COLORS, plot_results |
|
|
except ImportError as e: |
|
|
print(f"⚠️ visualization_utils 导入失败: {e}") |
|
|
|
|
|
def normalize_bbox(boxes, width, height): |
|
|
"""归一化边界框坐标""" |
|
|
if isinstance(boxes, torch.Tensor): |
|
|
normalized = boxes.clone() |
|
|
normalized[..., 0] /= width |
|
|
normalized[..., 1] /= height |
|
|
normalized[..., 2] /= width |
|
|
normalized[..., 3] /= height |
|
|
return normalized |
|
|
return boxes |
|
|
|
|
|
def plot_mask(mask, color=(1, 0, 0), alpha=0.5): |
|
|
"""绘制掩码""" |
|
|
import matplotlib.pyplot as plt |
|
|
h, w = mask.shape[-2:] |
|
|
mask_image = mask.reshape(h, w, 1) * np.array(color).reshape(1, 1, -1) |
|
|
plt.imshow(mask_image, alpha=alpha) |
|
|
|
|
|
def plot_bbox(h, w, box, text="", box_format="XYXY", color=(1, 0, 0), relative_coords=False): |
|
|
"""绘制边界框""" |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
if isinstance(box, torch.Tensor): |
|
|
box = box.tolist() |
|
|
x0, y0, x1, y1 = box |
|
|
rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=2, edgecolor=color, facecolor='none') |
|
|
plt.gca().add_patch(rect) |
|
|
if text: |
|
|
plt.text(x0, y0, text, color=color, fontsize=8) |
|
|
|
|
|
COLORS = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (1, 0, 1), (0, 1, 1)] |
|
|
plot_results = None |
|
|
draw_box_on_image = None |
|
|
|
|
|
except ImportError as e: |
|
|
print(f"❌ sam3 库导入失败: {e}") |
|
|
print("请确保 requirements.txt 中包含: git+https://github.com/facebookresearch/sam3.git") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_boxes_with_labels( |
|
|
image: Image.Image, |
|
|
xyxy_boxes: List[List[float]], |
|
|
box_labels: List[bool] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
在图像上绘制带颜色和标签的框。 |
|
|
|
|
|
Args: |
|
|
image: 原始 PIL 图像 |
|
|
xyxy_boxes: 框坐标列表 [[x_min, y_min, x_max, y_max], ...] |
|
|
box_labels: 框标签列表 [True/False, ...],True=正框(绿色),False=负框(红色) |
|
|
|
|
|
Returns: |
|
|
带有彩色框和标签的图像 |
|
|
""" |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
img_draw = image.copy() |
|
|
draw = ImageDraw.Draw(img_draw) |
|
|
|
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) |
|
|
except: |
|
|
try: |
|
|
font = ImageFont.truetype("Arial.ttf", 16) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
for i, (box, label) in enumerate(zip(xyxy_boxes, box_labels)): |
|
|
x_min, y_min, x_max, y_max = [int(coord) for coord in box] |
|
|
|
|
|
|
|
|
if label: |
|
|
color = (255, 0, 0) |
|
|
label_text = f"Box {i}: True (正框)" |
|
|
else: |
|
|
color = (0, 255, 0) |
|
|
label_text = f"Box {i}: False (负框)" |
|
|
|
|
|
|
|
|
draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=3) |
|
|
|
|
|
|
|
|
text_bbox = draw.textbbox((x_min, y_min - 20), label_text, font=font) |
|
|
|
|
|
text_y = max(0, y_min - 22) |
|
|
if text_y == 0: |
|
|
text_y = y_max + 2 |
|
|
|
|
|
text_bbox = draw.textbbox((x_min, text_y), label_text, font=font) |
|
|
draw.rectangle(text_bbox, fill=color) |
|
|
draw.text((x_min, text_y), label_text, fill="white", font=font) |
|
|
|
|
|
return img_draw |
|
|
|
|
|
|
|
|
def process_imageprompter_data( |
|
|
data: Any, |
|
|
box_mode_history: List[Tuple[int, str]] = None, |
|
|
verbose: bool = False |
|
|
) -> Tuple[List[List[float]], List[bool]]: |
|
|
""" |
|
|
处理 ImagePrompter 数据,提取框坐标 (XYXY 格式) 和对应的标签(正/负框)。 |
|
|
|
|
|
ImagePrompter 返回格式: |
|
|
{'image': <PIL Image>, 'points': [[x1, y1, label1, x2, y2, label2], ...]} |
|
|
|
|
|
Args: |
|
|
data: ImagePrompter 返回的数据字典 |
|
|
box_mode_history: 框模式切换历史列表,格式为 [(框索引, 模式), ...] |
|
|
例如 [(0, "positive"), (2, "negative")] 表示第0个框开始是正框,第2个框开始是负框 |
|
|
如果为 None 或空,则所有框默认为正框 |
|
|
verbose: 是否输出详细调试日志 |
|
|
|
|
|
Returns: |
|
|
tuple: (xyxy_boxes, box_labels) |
|
|
- xyxy_boxes: 框坐标列表 [[x_min, y_min, x_max, y_max], ...] |
|
|
- box_labels: 框标签列表 [True/False, ...],True=正框,False=负框 |
|
|
""" |
|
|
if data is None or not isinstance(data, dict): |
|
|
return [], [] |
|
|
|
|
|
xyxy_boxes = [] |
|
|
|
|
|
if verbose: |
|
|
print(f"\n--- Shape Parsing Debug START ---") |
|
|
print(f"Debug: Data keys = {list(data.keys())}") |
|
|
print(f"Debug: Box mode history = {box_mode_history}") |
|
|
|
|
|
|
|
|
|
|
|
if 'points' in data and data['points'] is not None: |
|
|
points_list = data['points'] |
|
|
|
|
|
for i, points in enumerate(points_list): |
|
|
if isinstance(points, (list, np.ndarray)) and len(points) >= 6: |
|
|
try: |
|
|
|
|
|
x1 = float(points[0]) |
|
|
y1 = float(points[1]) |
|
|
x2 = float(points[3]) |
|
|
y2 = float(points[4]) |
|
|
|
|
|
|
|
|
x_min = min(x1, x2) |
|
|
x_max = max(x1, x2) |
|
|
y_min = min(y1, y2) |
|
|
y_max = max(y1, y2) |
|
|
|
|
|
box = [x_min, y_min, x_max, y_max] |
|
|
xyxy_boxes.append(box) |
|
|
|
|
|
except (ValueError, TypeError, IndexError): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
box_labels = [] |
|
|
current_mode = "positive" |
|
|
|
|
|
|
|
|
mode_switch_points = {} |
|
|
if box_mode_history: |
|
|
for box_idx, mode in box_mode_history: |
|
|
mode_switch_points[box_idx] = mode |
|
|
|
|
|
for i in range(len(xyxy_boxes)): |
|
|
|
|
|
if i in mode_switch_points: |
|
|
current_mode = mode_switch_points[i] |
|
|
|
|
|
is_positive = (current_mode == "positive") |
|
|
box_labels.append(is_positive) |
|
|
|
|
|
if verbose: |
|
|
print(f"Total boxes: {len(xyxy_boxes)} (正框: {sum(box_labels) if box_labels else 0}, 负框: {len(box_labels) - sum(box_labels) if box_labels else 0})") |
|
|
print(f"--- Shape Parsing Debug END ---\n") |
|
|
|
|
|
return xyxy_boxes, box_labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_boxes_to_image( |
|
|
image_pil: Image, |
|
|
tgt: Dict, |
|
|
return_point: bool = False, |
|
|
point_width: float = 1.0, |
|
|
return_score=True, |
|
|
) -> Image: |
|
|
"""Plot bounding boxes and labels on an image.""" |
|
|
boxes = tgt["boxes"] |
|
|
scores = tgt["scores"] |
|
|
|
|
|
draw = ImageDraw.Draw(image_pil) |
|
|
mask = Image.new("L", image_pil.size, 0) |
|
|
mask_draw = ImageDraw.Draw(mask) |
|
|
|
|
|
for box, score in zip(boxes, scores): |
|
|
color = tuple(np.random.randint(0, 255, size=3).tolist()) |
|
|
x0, y0, x1, y1 = box |
|
|
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) |
|
|
|
|
|
if return_point: |
|
|
center_x = int((x0 + x1) / 2) |
|
|
center_y = int((y0 + y1) / 2) |
|
|
draw.ellipse( |
|
|
( |
|
|
center_x - point_width, |
|
|
center_y - point_width, |
|
|
center_x + point_width, |
|
|
center_y + point_width, |
|
|
), |
|
|
fill=color, |
|
|
width=point_width, |
|
|
) |
|
|
else: |
|
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=int(point_width)) |
|
|
|
|
|
if return_score: |
|
|
text = f"{score:.2f}" |
|
|
else: |
|
|
text = f"" |
|
|
font = ImageFont.load_default() |
|
|
if hasattr(font, "getbbox"): |
|
|
bbox = draw.textbbox((x0, y0), text, font) |
|
|
else: |
|
|
w, h = draw.textsize(text, font) |
|
|
bbox = (x0, y0, w + x0, y0 + h) |
|
|
if not return_point: |
|
|
draw.rectangle(bbox, fill=color) |
|
|
draw.text((x0, y0), text, fill="white") |
|
|
|
|
|
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) |
|
|
return image_pil, mask |
|
|
|
|
|
|
|
|
def parse_visual_prompt(points: List): |
|
|
"""Parse visual prompt points to bounding boxes (XYXY format)""" |
|
|
boxes = [] |
|
|
pos_points = [] |
|
|
neg_points = [] |
|
|
for point in points: |
|
|
if point[2] == 2 and point[-1] == 3: |
|
|
x1, y1, _, x2, y2, _ = point |
|
|
boxes.append([x1, y1, x2, y2]) |
|
|
elif point[2] == 1 and point[-1] == 4: |
|
|
x, y, _, _, _, _ = point |
|
|
pos_points.append([x, y]) |
|
|
elif point[2] == 0 and point[-1] == 4: |
|
|
x, y, _, _, _, _ = point |
|
|
neg_points.append([x, y]) |
|
|
return boxes, pos_points, neg_points |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bpe_path = None |
|
|
sam3_checkpoint = None |
|
|
example_image_hf_path = None |
|
|
|
|
|
def download_resources_from_hf(): |
|
|
"""从 Hugging Face Hub 下载模型和资源文件""" |
|
|
global bpe_path, sam3_checkpoint, example_image_hf_path |
|
|
|
|
|
if not SAM3_INSTALLED: |
|
|
print("❌ sam3 库未安装,无法下载资源") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
bpe_path = hf_hub_download( |
|
|
repo_id=SAM3_HF_REPO_ID, |
|
|
filename="assets/bpe_simple_vocab_16e6.txt.gz", |
|
|
cache_dir=os.environ.get("HF_HOME", None) |
|
|
) |
|
|
print(f"✅ BPE 词汇表: {bpe_path}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 无法下载 BPE 词汇表: {e}") |
|
|
|
|
|
if sam3 is not None: |
|
|
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
|
|
bpe_path = os.path.join(sam3_root, "assets", "bpe_simple_vocab_16e6.txt.gz") |
|
|
if not os.path.exists(bpe_path): |
|
|
bpe_path = None |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
env_checkpoint = os.environ.get("SAM3_CHECKPOINT_PATH") |
|
|
if env_checkpoint and os.path.exists(env_checkpoint): |
|
|
sam3_checkpoint = env_checkpoint |
|
|
print(f"✅ 使用环境变量指定的模型: {sam3_checkpoint}") |
|
|
else: |
|
|
|
|
|
sam3_checkpoint = hf_hub_download( |
|
|
repo_id=SAM3_HF_REPO_ID, |
|
|
filename="checkpoints/sam3.pt", |
|
|
cache_dir=os.environ.get("HF_HOME", None) |
|
|
) |
|
|
print(f"✅ 模型检查点: {sam3_checkpoint}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 无法下载模型检查点: {e}") |
|
|
|
|
|
try: |
|
|
sam3_checkpoint = hf_hub_download( |
|
|
repo_id=SAM3_HF_REPO_ID, |
|
|
filename="sam3.pt", |
|
|
cache_dir=os.environ.get("HF_HOME", None) |
|
|
) |
|
|
print(f"✅ 模型检查点(备选): {sam3_checkpoint}") |
|
|
except: |
|
|
sam3_checkpoint = None |
|
|
|
|
|
try: |
|
|
|
|
|
example_image_hf_path = hf_hub_download( |
|
|
repo_id=SAM3_HF_REPO_ID, |
|
|
filename="assets/images/test_image.jpg", |
|
|
cache_dir=os.environ.get("HF_HOME", None) |
|
|
) |
|
|
print(f"✅ 示例图片: {example_image_hf_path}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 无法下载示例图片: {e}") |
|
|
example_image_hf_path = None |
|
|
|
|
|
return bpe_path is not None and sam3_checkpoint is not None |
|
|
|
|
|
|
|
|
print(f"\n{'='*50}") |
|
|
print(f"正在从 Hugging Face Hub 下载资源...") |
|
|
print(f"仓库 ID: {SAM3_HF_REPO_ID}") |
|
|
print(f"{'='*50}\n") |
|
|
download_resources_from_hf() |
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {DEVICE}") |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
autocast_ctx = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""延迟加载模型(支持 ZeroGPU)""" |
|
|
global model, processor, autocast_ctx |
|
|
|
|
|
if model is not None: |
|
|
return True |
|
|
|
|
|
if not SAM3_INSTALLED: |
|
|
print("❌ sam3 库未安装") |
|
|
return False |
|
|
|
|
|
if sam3_checkpoint is None: |
|
|
print("❌ 模型检查点路径未配置") |
|
|
return False |
|
|
|
|
|
if bpe_path is None: |
|
|
print("❌ BPE 词汇表路径未配置") |
|
|
return False |
|
|
|
|
|
try: |
|
|
if DEVICE == "cuda": |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
autocast_ctx = torch.autocast("cuda", dtype=torch.bfloat16) |
|
|
autocast_ctx.__enter__() |
|
|
model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path=sam3_checkpoint).to(DEVICE) |
|
|
else: |
|
|
autocast_ctx = None |
|
|
model = build_sam3_image_model(bpe_path=bpe_path, checkpoint_path=sam3_checkpoint).to(DEVICE) |
|
|
|
|
|
processor = Sam3Processor(model, confidence_threshold=0.5) |
|
|
print("✅ 模型加载成功") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 模型加载失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
model = None |
|
|
processor = None |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
if not SPACES_GPU_AVAILABLE: |
|
|
load_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_to_pil(fig): |
|
|
"""将 Matplotlib 图形转换为 PIL Image。""" |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
|
buf.seek(0) |
|
|
plt.close(fig) |
|
|
return Image.open(buf).convert("RGB") |
|
|
|
|
|
|
|
|
def get_result_figure( |
|
|
img: Image.Image, |
|
|
results: dict, |
|
|
return_point: bool = False, |
|
|
point_width: float = 3.0, |
|
|
return_score: bool = True |
|
|
) -> Tuple[plt.Figure, int]: |
|
|
"""封装原始 plot_results 逻辑,支持显示中心点和置信度控制。 |
|
|
|
|
|
Args: |
|
|
img: 输入图像 |
|
|
results: 推理结果字典 |
|
|
return_point: 是否显示中心点而不是边框 |
|
|
point_width: 中心点或边框线宽 |
|
|
return_score: 是否显示置信度分数 |
|
|
""" |
|
|
fig = plt.figure(figsize=(12, 8)) |
|
|
plt.imshow(img) |
|
|
plt.axis("off") |
|
|
|
|
|
nb_objects = len(results.get("scores", [])) |
|
|
print(f"found {nb_objects} object(s)") |
|
|
|
|
|
for i in range(nb_objects): |
|
|
color = COLORS[i % len(COLORS)] |
|
|
|
|
|
if "masks" in results and i < len(results["masks"]): |
|
|
mask_data = results["masks"][i] |
|
|
if mask_data.ndim == 3: |
|
|
mask_data = mask_data.squeeze(0) |
|
|
plot_mask(mask_data.cpu(), color=color) |
|
|
|
|
|
if "boxes" in results and i < len(results["boxes"]): |
|
|
w, h = img.size |
|
|
box = results["boxes"][i].cpu().tolist() |
|
|
prob = results["scores"][i].item() |
|
|
|
|
|
|
|
|
if return_point: |
|
|
|
|
|
x0, y0, x1, y1 = box |
|
|
center_x = (x0 + x1) / 2 |
|
|
center_y = (y0 + y1) / 2 |
|
|
circle = plt.Circle( |
|
|
(center_x, center_y), |
|
|
point_width * 2, |
|
|
color=color, |
|
|
fill=True |
|
|
) |
|
|
plt.gca().add_patch(circle) |
|
|
|
|
|
|
|
|
if return_score: |
|
|
plt.text( |
|
|
center_x + point_width * 3, |
|
|
center_y, |
|
|
f"{prob:.2f}", |
|
|
color=color, |
|
|
fontsize=10, |
|
|
fontweight='bold' |
|
|
) |
|
|
else: |
|
|
|
|
|
text = f"(id={i}, {prob:.2f})" if return_score else f"(id={i})" |
|
|
plot_bbox( |
|
|
h, |
|
|
w, |
|
|
results["boxes"][i].cpu(), |
|
|
text=text, |
|
|
box_format="XYXY", |
|
|
color=color, |
|
|
relative_coords=False, |
|
|
) |
|
|
|
|
|
return fig, nb_objects |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sam3_segmentation_core( |
|
|
unified_image_input: Any, |
|
|
prompt_text: str, |
|
|
box_type: str, |
|
|
box_mode_history: List[Tuple[int, str]], |
|
|
return_point: bool = False, |
|
|
point_width: float = 3.0, |
|
|
return_score: bool = True |
|
|
): |
|
|
"""核心分割函数""" |
|
|
global model, processor |
|
|
|
|
|
|
|
|
if not SAM3_INSTALLED: |
|
|
return None, "❌ sam3 库未安装,请检查 requirements.txt 或 HF 仓库配置。", box_mode_history |
|
|
|
|
|
if model is None or processor is None: |
|
|
if not load_model(): |
|
|
return None, "❌ 模型未加载,请检查模型配置。可能需要设置 SAM3_HF_REPO_ID 环境变量。", box_mode_history |
|
|
|
|
|
|
|
|
image = None |
|
|
visual_prompter_data = None |
|
|
|
|
|
if IMAGE_PROMPTER_AVAILABLE and isinstance(unified_image_input, dict): |
|
|
image = unified_image_input.get('image') |
|
|
visual_prompter_data = unified_image_input |
|
|
else: |
|
|
image = unified_image_input |
|
|
visual_prompter_data = None |
|
|
|
|
|
if image is None: |
|
|
return None, "请上传图像。", box_mode_history |
|
|
|
|
|
img0 = image.copy() |
|
|
width, height = img0.size |
|
|
|
|
|
|
|
|
try: |
|
|
inference_state = processor.set_image(img0) |
|
|
except Exception as e: |
|
|
return None, f"图像处理失败: {e}", box_mode_history |
|
|
|
|
|
|
|
|
processor.reset_all_prompts(inference_state) |
|
|
found_objects = 0 |
|
|
xyxy_boxes = [] |
|
|
|
|
|
|
|
|
if box_type == "Text": |
|
|
if not prompt_text: |
|
|
return None, "文本模式下,请提供文本提示。", box_mode_history |
|
|
inference_state = processor.set_text_prompt( |
|
|
state=inference_state, |
|
|
prompt=prompt_text |
|
|
) |
|
|
caption_base = "文本提示分割" |
|
|
|
|
|
|
|
|
elif box_type in ["Single Box", "Multi Box"]: |
|
|
|
|
|
if not IMAGE_PROMPTER_AVAILABLE: |
|
|
return None, "当前环境不支持 ImagePrompter,Box 模式无法运行。", box_mode_history |
|
|
|
|
|
if visual_prompter_data: |
|
|
|
|
|
xyxy_boxes, box_labels = process_imageprompter_data(visual_prompter_data, box_mode_history, verbose=True) |
|
|
print(f"Boxes: {xyxy_boxes}") |
|
|
print(f"Labels: {box_labels}") |
|
|
|
|
|
if not xyxy_boxes: |
|
|
return None, f"请在图像上绘制至少一个矩形框作为提示(当前模式: {box_type})。", box_mode_history |
|
|
|
|
|
|
|
|
if box_type == "Single Box" and len(xyxy_boxes) > 1: |
|
|
xyxy_boxes = [xyxy_boxes[0]] |
|
|
box_labels = [box_labels[0]] if box_labels else [True] |
|
|
|
|
|
box_inputs = [] |
|
|
|
|
|
for i, (x_min, y_min, x_max, y_max) in enumerate(xyxy_boxes): |
|
|
x = x_min |
|
|
y = y_min |
|
|
w = x_max - x_min |
|
|
h = y_max - y_min |
|
|
box_inputs.append([x, y, w, h]) |
|
|
|
|
|
|
|
|
try: |
|
|
box_input_xywh = torch.tensor(box_inputs, dtype=torch.float32).view(-1, 4).to(DEVICE) |
|
|
box_input_cxcywh = box_xywh_to_cxcywh(box_input_xywh) |
|
|
norm_boxes_cxcywh = normalize_bbox(box_input_cxcywh, width, height).tolist() |
|
|
|
|
|
for i in range(len(box_inputs)): |
|
|
norm_box = norm_boxes_cxcywh[i] |
|
|
label = box_labels[i] if i < len(box_labels) else True |
|
|
label_str = "正框" if label else "负框" |
|
|
print(f"Adding box {i}: {norm_box}, label={label} ({label_str})") |
|
|
|
|
|
|
|
|
inference_state = processor.add_geometric_prompt( |
|
|
state=inference_state, box=norm_box, label=label |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during box conversion/prompt setting: {e}") |
|
|
return None, f"框提示处理失败: {e}", box_mode_history |
|
|
|
|
|
num_positive = sum(box_labels) if box_labels else len(xyxy_boxes) |
|
|
num_negative = len(xyxy_boxes) - num_positive |
|
|
caption_base = f"使用 {len(xyxy_boxes)} 个提示框分割(正框: {num_positive}, 负框: {num_negative})" |
|
|
|
|
|
else: |
|
|
return None, "请选择有效的提示类型 (Text, Single Box, 或 Multi Box)。", box_mode_history |
|
|
|
|
|
|
|
|
fig, found_objects = get_result_figure( |
|
|
img0.copy(), |
|
|
inference_state, |
|
|
return_point=return_point, |
|
|
point_width=point_width, |
|
|
return_score=return_score |
|
|
) |
|
|
result_image = plot_to_pil(fig) |
|
|
|
|
|
return result_image, f"{caption_base}。找到 {found_objects} 个对象。", box_mode_history |
|
|
|
|
|
|
|
|
|
|
|
if SPACES_GPU_AVAILABLE: |
|
|
@spaces.GPU |
|
|
def sam3_segmentation( |
|
|
unified_image_input: Any, |
|
|
prompt_text: str, |
|
|
box_type: str, |
|
|
box_mode_history: List[Tuple[int, str]], |
|
|
return_point: bool = False, |
|
|
point_width: float = 3.0, |
|
|
return_score: bool = True |
|
|
): |
|
|
"""ZeroGPU 版本的推理函数""" |
|
|
return sam3_segmentation_core( |
|
|
unified_image_input, prompt_text, box_type, |
|
|
box_mode_history, return_point, point_width, return_score |
|
|
) |
|
|
else: |
|
|
def sam3_segmentation( |
|
|
unified_image_input: Any, |
|
|
prompt_text: str, |
|
|
box_type: str, |
|
|
box_mode_history: List[Tuple[int, str]], |
|
|
return_point: bool = False, |
|
|
point_width: float = 3.0, |
|
|
return_score: bool = True |
|
|
): |
|
|
"""标准版本的推理函数""" |
|
|
return sam3_segmentation_core( |
|
|
unified_image_input, prompt_text, box_type, |
|
|
box_mode_history, return_point, point_width, return_score |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_box_mode_change( |
|
|
new_mode: str, |
|
|
unified_image_input: Any, |
|
|
current_history: List[Tuple[int, str]] |
|
|
) -> Tuple[List[Tuple[int, str]], str, Image.Image]: |
|
|
""" |
|
|
当用户切换框模式时,记录当前框数量和新模式,并更新预览。 |
|
|
|
|
|
Args: |
|
|
new_mode: 新选择的模式 ("正框 (Positive)" 或 "负框 (Negative)") |
|
|
unified_image_input: 当前 ImagePrompter 的数据 |
|
|
current_history: 当前的模式切换历史 |
|
|
|
|
|
Returns: |
|
|
tuple: (更新后的历史, 状态信息文本, 预览图像) |
|
|
""" |
|
|
if current_history is None: |
|
|
current_history = [] |
|
|
|
|
|
|
|
|
current_box_count = 0 |
|
|
if unified_image_input and isinstance(unified_image_input, dict): |
|
|
points = unified_image_input.get('points', []) |
|
|
if points: |
|
|
current_box_count = len(points) |
|
|
|
|
|
|
|
|
mode_internal = "positive" if "Positive" in new_mode or "正框" in new_mode else "negative" |
|
|
|
|
|
|
|
|
|
|
|
new_history = current_history.copy() |
|
|
new_history.append((current_box_count, mode_internal)) |
|
|
|
|
|
|
|
|
mode_display = "正框" if mode_internal == "positive" else "负框" |
|
|
status = f"✅ 已切换到 {mode_display} 模式。从第 {current_box_count + 1} 个框开始将被标记为{mode_display}。" |
|
|
|
|
|
print(f"Box mode changed: {new_mode} -> {mode_internal}, at box index {current_box_count}") |
|
|
print(f"Updated history: {new_history}") |
|
|
|
|
|
|
|
|
preview_image = None |
|
|
if unified_image_input and isinstance(unified_image_input, dict): |
|
|
image = unified_image_input.get('image') |
|
|
if image is not None: |
|
|
xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, new_history, verbose=False) |
|
|
if xyxy_boxes: |
|
|
preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
|
|
else: |
|
|
preview_image = image |
|
|
|
|
|
return new_history, status, preview_image |
|
|
|
|
|
|
|
|
def reset_box_mode_history( |
|
|
unified_image_input: Any |
|
|
) -> Tuple[List[Tuple[int, str]], str, Image.Image]: |
|
|
"""重置框模式历史并更新预览""" |
|
|
new_history = [(0, "positive")] |
|
|
status = "已重置,所有框将默认为正框。" |
|
|
|
|
|
|
|
|
preview_image = None |
|
|
if unified_image_input and isinstance(unified_image_input, dict): |
|
|
image = unified_image_input.get('image') |
|
|
if image is not None: |
|
|
xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, new_history, verbose=False) |
|
|
if xyxy_boxes: |
|
|
preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
|
|
else: |
|
|
preview_image = image |
|
|
|
|
|
return new_history, status, preview_image |
|
|
|
|
|
|
|
|
def get_current_box_status( |
|
|
unified_image_input: Any, |
|
|
box_mode_history: List[Tuple[int, str]] |
|
|
) -> str: |
|
|
"""获取当前框的状态信息""" |
|
|
if not unified_image_input or not isinstance(unified_image_input, dict): |
|
|
return "尚未绘制框" |
|
|
|
|
|
points = unified_image_input.get('points', []) |
|
|
if not points: |
|
|
return "尚未绘制框" |
|
|
|
|
|
num_boxes = len(points) |
|
|
|
|
|
|
|
|
if not box_mode_history: |
|
|
return f"已绘制 {num_boxes} 个框(全部为正框)" |
|
|
|
|
|
|
|
|
mode_switch_points = {} |
|
|
for box_idx, mode in box_mode_history: |
|
|
mode_switch_points[box_idx] = mode |
|
|
|
|
|
current_mode = "positive" |
|
|
positive_count = 0 |
|
|
negative_count = 0 |
|
|
|
|
|
for i in range(num_boxes): |
|
|
if i in mode_switch_points: |
|
|
current_mode = mode_switch_points[i] |
|
|
if current_mode == "positive": |
|
|
positive_count += 1 |
|
|
else: |
|
|
negative_count += 1 |
|
|
|
|
|
return f"已绘制 {num_boxes} 个框(正框: {positive_count}, 负框: {negative_count})" |
|
|
|
|
|
|
|
|
def update_box_preview( |
|
|
unified_image_input: Any, |
|
|
box_mode_history: List[Tuple[int, str]] |
|
|
) -> Tuple[Image.Image, str, str]: |
|
|
""" |
|
|
更新框预览图像,显示带颜色和标签的框。 |
|
|
|
|
|
Args: |
|
|
unified_image_input: ImagePrompter 的数据 |
|
|
box_mode_history: 框模式历史 |
|
|
|
|
|
Returns: |
|
|
tuple: (预览图像, 状态文本, 框提示参数文本) |
|
|
""" |
|
|
|
|
|
status_text = get_current_box_status(unified_image_input, box_mode_history) |
|
|
|
|
|
|
|
|
if not unified_image_input or not isinstance(unified_image_input, dict): |
|
|
return None, status_text, "None" |
|
|
|
|
|
image = unified_image_input.get('image') |
|
|
if image is None: |
|
|
return None, status_text, "None" |
|
|
|
|
|
|
|
|
xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, box_mode_history, verbose=False) |
|
|
|
|
|
if not xyxy_boxes: |
|
|
return image, status_text, "None" |
|
|
|
|
|
|
|
|
preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
|
|
|
|
|
|
|
|
boxes_int = [[int(coord) for coord in box] for box in xyxy_boxes] |
|
|
if len(xyxy_boxes) == 1: |
|
|
prompt_info_text = f"Box: {boxes_int[0]}\nLabel: {box_labels[0]}" |
|
|
else: |
|
|
prompt_info_text = f"Boxes: {boxes_int}\nLabels: {box_labels}" |
|
|
|
|
|
return preview_image, status_text, prompt_info_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_image_path = None |
|
|
example_image = None |
|
|
|
|
|
|
|
|
if example_image_hf_path and os.path.exists(example_image_hf_path): |
|
|
example_image_path = example_image_hf_path |
|
|
example_image = Image.open(example_image_hf_path) |
|
|
print(f"✅ 使用 HF Hub 下载的示例图片: {example_image_path}") |
|
|
else: |
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
sam3_asset_path = os.path.join(SCRIPT_DIR, "assets", "images", "test_image.jpg") |
|
|
|
|
|
if os.path.exists(sam3_asset_path): |
|
|
example_image_path = os.path.abspath(sam3_asset_path) |
|
|
example_image = Image.open(sam3_asset_path) |
|
|
print(f"✅ 使用本地示例图片: {example_image_path}") |
|
|
elif sam3 is not None: |
|
|
|
|
|
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
|
|
sam3_asset_path = os.path.join(sam3_root, "assets", "images", "test_image.jpg") |
|
|
if os.path.exists(sam3_asset_path): |
|
|
example_image_path = os.path.abspath(sam3_asset_path) |
|
|
example_image = Image.open(sam3_asset_path) |
|
|
print(f"✅ 使用 sam3 模块示例图片: {example_image_path}") |
|
|
|
|
|
if example_image is None: |
|
|
print(f"⚠️ 示例图片未找到,使用占位图") |
|
|
example_image_path = None |
|
|
example_image = Image.new('RGB', (512, 512), color='lightgray') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example_prompts_info = { |
|
|
"Text": "None", |
|
|
"Single Box": "Box: [487, 302, 591, 641]\nLabel: True", |
|
|
"Multi Box": "Boxes: [[487, 302, 591, 641], [341, 275, 495, 662]]\nLabels: [False, True]" |
|
|
} |
|
|
|
|
|
if IMAGE_PROMPTER_AVAILABLE: |
|
|
|
|
|
if example_image_path: |
|
|
example_data_corrected = [ |
|
|
[{"image": example_image_path, "points": []}, "Text", example_prompts_info["Text"], "shoe"], |
|
|
[{"image": example_image_path, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3]]}, "Single Box", example_prompts_info["Single Box"], ""], |
|
|
[{"image": example_image_path, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3], [341.0, 275.0, 2, 495.0, 662.0, 3]]}, "Multi Box", example_prompts_info["Multi Box"], ""], |
|
|
] |
|
|
else: |
|
|
example_data_corrected = [ |
|
|
[{"image": example_image, "points": []}, "Text", example_prompts_info["Text"], "shoe"], |
|
|
[{"image": example_image, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3]]}, "Single Box", example_prompts_info["Single Box"], ""], |
|
|
[{"image": example_image, "points": [[487.0, 302.0, 2, 591.0, 641.0, 3], [341.0, 275.0, 2, 495.0, 662.0, 3]]}, "Multi Box", example_prompts_info["Multi Box"], ""], |
|
|
] |
|
|
|
|
|
example_multi_box_history = [(0, "negative"), (1, "positive")] |
|
|
else: |
|
|
|
|
|
if example_image_path: |
|
|
example_data_corrected = [ |
|
|
[example_image_path, "Text", example_prompts_info["Text"], "shoe"], |
|
|
[example_image_path, "Single Box", example_prompts_info["Single Box"], ""], |
|
|
[example_image_path, "Multi Box", example_prompts_info["Multi Box"], ""], |
|
|
] |
|
|
else: |
|
|
example_data_corrected = [ |
|
|
[example_image, "Text", example_prompts_info["Text"], "shoe"], |
|
|
[example_image, "Single Box", example_prompts_info["Single Box"], ""], |
|
|
[example_image, "Multi Box", example_prompts_info["Multi Box"], ""], |
|
|
] |
|
|
example_multi_box_history = [(0, "positive")] |
|
|
|
|
|
|
|
|
def on_example_select( |
|
|
unified_image_input: Any, |
|
|
prompt_type: str |
|
|
) -> Tuple[List[Tuple[int, str]], Image.Image, str]: |
|
|
""" |
|
|
当用户选择示例时,自动更新框模式历史和预览。 |
|
|
|
|
|
Args: |
|
|
unified_image_input: ImagePrompter 的数据 |
|
|
prompt_type: 提示类型 (Text, Single Box, Multi Box) |
|
|
|
|
|
Returns: |
|
|
tuple: (框模式历史, 预览图像, 状态文本) |
|
|
""" |
|
|
|
|
|
if prompt_type == "Multi Box": |
|
|
|
|
|
box_history = [(0, "positive"), (1, "negative")] |
|
|
elif prompt_type == "Single Box": |
|
|
|
|
|
box_history = [(0, "positive")] |
|
|
else: |
|
|
|
|
|
box_history = [(0, "positive")] |
|
|
|
|
|
|
|
|
preview_image = None |
|
|
status_text = "尚未绘制框" |
|
|
|
|
|
if unified_image_input and isinstance(unified_image_input, dict): |
|
|
image = unified_image_input.get('image') |
|
|
if image is not None: |
|
|
xyxy_boxes, box_labels = process_imageprompter_data(unified_image_input, box_history, verbose=False) |
|
|
if xyxy_boxes: |
|
|
preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
|
|
num_positive = sum(box_labels) |
|
|
num_negative = len(box_labels) - num_positive |
|
|
status_text = f"已绘制 {len(xyxy_boxes)} 个框(正框: {num_positive}, 负框: {num_negative})" |
|
|
else: |
|
|
preview_image = image |
|
|
|
|
|
return box_history, preview_image, status_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
box_mode_history_state = gr.State([(0, "positive")]) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎯 SAM 3 Demo |
|
|
**Segment Anything Model 3 - 目标检测与分割** |
|
|
|
|
|
> 🚀 Powered by Hugging Face Spaces |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
with gr.Accordion("📋 使用说明", open=False): |
|
|
gr.Markdown(""" |
|
|
**使用方法:** |
|
|
|
|
|
📝 **Text 模式** |
|
|
1. 选择 "Text" 模式 |
|
|
2. 上传图像 |
|
|
3. 输入文本提示词(如 "shoe", "person") |
|
|
4. 点击"运行 SAM 3 分割" |
|
|
|
|
|
⬜ **Single Box 模式** |
|
|
1. 选择 "Single Box" 模式 |
|
|
2. 上传图像 |
|
|
3. 在图像上绘制一个矩形框 |
|
|
4. 点击"运行 SAM 3 分割" |
|
|
|
|
|
🔲 **Multi Box 模式(支持正/负框)** |
|
|
1. 选择 "Multi Box" 模式 |
|
|
2. 上传图像 |
|
|
3. **默认为正框模式**,绘制的框将包含目标 |
|
|
4. 如需绘制负框(排除区域): |
|
|
- 先绘制正框 |
|
|
- 点击切换到 "负框 (Negative)" 模式 |
|
|
- 继续绘制负框 |
|
|
5. 点击「🔄 刷新预览」按钮查看框标签预览 |
|
|
6. 点击"运行 SAM 3 分割" |
|
|
|
|
|
💡 **正框 vs 负框** |
|
|
- **正框(红色)**: 告诉模型"包含这个区域的目标" |
|
|
- **负框(绿色)**: 告诉模型"排除这个区域",用于去除误检 |
|
|
- **注意**: 负框需要配合正框使用才能生效 |
|
|
|
|
|
⚙️ **显示选项** |
|
|
- **显示中心点**: 用圆点代替边框显示检测结果中心位置 |
|
|
- **显示置信度**: 在结果中显示模型的置信度分数 |
|
|
- **线条/点宽度**: 调整边框线宽或中心点大小 |
|
|
""") |
|
|
|
|
|
if IMAGE_PROMPTER_AVAILABLE: |
|
|
unified_image_input = ImagePrompter( |
|
|
label="🖼️ 示例图像", |
|
|
type="pil" |
|
|
) |
|
|
else: |
|
|
unified_image_input = gr.Image( |
|
|
label="🖼️ 示例图像", |
|
|
type="pil" |
|
|
) |
|
|
|
|
|
prompt_type = gr.Radio( |
|
|
["Text", "Single Box", "Multi Box"], |
|
|
label="提示类型", |
|
|
value="Text" |
|
|
) |
|
|
|
|
|
text_prompt_input = gr.Textbox( |
|
|
label="文本提示参数", |
|
|
value="shoe", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
example_prompt_info_display = gr.Textbox( |
|
|
label="框提示参数", |
|
|
value="", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(visible=False) as box_mode_group: |
|
|
gr.Markdown("### 🎯 框模式设置") |
|
|
box_mode_selector = gr.Radio( |
|
|
["正框 (Positive)", "负框 (Negative)"], |
|
|
label="当前绘制模式", |
|
|
value="正框 (Positive)", |
|
|
info="正框=包含目标,负框=排除区域" |
|
|
) |
|
|
with gr.Row(): |
|
|
reset_history_btn = gr.Button("🔄 重置框标签", size="sm") |
|
|
|
|
|
|
|
|
with gr.Group(visible=False) as box_preview_group: |
|
|
gr.Markdown("### 📦 框预览(红色=正框 True,绿色=负框 False)") |
|
|
box_status_text = gr.Textbox( |
|
|
label="框状态", |
|
|
value="尚未绘制框", |
|
|
interactive=False |
|
|
) |
|
|
refresh_preview_btn = gr.Button("🔄 刷新预览", size="sm", variant="secondary") |
|
|
gr.Markdown("*绘制框后点击「刷新预览」按钮查看标注效果*") |
|
|
box_preview_image = gr.Image( |
|
|
label="框标签预览", |
|
|
type="pil", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### ⚙️ 显示选项") |
|
|
with gr.Row(): |
|
|
return_point = gr.Checkbox(label="显示中心点", value=False) |
|
|
return_score = gr.Checkbox(label="显示置信度", value=True) |
|
|
point_width = gr.Slider( |
|
|
label="线条/点宽度", |
|
|
value=3.0, |
|
|
minimum=0.0, |
|
|
maximum=20.0, |
|
|
step=0.1, |
|
|
) |
|
|
|
|
|
run_button = gr.Button("Run SAM3", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_image = gr.Image(label="分割结果", type="pil") |
|
|
result_info = gr.Textbox(label="结果信息", lines=2) |
|
|
|
|
|
def run_example(img, ptype, prompt_info, text): |
|
|
"""运行示例时使用正确的框模式历史""" |
|
|
if ptype == "Multi Box": |
|
|
|
|
|
history = [(0, "negative"), (1, "positive")] |
|
|
else: |
|
|
history = [(0, "positive")] |
|
|
result_img, result_text, _ = sam3_segmentation(img, text, ptype, history, False, 3.0, True) |
|
|
return result_img, result_text |
|
|
|
|
|
gr.Examples( |
|
|
examples=example_data_corrected, |
|
|
inputs=[unified_image_input, prompt_type, example_prompt_info_display, text_prompt_input], |
|
|
outputs=[output_image, result_info], |
|
|
fn=run_example, |
|
|
cache_examples=False, |
|
|
label="示例" |
|
|
) |
|
|
|
|
|
|
|
|
run_button.click( |
|
|
fn=sam3_segmentation, |
|
|
inputs=[unified_image_input, text_prompt_input, prompt_type, box_mode_history_state, return_point, point_width, return_score], |
|
|
outputs=[output_image, result_info, box_mode_history_state] |
|
|
) |
|
|
|
|
|
|
|
|
box_mode_selector.change( |
|
|
fn=on_box_mode_change, |
|
|
inputs=[box_mode_selector, unified_image_input, box_mode_history_state], |
|
|
outputs=[box_mode_history_state, box_status_text, box_preview_image] |
|
|
) |
|
|
|
|
|
|
|
|
reset_history_btn.click( |
|
|
fn=reset_box_mode_history, |
|
|
inputs=[unified_image_input], |
|
|
outputs=[box_mode_history_state, box_status_text, box_preview_image] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
refresh_preview_btn.click( |
|
|
fn=update_box_preview, |
|
|
inputs=[unified_image_input, box_mode_history_state], |
|
|
outputs=[box_preview_image, box_status_text, example_prompt_info_display] |
|
|
) |
|
|
|
|
|
def update_inputs(p_type): |
|
|
is_text = p_type == "Text" |
|
|
is_multi_box = p_type == "Multi Box" |
|
|
is_box_mode = p_type in ["Single Box", "Multi Box"] |
|
|
return ( |
|
|
gr.update(visible=is_text), |
|
|
gr.update(visible=is_multi_box), |
|
|
gr.update(visible=is_box_mode) |
|
|
) |
|
|
|
|
|
def update_inputs_and_preview(p_type, img_input): |
|
|
""" |
|
|
更新输入组件可见性,并在示例加载时自动更新预览。 |
|
|
""" |
|
|
is_text = p_type == "Text" |
|
|
is_multi_box = p_type == "Multi Box" |
|
|
is_box_mode = p_type in ["Single Box", "Multi Box"] |
|
|
|
|
|
|
|
|
if p_type == "Multi Box": |
|
|
box_history = [(0, "negative"), (1, "positive")] |
|
|
elif p_type == "Single Box": |
|
|
box_history = [(0, "positive")] |
|
|
else: |
|
|
box_history = [(0, "positive")] |
|
|
|
|
|
|
|
|
preview_image = None |
|
|
status_text = "尚未绘制框" |
|
|
|
|
|
if img_input and isinstance(img_input, dict): |
|
|
image = img_input.get('image') |
|
|
if image is not None: |
|
|
xyxy_boxes, box_labels = process_imageprompter_data(img_input, box_history, verbose=False) |
|
|
if xyxy_boxes: |
|
|
preview_image = draw_boxes_with_labels(image, xyxy_boxes, box_labels) |
|
|
num_positive = sum(box_labels) |
|
|
num_negative = len(box_labels) - num_positive |
|
|
status_text = f"已绘制 {len(xyxy_boxes)} 个框(正框: {num_positive}, 负框: {num_negative})" |
|
|
else: |
|
|
preview_image = image |
|
|
|
|
|
return ( |
|
|
gr.update(visible=is_text), |
|
|
gr.update(visible=is_multi_box), |
|
|
gr.update(visible=is_box_mode), |
|
|
box_history, |
|
|
preview_image, |
|
|
status_text |
|
|
) |
|
|
|
|
|
prompt_type.change( |
|
|
fn=update_inputs_and_preview, |
|
|
inputs=[prompt_type, unified_image_input], |
|
|
outputs=[text_prompt_input, box_mode_group, box_preview_group, box_mode_history_state, box_preview_image, box_status_text] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.load( |
|
|
fn=update_inputs, |
|
|
inputs=[prompt_type], |
|
|
outputs=[text_prompt_input, box_mode_group, box_preview_group] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
demo.launch( |
|
|
show_error=True |
|
|
) |
|
|
|