πΈ Flower: A Comprehensive Guide to Federated Learning
by Kristina P. Sinaga
πΌπΈFLOWER πΌπΈ
A Comprehensive Tutorial on Federated Learning with the Flower Framework
π Last Updated: August 2nd, 2025 π Original Publication: April 19th, 2024 β¨ Status: Enhanced with improved documentation and comprehensive analysis
π― Tutorial Overview
Federated Learning revolutionizes machine learning by enabling model training across multiple clients without centralizing sensitive data. Instead of moving data to the model, we bring the model to the data!
Key Concept: Each client trains locally on their private data, then shares only model parameters (not raw data) with a central server for aggregation.
This tutorial demonstrates how the <font color='green'>Flower framework π</font> makes federated learning implementation surprisingly straightforward and scalable.
π What Youβll Learn
- ποΈ Core FL Architecture: Client-server federated learning fundamentals
- πΈ Flower Framework: Hands-on implementation with real code
- π MNIST Classification: Practical federated image classification
- π FedAvg Strategy: Understanding parameter aggregation
- π Scalability Testing: From 10 to 200+ federated clients
- π¨ Visualization: Performance tracking and analysis
π οΈ Technical Stack
- Framework: Flower (Federated Learning)
- ML Library: PyTorch
- Dataset: MNIST (Handwritten Digits)
- Strategy: FedAvg (Federated Averaging)
- Visualization: Matplotlib
π Getting Started: Environment Setupitle: βFlowerβ
πΌπΈFLOWER πΌπΈ
The idea behind Federated Learning is to train a model between multiple clients and a server without having to share any data. This is done by letting each client train the model locally on its data and send its parameters back to the server, which then aggregates all the clientsβ parameters together using a predefined strategy. This process is made very simple by using the Flower framework π.Flower: A Friendly Federated Learning Research Framework on MNIST Data.
π Getting Started: Environment Setup
Letβs begin by setting up our federated learning environment. Weβll import all necessary libraries and configure our training device.
# Core Federated Learning Framework
import flwr as fl
# Deep Learning and Data Processing
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
# Utilities and Visualization
import numpy as np
import matplotlib.pyplot as plt
import random
from collections import OrderedDict
from typing import Dict, Tuple, List
# Flower-specific imports
from flwr.common import Metrics, NDArrays, Scalar
from transformers import AutoTokenizer, DataCollatorWithPadding
# Device Configuration
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# For this tutorial, we'll use CPU for consistency
DEVICE = torch.device("cpu")
# Display versions for reproducibility
print("π§ Environment Setup:")
print(f" β’ Flower: {fl.__version__}")
print(f" β’ PyTorch: {torch.__version__}")
print(f" β’ TorchVision: {torchvision.__version__}")
print(f" β’ NumPy: {np.__version__}")
print(f" β’ Device: {DEVICE}")
print("β
Environment ready for federated learning!")
π§ Neural Network Architecture
Weβll use a simple but effective Convolutional Neural Network for MNIST digit classification:
class Net(nn.Module):
"""
Simple CNN for MNIST classification
- 2 Convolutional layers with ReLU activation
- Max pooling for spatial reduction
- 2 Fully connected layers for classification
"""
def __init__(self, num_classes: int = 10) -> None:
super(Net, self).__init__()
# Convolutional Feature Extractor
self.conv1 = nn.Conv2d(1, 6, 5) # 28x28 -> 24x24
self.pool = nn.MaxPool2d(2, 2) # 24x24 -> 12x12
self.conv2 = nn.Conv2d(6, 16, 5) # 12x12 -> 8x8
# After pooling: 4x4
# Classifier Head
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Feature extraction
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
# Flatten for classifier
x = x.view(-1, 16 * 4 * 4)
# Classification
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model and display parameter count
model = Net(num_classes=10)
num_parameters = sum(value.numel() for value in model.state_dict().values())
print(f"π Model Statistics:")
print(f" β’ Architecture: Simple CNN")
print(f" β’ Total Parameters: {num_parameters:,}")
print(f" β’ Input Shape: [batch_size, 1, 28, 28]")
print(f" β’ Output Classes: 10 (digits 0-9)")
π Dataset Preparation and Visualization
Letβs load the MNIST dataset and visualize some samples to understand our data:
def get_mnist():
"""Load and preprocess MNIST dataset"""
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])
trainset = MNIST("./data", train=True, download=True, transform=transform)
testset = MNIST("./data", train=False, download=True, transform=transform)
return trainset, testset
# Load dataset
trainset, testset = get_mnist()
print(f"π Dataset Statistics:")
print(f" β’ Training samples: {len(trainset):,}")
print(f" β’ Test samples: {len(testset):,}")
print(f" β’ Image shape: {trainset.data[0].shape}")
print(f" β’ Classes: {len(trainset.classes)}")
π¨ Data Visualization
Letβs visualize some sample images to understand our dataset better:
def visualise_n_random_examples(trainset_, n: int = 64, verbose: bool = True):
"""
Visualize n random examples from the dataset in a grid layout
Args:
trainset_: MNIST training dataset
n: Number of images to display
verbose: Whether to print image indices
"""
# Sample random indices
idx = list(range(len(trainset_.data)))
random.shuffle(idx)
idx = idx[:n]
if verbose:
print(f"π― Displaying {n} random samples with indices: {idx[:10]}..." if n > 10 else f"π― Displaying samples with indices: {idx}")
# Create visualization grid
num_cols = 16
num_rows = int(np.ceil(len(idx) / num_cols))
fig, axs = plt.subplots(
figsize=(16, num_rows * 1),
nrows=num_rows,
ncols=num_cols,
gridspec_kw={'hspace': 0.1, 'wspace': 0.1}
)
# Ensure axs is always 2D for consistent indexing
if num_rows == 1:
axs = axs.reshape(1, -1)
# Display images
for c_i, i in enumerate(idx):
row, col = c_i // num_cols, c_i % num_cols
axs[row, col].imshow(trainset_.data[i], cmap="gray")
axs[row, col].set_title(f'{trainset_.targets[i]}', fontsize=8)
axs[row, col].axis('off')
# Hide unused subplots
for c_i in range(len(idx), num_rows * num_cols):
row, col = c_i // num_cols, c_i % num_cols
axs[row, col].axis('off')
plt.suptitle('MNIST Dataset Sample - Random Handwritten Digits', fontsize=14, y=0.98)
plt.tight_layout()
plt.show()
# Visualize the dataset
print("πΈ Visualizing MNIST Dataset Samples")
visualise_n_random_examples(trainset, n=64)
Figure: 4x16 grid of random MNIST handwritten digit samples from the training dataset
ποΈ Training and Evaluation Functions
Before diving into federated learning, letβs define our core training and evaluation functions:
ποΈ Training and Evaluation Functions
Before diving into federated learning, letβs define our core training and evaluation functions:
def train(net, trainloader, optimizer, epochs: int = 1):
"""
Train the network on the training set
Args:
net: Neural network model
trainloader: Training data loader
optimizer: Optimizer (SGD, Adam, etc.)
epochs: Number of training epochs
Returns:
Trained network
"""
criterion = torch.nn.CrossEntropyLoss()
net.train() # Set to training mode
total_loss = 0.0
total_samples = 0
for epoch in range(epochs):
epoch_loss = 0.0
for batch_idx, (images, labels) in enumerate(trainloader):
images, labels = images.to(DEVICE), labels.to(DEVICE)
# Forward pass
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
epoch_loss += loss.item()
total_samples += labels.size(0)
total_loss += epoch_loss
print(f" Epoch {epoch+1}/{epochs}: Loss = {epoch_loss/len(trainloader):.4f}")
avg_loss = total_loss / (epochs * len(trainloader))
print(f"π Training completed: Avg Loss = {avg_loss:.4f}")
return net
def test(net, testloader):
"""
Evaluate the network on the test set
Args:
net: Trained neural network
testloader: Test data loader
Returns:
tuple: (average_loss, accuracy)
"""
criterion = torch.nn.CrossEntropyLoss()
correct, total_loss = 0, 0.0
net.eval() # Set to evaluation mode
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = net(images)
# Calculate loss
loss = criterion(outputs, labels)
total_loss += loss.item()
# Calculate accuracy
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(testloader.dataset)
avg_loss = total_loss / len(testloader)
return avg_loss, accuracy
def run_centralised(epochs: int = 5, lr: float = 0.01, momentum: float = 0.9):
"""
Baseline: Traditional centralized training for comparison
Args:
epochs: Number of training epochs
lr: Learning rate
momentum: SGD momentum
"""
print("π― Running Centralized Training (Baseline)")
# Initialize model and optimizer
model = Net(num_classes=10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# Prepare data loaders
trainset, testset = get_mnist()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=128, num_workers=2)
# Train the model
trained_model = train(model, trainloader, optimizer, epochs)
# Evaluate performance
loss, accuracy = test(trained_model, testloader)
print(f"π Centralized Results:")
print(f" β’ Test Loss: {loss:.4f}")
print(f" β’ Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
return trained_model, loss, accuracy
# Run baseline for comparison
print("=" * 60)
print("BASELINE: CENTRALIZED TRAINING")
print("=" * 60)
baseline_model, baseline_loss, baseline_acc = run_centralised(epochs=3)
π Federated Data Partitioning
In federated learning, we need to split our dataset among multiple clients. Letβs create realistic data partitions:
def prepare_dataset(num_partitions: int, batch_size: int = 32, val_ratio: float = 0.1):
"""
Partition the training set into N disjoint subsets for federated clients
Args:
num_partitions: Number of clients/partitions
batch_size: Batch size for data loaders
val_ratio: Fraction of data for validation
Returns:
tuple: (trainloaders, valloaders, testloader)
"""
print(f"π Preparing federated dataset for {num_partitions} clients...")
# Load the dataset
trainset, testset = get_mnist()
# Split training set into equal partitions
num_images = len(trainset) // num_partitions
partition_lengths = [num_images] * num_partitions
# Handle remainder samples
remainder = len(trainset) % num_partitions
for i in range(remainder):
partition_lengths[i] += 1
print(f" β’ Total training samples: {len(trainset):,}")
print(f" β’ Samples per client: {num_images} (Β± {remainder} for some clients)")
# Create partitions with reproducible random seed
trainsets = random_split(
trainset, partition_lengths, torch.Generator().manual_seed(2023)
)
# Create data loaders with train/validation splits
trainloaders = []
valloaders = []
for i, trainset_partition in enumerate(trainsets):
num_total = len(trainset_partition)
num_val = int(val_ratio * num_total)
num_train = num_total - num_val
# Split into train and validation
for_train, for_val = random_split(
trainset_partition, [num_train, num_val],
torch.Generator().manual_seed(2023)
)
# Create data loaders
trainloaders.append(
DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)
)
valloaders.append(
DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)
)
# Global test loader
testloader = DataLoader(testset, batch_size=128, num_workers=2)
print(f"β
Dataset partitioning completed!")
print(f" β’ Training loaders: {len(trainloaders)}")
print(f" β’ Validation loaders: {len(valloaders)}")
print(f" β’ Test samples: {len(testset):,}")
return trainloaders, valloaders, testloader
π Analyzing Data Distribution
Letβs examine how data is distributed across clients:
# Create federated partitions
trainloaders, valloaders, testloader = prepare_dataset(
num_partitions=100, batch_size=32
)
# Analyze first client's data distribution
print("π Analyzing Client #1 Data Distribution:")
train_partition = trainloaders[0].dataset
partition_indices = train_partition.indices
print(f" β’ Client samples: {len(partition_indices):,}")
print(f" β’ Batch size: {trainloaders[0].batch_size}")
print(f" β’ Number of batches: {len(trainloaders[0])}")
# Extract labels for this partition
client_targets = [trainloaders[0].dataset.dataset.dataset.targets[i] for i in partition_indices]
unique_labels, counts = np.unique(client_targets, return_counts=True)
print(f" β’ Unique digits: {unique_labels.tolist()}")
print(f" β’ Label distribution: {dict(zip(unique_labels.tolist(), counts.tolist()))}")
# Visualize label distribution
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(client_targets, bins=10, color="skyblue", edgecolor="navy", alpha=0.7)
plt.grid(True, alpha=0.3)
plt.xticks(range(10))
plt.xlabel("Digit Label")
plt.ylabel("Number of Images")
plt.title("Client #1: Label Distribution (IID)")
plt.subplot(1, 2, 2)
plt.bar(unique_labels, counts, color="lightcoral", edgecolor="darkred", alpha=0.7)
plt.grid(True, alpha=0.3)
plt.xticks(range(10))
plt.xlabel("Digit Label")
plt.ylabel("Count")
plt.title("Client #1: Exact Label Counts")
plt.tight_layout()
plt.suptitle("MNIST Federated Client Data Analysis", y=1.02, fontsize=14)
plt.show()
# Overall statistics
total_samples = sum(len(loader.dataset) for loader in trainloaders)
avg_samples = total_samples / len(trainloaders)
print(f"\nπ Federated Dataset Statistics:")
print(f" β’ Total clients: {len(trainloaders)}")
print(f" β’ Total federated samples: {total_samples:,}")
print(f" β’ Average samples per client: {avg_samples:.1f}")
print(f" β’ Data distribution: IID (Identical & Independent)")
πΈ Defining the Flower Client
Now comes the heart of federated learning - defining our Flower client! A Flower client is elegantly simple with four key methods:
π Core Client Methods
| Method | Purpose | Description |
|---|---|---|
fit() |
ποΈLocal Training | Train model locally and return updated parameters |
evaluate() |
πLocal Evaluation | Evaluate global model on local validation data |
set_parameters() |
π₯Parameter Loading | Load global model parameters from server |
get_parameters() |
π€Parameter Extraction | Extract local model parameters for server |
Key Insight: Federated learning is like having a distributed team where each member (client) works on their own data, but everyone shares their learnings (parameters) to build a better collective model! π€
π οΈ Setting Up the Environment
# Import Flower framework
import flwr as fl
# Ensure we're using the correct device
DEVICE = torch.device("cpu") # Consistent with our setup
print(f"πΈ Flower Client Environment:")
print(f" β’ Device: {DEVICE}")
print(f" β’ Framework: PyTorch")
print(f" β’ Ready for federated learning!")
ποΈ Implementing the FlowerClient Class
class FlowerClient(fl.client.NumPyClient):
"""
A Flower client for federated MNIST classification
This client can:
- Receive global model parameters from the server
- Train locally on private data
- Send updated parameters back to the server
- Evaluate global model performance locally
"""
def __init__(self, trainloader, valloader) -> None:
super().__init__()
# Store client's private data
self.trainloader = trainloader
self.valloader = valloader
# Initialize local model
self.model = Net(num_classes=10).to(DEVICE)
print(f"πΈ Flower Client initialized:")
print(f" β’ Training samples: {len(trainloader.dataset)}")
print(f" β’ Validation samples: {len(valloader.dataset)}")
print(f" β’ Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
def set_parameters(self, parameters: NDArrays) -> None:
"""
Load parameters received from the server into the local model
Args:
parameters: List of NumPy arrays representing model weights
"""
# Convert NumPy arrays back to PyTorch tensors
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
# Load parameters into the model
self.model.load_state_dict(state_dict, strict=True)
def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""
Extract model parameters and convert to NumPy arrays
Args:
config: Configuration dictionary (unused in this implementation)
Returns:
List of NumPy arrays representing model weights
"""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def fit(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[NDArrays, int, Dict]:
"""
Perform local training using the global model parameters
Args:
parameters: Global model parameters from server
config: Training configuration
Returns:
tuple: (updated_parameters, num_samples, metrics)
"""
print(f"ποΈ Starting local training...")
# Load global model parameters
self.set_parameters(parameters)
# Configure local optimizer
optimizer = torch.optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9)
)
# Perform local training
epochs = config.get("epochs", 1)
self.model = train(self.model, self.trainloader, optimizer, epochs)
# Return updated parameters
updated_parameters = self.get_parameters({})
num_samples = len(self.trainloader.dataset)
print(f"β
Local training completed ({num_samples} samples)")
return updated_parameters, num_samples, {}
def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict[str, Scalar]]:
"""
Evaluate the global model on local validation data
Args:
parameters: Global model parameters
config: Evaluation configuration
Returns:
tuple: (loss, num_samples, metrics)
"""
print(f"π Evaluating global model locally...")
# Load global model parameters
self.set_parameters(parameters)
# Evaluate on local validation set
loss, accuracy = test(self.model, self.valloader)
num_samples = len(self.valloader.dataset)
print(f"π Local evaluation: Loss={loss:.4f}, Accuracy={accuracy:.4f}")
return float(loss), num_samples, {"accuracy": accuracy}
# Test client initialization
print("π§ͺ Testing FlowerClient initialization...")
sample_client = FlowerClient(trainloaders[0], valloaders[0])
print("β
FlowerClient class ready for federated learning!")
π‘ Pro Tip: The beauty of the FlowerClient class is that it encapsulates all the federated learning complexity while keeping the interface simple. Each client manages its own data and model, but communicates through standardized NumPy arrays! π―
π― Federated Learning Strategy: FedAvg
The strategy is the brain of federated learning! It orchestrates the entire process: client sampling, model aggregation, and evaluation coordination.
π§ Understanding FedAvg
Federated Averaging (FedAvg) is simple yet powerful:
- π€ Server sends global model to selected clients
- ποΈ Clients train locally on their private data
- π₯ Server receives updated model parameters
- βοΈ Aggregation averages all client updates (weighted by data size)
- π Repeat for multiple rounds
Mathematical Foundation:
w_global = Ξ£(n_i * w_i) / Ξ£(n_i)Where
w_iare client weights andn_iare client data sizes
π§ Implementing FedAvg with Flower
def get_evaluate_fn(testloader):
"""
Create a centralized evaluation function for the global model
This function will be called by the server after each round to assess
global model performance on a held-out test set.
Args:
testloader: DataLoader for the test dataset
Returns:
evaluate_fn: Function that evaluates global model
"""
def evaluate_fn(server_round: int, parameters: NDArrays, config: Dict[str, Scalar]):
"""
Evaluate the global model on the centralized test set
Args:
server_round: Current federated learning round
parameters: Global model parameters
config: Evaluation configuration
Returns:
tuple: (loss, metrics_dict)
"""
print(f"π Round {server_round}: Evaluating global model...")
# Create a fresh model for evaluation
model = Net(num_classes=10).to(DEVICE)
# Load global parameters
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
# Evaluate on the global test set
loss, accuracy = test(model, testloader)
print(f"π Global Model Performance:")
print(f" β’ Test Loss: {loss:.4f}")
print(f" β’ Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
return loss, {"accuracy": accuracy}
return evaluate_fn
# Configure the FedAvg strategy
print("βοΈ Configuring Federated Learning Strategy...")
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.1, # 10% of clients participate in each training round
fraction_evaluate=0.1, # 10% of clients participate in evaluation
min_available_clients=50, # Minimum clients needed to start training
evaluate_fn=get_evaluate_fn(testloader), # Centralized evaluation function
# Optional: Configure client sampling and aggregation
min_fit_clients=10, # Minimum clients for training round
min_evaluate_clients=5, # Minimum clients for evaluation round
# Training configuration sent to clients
fit_config_fn=lambda server_round: {
"lr": 0.01, # Learning rate
"momentum": 0.9, # SGD momentum
"epochs": 1, # Local training epochs
},
# Evaluation configuration sent to clients
evaluate_config_fn=lambda server_round: {
"batch_size": 64, # Evaluation batch size
}
)
print("β
FedAvg Strategy configured:")
print(f" β’ Client participation: 10% per round")
print(f" β’ Minimum clients required: 50")
print(f" β’ Local training: 1 epoch per round")
print(f" β’ Learning rate: 0.01")
π Client Factory Function
We need a factory function to create clients dynamically during simulation:
def generate_client_fn(trainloaders, valloaders):
"""
Generate a client factory function for Flower simulation
Args:
trainloaders: List of training data loaders (one per client)
valloaders: List of validation data loaders (one per client)
Returns:
client_fn: Function that creates FlowerClient instances
"""
def client_fn(cid: str) -> FlowerClient:
"""
Create a FlowerClient instance for a given client ID
Args:
cid: Client ID (string representation of client index)
Returns:
FlowerClient instance with the client's data partition
"""
client_id = int(cid)
print(f"πΈ Creating client {client_id} with {len(trainloaders[client_id].dataset)} training samples")
return FlowerClient(
trainloader=trainloaders[client_id],
valloader=valloaders[client_id]
)
return client_fn
# Create the client factory
client_fn_callback = generate_client_fn(trainloaders, valloaders)
print("π Client factory function created successfully!")
π Launching the Federated Learning Experiment
Everything is ready! Letβs launch our federated learning simulation:
import time
print("=" * 80)
print("πΈ LAUNCHING FEDERATED LEARNING WITH FLOWER πΈ")
print("=" * 80)
# Simulation configuration
num_clients = 100 # Total clients in the federation
num_rounds = 10 # Number of federated learning rounds
clients_per_round = 10 # Clients participating per round (10% of 100)
print(f"π Simulation Configuration:")
print(f" β’ Total clients: {num_clients}")
print(f" β’ FL rounds: {num_rounds}")
print(f" β’ Clients per round: {clients_per_round}")
print(f" β’ Strategy: FedAvg")
print(f" β’ Dataset: MNIST")
print(f" β’ Local epochs: 1")
print("\nπ Starting federated learning simulation...")
# Record start time
start_time = time.time()
# Launch the federated learning simulation!
history = fl.simulation.start_simulation(
client_fn=client_fn_callback, # Factory function for creating clients
num_clients=num_clients, # Total number of clients
config=fl.server.ServerConfig(num_rounds=num_rounds), # Number of FL rounds
strategy=strategy, # FedAvg strategy
client_resources={"num_cpus": 1, "num_gpus": 0} # Resource allocation per client
)
# Record end time
end_time = time.time()
simulation_time = end_time - start_time
print(f"\nβ±οΈ Simulation completed in {simulation_time:.2f} seconds")
print(f"π Federated learning finished successfully!")
β‘ Performance Note: Flower simulation is remarkably fast!
- 10 rounds: ~2 minutes on CPU
- 20 rounds: ~15 minutes on CPU
- Scales efficiently with more clients! π
π Results Analysis and Visualization
Letβs analyze the performance of our federated learning experiment:
print("π Analyzing Federated Learning Results...")
# Extract centralized accuracy metrics
print(f"π Available metrics: {list(history.metrics_centralized.keys())}")
if "accuracy" in history.metrics_centralized:
global_accuracy_centralized = history.metrics_centralized["accuracy"]
rounds = [data[0] for data in global_accuracy_centralized]
accuracies = [100.0 * data[1] for data in global_accuracy_centralized]
# Print summary statistics
print(f"\nπ Federated Learning Results Summary:")
print(f" β’ Initial accuracy: {accuracies[0]:.2f}%")
print(f" β’ Final accuracy: {accuracies[-1]:.2f}%")
print(f" β’ Improvement: {accuracies[-1] - accuracies[0]:+.2f}%")
print(f" β’ Best accuracy: {max(accuracies):.2f}%")
print(f" β’ Total rounds: {len(rounds)-1}") # -1 because round 0 is initial
# Compare with baseline
federated_final_acc = accuracies[-1] / 100.0
improvement_over_baseline = (federated_final_acc - baseline_acc) * 100
print(f"\nπ Comparison with Centralized Baseline:")
print(f" β’ Centralized accuracy: {baseline_acc*100:.2f}%")
print(f" β’ Federated accuracy: {federated_final_acc*100:.2f}%")
print(f" β’ Difference: {improvement_over_baseline:+.2f}%")
# Create comprehensive visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Accuracy over rounds
ax1.plot(rounds, accuracies, 'b-o', linewidth=2, markersize=6, label='Federated Learning')
ax1.axhline(y=baseline_acc*100, color='r', linestyle='--', linewidth=2, label='Centralized Baseline')
ax1.grid(True, alpha=0.3)
ax1.set_xlabel('Round')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title('Federated Learning: Global Model Accuracy')
ax1.legend()
ax1.set_ylim([min(accuracies) - 5, max(accuracies) + 5])
# Plot 2: Accuracy improvement
if len(accuracies) > 1:
improvements = [acc - accuracies[0] for acc in accuracies[1:]]
ax2.bar(rounds[1:], improvements, alpha=0.7, color='green')
ax2.grid(True, alpha=0.3)
ax2.set_xlabel('Round')
ax2.set_ylabel('Accuracy Improvement (%)')
ax2.set_title('Round-by-Round Improvement')
# Plot 3: Learning curve
ax3.plot(rounds, accuracies, 'g-', linewidth=3, alpha=0.7)
ax3.fill_between(rounds, accuracies, alpha=0.3, color='green')
ax3.grid(True, alpha=0.3)
ax3.set_xlabel('Round')
ax3.set_ylabel('Accuracy (%)')
ax3.set_title('Learning Curve (Area Plot)')
# Plot 4: Performance summary
metrics = ['Initial', 'Final', 'Best', 'Baseline']
values = [accuracies[0], accuracies[-1], max(accuracies), baseline_acc*100]
colors = ['lightblue', 'lightgreen', 'gold', 'lightcoral']
bars = ax4.bar(metrics, values, color=colors, alpha=0.8, edgecolor='black')
ax4.set_ylabel('Accuracy (%)')
ax4.set_title('Performance Comparison')
ax4.grid(True, alpha=0.3, axis='y')
# Add value labels on bars
for bar, value in zip(bars, values):
height = bar.get_height()
ax4.text(bar.get_x() + bar.get_width()/2., height + 0.5,
f'{value:.1f}%', ha='center', va='bottom', fontweight='bold')
plt.tight_layout()
plt.suptitle(f'MNIST Federated Learning Results - {num_clients} Clients, {num_rounds} Rounds',
fontsize=16, y=1.02)
plt.show()
else:
print("β οΈ No accuracy metrics found in history")
# Print detailed round-by-round progress
print(f"\nπ Detailed Round-by-Round Progress:")
print("Round | Accuracy | Improvement")
print("-" * 35)
for i, (round_num, acc) in enumerate(zip(rounds, accuracies)):
if i == 0:
print(f"{round_num:5d} | {acc:7.2f}% | Initial")
else:
improvement = acc - accuracies[i-1]
print(f"{round_num:5d} | {acc:7.2f}% | {improvement:+6.2f}%")
π Scaling Up: Testing with More Clients
One of Flowerβs greatest strengths is scalability. Letβs test with even more clients:
def run_scalability_test(client_counts: List[int] = [50, 100, 200]):
"""
Test federated learning performance with different numbers of clients
Args:
client_counts: List of client counts to test
"""
print("π¬ Running Scalability Analysis...")
results = {}
for num_clients_test in client_counts:
print(f"\nπ§ͺ Testing with {num_clients_test} clients...")
# Prepare dataset for this client count
train_loaders_test, val_loaders_test, test_loader_test = prepare_dataset(
num_partitions=num_clients_test, batch_size=32
)
# Create client factory
client_fn_test = generate_client_fn(train_loaders_test, val_loaders_test)
# Configure strategy for this test
strategy_test = fl.server.strategy.FedAvg(
fraction_fit=0.1,
fraction_evaluate=0.1,
min_available_clients=min(50, num_clients_test),
evaluate_fn=get_evaluate_fn(test_loader_test),
fit_config_fn=lambda server_round: {"lr": 0.01, "momentum": 0.9, "epochs": 1}
)
# Run simulation
start_time = time.time()
history_test = fl.simulation.start_simulation(
client_fn=client_fn_test,
num_clients=num_clients_test,
config=fl.server.ServerConfig(num_rounds=5), # Fewer rounds for testing
strategy=strategy_test,
client_resources={"num_cpus": 1, "num_gpus": 0}
)
end_time = time.time()
# Extract results
if "accuracy" in history_test.metrics_centralized:
accuracies = [data[1] for data in history_test.metrics_centralized["accuracy"]]
final_accuracy = accuracies[-1] * 100
results[num_clients_test] = {
'final_accuracy': final_accuracy,
'training_time': end_time - start_time,
'rounds': len(accuracies) - 1
}
print(f"β
{num_clients_test} clients: {final_accuracy:.2f}% accuracy in {end_time - start_time:.1f}s")
else:
print(f"β No results for {num_clients_test} clients")
return results
# Run scalability test
print("π― Demonstrating Flower's Scalability...")
scalability_results = run_scalability_test([100, 200])
# Visualize scalability results
if scalability_results:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
client_counts = list(scalability_results.keys())
accuracies = [scalability_results[c]['final_accuracy'] for c in client_counts]
times = [scalability_results[c]['training_time'] for c in client_counts]
# Accuracy vs Number of Clients
ax1.plot(client_counts, accuracies, 'bo-', linewidth=2, markersize=8)
ax1.grid(True, alpha=0.3)
ax1.set_xlabel('Number of Clients')
ax1.set_ylabel('Final Accuracy (%)')
ax1.set_title('Accuracy vs Number of Clients')
# Training Time vs Number of Clients
ax2.plot(client_counts, times, 'ro-', linewidth=2, markersize=8)
ax2.grid(True, alpha=0.3)
ax2.set_xlabel('Number of Clients')
ax2.set_ylabel('Training Time (seconds)')
ax2.set_title('Training Time vs Number of Clients')
plt.tight_layout()
plt.suptitle('Flower Framework Scalability Analysis', fontsize=14, y=1.02)
plt.show()
π Key Takeaways and Future Directions
π What We Accomplished
β Implemented complete federated learning pipeline with Flower β Demonstrated privacy-preserving training across 100+ clients β Achieved competitive performance compared to centralized training β Showcased remarkable scalability with minimal overhead β Visualized learning dynamics and performance metrics
π Key Insights
- πΈ Flower Simplicity: Complex federated learning made remarkably simple
- β‘ Performance: Excellent accuracy with fast training times
- π Scalability: Seamlessly scales from 10 to 200+ clients
- π Privacy: Data never leaves client devices
- π§ Flexibility: Easy to customize strategies and configurations
π Next Steps & Extensions
| Enhancement | Description | Impact |
|---|---|---|
| Non-IID Data | Simulate realistic heterogeneous data distributions | More realistic federated scenarios |
| Advanced Strategies | Implement FedProx, FedNova, or custom aggregation | Better convergence properties |
| Differential Privacy | Add noise for stronger privacy guarantees | Enhanced privacy protection |
| Cross-Device FL | Simulate mobile devices with intermittent connectivity | Real-world deployment scenarios |
| Personalization | Local model fine-tuning for client-specific tasks | Improved individual performance |
π‘ Real-World Applications
- π₯ Healthcare: Collaborative medical imaging without sharing patient data
- π± Mobile AI: Keyboard prediction, voice recognition across devices
- π¦ Finance: Fraud detection across institutions
- π Autonomous Vehicles: Shared learning from driving experiences
- π Industrial IoT: Predictive maintenance across factories
π Conclusion
<font color='blue'>πΌπΈF </font><font color='orange'>L </font><font color='magenta'>O </font><font color='yellow'>W </font><font color='green'>E </font><font color='black'>R πΌπΈ</font> has proven to be an exceptional framework for federated learning research and deployment. Its elegant design makes complex distributed ML accessible while maintaining the flexibility needed for cutting-edge research.
βFederated learning isnβt just about distributing computationβitβs about democratizing AI while preserving privacy. Flower makes this vision achievable.β πΈ
π Resources for Further Learning
- πΈ Flower Documentation
- π Federated Learning Book
- π FL Research Papers
- π» Flower GitHub
Thank you for joining this federated learning journey! Feel free to experiment, modify, and build upon this foundation. The future of privacy-preserving AI is blooming! πΈπ
π Document History
- π Original Publication: April 19th, 2024 - Initial Flower tutorial with basic implementation
- β¨ Major Update: August 2nd, 2025 - Comprehensive enhancement with:
- Improved code documentation and structure
- Advanced visualization and analysis
- Professional formatting and educational content
- Scalability testing and performance metrics
- Real-world applications and future directions
This tutorial continues to evolve as federated learning advances. Check back for future updates! π
tags: Federated Learning - Flower Framework - MNIST - PyTorch - Distributed Computing - Privacy-Preserving ML