help@rskworld.in +91 93305 39277
RSK World
  • Home
  • Development
    • Web Development
    • Mobile Apps
    • Software
    • Games
    • Project
  • Technologies
    • Data Science
    • AI Development
    • Cloud Development
    • Blockchain
    • Cyber Security
    • Dev Tools
    • Testing Tools
  • About
  • Contact

Theme Settings

Color Scheme
Display Options
Font Size
100%
Back to Project
RSK World
rust-web-server
/
src
RSK World
rust-web-server
Rust Web Server - High-Performance Async Web Server + WebSocket Support + JWT Authentication + File Upload + Memory Safety + Educational Design
src
  • auth.rs15.7 KB
  • config.rs2.9 KB
  • error.rs5.2 KB
  • file_upload.rs19 KB
  • handlers.rs12.8 KB
  • lib.rs1.8 KB
  • main.rs6 KB
  • middleware.rs6.2 KB
  • static_files.rs9.9 KB
  • utils.rs9.6 KB
  • websocket.rs15.3 KB
features.mdauth.rsfile_upload.rs
src/auth.rs
Raw Download
Find: Go to:
/*
 * Authentication Module - Rust Web Server
 * 
 * Created by RSK World (https://rskworld.in)
 * Founder: Molla Samser
 * Designer & Tester: Rima Khatun
 * 
 * Contact:
 * - Email: hello@rskworld.in, support@rskworld.in
 * - Phone: +91 93305 39277
 * - Address: Nutanhat, Mongolkote, Purba Burdwan, West Bengal, India, 713147
 * 
 * © 2026 RSK World. All rights reserved.
 * Content used for educational purposes only.
 */

use bcrypt::{hash, verify, DEFAULT_COST};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;

use crate::error::{ServerError, ServerResult};

/// User role types
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum UserRole {
    Admin,
    User,
    Guest,
}

/// User information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
    pub id: String,
    pub username: String,
    pub email: String,
    pub password_hash: String,
    pub role: UserRole,
    pub created_at: String,
    pub last_login: Option<String>,
    pub is_active: bool,
}

/// Login request
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
    pub username: String,
    pub password: String,
}

/// Register request
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
    pub username: String,
    pub email: String,
    pub password: String,
    pub confirm_password: String,
}

/// Authentication response
#[derive(Debug, Serialize)]
pub struct AuthResponse {
    pub token: String,
    pub user: UserInfo,
    pub expires_in: i64,
}

/// Public user information
#[derive(Debug, Serialize)]
pub struct UserInfo {
    pub id: String,
    pub username: String,
    pub email: String,
    pub role: UserRole,
    pub created_at: String,
}

/// JWT Claims
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
    pub sub: String, // Subject (user ID)
    pub username: String,
    pub email: String,
    pub role: UserRole,
    pub exp: i64, // Expiration time
    pub iat: i64, // Issued at
    pub jti: String, // JWT ID
}

/// Session information
#[derive(Debug, Clone)]
pub struct Session {
    pub user_id: String,
    pub token: String,
    pub expires_at: chrono::DateTime<Utc>,
    pub created_at: chrono::DateTime<Utc>,
}

/// Authentication manager
pub struct AuthManager {
    /// In-memory user storage (in production, use a database)
    users: Arc<RwLock<HashMap<String, User>>>,
    /// Active sessions
    sessions: Arc<RwLock<HashMap<String, Session>>>,
    /// JWT secret key
    jwt_secret: String,
    /// Token expiration time (hours)
    token_expiration_hours: i64,
}

impl AuthManager {
    /// Create a new authentication manager
    pub fn new(jwt_secret: String, token_expiration_hours: i64) -> Self {
        Self {
            users: Arc::new(RwLock::new(HashMap::new())),
            sessions: Arc::new(RwLock::new(HashMap::new())),
            jwt_secret,
            token_expiration_hours,
        }
    }

    /// Register a new user
    pub async fn register(&self, request: RegisterRequest) -> ServerResult<AuthResponse> {
        // Validate request
        if request.password != request.confirm_password {
            return Err(ServerError::BadRequest("Passwords do not match".to_string()));
        }

        if request.password.len() < 8 {
            return Err(ServerError::BadRequest("Password must be at least 8 characters".to_string()));
        }

        let mut users = self.users.write().await;

        // Check if username already exists
        if users.values().any(|u| u.username == request.username) {
            return Err(ServerError::BadRequest("Username already exists".to_string()));
        }

        // Check if email already exists
        if users.values().any(|u| u.email == request.email) {
            return Err(ServerError::BadRequest("Email already exists".to_string()));
        }

        // Hash password
        let password_hash = hash(&request.password, DEFAULT_COST)
            .map_err(|e| ServerError::Internal(format!("Failed to hash password: {}", e)))?;

        // Create user
        let user = User {
            id: Uuid::new_v4().to_string(),
            username: request.username.clone(),
            email: request.email.clone(),
            password_hash,
            role: UserRole::User,
            created_at: Utc::now().to_rfc3339(),
            last_login: None,
            is_active: true,
        };

        users.insert(user.id.clone(), user.clone());
        drop(users);

        info!("User registered: {}", request.username);

        // Generate token
        let token = self.generate_token(&user)?;
        let expires_in = self.token_expiration_hours * 3600; // Convert to seconds

        // Store session
        let session = Session {
            user_id: user.id.clone(),
            token: token.clone(),
            expires_at: Utc::now() + Duration::hours(self.token_expiration_hours),
            created_at: Utc::now(),
        };

        let mut sessions = self.sessions.write().await;
        sessions.insert(token.clone(), session);
        drop(sessions);

        Ok(AuthResponse {
            token,
            user: UserInfo {
                id: user.id,
                username: user.username,
                email: user.email,
                role: user.role,
                created_at: user.created_at,
            },
            expires_in,
        })
    }

    /// Login user
    pub async fn login(&self, request: LoginRequest) -> ServerResult<AuthResponse> {
        let users = self.users.read().await;

        // Find user by username
        let user = users
            .values()
            .find(|u| u.username == request.username && u.is_active)
            .ok_or_else(|| ServerError::Unauthorized("Invalid credentials".to_string()))?;

        // Verify password
        if !verify(&request.password, &user.password_hash).unwrap_or(false) {
            return Err(ServerError::Unauthorized("Invalid credentials".to_string()));
        }

        drop(users);

        info!("User logged in: {}", request.username);

        // Generate token
        let token = self.generate_token(user)?;
        let expires_in = self.token_expiration_hours * 3600; // Convert to seconds

        // Store session
        let session = Session {
            user_id: user.id.clone(),
            token: token.clone(),
            expires_at: Utc::now() + Duration::hours(self.token_expiration_hours),
            created_at: Utc::now(),
        };

        let mut sessions = self.sessions.write().await;
        sessions.insert(token.clone(), session);
        drop(sessions);

        // Update last login
        let mut users = self.users.write().await;
        if let Some(user_mut) = users.get_mut(&user.id) {
            user_mut.last_login = Some(Utc::now().to_rfc3339());
        }

        Ok(AuthResponse {
            token,
            user: UserInfo {
                id: user.id,
                username: user.username.clone(),
                email: user.email,
                role: user.role,
                created_at: user.created_at,
            },
            expires_in,
        })
    }

    /// Logout user
    pub async fn logout(&self, token: &str) -> ServerResult<()> {
        let mut sessions = self.sessions.write().await;
        
        if let Some(session) = sessions.remove(token) {
            info!("User logged out: {}", session.user_id);
            Ok(())
        } else {
            Err(ServerError::BadRequest("Invalid session".to_string()))
        }
    }

    /// Validate token and return user
    pub async fn validate_token(&self, token: &str) -> ServerResult<User> {
        // Check if session exists and is not expired
        let sessions = self.sessions.read().await;
        
        if let Some(session) = sessions.get(token) {
            if session.expires_at < Utc::now() {
                return Err(ServerError::Unauthorized("Token expired".to_string()));
            }
        } else {
            return Err(ServerError::Unauthorized("Invalid token".to_string()));
        }
        drop(sessions);

        // Decode JWT
        let token_data = decode::<Claims>(
            token,
            &DecodingKey::from_secret(self.jwt_secret.as_ref()),
            &Validation::default(),
        )
        .map_err(|e| ServerError::Unauthorized(format!("Invalid token: {}", e)))?;

        // Get user
        let users = self.users.read().await;
        let user = users
            .get(&token_data.claims.sub)
            .filter(|u| u.is_active)
            .ok_or_else(|| ServerError::Unauthorized("User not found or inactive".to_string()))?;

        Ok(user.clone())
    }

    /// Generate JWT token
    fn generate_token(&self, user: &User) -> ServerResult<String> {
        let now = Utc::now();
        let exp = now + Duration::hours(self.token_expiration_hours);

        let claims = Claims {
            sub: user.id.clone(),
            username: user.username.clone(),
            email: user.email.clone(),
            role: user.role.clone(),
            exp: exp.timestamp(),
            iat: now.timestamp(),
            jti: Uuid::new_v4().to_string(),
        };

        encode(
            &Header::default(),
            &claims,
            &EncodingKey::from_secret(self.jwt_secret.as_ref()),
        )
        .map_err(|e| ServerError::Internal(format!("Failed to generate token: {}", e)))
    }

    /// Create admin user (for initial setup)
    pub async fn create_admin_user(&self, username: &str, email: &str, password: &str) -> ServerResult<()> {
        let mut users = self.users.write().await;

        // Check if admin already exists
        if users.values().any(|u| u.role == UserRole::Admin) {
            warn!("Admin user already exists");
            return Ok(());
        }

        // Hash password
        let password_hash = hash(password, DEFAULT_COST)
            .map_err(|e| ServerError::Internal(format!("Failed to hash password: {}", e)))?;

        // Create admin user
        let admin = User {
            id: Uuid::new_v4().to_string(),
            username: username.to_string(),
            email: email.to_string(),
            password_hash,
            role: UserRole::Admin,
            created_at: Utc::now().to_rfc3339(),
            last_login: None,
            is_active: true,
        };

        users.insert(admin.id.clone(), admin);
        drop(users);

        info!("Admin user created: {}", username);
        Ok(())
    }

    /// Get user by ID
    pub async fn get_user(&self, user_id: &str) -> ServerResult<User> {
        let users = self.users.read().await;
        users
            .get(user_id)
            .filter(|u| u.is_active)
            .cloned()
            .ok_or_else(|| ServerError::NotFound("User not found".to_string()))
    }

    /// Update user role (admin only)
    pub async fn update_user_role(&self, user_id: &str, new_role: UserRole) -> ServerResult<()> {
        let mut users = self.users.write().await;
        
        if let Some(user) = users.get_mut(user_id) {
            user.role = new_role;
            info!("User role updated: {} -> {:?}", user_id, new_role);
            Ok(())
        } else {
            Err(ServerError::NotFound("User not found".to_string()))
        }
    }

    /// Clean up expired sessions
    pub async fn cleanup_expired_sessions(&self) {
        let mut sessions = self.sessions.write().await;
        let now = Utc::now();
        
        let expired_tokens: Vec<String> = sessions
            .iter()
            .filter(|(_, session)| session.expires_at < now)
            .map(|(token, _)| token.clone())
            .collect();

        for token in expired_tokens {
            sessions.remove(&token);
        }

        if !sessions.is_empty() {
            debug!("Cleaned up expired sessions");
        }
    }

    /// Get active sessions count
    pub async fn get_active_sessions_count(&self) -> usize {
        self.sessions.read().await.len()
    }

    /// Get total users count
    pub async fn get_total_users_count(&self) -> usize {
        self.users.read().await.len()
    }
}

/// Middleware for authentication
pub async fn auth_middleware(
    auth_manager: Arc<AuthManager>,
    token: Option<String>,
) -> ServerResult<User> {
    let token = token
        .or_else(|| {
            // Try to extract from Authorization header
            // This would need to be implemented based on how the request is structured
            None
        })
        .ok_or_else(|| ServerError::Unauthorized("No token provided".to_string()))?;

    auth_manager.validate_token(&token).await
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_user_registration() {
        let auth = AuthManager::new("test_secret".to_string(), 24);
        
        let request = RegisterRequest {
            username: "testuser".to_string(),
            email: "test@example.com".to_string(),
            password: "password123".to_string(),
            confirm_password: "password123".to_string(),
        };

        let result = auth.register(request).await;
        assert!(result.is_ok());
        
        let response = result.unwrap();
        assert!(!response.token.is_empty());
        assert_eq!(response.user.username, "testuser");
        assert_eq!(response.user.role, UserRole::User);
    }

    #[tokio::test]
    async fn test_user_login() {
        let auth = AuthManager::new("test_secret".to_string(), 24);
        
        // Register user first
        let register_request = RegisterRequest {
            username: "testuser".to_string(),
            email: "test@example.com".to_string(),
            password: "password123".to_string(),
            confirm_password: "password123".to_string(),
        };
        auth.register(register_request).await.unwrap();

        // Login
        let login_request = LoginRequest {
            username: "testuser".to_string(),
            password: "password123".to_string(),
        };

        let result = auth.login(login_request).await;
        assert!(result.is_ok());
        
        let response = result.unwrap();
        assert!(!response.token.is_empty());
        assert_eq!(response.user.username, "testuser");
    }

    #[tokio::test]
    async fn test_invalid_login() {
        let auth = AuthManager::new("test_secret".to_string(), 24);
        
        let login_request = LoginRequest {
            username: "nonexistent".to_string(),
            password: "wrongpassword".to_string(),
        };

        let result = auth.login(login_request).await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_token_validation() {
        let auth = AuthManager::new("test_secret".to_string(), 24);
        
        // Register and login user
        let register_request = RegisterRequest {
            username: "testuser".to_string(),
            email: "test@example.com".to_string(),
            password: "password123".to_string(),
            confirm_password: "password123".to_string(),
        };
        let auth_response = auth.register(register_request).await.unwrap();

        // Validate token
        let result = auth.validate_token(&auth_response.token).await;
        assert!(result.is_ok());
        
        let user = result.unwrap();
        assert_eq!(user.username, "testuser");
    }
}
510 lines•15.7 KB
rust
src/file_upload.rs
Raw Download
Find: Go to:
/*
 * File Upload Module - Rust Web Server
 * 
 * Created by RSK World (https://rskworld.in)
 * Founder: Molla Samser
 * Designer & Tester: Rima Khatun
 * 
 * Contact:
 * - Email: hello@rskworld.in, support@rskworld.in
 * - Phone: +91 93305 39277
 * - Address: Nutanhat, Mongolkote, Purba Burdwan, West Bengal, India, 713147
 * 
 * © 2026 RSK World. All rights reserved.
 * Content used for educational purposes only.
 */

use hyper::Body;
use hyper::Response;
use hyper::StatusCode;
use mime_guess::from_path;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::fs;
use tokio::io::AsyncReadExt;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;

use crate::error::{ServerError, ServerResult};
use crate::utils::get_content_type;

/// File upload configuration
#[derive(Debug, Clone)]
pub struct UploadConfig {
    /// Maximum file size in bytes
    pub max_file_size: usize,
    /// Allowed file extensions
    pub allowed_extensions: Vec<String>,
    /// Upload directory
    pub upload_dir: PathBuf,
    /// Whether to generate unique filenames
    pub generate_unique_names: bool,
    /// Maximum number of files per request
    pub max_files_per_request: usize,
}

impl Default for UploadConfig {
    fn default() -> Self {
        Self {
            max_file_size: 10 * 1024 * 1024, // 10MB
            allowed_extensions: vec![
                "jpg".to_string(), "jpeg".to_string(), "png".to_string(), "gif".to_string(),
                "pdf".to_string(), "txt".to_string(), "doc".to_string(), "docx".to_string(),
                "zip".to_string(), "rar".to_string(), "mp4".to_string(), "mp3".to_string(),
            ],
            upload_dir: PathBuf::from("uploads"),
            generate_unique_names: true,
            max_files_per_request: 5,
        }
    }
}

/// Uploaded file information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UploadedFile {
    pub id: String,
    pub original_name: String,
    pub filename: String,
    pub content_type: String,
    pub size: u64,
    pub upload_date: String,
    pub path: String,
    pub url: String,
}

/// File upload response
#[derive(Debug, Serialize)]
pub struct UploadResponse {
    pub success: bool,
    pub message: String,
    pub files: Vec<UploadedFile>,
    pub errors: Vec<String>,
}

/// File manager
pub struct FileManager {
    config: UploadConfig,
    uploaded_files: Arc<RwLock<HashMap<String, UploadedFile>>>,
}

impl FileManager {
    /// Create a new file manager
    pub fn new(config: UploadConfig) -> Self {
        Self {
            config,
            uploaded_files: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    /// Initialize upload directory
    pub async fn initialize(&self) -> ServerResult<()> {
        if !self.config.upload_dir.exists() {
            fs::create_dir_all(&self.config.upload_dir).await
                .map_err(|e| ServerError::StaticFile(format!("Failed to create upload directory: {}", e)))?;
            info!("Created upload directory: {:?}", self.config.upload_dir);
        }
        Ok(())
    }

    /// Handle file upload from multipart form data
    pub async fn handle_upload(&self, body: Body, content_type: Option<&str>) -> ServerResult<UploadResponse> {
        let mut response = UploadResponse {
            success: true,
            message: "Files uploaded successfully".to_string(),
            files: Vec::new(),
            errors: Vec::new(),
        };

        // Convert body to bytes
        let body_bytes = hyper::body::to_bytes(body).await
            .map_err(|e| ServerError::BadRequest(format!("Failed to read request body: {}", e)))?;

        if body_bytes.is_empty() {
            response.errors.push("No data received".to_string());
            response.success = false;
            response.message = "Upload failed: No data received".to_string();
            return Ok(response);
        }

        // Extract boundary from Content-Type header
        let boundary = self.extract_boundary(content_type)?;
        
        // Parse multipart data
        let parts = self.parse_multipart(&body_bytes, &boundary)?;

        if parts.is_empty() {
            response.errors.push("No valid files found in upload".to_string());
            response.success = false;
            response.message = "Upload failed: No valid files".to_string();
            return Ok(response);
        }

        if parts.len() > self.config.max_files_per_request {
            response.errors.push(format!(
                "Too many files. Maximum allowed: {}", 
                self.config.max_files_per_request
            ));
            response.success = false;
            response.message = "Upload failed due to too many files".to_string();
            return Ok(response);
        }

        for part in parts {
            match self.process_file_part(part).await {
                Ok(file_info) => {
                    response.files.push(file_info);
                }
                Err(e) => {
                    response.errors.push(format!("Failed to upload file: {}", e));
                    response.success = false;
                }
            }
        }

        if response.success && response.files.is_empty() {
            response.success = false;
            response.message = "No valid files were uploaded".to_string();
        } else if !response.errors.is_empty() {
            response.message = format!("Upload completed with {} errors", response.errors.len());
        }

        Ok(response)
    }

    /// Extract boundary from multipart content type
    fn extract_boundary(&self, content_type: Option<&str>) -> ServerResult<String> {
        let content_type = content_type.ok_or_else(|| {
            ServerError::BadRequest("Missing Content-Type header".to_string())
        })?;

        if !content_type.starts_with("multipart/form-data") {
            return Err(ServerError::BadRequest("Invalid Content-Type, expected multipart/form-data".to_string()));
        }

        // Extract boundary from Content-Type header
        // Format: multipart/form-data; boundary=----WebKitFormBoundary7MA4YWxkTrZu0gW
        for part in content_type.split(';') {
            let part = part.trim();
            if part.starts_with("boundary=") {
                let boundary = part.strip_prefix("boundary=")
                    .ok_or_else(|| ServerError::BadRequest("Invalid boundary format".to_string()))?
                    .trim();
                return Ok(boundary.to_string());
            }
        }

        Err(ServerError::BadRequest("Boundary not found in Content-Type".to_string()))
    }

    /// Parse multipart data (basic implementation)
    fn parse_multipart(&self, body: &[u8], boundary: &str) -> ServerResult<Vec<FilePart>> {
        let mut parts = Vec::new();
        let boundary_bytes = format!("--{}", boundary).as_bytes().to_vec();
        let end_boundary_bytes = format!("--{}--", boundary).as_bytes().to_vec();
        
        let mut start = 0;
        while start < body.len() {
            // Find next boundary
            let boundary_pos = body[start..]
                .windows(boundary_bytes.len())
                .position(|window| window == boundary_bytes.as_slice());
            
            if boundary_pos.is_none() {
                break;
            }
            
            let boundary_pos = start + boundary_pos.unwrap();
            
            // Skip boundary and CRLF
            let mut content_start = boundary_pos + boundary_bytes.len();
            if content_start + 2 <= body.len() && &body[content_start..content_start + 2] == b"\r\n" {
                content_start += 2;
            }
            
            // Find end boundary
            let end_boundary_pos = body[content_start..]
                .windows(end_boundary_bytes.len())
                .position(|window| window == end_boundary_bytes.as_slice());
            
            if end_boundary_pos.is_none() {
                break;
            }
            
            let end_boundary_pos = content_start + end_boundary_pos.unwrap();
            
            // Extract content between boundaries
            let content = &body[content_start..end_boundary_pos];
            
            // Parse headers and content
            if let Some(file_part) = self.parse_file_part(content)? {
                parts.push(file_part);
            }
            
            start = end_boundary_pos + end_boundary_bytes.len();
        }
        
        Ok(parts)
    }
    
    /// Parse individual file part from multipart content
    fn parse_file_part(&self, content: &[u8]) -> ServerResult<Option<FilePart>> {
        let content_str = String::from_utf8_lossy(content);
        let mut lines = content_str.lines();
        
        // Parse headers
        let mut filename = None;
        let mut content_type = "application/octet-stream".to_string();
        let mut headers_end = false;
        let mut header_lines = Vec::new();
        
        while let Some(line) = lines.next() {
            if line.is_empty() {
                headers_end = true;
                break;
            }
            header_lines.push(line);
            
            if line.to_lowercase().starts_with("content-disposition:") {
                if let Some(filename_part) = line.split("filename=").nth(1) {
                    filename = Some(filename_part.trim_matches('"').to_string());
                }
            }
            
            if line.to_lowercase().starts_with("content-type:") {
                if let Some(ct) = line.split(':').nth(1) {
                    content_type = ct.trim().to_string();
                }
            }
        }
        
        // Skip if no filename (not a file)
        if filename.is_none() || filename.as_ref().unwrap().is_empty() {
            return Ok(None);
        }
        
        // Find where headers end and file data begins
        let mut data_start = 0;
        for line in header_lines.iter() {
            data_start += line.len() + 2; // +2 for CRLF
        }
        data_start += 2; // Extra CRLF after headers
        
        // Extract file data
        let file_data = &content[data_start..];
        
        Ok(Some(FilePart {
            filename: filename.unwrap(),
            data: file_data.to_vec(),
            content_type,
        }))
    }

    /// Process a single file part
    async fn process_file_part(&self, part: FilePart) -> ServerResult<UploadedFile> {
        // Validate file size
        if part.data.len() > self.config.max_file_size {
            return Err(ServerError::BadRequest(
                format!("File too large. Maximum size: {} bytes", self.config.max_file_size)
            ));
        }

        // Validate file extension
        let extension = Path::new(&part.filename)
            .extension()
            .and_then(|ext| ext.to_str())
            .unwrap_or("")
            .to_lowercase();

        if !self.config.allowed_extensions.contains(&extension) {
            return Err(ServerError::BadRequest(
                format!("File type not allowed: {}", extension)
            ));
        }

        // Generate filename
        let filename = if self.config.generate_unique_names {
            let uuid = Uuid::new_v4();
            format!("{}_{}.{}", 
                uuid.to_string().replace("-", ""),
                chrono::Utc::now().timestamp(),
                extension
            )
        } else {
            part.filename.clone()
        };

        // Create file path
        let file_path = self.config.upload_dir.join(&filename);

        // Write file to disk
        fs::write(&file_path, &part.data).await
            .map_err(|e| ServerError::StaticFile(format!("Failed to write file: {}", e)))?;

        // Create file info
        let file_info = UploadedFile {
            id: Uuid::new_v4().to_string(),
            original_name: part.filename,
            filename: filename.clone(),
            content_type: get_content_type(&part.filename),
            size: part.data.len() as u64,
            upload_date: chrono::Utc::now().to_rfc3339(),
            path: file_path.to_string_lossy().to_string(),
            url: format!("/uploads/{}", filename),
        };

        // Store file info
        let mut uploaded_files = self.uploaded_files.write().await;
        uploaded_files.insert(file_info.id.clone(), file_info.clone());

        info!("File uploaded: {} ({} bytes)", file_info.filename, file_info.size);

        Ok(file_info)
    }

    /// Get uploaded file by ID
    pub async fn get_file(&self, file_id: &str) -> ServerResult<UploadedFile> {
        let uploaded_files = self.uploaded_files.read().await;
        uploaded_files
            .get(file_id)
            .cloned()
            .ok_or_else(|| ServerError::NotFound("File not found".to_string()))
    }

    /// List all uploaded files
    pub async fn list_files(&self) -> Vec<UploadedFile> {
        let uploaded_files = self.uploaded_files.read().await;
        uploaded_files.values().cloned().collect()
    }

    /// Delete uploaded file
    pub async fn delete_file(&self, file_id: &str) -> ServerResult<()> {
        let file_info = self.get_file(file_id).await?;
        
        // Delete from filesystem
        fs::remove_file(&file_info.path).await
            .map_err(|e| ServerError::StaticFile(format!("Failed to delete file: {}", e)))?;

        // Remove from memory
        let mut uploaded_files = self.uploaded_files.write().await;
        uploaded_files.remove(file_id);

        info!("File deleted: {}", file_info.filename);
        Ok(())
    }

    /// Serve uploaded file
    pub async fn serve_file(&self, filename: &str) -> ServerResult<Response<Body>> {
        let file_path = self.config.upload_dir.join(filename);
        
        if !file_path.exists() {
            return Err(ServerError::NotFound("File not found".to_string()));
        }

        // Read file content
        let content = fs::read(&file_path).await
            .map_err(|e| ServerError::StaticFile(format!("Failed to read file: {}", e)))?;

        // Determine content type
        let content_type = get_content_type(filename);

        // Create response
        let response = Response::builder()
            .status(StatusCode::OK)
            .header("content-type", content_type)
            .header("content-disposition", format!("inline; filename=\"{}\"", filename))
            .body(Body::from(content))
            .map_err(|e| ServerError::Internal(format!("Failed to create response: {}", e)))?;

        Ok(response)
    }

    /// Get upload statistics
    pub async fn get_stats(&self) -> UploadStats {
        let uploaded_files = self.uploaded_files.read().await;
        let total_files = uploaded_files.len();
        let total_size: u64 = uploaded_files.values().map(|f| f.size).sum();

        let mut extension_counts = HashMap::new();
        for file in uploaded_files.values() {
            let extension = Path::new(&file.filename)
                .extension()
                .and_then(|ext| ext.to_str())
                .unwrap_or("unknown")
                .to_lowercase();
            
            *extension_counts.entry(extension).or_insert(0) += 1;
        }

        UploadStats {
            total_files,
            total_size,
            extension_counts,
            upload_dir: self.config.upload_dir.to_string_lossy().to_string(),
        }
    }

    /// Clean up old files (older than specified days)
    pub async fn cleanup_old_files(&self, days_old: u64) -> ServerResult<usize> {
        let cutoff_date = chrono::Utc::now() - chrono::Duration::days(days_old as i64);
        let mut files_to_remove = Vec::new();

        {
            let uploaded_files = self.uploaded_files.read().await;
            for (file_id, file_info) in uploaded_files.iter() {
                if let Ok(upload_date) = chrono::DateTime::parse_from_rfc3339(&file_info.upload_date) {
                    if upload_date.naive_utc() < cutoff_date.naive_utc() {
                        files_to_remove.push(file_id.clone());
                    }
                }
            }
        }

        let mut removed_count = 0;
        for file_id in files_to_remove {
            if let Ok(_) = self.delete_file(&file_id).await {
                removed_count += 1;
            }
        }

        if removed_count > 0 {
            info!("Cleaned up {} old files", removed_count);
        }

        Ok(removed_count)
    }
}

/// File part from multipart data
#[derive(Debug, Clone)]
struct FilePart {
    filename: String,
    data: Vec<u8>,
    content_type: String,
}

/// Upload statistics
#[derive(Debug, Serialize)]
pub struct UploadStats {
    pub total_files: usize,
    pub total_size: u64,
    pub extension_counts: HashMap<String, usize>,
    pub upload_dir: String,
}

/// Create upload response
pub fn create_upload_response(response: UploadResponse) -> Response<Body> {
    let status = if response.success {
        StatusCode::OK
    } else {
        StatusCode::BAD_REQUEST
    };

    Response::builder()
        .status(status)
        .header("content-type", "application/json")
        .body(Body::from(serde_json::to_string_pretty(&response).unwrap()))
        .unwrap()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_file_manager_creation() {
        let config = UploadConfig::default();
        let manager = FileManager::new(config);
        
        let stats = manager.get_stats().await;
        assert_eq!(stats.total_files, 0);
        assert_eq!(stats.total_size, 0);
    }

    #[tokio::test]
    async fn test_upload_config_validation() {
        let config = UploadConfig::default();
        
        assert!(config.allowed_extensions.contains(&"jpg".to_string()));
        assert!(config.allowed_extensions.contains(&"png".to_string()));
        assert!(!config.allowed_extensions.contains(&"exe".to_string()));
        assert_eq!(config.max_file_size, 10 * 1024 * 1024); // 10MB
    }

    #[tokio::test]
    async fn test_file_info_serialization() {
        let file_info = UploadedFile {
            id: "123".to_string(),
            original_name: "test.jpg".to_string(),
            filename: "unique_test.jpg".to_string(),
            content_type: "image/jpeg".to_string(),
            size: 1024,
            upload_date: "2023-01-01T00:00:00Z".to_string(),
            path: "/uploads/unique_test.jpg".to_string(),
            url: "/uploads/unique_test.jpg".to_string(),
        };

        let json = serde_json::to_string(&file_info).unwrap();
        let deserialized: UploadedFile = serde_json::from_str(&json).unwrap();

        assert_eq!(deserialized.filename, "unique_test.jpg");
        assert_eq!(deserialized.size, 1024);
    }
}
557 lines•19 KB
rust

About RSK World

Founded by Molla Samser, with Designer & Tester Rima Khatun, RSK World is your one-stop destination for free programming resources, source code, and development tools.

Founder: Molla Samser
Designer & Tester: Rima Khatun

Development

  • Game Development
  • Web Development
  • Mobile Development
  • AI Development
  • Development Tools

Legal

  • Terms & Conditions
  • Privacy Policy
  • Disclaimer

Contact Info

Nutanhat, Mongolkote
Purba Burdwan, West Bengal
India, 713147

+91 93305 39277

hello@rskworld.in
support@rskworld.in

© 2026 RSK World. All rights reserved.

Content used for educational purposes only. View Disclaimer