Code Monkey home page Code Monkey logo

coordinated-multi-agent-imitation-learning's People

Contributors

samshipengs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

coordinated-multi-agent-imitation-learning's Issues

Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?)

In JointTraining, the nested policy with nested horizon produces error below:

Wroking on policy 0
Horizon 0 ==========
Epoch 0 | loss: 85.31 | time took: 1.91s | validation loss: 36.48
Total time took: 0.05hrs
Horizon 2 ==========

InvalidArgumentError Traceback (most recent call last)
C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1360 try:
-> 1361 return fn(*args)
1362 except errors.OpError as e:

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
1339 return tf_session.TF_Run(session, options, feed_dict, fetch_list,
-> 1340 target_list, status, run_metadata)
1341

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\errors_impl.py in exit(self, type_arg, value_arg, traceback_arg)
515 compat.as_text(c_api.TF_Message(self.status.status)),
--> 516 c_api.TF_GetCode(self.status.status))
517 # Delete the underlying status object from memory otherwise it stays alive

InvalidArgumentError: TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

During handling of the above exception, another exception occurred:

InvalidArgumentError Traceback (most recent call last)
in ()
1 batch_size = 32
----> 2 train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)

C:\Users\sshi\Desktop\raptors\code\train.py in train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)
31 for batch in iterate_minibatches(train_game, train_target, batch_size, shuffle=False):
32 train_xi, train_yi = batch
---> 33 p, l, _, train_sum = model.train(train_xi, train_yi, k)
34 model.train_writer.add_summary(train_sum, train_step)
35 epoch_loss += l/n_train_batch

C:\Users\sshi\Desktop\raptors\code\model.py in train(self, train_xi, train_yi, k)
124 def train(self, train_xi, train_yi, k):
125 return self.sess.run([self.pred, self.loss, self.opt, self.train_summary],
--> 126 feed_dict={self.X: train_xi, self.Y: train_yi, self.h: k})
127
128 def validate(self, val_xi, val_yi, k):

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1135 if final_fetches or final_targets or (handle and feed_dict_tensor):
1136 results = self._do_run(handle, final_targets, final_fetches,
-> 1137 feed_dict_tensor, options, run_metadata)
1138 else:
1139 results = []

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1353 if handle is None:
1354 return self._do_call(_run_fn, self._session, feeds, fetches, targets,
-> 1355 options, run_metadata)
1356 else:
1357 return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1372 except KeyError:
1373 pass
-> 1374 raise type(e)(node_def, op, message)
1375
1376 def _extend_graph(self):

InvalidArgumentError: TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

Caused by op 'rnn/while/cond/cond/TensorArrayReadV3_2', defined at:
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel_launcher.py", line 16, in
app.launch_new_instance()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
app.start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
ioloop.IOLoop.instance().start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\ioloop.py", line 888, in start
handler_func(fd_obj, events)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
handler(stream, idents, msg)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2717, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2827, in run_ast_nodes
if self.run_code(code, result):
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2881, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 2, in
train_all_single_policies(batch_size, sequence_length, train_game, train_target, test_game, test_target, models_path)
File "C:\Users\sshi\Desktop\raptors\code\train.py", line 14, in train_all_single_policies
learning_rate=0.01, seq_len=sequence_length-1)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 85, in init
policy_number=self.policy_number)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 41, in dynamic_raw_rnn
outputs_ta, last_state, _ = tf.nn.raw_rnn(cell, loop_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 1154, in raw_rnn
swap_memory=swap_memory)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3096, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2874, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2814, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\rnn.py", line 1115, in body
next_time, next_output, cell_state, loop_state)
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 33, in loop_fn
lambda: tf.cond(tf.equal(tf.mod(time, horizon+1), tf.constant(0)),
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2027, in cond
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1868, in BuildCondBranch
original_result = fn()
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 35, in
lambda: tf.concat((inputs_ta.read(time)[:, :policy_number
player_fts],
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py", line 432, in new_func
return func(args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2027, in cond
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 1868, in BuildCondBranch
original_result = fn()
File "C:\Users\sshi\Desktop\raptors\code\model.py", line 37, in
inputs_ta.read(time)[:, policy_number
player_fts+2:]), axis=1)))
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py", line 58, in fn
return method(self, *args, **kwargs)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py", line 861, in read
return self._implementation.read(index, name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py", line 260, in read
name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_data_flow_ops.py", line 4970, in _tensor_array_read_v3
dtype=dtype, name=name)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 3271, in create_op
op_def=op_def)
File "C:\Users\sshi\AppData\Local\Continuum\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1650, in init
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): TensorArray TensorArray_3809: Could not read index 1 twice because it was cleared after a previous read (perhaps try setting clear_after_read = false?).
[[Node: rnn/while/cond/cond/TensorArrayReadV3_2 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/cond/cond/TensorArrayReadV3_1/Switch, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_1, rnn/while/cond/cond/TensorArrayReadV3_1/Switch_2)]]

Roll out horizon

Currently the model is just using dynamic RNN, the roll out implemented using raw_rnn is not correct, need to fix that.

The result of using roll-out with raw_rnn horizon=0 should be equivalent to using regular rnn, but the result they produce is different:

  1. Using regular dynamic_rnn

Horizon Tensor("Placeholder_1:0", dtype=int32) ==========
Epoch 0 | loss: 231.87 | time took: 0.72s | validation loss: 158.90
Epoch 100 | loss: 7.85 | time took: 0.54s | validation loss: 10.80
Epoch 200 | loss: 6.92 | time took: 0.54s | validation loss: 10.02
Epoch 300 | loss: 6.06 | time took: 0.54s | validation loss: 8.95
Epoch 400 | loss: 5.92 | time took: 0.54s | validation loss: 8.73
Epoch 500 | loss: 5.38 | time took: 0.54s | validation loss: 8.24
Epoch 600 | loss: 5.32 | time took: 0.54s | validation loss: 8.67
Epoch 700 | loss: 5.01 | time took: 0.55s | validation loss: 8.05
Epoch 800 | loss: 5.33 | time took: 0.54s | validation loss: 9.63
Epoch 900 | loss: 4.78 | time took: 0.54s | validation loss: 7.76
Total time took: 0.15hrs

  1. raw_rnn with horizon=0

Epoch 0 | loss: 229.10 | time took: 0.83s | validation loss: 154.11
Epoch 100 | loss: 16.10 | time took: 0.67s | validation loss: 32.90
Epoch 200 | loss: 11.69 | time took: 0.67s | validation loss: 32.35
Epoch 300 | loss: 11.01 | time took: 0.67s | validation loss: 31.40
Epoch 400 | loss: 8.76 | time took: 0.67s | validation loss: 28.07
Epoch 500 | loss: 7.99 | time took: 0.67s | validation loss: 26.93
Epoch 600 | loss: 7.22 | time took: 0.67s | validation loss: 29.64
Epoch 700 | loss: 8.19 | time took: 0.67s | validation loss: 32.17
Epoch 800 | loss: 10.46 | time took: 0.67s | validation loss: 29.54
Epoch 900 | loss: 9.72 | time took: 0.67s | validation loss: 27.48
Total time took: 0.19hrs

which means there is probably something off in the raw_rnn implementation.

Mulitple players share same (conventional roles)

There are times that several players on the court were all F, initially I was trying to add more empty slots to compensate for this, i.e. create two more players slots for each role,
(p1x,p1y, p2x,p2y, ..., p7x,p7y), so the first 3 are reserved for F then in the case when there are several players sharing the same role then they have the right place to be.

However, after some visualization, I realized that when there are multiple roles then the extra slots are just zeros, this is misleading to the model since the zeros here do not represent anything in terms of trajectory, it represents the team play does not have certain roles.

So which means, for now, it seems like we should just keep the roles to be 5 (this probably has the negative effects on the model learning i.e. when there are several same roles on the court, the role assigning has to assign the duplicated roles to other roles even they are not very similar).

Play by play data is not accurate

For example, In game 0021500196, event 2, 'time_left': [705, 704, 685, 684]}, 'event_str': ['miss', 'rebound', 'miss', 'rebound'],
and match these with the shot clock left and the court visualization, the time_left in play-by-play seems like not describing the event segmentation correctly.

For 685.0 the shot clock is at 21.77, which at the time the shot was already missed for a while and the defending team got rebound and was already switching to offense. The event miss should be marked right after 24s shot clock reset.

Add more game data

Currently, it's only using one single game data, need to include all the games.

Overfitting

2018-05-19 20:15:18,769 | INFO : Training with hyper parameters:
{'use_model': 'dynamic_rnn_layer_norm', 'batch_size': 64, 'sequence_length': 50, 'overlap': 25, 'state_size': [128, 128], 'use_peepholes': None, 'input_dim': 179, 'dropout_rate': 0.6, 'learning_rate': 0.0001, 'n_epoch': 1000}

2018-05-19 20:15:22,532 | INFO : Horizon 0 ==========
2018-05-19 20:15:58,637 | INFO : Epoch 0 | loss: 424.64 | time took: 34.02s | validation loss: 370.09
2018-05-19 20:21:29,900 | INFO : Epoch 10 | loss: 13.52 | time took: 32.61s | validation loss: 10.31
2018-05-19 20:26:59,878 | INFO : Epoch 20 | loss: 3.62 | time took: 32.94s | validation loss: 3.74
2018-05-19 20:32:30,784 | INFO : Epoch 30 | loss: 3.08 | time took: 32.63s | validation loss: 2.86
2018-05-19 20:37:59,144 | INFO : Epoch 40 | loss: 2.90 | time took: 32.77s | validation loss: 2.74
2018-05-19 20:43:26,475 | INFO : Epoch 50 | loss: 2.79 | time took: 32.46s | validation loss: 2.58
2018-05-19 20:48:53,346 | INFO : Epoch 60 | loss: 2.70 | time took: 32.50s | validation loss: 2.70
2018-05-19 20:54:20,320 | INFO : Epoch 70 | loss: 2.63 | time took: 32.46s | validation loss: 2.56
2018-05-19 20:59:47,275 | INFO : Epoch 80 | loss: 2.57 | time took: 32.45s | validation loss: 2.60
2018-05-19 21:05:14,118 | INFO : Epoch 90 | loss: 2.50 | time took: 32.50s | validation loss: 2.59
2018-05-19 21:10:43,281 | INFO : Epoch 100 | loss: 2.43 | time took: 32.98s | validation loss: 2.55
2018-05-19 21:16:13,949 | INFO : Epoch 110 | loss: 2.37 | time took: 33.29s | validation loss: 2.57
2018-05-19 21:21:45,966 | INFO : Epoch 120 | loss: 2.30 | time took: 32.93s | validation loss: 2.51
2018-05-19 21:27:17,978 | INFO : Epoch 130 | loss: 2.24 | time took: 32.95s | validation loss: 2.44
2018-05-19 21:32:49,640 | INFO : Epoch 140 | loss: 2.18 | time took: 32.94s | validation loss: 2.72
2018-05-19 21:38:21,280 | INFO : Epoch 150 | loss: 2.12 | time took: 32.89s | validation loss: 2.48
2018-05-19 21:43:51,814 | INFO : Epoch 160 | loss: 2.07 | time took: 32.71s | validation loss: 2.62
2018-05-19 21:49:22,015 | INFO : Epoch 170 | loss: 2.02 | time took: 32.80s | validation loss: 2.57
2018-05-19 21:54:51,845 | INFO : Epoch 180 | loss: 1.96 | time took: 32.79s | validation loss: 2.57
2018-05-19 22:00:21,549 | INFO : Epoch 190 | loss: 1.91 | time took: 32.69s | validation loss: 2.62
2018-05-19 22:05:51,438 | INFO : Epoch 200 | loss: 1.86 | time took: 32.75s | validation loss: 2.52
2018-05-19 22:11:21,148 | INFO : Epoch 210 | loss: 1.82 | time took: 32.73s | validation loss: 2.60
2018-05-19 22:16:50,945 | INFO : Epoch 220 | loss: 1.78 | time took: 32.74s | validation loss: 2.67
2018-05-19 22:22:21,564 | INFO : Epoch 230 | loss: 1.74 | time took: 32.80s | validation loss: 2.65
2018-05-19 22:27:51,485 | INFO : Epoch 240 | loss: 1.70 | time took: 32.78s | validation loss: 2.84
2018-05-19 22:33:21,232 | INFO : Epoch 250 | loss: 1.66 | time took: 32.75s | validation loss: 2.59
2018-05-19 22:38:51,611 | INFO : Epoch 260 | loss: 1.62 | time took: 33.05s | validation loss: 2.66
2018-05-19 22:44:21,437 | INFO : Epoch 270 | loss: 1.59 | time took: 32.74s | validation loss: 2.68
2018-05-19 22:49:51,152 | INFO : Epoch 280 | loss: 1.55 | time took: 32.75s | validation loss: 2.61
2018-05-19 22:55:20,930 | INFO : Epoch 290 | loss: 1.52 | time took: 32.71s | validation loss: 2.55
2018-05-19 23:00:50,665 | INFO : Epoch 300 | loss: 1.49 | time took: 32.75s | validation loss: 2.74
2018-05-19 23:06:20,603 | INFO : Epoch 310 | loss: 1.46 | time took: 32.75s | validation loss: 2.80
2018-05-19 23:11:50,676 | INFO : Epoch 320 | loss: 1.43 | time took: 32.75s | validation loss: 2.75
2018-05-19 23:17:20,512 | INFO : Epoch 330 | loss: 1.40 | time took: 32.76s | validation loss: 2.78
2018-05-19 23:22:50,682 | INFO : Epoch 340 | loss: 1.37 | time took: 32.76s | validation loss: 3.07
2018-05-19 23:28:20,826 | INFO : Epoch 350 | loss: 1.34 | time took: 32.75s | validation loss: 2.82
2018-05-19 23:33:50,696 | INFO : Epoch 360 | loss: 1.32 | time took: 32.74s | validation loss: 2.77
2018-05-19 23:39:20,657 | INFO : Epoch 370 | loss: 1.30 | time took: 32.94s | validation loss: 2.82
2018-05-19 23:44:50,417 | INFO : Epoch 380 | loss: 1.27 | time took: 32.84s | validation loss: 2.94
2018-05-19 23:50:20,729 | INFO : Epoch 390 | loss: 1.24 | time took: 32.97s | validation loss: 2.79
2018-05-19 23:55:55,456 | INFO : Epoch 400 | loss: 1.22 | time took: 33.06s | validation loss: 2.88
2018-05-20 00:01:30,456 | INFO : Epoch 410 | loss: 1.20 | time took: 32.58s | validation loss: 2.94
2018-05-20 00:07:07,730 | INFO : Epoch 420 | loss: 1.18 | time took: 33.55s | validation loss: 3.09
2018-05-20 00:12:44,634 | INFO : Epoch 430 | loss: 1.16 | time took: 33.37s | validation loss: 2.78
2018-05-20 00:18:14,081 | INFO : Epoch 440 | loss: 1.14 | time took: 32.67s | validation loss: 3.27
2018-05-20 00:23:47,005 | INFO : Epoch 450 | loss: 1.12 | time took: 33.36s | validation loss: 2.82
2018-05-20 00:29:12,431 | INFO : Epoch 460 | loss: 1.10 | time took: 32.09s | validation loss: 3.01
2018-05-20 00:34:35,043 | INFO : Epoch 470 | loss: 1.08 | time took: 32.03s | validation loss: 2.63
2018-05-20 00:39:58,028 | INFO : Epoch 480 | loss: 1.07 | time took: 32.10s | validation loss: 2.92
2018-05-20 00:45:20,698 | INFO : Epoch 490 | loss: 1.05 | time took: 32.07s | validation loss: 3.15
2018-05-20 00:50:43,453 | INFO : Epoch 500 | loss: 1.03 | time took: 32.09s | validation loss: 2.89
2018-05-20 00:56:06,164 | INFO : Epoch 510 | loss: 1.01 | time took: 32.08s | validation loss: 2.87
2018-05-20 01:01:28,880 | INFO : Epoch 520 | loss: 1.00 | time took: 32.04s | validation loss: 3.02
2018-05-20 01:06:52,427 | INFO : Epoch 530 | loss: 0.98 | time took: 32.94s | validation loss: 2.94
2018-05-20 01:12:15,404 | INFO : Epoch 540 | loss: 0.96 | time took: 32.09s | validation loss: 2.90
2018-05-20 01:17:38,158 | INFO : Epoch 550 | loss: 0.96 | time took: 32.09s | validation loss: 2.80
2018-05-20 01:23:01,047 | INFO : Epoch 560 | loss: 0.93 | time took: 32.13s | validation loss: 3.14
2018-05-20 01:28:23,858 | INFO : Epoch 570 | loss: 0.93 | time took: 32.08s | validation loss: 2.94
2018-05-20 01:33:46,562 | INFO : Epoch 580 | loss: 0.91 | time took: 32.07s | validation loss: 2.79
2018-05-20 01:39:09,246 | INFO : Epoch 590 | loss: 0.90 | time took: 32.07s | validation loss: 3.16
2018-05-20 01:44:31,918 | INFO : Epoch 600 | loss: 0.89 | time took: 32.08s | validation loss: 2.85
2018-05-20 01:49:54,587 | INFO : Epoch 610 | loss: 0.87 | time took: 32.09s | validation loss: 2.90
2018-05-20 01:55:17,323 | INFO : Epoch 620 | loss: 0.86 | time took: 32.03s | validation loss: 2.67
2018-05-20 02:00:40,017 | INFO : Epoch 630 | loss: 0.85 | time took: 32.07s | validation loss: 3.18
2018-05-20 02:06:02,918 | INFO : Epoch 640 | loss: 0.84 | time took: 32.09s | validation loss: 3.01
2018-05-20 02:11:25,634 | INFO : Epoch 650 | loss: 0.83 | time took: 32.05s | validation loss: 2.93
2018-05-20 02:16:48,348 | INFO : Epoch 660 | loss: 0.82 | time took: 32.06s | validation loss: 3.02
2018-05-20 02:22:11,146 | INFO : Epoch 670 | loss: 0.81 | time took: 32.07s | validation loss: 2.79
2018-05-20 02:27:33,876 | INFO : Epoch 680 | loss: 0.80 | time took: 32.12s | validation loss: 2.88
2018-05-20 02:32:56,727 | INFO : Epoch 690 | loss: 0.79 | time took: 32.13s | validation loss: 2.95
2018-05-20 02:38:19,538 | INFO : Epoch 700 | loss: 0.77 | time took: 32.07s | validation loss: 2.66
2018-05-20 02:43:42,278 | INFO : Epoch 710 | loss: 0.76 | time took: 32.05s | validation loss: 2.89
2018-05-20 02:49:05,039 | INFO : Epoch 720 | loss: 0.76 | time took: 32.11s | validation loss: 2.81
2018-05-20 02:54:27,839 | INFO : Epoch 730 | loss: 0.75 | time took: 32.08s | validation loss: 2.98
2018-05-20 02:59:50,629 | INFO : Epoch 740 | loss: 0.74 | time took: 32.08s | validation loss: 2.97
2018-05-20 03:05:13,247 | INFO : Epoch 750 | loss: 0.73 | time took: 32.07s | validation loss: 2.82
2018-05-20 03:10:35,949 | INFO : Epoch 760 | loss: 0.72 | time took: 32.04s | validation loss: 2.81
2018-05-20 03:15:58,732 | INFO : Epoch 770 | loss: 0.72 | time took: 32.07s | validation loss: 2.86
2018-05-20 03:21:21,429 | INFO : Epoch 780 | loss: 0.70 | time took: 32.07s | validation loss: 2.81
2018-05-20 03:26:44,166 | INFO : Epoch 790 | loss: 0.70 | time took: 32.07s | validation loss: 2.73
2018-05-20 03:32:06,811 | INFO : Epoch 800 | loss: 0.69 | time took: 32.07s | validation loss: 2.93
2018-05-20 03:37:29,600 | INFO : Epoch 810 | loss: 0.68 | time took: 32.13s | validation loss: 3.20
2018-05-20 03:42:52,533 | INFO : Epoch 820 | loss: 0.67 | time took: 32.09s | validation loss: 2.85
2018-05-20 03:48:15,239 | INFO : Epoch 830 | loss: 0.66 | time took: 32.05s | validation loss: 2.91
2018-05-20 03:53:38,068 | INFO : Epoch 840 | loss: 0.66 | time took: 32.11s | validation loss: 2.95
2018-05-20 03:59:00,843 | INFO : Epoch 850 | loss: 0.65 | time took: 32.06s | validation loss: 2.93
2018-05-20 04:04:23,524 | INFO : Epoch 860 | loss: 0.64 | time took: 32.05s | validation loss: 3.17
2018-05-20 04:09:46,182 | INFO : Epoch 870 | loss: 0.64 | time took: 32.06s | validation loss: 2.82
2018-05-20 04:15:08,897 | INFO : Epoch 880 | loss: 0.63 | time took: 32.09s | validation loss: 2.85
2018-05-20 04:20:31,599 | INFO : Epoch 890 | loss: 0.63 | time took: 32.07s | validation loss: 2.86
2018-05-20 04:25:54,164 | INFO : Epoch 900 | loss: 0.62 | time took: 32.08s | validation loss: 2.72
2018-05-20 04:31:16,932 | INFO : Epoch 910 | loss: 0.61 | time took: 32.08s | validation loss: 2.88
2018-05-20 04:36:39,548 | INFO : Epoch 920 | loss: 0.61 | time took: 32.05s | validation loss: 2.95
2018-05-20 04:42:02,369 | INFO : Epoch 930 | loss: 0.60 | time took: 32.03s | validation loss: 2.87
2018-05-20 04:47:25,195 | INFO : Epoch 940 | loss: 0.60 | time took: 32.07s | validation loss: 2.83
2018-05-20 04:52:47,975 | INFO : Epoch 950 | loss: 0.59 | time took: 32.13s | validation loss: 2.86
2018-05-20 04:58:10,736 | INFO : Epoch 960 | loss: 0.59 | time took: 32.06s | validation loss: 3.23
2018-05-20 05:03:33,409 | INFO : Epoch 970 | loss: 0.58 | time took: 32.09s | validation loss: 3.18
2018-05-20 05:08:56,194 | INFO : Epoch 980 | loss: 0.58 | time took: 32.10s | validation loss: 3.03
2018-05-20 05:14:18,889 | INFO : Epoch 990 | loss: 0.57 | time took: 32.08s | validation loss: 2.96
2018-05-20 05:19:07,667 | INFO : Total time took: 9.06hrs
2018-05-20 05:19:07,824 | INFO : Done saving model for policy 0

To dos

  1. Add basketball trajectory visualization with players.
  2. Randomize train and test
  3. Double check random in iterbatch
  4. Maybe try a different train and test split e.g. reduce test size to 1: 9

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.