[docs]class Engine(object):
def __init__(self):
self.hooks = {}
[docs] def hook(self, name, state):
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Tensor is computed. The hook should have the following signature::
hook (grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
Example:
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
"""
if name in self.hooks:
self.hooks[name](state)
[docs] def train(self, network, iterator, maxepoch, optimizer):
state = {
'network': network,
'iterator': iterator,
'maxepoch': maxepoch,
'optimizer': optimizer,
'epoch': 0,
't': 0,
'train': True,
}
self.hook('on_start', state)
while state['epoch'] < state['maxepoch']:
self.hook('on_start_epoch', state)
for sample in state['iterator']:
state['sample'] = sample
self.hook('on_sample', state)
def closure():
loss, output = state['network'](state['sample'])
state['output'] = output
state['loss'] = loss
loss.backward()
self.hook('on_forward', state)
# to free memory in save_for_backward
state['output'] = None
state['loss'] = None
return loss
state['optimizer'].zero_grad()
state['optimizer'].step(closure)
self.hook('on_update', state)
state['t'] += 1
state['epoch'] += 1
self.hook('on_end_epoch', state)
self.hook('on_end', state)
return state
[docs] def test(self, network, iterator):
state = {
'network': network,
'iterator': iterator,
't': 0,
'train': False,
}
self.hook('on_start', state)
for sample in state['iterator']:
state['sample'] = sample
self.hook('on_sample', state)
def closure():
loss, output = state['network'](state['sample'])
state['output'] = output
state['loss'] = loss
self.hook('on_forward', state)
# to free memory in save_for_backward
state['output'] = None
state['loss'] = None
closure()
state['t'] += 1
self.hook('on_end', state)
return state