python

exercises

exercises.pyšŸ
"""
Python Security - Exercises
============================

Practice security concepts with hands-on exercises.
"""

import hashlib
import secrets
import os
import re
import html
from typing import Optional
from pathlib import Path
from datetime import datetime, timedelta


# ============================================================
# EXERCISE 1: Password Hasher Class
# ============================================================
"""
Create a PasswordManager class that:
1. Hashes passwords using PBKDF2 with SHA-256
2. Uses secure random salt (16 bytes)
3. Has configurable iteration count
4. Stores hash and salt together
5. Can verify passwords

Expected usage:
    pm = PasswordManager(iterations=100000)
    stored = pm.hash_password("mysecret")
    assert pm.verify_password("mysecret", stored)
    assert not pm.verify_password("wrong", stored)
"""

class PasswordManager:
    """Secure password hashing and verification."""
    
    def __init__(self, iterations: int = 100000):
        self.iterations = iterations
    
    def hash_password(self, password: str) -> str:
        """
        Hash password with salt.
        Returns: "salt_hex:hash_hex" format string
        """
        salt = os.urandom(16)
        hash_bytes = hashlib.pbkdf2_hmac(
            'sha256',
            password.encode('utf-8'),
            salt,
            self.iterations
        )
        return f"{salt.hex()}:{hash_bytes.hex()}"
    
    def verify_password(self, password: str, stored: str) -> bool:
        """Verify password against stored hash."""
        try:
            salt_hex, hash_hex = stored.split(':')
            salt = bytes.fromhex(salt_hex)
            stored_hash = bytes.fromhex(hash_hex)
            
            computed_hash = hashlib.pbkdf2_hmac(
                'sha256',
                password.encode('utf-8'),
                salt,
                self.iterations
            )
            return secrets.compare_digest(computed_hash, stored_hash)
        except (ValueError, AttributeError):
            return False


# Test your implementation
def test_password_manager():
    pm = PasswordManager(iterations=50000)  # Lower for testing
    
    # Test hashing
    password = "SuperSecret123!"
    stored = pm.hash_password(password)
    
    assert stored is not None, "hash_password should return a string"
    assert ':' in stored, "Format should be 'salt:hash'"
    
    # Test verification
    assert pm.verify_password(password, stored), "Should verify correct password"
    assert not pm.verify_password("wrong", stored), "Should reject wrong password"
    assert not pm.verify_password("", stored), "Should reject empty password"
    
    # Test different salts
    stored2 = pm.hash_password(password)
    assert stored != stored2, "Same password should have different salts"
    assert pm.verify_password(password, stored2), "Both hashes should verify"
    
    print("āœ“ Exercise 1 passed!")


# Uncomment to test:
# test_password_manager()


# ============================================================
# EXERCISE 2: Secure Token Generator
# ============================================================
"""
Create a TokenGenerator class that generates various secure tokens:
1. API keys with configurable prefix
2. Session tokens (URL-safe)
3. OTP (One-Time Password) codes
4. Verification codes (numeric, configurable length)
"""

class TokenGenerator:
    """Generate various types of secure tokens."""
    
    @staticmethod
    def generate_api_key(prefix: str = "api") -> str:
        """
        Generate API key like: api_aBcD1234...
        Should be URL-safe and 44 characters after prefix.
        """
        token = secrets.token_urlsafe(32)  # 43 chars
        return f"{prefix}_{token}"
    
    @staticmethod
    def generate_session_token(length: int = 32) -> str:
        """Generate URL-safe session token."""
        return secrets.token_urlsafe(length)
    
    @staticmethod
    def generate_otp(length: int = 6) -> str:
        """Generate numeric OTP code (e.g., "123456")."""
        digits = ''.join(str(secrets.randbelow(10)) for _ in range(length))
        return digits
    
    @staticmethod
    def generate_verification_code(length: int = 8) -> str:
        """
        Generate alphanumeric verification code.
        Should be uppercase for readability.
        """
        import string
        alphabet = string.ascii_uppercase + string.digits
        return ''.join(secrets.choice(alphabet) for _ in range(length))


# Test your implementation
def test_token_generator():
    gen = TokenGenerator()
    
    # Test API key
    key = gen.generate_api_key("sk_live")
    assert key.startswith("sk_live_"), "Should start with prefix"
    assert len(key) > 50, "Should be sufficiently long"
    
    # Test session token
    token = gen.generate_session_token(32)
    assert len(token) >= 32, "Should be at least 32 chars"
    assert token.replace('-', '').replace('_', '').isalnum(), "Should be URL-safe"
    
    # Test OTP
    otp = gen.generate_otp(6)
    assert len(otp) == 6, "Should be 6 digits"
    assert otp.isdigit(), "Should be numeric"
    
    # Test verification code
    code = gen.generate_verification_code(8)
    assert len(code) == 8, "Should be 8 characters"
    assert code.isupper() or code.isdigit(), "Should be uppercase"
    
    # Test uniqueness
    tokens = {gen.generate_session_token() for _ in range(100)}
    assert len(tokens) == 100, "Tokens should be unique"
    
    print("āœ“ Exercise 2 passed!")


# Uncomment to test:
# test_token_generator()


# ============================================================
# EXERCISE 3: Input Validator
# ============================================================
"""
Create a comprehensive InputValidator class with methods to validate:
1. Email addresses
2. Usernames (alphanumeric, 3-20 chars, must start with letter)
3. Passwords (min 8 chars, upper, lower, digit, special)
4. URLs (http/https only)
5. Phone numbers (10-15 digits, optional +)

Each method should return (is_valid, error_message or None)
"""

class InputValidator:
    """Validate various types of user input."""
    
    @staticmethod
    def validate_email(email: str) -> tuple[bool, Optional[str]]:
        """
        Validate email address format.
        Returns (is_valid, error_message or None)
        """
        pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not email or not re.match(pattern, email):
            return False, "Invalid email format"
        if len(email) > 254:
            return False, "Email too long"
        return True, None
    
    @staticmethod
    def validate_username(username: str) -> tuple[bool, Optional[str]]:
        """
        Validate username.
        Rules: 3-20 chars, alphanumeric + underscore, starts with letter.
        """
        if not username:
            return False, "Username is required"
        if len(username) < 3:
            return False, "Username must be at least 3 characters"
        if len(username) > 20:
            return False, "Username must be at most 20 characters"
        if not username[0].isalpha():
            return False, "Username must start with a letter"
        if not re.match(r'^[a-zA-Z][a-zA-Z0-9_]*$', username):
            return False, "Username can only contain letters, numbers, and underscores"
        return True, None
    
    @staticmethod
    def validate_password(password: str) -> tuple[bool, Optional[str]]:
        """
        Validate password strength.
        Rules: 8+ chars, uppercase, lowercase, digit, special char.
        """
        if len(password) < 8:
            return False, "Password must be at least 8 characters"
        if not re.search(r'[a-z]', password):
            return False, "Password must contain a lowercase letter"
        if not re.search(r'[A-Z]', password):
            return False, "Password must contain an uppercase letter"
        if not re.search(r'\d', password):
            return False, "Password must contain a digit"
        if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
            return False, "Password must contain a special character"
        return True, None
    
    @staticmethod
    def validate_url(url: str) -> tuple[bool, Optional[str]]:
        """Validate URL (http/https only)."""
        pattern = (
            r'^https?://'
            r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
            r'localhost|'
            r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
            r'(?::\d+)?'
            r'(?:/?|[/?]\S+)?$'
        )
        if not url or not re.match(pattern, url, re.IGNORECASE):
            return False, "Invalid URL format (must be http or https)"
        return True, None
    
    @staticmethod
    def validate_phone(phone: str) -> tuple[bool, Optional[str]]:
        """
        Validate phone number.
        Rules: 10-15 digits, optional leading +
        """
        cleaned = re.sub(r'[\s\-\(\)\.]', '', phone)
        if not re.match(r'^\+?\d{10,15}$', cleaned):
            return False, "Phone must be 10-15 digits, optional + prefix"
        return True, None


# Test your implementation
def test_input_validator():
    v = InputValidator()
    
    # Email tests
    assert v.validate_email("user@example.com")[0], "Valid email"
    assert not v.validate_email("invalid-email")[0], "Invalid email"
    assert not v.validate_email("@example.com")[0], "Missing user"
    
    # Username tests
    assert v.validate_username("john_doe")[0], "Valid username"
    assert not v.validate_username("ab")[0], "Too short"
    assert not v.validate_username("123user")[0], "Starts with digit"
    assert not v.validate_username("user@name")[0], "Invalid char"
    
    # Password tests
    assert v.validate_password("Str0ng@Pass")[0], "Strong password"
    assert not v.validate_password("weak")[0], "Too short"
    assert not v.validate_password("alllowercase123!")[0], "No uppercase"
    
    # URL tests
    assert v.validate_url("https://example.com")[0], "Valid HTTPS"
    assert v.validate_url("http://localhost:8080/path")[0], "With port"
    assert not v.validate_url("ftp://example.com")[0], "Wrong protocol"
    
    # Phone tests
    assert v.validate_phone("+1234567890")[0], "With plus"
    assert v.validate_phone("1234567890")[0], "Without plus"
    assert not v.validate_phone("123")[0], "Too short"
    
    print("āœ“ Exercise 3 passed!")


# Uncomment to test:
# test_input_validator()


# ============================================================
# EXERCISE 4: HTML Sanitizer
# ============================================================
"""
Create an HTMLSanitizer class that:
1. Escapes HTML special characters
2. Removes script tags entirely
3. Removes event handlers (onclick, onerror, etc.)
4. Allows only safe tags (p, b, i, em, strong, br)
5. Has a strict mode that removes ALL HTML

Note: For production, use a library like bleach!
"""

class HTMLSanitizer:
    """Sanitize HTML to prevent XSS."""
    
    SAFE_TAGS = {'p', 'b', 'i', 'em', 'strong', 'br'}
    
    @staticmethod
    def escape(text: str) -> str:
        """Escape HTML special characters."""
        return html.escape(text)
    
    @staticmethod
    def remove_scripts(html_text: str) -> str:
        """Remove script tags and their content."""
        # Remove script tags and content (case insensitive)
        result = re.sub(
            r'<script[^>]*>.*?</script>',
            '',
            html_text,
            flags=re.IGNORECASE | re.DOTALL
        )
        # Also remove self-closing and unclosed script tags
        result = re.sub(r'<script[^>]*/?>', '', result, flags=re.IGNORECASE)
        return result
    
    @staticmethod
    def remove_event_handlers(html_text: str) -> str:
        """Remove event handler attributes (on*)."""
        # Remove on* attributes (onclick, onerror, onload, etc.)
        return re.sub(
            r'\s+on\w+\s*=\s*["\'][^"\']*["\']',
            '',
            html_text,
            flags=re.IGNORECASE
        )
    
    @classmethod
    def sanitize(cls, html_text: str, strict: bool = False) -> str:
        """
        Sanitize HTML.
        If strict=True, escape all HTML.
        If strict=False, allow safe tags.
        """
        if strict:
            return cls.escape(html_text)
        
        # Remove dangerous content
        result = cls.remove_scripts(html_text)
        result = cls.remove_event_handlers(result)
        
        # Remove unsafe tags but keep content
        def replace_tag(match):
            tag = match.group(1).lower()
            if tag in cls.SAFE_TAGS:
                return match.group(0)
            return ''  # Remove the tag
        
        result = re.sub(r'<(/?)([a-zA-Z][a-zA-Z0-9]*)[^>]*>', replace_tag, result)
        return result


# Test your implementation
def test_html_sanitizer():
    s = HTMLSanitizer()
    
    # Test escaping
    assert '&lt;' in s.escape('<'), "Should escape <"
    assert '&gt;' in s.escape('>'), "Should escape >"
    assert '&amp;' in s.escape('&'), "Should escape &"
    
    # Test script removal
    result = s.remove_scripts('<p>Hello</p><script>alert("XSS")</script>')
    assert '<script>' not in result, "Should remove script tags"
    assert 'Hello' in result, "Should keep content"
    
    # Test event handler removal
    result = s.remove_event_handlers('<img src="x" onerror="alert(1)">')
    assert 'onerror' not in result, "Should remove onerror"
    
    # Test strict sanitization
    strict = s.sanitize('<b>Bold</b><script>bad</script>', strict=True)
    assert '<b>' not in strict, "Strict should escape all HTML"
    
    print("āœ“ Exercise 4 passed!")


# Uncomment to test:
# test_html_sanitizer()


# ============================================================
# EXERCISE 5: Secure Path Handler
# ============================================================
"""
Create a SecurePathHandler class that:
1. Prevents directory traversal attacks
2. Validates file extensions
3. Generates safe filenames from user input
4. Checks if a path is within allowed directory
"""

class SecurePathHandler:
    """Handle file paths securely."""
    
    def __init__(self, base_dir: str, allowed_extensions: set[str] | None = None):
        self.base_dir = Path(base_dir).resolve()
        self.allowed_extensions = allowed_extensions or {'.txt', '.pdf', '.jpg', '.png'}
    
    def is_safe_path(self, path: str) -> bool:
        """Check if path is within base directory."""
        try:
            # Don't allow absolute paths
            if os.path.isabs(path):
                return False
            
            # Check for traversal attempts
            if '..' in path:
                return False
            
            # Resolve and check if within base
            target = (self.base_dir / path).resolve()
            return str(target).startswith(str(self.base_dir))
        except (ValueError, OSError):
            return False
    
    def safe_join(self, *paths) -> Optional[Path]:
        """
        Safely join paths, return None if traversal detected.
        """
        try:
            # Join all paths
            joined = Path(*paths)
            target = (self.base_dir / joined).resolve()
            
            # Verify result is under base_dir
            if str(target).startswith(str(self.base_dir)):
                return target
            return None
        except (ValueError, OSError):
            return None
    
    def sanitize_filename(self, filename: str) -> str:
        """
        Sanitize filename:
        - Remove path separators
        - Remove dangerous characters
        - Limit length
        - Ensure valid extension
        """
        # Get just the filename, no path
        filename = os.path.basename(filename)
        
        # Remove potentially dangerous characters
        filename = re.sub(r'[^\w\s\-\.]', '', filename)
        
        # Remove leading dots (hidden files)
        filename = filename.lstrip('.')
        
        # Split name and extension
        name, ext = os.path.splitext(filename)
        
        # Limit lengths
        name = name[:100] if name else 'unnamed'
        ext = ext[:10]
        
        return f"{name}{ext}" if ext else name
    
    def validate_extension(self, filename: str) -> bool:
        """Check if file extension is allowed."""
        ext = os.path.splitext(filename)[1].lower()
        return ext in self.allowed_extensions


# Test your implementation
def test_secure_path_handler():
    handler = SecurePathHandler("/var/uploads", {'.txt', '.pdf'})
    
    # Test traversal prevention
    assert not handler.is_safe_path("../../../etc/passwd"), "Should block traversal"
    assert not handler.is_safe_path("/etc/passwd"), "Should block absolute paths"
    
    # Test safe join
    assert handler.safe_join("docs", "file.txt") is not None, "Valid path"
    assert handler.safe_join("..", "etc", "passwd") is None, "Traversal blocked"
    
    # Test filename sanitization
    safe = handler.sanitize_filename("../../../etc/passwd")
    assert ".." not in safe, "Should remove traversal"
    assert "/" not in safe, "Should remove separators"
    
    # Test extension validation
    assert handler.validate_extension("file.txt"), "Allowed extension"
    assert handler.validate_extension("file.pdf"), "Allowed extension"
    assert not handler.validate_extension("file.exe"), "Blocked extension"
    
    print("āœ“ Exercise 5 passed!")


# Uncomment to test:
# test_secure_path_handler()


# ============================================================
# EXERCISE 6: Rate Limiter
# ============================================================
"""
Create a RateLimiter class that:
1. Limits requests per time window
2. Supports different limits per endpoint
3. Returns remaining requests and reset time
4. Can be reset manually
"""

class RateLimiter:
    """Rate limit requests by identifier."""
    
    def __init__(self, default_limit: int = 100, window_seconds: int = 60):
        self.default_limit = default_limit
        self.window = timedelta(seconds=window_seconds)
        self.requests: dict[str, list[datetime]] = {}  # key: "identifier:endpoint"
        self.limits: dict[str, int] = {}  # endpoint -> limit
    
    def check(self, identifier: str, endpoint: str = "default") -> dict:
        """
        Check if request is allowed.
        Returns {
            'allowed': bool,
            'remaining': int,
            'reset_at': datetime
        }
        """
        now = datetime.now()
        key = f"{identifier}:{endpoint}"
        window_start = now - self.window
        
        # Get limit for endpoint
        limit = self.limits.get(endpoint, self.default_limit)
        
        # Clean old requests
        if key in self.requests:
            self.requests[key] = [
                req for req in self.requests[key]
                if req > window_start
            ]
        else:
            self.requests[key] = []
        
        current_count = len(self.requests[key])
        reset_at = now + self.window
        
        if current_count >= limit:
            return {
                'allowed': False,
                'remaining': 0,
                'reset_at': reset_at
            }
        
        # Record this request
        self.requests[key].append(now)
        
        return {
            'allowed': True,
            'remaining': limit - current_count - 1,
            'reset_at': reset_at
        }
    
    def set_limit(self, endpoint: str, limit: int):
        """Set custom limit for endpoint."""
        self.limits[endpoint] = limit
    
    def reset(self, identifier: str, endpoint: str | None = None):
        """Reset rate limit for identifier."""
        if endpoint:
            key = f"{identifier}:{endpoint}"
            self.requests.pop(key, None)
        else:
            # Reset all endpoints for this identifier
            keys_to_remove = [
                k for k in self.requests 
                if k.startswith(f"{identifier}:")
            ]
            for k in keys_to_remove:
                del self.requests[k]


# Test your implementation
def test_rate_limiter():
    limiter = RateLimiter(default_limit=5, window_seconds=60)
    
    # Test basic limiting
    for i in range(5):
        result = limiter.check("user1")
        assert result['allowed'], f"Request {i+1} should be allowed"
        assert result['remaining'] == 4 - i, f"Remaining should be {4-i}"
    
    # 6th request should be blocked
    result = limiter.check("user1")
    assert not result['allowed'], "Should be rate limited"
    assert result['remaining'] == 0
    
    # Different user should have separate limit
    result = limiter.check("user2")
    assert result['allowed'], "Different user should be allowed"
    
    # Test reset
    limiter.reset("user1")
    result = limiter.check("user1")
    assert result['allowed'], "Should be allowed after reset"
    
    # Test custom endpoint limits
    limiter.set_limit("api/expensive", 2)
    limiter.check("user1", "api/expensive")
    limiter.check("user1", "api/expensive")
    result = limiter.check("user1", "api/expensive")
    assert not result['allowed'], "Custom limit should apply"
    
    print("āœ“ Exercise 6 passed!")


# Uncomment to test:
# test_rate_limiter()


# ============================================================
# EXERCISE 7: Secure Session Manager
# ============================================================
"""
Create a SecureSessionManager that:
1. Creates sessions with secure tokens
2. Validates sessions with expiration
3. Supports session data storage
4. Can invalidate individual or all sessions for a user
5. Prevents session fixation
"""

class SecureSessionManager:
    """Manage user sessions securely."""
    
    def __init__(self, expiry_minutes: int = 30):
        self.expiry = timedelta(minutes=expiry_minutes)
        self.sessions: dict[str, dict] = {}
    
    def create_session(self, user_id: str, data: dict | None = None) -> str:
        """
        Create new session, return token.
        Store user_id, creation time, expiry, and optional data.
        """
        token = secrets.token_urlsafe(32)
        now = datetime.now()
        
        self.sessions[token] = {
            'user_id': user_id,
            'created_at': now,
            'expires_at': now + self.expiry,
            'data': data or {}
        }
        
        return token
    
    def validate_session(self, token: str) -> Optional[dict]:
        """
        Validate session token.
        Return session data if valid, None if invalid/expired.
        """
        session = self.sessions.get(token)
        
        if not session:
            return None
        
        if datetime.now() > session['expires_at']:
            del self.sessions[token]
            return None
        
        return {
            'user_id': session['user_id'],
            'data': session['data'],
            'created_at': session['created_at'],
            'expires_at': session['expires_at']
        }
    
    def update_session_data(self, token: str, data: dict) -> bool:
        """Update data for existing session."""
        if token not in self.sessions:
            return False
        
        if datetime.now() > self.sessions[token]['expires_at']:
            del self.sessions[token]
            return False
        
        self.sessions[token]['data'] = data
        return True
    
    def invalidate_session(self, token: str):
        """Invalidate single session."""
        self.sessions.pop(token, None)
    
    def invalidate_all_sessions(self, user_id: str):
        """Invalidate all sessions for a user."""
        tokens_to_remove = [
            token for token, session in self.sessions.items()
            if session['user_id'] == user_id
        ]
        for token in tokens_to_remove:
            del self.sessions[token]
    
    def regenerate_token(self, old_token: str) -> Optional[str]:
        """
        Regenerate session token (prevent fixation).
        Keep session data, create new token.
        """
        session = self.sessions.get(old_token)
        
        if not session:
            return None
        
        if datetime.now() > session['expires_at']:
            del self.sessions[old_token]
            return None
        
        # Create new token with same session data
        new_token = secrets.token_urlsafe(32)
        self.sessions[new_token] = session.copy()
        
        # Invalidate old token
        del self.sessions[old_token]
        
        return new_token


# Test your implementation
def test_secure_session_manager():
    mgr = SecureSessionManager(expiry_minutes=30)
    
    # Test session creation
    token = mgr.create_session("user123", {"role": "admin"})
    assert token is not None, "Should return token"
    assert len(token) >= 32, "Token should be long enough"
    
    # Test validation
    session = mgr.validate_session(token)
    assert session is not None, "Should return session data"
    assert session.get("user_id") == "user123"
    assert session.get("data", {}).get("role") == "admin"
    
    # Test invalid token
    assert mgr.validate_session("invalid") is None
    
    # Test update
    assert mgr.update_session_data(token, {"role": "user"})
    session = mgr.validate_session(token)
    assert session["data"]["role"] == "user"
    
    # Test invalidation
    mgr.invalidate_session(token)
    assert mgr.validate_session(token) is None
    
    # Test invalidate all
    token1 = mgr.create_session("user456")
    token2 = mgr.create_session("user456")
    mgr.invalidate_all_sessions("user456")
    assert mgr.validate_session(token1) is None
    assert mgr.validate_session(token2) is None
    
    # Test token regeneration
    token = mgr.create_session("user789", {"test": "data"})
    new_token = mgr.regenerate_token(token)
    assert new_token != token, "Should be different token"
    assert mgr.validate_session(token) is None, "Old token should be invalid"
    session = mgr.validate_session(new_token)
    assert session["data"]["test"] == "data", "Data should be preserved"
    
    print("āœ“ Exercise 7 passed!")


# Uncomment to test:
# test_secure_session_manager()


# ============================================================
# EXERCISE 8: Audit Logger
# ============================================================
"""
Create a SecureAuditLogger that:
1. Logs security events with timestamps
2. Redacts sensitive information (passwords, tokens)
3. Includes event type, user, IP, and details
4. Prevents log injection attacks
5. Can export logs in a secure format
"""

class SecureAuditLogger:
    """Log security events securely."""
    
    SENSITIVE_KEYS = {'password', 'token', 'secret', 'api_key', 'credit_card'}
    
    def __init__(self):
        self.logs = []
    
    def _redact_sensitive(self, data: dict) -> dict:
        """Redact sensitive information from data."""
        if not data:
            return {}
        
        redacted = {}
        for key, value in data.items():
            if any(sensitive in key.lower() for sensitive in self.SENSITIVE_KEYS):
                redacted[key] = '[REDACTED]'
            elif isinstance(value, dict):
                redacted[key] = self._redact_sensitive(value)
            else:
                redacted[key] = value
        return redacted
    
    def _sanitize_log_data(self, text: str) -> str:
        """Prevent log injection by sanitizing text."""
        # Remove newlines, carriage returns, and other control characters
        sanitized = re.sub(r'[\n\r\t]', ' ', str(text))
        # Limit length
        return sanitized[:500]
    
    def log(self, event_type: str, user: str, ip: str, details: dict | None = None):
        """
        Log a security event.
        Auto-redacts sensitive info and sanitizes input.
        """
        log_entry = {
            'timestamp': datetime.now().isoformat(),
            'event_type': self._sanitize_log_data(event_type),
            'user': self._sanitize_log_data(user),
            'ip': self._sanitize_log_data(ip),
            'details': self._redact_sensitive(details) if details else {}
        }
        self.logs.append(log_entry)
    
    def get_logs(self, event_type: str | None = None, user: str | None = None) -> list:
        """Get logs, optionally filtered."""
        result = self.logs
        
        if event_type:
            result = [log for log in result if log['event_type'] == event_type]
        
        if user:
            result = [log for log in result if log['user'] == user]
        
        return result
    
    def export_logs(self) -> str:
        """Export logs as JSON string."""
        import json
        return json.dumps(self.logs, indent=2, default=str)


# Test your implementation
def test_secure_audit_logger():
    import json
    
    logger = SecureAuditLogger()
    
    # Test logging
    logger.log("LOGIN", "user1", "192.168.1.1", {"password": "secret123"})
    logger.log("API_CALL", "user1", "192.168.1.1", {"token": "abc123", "endpoint": "/api/data"})
    
    # Test redaction
    logs = logger.get_logs()
    for log in logs:
        if 'password' in str(log):
            assert 'secret123' not in str(log), "Password should be redacted"
        if 'token' in str(log):
            assert 'abc123' not in str(log), "Token should be redacted"
    
    # Test log injection prevention
    logger.log("TEST", "user\nFAKE_LOG: admin", "127.0.0.1")
    logs = logger.get_logs()
    for log in logs:
        assert '\n' not in str(log.get('user', '')), "Newlines should be sanitized"
    
    # Test filtering
    logger.log("LOGOUT", "user2", "192.168.1.2")
    user1_logs = logger.get_logs(user="user1")
    assert len(user1_logs) == 2, "Should filter by user"
    
    login_logs = logger.get_logs(event_type="LOGIN")
    assert len(login_logs) == 1, "Should filter by event type"
    
    # Test export
    exported = logger.export_logs()
    parsed = json.loads(exported)
    assert isinstance(parsed, list), "Export should be valid JSON"
    
    print("āœ“ Exercise 8 passed!")


# Uncomment to test:
# test_secure_audit_logger()


# ============================================================
# EXERCISE 9: CSRF Token Manager
# ============================================================
"""
Create a CSRFTokenManager that:
1. Generates CSRF tokens per session
2. Validates tokens with constant-time comparison
3. Has configurable expiration
4. Supports one-time tokens (deleted after use)
"""

class CSRFTokenManager:
    """Manage CSRF tokens for form protection."""
    
    def __init__(self, expiry_minutes: int = 60, one_time: bool = False):
        self.expiry = timedelta(minutes=expiry_minutes)
        self.one_time = one_time
        self.tokens: dict[str, dict] = {}  # session_id -> {token, expires_at}
    
    def generate_token(self, session_id: str) -> str:
        """Generate CSRF token for session."""
        token = secrets.token_urlsafe(32)
        now = datetime.now()
        
        self.tokens[session_id] = {
            'token': token,
            'expires_at': now + self.expiry
        }
        
        return token
    
    def validate_token(self, session_id: str, token: str) -> bool:
        """
        Validate CSRF token.
        Use constant-time comparison.
        Delete if one_time mode.
        """
        token_data = self.tokens.get(session_id)
        
        if not token_data:
            return False
        
        # Check expiration
        if datetime.now() > token_data['expires_at']:
            del self.tokens[session_id]
            return False
        
        # Constant-time comparison
        is_valid = secrets.compare_digest(token_data['token'], token)
        
        # Delete if one-time mode and valid
        if is_valid and self.one_time:
            del self.tokens[session_id]
        
        return is_valid
    
    def invalidate_tokens(self, session_id: str):
        """Invalidate all tokens for session."""
        self.tokens.pop(session_id, None)


# Test your implementation
def test_csrf_token_manager():
    # Test regular mode
    csrf = CSRFTokenManager(expiry_minutes=60, one_time=False)
    
    token = csrf.generate_token("session1")
    assert token is not None
    assert len(token) >= 32
    
    assert csrf.validate_token("session1", token), "Valid token"
    assert not csrf.validate_token("session1", "wrong"), "Invalid token"
    assert not csrf.validate_token("wrong_session", token), "Wrong session"
    
    # Token should still work (not one-time)
    assert csrf.validate_token("session1", token), "Should still be valid"
    
    # Test one-time mode
    csrf_ot = CSRFTokenManager(expiry_minutes=60, one_time=True)
    
    token = csrf_ot.generate_token("session2")
    assert csrf_ot.validate_token("session2", token), "First use valid"
    assert not csrf_ot.validate_token("session2", token), "Second use invalid (one-time)"
    
    # Test invalidation
    token = csrf.generate_token("session3")
    csrf.invalidate_tokens("session3")
    assert not csrf.validate_token("session3", token), "Invalidated"
    
    print("āœ“ Exercise 9 passed!")


# Uncomment to test:
# test_csrf_token_manager()


# ============================================================
# EXERCISE 10: Secure Configuration Manager
# ============================================================
"""
Create a SecureConfigManager that:
1. Loads configuration from environment variables
2. Supports default values
3. Validates required fields
4. Encrypts sensitive values in memory
5. Never exposes secrets in __repr__ or logs
"""

class SecureConfigManager:
    """Manage application configuration securely."""
    
    SENSITIVE_FIELDS = {'password', 'secret', 'api_key', 'token', 'private_key'}
    
    def __init__(self):
        self._config = {}
        self._sensitive_keys = set()
        # Use a runtime key for obfuscation (not real encryption)
        self._mask_key = secrets.token_bytes(32)
    
    def load_from_env(self, schema: dict) -> list[str]:
        """
        Load config from environment.
        Schema: {
            'DATABASE_URL': {'required': True},
            'DEBUG': {'required': False, 'default': 'false'},
            'API_KEY': {'required': True, 'sensitive': True}
        }
        Returns list of validation errors.
        """
        errors = []
        
        for key, options in schema.items():
            value = os.environ.get(key)
            
            if value is None:
                if options.get('required', False):
                    errors.append(f"Missing required environment variable: {key}")
                elif 'default' in options:
                    value = options['default']
            
            if value is not None:
                self._config[key] = value
                if options.get('sensitive', False):
                    self._sensitive_keys.add(key)
        
        return errors
    
    def get(self, key: str, default=None):
        """Get configuration value."""
        return self._config.get(key, default)
    
    def is_sensitive(self, key: str) -> bool:
        """Check if key contains sensitive data."""
        return key in self._sensitive_keys
    
    def __repr__(self):
        """Safe representation that hides sensitive values."""
        safe_config = {}
        for key, value in self._config.items():
            if key in self._sensitive_keys:
                safe_config[key] = '[REDACTED]'
            else:
                safe_config[key] = value
        return f"SecureConfigManager({safe_config})"
    
    def to_dict(self, include_sensitive: bool = False) -> dict:
        """Export config as dict, optionally masking sensitive values."""
        result = {}
        for key, value in self._config.items():
            if key in self._sensitive_keys and not include_sensitive:
                result[key] = '[REDACTED]'
            else:
                result[key] = value
        return result


# Test your implementation
def test_secure_config_manager():
    # Set up test environment
    os.environ['TEST_DB_URL'] = 'postgresql://localhost/test'
    os.environ['TEST_API_KEY'] = 'sk_test_secret123'
    
    config = SecureConfigManager()
    
    schema = {
        'TEST_DB_URL': {'required': True},
        'TEST_API_KEY': {'required': True, 'sensitive': True},
        'TEST_DEBUG': {'required': False, 'default': 'false'},
        'TEST_MISSING': {'required': False}
    }
    
    errors = config.load_from_env(schema)
    assert len(errors) == 0, f"Should have no errors: {errors}"
    
    # Test getting values
    assert config.get('TEST_DB_URL') == 'postgresql://localhost/test'
    assert config.get('TEST_API_KEY') == 'sk_test_secret123'
    assert config.get('TEST_DEBUG') == 'false'
    assert config.get('TEST_MISSING') is None
    assert config.get('TEST_MISSING', 'default') == 'default'
    
    # Test sensitive detection
    assert config.is_sensitive('TEST_API_KEY')
    assert not config.is_sensitive('TEST_DB_URL')
    
    # Test repr hides sensitive values
    repr_str = repr(config)
    assert 'sk_test_secret123' not in repr_str, "Should not expose secret in repr"
    
    # Test to_dict
    safe_dict = config.to_dict(include_sensitive=False)
    assert 'sk_test_secret123' not in str(safe_dict), "Should mask in dict"
    
    full_dict = config.to_dict(include_sensitive=True)
    assert config.get('TEST_API_KEY') in str(full_dict), "Should include when requested"
    
    # Test required validation
    schema_missing = {
        'NONEXISTENT_REQUIRED': {'required': True}
    }
    errors = config.load_from_env(schema_missing)
    assert len(errors) > 0, "Should report missing required field"
    
    # Cleanup
    del os.environ['TEST_DB_URL']
    del os.environ['TEST_API_KEY']
    
    print("āœ“ Exercise 10 passed!")


# Uncomment to test:
# test_secure_config_manager()


# ============================================================
# Run all tests
# ============================================================

def run_all_tests():
    """Run all exercise tests."""
    print("Running Security Exercises Tests")
    print("=" * 50)
    
    tests = [
        ("Password Manager", test_password_manager),
        ("Token Generator", test_token_generator),
        ("Input Validator", test_input_validator),
        ("HTML Sanitizer", test_html_sanitizer),
        ("Secure Path Handler", test_secure_path_handler),
        ("Rate Limiter", test_rate_limiter),
        ("Session Manager", test_secure_session_manager),
        ("Audit Logger", test_secure_audit_logger),
        ("CSRF Token Manager", test_csrf_token_manager),
        ("Secure Config Manager", test_secure_config_manager),
    ]
    
    passed = 0
    failed = 0
    
    for name, test_func in tests:
        try:
            test_func()
            passed += 1
        except (AssertionError, Exception) as e:
            print(f"āœ— {name}: {e}")
            failed += 1
    
    print("\n" + "=" * 50)
    print(f"Results: {passed} passed, {failed} failed")
    
    if failed == 0:
        print("\nšŸŽ‰ Congratulations! You've completed the Security module!")
        print("You now understand Python security best practices!")


# Uncomment to run all tests:
# run_all_tests()


# ============================================================
# BONUS CHALLENGE: Build a Complete Auth System
# ============================================================
"""
CHALLENGE: Create a complete authentication system that includes:
1. User registration with password hashing
2. Login with rate limiting
3. Session management with secure tokens
4. CSRF protection
5. Password reset with time-limited tokens
6. Audit logging

This is a comprehensive exercise that combines all security concepts!
"""
Exercises - Python Tutorial | DeepML