https://colab.research.google.com/drive/1wwzoc1OhHEMLnC0iHTqjBR1XJjHCzzkR
This notebook will guide you through the complete process of building, training, and evaluating an image recognition model (a Convolutional Neural Network, or CNN) to identify plant diseases from images.
We will use TensorFlow and Keras to:
Load Data: Download a sample dataset of potato plant leaves (Healthy, Early Blight, Late Blight).
Preprocess: Set up a data pipeline to load, augment, and batch images.
Build Model: Define a CNN architecture.
Train Model: Train the model on the data and validate its performance.
Evaluate: Visualize the model's accuracy/loss and test it on new images.
For this to run quickly, make sure you have a GPU enabled.
Go to Runtime in the menu bar.
Select Change runtime type.
Choose GPU from the "Hardware accelerator" dropdown.
(TEXT CELL 2)
First, we'll import all the necessary libraries. This includes tensorflow and keras for building the model, matplotlib for plotting, and utilities for downloading and managing files.
(CODE CELL 1)
Python
# ==============================================================================
# 1. SETUP AND IMPORTS
# ==============================================================================
# Import necessary libraries for deep learning, data handling, and visualization
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
import numpy as np
import os
import zipfile
import requests
from io import BytesIO
print(f"TensorFlow Version: {tf.__version__}")
(TEXT CELL 3)
We'll download a sample dataset of potato plant diseases. This dataset is conveniently packaged in a .zip file and contains three folders, one for each class:
Potato___Early_blight
Potato___Late_blight
Potato___healthy
Keras utilities can read images directly from folders like this, using the folder names as the labels.
Note: If you have your own 3 image sets, you can skip this step and upload your data. Just make sure it follows this directory structure:
/your_data_folder
/class_1
img1.jpg
...
/class_2
imgA.jpg
...
/class_3
imgX.jpg
...
You would then change the base_dir variable in the next step to point to /your_data_folder.
(CODE CELL 2)
Python
# ==============================================================================
# 2. DOWNLOAD AND PREPARE THE DATASET
# ==============================================================================
# URL of the dataset zip file
dataset_url = "https://storage.googleapis.com/plantdata/PlantVillagePotato.zip"
print("Downloading sample dataset...")
# Download and extract the data
response = requests.get(dataset_url)
with zipfile.ZipFile(BytesIO(response.content)) as z:
z.extractall("plant_diseases")
print("Dataset downloaded and extracted successfully.")
# Define the base directory where the images are stored
base_dir = "plant_diseases/PlantVillage"
print(f"Dataset path: {base_dir}")
print("Classes found:", os.listdir(base_dir))
(TEXT CELL 4)
Now we'll use a Keras utility, image_dataset_from_directory, to load our data. This is highly efficient.
We'll define a few key parameters:
BATCH_SIZE: How many images to process at one time.
IMG_SIZE: The resolution to which all images will be resized (e.g., 256x256 pixels).
VALIDATION_SPLIT: We'll hold back 20% (0.2) of the data to test the model's performance on images it hasn't "seen" during training.
We also .cache() and .prefetch() the datasets. This is a performance optimization that loads data into memory for faster training.
(CODE CELL 3)
Python
# ==============================================================================
# 3. DEFINE PARAMETERS AND LOAD DATA
# ==============================================================================
# Set key parameters for the model and data loaders
BATCH_SIZE = 32
IMG_SIZE = (256, 256) # All images will be resized to this size
VALIDATION_SPLIT = 0.2 # Use 20% of the data for validation
# Load the training dataset from the directory, splitting it for validation
print("\nLoading training and validation datasets...")
train_dataset = tf.keras.utils.image_dataset_from_directory(
base_dir,
validation_split=VALIDATION_SPLIT,
subset="training",
seed=123, # Using a seed ensures the split is the same every time
image_size=IMG_SIZE,
batch_size=BATCH_SIZE
)
# Load the validation dataset
validation_dataset = tf.keras.utils.image_dataset_from_directory(
base_dir,
validation_split=VALIDATION_SPLIT,
subset="validation",
seed=123,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE
)
# Get the class names from the directory structure
class_names = train_dataset.class_names
num_classes = len(class_names)
print(f"Found {num_classes} classes: {class_names}")
# Configure the dataset for performance by caching and prefetching
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.cache().prefetch(buffer_size=AUTOTUNE)
(TEXT CELL 5)
It's time to build our model. We will use a Sequential model, which is a simple stack of layers.
We'll add data_augmentation layers first. These randomly flip, rotate, and zoom images during training. This helps prevent overfitting by teaching the model to recognize a diseased leaf even if it's upside down or at a different angle.
The core of our model consists of:
Rescaling: Normalizes pixel values from [0, 255] to [0, 1], which is better for the network.
Conv2D: These are convolutional layers that find patterns (like edges, spots, or textures).
MaxPooling2D: These layers downsample the image, keeping the most important features and reducing computation.
Dropout: Randomly "turns off" some neurons during training to further prevent overfitting.
Flatten: Converts the 2D image data into a 1D list.
Dense: A standard, fully connected "brain" layer.
Output Layer: The final Dense layer has num_classes (3, in our case) nodes and a softmax activation, which outputs a probability for each class (e.g., 80% Late Blight, 15% Early Blight, 5% Healthy).
(CODE CELL 4)
Python
# ==============================================================================
# 4. BUILD THE CONVOLUTIONAL NEURAL NETWORK (CNN) MODEL
# ==============================================================================
print("\nBuilding the CNN model...")
# Define data augmentation layers to prevent overfitting
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal", input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
]
)
model = Sequential([
# Start with data augmentation
data_augmentation,
# Rescale pixel values from [0, 255] to [0, 1]
layers.Rescaling(1./255),
# First convolutional block
layers.Conv2D(32, (3, 3), padding='same', activation='relu'),
layers.MaxPooling2D(),
# Second convolutional block
layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
layers.MaxPooling2D(),
# Third convolutional block
layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
layers.MaxPooling2D(),
# Add a dropout layer to reduce overfitting
layers.Dropout(0.2),
# Flatten the results to feed into a dense layer
layers.Flatten(),
# Fully connected layer
layers.Dense(256, activation='relu'),
# Output layer with a node for each class
layers.Dense(num_classes, activation='softmax')
])
(TEXT CELL 6)
Before we can train the model, we need to compile it. This step configures the training process.
optimizer: We use 'adam', which is an efficient and popular algorithm for adjusting the model's internal parameters.
loss: This is the function we want to minimize. SparseCategoricalCrossentropy is the standard choice for multi-class classification.
metrics: We tell the model to report its 'accuracy' at each step.
We'll also print a model.summary() to see the architecture and the number of parameters.
(CODE CELL 5)
Python
# ==============================================================================
# 5. COMPILE THE MODEL
# ==============================================================================
# Compiling configures the model for training.
print("Compiling the model...")
model.compile(
optimizer='adam', # Adam is a popular and effective optimizer
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'] # We want to monitor the accuracy
)
# Print a summary of the model architecture
model.summary()
(TEXT CELL 7)
This is the main event! We'll call model.fit() to start training.
train_dataset: The data to train on.
validation_data: The data to test against (the 20% we held back).
epochs: An "epoch" is one full pass over the entire dataset. We'll do 15 passes.
Watch the output: accuracy should go up, and loss should go down. We also want to see val_accuracy (validation accuracy) increase, which shows the model is generalizing well.
(CODE CELL 6)
Python
# ==============================================================================
# 6. TRAIN THE MODEL
# ==============================================================================
print("\nStarting model training...")
EPOCHS = 15 # The number of times the model will see the entire dataset
# The `fit` method trains the model
# This will take a few minutes, especially on a GPU
history = model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=EPOCHS
)
print("Model training finished.")
(TEXT CELL 8)
Training is done! But how well did it do?
A "good" training run is one where the validation accuracy (blue line) tracks closely with the training accuracy (orange line), and both are high.
If the validation accuracy flatlines or goes down while training accuracy keeps rising, it's a sign of overfitting. Our plots should look pretty good, thanks to data augmentation and dropout.
(CODE CELL 7)
Python
# ==============================================================================
# 7. VISUALIZE TRAINING RESULTS
# ==============================================================================
# Plotting the training history helps us see how the model learned over time.
print("\nVisualizing training results...")
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(EPOCHS)
plt.figure(figsize=(12, 5))
# Plot training and validation accuracy
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# Plot training and validation loss
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.suptitle("Model Training Performance", fontsize=16)
plt.show()
(TEXT CELL 9)
Finally, let's test our trained model on an image from the validation set.
The code will:
Grab a batch of images.
Use model.predict_on_batch to get predictions.
Display the first image from the batch.
Print the model's prediction (and confidence) alongside the true, correct label.
The title will be green if the prediction is correct and red if it's wrong.
(CODE CELL 8)
Python
# ==============================================================================
# 8. MAKE A PREDICTION ON A NEW IMAGE
# ==============================================================================
# Let's test the model on one of the validation images.
# Retrieve a batch of images from the validation set
image_batch, label_batch = validation_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch)
# Display the first image from the batch, its true label, and the model's prediction
plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1)
# Get the first image and its true label
first_image = image_batch[0]
true_label = class_names[label_batch[0]]
# Get the model's prediction
predicted_class_index = np.argmax(predictions[0])
predicted_class = class_names[predicted_class_index]
confidence = 100 * np.max(predictions[0])
# Display the image and labels
plt.imshow(first_image.astype("uint8"))
plt.title(f"Prediction: {predicted_class} ({confidence:.2f}%)\nTrue Label: {true_label}",
color=("green" if predicted_class == true_label else "red"))
plt.axis("off")
plt.show()