36 lines
911 B
Python
Executable File
36 lines
911 B
Python
Executable File
from torch import nn
|
|
|
|
|
|
class ModuleWrapper(nn.Module):
|
|
"""Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""
|
|
|
|
def __init__(self):
|
|
super(ModuleWrapper, self).__init__()
|
|
|
|
def set_flag(self, flag_name, value):
|
|
setattr(self, flag_name, value)
|
|
for m in self.children():
|
|
if hasattr(m, 'set_flag'):
|
|
m.set_flag(flag_name, value)
|
|
|
|
def forward(self, x):
|
|
for module in self.children():
|
|
x = module(x)
|
|
|
|
kl = 0.0
|
|
for module in self.modules():
|
|
if hasattr(module, 'kl_loss'):
|
|
kl = kl + module.kl_loss()
|
|
|
|
return x, kl
|
|
|
|
|
|
class FlattenLayer(ModuleWrapper):
|
|
|
|
def __init__(self, num_features):
|
|
super(FlattenLayer, self).__init__()
|
|
self.num_features = num_features
|
|
|
|
def forward(self, x):
|
|
return x.view(-1, self.num_features)
|