help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
text-classification
/
notebooks
RSK World
text-classification
Text Classification Dataset - NLP + Multi-Class Classification + Machine Learning
notebooks
  • text_classification_tutorial.ipynb16.6 KB
text_classification_tutorial.ipynb
notebooks/text_classification_tutorial.ipynb
Raw Download
Find: Go to:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Text Classification Tutorial\n",
        "\n",
        "---\n",
        "\n",
        "## Project Information\n",
        "\n",
        "- **Project:** Text Classification Dataset\n",
        "- **Category:** Text Data / NLP\n",
        "- **Author:** Molla Samser\n",
        "- **Designer & Tester:** Rima Khatun\n",
        "- **Website:** [https://rskworld.in](https://rskworld.in)\n",
        "- **Email:** help@rskworld.in | support@rskworld.in\n",
        "- **Phone:** +91 93305 39277\n",
        "\n",
        "**Copyright (c) 2026 RSK World - All Rights Reserved**\n",
        "\n",
        "---\n",
        "\n",
        "This notebook demonstrates how to:\n",
        "1. Load and explore the text classification dataset\n",
        "2. Preprocess text data\n",
        "3. Train traditional ML classifiers\n",
        "4. Evaluate and compare model performance\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Text Classification Dataset - Tutorial Notebook\n",
        "# Author: Molla Samser | Website: https://rskworld.in\n",
        "# Copyright (c) 2026 RSK World - All Rights Reserved\n",
        "# ================================================================================\n",
        "\n",
        "import os\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')\n",
        "\n",
        "# Data manipulation\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "\n",
        "# Visualization\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "# Text processing\n",
        "import re\n",
        "import string\n",
        "from collections import Counter\n",
        "\n",
        "# ML libraries\n",
        "from sklearn.feature_extraction.text import TfidfVectorizer\n",
        "from sklearn.naive_bayes import MultinomialNB\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.svm import LinearSVC\n",
        "from sklearn.metrics import classification_report, confusion_matrix, accuracy_score\n",
        "\n",
        "# Set style\n",
        "plt.style.use('seaborn-v0_8-darkgrid')\n",
        "sns.set_palette('Set2')\n",
        "\n",
        "print(\"Libraries imported successfully!\")\n",
        "print(\"Author: Molla Samser | Website: https://rskworld.in\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Load Dataset - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "DATA_DIR = '../data/csv'\n",
        "\n",
        "# Load datasets\n",
        "train_df = pd.read_csv(f'{DATA_DIR}/train.csv', comment='#')\n",
        "val_df = pd.read_csv(f'{DATA_DIR}/validation.csv', comment='#')\n",
        "test_df = pd.read_csv(f'{DATA_DIR}/test.csv', comment='#')\n",
        "\n",
        "# Category mapping\n",
        "CATEGORIES = {\n",
        "    0: 'Technology', 1: 'Sports', 2: 'Politics',\n",
        "    3: 'Entertainment', 4: 'Business', 5: 'Science'\n",
        "}\n",
        "\n",
        "print(\"Dataset Loaded Successfully!\")\n",
        "print(f\"Training samples: {len(train_df)}\")\n",
        "print(f\"Validation samples: {len(val_df)}\")\n",
        "print(f\"Test samples: {len(test_df)}\")\n",
        "train_df.head()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Text Preprocessing - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "def preprocess_text(text):\n",
        "    \"\"\"Preprocess text for classification.\"\"\"\n",
        "    text = text.lower()\n",
        "    text = re.sub(r'https?://\\S+|www\\.\\S+', '', text)\n",
        "    text = text.translate(str.maketrans('', '', string.punctuation))\n",
        "    text = ' '.join(text.split())\n",
        "    return text\n",
        "\n",
        "# Apply preprocessing\n",
        "train_df['processed_text'] = train_df['text'].apply(preprocess_text)\n",
        "val_df['processed_text'] = val_df['text'].apply(preprocess_text)\n",
        "test_df['processed_text'] = test_df['text'].apply(preprocess_text)\n",
        "\n",
        "print(\"Preprocessing complete!\")\n",
        "print(f\"\\nOriginal: {train_df['text'].iloc[0][:80]}...\")\n",
        "print(f\"Processed: {train_df['processed_text'].iloc[0][:80]}...\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Feature Extraction & Model Training - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "# TF-IDF Vectorization\n",
        "tfidf = TfidfVectorizer(max_features=10000, ngram_range=(1, 2), stop_words='english')\n",
        "X_train = tfidf.fit_transform(train_df['processed_text'])\n",
        "X_val = tfidf.transform(val_df['processed_text'])\n",
        "X_test = tfidf.transform(test_df['processed_text'])\n",
        "\n",
        "y_train = train_df['label'].values\n",
        "y_val = val_df['label'].values\n",
        "y_test = test_df['label'].values\n",
        "\n",
        "# Train models\n",
        "models = {\n",
        "    'Naive Bayes': MultinomialNB(),\n",
        "    'Logistic Regression': LogisticRegression(max_iter=1000, class_weight='balanced'),\n",
        "    'Linear SVM': LinearSVC(max_iter=1000, class_weight='balanced')\n",
        "}\n",
        "\n",
        "results = {}\n",
        "for name, model in models.items():\n",
        "    model.fit(X_train, y_train)\n",
        "    test_acc = accuracy_score(y_test, model.predict(X_test))\n",
        "    results[name] = test_acc\n",
        "    print(f\"{name}: {test_acc:.4f}\")\n",
        "\n",
        "print(f\"\\nBest Model: {max(results, key=results.get)}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Sample Predictions - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "lr_model = models['Logistic Regression']\n",
        "\n",
        "sample_texts = [\n",
        "    \"New AI-powered smartphone features revolutionary camera technology.\",\n",
        "    \"Team wins championship in overtime thriller at stadium.\",\n",
        "    \"Government announces new policy to address climate change.\"\n",
        "]\n",
        "\n",
        "print(\"Sample Predictions:\")\n",
        "print(\"=\" * 60)\n",
        "for text in sample_texts:\n",
        "    processed = preprocess_text(text)\n",
        "    features = tfidf.transform([processed])\n",
        "    prediction = lr_model.predict(features)[0]\n",
        "    print(f\"\\nText: {text}\")\n",
        "    print(f\"Predicted: {CATEGORIES[prediction]}\")\n",
        "\n",
        "print(\"\\n\" + \"=\" * 60)\n",
        "print(\"Tutorial Complete!\")\n",
        "print(\"Author: Molla Samser | Website: https://rskworld.in\")\n",
        "print(\"Copyright (c) 2026 RSK World - All Rights Reserved\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Advanced Visualizations\n",
        "The following cells demonstrate advanced data visualization techniques for understanding your text classification dataset.\n",
        "\n",
        "**Author:** Molla Samser | **Website:** [rskworld.in](https://rskworld.in)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Confusion Matrix Visualization - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "from sklearn.metrics import confusion_matrix\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "\n",
        "# Set dark theme for visualization\n",
        "plt.style.use('dark_background')\n",
        "\n",
        "# Get predictions from our best model (Logistic Regression)\n",
        "best_model = models['Logistic Regression']\n",
        "y_pred = best_model.predict(X_test)\n",
        "\n",
        "# Create confusion matrix\n",
        "cm = confusion_matrix(y_test, y_pred)\n",
        "cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
        "\n",
        "# Plot\n",
        "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
        "\n",
        "# Raw counts\n",
        "sns.heatmap(cm, annot=True, fmt='d', cmap='Reds', \n",
        "            xticklabels=CATEGORIES.values(), yticklabels=CATEGORIES.values(),\n",
        "            ax=axes[0], linewidths=0.5, linecolor='#333')\n",
        "axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')\n",
        "axes[0].set_xlabel('Predicted Label')\n",
        "axes[0].set_ylabel('True Label')\n",
        "\n",
        "# Normalized\n",
        "sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',\n",
        "            xticklabels=CATEGORIES.values(), yticklabels=CATEGORIES.values(),\n",
        "            ax=axes[1], linewidths=0.5, linecolor='#333')\n",
        "axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')\n",
        "axes[1].set_xlabel('Predicted Label')\n",
        "axes[1].set_ylabel('True Label')\n",
        "\n",
        "plt.suptitle('Model Performance Analysis - RSK World | rskworld.in', \n",
        "             fontsize=16, fontweight='bold', y=1.02)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "print(f\"\\nConfusion Matrix Analysis Complete!\")\n",
        "print(f\"Author: Molla Samser | Website: https://rskworld.in\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Category Distribution & Word Cloud - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "# Category Distribution\n",
        "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
        "\n",
        "# Color palette for categories\n",
        "category_colors = ['#3b82f6', '#22c55e', '#8b5cf6', '#ec4899', '#f59e0b', '#06b6d4']\n",
        "\n",
        "# Pie chart\n",
        "category_counts = train_df['label'].map(CATEGORIES).value_counts()\n",
        "wedges, texts, autotexts = axes[0].pie(\n",
        "    category_counts.values, \n",
        "    labels=category_counts.index,\n",
        "    autopct='%1.1f%%',\n",
        "    colors=category_colors,\n",
        "    explode=[0.02] * len(category_counts),\n",
        "    shadow=True\n",
        ")\n",
        "axes[0].set_title('Category Distribution', fontsize=14, fontweight='bold')\n",
        "\n",
        "# Bar chart\n",
        "bars = axes[1].bar(category_counts.index, category_counts.values, color=category_colors)\n",
        "axes[1].set_xlabel('Category', fontsize=12)\n",
        "axes[1].set_ylabel('Count', fontsize=12)\n",
        "axes[1].set_title('Documents per Category', fontsize=14, fontweight='bold')\n",
        "axes[1].tick_params(axis='x', rotation=45)\n",
        "\n",
        "# Add value labels on bars\n",
        "for bar in bars:\n",
        "    height = bar.get_height()\n",
        "    axes[1].text(bar.get_x() + bar.get_width()/2., height,\n",
        "                f'{int(height)}', ha='center', va='bottom', fontsize=10, fontweight='bold')\n",
        "\n",
        "plt.suptitle('Dataset Analysis - RSK World | rskworld.in', fontsize=16, fontweight='bold', y=1.02)\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# Word Count Distribution\n",
        "train_df['word_count'] = train_df['text'].str.split().str.len()\n",
        "\n",
        "fig, ax = plt.subplots(figsize=(12, 5))\n",
        "for i, category in enumerate(CATEGORIES.values()):\n",
        "    data = train_df[train_df['label'].map(CATEGORIES) == category]['word_count']\n",
        "    ax.hist(data, bins=20, alpha=0.6, label=category, color=category_colors[i])\n",
        "\n",
        "ax.set_xlabel('Word Count', fontsize=12)\n",
        "ax.set_ylabel('Frequency', fontsize=12)\n",
        "ax.set_title('Word Count Distribution by Category', fontsize=14, fontweight='bold')\n",
        "ax.legend(loc='upper right')\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "print(f\"\\nDataset statistics:\")\n",
        "print(f\"Total documents: {len(train_df)}\")\n",
        "print(f\"Average word count: {train_df['word_count'].mean():.1f}\")\n",
        "print(f\"Min word count: {train_df['word_count'].min()}\")\n",
        "print(f\"Max word count: {train_df['word_count'].max()}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# ================================================================================\n",
        "# Advanced Model Evaluation & Feature Analysis - RSK World (https://rskworld.in)\n",
        "# ================================================================================\n",
        "\n",
        "from sklearn.model_selection import cross_val_score, StratifiedKFold\n",
        "from sklearn.metrics import classification_report\n",
        "\n",
        "# Cross-validation analysis\n",
        "print(\"=\" * 60)\n",
        "print(\"Cross-Validation Analysis\")\n",
        "print(\"=\" * 60)\n",
        "\n",
        "skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)\n",
        "\n",
        "for name, model in models.items():\n",
        "    # Note: We need to fit on processed text, so we'll use the already fitted model\n",
        "    scores = cross_val_score(model, X_train, y_train, cv=skf, scoring='f1_macro')\n",
        "    print(f\"\\n{name}:\")\n",
        "    print(f\"  Mean F1 Score: {scores.mean():.4f} (+/- {scores.std() * 2:.4f})\")\n",
        "    print(f\"  All Folds: {[f'{s:.4f}' for s in scores]}\")\n",
        "\n",
        "# Detailed classification report for best model\n",
        "print(\"\\n\" + \"=\" * 60)\n",
        "print(\"Detailed Classification Report - Logistic Regression\")\n",
        "print(\"=\" * 60)\n",
        "print(classification_report(\n",
        "    y_test, \n",
        "    models['Logistic Regression'].predict(X_test),\n",
        "    target_names=list(CATEGORIES.values())\n",
        "))\n",
        "\n",
        "# Top features per category\n",
        "print(\"\\n\" + \"=\" * 60)\n",
        "print(\"Top Predictive Words per Category\")\n",
        "print(\"=\" * 60)\n",
        "\n",
        "feature_names = tfidf.get_feature_names_out()\n",
        "lr_coefs = models['Logistic Regression'].coef_\n",
        "\n",
        "for i, category in CATEGORIES.items():\n",
        "    top_indices = np.argsort(lr_coefs[i])[-10:][::-1]\n",
        "    top_words = [feature_names[idx] for idx in top_indices]\n",
        "    print(f\"\\n{category}:\")\n",
        "    print(f\"  {', '.join(top_words)}\")\n",
        "\n",
        "print(\"\\n\" + \"=\" * 60)\n",
        "print(\"Advanced Analysis Complete!\")\n",
        "print(\"Author: Molla Samser | Website: https://rskworld.in\")\n",
        "print(\"Copyright (c) 2026 RSK World - All Rights Reserved\")\n",
        "print(\"=\" * 60)\n"
      ]
    }
  ],
  "metadata": {
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
401 lines•16.6 KB
json

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer