大家好,我是何三,独立开发者

今天给大家介绍一个在目标检测领域非常有趣的工作——DEIM(DETR with Improved Matching)。这是 CVPR 2025 接收的一篇论文,通过改进 DETR 的匹配机制,实现了更快的收敛速度和更高的检测精度。

背景引入

目标检测是计算机视觉领域的核心任务之一,从传统的 R-CNN 系列到后来的 YOLO、SSD,再到基于 Transformer 的 DETR,这个领域一直在不断演进。

DETR(DEtection TRansformer)的出现是一个重要的里程碑,它首次将 Transformer 架构引入目标检测,彻底改变了传统的检测范式。但是,DETR 也有一个明显的问题:收敛速度慢,需要很长的训练时间才能达到理想的性能。

DEIM匹配机制示意图

核心原理讲解

DEIM 的核心创新在于改进的匹配机制。在传统的 DETR 中,使用的是二分图匹配算法,这种方式虽然理论上可行,但在实际训练中往往效率不高。

DEIM 提出了两个关键的改进:

1. Dense O2O(One-to-One)匹配

传统的 DETR 使用匈牙利算法进行匹配,计算复杂度高。DEIM 提出了 Dense O2O 匹配,通过更高效的匹配策略,大幅提升了训练效率。

2. 多阶段优化

DEIM 引入了多阶段的匹配优化策略,在不同的训练阶段使用不同的匹配策略,从而实现更快的收敛。

简单来说,DEIM 就像是给 DETR 换了一个更聪明的"匹配引擎",让它能够更快地学会如何准确地检测目标。

性能对比

DEIM 在 COCO 数据集上表现非常出色,我们来看看具体的性能数据:

DEIM性能对比

从表格中可以看到:

  • DEIM-N:仅 4M 参数,延迟 2.12ms,AP 达到 43.0
  • DEIM-S:10M 参数,延迟 3.49ms,AP 达到 49.0
  • DEIM-M:19M 参数,延迟 5.62ms,AP 达到 52.7
  • DEIM-L:31M 参数,延迟 8.07ms,AP 达到 54.7
  • DEIM-X:62M 参数,延迟 12.89ms,AP 达到 56.5

这些数据表明,DEIM 不仅收敛速度快,而且在精度上也超越了传统方法。

代码实战

下面我们来看看如何使用 DEIM 进行目标检测。

环境配置

首先,我们需要安装必要的依赖:

# 创建虚拟环境
conda create -n deim python=3.11.9
conda activate deim

# 安装依赖
pip install -r requirements.txt

数据准备

DEIM 支持 COCO 格式的数据集。如果你有自己的数据集,需要将其转换为 COCO 格式:

import json

def convert_to_coco(input_annotations, output_annotations):
    """
    将自定义标注转换为 COCO 格式
    """
    coco_format = {
        "images": [],
        "annotations": [],
        "categories": []
    }

    # 转换图像信息
    for img_id, img_info in enumerate(input_annotations["images"]):
        coco_format["images"].append({
            "id": img_id,
            "file_name": img_info["file_name"],
            "width": img_info["width"],
            "height": img_info["height"]
        })

    # 转换标注信息
    for ann_id, ann_info in enumerate(input_annotations["annotations"]):
        coco_format["annotations"].append({
            "id": ann_id,
            "image_id": ann_info["image_id"],
            "category_id": ann_info["category_id"],
            "bbox": ann_info["bbox"],
            "area": ann_info["area"],
            "iscrowd": 0
        })

    # 转换类别信息
    for cat_id, cat_name in enumerate(input_annotations["categories"]):
        coco_format["categories"].append({
            "id": cat_id,
            "name": cat_name
        })

    # 保存为 JSON 文件
    with open(output_annotations, 'w') as f:
        json.dump(coco_format, f)

if __name__ == "__main__":
    convert_to_coco('your_annotations.json', 'coco_format.json')

配置文件

创建配置文件 custom_detection.yml

task: detection

evaluator:
  type: CocoEvaluator
  iou_types: ['bbox']

num_classes: 80  # 根据你的数据集调整
remap_mscoco_category: False

train_dataloader:
  type: DataLoader
  dataset:
    type: CocoDetection
    img_folder: /data/yourdataset/train
    ann_file: /data/yourdataset/train/train.json
    return_masks: False
    transforms:
      type: Compose
      ops: ~
  shuffle: True
  num_workers: 4
  drop_last: True
  collate_fn:
    type: BatchImageCollateFunction

val_dataloader:
  type: DataLoader
  dataset:
    type: CocoDetection
    img_folder: /data/yourdataset/val
    ann_file: /data/yourdataset/val/val.json
    return_masks: False
    transforms:
      type: Compose
      ops: ~
  shuffle: False
  num_workers: 4
  drop_last: False
  collate_fn:
    type: BatchImageCollateFunction

训练模型

from deim import DEIMTrainer

# 初始化训练器
trainer = DEIMTrainer(
    config_path='custom_detection.yml',
    checkpoint_path='pretrained/deim_s.pth'
)

# 开始训练
trainer.train(
    epochs=50,
    batch_size=8,
    learning_rate=1e-4
)

# 评估模型
results = trainer.evaluate()
print(f"mAP: {results['map']:.4f}")

推理示例

import torch
from PIL import Image
from torchvision import transforms

# 加载模型
model = torch.load('checkpoint.pth')
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

# 加载图像
image = Image.open('test.jpg')
image_tensor = transform(image).unsqueeze(0)

# 推理
with torch.no_grad():
    predictions = model(image_tensor)

# 解析结果
boxes = predictions[0]['boxes']
scores = predictions[0]['scores']
labels = predictions[0]['labels']

# 过滤低置信度结果
threshold = 0.5
keep = scores > threshold
boxes = boxes[keep]
scores = scores[keep]
labels = labels[keep]

print(f"检测到 {len(boxes)} 个目标")

模型架构

DEIM 的模型架构基于 DETR,但在匹配机制上做了重要改进:

DEIM模型架构

从架构图中可以看到,DEIM 包含以下主要组件:

  1. Backbone:特征提取网络(如 ResNet、Swin Transformer)
  2. Feature Pyramid Network:多尺度特征融合
  3. Transformer Encoder:特征增强
  4. Object Queries:可学习的查询向量
  5. Transformer Decoder:解码预测
  6. DEIM Matching:改进的匹配机制(核心创新)
  7. Prediction Heads:边界框和类别预测

应用场景

DEIM 的快速推理和高精度特性使其非常适合以下场景:

  1. 实时检测:延迟低至 2.12ms,适合实时应用
  2. 移动端部署:轻量级模型(如 DEIM-N)适合移动设备
  3. 工业检测:高精度检测适合工业质检
  4. 自动驾驶:快速准确的检测对自动驾驶至关重要

总结

DEIM 通过改进的匹配机制,在保持 DETR 简洁性的同时,大幅提升了收敛速度和检测精度。它的主要优势包括:

  • 快速收敛:相比传统 DETR,训练效率大幅提升
  • 高精度:在 COCO 数据集上达到 56.5 AP
  • 快速推理:延迟低至 2.12ms,适合实时应用
  • 轻量级:提供多种尺寸模型,适合不同场景

如果你对目标检测感兴趣,DEIM 绝对是一个值得尝试的框架。它不仅性能出色,而且代码开源,易于使用和扩展。