Code Monkey home page Code Monkey logo

Comments (4)

yrcong avatar yrcong commented on July 19, 2024

We are glad if our work is helpful to you.
As discussed in the paper, different from two-stage methods, the ground truth bounding boxes and categories of entities
cannot be given directly. Therefore, we assign the ground truth information to the matched triplet proposals when
evaluating RelTR on PredCLS/SGCLS. Of course, the triplet features are not replaced.

For example, a triplet proposal is assigned to a ground truth relation. For PredCLS, the subject class/box and object class/box are replaced with the ground truth, while the relationship is computed using the original features (feature map of predicted boxes).

from reltr.

yrcong avatar yrcong commented on July 19, 2024

You could have a try with the following code (we haven't cleaned it and it looks stupid..) for reference:

def evaluate_batch_predcls(outputs, targets, evaluator, matching_indices, evaluator_list):

#TODO
for batch, target in enumerate(targets):
    target_bboxes_scaled = rescale_bboxes(target['boxes'].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() # recovered boxes with original size

    gt_entry = {'gt_classes': target['labels'].cpu().clone().numpy(),
                'gt_relations': target['rel_annotations'].cpu().clone().numpy(),
                'gt_boxes': target_bboxes_scaled}

    index1, index2 = matching_indices[1][batch]

    sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()
    obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()

    pred_sub_scores, pred_sub_classes = torch.max(outputs['sub_logits'][batch].softmax(-1)[:, :-1], dim=1)
    pred_obj_scores, pred_obj_classes = torch.max(outputs['obj_logits'][batch].softmax(-1)[:, :-1], dim=1)
    rel_scores = outputs['rel_logits'][batch][:,1:-1].softmax(-1)

    pred_sub_classes = pred_sub_classes.cpu().clone().numpy()
    pred_sub_scores = pred_sub_scores.cpu().clone().numpy()
    pred_obj_classes = pred_obj_classes.cpu().clone().numpy()
    pred_obj_scores = pred_obj_scores.cpu().clone().numpy()

    if len(gt_entry['gt_relations']) > 1:
        sub_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][:, 0]]
        obj_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][:, 1]]
        pred_sub_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][:, 0]]
        pred_obj_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][:, 1]]
        pred_sub_scores[index1] = 1
        pred_obj_scores[index1] = 1
    else:
        sub_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][0]]
        obj_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][1]]
        pred_sub_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][0]]
        pred_obj_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][1]]
        pred_sub_scores[index1] = 1
        pred_obj_scores[index1] = 1

    mask = (pred_sub_classes - pred_obj_classes != 0)
    if mask.sum() <= 198:
        sub_bboxes_scaled = sub_bboxes_scaled[mask]
        pred_sub_classes = pred_sub_classes[mask]
        pred_sub_scores = pred_sub_scores[mask]
        obj_bboxes_scaled = obj_bboxes_scaled[mask]
        pred_obj_classes = pred_obj_classes[mask]
        pred_obj_scores = pred_obj_scores[mask]
        rel_scores = rel_scores[mask]

    pred_entry = {'sub_boxes': sub_bboxes_scaled,
                  'sub_classes': pred_sub_classes,
                  'sub_scores': pred_sub_scores,
                  'obj_boxes': obj_bboxes_scaled,
                  'obj_classes': pred_obj_classes,
                  'obj_scores': pred_obj_scores,
                  'rel_scores': rel_scores.cpu().clone().numpy()}

    evaluator['predcls'].evaluate_scene_graph_entry(gt_entry, pred_entry)

    if evaluator_list is not None:
        for pred_id, _, evaluator_rel in evaluator_list:
            gt_entry_rel = gt_entry.copy()
            mask = np.in1d(gt_entry_rel['gt_relations'][:, -1], pred_id)
            gt_entry_rel['gt_relations'] = gt_entry_rel['gt_relations'][mask, :]
            if gt_entry_rel['gt_relations'].shape[0] == 0:
                continue
            evaluator_rel['predcls'].evaluate_scene_graph_entry(gt_entry_rel, pred_entry)

from reltr.

mohammedessamtga avatar mohammedessamtga commented on July 19, 2024

Thank you very much !

from reltr.

rebelzion avatar rebelzion commented on July 19, 2024

@yrcong Thanks for the code example. I was just wondering, how do you compute the matching_indices ? Is it related to the _compute_pred_matches function in sg_eval.py ?

Or computing the matching_indices means getting the iou overlaps between predictions and targets ?

from reltr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.