36 lines
911 B
Python
36 lines
911 B
Python
|
from torch import nn
|
||
|
|
||
|
|
||
|
class ModuleWrapper(nn.Module):
|
||
|
"""Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(ModuleWrapper, self).__init__()
|
||
|
|
||
|
def set_flag(self, flag_name, value):
|
||
|
setattr(self, flag_name, value)
|
||
|
for m in self.children():
|
||
|
if hasattr(m, 'set_flag'):
|
||
|
m.set_flag(flag_name, value)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for module in self.children():
|
||
|
x = module(x)
|
||
|
|
||
|
kl = 0.0
|
||
|
for module in self.modules():
|
||
|
if hasattr(module, 'kl_loss'):
|
||
|
kl = kl + module.kl_loss()
|
||
|
|
||
|
return x, kl
|
||
|
|
||
|
|
||
|
class FlattenLayer(ModuleWrapper):
|
||
|
|
||
|
def __init__(self, num_features):
|
||
|
super(FlattenLayer, self).__init__()
|
||
|
self.num_features = num_features
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x.view(-1, self.num_features)
|