Source code for torchnet.engine.engine

[docs]class Engine(object): def __init__(self): self.hooks = {}
[docs] def hook(self, name, state): r"""Registers a backward hook. The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature:: hook (grad) -> Tensor or None The hook should not modify its argument, but it can optionally return a new gradient which will be used in place of :attr:`grad`. This function returns a handle with a method ``handle.remove()`` that removes the hook from the module. Example: >>> v = torch.tensor([0., 0., 0.], requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.tensor([1., 2., 3.])) >>> v.grad 2 4 6 [torch.FloatTensor of size (3,)] >>> h.remove() # removes the hook """ if name in self.hooks: self.hooks[name](state)
[docs] def train(self, network, iterator, maxepoch, optimizer): state = { 'network': network, 'iterator': iterator, 'maxepoch': maxepoch, 'optimizer': optimizer, 'epoch': 0, 't': 0, 'train': True, } self.hook('on_start', state) while state['epoch'] < state['maxepoch']: self.hook('on_start_epoch', state) for sample in state['iterator']: state['sample'] = sample self.hook('on_sample', state) def closure(): loss, output = state['network'](state['sample']) state['output'] = output state['loss'] = loss loss.backward() self.hook('on_forward', state) # to free memory in save_for_backward state['output'] = None state['loss'] = None return loss state['optimizer'].zero_grad() state['optimizer'].step(closure) self.hook('on_update', state) state['t'] += 1 state['epoch'] += 1 self.hook('on_end_epoch', state) self.hook('on_end', state) return state
[docs] def test(self, network, iterator): state = { 'network': network, 'iterator': iterator, 't': 0, 'train': False, } self.hook('on_start', state) for sample in state['iterator']: state['sample'] = sample self.hook('on_sample', state) def closure(): loss, output = state['network'](state['sample']) state['output'] = output state['loss'] = loss self.hook('on_forward', state) # to free memory in save_for_backward state['output'] = None state['loss'] = None closure() state['t'] += 1 self.hook('on_end', state) return state