Comments (6)
@yrcong Thanks for the code example. I was just wondering, how do you compute the matching_indices
? Is it related to the _compute_pred_matches
function in sg_eval.py
?
Or computing the matching_indices
means getting the iou overlaps between predictions and targets ?
from reltr.
We are glad if our work is helpful to you.
As discussed in the paper, different from two-stage methods, the ground truth bounding boxes and categories of entities
cannot be given directly. Therefore, we assign the ground truth information to the matched triplet proposals when
evaluating RelTR on PredCLS/SGCLS. Of course, the triplet features are not replaced.
For example, a triplet proposal is assigned to a ground truth relation. For PredCLS, the subject class/box and object class/box are replaced with the ground truth, while the relationship is computed using the original features (feature map of predicted boxes).
from reltr.
You could have a try with the following code (we haven't cleaned it and it looks stupid..) for reference:
def evaluate_batch_predcls(outputs, targets, evaluator, matching_indices, evaluator_list):
#TODO
for batch, target in enumerate(targets):
target_bboxes_scaled = rescale_bboxes(target['boxes'].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy() # recovered boxes with original size
gt_entry = {'gt_classes': target['labels'].cpu().clone().numpy(),
'gt_relations': target['rel_annotations'].cpu().clone().numpy(),
'gt_boxes': target_bboxes_scaled}
index1, index2 = matching_indices[1][batch]
sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()
obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][batch].cpu(), torch.flip(target['orig_size'],dims=[0]).cpu()).clone().numpy()
pred_sub_scores, pred_sub_classes = torch.max(outputs['sub_logits'][batch].softmax(-1)[:, :-1], dim=1)
pred_obj_scores, pred_obj_classes = torch.max(outputs['obj_logits'][batch].softmax(-1)[:, :-1], dim=1)
rel_scores = outputs['rel_logits'][batch][:,1:-1].softmax(-1)
pred_sub_classes = pred_sub_classes.cpu().clone().numpy()
pred_sub_scores = pred_sub_scores.cpu().clone().numpy()
pred_obj_classes = pred_obj_classes.cpu().clone().numpy()
pred_obj_scores = pred_obj_scores.cpu().clone().numpy()
if len(gt_entry['gt_relations']) > 1:
sub_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][:, 0]]
obj_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][:, 1]]
pred_sub_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][:, 0]]
pred_obj_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][:, 1]]
pred_sub_scores[index1] = 1
pred_obj_scores[index1] = 1
else:
sub_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][0]]
obj_bboxes_scaled[index1] = gt_entry['gt_boxes'][gt_entry['gt_relations'][index2][1]]
pred_sub_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][0]]
pred_obj_classes[index1] = gt_entry['gt_classes'][gt_entry['gt_relations'][index2][1]]
pred_sub_scores[index1] = 1
pred_obj_scores[index1] = 1
mask = (pred_sub_classes - pred_obj_classes != 0)
if mask.sum() <= 198:
sub_bboxes_scaled = sub_bboxes_scaled[mask]
pred_sub_classes = pred_sub_classes[mask]
pred_sub_scores = pred_sub_scores[mask]
obj_bboxes_scaled = obj_bboxes_scaled[mask]
pred_obj_classes = pred_obj_classes[mask]
pred_obj_scores = pred_obj_scores[mask]
rel_scores = rel_scores[mask]
pred_entry = {'sub_boxes': sub_bboxes_scaled,
'sub_classes': pred_sub_classes,
'sub_scores': pred_sub_scores,
'obj_boxes': obj_bboxes_scaled,
'obj_classes': pred_obj_classes,
'obj_scores': pred_obj_scores,
'rel_scores': rel_scores.cpu().clone().numpy()}
evaluator['predcls'].evaluate_scene_graph_entry(gt_entry, pred_entry)
if evaluator_list is not None:
for pred_id, _, evaluator_rel in evaluator_list:
gt_entry_rel = gt_entry.copy()
mask = np.in1d(gt_entry_rel['gt_relations'][:, -1], pred_id)
gt_entry_rel['gt_relations'] = gt_entry_rel['gt_relations'][mask, :]
if gt_entry_rel['gt_relations'].shape[0] == 0:
continue
evaluator_rel['predcls'].evaluate_scene_graph_entry(gt_entry_rel, pred_entry)
from reltr.
Thank you very much !
from reltr.
@yrcong Thanks for the code example. I was just wondering, how do you compute the
matching_indices
? Is it related to the_compute_pred_matches
function insg_eval.py
?Or computing the
matching_indices
means getting the iou overlaps between predictions and targets ?
Have you figured it out? I think it sounds reasonable to compute matching_indices
according to iou overlaps.
from reltr.
@yrcong Could you give some details about computing matching_indices
. Thanks! :)
from reltr.
Related Issues (20)
- convert the reltr model to onnx forma
- about Predcls HOT 1
- Evaluation HOT 3
- checkpoint should be updated with enhanced version HOT 1
- Some misunderstanding about the heat map using to predict Relationship HOT 2
- Error during training in bbox.pyx : ValueError: Buffer dtype mismatch, expected 'DTYPE_t' but got 'double' HOT 1
- 请问如何将inference.py得到的场景图保存成一个json文件呢 HOT 1
- What happens when there are no relations in a sample? HOT 2
- When I was training data, I encountered an error HOT 1
- About OpenV6 HOT 3
- How to use the resulting weight file for evaluation HOT 1
- 1 HOT 5
- I can only get the relationships between entities, how can I get the scene graph? HOT 1
- Unable to train on a single GPU - ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError HOT 1
- Incompatible tensor size when a running a batch size of 2 HOT 2
- demo HOT 3
- 关于openimages的数据处理问题 HOT 1
- Generated checkpoint files throwing error when trying to infere an image HOT 1
- Training failed on custom data set HOT 1
- Inference code for Open Images? HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from reltr.