kopia lustrzana https://github.com/lzzcd001/MeshDiffusion
98 wiersze
3.3 KiB
Python
98 wiersze
3.3 KiB
Python
|
# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py
|
||
|
|
||
|
from __future__ import division
|
||
|
from __future__ import unicode_literals
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
|
||
|
class ExponentialMovingAverage:
|
||
|
"""
|
||
|
Maintains (exponential) moving average of a set of parameters.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, parameters, decay, use_num_updates=True):
|
||
|
"""
|
||
|
Args:
|
||
|
parameters: Iterable of `torch.nn.Parameter`; usually the result of
|
||
|
`model.parameters()`.
|
||
|
decay: The exponential decay.
|
||
|
use_num_updates: Whether to use number of updates when computing
|
||
|
averages.
|
||
|
"""
|
||
|
if decay < 0.0 or decay > 1.0:
|
||
|
raise ValueError('Decay must be between 0 and 1')
|
||
|
self.decay = decay
|
||
|
self.num_updates = 0 if use_num_updates else None
|
||
|
self.shadow_params = [p.clone().detach()
|
||
|
for p in parameters if p.requires_grad]
|
||
|
self.collected_params = []
|
||
|
|
||
|
def update(self, parameters):
|
||
|
"""
|
||
|
Update currently maintained parameters.
|
||
|
|
||
|
Call this every time the parameters are updated, such as the result of
|
||
|
the `optimizer.step()` call.
|
||
|
|
||
|
Args:
|
||
|
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
||
|
parameters used to initialize this object.
|
||
|
"""
|
||
|
decay = self.decay
|
||
|
if self.num_updates is not None:
|
||
|
self.num_updates += 1
|
||
|
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||
|
one_minus_decay = 1.0 - decay
|
||
|
with torch.no_grad():
|
||
|
parameters = [p for p in parameters if p.requires_grad]
|
||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||
|
s_param.sub_(one_minus_decay * (s_param - param))
|
||
|
|
||
|
def copy_to(self, parameters):
|
||
|
"""
|
||
|
Copy current parameters into given collection of parameters.
|
||
|
|
||
|
Args:
|
||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||
|
updated with the stored moving averages.
|
||
|
"""
|
||
|
parameters = [p for p in parameters if p.requires_grad]
|
||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||
|
if param.requires_grad:
|
||
|
param.data.copy_(s_param.data)
|
||
|
|
||
|
def store(self, parameters):
|
||
|
"""
|
||
|
Save the current parameters for restoring later.
|
||
|
|
||
|
Args:
|
||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||
|
temporarily stored.
|
||
|
"""
|
||
|
self.collected_params = [param.clone() for param in parameters]
|
||
|
|
||
|
def restore(self, parameters):
|
||
|
"""
|
||
|
Restore the parameters stored with the `store` method.
|
||
|
Useful to validate the model with EMA parameters without affecting the
|
||
|
original optimization process. Store the parameters before the
|
||
|
`copy_to` method. After validation (or model saving), use this to
|
||
|
restore the former parameters.
|
||
|
|
||
|
Args:
|
||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||
|
updated with the stored parameters.
|
||
|
"""
|
||
|
for c_param, param in zip(self.collected_params, parameters):
|
||
|
param.data.copy_(c_param.data)
|
||
|
|
||
|
def state_dict(self):
|
||
|
return dict(decay=self.decay, num_updates=self.num_updates,
|
||
|
shadow_params=self.shadow_params)
|
||
|
|
||
|
def load_state_dict(self, state_dict):
|
||
|
self.decay = state_dict['decay']
|
||
|
self.num_updates = state_dict['num_updates']
|
||
|
self.shadow_params = state_dict['shadow_params']
|