I am trying to use train VMAS with my custom model that take both the 2D birds eye fully observable game map observation and the continuous observations as input. In the development roadmap, I noticed that you are also planning to implement the 2D birds eye view.
I created the following working example with a dummy env.
import ray
import gym
from gym import spaces
import numpy as np
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.visionnet import VisionNetwork as TorchVis
from ray.tune.registry import register_env
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
ray.init(local_mode=True)
class TorchCustomModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.torch_sub_fc_model = TorchFC(obs_space.original_space['fc'], action_space, num_outputs, model_config, name)
self.torch_sub_vis_model = TorchVis(obs_space.original_space['vis'], action_space, num_outputs, model_config, name)
self.value_f = nn.Linear(2, 1)
self.head = nn.Linear( num_outputs * 2, action_space.shape[0]*2)
def forward(self, input_dict, state, seq_lens):
fc_out, _ = self.torch_sub_fc_model({"obs": input_dict["obs"]['fc']}, state, seq_lens)
# print("fc_out shape:", fc_out.shape) # Debug print
cnn_out, _ = self.torch_sub_vis_model({"obs": input_dict["obs"]['vis']}, state, seq_lens)
# print("cnn_out shape:", cnn_out.shape) # Debug print
x = torch.cat((fc_out, cnn_out), -1)
# print("concatenated shape:", x.shape) # Debug print
out = self.head(x)
return out, []
def value_function(self):
vf_fc = self.torch_sub_fc_model.value_function()
vf_cnn = self.torch_sub_vis_model.value_function()
vf_combined = torch.stack([vf_fc, vf_cnn], -1)
return self.value_f(vf_combined).squeeze(-1)
# My custom environment
class MyEnv(gym.Env):
def __init__(self, env_config):
self.observation_space = spaces.Dict({
"fc": spaces.Box(low=0, high=1, shape=(100,)),
"vis": spaces.Box(low=0, high=1, shape=(96, 96, 3))
})
self.action_space = spaces.Box(low=-1, high=1, shape=(2,))
def reset(self):
return {"fc": np.random.rand(100), "vis": np.random.rand(96, 96, 3)}
def step(self, action):
return {"fc": np.random.rand(100), "vis": np.random.rand(96, 96, 3)}, 1, False, {}
register_env("my_env", lambda config: MyEnv(config))
ModelCatalog.register_custom_model("my_model", TorchCustomModel)
config = DEFAULT_CONFIG.copy()
config['framework'] = 'torch'
config['model'] = {
"custom_model": "my_model",
"dim": 96,
"conv_filters":[[16, [8, 8], 4], [32, [4, 4], 2], [256, [11, 11], 2]]
}
config['env'] = "my_env"
trainer = PPOTrainer(config=config)
for i in range(10):
result = trainer.train()
print(result)