Comments (8)
@bzeni1 Hi, could you share the minimal example that I can reproduce your issue? It sounds like your code is simply incorrect.
btw, when you instantiate algorithms, you need to do as follows:
ddpg = d3rlpy.algos.DDPGConfig().create()
from d3rlpy.
@takuseno Hi, find my code below. What could be the problem? Thank you in advance for your assistance on this matter.
processed_data = race_data.copy()
settings_columns = [ #22 selected coloumns from my dataset ]
processed_data['reward'] = processed_data['ACCELERATION_m_s2']
for col in settings_columns:
processed_data[f'state_{col}'] = processed_data[col]
for col in settings_columns:
processed_data[f'action_{col}'] = processed_data[col].diff().fillna(0)
for col in settings_columns:
processed_data[f'next_{col}'] = processed_data[col].shift(-1)
#end of an episode (race)
processed_data['done'] = processed_data['race_num'].diff(-1) != 0
processed_data = processed_data[processed_data['done'] == False]
print("After filtering rows:", processed_data.shape)
print("States shape:", processed_data[settings_columns].shape)
print("Actions shape:", processed_data[[f'action_{col}' for col in settings_columns]].shape)
print("Rewards shape:", processed_data['reward'].shape)
print("Next states shape:", processed_data[[f'next_{col}' for col in settings_columns]].shape)
print("Dones shape:", processed_data['done'].shape)
**Output**
States shape: (17642, 22)
Actions shape: (17642, 22)
Rewards shape: (17642,)
Next states shape: (17642, 22)
Dones shape: (17642,)
states = processed_data[[f'state_{col}' for col in settings_columns if f'state_{col}' in processed_data.columns]].to_numpy()
actions = processed_data[[col for col in processed_data.columns if col.startswith('action_')]].to_numpy()
rewards = processed_data['reward'].to_numpy()
next_states = processed_data[[col for col in processed_data.columns if col.startswith('next_')]].to_numpy()
dones = processed_data['done'].to_numpy()
#next step:
dataset = MDPDataset(states, actions, rewards, next_states, dones)
#ValueError: operands could not be broadcast together with shapes (388124,) (17642,)
from d3rlpy.
Thanks for sharing your code. It looks like next_states
is unnecessary. It needs to be as follows:
dataset = MDPDataset(states, actions, rewards, dones)
from d3rlpy.
Thanks for your advice. By removing next_states I am encountering a new issue:
ValueError: Either episodes or env must be provided to determine signatures. Or specify signatures directly.
However I already defined the segment by the 'done' flags, I still don't know how to determine the episodes. What do you think?
from d3rlpy.
My guess is thatdones
is all zeros, thus episodes couldn't be found. You need to correctly setup dones
.
from d3rlpy.
Hi
I think I am running into a similar issue. I have 2 datasets. FOr both of them all the dimensions are the same
observations: (5000, 4), actions: (5000, 2), rewards: (5000,), terminals: (5000,)
But with 1 dataset the fit function for IQL fails. Although I am getting a different error. I can see that both datasets have some terminals = 1.
Any suggestions for where an error like this might come up?
`
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/base.py", line 409, in fit
results = list(
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/base.py", line 543, in fitter
loss = self.update(batch)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/base.py", line 863, in update
loss = self._impl.update(torch_batch, self._grad_step)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/torch_utility.py", line 365, in wrapper
return f(self, *args, **kwargs) # type: ignore
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/base.py", line 70, in update
return self.inner_update(batch, grad_step)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/torch/ddpg_impl.py", line 118, in inner_update
metrics.update(self.update_critic(batch))
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/torch/ddpg_impl.py", line 84, in update_critic
loss = self.compute_critic_loss(batch, q_tpn)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/algos/qlearning/torch/iql_impl.py", line 73, in compute_critic
_loss
q_loss = self._q_func_forwarder.compute_error(
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/q_functions/ensemble_q_function.py", line 256, in
compute_error
return compute_ensemble_q_function_error(
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/q_functions/ensemble_q_function.py", line 96, in
compute_ensemble_q_function_error
loss = forwarder.compute_error(
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/q_functions/mean_q_function.py", line 130, in com
pute_error
value = self._q_func(observations, actions).q_value
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/q_functions/base.py", line 35, in __call__
return super().__call__(x, action) # type: ignore
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/q_functions/mean_q_function.py", line 99, in forw
ard
q_value=self._fc(self._encoder(x, action)),
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/encoders.py", line 41, in __call__
return super().__call__(x, action)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/d3rlpy/models/torch/encoders.py", line 284, in forward
return self._layers(x)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/rohan/anaconda3/envs/franka/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x6 and 5x256)
`
from d3rlpy.
@rohanblueboybaijal Sorry for the late response. Could you share a minimal example that I can reproduce your error?
from d3rlpy.
Let me close this issue since the initial question should be resolved. Feel free to open a new issue to follow up.
from d3rlpy.
Related Issues (20)
- [BUG] Issue with CUDA Device-Side Assertion Failure During Training HOT 2
- [REQUEST] Allow dynamic neural architectures to be used HOT 2
- [Question] Is it possible to only collect image observation? HOT 3
- [BUG] error when running dt on atari HOT 2
- Issue: Creating MDPDataset from CSV file in d3rlpy HOT 2
- [QUESTION] Continously increasing loss and TD error HOT 7
- [REQUEST] NotImplementedError: "save_policy method does not support tuple observation yet." HOT 4
- [QUESTION] Adding a new algorithm to d3rlpy HOT 3
- [REQUEST] save decision transformer models as TorchScript/ONNX HOT 2
- [BUG] DiscreteDecisionTransformer Inference Problem, AttributeError: 'numpy.ndarray' object has no attribute 'length' HOT 4
- [Question]Just want to make sure that the "environment" metric collected by the logger of offline RL algorithms is the result used in papers HOT 2
- d3rlpy install d4rl HOT 4
- ValueError: too many values to unpack (expected 4) when using hopper-medium-v0 environment HOT 6
- [BUG] How to continue training from a save checkpoint HOT 2
- [QUESTION] len(observation_shape) == 1 HOT 5
- [BUG] saving and loading model with custom network gives KeyError: 'custom' HOT 3
- [QUESTION] Offline Learning via custom MDPDataset HOT 1
- gym version incompatibility HOT 2
- Differences in RTG computation between inference and training time HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from d3rlpy.