help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
pytorch-neuralnetworks
RSK World
pytorch-neuralnetworks
Neural networks with PyTorch
pytorch-neuralnetworks
  • __pycache__
  • data
  • examples
  • models
  • notebooks
  • saved_models
  • training
  • utils
  • .gitignore866 B
  • FEATURES.md4.5 KB
  • GITHUB_RELEASE_INSTRUCTIONS.md1.8 KB
  • LICENSE1.3 KB
  • README.md4.8 KB
  • RELEASE_NOTES_v1.0.0.md3.1 KB
  • deploy.py4.3 KB
  • example.py2.4 KB
  • main.py3.7 KB
  • requirements.txt377 B
deploy.py
deploy.py
Raw Download
Find: Go to:
"""
Model Deployment Script
Project: PyTorch Neural Networks
Author: RSK World
Website: https://rskworld.in
Email: help@rskworld.in
Phone: +91 93305 39277
Description: Script for deploying trained models for inference
"""

import torch
import torch.nn as nn
import argparse
import os
import sys

from models.basic_nn import BasicNeuralNetwork
from models.cnn import SimpleCNN
from models.rnn import SimpleRNN


def load_model(model_type, model_path, **kwargs):
    """
    Load a trained model
    
    Args:
        model_type: Type of model ('basic', 'cnn', 'rnn')
        model_path: Path to saved model
        **kwargs: Additional model parameters
        
    Returns:
        Loaded model
    """
    if model_type == 'basic':
        model = BasicNeuralNetwork(
            input_size=kwargs.get('input_size', 20),
            hidden_size=kwargs.get('hidden_size', 64),
            output_size=kwargs.get('output_size', 3)
        )
    elif model_type == 'cnn':
        model = SimpleCNN(num_classes=kwargs.get('num_classes', 10))
    elif model_type == 'rnn':
        model = SimpleRNN(
            input_size=kwargs.get('input_size', 5),
            hidden_size=kwargs.get('hidden_size', 64),
            num_layers=kwargs.get('num_layers', 2),
            num_classes=kwargs.get('num_classes', 3)
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # Load state dict
    checkpoint = torch.load(model_path, map_location='cpu')
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    return model


def predict(model, input_data, device='cpu'):
    """
    Make predictions using the model
    
    Args:
        model: Trained model
        input_data: Input tensor
        device: Device to run on
        
    Returns:
        Predictions and probabilities
    """
    model = model.to(device)
    input_data = input_data.to(device)
    
    with torch.no_grad():
        outputs = model(input_data)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
    
    return predicted, probabilities


def main():
    parser = argparse.ArgumentParser(description='Deploy PyTorch Model - RSK World')
    parser.add_argument('--model_type', type=str, required=True,
                       choices=['basic', 'cnn', 'rnn'],
                       help='Type of model to deploy')
    parser.add_argument('--model_path', type=str, required=True,
                       help='Path to saved model')
    parser.add_argument('--input_file', type=str,
                       help='Path to input data file (numpy/torch format)')
    parser.add_argument('--device', type=str, default='cpu',
                       choices=['cpu', 'cuda'],
                       help='Device to run inference on')
    
    args = parser.parse_args()
    
    # Check if model file exists
    if not os.path.exists(args.model_path):
        print(f"Error: Model file not found at {args.model_path}")
        sys.exit(1)
    
    # Load model
    print(f"Loading {args.model_type} model from {args.model_path}...")
    model = load_model(args.model_type, args.model_path)
    print("Model loaded successfully!")
    
    # Load input data or create sample
    if args.input_file and os.path.exists(args.input_file):
        input_data = torch.load(args.input_file)
        print(f"Loaded input data from {args.input_file}")
    else:
        # Create sample input based on model type
        if args.model_type == 'basic':
            input_data = torch.randn(1, 20)
        elif args.model_type == 'cnn':
            input_data = torch.randn(1, 1, 28, 28)
        elif args.model_type == 'rnn':
            input_data = torch.randn(1, 10, 5)
        print("Using sample input data")
    
    # Make prediction
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    predicted, probabilities = predict(model, input_data, device)
    
    print(f"\nInput shape: {input_data.shape}")
    print(f"Predicted class: {predicted.item()}")
    print(f"Probabilities: {probabilities.squeeze().tolist()}")


if __name__ == '__main__':
    main()

137 lines•4.3 KB
python

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer