Code Monkey home page Code Monkey logo

Comments (23)

etienne87 avatar etienne87 commented on June 26, 2024 4

Anyway, there is a first implementation that works fine if you don't have too much underachieved experiences (of length < Config.TIME_MAX) here

I "solved" the issue by padding sequences in ThreadTrainer.py.

In order to be optimal, we would need to dynamically batch the data after the feedforward encoder (before the LSTM), in order to feed a (N, TIME_MAX, 256) Tensor to tf.dynamic_rnn; However I am not convinced this really slows down the process as most of experience batches should be full (sequence length is TIME_MAX).

I will now test on Pong, fuse with GAE branch. If someone wants to help me understand how to improve this you are welcome! :-)

from ga3c.

ricky1203 avatar ricky1203 commented on June 26, 2024 1

@etienne87
check the def _create_rnn_from_cell() in model.py

note: for hidden states stored in model, agent should predict/train in one model(GPU device) during one episode

from ga3c.

ifrosio avatar ifrosio commented on June 26, 2024

Not immediately, but it shouldn't be hard to implement it in TF.
If you have any version with LSTM, please let us know.

from ga3c.

ieow avatar ieow commented on June 26, 2024

Is it possible to implement lstm in this ga3c architecture?
RNN (lstm) required serialize input, but based on this ga3c architecture which push exp to queue from multiple agent would not make the 'exp' serial input. Thus, batch input for training thread would be mixed and cannot be used as RNN training input.
Correct me if I am wrong.
Thanks

from ga3c.

adi-sharma avatar adi-sharma commented on June 26, 2024

Should be straight forward, as the state for Atari games is already defined as 4 frames together (See section 4.1 of the original DQN paper - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) and that is what GA3C uses. If you supply those frames serially, the LSTM version of GA3C will work.

from ga3c.

mbz avatar mbz commented on June 26, 2024

Implementing the LSTM version without lots of code change depends on how long the sequences of training data should be. If the sequences are as long as TMAX frames (which I think is the case) then the current architecture works since Trainers receive sequences of TMAX frames. But if the training data should be any longer (i.e. multiple TMAX frames merged together) it becomes a little bit more complicated.

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

in case of LSTM, shouldn't the batch be organized in (N, T, C, H, W) format?

from ga3c.

mbz avatar mbz commented on June 26, 2024

@etienne87 you are correct. But please look at here. What Trainer receives is in (N, T, C, H, W) format but it merges the T dimension to have data in (N, C, H, W) format. In a recurrent model these concatenations are unnecessary.

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

@mbz, thanks for pointing to this. Now i am super confused with this part of the code! can you take a look at #6 ? I don't see how these concatenations are working at all!

I would suggest to modify ThreadTrainer.py to :

           if self.server.model.rnn:
                print('todo')
            else:
                while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
                    x_, r_, a_ = self.server.training_q.get()
                    if batch_size == 0:
                        x__ = x_; r__ = r_; a__ = a_
                    else:
                        x__ = np.concatenate((x__, x_))
                        r__ = np.concatenate((r__, r_))
                        a__ = np.concatenate((a__, a_))     
                    batch_size += x_.shape[0]

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

LSTM would require reset_state func to address a specific row from the batch right?

class NetworkVP():
    [...]
    def reset_state(self, idx):
        #todo...
        self.lstm_state_c[idx,...] = 0
        self.lstm_state_h[idx,...] = 0

sorry for pseudocode, not expert with TF.

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

Another confusion I have about this, (because little experience with TF). It seems we need 2 graphs : one for prediction (taking a dynamic_rnn), and one (maybe taking a static tf.rnn?) for backprop (if feeding (N, T, C, H, W) , or is there a way to use a gradient applier like in myosuda ?
Sorry if this is not really the good place to ask.

from ga3c.

mbz avatar mbz commented on June 26, 2024

@etienne87 I'm not sure if I understand your first question about reset_state correctly. Can you please provide more details?

About having separate graphs, there are different ways of implementing the same logic in TF. We are not using separate graphs simply because it's not necessary. Can you please leverage why you think having two graphs is necessary?

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

@mbz ok! What I mean : In classic A3C, it seems we can just backprop at the end of an episode (T_MAX), by just re-using the already computed predictions. On the other hand, here, it seems we need to recompute the predictions with the samples and actions. In short : X should be (N, H, W, C) in predictor thread, (N, T, H, W, C) in the train function? Maybe I misunderstood something about TF internal mechanism?

Also, the thing about reset : at beginning of each episode you probably want to reset to zero the c & h of your lstm. So as @ppwwyyxx is suggesting, lstm state should be saved inside each ProcessAgent ?

from ga3c.

ppwwyyxx avatar ppwwyyxx commented on June 26, 2024

@mbz I have implemented A3C-LSTM with long sequence length. You don't have to send the whole sequence into the graph. What I did is to maintain the current LSTM hidden state for every game simulator in Python side, and every time feed the new frame together with the hidden state of each simulator to the graph.
This way the sequence length could be as long as one episode.

from ga3c.

markovyao avatar markovyao commented on June 26, 2024

@ppwwyyxx I have built an LSTM and stored the hidden states.
However, I got two questions. 1. where to reset or initial the stored states before each episode? 2. how to deal with the class Experience in LSTM training?

from ga3c.

ppwwyyxx avatar ppwwyyxx commented on June 26, 2024

@markovyao
The states were maintained in python, inside each agent (simulator), so you can easily set them when needed (e.g. right after the agent reaches the end of episode).
Since each agent maintains its own hidden states, it can do the following by its own:

  1. keep the hidden states in its own experience history buffer and give it to the network for training
  2. send the hidden states to predictor to get the next action
  3. request the predictor to send back the next hidden state and keep it

from ga3c.

ricky1203 avatar ricky1203 commented on June 26, 2024

@markovyao
alternative implement to @ppwwyyxx solution:

  1. create matrix vars to store LSTM hidden states
  2. every agent assign an unique agent_index
  3. use tf.gather to select the hidden states from matrix to pass to the tf.nn.dynamic_rnn
    init_state = tf.gather(_init_state, self._input_agent_indexs)
  4. use tf.scatter_update to reset/update LSTM hidden states accord the agent_index/last_is_over, for example:
        if tc.is_training:
            need_reset_states = tf.reshape(tf.ones_like(self._input_is_over) - self._input_is_over, (-1, 1))
            op_updates = [tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx] * tf.cast(need_reset_states, rnn_output_states_array[idx].dtype)) \
                          for idx in range(len(rnn_output_states_array))]
        else:
            # in predict mode, the is_over is for last state
            batch_size = tf.shape(self._input_agent_indexs)[0]
            op_updates = []
            for idx in range(len(initial_rnn_states)):
                shape_states = tf.shape(initial_rnn_states[idx])
                op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, tf.zeros((batch_size,shape_states[1]), dtype=initial_rnn_states[idx].dtype))
                op_resets.append(op)
                op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx])
                op_updates.append(op)
  1. in predict/train, call the op update/reset when needed

this implement is useful if you have many LSTM network or in frequent modification development, for it only export update/reset ops outside model

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

@ricky1203 : could you perhaps provide an example/ link in context?

from ga3c.

Golly avatar Golly commented on June 26, 2024

@etienne87
Do you have success with developing LSTM pls?

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

@Golly not so much to be honest. Also I think I first need to test idea referred in #16; Otherwise LSTM version will need re-computation of TMAX steps before each backward & update.

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

Coming back to this problem with a slightly more understanding on with variable length rnn :
I think the easiest way to code the LSTM version is to keep track of c, h states in Experiences Queues.

In ThreadTrainer::run :

 while not self.exit_flag:
            batch_size = 0
            ids = []
            lengths = []
            while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
                idx, x_, r_, a_, c_, h_ = self.server.training_q.get()
                lengths.append(x_.shape[0])
                if batch_size == 0:
                    x__ = x_; r__ = r_; a__ = a_; c__ = c_; h__ = h_;
                else:
                    x__ = np.concatenate((x__, x_))
                    r__ = np.concatenate((r__, r_))
                    a__ = np.concatenate((a__, a_))
                    c__ = np.concatenate((c__,c_))
                    h__ = np.concatenate((h__,h_))
                
                ids.append(idx)
                batch_size += x_.shape[0]
            
            if Config.TRAIN_MODELS:
                self.server.train_model(x__, r__, a__,c__,h__, lengths) 

In NetworkVP::_create_graph

self.d1 = ... #result of feedforward encoder
self.lstm = rnn.BasicLSTMCell(256, state_is_tuple=True)
self.step_sizes = tf.placeholder(tf.int32, [None], name='stepsize') #given by ThreadTrainer, otherwise assume np.ones((batch_predict_size))
batch_size = tf.shape(self.step_sizes)[0]    
d1 = tf.reshape(self.d1, [batch_size,-1,256]) #this will not work without a special function
self.c0 = tf.placeholder(tf.float32, [None, 256])
self.h0 = tf.placeholder(tf.float32, [None, 256])
self.initial_lstm_state = rnn.LSTMStateTuple(self.c0,self.h0)  
lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn(self.lstm,d1,
                                                        initial_state = self.initial_lstm_state,
                                                        sequence_length = self.step_sizes,
                                                        time_major = False))
self._state = tf.reshape(lstm_outputs, [-1,256])  #pass this vector to pi, v

In NetworkVP::predict_p_and_v:

step_sizes = np.ones((c.shape[0],),dtype=np.int32)
feed_dict = self.__get_base_feed_dict()
feed_dict.update({self.x: x, self.step_sizes:step_sizes, self.c0:c, self.h0:h})
p, v, rnn_state = self.sess.run([self.softmax_p, self.logits_v, self.lstm_state], feed_dict=feed_dict)
return p, v, rnn_state.c, rnn_state.h

In NetworkVP::train:

step_sizes = np.array(lengths)
feed_dict = self.__get_base_feed_dict()
feed_dict.update({self.x: x,  self.y_r: r, self.action_index: a, self.step_sizes:step_sizes, self.c0:c, self.h0:h})
r = np.reshape(y_r,(y_r.shape[0],))
self.sess.run(self.train_op, feed_dict=feed_dict)

I think the only thing i am missing is how to sort of "unpack" sequence of encoded states in _create_graph method :

d1 = tf.reshape(self.d1, [batch_size,-1,256]) will not work when sequence lengths are variable, does anybody know TF enough to tell me how to use `step_sizes' in order to create a list of (nstep, 256) tensors?

from ga3c.

etienne87 avatar etienne87 commented on June 26, 2024

Hum, Actually there was still an error in my code, I forgot to mask the loss for padding inputs!

I propose a first fix here

Apparently this now works better (at least for CartPole-v0)

In Config.py :

    TIME_MAX = 5
    STACKED_FRAMES = 4
    IMAGE_WIDTH = 1
    IMAGE_HEIGHT = 4
    EPISODES = 4000
    ANNEALING_EPISODE_COUNT = 4000
    BETA_START = 0.01
    BETA_END = 0.01
    LEARNING_RATE_START = 0.0003
    LEARNING_RATE_END = 0.0003
    RMSPROP_DECAY = 0.99
    RMSPROP_MOMENTUM = 0.0
    RMSPROP_EPSILON = 0.1
    DUAL_RMSPROP = False
    USE_GRAD_CLIP = False
    GRAD_CLIP_NORM = 40.0 
    LOG_EPSILON = 1e-6
    TRAINING_MIN_BATCH_SIZE = 16
    USE_RNN = True
    NCELLS = 256
    MIN_POLICY = 0.0
    USE_LOG_SOFTMAX = True

ga3c_lstm_vs_ff

from ga3c.

wgeul avatar wgeul commented on June 26, 2024

TIME_MAX

Out of interest, can I ask why you've removed this page? What were your findings wrt performance of the addition of LSTM?

Edit: Found your model here: https://github.com/etienne87/GA3C , thanks!

from ga3c.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.