help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
text-classification
/
scripts
RSK World
text-classification
Text Classification Dataset - NLP + Multi-Class Classification + Machine Learning
scripts
  • __init__.py2.3 KB
  • active_learning.py26.8 KB
  • api_server.py12.7 KB
  • batch_processor.py16.4 KB
  • data_augmentation.py18.2 KB
  • data_quality.py20 KB
  • deep_learning.py24.2 KB
  • hyperparameter_tuning.py22.5 KB
  • model_explainability.py17.9 KB
  • preprocessing.py8.7 KB
  • train_classifier.py13.8 KB
  • train_transformers.py12.5 KB
  • visualizations.py19 KB
model_explainability.py
scripts/model_explainability.py
Raw Download
Find: Go to:
"""
================================================================================
Text Classification Dataset - Model Explainability Module
================================================================================
Project: Text Classification Dataset
Category: Text Data / NLP

Author: Molla Samser
Designer & Tester: Rima Khatun
Website: https://rskworld.in
Email: help@rskworld.in | support@rskworld.in
Phone: +91 93305 39277

Copyright (c) 2026 RSK World - All Rights Reserved
Content used for educational purposes only.

Features:
- LIME Text Explainer
- Feature Attribution
- Word Importance Highlighting
- Prediction Confidence Analysis
- Attention Visualization (for transformers)

Created: December 2026
================================================================================
"""

import re
import string
from typing import List, Dict, Tuple, Optional, Callable
import numpy as np

# Project information
__author__ = "Molla Samser"
__website__ = "https://rskworld.in"
__email__ = "help@rskworld.in"

# Category mapping
CATEGORIES = {
    0: 'Technology',
    1: 'Sports',
    2: 'Politics',
    3: 'Entertainment',
    4: 'Business',
    5: 'Science'
}


class TextExplainer:
    """
    Text classification model explainer using LIME-like approach.
    
    Explains predictions by analyzing word importance through
    perturbation-based feature attribution.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(
        self,
        classifier_fn: Callable,
        class_names: List[str] = None,
        num_samples: int = 1000,
        random_state: int = 42
    ):
        """
        Initialize the explainer.
        
        Args:
            classifier_fn: Function that takes list of texts and returns probabilities
            class_names: List of class names
            num_samples: Number of perturbations to generate
            random_state: Random seed
        """
        self.classifier_fn = classifier_fn
        self.class_names = class_names or list(CATEGORIES.values())
        self.num_samples = num_samples
        np.random.seed(random_state)
    
    def _tokenize(self, text: str) -> List[str]:
        """Simple word tokenization."""
        # Remove punctuation and split
        text = re.sub(r'[^\w\s]', '', text.lower())
        return text.split()
    
    def _perturb_text(
        self,
        words: List[str],
        num_samples: int
    ) -> Tuple[List[str], np.ndarray]:
        """
        Generate perturbed versions of text by randomly removing words.
        
        Args:
            words: List of words in original text
            num_samples: Number of perturbations
            
        Returns:
            Tuple of (perturbed_texts, perturbation_matrix)
        """
        num_words = len(words)
        
        # Generate binary mask matrix (1 = keep word, 0 = remove)
        # Each row is a perturbation
        perturbation_matrix = np.random.binomial(1, 0.5, size=(num_samples, num_words))
        
        # Always include original text
        perturbation_matrix[0] = np.ones(num_words)
        
        perturbed_texts = []
        for row in perturbation_matrix:
            perturbed_words = [w for w, keep in zip(words, row) if keep]
            perturbed_texts.append(' '.join(perturbed_words) if perturbed_words else words[0])
        
        return perturbed_texts, perturbation_matrix
    
    def _compute_weights(
        self,
        perturbation_matrix: np.ndarray,
        kernel_width: float = 25.0
    ) -> np.ndarray:
        """
        Compute weights for each perturbation based on distance from original.
        
        Args:
            perturbation_matrix: Binary matrix of perturbations
            kernel_width: Width of exponential kernel
            
        Returns:
            Array of weights
        """
        # Distance is number of removed words
        distances = np.sum(perturbation_matrix == 0, axis=1)
        weights = np.exp(-distances ** 2 / kernel_width ** 2)
        return weights
    
    def _fit_linear_model(
        self,
        perturbation_matrix: np.ndarray,
        predictions: np.ndarray,
        weights: np.ndarray,
        target_class: int
    ) -> np.ndarray:
        """
        Fit weighted linear model to get feature importance.
        
        Args:
            perturbation_matrix: Binary feature matrix
            predictions: Model predictions for perturbations
            weights: Sample weights
            target_class: Class to explain
            
        Returns:
            Feature importance scores
        """
        from sklearn.linear_model import Ridge
        
        # Get predictions for target class
        y = predictions[:, target_class]
        
        # Fit weighted ridge regression
        model = Ridge(alpha=1.0)
        model.fit(perturbation_matrix, y, sample_weight=weights)
        
        return model.coef_
    
    def explain(
        self,
        text: str,
        num_features: int = 10,
        target_class: Optional[int] = None
    ) -> Dict:
        """
        Explain a prediction for the given text.
        
        Args:
            text: Text to explain
            num_features: Number of top features to return
            target_class: Class to explain (None = predicted class)
            
        Returns:
            Dictionary with explanation results
        """
        # Tokenize
        words = self._tokenize(text)
        
        if len(words) == 0:
            return {'error': 'Empty text after tokenization'}
        
        # Generate perturbations
        perturbed_texts, perturbation_matrix = self._perturb_text(
            words, self.num_samples
        )
        
        # Get predictions for all perturbations
        predictions = self.classifier_fn(perturbed_texts)
        
        # Get original prediction
        original_probs = predictions[0]
        predicted_class = np.argmax(original_probs)
        
        if target_class is None:
            target_class = predicted_class
        
        # Compute weights
        weights = self._compute_weights(perturbation_matrix)
        
        # Fit linear model
        importances = self._fit_linear_model(
            perturbation_matrix, predictions, weights, target_class
        )
        
        # Get top features
        top_indices = np.argsort(np.abs(importances))[-num_features:][::-1]
        
        word_importance = [
            {
                'word': words[i],
                'importance': float(importances[i]),
                'direction': 'positive' if importances[i] > 0 else 'negative'
            }
            for i in top_indices if i < len(words)
        ]
        
        return {
            'text': text,
            'predicted_class': int(predicted_class),
            'predicted_category': self.class_names[predicted_class],
            'explained_class': int(target_class),
            'explained_category': self.class_names[target_class],
            'confidence': float(original_probs[predicted_class]),
            'probabilities': {
                self.class_names[i]: float(p)
                for i, p in enumerate(original_probs)
            },
            'word_importance': word_importance,
            'top_positive_words': [
                w for w in word_importance if w['direction'] == 'positive'
            ][:5],
            'top_negative_words': [
                w for w in word_importance if w['direction'] == 'negative'
            ][:5]
        }
    
    def explain_with_html(
        self,
        text: str,
        num_features: int = 10,
        target_class: Optional[int] = None
    ) -> str:
        """
        Generate HTML visualization of explanation.
        
        Args:
            text: Text to explain
            num_features: Number of features
            target_class: Class to explain
            
        Returns:
            HTML string with highlighted text
        """
        explanation = self.explain(text, num_features, target_class)
        
        if 'error' in explanation:
            return f"<p>Error: {explanation['error']}</p>"
        
        # Create word importance lookup
        word_scores = {
            item['word'].lower(): item['importance']
            for item in explanation['word_importance']
        }
        
        # Build HTML
        html_parts = ['<div style="font-family: Arial, sans-serif; line-height: 1.8;">']
        
        words = text.split()
        for word in words:
            clean_word = re.sub(r'[^\w]', '', word.lower())
            if clean_word in word_scores:
                score = word_scores[clean_word]
                # Color based on importance
                if score > 0:
                    intensity = min(255, int(abs(score) * 500))
                    color = f'rgba(34, 197, 94, {min(1, abs(score) * 3)})'  # Green
                else:
                    intensity = min(255, int(abs(score) * 500))
                    color = f'rgba(239, 68, 68, {min(1, abs(score) * 3)})'  # Red
                
                html_parts.append(
                    f'<span style="background-color: {color}; padding: 2px 4px; '
                    f'border-radius: 3px; margin: 0 2px;">{word}</span>'
                )
            else:
                html_parts.append(f' {word}')
        
        html_parts.append('</div>')
        
        # Add prediction info
        html_parts.append('<div style="margin-top: 20px; padding: 15px; '
                         'background: #1a1333; border-radius: 8px; color: #f8fafc;">')
        html_parts.append(f'<strong>Prediction:</strong> {explanation["predicted_category"]} '
                         f'({explanation["confidence"]:.1%} confidence)')
        html_parts.append('</div>')
        
        return '\n'.join(html_parts)


class AttentionVisualizer:
    """
    Visualize attention weights from transformer models.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(self, model, tokenizer):
        """
        Initialize with transformer model and tokenizer.
        
        Args:
            model: Transformer model
            tokenizer: Tokenizer
        """
        self.model = model
        self.tokenizer = tokenizer
    
    def get_attention_weights(self, text: str) -> Dict:
        """
        Extract attention weights for input text.
        
        Args:
            text: Input text
            
        Returns:
            Dictionary with attention analysis
        """
        import torch
        
        # Tokenize
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        )
        
        # Get attention
        with torch.no_grad():
            outputs = self.model(**inputs, output_attentions=True)
        
        # Average attention across layers and heads
        attentions = outputs.attentions  # Tuple of (batch, heads, seq, seq)
        avg_attention = torch.stack(attentions).mean(dim=(0, 1, 2))
        
        # Get tokens
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Attention scores per token (excluding special tokens)
        token_scores = avg_attention[1:-1].numpy()  # Remove [CLS] and [SEP]
        tokens = tokens[1:-1]
        
        return {
            'tokens': tokens,
            'attention_scores': token_scores.tolist(),
            'top_attended': sorted(
                zip(tokens, token_scores),
                key=lambda x: x[1],
                reverse=True
            )[:10]
        }


class PredictionAnalyzer:
    """
    Analyze prediction confidence and uncertainty.
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(self, classifier_fn: Callable, class_names: List[str] = None):
        """
        Initialize analyzer.
        
        Args:
            classifier_fn: Function that returns probabilities
            class_names: List of class names
        """
        self.classifier_fn = classifier_fn
        self.class_names = class_names or list(CATEGORIES.values())
    
    def analyze(self, text: str) -> Dict:
        """
        Comprehensive prediction analysis.
        
        Args:
            text: Input text
            
        Returns:
            Analysis results
        """
        probs = self.classifier_fn([text])[0]
        
        predicted_class = np.argmax(probs)
        sorted_indices = np.argsort(probs)[::-1]
        
        # Entropy (uncertainty measure)
        entropy = -np.sum(probs * np.log(probs + 1e-10))
        max_entropy = np.log(len(probs))
        normalized_entropy = entropy / max_entropy
        
        # Margin (difference between top 2 predictions)
        margin = probs[sorted_indices[0]] - probs[sorted_indices[1]]
        
        return {
            'text': text[:200] + '...' if len(text) > 200 else text,
            'prediction': {
                'class': int(predicted_class),
                'category': self.class_names[predicted_class],
                'confidence': float(probs[predicted_class])
            },
            'all_probabilities': {
                self.class_names[i]: float(probs[i])
                for i in sorted_indices
            },
            'uncertainty': {
                'entropy': float(entropy),
                'normalized_entropy': float(normalized_entropy),
                'margin': float(margin),
                'is_confident': margin > 0.3,
                'is_uncertain': normalized_entropy > 0.7
            },
            'ranked_predictions': [
                {
                    'rank': rank + 1,
                    'category': self.class_names[idx],
                    'probability': float(probs[idx])
                }
                for rank, idx in enumerate(sorted_indices)
            ]
        }


def explain_prediction(
    text: str,
    model_path: str,
    output_html: str = 'explanation.html'
):
    """
    Generate visual explanation for a prediction.
    
    Args:
        text: Text to explain
        model_path: Path to trained model
        output_html: Output HTML file
        
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    import joblib
    
    print(f"\n{'='*60}")
    print("Model Explainability - RSK World")
    print(f"Author: {__author__} | Website: {__website__}")
    print(f"{'='*60}\n")
    
    # Load model
    model_data = joblib.load(model_path)
    pipeline = model_data['pipeline']
    
    # Create classifier function
    def classifier_fn(texts):
        return pipeline.predict_proba(texts)
    
    # Create explainer
    explainer = TextExplainer(classifier_fn)
    
    # Generate explanation
    print(f"Analyzing text: {text[:50]}...")
    explanation = explainer.explain(text)
    
    # Print results
    print(f"\nPredicted: {explanation['predicted_category']} "
          f"({explanation['confidence']:.1%} confidence)")
    print("\nTop important words:")
    for item in explanation['word_importance'][:10]:
        direction = "↑" if item['direction'] == 'positive' else "↓"
        print(f"  {direction} {item['word']}: {item['importance']:.4f}")
    
    # Generate HTML
    html = explainer.explain_with_html(text)
    
    # Create full HTML document
    full_html = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Prediction Explanation - RSK World</title>
    <meta charset="UTF-8">
    <style>
        body {{
            font-family: 'Segoe UI', Arial, sans-serif;
            background: #0f0a1f;
            color: #f8fafc;
            padding: 40px;
            max-width: 800px;
            margin: 0 auto;
        }}
        h1 {{ color: #dc2626; }}
        .container {{
            background: #1a1333;
            padding: 30px;
            border-radius: 12px;
            border: 1px solid #352d54;
        }}
        .footer {{
            margin-top: 30px;
            text-align: center;
            color: #6b6882;
            font-size: 14px;
        }}
        .legend {{
            margin-top: 20px;
            padding: 15px;
            background: #231d3a;
            border-radius: 8px;
        }}
        .legend span {{
            display: inline-block;
            margin-right: 20px;
        }}
        .positive {{ color: #22c55e; }}
        .negative {{ color: #ef4444; }}
    </style>
</head>
<body>
    <h1>🔍 Prediction Explanation</h1>
    <div class="container">
        {html}
        <div class="legend">
            <strong>Legend:</strong>
            <span class="positive">■ Supports prediction</span>
            <span class="negative">■ Against prediction</span>
        </div>
    </div>
    <div class="footer">
        <p>Generated by RSK World Text Classification</p>
        <p>Author: {__author__} | Website: <a href="{__website__}" style="color: #dc2626;">{__website__}</a></p>
    </div>
</body>
</html>
"""
    
    with open(output_html, 'w', encoding='utf-8') as f:
        f.write(full_html)
    
    print(f"\nExplanation saved to: {output_html}")


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 2:
        text = sys.argv[1]
        model_path = sys.argv[2]
        explain_prediction(text, model_path)
    else:
        print("Usage: python model_explainability.py 'Your text here' model.joblib")
        print("\nDemo mode:")
        
        # Demo without model
        print(f"\n{'='*60}")
        print("Explainability Module Demo - RSK World")
        print(f"Author: {__author__} | Website: {__website__}")
        print(f"{'='*60}")

575 lines•17.9 KB
python

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer