help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
pytorch-neuralnetworks
/
training
/
__pycache__
RSK World
pytorch-neuralnetworks
Neural networks with PyTorch
__pycache__
  • __init__.cpython-313.pyc915 B
callbacks.py
training/callbacks.py
Raw Download
Find: Go to:
"""
Training Callbacks - PyTorch Neural Networks
Project: PyTorch Neural Networks
Author: RSK World
Website: https://rskworld.in
Email: help@rskworld.in
Phone: +91 93305 39277
Description: Callback utilities for training (early stopping, model checkpointing, etc.)
"""

import torch
import os
import numpy as np


class EarlyStopping:
    """
    Early stopping callback to stop training when validation loss stops improving
    
    Project: PyTorch Neural Networks
    Author: RSK World
    Website: https://rskworld.in
    """
    
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=True):
        """
        Initialize early stopping
        
        Args:
            patience: Number of epochs to wait before stopping
            min_delta: Minimum change to qualify as an improvement
            restore_best_weights: Whether to restore best model weights
            verbose: Whether to print messages
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        self.early_stop = False
    
    def __call__(self, val_loss, model):
        """
        Check if training should stop
        
        Args:
            val_loss: Current validation loss
            model: Model to save weights from
            
        Returns:
            True if training should stop, False otherwise
        """
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter}/{self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                if self.restore_best_weights:
                    model.load_state_dict(self.best_weights)
                if self.verbose:
                    print(f'Early stopping triggered. Best loss: {self.best_loss:.4f}')
        
        return self.early_stop
    
    def save_checkpoint(self, model):
        """Save model checkpoint"""
        self.best_weights = model.state_dict().copy()


class ModelCheckpoint:
    """
    Model checkpoint callback to save model at specified intervals
    
    Project: PyTorch Neural Networks
    Author: RSK World
    Website: https://rskworld.in
    """
    
    def __init__(self, filepath, monitor='val_loss', save_best_only=True, verbose=True):
        """
        Initialize model checkpoint
        
        Args:
            filepath: Path to save the model
            monitor: Metric to monitor ('val_loss', 'val_accuracy', etc.)
            save_best_only: Whether to save only the best model
            verbose: Whether to print messages
        """
        self.filepath = filepath
        self.monitor = monitor
        self.save_best_only = save_best_only
        self.verbose = verbose
        self.best_score = None
        self.monitor_op = None
    
    def __call__(self, score, model, optimizer, epoch):
        """
        Save model checkpoint
        
        Args:
            score: Current score for monitored metric
            model: Model to save
            optimizer: Optimizer to save
            epoch: Current epoch number
        """
        # Determine if higher or lower is better
        if 'loss' in self.monitor:
            if self.monitor_op is None:
                self.monitor_op = lambda x, y: x < y
            is_better = self.monitor_op(score, self.best_score if self.best_score is not None else float('inf'))
        else:
            if self.monitor_op is None:
                self.monitor_op = lambda x, y: x > y
            is_better = self.monitor_op(score, self.best_score if self.best_score is not None else float('-inf'))
        
        if not self.save_best_only or is_better or self.best_score is None:
            self.best_score = score
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'score': score,
            }
            torch.save(checkpoint, self.filepath)
            if self.verbose:
                print(f'Model checkpoint saved: {self.filepath} (score: {score:.4f})')


class LearningRateScheduler:
    """
    Learning rate scheduler callback
    
    Project: PyTorch Neural Networks
    Author: RSK World
    Website: https://rskworld.in
    """
    
    def __init__(self, scheduler, verbose=True):
        """
        Initialize learning rate scheduler
        
        Args:
            scheduler: PyTorch learning rate scheduler
            verbose: Whether to print learning rate changes
        """
        self.scheduler = scheduler
        self.verbose = verbose
    
    def step(self, metrics=None):
        """
        Step the scheduler
        
        Args:
            metrics: Optional metrics for ReduceLROnPlateau
        """
        if metrics is not None:
            self.scheduler.step(metrics)
        else:
            self.scheduler.step()
        
        if self.verbose:
            current_lr = self.scheduler.optimizer.param_groups[0]['lr']
            print(f'Learning rate: {current_lr:.6f}')

175 lines•5.6 KB
python

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer