"""
================================================================================
Text Classification Dataset - Data Quality Analyzer
================================================================================
Project: Text Classification Dataset
Category: Text Data / NLP

Author: Molla Samser
Designer & Tester: Rima Khatun
Website: https://rskworld.in
Email: help@rskworld.in | support@rskworld.in
Phone: +91 93305 39277

Copyright (c) 2026 RSK World - All Rights Reserved
Content used for educational purposes only.

Features:
- Duplicate Detection
- Missing Value Analysis
- Class Imbalance Detection
- Text Quality Metrics
- Outlier Detection
- Language Detection
- Noise Detection
- Data Completeness Scoring
- Automated Recommendations

Created: December 2026
================================================================================
"""

import re
import string
from typing import Dict, List, Tuple, Optional, Set
from collections import Counter
from datetime import datetime

import numpy as np
import pandas as pd

# Project information
__author__ = "Molla Samser"
__website__ = "https://rskworld.in"
__email__ = "help@rskworld.in"

# Category mapping
CATEGORIES = {
    0: 'Technology', 1: 'Sports', 2: 'Politics',
    3: 'Entertainment', 4: 'Business', 5: 'Science'
}


class DataQualityAnalyzer:
    """
    Comprehensive data quality analyzer for text classification datasets.
    
    Analyzes:
    - Data completeness
    - Duplicate records
    - Class balance
    - Text quality metrics
    - Potential label errors
    - Outliers
    
    Author: Molla Samser | RSK World (https://rskworld.in)
    """
    
    def __init__(self, verbose: bool = True):
        """
        Initialize the analyzer.
        
        Args:
            verbose: Print detailed analysis
        """
        self.verbose = verbose
        self.report = {}
        self.issues = []
        self.warnings = []
        self.recommendations = []
    
    def analyze(
        self,
        df: pd.DataFrame,
        text_column: str = 'text',
        label_column: str = 'label'
    ) -> Dict:
        """
        Perform comprehensive data quality analysis.
        
        Args:
            df: DataFrame to analyze
            text_column: Name of text column
            label_column: Name of label column
            
        Returns:
            Comprehensive quality report
        """
        if self.verbose:
            print(f"\n{'='*60}")
            print("Data Quality Analysis - RSK World")
            print(f"Author: {__author__} | Website: {__website__}")
            print(f"{'='*60}\n")
        
        # Basic info
        self._analyze_basic_info(df, text_column, label_column)
        
        # Missing values
        self._analyze_missing_values(df, text_column, label_column)
        
        # Duplicates
        self._analyze_duplicates(df, text_column)
        
        # Class balance
        self._analyze_class_balance(df, label_column)
        
        # Text quality
        self._analyze_text_quality(df, text_column)
        
        # Outliers
        self._analyze_outliers(df, text_column)
        
        # Potential label errors
        self._analyze_potential_label_errors(df, text_column, label_column)
        
        # Generate recommendations
        self._generate_recommendations()
        
        # Calculate overall score
        self._calculate_quality_score()
        
        # Print summary
        if self.verbose:
            self._print_summary()
        
        return self.report
    
    def _analyze_basic_info(self, df: pd.DataFrame, text_col: str, label_col: str):
        """Analyze basic dataset information."""
        self.report['basic_info'] = {
            'total_samples': len(df),
            'columns': list(df.columns),
            'text_column': text_col,
            'label_column': label_col,
            'memory_usage_mb': df.memory_usage(deep=True).sum() / 1024**2
        }
        
        if self.verbose:
            print("1. Basic Information")
            print("-" * 40)
            print(f"   Total samples: {len(df):,}")
            print(f"   Columns: {list(df.columns)}")
            print(f"   Memory usage: {self.report['basic_info']['memory_usage_mb']:.2f} MB\n")
    
    def _analyze_missing_values(self, df: pd.DataFrame, text_col: str, label_col: str):
        """Analyze missing values."""
        missing_text = df[text_col].isna().sum()
        missing_label = df[label_col].isna().sum()
        empty_text = (df[text_col].str.strip() == '').sum() if df[text_col].dtype == 'object' else 0
        
        self.report['missing_values'] = {
            'missing_text': int(missing_text),
            'missing_labels': int(missing_label),
            'empty_text': int(empty_text),
            'missing_percentage': float((missing_text + missing_label) / (len(df) * 2) * 100)
        }
        
        if missing_text > 0 or missing_label > 0:
            self.issues.append(f"Found {missing_text + missing_label} missing values")
        if empty_text > 0:
            self.warnings.append(f"Found {empty_text} empty text entries")
        
        if self.verbose:
            print("2. Missing Values Analysis")
            print("-" * 40)
            print(f"   Missing text: {missing_text}")
            print(f"   Missing labels: {missing_label}")
            print(f"   Empty text: {empty_text}\n")
    
    def _analyze_duplicates(self, df: pd.DataFrame, text_col: str):
        """Analyze duplicate records."""
        # Exact duplicates
        exact_duplicates = df.duplicated().sum()
        
        # Text duplicates (same text, possibly different labels)
        text_duplicates = df[text_col].duplicated().sum()
        
        # Near duplicates (simplified check)
        normalized_texts = df[text_col].str.lower().str.strip()
        normalized_duplicates = normalized_texts.duplicated().sum()
        
        self.report['duplicates'] = {
            'exact_duplicates': int(exact_duplicates),
            'text_duplicates': int(text_duplicates),
            'normalized_duplicates': int(normalized_duplicates),
            'duplicate_percentage': float(text_duplicates / len(df) * 100)
        }
        
        if text_duplicates > 0:
            self.warnings.append(f"Found {text_duplicates} duplicate texts")
        
        if self.verbose:
            print("3. Duplicate Analysis")
            print("-" * 40)
            print(f"   Exact duplicates: {exact_duplicates}")
            print(f"   Text duplicates: {text_duplicates}")
            print(f"   After normalization: {normalized_duplicates}\n")
    
    def _analyze_class_balance(self, df: pd.DataFrame, label_col: str):
        """Analyze class distribution."""
        class_counts = df[label_col].value_counts().sort_index()
        total = len(df)
        
        # Calculate imbalance ratio
        max_class = class_counts.max()
        min_class = class_counts.min()
        imbalance_ratio = max_class / min_class if min_class > 0 else float('inf')
        
        # Calculate entropy (higher = more balanced)
        proportions = class_counts / total
        entropy = -sum(p * np.log2(p) for p in proportions if p > 0)
        max_entropy = np.log2(len(class_counts))
        normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
        
        self.report['class_balance'] = {
            'class_distribution': class_counts.to_dict(),
            'class_percentages': {k: float(v/total*100) for k, v in class_counts.items()},
            'imbalance_ratio': float(imbalance_ratio),
            'entropy': float(entropy),
            'normalized_entropy': float(normalized_entropy),
            'is_balanced': imbalance_ratio < 2.0
        }
        
        if imbalance_ratio > 3.0:
            self.issues.append(f"Severe class imbalance (ratio: {imbalance_ratio:.1f}:1)")
        elif imbalance_ratio > 2.0:
            self.warnings.append(f"Moderate class imbalance (ratio: {imbalance_ratio:.1f}:1)")
        
        if self.verbose:
            print("4. Class Balance Analysis")
            print("-" * 40)
            for label, count in class_counts.items():
                cat_name = CATEGORIES.get(label, f'Class {label}')
                print(f"   {cat_name}: {count} ({count/total*100:.1f}%)")
            print(f"   Imbalance ratio: {imbalance_ratio:.2f}:1")
            print(f"   Balance score: {normalized_entropy*100:.1f}%\n")
    
    def _analyze_text_quality(self, df: pd.DataFrame, text_col: str):
        """Analyze text quality metrics."""
        texts = df[text_col].dropna()
        
        # Length statistics
        word_counts = texts.str.split().str.len()
        char_counts = texts.str.len()
        
        # Special character ratio
        special_char_ratios = texts.apply(
            lambda x: len(re.findall(r'[^a-zA-Z0-9\s]', str(x))) / max(len(str(x)), 1)
        )
        
        # URL count
        url_pattern = r'https?://\S+|www\.\S+'
        url_counts = texts.apply(lambda x: len(re.findall(url_pattern, str(x))))
        
        # Uppercase ratio
        uppercase_ratios = texts.apply(
            lambda x: sum(1 for c in str(x) if c.isupper()) / max(len(str(x)), 1)
        )
        
        self.report['text_quality'] = {
            'word_count': {
                'mean': float(word_counts.mean()),
                'std': float(word_counts.std()),
                'min': int(word_counts.min()),
                'max': int(word_counts.max()),
                'median': float(word_counts.median())
            },
            'char_count': {
                'mean': float(char_counts.mean()),
                'std': float(char_counts.std()),
                'min': int(char_counts.min()),
                'max': int(char_counts.max())
            },
            'special_char_ratio': {
                'mean': float(special_char_ratios.mean()),
                'max': float(special_char_ratios.max())
            },
            'texts_with_urls': int(url_counts.gt(0).sum()),
            'avg_uppercase_ratio': float(uppercase_ratios.mean()),
            'very_short_texts': int(word_counts.lt(5).sum()),
            'very_long_texts': int(word_counts.gt(100).sum())
        }
        
        very_short = word_counts.lt(5).sum()
        if very_short > len(df) * 0.1:
            self.warnings.append(f"{very_short} texts have fewer than 5 words")
        
        if self.verbose:
            print("5. Text Quality Metrics")
            print("-" * 40)
            print(f"   Average word count: {word_counts.mean():.1f}")
            print(f"   Word count std: {word_counts.std():.1f}")
            print(f"   Min/Max words: {word_counts.min()} / {word_counts.max()}")
            print(f"   Texts with URLs: {url_counts.gt(0).sum()}")
            print(f"   Very short texts (<5 words): {very_short}\n")
    
    def _analyze_outliers(self, df: pd.DataFrame, text_col: str):
        """Detect outlier texts."""
        texts = df[text_col].dropna()
        word_counts = texts.str.split().str.len()
        
        # IQR method for outliers
        Q1 = word_counts.quantile(0.25)
        Q3 = word_counts.quantile(0.75)
        IQR = Q3 - Q1
        
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        
        outliers_low = word_counts.lt(lower_bound).sum()
        outliers_high = word_counts.gt(upper_bound).sum()
        
        self.report['outliers'] = {
            'lower_bound': float(max(0, lower_bound)),
            'upper_bound': float(upper_bound),
            'outliers_low': int(outliers_low),
            'outliers_high': int(outliers_high),
            'total_outliers': int(outliers_low + outliers_high),
            'outlier_percentage': float((outliers_low + outliers_high) / len(df) * 100)
        }
        
        if outliers_high > len(df) * 0.05:
            self.warnings.append(f"{outliers_high} texts are unusually long")
        
        if self.verbose:
            print("6. Outlier Detection")
            print("-" * 40)
            print(f"   Normal range: {max(0, lower_bound):.0f} - {upper_bound:.0f} words")
            print(f"   Outliers (too short): {outliers_low}")
            print(f"   Outliers (too long): {outliers_high}\n")
    
    def _analyze_potential_label_errors(self, df: pd.DataFrame, text_col: str, label_col: str):
        """Detect potential label errors using keyword analysis."""
        # Category keywords
        keywords = {
            0: ['apple', 'google', 'microsoft', 'ai', 'tech', 'software', 'app'],
            1: ['football', 'basketball', 'tennis', 'game', 'team', 'player', 'win'],
            2: ['government', 'president', 'congress', 'election', 'vote', 'policy'],
            3: ['movie', 'film', 'music', 'actor', 'netflix', 'entertainment'],
            4: ['stock', 'market', 'business', 'company', 'economy', 'profit'],
            5: ['scientist', 'research', 'nasa', 'space', 'climate', 'discovery']
        }
        
        potential_errors = []
        
        for idx, row in df.iterrows():
            text = str(row[text_col]).lower()
            label = row[label_col]
            
            # Count keyword matches for each category
            scores = {}
            for cat, kws in keywords.items():
                scores[cat] = sum(1 for kw in kws if kw in text)
            
            # Check if another category has more matches
            if scores:
                best_match = max(scores, key=scores.get)
                if scores[best_match] > 2 and best_match != label and scores.get(label, 0) == 0:
                    potential_errors.append({
                        'index': idx,
                        'current_label': int(label),
                        'suggested_label': int(best_match),
                        'confidence': scores[best_match]
                    })
        
        self.report['potential_label_errors'] = {
            'count': len(potential_errors),
            'percentage': float(len(potential_errors) / len(df) * 100),
            'examples': potential_errors[:10]  # First 10
        }
        
        if len(potential_errors) > len(df) * 0.05:
            self.warnings.append(f"Found {len(potential_errors)} potential label errors")
        
        if self.verbose:
            print("7. Potential Label Errors")
            print("-" * 40)
            print(f"   Potential errors found: {len(potential_errors)}")
            print(f"   Error rate: {len(potential_errors)/len(df)*100:.1f}%\n")
    
    def _generate_recommendations(self):
        """Generate recommendations based on analysis."""
        r = self.report
        
        # Missing values
        if r['missing_values']['missing_text'] > 0 or r['missing_values']['empty_text'] > 0:
            self.recommendations.append("Remove or impute missing/empty text values")
        
        # Duplicates
        if r['duplicates']['text_duplicates'] > 0:
            self.recommendations.append("Consider removing duplicate texts to prevent data leakage")
        
        # Class imbalance
        if r['class_balance']['imbalance_ratio'] > 2.0:
            self.recommendations.append("Apply class balancing techniques (oversampling, undersampling, or class weights)")
        
        # Text quality
        if r['text_quality']['very_short_texts'] > 10:
            self.recommendations.append("Review very short texts - they may lack sufficient information")
        
        # Outliers
        if r['outliers']['total_outliers'] > 10:
            self.recommendations.append("Review outlier texts - consider truncation or removal")
        
        # Label errors
        if r['potential_label_errors']['count'] > 5:
            self.recommendations.append("Review flagged potential label errors for correction")
        
        self.report['recommendations'] = self.recommendations
    
    def _calculate_quality_score(self):
        """Calculate overall data quality score."""
        r = self.report
        
        scores = []
        
        # Missing value score (0-100)
        missing_score = 100 - min(100, r['missing_values']['missing_percentage'] * 10)
        scores.append(('Completeness', missing_score))
        
        # Duplicate score (0-100)
        dup_score = 100 - min(100, r['duplicates']['duplicate_percentage'] * 5)
        scores.append(('Uniqueness', dup_score))
        
        # Balance score (0-100)
        balance_score = r['class_balance']['normalized_entropy'] * 100
        scores.append(('Balance', balance_score))
        
        # Outlier score (0-100)
        outlier_score = 100 - min(100, r['outliers']['outlier_percentage'] * 5)
        scores.append(('Consistency', outlier_score))
        
        # Label quality score (0-100)
        label_score = 100 - min(100, r['potential_label_errors']['percentage'] * 10)
        scores.append(('Label Quality', label_score))
        
        overall_score = sum(s[1] for s in scores) / len(scores)
        
        self.report['quality_scores'] = {
            'dimensions': dict(scores),
            'overall_score': float(overall_score),
            'grade': self._score_to_grade(overall_score)
        }
    
    def _score_to_grade(self, score: float) -> str:
        """Convert score to letter grade."""
        if score >= 90: return 'A'
        elif score >= 80: return 'B'
        elif score >= 70: return 'C'
        elif score >= 60: return 'D'
        else: return 'F'
    
    def _print_summary(self):
        """Print analysis summary."""
        print("=" * 60)
        print("QUALITY SCORE SUMMARY")
        print("=" * 60)
        
        scores = self.report['quality_scores']
        for dim, score in scores['dimensions'].items():
            bar = '█' * int(score // 10) + '░' * (10 - int(score // 10))
            print(f"   {dim:15} [{bar}] {score:.0f}%")
        
        print("-" * 60)
        print(f"   Overall Score: {scores['overall_score']:.1f}% (Grade: {scores['grade']})")
        
        if self.issues:
            print(f"\n🔴 Issues ({len(self.issues)}):")
            for issue in self.issues:
                print(f"   • {issue}")
        
        if self.warnings:
            print(f"\n🟡 Warnings ({len(self.warnings)}):")
            for warning in self.warnings:
                print(f"   • {warning}")
        
        if self.recommendations:
            print(f"\n💡 Recommendations ({len(self.recommendations)}):")
            for rec in self.recommendations:
                print(f"   • {rec}")
        
        print(f"\n{'='*60}")
        print(f"Analysis complete! | Author: {__author__}")
        print(f"{'='*60}")
    
    def export_report(self, output_path: str = 'quality_report.json'):
        """Export report to JSON file."""
        import json
        
        export_data = {
            'metadata': {
                'author': __author__,
                'website': __website__,
                'generated_at': datetime.now().isoformat()
            },
            'report': self.report,
            'issues': self.issues,
            'warnings': self.warnings
        }
        
        with open(output_path, 'w') as f:
            json.dump(export_data, f, indent=2, default=str)
        
        if self.verbose:
            print(f"Report exported to: {output_path}")


def analyze_dataset(csv_path: str) -> Dict:
    """
    Quick function to analyze a CSV dataset.
    
    Args:
        csv_path: Path to CSV file
        
    Returns:
        Quality report dictionary
    """
    df = pd.read_csv(csv_path, comment='#')
    analyzer = DataQualityAnalyzer()
    return analyzer.analyze(df)


if __name__ == "__main__":
    # Demo
    try:
        report = analyze_dataset('../data/csv/train.csv')
    except FileNotFoundError:
        print("Dataset not found. Please ensure train.csv exists in ../data/csv/")

