Code Monkey home page Code Monkey logo

social-nce's Introduction

Social-NCE + CrowdNav

Website | Paper | Video | Social NCE + Trajectron | Social NCE + STGCNN

This is an official implementation for
Social NCE: Contrastive Learning of Socially-aware Motion Representations
Yuejiang Liu, Qi Yan, Alexandre Alahi, ICCV 2021

TL;DR: Contrastive Representation Learning + Negative Data Augmentations 🡲 Robust Neural Motion Models

[New] our more recent work on this topic:
Towards Robust and Adaptive Motion Forecasting: A Causal Representation Perspective, CVPR 2022.

Preparation

Setup environments follwoing the SETUP.md

Training & Evaluation

  • Behavioral Cloning (Vanilla)
    python imitate.py --contrast_weight=0.0 --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-baseline-data-0.50/policy_net.pth
    
  • Social-NCE + Conventional Negative Sampling (Local)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-local-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0-range-2.00/policy_net.pth
    
  • Social-NCE + Safety-driven Negative Sampling (Ours)
    python imitate.py --contrast_weight=2.0 --contrast_sampling='event' --gpu
    python test.py --policy='sail' --circle --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth
    
  • Method Comparison
    bash script/run_vanilla.sh && bash script/run_local.sh && bash script/run_snce.sh
    python utils/compare.py
    

Basic Results

Results of behavioral cloning with different methods.

Averaged results from the 150th to 200th epochs.

collision reward
Vanilla 12.7% ± 3.8% 0.274 ± 0.019
Local 19.3% ± 4.2% 0.240 ± 0.021
Ours 2.0% ± 0.6% 0.331 ± 0.003

Citation

If you find this code useful for your research, please cite our papers:

@inproceedings{liu2021social,
  title={Social nce: Contrastive learning of socially-aware motion representations},
  author={Liu, Yuejiang and Yan, Qi and Alahi, Alexandre},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
  pages={15118--15129},
  year={2021}
}
@inproceedings{chen2019crowd,
  title={Crowd-robot interaction: Crowd-aware robot navigation with attention-based deep reinforcement learning},
  author={Chen, Changan and Liu, Yuejiang and Kreiss, Sven and Alahi, Alexandre},
  booktitle={International Conference on Robotics and Automation (ICRA)},
  pages={6015--6022},
  year={2019}
}

social-nce's People

Contributors

yuejiangliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

social-nce's Issues

some questions about training error

hello, thanks for you available code.
i met error like this when i run "python imitate.py --contrast_weight=2.0 --contrast_sampling='local' --gpu" in crowd_nav path, the error information is:
RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)
it seems that the attention's linear operation got the wrong shapes, what should i do?
thanks :)

I cannot find "Imitator" in crowd_nav.utils.trainer

Thank you for your hard work.
in the file "crowd_nav/utils/demonstrate.py"

   from crowd_nav.utils.frames import FrameStack
  from crowd_nav.utils.trainer import Imitator
  from crowd_nav.utils.memory import ReplayMemory

I cannot find "Imitator" in crowd_nav.utils.trainer.
will you help me ?

Error for testing with visualize option

I tried to run the test.py script with visualize option python test.py --policy='sail' --circle --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth --visualize and got an error below

Traceback (most recent call last):
  File "test.py", line 126, in <module>
    main()
  File "test.py", line 105, in main
    action = robot.act(ob)
  File "/home/vale/github/mpc-nav/social-nce/crowd_sim/envs/utils/robot.py", line 13, in act
    action = self.policy.predict(state)
  File "/home/vale/github/mpc-nav/social-nce/crowd_nav/policy/sail.py", line 122, in predict
    self.last_state = self.transform(state)
  File "/home/vale/github/mpc-nav/social-nce/crowd_nav/policy/sail.py", line 134, in transform
    num_human = len(state.human_states)
TypeError: object of type 'ObservableState' has no len()

Can anyone help me with that issue?

Training on Trajnet++

Hi I am about to try your code on the Trajnet++.

I am trying to get the FDE score of 1.14 .

Did you train on the whole (with cff) training dataset?
How many epochs?
And with which parameters?

If I reach your performance I will delete the submission if desired.

Thanks in advance
Many greetings

Got an error when training social NCE model

Hello,

I tried to train a social NCE with GPU option using python imitate.py --contrast_weight=2.0 --contrast_sampling='event' --gpu, and I got an error below

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mask in method wrapper__masked_select)

Does anyone help me to fix this issue? Thanks in advance.

Some issues in test with square and mixed scenario

Hello, me and my team are trying to analyze your works(both of crowdnav and social-nce) as term project.
While I run test.py I face two problems.
1st. In square scenario, there is too large decreasing of performance. WHY?

!python test.py --policy='sail' --square --model_file=data/output/imitate-event-data-0.50-weight-2.0-horizon-4-temperature-0.20-nboundary-0/policy_net.pth
(skip)
2021-12-05 08:55:59, INFO: TEST success: 0.58, collision: 0.41, nav time: 10.43, reward: 0.1095 +- 0.2644
2021-12-05 08:55:59, INFO: Frequency of being in danger: 1.23

2nd. In mixed scenario, (I add argument in parser), error is occur
Traceback (most recent call last):

File "test.py", line 129, in
main()
File "test.py", line 126, in main
explorer.run_k_episodes(env.case_size[args.phase], args.phase, print_failure=False)
File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/explorer.py", line 60, in run_k_episodes
action = self.robot.act(ob)
File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_sim/envs/utils/robot.py", line 13, in act
action = self.policy.predict(state)
File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 123, in predict
action = self.model(self.last_state[0], self.last_state[1])[0].squeeze()
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/policy/sail.py", line 73, in forward
human_state = self.transform.transform_frame(crowd_obsv)
File "/content/drive/My Drive/ColabFiles/robotvision/social-nce/crowd_nav/utils/transform.py", line 14, in transform_frame
state = torch.cat([frame, relative], axis=2)
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 5 for tensor number 1 in the list.

If you already have knew about those problem, please share the answers. Thank you for reading

a issue about negative samples in logits

Thanks for your great work!

logits = torch.cat([sim_pos.view(-1).unsqueeze(1), sim_neg.view(sim_neg.size(0), -1).repeat_interleave(self.horizon, dim=0)], dim=1) / self.temperature

I am not sure why "sim_neg.view(sim_neg.size(0), -1).repeat_interleave(self.horizon, dim=0)" rather than "sim_neg.view(sim_neg.size(0) * sim_neg.size(1), -1)". The number of negative samples corresponding to one positive sample should be (N-1)*8 or H*(N-1)*8?

What does SAIL refer to?

Thanks for your wonderful repo!
May I ask what does SAIL policy refer to?
Based on the paper I only find SARL but not SAIL. Is SAIL an upgrade version of SARL?
Thanks!

Error when download data

When I download data, pip install gdown && gdown https://drive.google.com/uc?id=1D2guAxD_EgrKnJFMcLSBkf10SOagz0mr,
it occurs that:
Traceback (most recent call last):
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 169, in _new_conn
conn = connection.create_connection(
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 96, in create_connection
raise err
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/connection.py", line 86, in create_connection
sock.connect(sa)
OSError: [Errno 101] Network is unreachable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 699, in urlopen
httplib_response = self._make_request(
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 382, in _make_request
self._validate_conn(conn)
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 1010, in _validate_conn
conn.connect()
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 353, in connect
conn = self._new_conn()
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connection.py", line 181, in _new_conn
raise NewConnectionError(
urllib3.exceptions.NewConnectionError: <urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 439, in send
resp = conn.urlopen(
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/connectionpool.py", line 755, in urlopen
retries = retries.increment(
File "/home/lyl/miniconda3/lib/python3.8/site-packages/urllib3/util/retry.py", line 573, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/lyl/miniconda3/bin/gdown", line 8, in
sys.exit(main())
File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/cli.py", line 95, in main
download(
File "/home/lyl/miniconda3/lib/python3.8/site-packages/gdown/download.py", line 77, in download
res = sess.get(url, stream=True)
File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 555, in get
return self.request('GET', url, **kwargs)
File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 542, in request
resp = self.send(prep, **send_kwargs)
File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/sessions.py", line 655, in send
r = adapter.send(request, **kwargs)
File "/home/lyl/miniconda3/lib/python3.8/site-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPSConnectionPool(host='drive.google.com', port=443): Max retries exceeded with url: /uc?id=1awXDsRQcmgacj7nUhPzwb5UMNZCJCjvu (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f96a64c1100>: Failed to establish a new connection: [Errno 101] Network is unreachable'))

Can you fix it?Thanks.

about offline RL

I have read your article(Social-NCE) carefully and am very interested in the part about offline reinforcement learning, but I don't seem to find the corresponding part in this code. Can you tell me how to get it? or can you give me some recommended codes?

Question regarding training

Hello,

first of all, thank you for the publicly available code! I have 2 understanding questions regarding the training:

  1. as far as I understand you determine the positive and negative samples based on the ground truth at time t, with the horizon of e.g. 5 time steps. What about the possible collisions before and after the one point?
  2. What is the ground truth in reinforcement learning? Do you use here a linear model?

Best regards

What does 'sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))' mean

Thanks for your hard work.

I have some questions about this code
for alpha in alpha_list:
sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))
Looks like
α* neg_ seeds + (1- α)* pos_ seeds
What is the purpose of this code

from sampling.py

    # primary-neighbor boundary
    if self.num_boundary > 0:
        alpha_list = torch.linspace(self.ratio_boundary, 1.0, steps=self.num_boundary) #0.5-1 分成0-9份 参数决定
        sample_boundary = []
        for alpha in alpha_list:
            sample_boundary.append(neg_seed * alpha + pos_seed.unsqueeze(2) * (1-alpha))
        sample_boundary = torch.cat(sample_boundary, axis=2)
        sample_neg = torch.cat([sample_boundary, sample_territory], axis=2)
    else:
        sample_neg = sample_territory

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.