dongminlee94 / meta-learning-for-everyone Goto Github PK
View Code? Open in Web Editor NEW"모두를 위한 메타러닝" 책에 대한 코드 저장소
License: Apache License 2.0
"모두를 위한 메타러닝" 책에 대한 코드 저장소
License: Apache License 2.0
The bug on RL^2's buffer needs to be fixed.
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
A clear and concise description of what you want to happen.
Add any other context or screenshots about the feature request here.
In the original MAML paper, the author implements a value function (vf)
as a linear function and fits it with batch data every time.
Learn2Learn repo implement vf
exactly same with the paper.
However, in Ray Project repo, vf
is implemented as a 2 layered neural network and updated simultaneously with a policy network
. In this case, vf is shared through tasks and iterations.
we need to determine how to implement this vf
.
self.tasks
uses dictionary unnecessarily.directions = [-1, 1, -1, 1]
self.tasks = [{"direction": direction} for direction in directions]
https://github.com/dongminlee94/meta-rl/blob/develop/src/envs/half_cheetah_dir.py#L19
self._goal
is not used in both cheetah-dir and cheetah-vel environmentsdef __init__(self, num_tasks=2, seed=0):
super().__init__(render=False)
self.tasks = self.sample_tasks(num_tasks)
self._goal_vel = self.tasks[0].get("velocity", 0.0)
self._goal = self._goal_vel
self._task = None
self._alive = None
self.rewards = None
self.potential = None
self.seed(seed)
'''
'''
'''
def reset_task(self, index):
"""Reset velocity target to index of task"""
self._task = self.tasks[index]
self._goal_vel = self._task["velocity"]
self._goal = self._goal_vel
self.reset()
self.tasks
from dict to listself._goal
attributeAdd any other context or screenshots about the feature request here.
This issue is to gather the common modules (e.g., sampler, buffer). Because the repo's codes currently have too many duplication codes.
현재 Colab에서 해당 코드 실행시 torchmeta가 import가 안되는 문제가 발생합니다.
파이썬 버전 문제인가 싶어서 python3.7로 실행했는데,
패키지 설치까진되고, 따로 import가 안됩니다.
혹시 해결하신분 계신가요?
Installing collected packages: urllib3, typing-extensions, tqdm, Pillow, ordered-set, numpy, idna, charset-normalizer, certifi, torch, requests, h5py, torchvision, torchmeta
Successfully installed Pillow-9.5.0 certifi-2022.12.7 charset-normalizer-3.1.0 h5py-3.8.0 idna-3.4 numpy-1.21.6 ordered-set-4.1.0 requests-2.30.0 torch-1.9.1 torchmeta-1.8.0 torchvision-0.10.1 tqdm-4.65.0 typing-extensions-4.5.0 urllib3-2.0.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
argv[0]=
is occurred at the beginning of the training.
Steps to reproduce the behavior:
A clear and concise description of what you expected to happen.
Add any other context about the problem here.
All linters' configurations can be integrated into setup.cfg
.
https://github.com/annotation-ai/python-project-template/blob/main/setup.cfg
You can see the integrated configurations for black
, isort
, mypy
, flake8
, pylint
, pytest
.
This issue is to add type annotation to the repo's codes
This issue is an issue that must adds the list below.
The issue is related to #22
Currently get_action
is misplaced in sampler.py
.
To remove it, it is needed to check whether there are any problems in putting agent
into the sampler.py
.
recommended code: self.agent.policy = inner_policy
need to determine how to implement meta-test process of MAML algorithm.
Change the code from
def collect_train_data(self, task_index, max_samples, update_posterior, add_to_enc_buffer):
"""Data collecting for meta-train"""
self.agent.encoder.clear_z()
self.agent.policy.is_deterministic = False
cur_samples = 0
while cur_samples < max_samples:
trajs, num_samples = self.sampler.obtain_samples(
max_samples=max_samples - cur_samples,
update_posterior=update_posterior,
accum_context=False,
)
cur_samples += num_samples
self.rl_replay_buffer.add_trajs(task_index, trajs)
if add_to_enc_buffer:
self.encoder_replay_buffer.add_trajs(task_index, trajs)
if update_posterior:
context_batch = self.sample_context([task_index])
self.agent.encoder.infer_posterior(context_batch)
def obtain_samples(self, max_samples, update_posterior, accum_context=True):
"""Obtain samples up to the number of maximum samples"""
trajs = []
cur_samples = 0
while cur_samples < max_samples:
traj = self.rollout(accum_context=accum_context)
trajs.append(traj)
cur_samples += len(traj["cur_obs"])
self.agent.encoder.sample_z()
if update_posterior:
break
return trajs, cur_samples
to
def collect_train_data(self, task_index, max_samples, update_posterior, add_to_enc_buffer):
"""Data collecting for meta-train"""
self.agent.encoder.clear_z()
self.agent.policy.is_deterministic = False
trajs, num_samples = self.sampler.obtain_samples(
max_samples=max_samples,
accum_context=False,
)
self.rl_replay_buffer.add_trajs(task_index, trajs)
if add_to_enc_buffer:
self.encoder_replay_buffer.add_trajs(task_index, trajs)
if update_posterior:
context_batch = self.sample_context([task_index])
self.agent.encoder.infer_posterior(context_batch)
def obtain_samples(self, max_samples, accum_context=True):
"""Obtain samples up to the number of maximum samples"""
trajs = []
cur_samples = 0
while cur_samples < max_samples:
traj = self.rollout(accum_context=accum_context)
trajs.append(traj)
cur_samples += len(traj["cur_obs"])
self.agent.encoder.sample_z()
return trajs, cur_samples
Then, experiment PEARL
Killed error is occurred on laptop notebooks to be the cause of out of memory
Steps to reproduce the behavior:
A clear and concise description of what you expected to happen.
Add any other context about the problem here.
Change from python config files to yaml config files
Hi, friends. I happen to find this great repo! It seems much elegant than most other meta-rl implementations. I wonder if the Feature
branch is tested on environments like cheetah
and can be readily used for other environments. Thanks!
torch.lstsq()
is deprecated in favor of torch.linalg.lstsq()
in Pytorch version >= 1.9.0.
for more stable codes we might modify the function for users with higher version of Pytorch as followed.
if hasattr(torch, 'lstsq'): # Required for torch < 1.9.0
coeffs = torch.lstsq(b, A).solution
else:
coeffs = torch.linalg(A, b).solution
In MAML we uses "Higher" module which is a library providing support for higher-order optimization developed by Facebook.
https://github.com/facebookresearch/higher
It turns existing torch.nn.Module
instances "stateless", meaning that changes to the parameters thereof can be tracked.
Therefore, it is needed to be checked whether the implemented Higher module really tracks the parameters of policy
through the outer loop and inner loop.
@all-contributors please add @dongminlee94 for infrastructure, tests and code
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.