This blog explores fine-tuning a Swin Transformer model for agricultural plant classification using transfer learning and contrastive learning techniques.
Plant classification is important in various industries, including agriculture, environmental monitoring, and research. Traditionally, plant identification relied on manual methods, which can be slow, labor-intensive, and prone to errors. AI-driven models, particularly in computer vision, have made a significant impact by automating this process, improving both speed and accuracy.
Convolutional Neural Networks (CNNs) have been widely used for plant detection tasks, but they often struggle with complex patterns and large datasets. Vision Transformers (ViTs) address some of these challenges by offering a new way of processing images, yet they can be computationally intensive.
The Swin Transformer overcomes these limitations by using a hierarchical structure and shifted window attention. This helps it to efficiently process high-resolution images while maintaining strong performance. By fine-tuning a Swin Transformer, we can train it to recognize distinct plant features, making it ideal for tasks such as species identification and disease detection.
This model is particularly beneficial in fields like agriculture and agritech, where accurate, fast, and scalable plant classification is essential for improving crop yields, monitoring plant health, and enhancing environmental sustainability.
In this guide, we will walk you through how to implement the Swin Transformer for plant classification, explore its architecture, and discuss how it is helping industries that rely on plant data and analysis.
The Swin Transformer is a hierarchical vision transformer designed to improve image recognition through locality, scalability, and computational efficiency. The Swin Transformer variant used here incorporates several key features:
The Swin Transformer’s efficient image recognition capabilities make it highly effective for a wide range of applications. In industries like agritech and agriculture, it plays a crucial role in plant detection and plant disease detection.
Its hierarchical structure and efficient handling of spatial relationships enable it to accurately classify plant species and identify potential diseases, making it a valuable tool for improving agricultural practices. The Swin Transformer’s ability to capture both fine-grained and large-scale features also enhances its use in automated systems for crop monitoring, precision farming, and environmental analysis.
We structured the training process into four key stages. Each step plays a crucial role in building an accurate plant classification model.
Before training any model, the first step is choosing the right dataset. For this project, we used PlantNet-300K, a large-scale plant image dataset designed for classification tasks. It contains images of various plant species, making it ideal for training AI models to recognize different plants accurately.Here’s a quick breakdown of the dataset:
Before we begin training our Swin Transformer model for plant classification, the first step is to get the data ready. This involves downloading and extracting a large dataset of plant images. We’ll be using the PlantNet 300K dataset, which contains a wealth of plant species data, ideal for training a robust classification model.
We need to install a few dependencies to make sure our setup works smoothly. First, run the following command to install necessary libraries:
pip install wget timm torch torchvision
Once the dependencies are installed, the next step is to download the dataset. This might take a little time. To download the dataset, use the following wget command:
!wget https://zenodo.org/records/5645731/files/plantnet_300K.zip
Now that we have the dataset, we need to extract its contents. The PlantNet 300K dataset is pretty large, so this part might take 4–5 hours. Be patient as the dataset unzips:
!unzip -n "plantnet_300K.zip?download=1.1" -d plantnet_300K
Now that we have the dataset, it’s time to start setting things up for the training process. We will go through a few steps to import the necessary libraries, set up file paths, and organize the dataset into species-specific folders.
The first step is to import all the libraries we need to make things work. Here’s what we’ll be using:
import torch
import torch.nn as nn
import torch.optim as optim
import timm
import json
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import os
from torchvision import datasets
Import shutil
Next, we’ll define the paths where the dataset is stored and where we want to organize the images. Make sure to replace any directory paths as necessary based on where your dataset is saved.
DATASET_DIR = "plantnet_300K/plantnet_300K"
IMAGES_DIR = os.path.join(DATASET_DIR, "images/test")
SPECIES_MAP_FILE=os.path.join(DATASET_DIR,"plantnet300K_species_id_2_name.json")
We’ll now load the mapping from species IDs to species names. This is crucial to ensure each plant species is categorized correctly. The file plantnet300K_species_id_2_name.json contains this mapping in JSON format.
with open(SPECIES_MAP_FILE, "r") as f: species_map = json.load(f) # Dictionary
This loads the mapping into the species_map dictionary. Each plant ID will be mapped to its corresponding species name.
We’ll create a function that organizes the images into species-specific folders. The images will be moved from their current directory into a new folder named after the species. If a folder for a species doesn’t exist yet, it will be created automatically.
def organize_images(base_dir):
for plant_id in tqdm(os.listdir(base_dir),
desc=f"Organizing {os.path.basename(base_dir)}"):
plant_path = os.path.join(base_dir, plant_id)
# Skip if not a directory
if not os.path.isdir(plant_path):
continue
species_name = species_map.get(plant_id, "Unknown_Species").replace(" ", "_")
# Create species folder inside base_dir
species_folder = os.path.join(base_dir, species_name) os.makedirs(species_folder, exist_ok=True)
# Move images into species folder
for img in os.listdir(plant_path):
shutil.move(os.path.join(plant_path, img), os.path.join(species_folder, img))
# Remove empty plant_id folder os.rmdir(plant_path)
# Organize images in train and val directories
organize_images(TRAIN_DIR)
organize_images(VAL_DIR)
print("✅ Images successfully organized into species-specific folders!")
Once the images are organized, let’s check how many species are in the train and validation directories. This will help us confirm that everything has been organized properly.
train_species = os.listdir("plantnet_300K/plantnet_300K/train")
val_species = os.listdir("plantnet_300K/plantnet_300K/val")
print(f"📂 Train Folders: {len(train_species)} species")
print(f"📂 Validation Folders: {len(val_species)} species")
Now that the dataset is organized into species-specific folders, we can move forward with data preprocessing. We’ll ensure that the dataset is clean and ready for training by performing the following tasks:
If you haven’t already set the directories for your train and validation datasets, you can define them here:
TRAIN_DIR = "plantnet_300K/plantnet_300K/train"
VAL_DIR = "plantnet_300K/plantnet_300K/val"
To ease the training process, we need to ensure there are no empty species folders in both the training and validation directories. The following function checks for empty folders and returns them along with their count:
def identify_empty_folders(root_dir):
empty_folders = []
for species_folder in os.listdir(root_dir):
species_folder_path = os.path.join(root_dir, species_folder)
# Only proceed if it's a directory
if os.path.isdir(species_folder_path):
# Check if the folder is empty
if not any(os.scandir(species_folder_path)):
empty_folders.append(species_folder_path)
# Return the list of empty folders and their count
return empty_folders, len(empty_folders)
# Identify empty folders in both training and validation directories
empty_train_folders, empty_train_count = identify_empty_folders(TRAIN_DIR)
empty_val_folders, empty_val_count = identify_empty_folders(VAL_DIR)
# Print the results
if empty_train_folders:
print("Empty folders in train directory:")
for folder in empty_train_folders:
print(folder)
print(f"Total empty folders in train: {empty_train_count}")
else:
print("No empty folders found in train.")
if empty_val_folders:
print("Empty folders in val directory:")
for folder in empty_val_folders:
print(folder)
print(f"Total empty folders in val: {empty_val_count}")
else:
print("No empty folders found in val.")
train_folders = set(os.listdir(train_dir))
# Get a list of folders in val
val_folders = set(os.listdir(val_dir))
We had 321 empty folders in our train directory, hence we have to remove those folders in the val directory to ensure correct classification classes.
folders_to_remove = val_folders - train_folders
# Remove those folders from val
for folder in folders_to_remove:
folder_path = os.path.join(val_dir, folder)
shutil.rmtree(folder_path)
print(f"Removed: {folder_path}")
print("Completed folder removal.")
After removing the empty folders, let’s count how many species folders remain in both the train and validation directories to confirm everything is correct:
def count_remaining_folders(directory):
# Count the number of non-empty directories (species folders)
return len([folder for folder in os.listdir(directory) if os.path.isdir(os.path.join(directory, folder))])
# Count remaining folders in train and val directories
train_remaining_count = count_remaining_folders(TRAIN_DIR)
val_remaining_count = count_remaining_folders(VAL_DIR)
# Display the count
print(f"Remaining folders in TRAIN directory: {train_remaining_count}")
print(f"Remaining folders in VAL directory: {val_remaining_count}")
Now, let’s set up image transformations for preprocessing. This will help standardize the images before feeding them into the model. The preprocessing steps include resizing the images, converting them to tensors, and normalizing the pixel values.
# Image Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transforms.Normalize(mean, std): Normalizes the image pixel values based on the given mean and standard deviation (these are typical values for pre-trained models like ResNet).
Sometimes, filenames can contain special characters that might cause issues. Let’s sanitize the filenames to ensure that they’re valid and safe for processing.
Next, we’ll define a function to check if a file is a valid image. This helps us make sure that only image files are being processed.
def sanitize_filename(filename):
# Replace problematic characters
return filename.translate(str.maketrans('', '', string.punctuation))
# Function to check if a file is a valid image
def is_valid_file(x):
valid_extensions = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp']
return any(x.endswith(ext) for ext in valid_extensions)
# Load Datasets with additional folder name sanitization
def load_dataset_with_sanitization(root_dir):
# Sanitize the class names
classes = os.listdir(root_dir)
sanitized_classes = []
for class_name in classes:
sanitized_class_name = sanitize_filename(class_name)
if os.path.isdir(os.path.join(root_dir, class_name)) and any(is_valid_file(f) for f in os.listdir(os.path.join(root_dir, class_name))):
sanitized_classes.append(sanitized_class_name)
return sanitized_classes
# Loading train and validation datasets
train_classes = load_dataset_with_sanitization(TRAIN_DIR)
val_classes = load_dataset_with_sanitization(VAL_DIR)
# Print the number of classes (non-empty and sanitized)
print(f"Number of classes in TRAIN_DIR: {len(train_classes)}")
print(f"Number of classes in VAL_DIR: {len(val_classes)}")
Now, let’s load the datasets for both training and validation, ensuring that the class names (species names) are sanitized and only valid images are included:
# Loading train and validation datasets
train_classes = load_dataset_with_sanitization(TRAIN_DIR)
val_classes = load_dataset_with_sanitization(VAL_DIR)
# Print the number of classes (non-empty and sanitized)
print(f"Number of classes in TRAIN_DIR: {len(train_classes)}")
print(f"Number of classes in VAL_DIR: {len(val_classes)}")
At this stage, we’ve successfully prepared the dataset for training:
You’re now ready to move on to model training!
This section will walk through the process of setting up the training environment, defining key hyperparameters, initializing the model, and conducting the training and fine-tuning process for the Swin Transformer model.
Before starting the training process, it is essential to define the key hyperparameters. These values determine how the model will be trained and help in controlling the model’s performance. Here are the hyperparameters we’ll use:
# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 1e-4
VAL_FREQ = 5
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
The training and validation datasets are loaded with the transformations applied earlier. This ensures that images are resized, normalized, and converted to tensors.
Next, we create DataLoader instances for both the training and validation datasets. These will handle batching and shuffling for the training process.
train_dataset = datasets.ImageFolder(root=TRAIN_DIR, transform=transform, is_valid_file=is_valid_file)
val_dataset = datasets.ImageFolder(root=VAL_DIR, transform=transform, is_valid_file=is_valid_file)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
It’s important to verify that the labels in the training dataset are valid and in the correct range.
labels = train_dataset.targets
Num_classes = len(trin_dataset.classes)
Min_label = min(labels)
Max_label = max(labels)
if min_label < 0 or max_label >= num_classes"
print(f"Invalid labels! Labels should be in range[0, {num_classes -1 }]")
else:
print(f"labels are valid. Range: [0, {num_classes -1 }]")
We use the timm library to initialize the Swin Transformer model, pre-trained on ImageNet, and adapt it to our specific dataset by setting the number of output classes.
model = timm.create_model('swin_large_patch4_window7_224', pretrained=True, num_classes=len(train_dataset.classes))
# Move model to device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
Next, we set up variables to track the best model based on validation accuracy and implement early stopping to avoid overfitting.
best_val_accuracy = 0.0 # Track the best validation accuracy
best_model_path = "best_model.pth" # Path to save the best model
patience = 5 # Number of epochs to wait for improvement
epochs_without_improvement = 0 # Counter for early stopping
The training loop runs for the specified number of epochs. For each epoch, the model is trained and evaluated on the validation set. The best model is saved whenever there is an improvement in validation accuracy.
for epoch in range(EPOCHS):
print(f"\nEpoch {epoch + 1}/{EPOCHS}")
model.train()
train_loss, train_accuracy = 0.0, 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
if labels.min() < 0 or labels.max() >= num_classes:
print(f"Error in batch! Min: {labels.min()}, Max: {labels.max()}")
break
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs, 1)
train_accuracy += (predicted == labels).sum().item()
train_loss /= len(train_loader)
train_accuracy /= len(train_loader.dataset)
print(f"Training loss: {train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
model.eval()
val_loss, val_accuracy = 0.0, 0.0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
val_accuracy += (predicted == labels).sum().item()
val_loss /= len(val_loader)
val_accuracy /= len(val_loader.dataset)
print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")
if val_accuracy > best_val_accuracy:
best_val_accuracy = val_accuracy
torch.save(model.state_dict(), best_model_path)
print(f"Best model updated with accuracy: {best_val_accuracy:.4f}")
epochs_without_improvement = 0 # Reset counter if validation accuracy improves
else:
epochs_without_improvement += 1
# Early stopping check
if epochs_without_improvement >= patience:
print(f"Early stopping after {epoch + 1} epochs without improvement.")
break
print("Training complete! Best model saved.")
Training Accuracy: 0.97
Validation Accuracy: 0.74
For future use, save the mapping between class names and class indices for the training dataset. This is useful for consistency when the model is deployed or further processed.
# Path to save the mapping
class_to_idx_path = "class_to_idx.json"
# Save train_dataset.class_to_idx as JSON
with open(class_to_idx_path, "w") as f:
json.dump(train_dataset.class_to_idx, f)
print(f"✅ Saved class_to_idx mapping to {class_to_idx_path}")
In this section, we will focus on evaluating the model’s performance on a separate test dataset, and we will also test the model’s prediction for an individual image.
We need several libraries to handle image processing, dataset loading, model inference, and evaluation metrics:
import os
import json import torch import torch.nn.functional as F from torch.utils.data
import Dataset, DataLoader from torchvision import datasets, transforms from sklearn.metrics import precision_score, recall_score, f1_score import import timm
import matplotlib.pyplot as plt from PIL import Image
Here, we define important paths and constants, including the directory for the test images, the path for the saved model, and the class-to-index mapping file. We also define a label threshold to filter labels that are less than or equal to a specified value.
TEST_DIR = "planetnet_300k/plantnet_300k/images/test" # Update with the correct path
MODEL_PATH = "best_model.pth"
CLASS_TO_IDX_PATH = "class_to_idx.json"
BATCH_SIZE = 32
LABEL_THRESHOLD = 758 # Keep labels <= 758
Next, we load the class-to-index mapping from the previously saved JSON file. Then, we define the image transformation steps, which resize, normalize, and convert images into tensor format.
with open(CLASS_TO_IDX_PATH, "r") as f:
class_to_idx = json.load(f)
transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ])
Further, we load the test dataset and filter out images with labels exceeding the specified label threshold. Then, we create a custom FilteredDataset class to hold the filtered data.
full_test_dataset = datasets.ImageFolder(root=TEST_DIR, transform=transform)
full_test_dataset.class_to_idx = class_to_idx
filtered_data = [(img, label) for img, label in full_test_dataset if label <= LABEL_THRESHOLD]
class FilteredDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data) def __getitem__(self, idx): return self.data[idx] filtered_test_dataset = FilteredDataset(filtered_data) filtered_test_loader = DataLoader(filtered_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
Next, We load the pre-trained model and set it to evaluation mode for testing.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model("swin_large_patch4_window7_224", pretrained=False, num_classes=len(class_to_idx))
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()
criterion = torch.nn.CrossEntropyLoss()
We now evaluate the model on the test dataset. The model’s performance is calculated using test loss, accuracy, and additional metrics like precision, recall, and F1 score.
test_loss, test_accuracy = 0.0, 0.0
all_preds, all_labels = [], []
with torch.no_grad():
for inputs, labels in filtered_test_loader:
inputs, labels = inputs.to(device), labels.to(device)
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs, 1)
test_accuracy += (predicted == labels).sum().item()
# Store predictions and true labels
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
The test loss and accuracy are averaged over the entire dataset, while precision, recall, and F1 score are computed using the predicted and actual labels, providing a comprehensive evaluation of the model’s performance.
test_loss /= len(filtered_test_loader)
test_accuracy /= len(filtered_test_loader.dataset)
precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)
print(f" Test Results:")
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}") print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
Next, we will test the model’s prediction on a single image. We load the image, apply the same transformations, and pass it through the model to get the predicted label.
# Load image
image_path = "money_plant.jpeg" # path to image
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0).to(device)
Next, we perform inference by passing the image through the model and this prints the predicted label and species name.
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
predicted_label = predicted.item()
species_name = species_map.get(str(predicted_label), "Unknown_Species")
# Print result
print(f"Predicted Label: {predicted_label}, Species: {species_name}")
Finally, we visualize the input image and display the predicted label on it.
plt.imshow(Image.open(image_path))
plt.title(f"Predicted: {species_name}")
plt.axis('off')
plt.show()
After completing the testing phase, we see the following evaluation results:
Test Results:
Test Loss: 0.4564 Test Accuracy: 0.8767
Precision: 0.8825 Recall: 0.8767 F1 Score: 0.8646
We have tested this Image, with the output:
Training a model like the Swin Transformer requires technical expertise and careful setup. At Superteams.ai, we help businesses like yours ease the process of training models for real-world applications. Whether you’re focused on agriculture or other industries, we provide the support and resources needed to get the best results from your AI projects.
This blog was first published on Medium by Prachi Desai. It is meant to be a project that focuses on the environmental use cases of AI. You can see the link here: https://medium.com/@bishtprachi2003/fine-tuned-swin-transformer-for-plant-classification-25b674478f98