help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
text-classification
/
scripts
RSK World
text-classification
Text Classification Dataset - NLP + Multi-Class Classification + Machine Learning
scripts
  • __init__.py2.3 KB
  • active_learning.py26.8 KB
  • api_server.py12.7 KB
  • batch_processor.py16.4 KB
  • data_augmentation.py18.2 KB
  • data_quality.py20 KB
  • deep_learning.py24.2 KB
  • hyperparameter_tuning.py22.5 KB
  • model_explainability.py17.9 KB
  • preprocessing.py8.7 KB
  • train_classifier.py13.8 KB
  • train_transformers.py12.5 KB
  • visualizations.py19 KB
model_explainability.pytrain_classifier.py
scripts/model_explainability.py
Raw Download
Find: Go to:
"""
================================================================================
Text Classification Dataset - Model Explainability Module
================================================================================
Project: Text Classification Dataset
Category: Text Data / NLP

Author: Molla Samser
Designer & Tester: Rima Khatun
Website: https://rskworld.in
Email: help@rskworld.in | support@rskworld.in
Phone: +91 93305 39277

Copyright (c) 2026 RSK World - All Rights Reserved
Content used for educational purposes only.

Features:
- LIME Text Explainer
- Feature Attribution
- Word Importance Highlighting
- Prediction Confidence Analysis
- Attention Visualization (for transformers)

Created: December 2026
================================================================================
"""

import re
import string
from typing import List, Dict, Tuple, Optional, Callable
import numpy as np

# Project information
__author__ = "Molla Samser"
__website__ = "https://rskworld.in"
__email__ = "help@rskworld.in"

# Category mapping
CATEGORIES = {
    0: 'Technology',
    1: 'Sports',
    2: 'Politics',
    3: 'Entertainment',
    4: 'Business',
    5: 'Science'
}


class TextExplainer:
    """
    Text classification model explainer using LIME-like approach.
    
    Explains predictions by analyzing word importance through
    perturbation-based feature attribution.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(
        self,
        classifier_fn: Callable,
        class_names: List[str] = None,
        num_samples: int = 1000,
        random_state: int = 42
    ):
        """
        Initialize the explainer.
        
        Args:
            classifier_fn: Function that takes list of texts and returns probabilities
            class_names: List of class names
            num_samples: Number of perturbations to generate
            random_state: Random seed
        """
        self.classifier_fn = classifier_fn
        self.class_names = class_names or list(CATEGORIES.values())
        self.num_samples = num_samples
        np.random.seed(random_state)
    
    def _tokenize(self, text: str) -> List[str]:
        """Simple word tokenization."""
        # Remove punctuation and split
        text = re.sub(r'[^\w\s]', '', text.lower())
        return text.split()
    
    def _perturb_text(
        self,
        words: List[str],
        num_samples: int
    ) -> Tuple[List[str], np.ndarray]:
        """
        Generate perturbed versions of text by randomly removing words.
        
        Args:
            words: List of words in original text
            num_samples: Number of perturbations
            
        Returns:
            Tuple of (perturbed_texts, perturbation_matrix)
        """
        num_words = len(words)
        
        # Generate binary mask matrix (1 = keep word, 0 = remove)
        # Each row is a perturbation
        perturbation_matrix = np.random.binomial(1, 0.5, size=(num_samples, num_words))
        
        # Always include original text
        perturbation_matrix[0] = np.ones(num_words)
        
        perturbed_texts = []
        for row in perturbation_matrix:
            perturbed_words = [w for w, keep in zip(words, row) if keep]
            perturbed_texts.append(' '.join(perturbed_words) if perturbed_words else words[0])
        
        return perturbed_texts, perturbation_matrix
    
    def _compute_weights(
        self,
        perturbation_matrix: np.ndarray,
        kernel_width: float = 25.0
    ) -> np.ndarray:
        """
        Compute weights for each perturbation based on distance from original.
        
        Args:
            perturbation_matrix: Binary matrix of perturbations
            kernel_width: Width of exponential kernel
            
        Returns:
            Array of weights
        """
        # Distance is number of removed words
        distances = np.sum(perturbation_matrix == 0, axis=1)
        weights = np.exp(-distances ** 2 / kernel_width ** 2)
        return weights
    
    def _fit_linear_model(
        self,
        perturbation_matrix: np.ndarray,
        predictions: np.ndarray,
        weights: np.ndarray,
        target_class: int
    ) -> np.ndarray:
        """
        Fit weighted linear model to get feature importance.
        
        Args:
            perturbation_matrix: Binary feature matrix
            predictions: Model predictions for perturbations
            weights: Sample weights
            target_class: Class to explain
            
        Returns:
            Feature importance scores
        """
        from sklearn.linear_model import Ridge
        
        # Get predictions for target class
        y = predictions[:, target_class]
        
        # Fit weighted ridge regression
        model = Ridge(alpha=1.0)
        model.fit(perturbation_matrix, y, sample_weight=weights)
        
        return model.coef_
    
    def explain(
        self,
        text: str,
        num_features: int = 10,
        target_class: Optional[int] = None
    ) -> Dict:
        """
        Explain a prediction for the given text.
        
        Args:
            text: Text to explain
            num_features: Number of top features to return
            target_class: Class to explain (None = predicted class)
            
        Returns:
            Dictionary with explanation results
        """
        # Tokenize
        words = self._tokenize(text)
        
        if len(words) == 0:
            return {'error': 'Empty text after tokenization'}
        
        # Generate perturbations
        perturbed_texts, perturbation_matrix = self._perturb_text(
            words, self.num_samples
        )
        
        # Get predictions for all perturbations
        predictions = self.classifier_fn(perturbed_texts)
        
        # Get original prediction
        original_probs = predictions[0]
        predicted_class = np.argmax(original_probs)
        
        if target_class is None:
            target_class = predicted_class
        
        # Compute weights
        weights = self._compute_weights(perturbation_matrix)
        
        # Fit linear model
        importances = self._fit_linear_model(
            perturbation_matrix, predictions, weights, target_class
        )
        
        # Get top features
        top_indices = np.argsort(np.abs(importances))[-num_features:][::-1]
        
        word_importance = [
            {
                'word': words[i],
                'importance': float(importances[i]),
                'direction': 'positive' if importances[i] > 0 else 'negative'
            }
            for i in top_indices if i < len(words)
        ]
        
        return {
            'text': text,
            'predicted_class': int(predicted_class),
            'predicted_category': self.class_names[predicted_class],
            'explained_class': int(target_class),
            'explained_category': self.class_names[target_class],
            'confidence': float(original_probs[predicted_class]),
            'probabilities': {
                self.class_names[i]: float(p)
                for i, p in enumerate(original_probs)
            },
            'word_importance': word_importance,
            'top_positive_words': [
                w for w in word_importance if w['direction'] == 'positive'
            ][:5],
            'top_negative_words': [
                w for w in word_importance if w['direction'] == 'negative'
            ][:5]
        }
    
    def explain_with_html(
        self,
        text: str,
        num_features: int = 10,
        target_class: Optional[int] = None
    ) -> str:
        """
        Generate HTML visualization of explanation.
        
        Args:
            text: Text to explain
            num_features: Number of features
            target_class: Class to explain
            
        Returns:
            HTML string with highlighted text
        """
        explanation = self.explain(text, num_features, target_class)
        
        if 'error' in explanation:
            return f"<p>Error: {explanation['error']}</p>"
        
        # Create word importance lookup
        word_scores = {
            item['word'].lower(): item['importance']
            for item in explanation['word_importance']
        }
        
        # Build HTML
        html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8;">']
        
        words = text.split()
        for word in words:
            clean_word = re.sub(r'[^\w]', '', word.lower())
            if clean_word in word_scores:
                score = word_scores[clean_word]
                # Color based on importance
                if score > 0:
                    intensity = min(255, int(abs(score) * 500))
                    color = f'rgba(34, 197, 94, {min(1, abs(score) * 3)})'  # Green
                else:
                    intensity = min(255, int(abs(score) * 500))
                    color = f'rgba(239, 68, 68, {min(1, abs(score) * 3)})'  # Red
                
                html_parts.append(
                    f'<span style="background-color: {color}; padding: 2px 4px; '
                    f'border-radius: 3px; margin: 0 2px;">{word}</span>'
                )
            else:
                html_parts.append(f' {word}')
        
        html_parts.append('</div>')
        
        # Add prediction info
        html_parts.append('<div style="margin-top: 20px; padding: 15px; '
                         'background: #1a1333; border-radius: 8px; color: #f8fafc;">')
        html_parts.append(f'<strong>Prediction:</strong> {explanation["predicted_category"]} '
                         f'({explanation["confidence"]:.1%} confidence)')
        html_parts.append('</div>')
        
        return '\n'.join(html_parts)


class AttentionVisualizer:
    """
    Visualize attention weights from transformer models.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(self, model, tokenizer):
        """
        Initialize with transformer model and tokenizer.
        
        Args:
            model: Transformer model
            tokenizer: Tokenizer
        """
        self.model = model
        self.tokenizer = tokenizer
    
    def get_attention_weights(self, text: str) -> Dict:
        """
        Extract attention weights for input text.
        
        Args:
            text: Input text
            
        Returns:
            Dictionary with attention analysis
        """
        import torch
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        )
        
        # Get attention
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
        
        # Average attention across layers and heads
        attentions = outputs.attentions  # Tuple of (batch, heads, seq, seq)
        avg_attention = torch.stack(attentions).mean(dim=(0, 1, 2))
        
        # Get tokens
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Attention scores per token (excluding special tokens)
        token_scores = avg_attention[1:-1].numpy()  # Remove [CLS] and [SEP]
        tokens = tokens[1:-1]
        
        return {
            'tokens': tokens,
            'attention_scores': token_scores.tolist(),
            'top_attended': sorted(
                zip(tokens, token_scores),
                key=lambda x: x[1],
                reverse=True
            )[:10]
        }


class PredictionAnalyzer:
    """
    Analyze prediction confidence and uncertainty.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(self, classifier_fn: Callable, class_names: List[str] = None):
        """
        Initialize analyzer.
        
        Args:
            classifier_fn: Function that returns probabilities
            class_names: List of class names
        """
        self.classifier_fn = classifier_fn
        self.class_names = class_names or list(CATEGORIES.values())
    
    def analyze(self, text: str) -> Dict:
        """
        Comprehensive prediction analysis.
        
        Args:
            text: Input text
            
        Returns:
            Analysis results
        """
        probs = self.classifier_fn([text])[0]
        
        predicted_class = np.argmax(probs)
        sorted_indices = np.argsort(probs)[::-1]
        
        # Entropy (uncertainty measure)
        entropy = -np.sum(probs * np.log(probs + 1e-10))
        max_entropy = np.log(len(probs))
        normalized_entropy = entropy / max_entropy
        
        # Margin (difference between top 2 predictions)
        margin = probs[sorted_indices[0]] - probs[sorted_indices[1]]
        
        return {
            'text': text[:200] + '...' if len(text) > 200 else text,
            'prediction': {
                'class': int(predicted_class),
                'category': self.class_names[predicted_class],
                'confidence': float(probs[predicted_class])
            },
            'all_probabilities': {
                self.class_names[i]: float(probs[i])
                for i in sorted_indices
            },
            'uncertainty': {
                'entropy': float(entropy),
                'normalized_entropy': float(normalized_entropy),
                'margin': float(margin),
                'is_confident': margin > 0.3,
                'is_uncertain': normalized_entropy > 0.7
            },
            'ranked_predictions': [
                {
                    'rank': rank + 1,
                    'category': self.class_names[idx],
                    'probability': float(probs[idx])
                }
                for rank, idx in enumerate(sorted_indices)
            ]
        }


def explain_prediction(
    text: str,
    model_path: str,
    output_html: str = 'explanation.html'
):
    """
    Generate visual explanation for a prediction.
    
    Args:
        text: Text to explain
        model_path: Path to trained model
        output_html: Output HTML file
        
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    import joblib
    
    print(f"\n{'='*60}")
    print("Model Explainability - RSK World")
    print(f"Author: {__author__} | Website: {__website__}")
    print(f"{'='*60}\n")
    
    # Load model
    model_data = joblib.load(model_path)
    pipeline = model_data['pipeline']
    
    # Create classifier function
    def classifier_fn(texts):
        return pipeline.predict_proba(texts)
    
    # Create explainer
    explainer = TextExplainer(classifier_fn)
    
    # Generate explanation
    print(f"Analyzing text: {text[:50]}...")
    explanation = explainer.explain(text)
    
    # Print results
    print(f"\nPredicted: {explanation['predicted_category']} "
          f"({explanation['confidence']:.1%} confidence)")
    print("\nTop important words:")
    for item in explanation['word_importance'][:10]:
        direction = "↑" if item['direction'] == 'positive' else "↓"
        print(f"  {direction} {item['word']}: {item['importance']:.4f}")
    
    # Generate HTML
    html = explainer.explain_with_html(text)
    
    # Create full HTML document
    full_html = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Prediction Explanation - RSK World</title>
    <meta charset="UTF-8">
    <style>
        body {{
            font-family: 'Segoe UI', Arial, sans-serif;
            background: #0f0a1f;
            color: #f8fafc;
            padding: 40px;
            max-width: 800px;
            margin: 0 auto;
        }}
        h1 {{ color: #dc2626; }}
        .container {{
            background: #1a1333;
            padding: 30px;
            border-radius: 12px;
            border: 1px solid #352d54;
        }}
        .footer {{
            margin-top: 30px;
            text-align: center;
            color: #6b6882;
            font-size: 14px;
        }}
        .legend {{
            margin-top: 20px;
            padding: 15px;
            background: #231d3a;
            border-radius: 8px;
        }}
        .legend span {{
            display: inline-block;
            margin-right: 20px;
        }}
        .positive {{ color: #22c55e; }}
        .negative {{ color: #ef4444; }}
    </style>
</head>
<body>
    <h1>🔍 Prediction Explanation</h1>
    <div class="container">
        {html}
        <div class="legend">
            <strong>Legend:</strong>
            <span class="positive">■ Supports prediction</span>
            <span class="negative">■ Against prediction</span>
        </div>
    </div>
    <div class="footer">
        <p>Generated by RSK World Text Classification</p>
        <p>Author: {__author__} | Website: <a href="{__website__}" style="color: #dc2626;">{__website__}</a></p>
    </div>
</body>
</html>
"""
    
    with open(output_html, 'w', encoding='utf-8') as f:
        f.write(full_html)
    
    print(f"\nExplanation saved to: {output_html}")


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 2:
        text = sys.argv[1]
        model_path = sys.argv[2]
        explain_prediction(text, model_path)
    else:
        print("Usage: python model_explainability.py 'Your text here' model.joblib")
        print("\nDemo mode:")
        
        # Demo without model
        print(f"\n{'='*60}")
        print("Explainability Module Demo - RSK World")
        print(f"Author: {__author__} | Website: {__website__}")
        print(f"{'='*60}")

575 lines•17.9 KB
python
scripts/train_classifier.py
Raw Download
Find: Go to:
"""
================================================================================
Text Classification Dataset - Model Training Script
================================================================================
Project: Text Classification Dataset
Category: Text Data / NLP

Author: Molla Samser
Designer & Tester: Rima Khatun
Website: https://rskworld.in
Email: help@rskworld.in | support@rskworld.in
Phone: +91 93305 39277

Copyright (c) 2026 RSK World - All Rights Reserved
Content used for educational purposes only.

Created: December 2026
================================================================================
"""

import os
import json
import argparse
from datetime import datetime
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix
)
from sklearn.pipeline import Pipeline
import joblib

# Project information
__author__ = "Molla Samser"
__website__ = "https://rskworld.in"
__email__ = "help@rskworld.in"


class TextClassifier:
    """
    A flexible text classification system supporting multiple algorithms.
    
    Supported algorithms:
    - Naive Bayes (MultinomialNB)
    - Logistic Regression
    - Support Vector Machine (LinearSVC)
    - Random Forest
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    ALGORITHMS = {
        'naive_bayes': MultinomialNB,
        'logistic_regression': LogisticRegression,
        'svm': LinearSVC,
        'random_forest': RandomForestClassifier
    }
    
    CATEGORIES = {
        0: 'Technology',
        1: 'Sports',
        2: 'Politics',
        3: 'Entertainment',
        4: 'Business',
        5: 'Science'
    }
    
    def __init__(
        self,
        algorithm: str = 'logistic_regression',
        max_features: int = 10000,
        ngram_range: Tuple[int, int] = (1, 2),
        random_state: int = 42
    ):
        """
        Initialize the TextClassifier.
        
        Args:
            algorithm: Classification algorithm to use
            max_features: Maximum number of TF-IDF features
            ngram_range: Range of n-grams for feature extraction
            random_state: Random seed for reproducibility
        """
        self.algorithm = algorithm
        self.max_features = max_features
        self.ngram_range = ngram_range
        self.random_state = random_state
        
        # Validate algorithm
        if algorithm not in self.ALGORITHMS:
            raise ValueError(
                f"Unknown algorithm: {algorithm}. "
                f"Available: {list(self.ALGORITHMS.keys())}"
            )
        
        # Initialize vectorizer
        self.vectorizer = TfidfVectorizer(
            max_features=max_features,
            ngram_range=ngram_range,
            stop_words='english',
            sublinear_tf=True
        )
        
        # Initialize classifier
        if algorithm == 'logistic_regression':
            self.classifier = LogisticRegression(
                max_iter=1000,
                random_state=random_state,
                class_weight='balanced'
            )
        elif algorithm == 'svm':
            self.classifier = LinearSVC(
                max_iter=1000,
                random_state=random_state,
                class_weight='balanced'
            )
        elif algorithm == 'random_forest':
            self.classifier = RandomForestClassifier(
                n_estimators=100,
                random_state=random_state,
                class_weight='balanced',
                n_jobs=-1
            )
        else:
            self.classifier = MultinomialNB()
        
        # Create pipeline
        self.pipeline = Pipeline([
            ('vectorizer', self.vectorizer),
            ('classifier', self.classifier)
        ])
        
        self.is_trained = False
        self.training_history = {}
    
    def train(
        self,
        X_train: List[str],
        y_train: List[int],
        X_val: Optional[List[str]] = None,
        y_val: Optional[List[int]] = None
    ) -> Dict:
        """
        Train the classification model.
        
        Args:
            X_train: Training texts
            y_train: Training labels
            X_val: Validation texts (optional)
            y_val: Validation labels (optional)
            
        Returns:
            Dictionary containing training metrics
        """
        print(f"\n{'='*60}")
        print(f"Training Text Classifier - RSK World")
        print(f"Algorithm: {self.algorithm}")
        print(f"Author: {__author__} | Website: {__website__}")
        print(f"{'='*60}\n")
        
        # Train the pipeline
        print(f"Training on {len(X_train)} samples...")
        self.pipeline.fit(X_train, y_train)
        self.is_trained = True
        
        # Calculate training metrics
        train_preds = self.pipeline.predict(X_train)
        train_accuracy = accuracy_score(y_train, train_preds)
        
        results = {
            'algorithm': self.algorithm,
            'train_samples': len(X_train),
            'train_accuracy': train_accuracy,
            'timestamp': datetime.now().isoformat()
        }
        
        print(f"Training Accuracy: {train_accuracy:.4f}")
        
        # Validation metrics if provided
        if X_val is not None and y_val is not None:
            val_preds = self.pipeline.predict(X_val)
            val_accuracy = accuracy_score(y_val, val_preds)
            val_f1 = f1_score(y_val, val_preds, average='weighted')
            
            results['val_samples'] = len(X_val)
            results['val_accuracy'] = val_accuracy
            results['val_f1'] = val_f1
            
            print(f"Validation Accuracy: {val_accuracy:.4f}")
            print(f"Validation F1 Score: {val_f1:.4f}")
        
        self.training_history = results
        return results
    
    def evaluate(self, X_test: List[str], y_test: List[int]) -> Dict:
        """
        Evaluate the model on test data.
        
        Args:
            X_test: Test texts
            y_test: Test labels
            
        Returns:
            Dictionary containing evaluation metrics
        """
        if not self.is_trained:
            raise RuntimeError("Model must be trained before evaluation")
        
        print(f"\n{'='*60}")
        print("Model Evaluation")
        print(f"{'='*60}\n")
        
        # Generate predictions
        predictions = self.pipeline.predict(X_test)
        
        # Calculate metrics
        accuracy = accuracy_score(y_test, predictions)
        precision = precision_score(y_test, predictions, average='weighted')
        recall = recall_score(y_test, predictions, average='weighted')
        f1 = f1_score(y_test, predictions, average='weighted')
        
        # Classification report
        report = classification_report(
            y_test, predictions,
            target_names=list(self.CATEGORIES.values()),
            output_dict=True
        )
        
        # Confusion matrix
        conf_matrix = confusion_matrix(y_test, predictions)
        
        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'classification_report': report,
            'confusion_matrix': conf_matrix.tolist()
        }
        
        # Print results
        print(f"Test Accuracy:  {accuracy:.4f}")
        print(f"Test Precision: {precision:.4f}")
        print(f"Test Recall:    {recall:.4f}")
        print(f"Test F1 Score:  {f1:.4f}")
        print(f"\nClassification Report:")
        print(classification_report(
            y_test, predictions,
            target_names=list(self.CATEGORIES.values())
        ))
        
        return results
    
    def predict(self, texts: List[str]) -> List[int]:
        """
        Predict categories for new texts.
        
        Args:
            texts: List of texts to classify
            
        Returns:
            List of predicted labels
        """
        if not self.is_trained:
            raise RuntimeError("Model must be trained before prediction")
        
        return self.pipeline.predict(texts)
    
    def predict_with_labels(self, texts: List[str]) -> List[Dict]:
        """
        Predict categories with human-readable labels.
        
        Args:
            texts: List of texts to classify
            
        Returns:
            List of dictionaries with predictions
        """
        predictions = self.predict(texts)
        results = []
        
        for text, pred in zip(texts, predictions):
            results.append({
                'text': text[:100] + '...' if len(text) > 100 else text,
                'predicted_label': int(pred),
                'predicted_category': self.CATEGORIES[pred]
            })
        
        return results
    
    def save_model(self, filepath: str):
        """Save the trained model to disk."""
        if not self.is_trained:
            raise RuntimeError("Model must be trained before saving")
        
        model_data = {
            'pipeline': self.pipeline,
            'algorithm': self.algorithm,
            'max_features': self.max_features,
            'ngram_range': self.ngram_range,
            'training_history': self.training_history,
            'metadata': {
                'author': __author__,
                'website': __website__,
                'email': __email__,
                'saved_at': datetime.now().isoformat()
            }
        }
        
        joblib.dump(model_data, filepath)
        print(f"Model saved to: {filepath}")
    
    @classmethod
    def load_model(cls, filepath: str) -> 'TextClassifier':
        """Load a trained model from disk."""
        model_data = joblib.load(filepath)
        
        classifier = cls(
            algorithm=model_data['algorithm'],
            max_features=model_data['max_features'],
            ngram_range=model_data['ngram_range']
        )
        
        classifier.pipeline = model_data['pipeline']
        classifier.is_trained = True
        classifier.training_history = model_data.get('training_history', {})
        
        print(f"Model loaded from: {filepath}")
        return classifier


def load_dataset(data_dir: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Load train, validation, and test datasets.
    
    Args:
        data_dir: Path to the data directory
        
    Returns:
        Tuple of (train_df, val_df, test_df)
    """
    train_path = os.path.join(data_dir, 'csv', 'train.csv')
    val_path = os.path.join(data_dir, 'csv', 'validation.csv')
    test_path = os.path.join(data_dir, 'csv', 'test.csv')
    
    train_df = pd.read_csv(train_path, comment='#')
    val_df = pd.read_csv(val_path, comment='#')
    test_df = pd.read_csv(test_path, comment='#')
    
    return train_df, val_df, test_df


def main():
    """Main training script."""
    parser = argparse.ArgumentParser(
        description='Train Text Classification Model - RSK World'
    )
    parser.add_argument(
        '--data-dir', type=str, default='../data',
        help='Path to data directory'
    )
    parser.add_argument(
        '--algorithm', type=str, default='logistic_regression',
        choices=['naive_bayes', 'logistic_regression', 'svm', 'random_forest'],
        help='Classification algorithm'
    )
    parser.add_argument(
        '--max-features', type=int, default=10000,
        help='Maximum TF-IDF features'
    )
    parser.add_argument(
        '--output', type=str, default='model.joblib',
        help='Output model path'
    )
    
    args = parser.parse_args()
    
    print(f"\n{'='*60}")
    print("Text Classification Training - RSK World")
    print(f"Author: {__author__}")
    print(f"Website: {__website__}")
    print(f"Email: {__email__}")
    print(f"{'='*60}\n")
    
    # Load data
    print("Loading dataset...")
    train_df, val_df, test_df = load_dataset(args.data_dir)
    
    print(f"Train samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")
    print(f"Test samples: {len(test_df)}")
    
    # Initialize classifier
    classifier = TextClassifier(
        algorithm=args.algorithm,
        max_features=args.max_features
    )
    
    # Train model
    classifier.train(
        X_train=train_df['text'].tolist(),
        y_train=train_df['label'].tolist(),
        X_val=val_df['text'].tolist(),
        y_val=val_df['label'].tolist()
    )
    
    # Evaluate on test set
    classifier.evaluate(
        X_test=test_df['text'].tolist(),
        y_test=test_df['label'].tolist()
    )
    
    # Save model
    classifier.save_model(args.output)
    
    # Demo predictions
    print(f"\n{'='*60}")
    print("Demo Predictions")
    print(f"{'='*60}\n")
    
    demo_texts = [
        "New smartphone features revolutionary camera technology.",
        "Team wins championship in overtime thriller.",
        "Government passes new legislation on climate change."
    ]
    
    predictions = classifier.predict_with_labels(demo_texts)
    for pred in predictions:
        print(f"Text: {pred['text']}")
        print(f"Predicted: {pred['predicted_category']}")
        print()


if __name__ == "__main__":
    main()

446 lines•13.8 KB
python

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer