Code Monkey home page Code Monkey logo

Comments (7)

ZikangZhou avatar ZikangZhou commented on June 3, 2024 2

The coordinate system needs to be transformed twice. First, according to equation 15 in the paper,
Screenshot 2022-08-31 at 9 42 03 AM
the outputs are in per-agent local coordinate systems. So the first step is to convert the outputs to the global coordinate system (which is AV-centric ) by using data['rotate_angles'] for rotation and using data['positions'][:, 'current_time_step'] for translation. The 'current_time_step' for Argoverse 1 is 19. Now, the outputs are in an AV-centric coordinate system, so the second step is to convert the outputs back to the original global coordinate system using data['origin'] and data['theta'].

By the way, normalizing the data into an AV-centric coordinate system during data preprocessing is redundant. I did this only for the convenience of ablation studies.

from hivt.

SwagJ avatar SwagJ commented on June 3, 2024 1

Hi @ZikangZhou,

Sorry to reopen this issue, but is the provided hivt-128 checkpoint the one used for reporting test set result in paper? When I generated test set prediction using the provided hivt-128, the result is not as good as reported in the paper. the minFDE is 1.232 and brief-minFDE is 1.926. The prediction_step is as follow.

        y_hat, pi = self(data)
        pi = F.softmax(pi)
        y_hat = y_hat.permute(1, 0, 2, 3)
        y_hat_agent = y_hat[data['agent_index'], :, :, :2]
        pi_agent = pi[data['agent_index'], :]
        if self.rotate:
            data_angles = data['theta']
            data_origin = data['origin']
            data_rotate_angle = data['rotate_angles'][data['agent_index']]
            data_local_origin = data.positions[data['agent_index'], 19, :]
            rotate_mat = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
            sin_vals = torch.sin(-data_angles)
            cos_vals = torch.cos(-data_angles)
            rotate_mat[:, 0, 0] = cos_vals
            rotate_mat[:, 0, 1] = -sin_vals
            rotate_mat[:, 1, 0] = sin_vals
            rotate_mat[:, 1, 1] = cos_vals
            rotate_local = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
            sin_vals_angle = torch.sin(-data_rotate_angle)
            cos_vals_angle = torch.cos(-data_rotate_angle)
            rotate_local[:, 0, 0] = cos_vals_angle
            rotate_local[:, 0, 1] = -sin_vals_angle
            rotate_local[:, 1, 0] = sin_vals_angle
            rotate_local[:, 1, 1] = cos_vals_angle
            for i in range(data['agent_index'].shape[0]):
                stacked_rotate_mat = torch.stack([rotate_mat[i]] * self.num_modes, dim=0)
                stacked_rotate_local = torch.stack([rotate_local[i]] * self.num_modes, dim=0)
                # print("input shape:", y_hat_agent[i, :, :, :].shape, stacked_rotate_mat.shape, data_origin[i].shape)
                y_hat_agent[i, :, :, :] = torch.bmm(y_hat_agent[i, :, :, :], stacked_rotate_local) \
                                          + data_local_origin[i].unsqueeze(0).unsqueeze(0)
                y_hat_agent[i, :, :, :] = torch.bmm(y_hat_agent[i, :, :, :], stacked_rotate_mat) \
                                          + data_origin[i].unsqueeze(0).unsqueeze(0)

        
        return y_hat_agent, pi_agent, data['seq_id']

Is there anything wrong with my coordinate transformation? Also, I found another weird thing: If I keep original y in preprocess .pt file and transform is following this:

x[node_idx, node_steps] = torch.matmul(xy - origin, rotate_mat)
in network's forward, the transformed original y and y_agent has a value difference at about 0.002 ~0.006. For this, I am not quite understand.
I am looking forward to your reply.

Best,

from hivt.

SwagJ avatar SwagJ commented on June 3, 2024

Hi @ZikangZhou,

Thank you for your reply. Yes, there is two transformations from agent-centric to av-centric, then to global coordinate. But the result is still a bit of weird. I will double check if there is a bug in my implementation.
Also, when I attempted to train on Waymo, the pytorch lightning module will not finish the whole train set, althought the tqdm bar shows the correct length of dataloader. e.g. the train iteration to 200 out of 15k for each iteration and then validation started. Have you come across this problem? Thank you in advance.

Best,

from hivt.

ZikangZhou avatar ZikangZhou commented on June 3, 2024

No, I haven't encountered such a problem. Maybe you can ask the pytorch lightning community.

from hivt.

SwagJ avatar SwagJ commented on June 3, 2024

I see. Thank you very much.

from hivt.

ZikangZhou avatar ZikangZhou commented on June 3, 2024

The code looks fine. The checkpoints are not the ones that I submitted to the server. This codebase is actually a re-implementation of HiVT for the purpose of better readability, so there might exist some unintentional mismatch between this codebase and my original one. Plus, I remember that the final submissions on the leaderboard were trained for more epochs with a slightly different training recipe. But thank you for letting me know the test set performance of the checkpoints. I may try to provide better checkpoints later.

For the second question, I guess this phenomenon is somehow related to the different floating point errors on different computing devices?

from hivt.

SwagJ avatar SwagJ commented on June 3, 2024

Hi @ZikangZhou,

Thank you for your reply. I will try your suggestion on training.
Also, for second question, it happens when training with one device.

from hivt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.