diff --git a/beta_vae_torch.py b/beta_vae_torch.py new file mode 100644 index 000000000..74870fb12 --- /dev/null +++ b/beta_vae_torch.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + + +class BetaVAE(nn.Module): + def __init__(self, input_dim, latent_dim, beta=1.0): + super(BetaVAE, self).__init__() + self.beta = beta + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(input_dim, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU() + ) + + self.fc_mu = nn.Linear(256, latent_dim) + self.fc_logvar = nn.Linear(256, latent_dim) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(latent_dim, 256), + nn.ReLU(), + nn.Linear(256, 512), + nn.ReLU(), + nn.Linear(512, input_dim), + nn.Sigmoid() + ) + + def encode(self, x): + h = self.encoder(x) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + return mu, logvar + + def decode(self, z): + return self.decoder(z) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + return self.decode(z), mu, logvar + + +def loss_function(recon_x, x, mu, logvar, beta): + BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return BCE + beta * KLD + diff --git a/pyod/models/beta_vae_torch.py b/pyod/models/beta_vae_torch.py new file mode 100644 index 000000000..74870fb12 --- /dev/null +++ b/pyod/models/beta_vae_torch.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn + + +class BetaVAE(nn.Module): + def __init__(self, input_dim, latent_dim, beta=1.0): + super(BetaVAE, self).__init__() + self.beta = beta + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(input_dim, 512), + nn.ReLU(), + nn.Linear(512, 256), + nn.ReLU() + ) + + self.fc_mu = nn.Linear(256, latent_dim) + self.fc_logvar = nn.Linear(256, latent_dim) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(latent_dim, 256), + nn.ReLU(), + nn.Linear(256, 512), + nn.ReLU(), + nn.Linear(512, input_dim), + nn.Sigmoid() + ) + + def encode(self, x): + h = self.encoder(x) + mu = self.fc_mu(h) + logvar = self.fc_logvar(h) + return mu, logvar + + def decode(self, z): + return self.decoder(z) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + return self.decode(z), mu, logvar + + +def loss_function(recon_x, x, mu, logvar, beta): + BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return BCE + beta * KLD +