Code Monkey home page Code Monkey logo

Comments (5)

yearing1017 avatar yearing1017 commented on August 23, 2024

找到的源码:torchvision.models.detection.fasterrcnn_resnet50_fpn

源码中的描述,解释网络的在训练和测试时的输入输出:

Implements Faster R-CNN.
    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.
    The behavior of the model changes depending if it is in training or evaluation mode.
    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
        - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values of x
          between 0 and W and values of y between 0 and H
        - labels (Int64Tensor[N]): the class label for each ground-truth box
    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses for both the RPN and the R-CNN.
    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
        - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values of x
          between 0 and W and values of y between 0 and H
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores or each prediction

from global-wheat-detection.

yearing1017 avatar yearing1017 commented on August 23, 2024

pytorch教程中关于目标检测的Dataset类:TorchVision 对象检测微调教程

用于训练对象检测,实例细分和人员关键点检测的参考脚本可轻松支持添加新的自定义数据集。 数据集应继承自标准torch.utils.data.Dataset类,并实现__len__和__getitem__。

我们唯一需要的特异性是数据集__getitem__应该返回:

  • 图像:大小为(H, W)的 PIL 图像
  • 目标:包含以下字段的字典
    • boxes (FloatTensor[N, 4]):[x0, y0, x1, y1]格式的N个边界框的坐标,范围从0至W,从0至H
    • labels (Int64Tensor[N]):每个边界框的标签。0经常表示背景类
    • image_id (Int64Tensor[1]):图像标识符。 它在数据集中的所有图像之间应该是唯一的,并在评估过程中使用
    • area (Tensor[N]):边界框的区域。 在使用 COCO 度量进行评估时,可使用此值来区分小盒子,中盒子和大盒子之间的度量得分。
    • iscrowd (UInt8Tensor[N]):iscrowd = True 的实例在评估期间将被忽略。
    • (可选)masks (UInt8Tensor[N, H, W]):每个对象的分割Mask
    • (可选)keypoints (FloatTensor[N, K, 3]):对于 N 个对象中的每个对象,它包含[x, y, visibility]格式的 K 个关键点,以定义对象。 可见性= 0 表示关键点不可见。 请注意,对于数据扩充,翻转关键点的概念取决于数据表示形式,您可能应该将references/detection/transforms.py修改为新的关键点表示形式
      如果您的模型返回上述方法,则它们将使其适用于训练和评估,并将使用pycocotools中的评估脚本。

from global-wheat-detection.

yearing1017 avatar yearing1017 commented on August 23, 2024

Error 01

  • 想采用与训练时一样的方法得出loss作为验证部分的标注,如下代码:
for epo in range(epoch):
        loss_hist.reset()
        val_loss_hist.reset()
        model.train()
        for images, targets, image_ids in train_data_loader:
            
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            # 这里应该是model中封装了求loss,官方教程中写的是 model(images, targets):Returns losses and detections
            loss_dict = model(images, targets)  # Returns losses and detections
            # 对一个batch的loss求和
            losses = sum(loss for loss in loss_dict.values())
            # item方法是取出数值
            loss_value = losses.item()
            # 使用封装的类记录loss
            loss_hist.send(loss_value)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            if itr % 50 == 0:
                print(f"Iteration #{itr} loss: {loss_value}")
            itr += 1

        # 验证
        model.eval()
        with torch.no_grad():
            for val_images, val_targets, val_image_ids in valid_data_loader:
                images = list(image.to(device) for image in val_images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in val_targets]

                val_loss_dict = model(images, targets)  # Returns losses and detections
                val_losses = sum(loss for loss in val_loss_dict.values())
                loss_value = losses.item()
                loss_hist.send(loss_value)

        # 判断是否为更优的模型,以loss为标准
        if val_loss_hist.value<least_loss:
            least_loss = val_loss_hist.value
            lval=int(least_loss*1000)/1000
            torch.save(model.state_dict(), f'fasterrcnn_custom_test_ep{epo}_loss{lval}.pth')
            
        else:
            if lr_scheduler is not None:
                lr_scheduler.step()
        print(f"Epoch #{epo} train_loss: {loss_hist.value} val_loss: {val_loss_hist.value}")
  • 代码运行至验证部分报错:‘List’ has no attribute values,意思就是list中没有loss值

  • 打印了一下验证部分的返回值,没有loss;做实验得:去掉eval就会有loss;于是准备使用iou和map等标注代替

from global-wheat-detection.

yearing1017 avatar yearing1017 commented on August 23, 2024

训练得到的frc_55_0629模型在kaggle上所得LB:0.5399

from global-wheat-detection.

yearing1017 avatar yearing1017 commented on August 23, 2024

后续加入了Pseudo Labeling策略,继续在之前的基础上训练网络,所得分数:0.6912

from global-wheat-detection.

Related Issues (7)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.