bayesiancnn/layers/misc.py

36 lines
911 B
Python
Raw Normal View History

2024-05-10 09:59:24 +00:00
from torch import nn
class ModuleWrapper(nn.Module):
"""Wrapper for nn.Module with support for arbitrary flags and a universal forward pass"""
def __init__(self):
super(ModuleWrapper, self).__init__()
def set_flag(self, flag_name, value):
setattr(self, flag_name, value)
for m in self.children():
if hasattr(m, 'set_flag'):
m.set_flag(flag_name, value)
def forward(self, x):
for module in self.children():
x = module(x)
kl = 0.0
for module in self.modules():
if hasattr(module, 'kl_loss'):
kl = kl + module.kl_loss()
return x, kl
class FlattenLayer(ModuleWrapper):
def __init__(self, num_features):
super(FlattenLayer, self).__init__()
self.num_features = num_features
def forward(self, x):
return x.view(-1, self.num_features)