help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
pytorch-neuralnetworks
/
training
RSK World
pytorch-neuralnetworks
Neural networks with PyTorch
training
  • __pycache__
  • __init__.py917 B
  • advanced_trainer.py5.2 KB
  • callbacks.py5.6 KB
  • metrics.py3.9 KB
  • trainer.py5.4 KB
  • utils.py3.9 KB
advanced_trainer.py
training/advanced_trainer.py
Raw Download
Find: Go to:
"""
Advanced Training Features - PyTorch Neural Networks
Project: PyTorch Neural Networks
Author: RSK World
Website: https://rskworld.in
Email: help@rskworld.in
Phone: +91 93305 39277
Description: Advanced training features (gradient clipping, mixed precision, etc.)
"""

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from .trainer import Trainer


class AdvancedTrainer(Trainer):
    """
    Advanced trainer with additional features
    
    Project: PyTorch Neural Networks
    Author: RSK World
    Website: https://rskworld.in
    """
    
    def __init__(self, model, criterion, optimizer, device, 
                 gradient_clip=None, use_mixed_precision=False):
        """
        Initialize advanced trainer
        
        Args:
            model: PyTorch model
            criterion: Loss function
            optimizer: Optimizer
            device: Device to train on
            gradient_clip: Maximum gradient norm for clipping
            use_mixed_precision: Whether to use mixed precision training
        """
        super().__init__(model, criterion, optimizer, device)
        self.gradient_clip = gradient_clip
        self.use_mixed_precision = use_mixed_precision
        self.scaler = GradScaler() if use_mixed_precision else None
    
    def train_epoch(self, train_loader):
        """
        Train for one epoch with advanced features
        
        Args:
            train_loader: DataLoader for training data
            
        Returns:
            Average training loss and accuracy
        """
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.use_mixed_precision:
                # Mixed precision training
                with autocast():
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, labels)
                
                self.scaler.scale(loss).backward()
                
                # Gradient clipping
                if self.gradient_clip:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip)
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard training
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                
                # Gradient clipping
                if self.gradient_clip:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip)
                
                self.optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        
        return epoch_loss, epoch_accuracy


class DistributedTrainer:
    """
    Distributed training trainer (for multi-GPU training)
    
    Project: PyTorch Neural Networks
    Author: RSK World
    Website: https://rskworld.in
    """
    
    def __init__(self, model, criterion, optimizer, device):
        """
        Initialize distributed trainer
        
        Args:
            model: PyTorch model
            criterion: Loss function
            optimizer: Optimizer
            device: Device to train on
        """
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        
        # Wrap model for distributed training
        if torch.cuda.device_count() > 1:
            self.model = nn.DataParallel(self.model)
            print(f"Using {torch.cuda.device_count()} GPUs")
        
        self.model.to(device)
    
    def train_epoch(self, train_loader):
        """Train for one epoch (distributed)"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        
        return epoch_loss, epoch_accuracy

161 lines•5.2 KB
python

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer