Federated Multi-channel VAEs in MNIST¶
In this notebook we illustrate a simple example of multi-channel data generated by rotating the MNIST dataset and modeling each angle of rotation as a data view/channel.
import torch
from torch.utils.data import Subset, DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
The set of parameters utilized in this tutorial can be configured as follows:
N_CENTERS = 4
N_ROUNDS = 10 # Number of iterations between all the centers training and the aggregation process.
N_EPOCHS = 15 # Number of epochs before aggregating
BATCH_SIZE = 48
LR = 1e-3 # Learning rate
N_CHANNELS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)
Device: cuda
To create the new channels, it is necessary to transform the dataset. A way to do it is to define a customized transform class that receives a set of angles (channels) and then apply it to the original image.
import torchvision.transforms.functional as TF
class MultiChannel:
"""Create a multi-channel version of each digit by rotating it"""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
return [TF.rotate(x, angle) for angle in self.angles]
angles = torch.linspace(0, 300, N_CHANNELS).tolist() # Number of rotations to generate (channels in the range 0-300 deg)
transform = transforms.Compose([
transforms.ToTensor(),
MultiChannel(angles)
])
dataset = datasets.MNIST('~/data/', download=True, train=True, transform=transform)
We can then visualize the results of the transformation for a couple of observations.
def plot_digit(mnist_point, angles=None):
images, label = mnist_point
fig, axs = plt.subplots(ncols=len(images), subplot_kw={'xticks': [], 'yticks': []})
for ax, image in zip(axs, images):
ax.imshow(image.reshape(28,28), cmap='gray')
if angles is not None:
for ax, angle in zip(axs, angles):
ax.set_title(f'{angle}$^\circ$')
plt.suptitle(f'Label: {label}', size=24, weight='bold')
plt.tight_layout()
return ax
plot_digit(dataset[0], angles)
plot_digit(dataset[2], angles)
plt.show()
Federated Averaging using iid splitting¶
def split_iid(dataset, n_centers):
""" Split PyTorch dataset randomly into n_centers """
n_obs_per_center = [len(dataset) // n_centers for _ in range(n_centers)]
return random_split(dataset, n_obs_per_center)
def federated_averaging(models, n_obs_per_client):
assert len(models) > 0, 'An empty list of models was passed.'
assert len(n_obs_per_client) == len(models), 'List with number of observations must have ' \
'the same number of elements that list of models.'
# Compute proportions
n_obs = sum(n_obs_per_client)
proportions = [n_k / n_obs for n_k in n_obs_per_client]
# Empty model parameter dictionary
avg_params = models[0].state_dict()
for key, val in avg_params.items():
avg_params[key] = torch.zeros_like(val)
# Compute average
for model, proportion in zip(models, proportions):
for key in avg_params.keys():
avg_params[key] += proportion * model.state_dict()[key]
# Copy one of the models and load trained params
avg_model = copy.deepcopy(models[0])
avg_model.load_state_dict(avg_params)
return avg_model
federated_dataset = split_iid(dataset, n_centers=N_CENTERS)
print('Number of centers:', len(federated_dataset))
Number of centers: 4
Defining and distributing a model: Multi-channel Variational Autoencoder¶
In this excercise we will use the Multi-channel Variational Autoencoder proposed by Antelmi et al.
!pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
Building wheel for mcvae (setup.py) ... ?25l?25hdone
import copy
from mcvae.models import Mcvae, ThreeLayersVAE, VAE
uuid ... memory.total
index ...
0 GPU-15b8e8d9-4cfc-b229-ed4c-a9c8fdd3adc9 ... 11441
[1 rows x 6 columns]
Then we define a set of parameters necessary to instantiate the model.
N_FEATURES = 784 # Number of pixels in MNIST
dummy_data = [torch.zeros(1, N_FEATURES) for _ in range(N_CHANNELS)] # Dummy data to initialize the input layer size
lat_dim = 3 # Size of the latent space for this autoencoder
vae_class = ThreeLayersVAE # Architecture of the autoencoder (VAE: Single layer)
A model is then defined and then send (copied) to the clients for training on their available data.
model = Mcvae(data=dummy_data, lat_dim=lat_dim, vaeclass=vae_class).to(DEVICE)
model.optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
model.init_loss()
Now replicate a copy of the models across different centers.
models = [copy.deepcopy(model) for _ in range(N_CENTERS)]
n_obs_per_client = [len(client_data) for client_data in federated_dataset]
Train in a federated fashion
def get_data(subset, shuffle=True):
""" Extracts data from a Subset torch dataset in the form of a tensor"""
loader = DataLoader(subset, batch_size=len(subset), shuffle=shuffle)
return iter(loader).next()
init_params = model.state_dict()
for round_i in range(N_ROUNDS):
print(f'=============================== ROUND {round_i + 1} ===============================')
for client_dataset, client_model in zip(federated_dataset, models):
# Load client data in the form of a tensor
X, y = get_data(client_dataset)
X = [x.view(-1, N_FEATURES).to(DEVICE) for x in X] # Flatten for Linear layers
client_model.data = X
# Load client's model parameters and train
client_model.load_state_dict(init_params)
client_model.optimize(epochs=N_EPOCHS, data=client_model.data)
# Aggregate models using federated averaging
trained_model = federated_averaging(models, n_obs_per_client)
init_params = trained_model.state_dict()
=============================== ROUND 1 ===============================
====> Epoch: 0/15 (0%) Loss: 8461.6523 LL: -8461.6309 KL: 0.0211 LL/KL: -400548.4064
====> Epoch: 10/15 (67%) Loss: 1105.4454 LL: -1083.1682 KL: 22.2773 LL/KL: -48.6221
====> Epoch: 0/15 (0%) Loss: 8462.0410 LL: -8462.0195 KL: 0.0211 LL/KL: -401301.4389
====> Epoch: 10/15 (67%) Loss: 1113.8654 LL: -1091.7914 KL: 22.0739 LL/KL: -49.4606
====> Epoch: 0/15 (0%) Loss: 8522.7480 LL: -8522.7266 KL: 0.0211 LL/KL: -404286.2857
====> Epoch: 10/15 (67%) Loss: 1139.8148 LL: -1117.2366 KL: 22.5782 LL/KL: -49.4829
====> Epoch: 0/15 (0%) Loss: 8539.5166 LL: -8539.4951 KL: 0.0211 LL/KL: -404330.8462
====> Epoch: 10/15 (67%) Loss: 1144.1252 LL: -1121.5454 KL: 22.5799 LL/KL: -49.6702
=============================== ROUND 2 ===============================
====> Epoch: 20/30 (67%) Loss: 673.3018 LL: -594.1547 KL: 79.1472 LL/KL: -7.5070
====> Epoch: 20/30 (67%) Loss: 682.7709 LL: -603.8767 KL: 78.8942 LL/KL: -7.6543
====> Epoch: 20/30 (67%) Loss: 700.8270 LL: -621.7188 KL: 79.1082 LL/KL: -7.8591
====> Epoch: 20/30 (67%) Loss: 704.5477 LL: -625.6259 KL: 78.9218 LL/KL: -7.9272
=============================== ROUND 3 ===============================
====> Epoch: 30/45 (67%) Loss: 421.6263 LL: -346.3453 KL: 75.2809 LL/KL: -4.6007
====> Epoch: 40/45 (89%) Loss: 0.4174 LL: 48.2205 KL: 48.6379 LL/KL: 0.9914
====> Epoch: 30/45 (67%) Loss: 426.5792 LL: -351.2788 KL: 75.3004 LL/KL: -4.6650
====> Epoch: 40/45 (89%) Loss: -3.4588 LL: 51.9528 KL: 48.4940 LL/KL: 1.0713
====> Epoch: 30/45 (67%) Loss: 442.3616 LL: -366.7449 KL: 75.6168 LL/KL: -4.8500
====> Epoch: 40/45 (89%) Loss: 32.3111 LL: 16.3418 KL: 48.6530 LL/KL: 0.3359
====> Epoch: 30/45 (67%) Loss: 445.7666 LL: -370.0311 KL: 75.7355 LL/KL: -4.8858
====> Epoch: 40/45 (89%) Loss: 27.1707 LL: 21.3217 KL: 48.4924 LL/KL: 0.4397
=============================== ROUND 4 ===============================
====> Epoch: 50/60 (83%) Loss: -371.7141 LL: 411.0286 KL: 39.3145 LL/KL: 10.4549
====> Epoch: 50/60 (83%) Loss: -380.0650 LL: 419.4107 KL: 39.3457 LL/KL: 10.6596
====> Epoch: 50/60 (83%) Loss: -355.8457 LL: 395.3784 KL: 39.5327 LL/KL: 10.0013
====> Epoch: 50/60 (83%) Loss: -365.0667 LL: 404.5761 KL: 39.5094 LL/KL: 10.2400
=============================== ROUND 5 ===============================
====> Epoch: 60/75 (80%) Loss: -733.7194 LL: 781.1589 KL: 47.4396 LL/KL: 16.4664
====> Epoch: 70/75 (93%) Loss: -1148.3826 LL: 1198.3474 KL: 49.9649 LL/KL: 23.9838
====> Epoch: 60/75 (80%) Loss: -738.9148 LL: 786.3655 KL: 47.4507 LL/KL: 16.5722
====> Epoch: 70/75 (93%) Loss: -1158.9563 LL: 1208.7041 KL: 49.7478 LL/KL: 24.2966
====> Epoch: 60/75 (80%) Loss: -718.3889 LL: 766.0316 KL: 47.6427 LL/KL: 16.0787
====> Epoch: 70/75 (93%) Loss: -1131.4376 LL: 1181.5215 KL: 50.0838 LL/KL: 23.5909
====> Epoch: 60/75 (80%) Loss: -731.1516 LL: 778.8954 KL: 47.7438 LL/KL: 16.3141
====> Epoch: 70/75 (93%) Loss: -1149.3083 LL: 1199.0266 KL: 49.7183 LL/KL: 24.1164
=============================== ROUND 6 ===============================
====> Epoch: 80/90 (89%) Loss: -1368.0773 LL: 1415.7279 KL: 47.6507 LL/KL: 29.7105
====> Epoch: 80/90 (89%) Loss: -1347.8184 LL: 1394.8962 KL: 47.0778 LL/KL: 29.6296
====> Epoch: 80/90 (89%) Loss: -1350.2051 LL: 1397.8102 KL: 47.6051 LL/KL: 29.3626
====> Epoch: 80/90 (89%) Loss: -1357.5592 LL: 1404.9619 KL: 47.4027 LL/KL: 29.6388
=============================== ROUND 7 ===============================
====> Epoch: 90/105 (86%) Loss: -1572.0276 LL: 1619.9266 KL: 47.8990 LL/KL: 33.8196
====> Epoch: 100/105 (95%) Loss: -1771.4706 LL: 1819.5625 KL: 48.0920 LL/KL: 37.8351
====> Epoch: 90/105 (86%) Loss: -1572.1785 LL: 1620.1028 KL: 47.9243 LL/KL: 33.8054
====> Epoch: 100/105 (95%) Loss: -1769.1110 LL: 1816.9020 KL: 47.7910 LL/KL: 38.0177
====> Epoch: 90/105 (86%) Loss: -1546.8544 LL: 1594.9036 KL: 48.0492 LL/KL: 33.1931
====> Epoch: 100/105 (95%) Loss: -1743.5269 LL: 1791.6266 KL: 48.0997 LL/KL: 37.2482
====> Epoch: 90/105 (86%) Loss: -1567.8221 LL: 1615.9977 KL: 48.1755 LL/KL: 33.5440
====> Epoch: 100/105 (95%) Loss: -1767.1917 LL: 1814.9788 KL: 47.7871 LL/KL: 37.9805
=============================== ROUND 8 ===============================
====> Epoch: 110/120 (92%) Loss: -1950.5330 LL: 1998.4330 KL: 47.9001 LL/KL: 41.7209
====> Epoch: 110/120 (92%) Loss: -1951.3871 LL: 1999.3215 KL: 47.9344 LL/KL: 41.7095
====> Epoch: 110/120 (92%) Loss: -1925.1748 LL: 1973.2444 KL: 48.0695 LL/KL: 41.0498
====> Epoch: 110/120 (92%) Loss: -1950.7429 LL: 1998.8005 KL: 48.0576 LL/KL: 41.5917
=============================== ROUND 9 ===============================
====> Epoch: 120/135 (89%) Loss: -2126.3816 LL: 2174.9695 KL: 48.5878 LL/KL: 44.7637
====> Epoch: 130/135 (96%) Loss: -2296.7185 LL: 2345.7883 KL: 49.0697 LL/KL: 47.8052
====> Epoch: 120/135 (89%) Loss: -2126.6194 LL: 2175.2815 KL: 48.6620 LL/KL: 44.7018
====> Epoch: 130/135 (96%) Loss: -2298.8037 LL: 2347.8977 KL: 49.0940 LL/KL: 47.8245
====> Epoch: 120/135 (89%) Loss: -2103.1416 LL: 2151.8955 KL: 48.7539 LL/KL: 44.1379
====> Epoch: 130/135 (96%) Loss: -2272.0200 LL: 2321.1653 KL: 49.1453 LL/KL: 47.2307
====> Epoch: 120/135 (89%) Loss: -2126.0889 LL: 2175.0093 KL: 48.9204 LL/KL: 44.4602
====> Epoch: 130/135 (96%) Loss: -2296.1450 LL: 2345.0369 KL: 48.8918 LL/KL: 47.9638
=============================== ROUND 10 ===============================
====> Epoch: 140/150 (93%) Loss: -2429.8711 LL: 2479.5564 KL: 49.6852 LL/KL: 49.9053
====> Epoch: 140/150 (93%) Loss: -2442.0671 LL: 2491.3926 KL: 49.3255 LL/KL: 50.5092
====> Epoch: 140/150 (93%) Loss: -2408.2605 LL: 2457.9480 KL: 49.6874 LL/KL: 49.4683
====> Epoch: 140/150 (93%) Loss: -2433.4390 LL: 2483.2117 KL: 49.7728 LL/KL: 49.8909
Results visualization¶
Using the final parameters we can evaluate the performance of the model by visualizing the testing set onto the latent space. In the Multi-channel scenario, the latent variables \(i\) in the latent space of the channel \(j\) is defined as \(Z_{ij}\).
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
dataset_test = datasets.MNIST('~/data/', train=False, download=True, transform=transform)
X_test, y_test = get_data(dataset_test)
X_test = [x.view(-1, N_FEATURES).to(DEVICE) for x in X_test] # Flatten for Linear layers
Z_test = np.hstack([z.loc.cpu().detach().numpy() for z in trained_model.encode(X_test)])
col_names = [f'$Z_{{{(i // lat_dim) + 1}{(i % lat_dim) + 1}}}$' for i in range(Z_test.shape[1])]
latent_df = pd.DataFrame(Z_test, columns=col_names)
latent_df['label'] = y_test
latent_df['label'] = latent_df['label'].astype('category')
latent_df.head()
$Z_{11}$ | $Z_{12}$ | $Z_{13}$ | $Z_{21}$ | $Z_{22}$ | $Z_{23}$ | $Z_{31}$ | $Z_{32}$ | $Z_{33}$ | $Z_{41}$ | $Z_{42}$ | $Z_{43}$ | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -1.836733 | -0.541675 | 2.626518 | -1.834759 | -0.427387 | 2.642610 | -1.892036 | -0.469842 | 2.679921 | -1.833463 | -0.503372 | 2.641706 | 5 |
1 | -0.052134 | -0.643257 | 2.186315 | 0.026131 | -0.627934 | 2.161103 | 0.003671 | -0.775653 | 2.187668 | -0.072231 | -0.613178 | 2.200200 | 9 |
2 | 0.295430 | -0.565104 | 0.808192 | 0.297953 | -0.577041 | 0.845502 | 0.339431 | -0.616785 | 0.818862 | 0.314601 | -0.576570 | 0.782091 | 2 |
3 | 0.292785 | -0.775920 | 1.084410 | 0.530444 | -0.809846 | 1.136328 | 0.477896 | -0.708132 | 1.191237 | 0.295386 | -0.821254 | 1.176053 | 6 |
4 | 1.955340 | 3.049021 | 1.573791 | 2.166283 | 3.061774 | 1.487722 | 2.132793 | 3.041331 | 1.513418 | 2.023544 | 3.144786 | 1.534572 | 0 |
Representation of the test set onto the latent space¶
sns.pairplot(latent_df, hue='label', corner=True)
plt.show()
Evualuation of the reconstruction¶
Another way to evaluate the performance of this model is to evaluate the reconstruction.
sample_test = [x.reshape(-1, N_FEATURES) for x in dataset_test[0][0]]
sample_pred = trained_model.reconstruct(sample_test)
sample_pred = [x.reshape(-1, 28, 28) for x in sample_test]
plot_digit(dataset_test[0])
plot_digit((sample_test, "reconstruction"))
plt.show()
As observed, the reconstruction is quite good in all the views.
Generation of new datapoints¶
We can try to generate random digits and their respective rotations in a random manner:
First, by definition the aim is to capture all the relationships in the data in the latent space contrained as a Normal distribution \(Z \sim \mathcal{N}(0,1)\). So we can take a sample and reconstruct from there new digits.
from torch.distributions import Normal
rsample = Normal(torch.zeros(1, lat_dim), torch.ones(1, lat_dim)).sample().to(DEVICE)
z_rsample = [rsample for _ in range(N_CHANNELS)]
# Reconstruction process.
# Taken from: https://gitlab.inria.fr/epione_ML/mcvae/-/blob/master/src/mcvae/models/mcvae.py#L162
p = trained_model.decode_in_reconstruction(z_rsample)
x_hat = []
for c in range(N_CHANNELS):
if c in trained_model.dec_channels:
# mean along the stacking direction
x_tmp = torch.stack([p[c][e].loc.detach() for e in range(N_CHANNELS)]).mean(0)
x_hat.append(x_tmp)
del x_tmp
else:
x_hat.append(None)
generated_sample = [x.reshape(-1, 28, 28).detach().cpu() for x in x_hat]
plot_digit((generated_sample, "random"))
plt.show()