Comments (7)
The coordinate system needs to be transformed twice. First, according to equation 15 in the paper,
the outputs are in per-agent local coordinate systems. So the first step is to convert the outputs to the global coordinate system (which is AV-centric ) by using data['rotate_angles'] for rotation and using data['positions'][:, 'current_time_step'] for translation. The 'current_time_step' for Argoverse 1 is 19. Now, the outputs are in an AV-centric coordinate system, so the second step is to convert the outputs back to the original global coordinate system using data['origin'] and data['theta'].
By the way, normalizing the data into an AV-centric coordinate system during data preprocessing is redundant. I did this only for the convenience of ablation studies.
from hivt.
Hi @ZikangZhou,
Sorry to reopen this issue, but is the provided hivt-128 checkpoint the one used for reporting test set result in paper? When I generated test set prediction using the provided hivt-128, the result is not as good as reported in the paper. the minFDE is 1.232 and brief-minFDE is 1.926. The prediction_step is as follow.
y_hat, pi = self(data)
pi = F.softmax(pi)
y_hat = y_hat.permute(1, 0, 2, 3)
y_hat_agent = y_hat[data['agent_index'], :, :, :2]
pi_agent = pi[data['agent_index'], :]
if self.rotate:
data_angles = data['theta']
data_origin = data['origin']
data_rotate_angle = data['rotate_angles'][data['agent_index']]
data_local_origin = data.positions[data['agent_index'], 19, :]
rotate_mat = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
sin_vals = torch.sin(-data_angles)
cos_vals = torch.cos(-data_angles)
rotate_mat[:, 0, 0] = cos_vals
rotate_mat[:, 0, 1] = -sin_vals
rotate_mat[:, 1, 0] = sin_vals
rotate_mat[:, 1, 1] = cos_vals
rotate_local = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
sin_vals_angle = torch.sin(-data_rotate_angle)
cos_vals_angle = torch.cos(-data_rotate_angle)
rotate_local[:, 0, 0] = cos_vals_angle
rotate_local[:, 0, 1] = -sin_vals_angle
rotate_local[:, 1, 0] = sin_vals_angle
rotate_local[:, 1, 1] = cos_vals_angle
for i in range(data['agent_index'].shape[0]):
stacked_rotate_mat = torch.stack([rotate_mat[i]] * self.num_modes, dim=0)
stacked_rotate_local = torch.stack([rotate_local[i]] * self.num_modes, dim=0)
# print("input shape:", y_hat_agent[i, :, :, :].shape, stacked_rotate_mat.shape, data_origin[i].shape)
y_hat_agent[i, :, :, :] = torch.bmm(y_hat_agent[i, :, :, :], stacked_rotate_local) \
+ data_local_origin[i].unsqueeze(0).unsqueeze(0)
y_hat_agent[i, :, :, :] = torch.bmm(y_hat_agent[i, :, :, :], stacked_rotate_mat) \
+ data_origin[i].unsqueeze(0).unsqueeze(0)
return y_hat_agent, pi_agent, data['seq_id']
Is there anything wrong with my coordinate transformation? Also, I found another weird thing: If I keep original y in preprocess .pt file and transform is following this:
HiVT/datasets/argoverse_v1_dataset.py
Line 131 in 8a2e8cd
I am looking forward to your reply.
Best,
from hivt.
Hi @ZikangZhou,
Thank you for your reply. Yes, there is two transformations from agent-centric to av-centric, then to global coordinate. But the result is still a bit of weird. I will double check if there is a bug in my implementation.
Also, when I attempted to train on Waymo, the pytorch lightning module will not finish the whole train set, althought the tqdm bar shows the correct length of dataloader. e.g. the train iteration to 200 out of 15k for each iteration and then validation started. Have you come across this problem? Thank you in advance.
Best,
from hivt.
No, I haven't encountered such a problem. Maybe you can ask the pytorch lightning community.
from hivt.
I see. Thank you very much.
from hivt.
The code looks fine. The checkpoints are not the ones that I submitted to the server. This codebase is actually a re-implementation of HiVT for the purpose of better readability, so there might exist some unintentional mismatch between this codebase and my original one. Plus, I remember that the final submissions on the leaderboard were trained for more epochs with a slightly different training recipe. But thank you for letting me know the test set performance of the checkpoints. I may try to provide better checkpoints later.
For the second question, I guess this phenomenon is somehow related to the different floating point errors on different computing devices?
from hivt.
Hi @ZikangZhou,
Thank you for your reply. I will try your suggestion on training.
Also, for second question, it happens when training with one device.
from hivt.
Related Issues (20)
- A question about AAEncoder HOT 1
- Qualitative Results Visualization HOT 5
- Test set performance HOT 17
- What do "semantic attributes" represent? HOT 1
- torch.load bottleneck? HOT 1
- Prediction Results for non-agent objects HOT 6
- Code for Qualitative Results analysis HOT 1
- About the loss: reg loss and cls loss HOT 3
- How to obtain the ADE/FDE/MR result of test set? HOT 4
- About no requirements.txt HOT 1
- Code Question HOT 1
- 评价指标报错 HOT 4
- about how to generate test result HOT 2
- The meaning about bos_mask HOT 1
- Bit-wise NOT operation "~" for padding_mask in data.
- What does 'parallel' in args mean?
- HiVT map-free customization HOT 1
- "solving environment killed" while "conda install pytorch-lightning==1.5.2 -c conda-forge"
- cuda and torch version violate
- 大佬,请问第二次执行的时候如果不重新生成pt文件,需要注释掉什么代码或者进行什么操作吗
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from hivt.