How to train neural network over days?

The name of the pictureThe name of the pictureThe name of the pictureClash Royale CLAN TAG#URR8PPP


How to train neural network over days?



I need to train a CNN that will take 1-2 days to train on a remotely accessed GPU server.



Will I simply need to leave my laptop on overnight for the training to be complete or is there a way to save the state of the training and resume from there the next day?



(Implementation in pytorch)




2 Answers
2



If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:


state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}

torch.save(state, filepath)



To resume training you would do things like: state = torch.load(filepath), and then, to restore the state of each individual object, something like this:


state = torch.load(filepath)


model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(stata['optimizer'])



Since you are resuming training, DO NOT call model.eval() once you restore the states when loading.



To read more about this or see actual examples: https://www.programcreek.com/python/example/101175/torch.save



I assume you ssh into you remote server. When training the model by running your script, say, $ python train.py, simply pre-append nohup:


$ python train.py


nohup


$ nohup python train.py



This tells your process to disregard the hangup signal when you exit the ssh session and shut down your laptop.






By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.

Popular posts from this blog

Keycloak server returning user_not_found error when user is already imported with LDAP

Using generate_series in ecto and passing a value

PHP parse/syntax errors; and how to solve them?