|
class PHMLinear(torch.nn.Module): |
|
def __init__(self, in_features: int, out_features: int, |
|
phm_dim: int, phm_rule: Union[None, nn.Parameter, nn.ParameterList, list, torch.Tensor] = None, |
|
bias: bool = True, w_init: str = "phm", c_init: str = "standard", |
|
learn_phm: bool = True) -> None: |
|
super(PHMLinear, self).__init__() |
|
assert w_init in ["phm", "glorot-normal", "glorot-uniform"] |
|
assert c_init in ["standard", "random"] |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.learn_phm = learn_phm |
|
self.phm_dim = phm_dim |
|
|
|
self.shared_phm = False |
|
if phm_rule is not None: |
|
self.shared_phm = True |
|
self.phm_rule = phm_rule |
|
if not isinstance(phm_rule, nn.ParameterList) and learn_phm: |
|
self.phm_rule = nn.ParameterList([nn.Parameter(mat, requires_grad=learn_phm) for mat in self.phm_rule]) |
|
else: |
|
self.phm_rule = get_multiplication_matrices(phm_dim, type=c_init) |
|
|
|
self.phm_rule = nn.ParameterList([nn.Parameter(mat, requires_grad=learn_phm) for mat in self.phm_rule]) |
|
|
|
self.bias_flag = bias |
|
self.w_init = w_init |
|
self.c_init = c_init |
|
self.W = nn.ParameterList([nn.Parameter(torch.Tensor(out_features, in_features), |
|
requires_grad=True) |
|
for _ in range(phm_dim)]) |
|
if self.bias_flag: |
|
self.b = nn.ParameterList( |
|
[nn.Parameter(torch.Tensor(out_features), requires_grad=True) for _ in range(phm_dim)] |
|
) |
|
else: |
|
self.register_parameter("b", None) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if self.w_init == "phm": |
|
W_init = phm_init(phm_dim=self.phm_dim, in_features=self.in_features, out_features=self.out_features) |
|
for W_param, W_i in zip(self.W, W_init): |
|
W_param.data = W_i.data |
|
|
|
elif self.w_init == "glorot-normal": |
|
for i in range(self.phm_dim): |
|
self.W[i] = glorot_normal(self.W[i]) |
|
elif self.w_init == "glorot-uniform": |
|
for i in range(self.phm_dim): |
|
self.W[i] = glorot_uniform(self.W[i]) |
|
else: |
|
raise ValueError |
|
if self.bias_flag: |
|
self.b[0].data.fill_(0.0) |
|
for bias in self.b[1:]: |
|
bias.data.fill_(0.2) |
|
|
|
if not self.shared_phm: |
|
phm_rule = get_multiplication_matrices(phm_dim=self.phm_dim, type=self.c_init) |
|
for i, init_data in enumerate(phm_rule): |
|
self.phm_rule[i].data = init_data |
|
|
|
def forward(self, x: torch.Tensor, phm_rule: Union[None, nn.ParameterList] = None) -> torch.Tensor: |
|
# #ToDo modify forward() functional so it can handle shared phm-rule contribution matrices. |
|
return matvec_product(W=self.W, x=x, bias=self.b, phm_rule=self.phm_rule) |
|
|
|
def __repr__(self): |
|
return '{}(in_features={}, out_features={}, ' \ |
|
'phm_dim={}, ' \ |
|
'bias={}, w_init={}, c_init={}, ' \ |
|
'learn_phm={})'.format(self.__class__.__name__, |
|
self.in_features, |
|
self.out_features, |
|
self.phm_dim, |
|
self.bias_flag, |
|
self.w_init, |
|
self.c_init, |
|
self.learn_phm) |