Hi ,
RUMPy is a great framework!I successfully trained the EDSR model according to the documentation, but when I used the same data set to train the Q-EDSR model. There is an runtime error . It seems to be a problem with the model deconstruction.
Q-EDSR training configuration file refers to the “Documentation/sample_config_files/div2k/q-edsr.toml”
Running epoch 0
Training cleared to run.
Training Run:
train_data <torch.utils.data.dataloader.DataLoader object at 0x7f2ab9585f10>
0%| | 0/500 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/ts03/anaconda3/envs/rumpy/bin/train_sisr", line 33, in
sys.exit(load_entry_point('RUMpy', 'console_scripts', 'train_sisr')())
File "/home/ts03/anaconda3/envs/rumpy/lib/python3.7/site-packages/click/core.py", line 1130, in call
return self.main(*args, **kwargs)
File "/home/ts03/anaconda3/envs/rumpy/lib/python3.7/site-packages/click/core.py", line 1055, in main
rv = self.invoke(ctx)
File "/home/ts03/anaconda3/envs/rumpy/lib/python3.7/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/ts03/anaconda3/envs/rumpy/lib/python3.7/site-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/shared_framework/net_train.py", line 95, in experiment_setup
experiment.run_experiment() # initiates model training
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/shared_framework/training/base_handler.py", line 345, in run_experiment
training_loss = self.train()
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/shared_framework/training/base_handler.py", line 220, in train
losses, _ = self.model.train_batch(**batch) # entire training scheme occurs here
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/SISR/models/interface.py", line 101, in train_batch
return self.model.run_train(x=lr, y=hr, **kwargs)
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/SISR/models/attention_manipulators/init.py", line 186, in run_train
input_data, extra_channels = self.channel_concat_logic(x, extra_channels, metadata, metadata_keys)
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/SISR/models/attention_manipulators/init.py", line 152, in channel_concat_logic
extra_channels = self.generate_channels(x, metadata, metadata_keys)
File "/home/ts03/workspace/ricky/rumpy/RUMpy/rumpy/SISR/models/attention_manipulators/init.py", line 103, in generate_channels
extra_channels[index, ...] = extra_channels[index, :] * added_info
RuntimeError: The size of tensor a (10) must match the size of tensor b (0) at non-singleton dimension 0