Python VAE and generating X-Ray Data

John Thuma
4 min readJul 17, 2024

--

Python! you beautiful, wonderful, brilliant, flexible BEAST! I love you! :)

Okay, settle down!

I was asked by a colleague the other day to build some synthetic data and show how it can be done. “Do it quickly!” they cried. So I started playing around with the cloud providers LLM’s, and the Hugging Face small LM’s and said to myself: “Sure, I can do it this way, but I would rather not. I want to do it myself. I want to tune and control. I don’t want to use GPU’s and cloud technologies. I also may want to keep the IP to myself or share it with my pals.

So I went up on Kaggle and found some X-Ray data. (Click Here) and started to consider how I would generate some synthetic X-rays from just a couple of labelled training jpeg files, like the one below:

person1 Bacteria

I started to do some research on how I could approach creating new images based on a few samples. On my journey I stumbled upon VAE, or Variational Autoencoder. VAE is like this magical data making machine. It learns from examples, compresses the information into simple codes, and then uses these codes to create new, similar images or data. It’s a way for computers to learn and create in a way that’s a bit like how our brains might do it!

SOUNDS TECHNICAL, CAN YOU EXPLAIN IT TO THE REST OF US, THUMA?

VAE For business people, Variational Autoencoders (VAEs) are highly valuable because they can reveal important insights from large and complex datasets. By utilizing VAEs, organizations can uncover hidden patterns in their data that might inspire new product ideas, innovative marketing strategies, or improvements in their operations.

Additionally, VAEs can generate synthetic data for testing and development, saving both time and resources. Overall, VAEs offer a competitive edge by helping businesses understand their data more deeply and make better-informed decisions.

So I put 12 images in a folder and wrote the following code using PyTorch and other Python Libraries. I also used my Macbook to do all the work.

Pretty Nice Computer

The code I used is below. I spent most of the time trying to improve the quality of the images by playing with different epochs and other parameters for loss and other variables/settings.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Custom Dataset Class
class XrayDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.images = [os.path.join(root, fname) for fname in os.listdir(root) if self.is_image_file(fname)]
print(f"Found {len(self.images)} images in {root}")

def is_image_file(self, filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'jpg', 'png'])

def __len__(self):
return len(self.images)

def __getitem__(self, index):
img_path = self.images[index]
with open(img_path, 'rb') as f:
img = Image.open(f).convert('L') # Convert to grayscale
if self.transform:
img = self.transform(img)
return img, 0 # Return a dummy label since we are not using labels

# Data loading and transformation
data_dir = '/Users/john.thuma/Desktop/FILES/CODE/xray_data'
transform = transforms.Compose([
transforms.Resize((256, 256)), # Increase image resolution
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] for grayscale
])

dataset = XrayDataset(root=data_dir, transform=transform)
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
print(f"Data loader created with {len(train_loader)} batches")

# Define the VAE model
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(256*256, 512)
self.fc21 = nn.Linear(512, 100) # Mean
self.fc22 = nn.Linear(512, 100) # Log variance
self.fc3 = nn.Linear(100, 512)
self.fc4 = nn.Linear(512, 256*256)

def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std

def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))

def forward(self, x):
x = x.view(-1, 256*256)
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar

# Reconstruction loss + KL divergence loss
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 256*256), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD

# Training loop
vae = VAE().to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) # Adjusted learning rate

##########################################################
epochs = 200 # Reduced number of epochs for demonstration
##########################################################

for epoch in range(epochs):
vae.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = vae(data)

# Ensure data is in the correct range
data = torch.clamp(data, 0, 1)

loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()

print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}')

# Generate new X-ray images
vae.eval()
with torch.no_grad():
z = torch.randn(64, 100).to(device) # Adjusted latent dimension
sample = vae.decode(z).cpu()
sample = sample.view(64, 1, 256, 256)

# Plot the generated images
fig, axes = plt.subplots(8, 8, figsize=(12, 12))
for i, ax in enumerate(axes.flatten()):
ax.imshow(sample[i].squeeze(), cmap='gray')
ax.axis('off')
plt.show()

The code above was developed in less than 4 hours and generated 64 images in seconds using VAE and PyTorch. This type of processing process can be used against other unstructured forms of data as well as structured data.

Below is the result: 64 brand spanking new X-ray images that I generated with the code above.

Generated X-ray Images

Conclusion: I am sure it can be better and there are many different ways to perform this type of generative use case. I would love to hear from you.

--

--

John Thuma

Experienced Data and Analytics guru. 30 years of hands-on keyboard experience. Love hiking, writing, reading, and constant learning. All content is my opinion.