local function main()
--torch.setnumthreads(10)
--print('threads: ', torch.getnumthreads())
g_make_deterministic(1)
state_train = {data=ptb.traindataset(params.batch_size)}
state_valid = {data=ptb.validdataset(params.batch_size)}
state_test = {data=ptb.testdataset(params.batch_size)}
params.vocab_size = ptb.vocab_size()
print('Network parameters')
print(params)
local states = {state_train, state_valid, state_test}
for _, state in pairs(states) do
reset_state(state)
end
setup()
-- load saved model before train
local saved_model
local file = io.open(params.model_path, "rb")
if file then
file:close()
saved_model = torch.load(params.model_path)
print('load from previous saved model')
end
model = saved_model or model
collectgarbage()
state_train.pos = model.state_train_pos or 1
params.lr = model.lr or params.lr
local step = model.step or 0
local epoch = model.epoch or 0
local total_cases = model.total_cases or 0
local tics = model.tics or 0
local beginning_time = torch.tic() - tics
local start_time = torch.tic() - tics
print('Starting training')
local words_per_step = params.seq_length * params.batch_size
local epoch_size = torch.floor(state_train.data:size(1) / params.seq_length)
local perps
while epoch < params.max_max_epoch do
local perp = fp(state_train)
if perps == nil then
perps = torch.zeros(epoch_size):add(perp)
end
perps[step % epoch_size + 1] = perp
step = step + 1
bp(state_train)
total_cases = total_cases + params.seq_length * params.batch_size
epoch = step /epoch_size
if step % torch.round(epoch_size / 10) == 10 then
local wps = torch.floor(total_cases / torch.toc(start_time))
local since_beginning = g_d(torch.toc(beginning_time)/60)
print('epoch = ' .. g_f3(epoch) ..
', train perp. = ' .. g_f3(torch.exp(perps:mean()))..
', wps = ' .. wps ..
', dw:norm() = ' .. g_f3(model.norm_dw) ..
', lr = ' .. g_f3(params.lr) ..
', since begining = '..since_beginning..' mins')
-- save model to model_path file
model.step = step
model.epoch = epoch
model.total_cases = total_cases
model.tics = torch.tic() - beginning_time
model.state_train_pos = state_train.pos
model.lr = params.lr
--clear_state()
torch.save(params.model_path, model)
end
if step % epoch_size == 0 then
run_valid()
if epoch > params.max_epoch then
params.lr = params.lr / params.decay
end
end
if step % 33 == 0 then
collectgarbage()
end
end
run_test()
print('training is over.')
end