classification_nn.py
- class classification_nn.MyDataset(data_dir, before_trans)
Bases:
Dataset
A custom dataset class to handle image data for training and testing.
- Attributes:
imgs (list): A list of preprocessed image tensors. labels (list): A list of corresponding labels as tensors.
- Args:
data_dir (str): The root directory containing subdirectories for each label.
- Methods:
__getitem__(idx): Returns the image tensor and corresponding label at the given index. __len__(): Returns the total number of samples in the dataset.
- classification_nn.augment_data(weights: VGG16_Weights) tuple
Creates a set of image transformations for data augmentation.
This function applies a series of random transformations to input images to enhance dataset variability. It starts with the default preprocessing transformations from the given VGG16 weights and adds random augmentations such as flipping, rotation, color jittering, and cropping.
- Args:
- weights (VGG16_Weights): The pretrained VGG16 weights, which include
default image preprocessing transformations.
- Returns:
- tuple:
before_trans (Callable): Default preprocessing transformations (e.g., resizing and normalization).
random_trans (torchvision.transforms.Compose): Composed transformations that include random augmentations for training.
- Transformations Applied:
RandomHorizontalFlip: Randomly flips the image horizontally.
RandomRotation: Rotates the image by a random angle up to ±20 degrees.
ColorJitter: Randomly adjusts brightness, contrast, saturation, and hue.
RandomResizedCrop: Randomly crops and resizes the image with a scale between 80% and 100%.
RandomVerticalFlip: Randomly flips the image vertically.
Normalize: Normalizes the image using ImageNet mean and standard deviation.
- Example:
>>> weights = VGG16_Weights.DEFAULT >>> before_trans, random_trans = augment_data(weights) >>> print(before_trans) >>> print(random_trans)
- classification_nn.compile_model(fruit_model: Sequential) tuple
Prepares and compiles a model for training.
This function sets up the model by defining the loss function, optimizer, and applying torch.compile for optimized execution. The model is moved to the appropriate device (CPU or GPU).
- Args:
fruit_model (nn.Sequential): The neural network model to be compiled.
- Returns:
- tuple:
fruit_model (torch.nn.Module): The compiled model optimized for faster execution.
loss_function (torch.nn.CrossEntropyLoss): The loss function used for classification tasks.
optimizer (torch.optim.Adam): The Adam optimizer for updating model weights.
- Example:
>>> model, loss_fn, optimizer = compile_model(fruit_model) >>> print(model) >>> print(loss_fn) >>> print(optimizer)
- classification_nn.create_extended_model(vgg_model: Sequential) Sequential
Extends a pretrained VGG model to customize the classification head for a specific task.
This function modifies a VGG model to adapt it for a custom classification problem with 6 classes. It retains the original feature extraction layers and adds new fully connected layers to replace part of the classifier.
- Args:
vgg_model (nn.Sequential): A pretrained VGG model, typically loaded using torchvision.
- Returns:
nn.Sequential: The extended model with a custom classification head.
- Example:
>>> from torchvision.models import vgg16, VGG16_Weights >>> weights = VGG16_Weights.DEFAULT >>> vgg_model = vgg16(weights=weights) >>> extended_model = create_extended_model(vgg_model) >>> print(extended_model) Sequential( (0): Sequential( ... ) # Feature extraction layers (1): AdaptiveAvgPool2d( ... ) (2): Flatten(start_dim=1, end_dim=-1) (3): Sequential( ... ) # Part of VGG classifier (4): Linear(in_features=4096, out_features=500, bias=True) (5): ReLU() (6): Linear(in_features=500, out_features=6, bias=True) )
- classification_nn.download_dataset_from_cloud(url: str, folder_name: str) None
Downloads and extracts a dataset from a cloud storage URL.
Parameters: - url (str): The URL of the file to download. - folder_name (str): The folder where the dataset will be extracted.
Raises: - FileNotFoundError: If the downloaded file cannot be found. - zipfile.BadZipFile: If the ZIP file is invalid or corrupted.
- classification_nn.download_pretrained_model() tuple
Downloads and initializes the VGG16 model with pretrained weights.
This function uses the default weights of the VGG16 model provided by PyTorch’s torchvision.models library. It initializes the VGG16 model and prints its architecture.
- Returns:
- tuple:
vgg_model (torchvision.models.VGG): The initialized VGG16 model with pretrained weights.
weights (torchvision.models.VGG16_Weights): The default weights used for the VGG16 model.
- Example:
>>> model, weights = download_pretrained_model() VGG( (features): Sequential( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ... ) )
- classification_nn.image_inference(model: Sequential, before_trans: callable, image: str) int
Performs image classification inference using a trained model.
This function reads an input image, applies the necessary preprocessing transformations, and uses the trained model to predict the class of the image.
- Args:
model (torch.nn.Module): The trained neural network model for inference. before_trans (callable): Preprocessing transformations (e.g., normalization, resizing) compatible with the model’s input requirements. image (str): Path to the image file to be classified.
- Returns:
int: The predicted class index of the input image.
- Process:
The image is loaded using torchvision.io.read_image in RGB format.
The preprocessing transformations (before_trans) are applied to the image.
The image tensor is sent to the device (CPU/GPU) and passed through the model.
The class with the highest probability is extracted using torch.argmax.
- Notes:
The model is set to evaluation mode (model.eval()) to disable dropout and batch normalization.
The input image tensor is expanded to include a batch dimension using unsqueeze(0).
- Example:
>>> model = ... # Load a trained model >>> before_trans = weights.transforms() # Preprocessing transformations >>> predicted_class = image_inference(model, before_trans, "data/fruits/apple.jpg") >>> print(f"Predicted class index: {predicted_class}")
- classification_nn.main()
- classification_nn.prepare_dataset(before_trans: VGG16_Weights) tuple
Prepares the training and validation datasets with data loaders.
This function loads the training and validation datasets from specified directory paths using a custom dataset class (MyDataset). It applies the default transformations provided by the VGG16 weights to preprocess the images and creates PyTorch DataLoaders for efficient batching.
- Args:
- before_trans (VGG16_Weights): Default preprocessing transformations
obtained from the VGG16 pretrained weights.
- Returns:
- tuple:
train_loader (torch.utils.data.DataLoader): DataLoader for the training set.
valid_loader (torch.utils.data.DataLoader): DataLoader for the validation set.
train_N (int): Number of samples in the training dataset.
valid_N (int): Number of samples in the validation dataset.
- Notes:
Training data is shuffled to introduce randomness in batches.
Validation data is not shuffled to maintain consistent evaluation.
- Example:
>>> weights = VGG16_Weights.DEFAULT >>> train_loader, valid_loader, train_N, valid_N = prepare_dataset(weights.transforms()) >>> print(f"Training samples: {train_N}, Validation samples: {valid_N}")
- classification_nn.train_model(vgg_model: Sequential, fruit_model: Sequential, train_loader, valid_loader, train_N: int, valid_N: int, random_trans, optimizer, loss_function) Sequential
Trains and fine-tunes a neural network model using a two-phase approach.
This function trains a fruit classification model in two phases: 1. Initial Training Phase: Trains the top layers of the model while keeping the base model frozen. 2. Fine-Tuning Phase: Unfreezes the base model (VGG16) and continues training with a lower learning rate.
- Args:
vgg_model (torch.nn.Module): Pretrained VGG16 model used as the base model. fruit_model (torch.nn.Module): Extended model for fruit classification. train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset. valid_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset. train_N (int): Total number of samples in the training dataset. valid_N (int): Total number of samples in the validation dataset. random_trans (torchvision.transforms.Compose): Data augmentation transformations applied during training. optimizer (torch.optim.Optimizer): Optimizer for updating model weights. loss_function (torch.nn.Module): Loss function to compute training and validation loss.
- Returns:
torch.nn.Module: The trained and fine-tuned fruit classification model.
- Process:
The model is trained for 12 epochs with the base VGG model frozen.
The base VGG model is unfrozen, and the model is fine-tuned for 4 additional epochs with a reduced learning rate.
Validation is performed at the end of each epoch to monitor performance.
- Example:
>>> fruit_model = train_model(vgg_model, fruit_model, train_loader, valid_loader, ... train_N, valid_N, random_trans, optimizer, loss_function)