Comments (5)
I didn't have a full test on Waymo, but recently I applied the same idea on Argoverse 2 and achieved pretty strong results. I think the model is not overfitting a single dataset. I feel that your results look like random predictions. Perhaps you should first check whether you have converted the prediction results back to the original coordinate system. Usually speaking, you can get minADE lower than 2.0 after training one epoch.
from hivt.
To get a strong result on Waymo or Argoverse 2, the model may need a larger receptive field since the prediction horizon in these new datasets is pretty long. On Argoverse 2 (which requires predicting a 6-second future) I crop a local map for each agent with a radius of 150m.
from hivt.
I didn't have a full test on Waymo, but recently I applied the same idea on Argoverse 2 and achieved pretty strong results. I think the model is not overfitting a single dataset. I feel that your results look like random predictions. Perhaps you should first check whether you have converted the prediction results back to the original coordinate system. Usually speaking, you can get minADE lower than 2.0 after training one epoch.
Thanks for your reply.
But in the code the data.y is normalized by agents and the prediction is in the same coordinate system as data.y. And can directly calculate the loss.
So I do not understand what is the meaning of "converte the prediction results back to the original coordinate system".
Can you explain the idea in detail?
Thanks
from hivt.
I'm not sure how you evaluate the model. If you evaluate the model by submitting it to the online evaluation server, you need to pay attention to the coordinate system and make sure that the coordinate system you use is consistent with that adopted by the evaluation server. If you evaluate the model offline by using the metrics I implement in the codebase, it should be fine since I have converted the ground truth into agent-centric local coordinate systems. But one more thing you should note is that some datasets have missing values in the ground truth. You may get abnormal results if you don't mask out those missing values correctly when calculating the metric numbers. The metrics I implement in this codebase assume that the ground-truth trajectories are always complete since Argoverse 1 doesn't have this problem.
For example, the code snippet below is an implementation of ADE with missing values in consideration:
class ADE(Metric):
def __init__(self, **kwargs) -> None:
super(ADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
if valid_mask is None:
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
self.sum += ((torch.norm(pred - target, p=2, dim=-1) * valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
from hivt.
Thanks for your reply.
I changed the ADE metrics and also gave a larger receptive field of 200m.
But the result almost no change.
the minADE still converge at 4 and do not reduce.
And I also overfitting the network, but the minADE is still 4.
Will you experiment on the waymo dataset?
I am not sure where the problem is.
from hivt.
Related Issues (20)
- A question about AAEncoder HOT 1
- Qualitative Results Visualization HOT 5
- Test set performance HOT 17
- What do "semantic attributes" represent? HOT 1
- torch.load bottleneck? HOT 1
- Prediction Results for non-agent objects HOT 6
- Code for Qualitative Results analysis HOT 1
- About the loss: reg loss and cls loss HOT 3
- How to obtain the ADE/FDE/MR result of test set? HOT 4
- About no requirements.txt HOT 1
- Code Question HOT 1
- 评价指标报错 HOT 5
- about how to generate test result HOT 2
- The meaning about bos_mask HOT 1
- Bit-wise NOT operation "~" for padding_mask in data.
- What does 'parallel' in args mean?
- HiVT map-free customization HOT 1
- "solving environment killed" while "conda install pytorch-lightning==1.5.2 -c conda-forge"
- cuda and torch version violate
- 大佬,请问第二次执行的时候如果不重新生成pt文件,需要注释掉什么代码或者进行什么操作吗
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hivt.