python
exercises
exercises.pyšpython
"""
Python Security - Exercises
============================
Practice security concepts with hands-on exercises.
"""
import hashlib
import secrets
import os
import re
import html
from typing import Optional
from pathlib import Path
from datetime import datetime, timedelta
# ============================================================
# EXERCISE 1: Password Hasher Class
# ============================================================
"""
Create a PasswordManager class that:
1. Hashes passwords using PBKDF2 with SHA-256
2. Uses secure random salt (16 bytes)
3. Has configurable iteration count
4. Stores hash and salt together
5. Can verify passwords
Expected usage:
pm = PasswordManager(iterations=100000)
stored = pm.hash_password("mysecret")
assert pm.verify_password("mysecret", stored)
assert not pm.verify_password("wrong", stored)
"""
class PasswordManager:
"""Secure password hashing and verification."""
def __init__(self, iterations: int = 100000):
self.iterations = iterations
def hash_password(self, password: str) -> str:
"""
Hash password with salt.
Returns: "salt_hex:hash_hex" format string
"""
salt = os.urandom(16)
hash_bytes = hashlib.pbkdf2_hmac(
'sha256',
password.encode('utf-8'),
salt,
self.iterations
)
return f"{salt.hex()}:{hash_bytes.hex()}"
def verify_password(self, password: str, stored: str) -> bool:
"""Verify password against stored hash."""
try:
salt_hex, hash_hex = stored.split(':')
salt = bytes.fromhex(salt_hex)
stored_hash = bytes.fromhex(hash_hex)
computed_hash = hashlib.pbkdf2_hmac(
'sha256',
password.encode('utf-8'),
salt,
self.iterations
)
return secrets.compare_digest(computed_hash, stored_hash)
except (ValueError, AttributeError):
return False
# Test your implementation
def test_password_manager():
pm = PasswordManager(iterations=50000) # Lower for testing
# Test hashing
password = "SuperSecret123!"
stored = pm.hash_password(password)
assert stored is not None, "hash_password should return a string"
assert ':' in stored, "Format should be 'salt:hash'"
# Test verification
assert pm.verify_password(password, stored), "Should verify correct password"
assert not pm.verify_password("wrong", stored), "Should reject wrong password"
assert not pm.verify_password("", stored), "Should reject empty password"
# Test different salts
stored2 = pm.hash_password(password)
assert stored != stored2, "Same password should have different salts"
assert pm.verify_password(password, stored2), "Both hashes should verify"
print("ā Exercise 1 passed!")
# Uncomment to test:
# test_password_manager()
# ============================================================
# EXERCISE 2: Secure Token Generator
# ============================================================
"""
Create a TokenGenerator class that generates various secure tokens:
1. API keys with configurable prefix
2. Session tokens (URL-safe)
3. OTP (One-Time Password) codes
4. Verification codes (numeric, configurable length)
"""
class TokenGenerator:
"""Generate various types of secure tokens."""
@staticmethod
def generate_api_key(prefix: str = "api") -> str:
"""
Generate API key like: api_aBcD1234...
Should be URL-safe and 44 characters after prefix.
"""
token = secrets.token_urlsafe(32) # 43 chars
return f"{prefix}_{token}"
@staticmethod
def generate_session_token(length: int = 32) -> str:
"""Generate URL-safe session token."""
return secrets.token_urlsafe(length)
@staticmethod
def generate_otp(length: int = 6) -> str:
"""Generate numeric OTP code (e.g., "123456")."""
digits = ''.join(str(secrets.randbelow(10)) for _ in range(length))
return digits
@staticmethod
def generate_verification_code(length: int = 8) -> str:
"""
Generate alphanumeric verification code.
Should be uppercase for readability.
"""
import string
alphabet = string.ascii_uppercase + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
# Test your implementation
def test_token_generator():
gen = TokenGenerator()
# Test API key
key = gen.generate_api_key("sk_live")
assert key.startswith("sk_live_"), "Should start with prefix"
assert len(key) > 50, "Should be sufficiently long"
# Test session token
token = gen.generate_session_token(32)
assert len(token) >= 32, "Should be at least 32 chars"
assert token.replace('-', '').replace('_', '').isalnum(), "Should be URL-safe"
# Test OTP
otp = gen.generate_otp(6)
assert len(otp) == 6, "Should be 6 digits"
assert otp.isdigit(), "Should be numeric"
# Test verification code
code = gen.generate_verification_code(8)
assert len(code) == 8, "Should be 8 characters"
assert code.isupper() or code.isdigit(), "Should be uppercase"
# Test uniqueness
tokens = {gen.generate_session_token() for _ in range(100)}
assert len(tokens) == 100, "Tokens should be unique"
print("ā Exercise 2 passed!")
# Uncomment to test:
# test_token_generator()
# ============================================================
# EXERCISE 3: Input Validator
# ============================================================
"""
Create a comprehensive InputValidator class with methods to validate:
1. Email addresses
2. Usernames (alphanumeric, 3-20 chars, must start with letter)
3. Passwords (min 8 chars, upper, lower, digit, special)
4. URLs (http/https only)
5. Phone numbers (10-15 digits, optional +)
Each method should return (is_valid, error_message or None)
"""
class InputValidator:
"""Validate various types of user input."""
@staticmethod
def validate_email(email: str) -> tuple[bool, Optional[str]]:
"""
Validate email address format.
Returns (is_valid, error_message or None)
"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
if not email or not re.match(pattern, email):
return False, "Invalid email format"
if len(email) > 254:
return False, "Email too long"
return True, None
@staticmethod
def validate_username(username: str) -> tuple[bool, Optional[str]]:
"""
Validate username.
Rules: 3-20 chars, alphanumeric + underscore, starts with letter.
"""
if not username:
return False, "Username is required"
if len(username) < 3:
return False, "Username must be at least 3 characters"
if len(username) > 20:
return False, "Username must be at most 20 characters"
if not username[0].isalpha():
return False, "Username must start with a letter"
if not re.match(r'^[a-zA-Z][a-zA-Z0-9_]*$', username):
return False, "Username can only contain letters, numbers, and underscores"
return True, None
@staticmethod
def validate_password(password: str) -> tuple[bool, Optional[str]]:
"""
Validate password strength.
Rules: 8+ chars, uppercase, lowercase, digit, special char.
"""
if len(password) < 8:
return False, "Password must be at least 8 characters"
if not re.search(r'[a-z]', password):
return False, "Password must contain a lowercase letter"
if not re.search(r'[A-Z]', password):
return False, "Password must contain an uppercase letter"
if not re.search(r'\d', password):
return False, "Password must contain a digit"
if not re.search(r'[!@#$%^&*(),.?":{}|<>]', password):
return False, "Password must contain a special character"
return True, None
@staticmethod
def validate_url(url: str) -> tuple[bool, Optional[str]]:
"""Validate URL (http/https only)."""
pattern = (
r'^https?://'
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|'
r'localhost|'
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})'
r'(?::\d+)?'
r'(?:/?|[/?]\S+)?$'
)
if not url or not re.match(pattern, url, re.IGNORECASE):
return False, "Invalid URL format (must be http or https)"
return True, None
@staticmethod
def validate_phone(phone: str) -> tuple[bool, Optional[str]]:
"""
Validate phone number.
Rules: 10-15 digits, optional leading +
"""
cleaned = re.sub(r'[\s\-\(\)\.]', '', phone)
if not re.match(r'^\+?\d{10,15}$', cleaned):
return False, "Phone must be 10-15 digits, optional + prefix"
return True, None
# Test your implementation
def test_input_validator():
v = InputValidator()
# Email tests
assert v.validate_email("user@example.com")[0], "Valid email"
assert not v.validate_email("invalid-email")[0], "Invalid email"
assert not v.validate_email("@example.com")[0], "Missing user"
# Username tests
assert v.validate_username("john_doe")[0], "Valid username"
assert not v.validate_username("ab")[0], "Too short"
assert not v.validate_username("123user")[0], "Starts with digit"
assert not v.validate_username("user@name")[0], "Invalid char"
# Password tests
assert v.validate_password("Str0ng@Pass")[0], "Strong password"
assert not v.validate_password("weak")[0], "Too short"
assert not v.validate_password("alllowercase123!")[0], "No uppercase"
# URL tests
assert v.validate_url("https://example.com")[0], "Valid HTTPS"
assert v.validate_url("http://localhost:8080/path")[0], "With port"
assert not v.validate_url("ftp://example.com")[0], "Wrong protocol"
# Phone tests
assert v.validate_phone("+1234567890")[0], "With plus"
assert v.validate_phone("1234567890")[0], "Without plus"
assert not v.validate_phone("123")[0], "Too short"
print("ā Exercise 3 passed!")
# Uncomment to test:
# test_input_validator()
# ============================================================
# EXERCISE 4: HTML Sanitizer
# ============================================================
"""
Create an HTMLSanitizer class that:
1. Escapes HTML special characters
2. Removes script tags entirely
3. Removes event handlers (onclick, onerror, etc.)
4. Allows only safe tags (p, b, i, em, strong, br)
5. Has a strict mode that removes ALL HTML
Note: For production, use a library like bleach!
"""
class HTMLSanitizer:
"""Sanitize HTML to prevent XSS."""
SAFE_TAGS = {'p', 'b', 'i', 'em', 'strong', 'br'}
@staticmethod
def escape(text: str) -> str:
"""Escape HTML special characters."""
return html.escape(text)
@staticmethod
def remove_scripts(html_text: str) -> str:
"""Remove script tags and their content."""
# Remove script tags and content (case insensitive)
result = re.sub(
r'<script[^>]*>.*?</script>',
'',
html_text,
flags=re.IGNORECASE | re.DOTALL
)
# Also remove self-closing and unclosed script tags
result = re.sub(r'<script[^>]*/?>', '', result, flags=re.IGNORECASE)
return result
@staticmethod
def remove_event_handlers(html_text: str) -> str:
"""Remove event handler attributes (on*)."""
# Remove on* attributes (onclick, onerror, onload, etc.)
return re.sub(
r'\s+on\w+\s*=\s*["\'][^"\']*["\']',
'',
html_text,
flags=re.IGNORECASE
)
@classmethod
def sanitize(cls, html_text: str, strict: bool = False) -> str:
"""
Sanitize HTML.
If strict=True, escape all HTML.
If strict=False, allow safe tags.
"""
if strict:
return cls.escape(html_text)
# Remove dangerous content
result = cls.remove_scripts(html_text)
result = cls.remove_event_handlers(result)
# Remove unsafe tags but keep content
def replace_tag(match):
tag = match.group(1).lower()
if tag in cls.SAFE_TAGS:
return match.group(0)
return '' # Remove the tag
result = re.sub(r'<(/?)([a-zA-Z][a-zA-Z0-9]*)[^>]*>', replace_tag, result)
return result
# Test your implementation
def test_html_sanitizer():
s = HTMLSanitizer()
# Test escaping
assert '<' in s.escape('<'), "Should escape <"
assert '>' in s.escape('>'), "Should escape >"
assert '&' in s.escape('&'), "Should escape &"
# Test script removal
result = s.remove_scripts('<p>Hello</p><script>alert("XSS")</script>')
assert '<script>' not in result, "Should remove script tags"
assert 'Hello' in result, "Should keep content"
# Test event handler removal
result = s.remove_event_handlers('<img src="x" onerror="alert(1)">')
assert 'onerror' not in result, "Should remove onerror"
# Test strict sanitization
strict = s.sanitize('<b>Bold</b><script>bad</script>', strict=True)
assert '<b>' not in strict, "Strict should escape all HTML"
print("ā Exercise 4 passed!")
# Uncomment to test:
# test_html_sanitizer()
# ============================================================
# EXERCISE 5: Secure Path Handler
# ============================================================
"""
Create a SecurePathHandler class that:
1. Prevents directory traversal attacks
2. Validates file extensions
3. Generates safe filenames from user input
4. Checks if a path is within allowed directory
"""
class SecurePathHandler:
"""Handle file paths securely."""
def __init__(self, base_dir: str, allowed_extensions: set[str] | None = None):
self.base_dir = Path(base_dir).resolve()
self.allowed_extensions = allowed_extensions or {'.txt', '.pdf', '.jpg', '.png'}
def is_safe_path(self, path: str) -> bool:
"""Check if path is within base directory."""
try:
# Don't allow absolute paths
if os.path.isabs(path):
return False
# Check for traversal attempts
if '..' in path:
return False
# Resolve and check if within base
target = (self.base_dir / path).resolve()
return str(target).startswith(str(self.base_dir))
except (ValueError, OSError):
return False
def safe_join(self, *paths) -> Optional[Path]:
"""
Safely join paths, return None if traversal detected.
"""
try:
# Join all paths
joined = Path(*paths)
target = (self.base_dir / joined).resolve()
# Verify result is under base_dir
if str(target).startswith(str(self.base_dir)):
return target
return None
except (ValueError, OSError):
return None
def sanitize_filename(self, filename: str) -> str:
"""
Sanitize filename:
- Remove path separators
- Remove dangerous characters
- Limit length
- Ensure valid extension
"""
# Get just the filename, no path
filename = os.path.basename(filename)
# Remove potentially dangerous characters
filename = re.sub(r'[^\w\s\-\.]', '', filename)
# Remove leading dots (hidden files)
filename = filename.lstrip('.')
# Split name and extension
name, ext = os.path.splitext(filename)
# Limit lengths
name = name[:100] if name else 'unnamed'
ext = ext[:10]
return f"{name}{ext}" if ext else name
def validate_extension(self, filename: str) -> bool:
"""Check if file extension is allowed."""
ext = os.path.splitext(filename)[1].lower()
return ext in self.allowed_extensions
# Test your implementation
def test_secure_path_handler():
handler = SecurePathHandler("/var/uploads", {'.txt', '.pdf'})
# Test traversal prevention
assert not handler.is_safe_path("../../../etc/passwd"), "Should block traversal"
assert not handler.is_safe_path("/etc/passwd"), "Should block absolute paths"
# Test safe join
assert handler.safe_join("docs", "file.txt") is not None, "Valid path"
assert handler.safe_join("..", "etc", "passwd") is None, "Traversal blocked"
# Test filename sanitization
safe = handler.sanitize_filename("../../../etc/passwd")
assert ".." not in safe, "Should remove traversal"
assert "/" not in safe, "Should remove separators"
# Test extension validation
assert handler.validate_extension("file.txt"), "Allowed extension"
assert handler.validate_extension("file.pdf"), "Allowed extension"
assert not handler.validate_extension("file.exe"), "Blocked extension"
print("ā Exercise 5 passed!")
# Uncomment to test:
# test_secure_path_handler()
# ============================================================
# EXERCISE 6: Rate Limiter
# ============================================================
"""
Create a RateLimiter class that:
1. Limits requests per time window
2. Supports different limits per endpoint
3. Returns remaining requests and reset time
4. Can be reset manually
"""
class RateLimiter:
"""Rate limit requests by identifier."""
def __init__(self, default_limit: int = 100, window_seconds: int = 60):
self.default_limit = default_limit
self.window = timedelta(seconds=window_seconds)
self.requests: dict[str, list[datetime]] = {} # key: "identifier:endpoint"
self.limits: dict[str, int] = {} # endpoint -> limit
def check(self, identifier: str, endpoint: str = "default") -> dict:
"""
Check if request is allowed.
Returns {
'allowed': bool,
'remaining': int,
'reset_at': datetime
}
"""
now = datetime.now()
key = f"{identifier}:{endpoint}"
window_start = now - self.window
# Get limit for endpoint
limit = self.limits.get(endpoint, self.default_limit)
# Clean old requests
if key in self.requests:
self.requests[key] = [
req for req in self.requests[key]
if req > window_start
]
else:
self.requests[key] = []
current_count = len(self.requests[key])
reset_at = now + self.window
if current_count >= limit:
return {
'allowed': False,
'remaining': 0,
'reset_at': reset_at
}
# Record this request
self.requests[key].append(now)
return {
'allowed': True,
'remaining': limit - current_count - 1,
'reset_at': reset_at
}
def set_limit(self, endpoint: str, limit: int):
"""Set custom limit for endpoint."""
self.limits[endpoint] = limit
def reset(self, identifier: str, endpoint: str | None = None):
"""Reset rate limit for identifier."""
if endpoint:
key = f"{identifier}:{endpoint}"
self.requests.pop(key, None)
else:
# Reset all endpoints for this identifier
keys_to_remove = [
k for k in self.requests
if k.startswith(f"{identifier}:")
]
for k in keys_to_remove:
del self.requests[k]
# Test your implementation
def test_rate_limiter():
limiter = RateLimiter(default_limit=5, window_seconds=60)
# Test basic limiting
for i in range(5):
result = limiter.check("user1")
assert result['allowed'], f"Request {i+1} should be allowed"
assert result['remaining'] == 4 - i, f"Remaining should be {4-i}"
# 6th request should be blocked
result = limiter.check("user1")
assert not result['allowed'], "Should be rate limited"
assert result['remaining'] == 0
# Different user should have separate limit
result = limiter.check("user2")
assert result['allowed'], "Different user should be allowed"
# Test reset
limiter.reset("user1")
result = limiter.check("user1")
assert result['allowed'], "Should be allowed after reset"
# Test custom endpoint limits
limiter.set_limit("api/expensive", 2)
limiter.check("user1", "api/expensive")
limiter.check("user1", "api/expensive")
result = limiter.check("user1", "api/expensive")
assert not result['allowed'], "Custom limit should apply"
print("ā Exercise 6 passed!")
# Uncomment to test:
# test_rate_limiter()
# ============================================================
# EXERCISE 7: Secure Session Manager
# ============================================================
"""
Create a SecureSessionManager that:
1. Creates sessions with secure tokens
2. Validates sessions with expiration
3. Supports session data storage
4. Can invalidate individual or all sessions for a user
5. Prevents session fixation
"""
class SecureSessionManager:
"""Manage user sessions securely."""
def __init__(self, expiry_minutes: int = 30):
self.expiry = timedelta(minutes=expiry_minutes)
self.sessions: dict[str, dict] = {}
def create_session(self, user_id: str, data: dict | None = None) -> str:
"""
Create new session, return token.
Store user_id, creation time, expiry, and optional data.
"""
token = secrets.token_urlsafe(32)
now = datetime.now()
self.sessions[token] = {
'user_id': user_id,
'created_at': now,
'expires_at': now + self.expiry,
'data': data or {}
}
return token
def validate_session(self, token: str) -> Optional[dict]:
"""
Validate session token.
Return session data if valid, None if invalid/expired.
"""
session = self.sessions.get(token)
if not session:
return None
if datetime.now() > session['expires_at']:
del self.sessions[token]
return None
return {
'user_id': session['user_id'],
'data': session['data'],
'created_at': session['created_at'],
'expires_at': session['expires_at']
}
def update_session_data(self, token: str, data: dict) -> bool:
"""Update data for existing session."""
if token not in self.sessions:
return False
if datetime.now() > self.sessions[token]['expires_at']:
del self.sessions[token]
return False
self.sessions[token]['data'] = data
return True
def invalidate_session(self, token: str):
"""Invalidate single session."""
self.sessions.pop(token, None)
def invalidate_all_sessions(self, user_id: str):
"""Invalidate all sessions for a user."""
tokens_to_remove = [
token for token, session in self.sessions.items()
if session['user_id'] == user_id
]
for token in tokens_to_remove:
del self.sessions[token]
def regenerate_token(self, old_token: str) -> Optional[str]:
"""
Regenerate session token (prevent fixation).
Keep session data, create new token.
"""
session = self.sessions.get(old_token)
if not session:
return None
if datetime.now() > session['expires_at']:
del self.sessions[old_token]
return None
# Create new token with same session data
new_token = secrets.token_urlsafe(32)
self.sessions[new_token] = session.copy()
# Invalidate old token
del self.sessions[old_token]
return new_token
# Test your implementation
def test_secure_session_manager():
mgr = SecureSessionManager(expiry_minutes=30)
# Test session creation
token = mgr.create_session("user123", {"role": "admin"})
assert token is not None, "Should return token"
assert len(token) >= 32, "Token should be long enough"
# Test validation
session = mgr.validate_session(token)
assert session is not None, "Should return session data"
assert session.get("user_id") == "user123"
assert session.get("data", {}).get("role") == "admin"
# Test invalid token
assert mgr.validate_session("invalid") is None
# Test update
assert mgr.update_session_data(token, {"role": "user"})
session = mgr.validate_session(token)
assert session["data"]["role"] == "user"
# Test invalidation
mgr.invalidate_session(token)
assert mgr.validate_session(token) is None
# Test invalidate all
token1 = mgr.create_session("user456")
token2 = mgr.create_session("user456")
mgr.invalidate_all_sessions("user456")
assert mgr.validate_session(token1) is None
assert mgr.validate_session(token2) is None
# Test token regeneration
token = mgr.create_session("user789", {"test": "data"})
new_token = mgr.regenerate_token(token)
assert new_token != token, "Should be different token"
assert mgr.validate_session(token) is None, "Old token should be invalid"
session = mgr.validate_session(new_token)
assert session["data"]["test"] == "data", "Data should be preserved"
print("ā Exercise 7 passed!")
# Uncomment to test:
# test_secure_session_manager()
# ============================================================
# EXERCISE 8: Audit Logger
# ============================================================
"""
Create a SecureAuditLogger that:
1. Logs security events with timestamps
2. Redacts sensitive information (passwords, tokens)
3. Includes event type, user, IP, and details
4. Prevents log injection attacks
5. Can export logs in a secure format
"""
class SecureAuditLogger:
"""Log security events securely."""
SENSITIVE_KEYS = {'password', 'token', 'secret', 'api_key', 'credit_card'}
def __init__(self):
self.logs = []
def _redact_sensitive(self, data: dict) -> dict:
"""Redact sensitive information from data."""
if not data:
return {}
redacted = {}
for key, value in data.items():
if any(sensitive in key.lower() for sensitive in self.SENSITIVE_KEYS):
redacted[key] = '[REDACTED]'
elif isinstance(value, dict):
redacted[key] = self._redact_sensitive(value)
else:
redacted[key] = value
return redacted
def _sanitize_log_data(self, text: str) -> str:
"""Prevent log injection by sanitizing text."""
# Remove newlines, carriage returns, and other control characters
sanitized = re.sub(r'[\n\r\t]', ' ', str(text))
# Limit length
return sanitized[:500]
def log(self, event_type: str, user: str, ip: str, details: dict | None = None):
"""
Log a security event.
Auto-redacts sensitive info and sanitizes input.
"""
log_entry = {
'timestamp': datetime.now().isoformat(),
'event_type': self._sanitize_log_data(event_type),
'user': self._sanitize_log_data(user),
'ip': self._sanitize_log_data(ip),
'details': self._redact_sensitive(details) if details else {}
}
self.logs.append(log_entry)
def get_logs(self, event_type: str | None = None, user: str | None = None) -> list:
"""Get logs, optionally filtered."""
result = self.logs
if event_type:
result = [log for log in result if log['event_type'] == event_type]
if user:
result = [log for log in result if log['user'] == user]
return result
def export_logs(self) -> str:
"""Export logs as JSON string."""
import json
return json.dumps(self.logs, indent=2, default=str)
# Test your implementation
def test_secure_audit_logger():
import json
logger = SecureAuditLogger()
# Test logging
logger.log("LOGIN", "user1", "192.168.1.1", {"password": "secret123"})
logger.log("API_CALL", "user1", "192.168.1.1", {"token": "abc123", "endpoint": "/api/data"})
# Test redaction
logs = logger.get_logs()
for log in logs:
if 'password' in str(log):
assert 'secret123' not in str(log), "Password should be redacted"
if 'token' in str(log):
assert 'abc123' not in str(log), "Token should be redacted"
# Test log injection prevention
logger.log("TEST", "user\nFAKE_LOG: admin", "127.0.0.1")
logs = logger.get_logs()
for log in logs:
assert '\n' not in str(log.get('user', '')), "Newlines should be sanitized"
# Test filtering
logger.log("LOGOUT", "user2", "192.168.1.2")
user1_logs = logger.get_logs(user="user1")
assert len(user1_logs) == 2, "Should filter by user"
login_logs = logger.get_logs(event_type="LOGIN")
assert len(login_logs) == 1, "Should filter by event type"
# Test export
exported = logger.export_logs()
parsed = json.loads(exported)
assert isinstance(parsed, list), "Export should be valid JSON"
print("ā Exercise 8 passed!")
# Uncomment to test:
# test_secure_audit_logger()
# ============================================================
# EXERCISE 9: CSRF Token Manager
# ============================================================
"""
Create a CSRFTokenManager that:
1. Generates CSRF tokens per session
2. Validates tokens with constant-time comparison
3. Has configurable expiration
4. Supports one-time tokens (deleted after use)
"""
class CSRFTokenManager:
"""Manage CSRF tokens for form protection."""
def __init__(self, expiry_minutes: int = 60, one_time: bool = False):
self.expiry = timedelta(minutes=expiry_minutes)
self.one_time = one_time
self.tokens: dict[str, dict] = {} # session_id -> {token, expires_at}
def generate_token(self, session_id: str) -> str:
"""Generate CSRF token for session."""
token = secrets.token_urlsafe(32)
now = datetime.now()
self.tokens[session_id] = {
'token': token,
'expires_at': now + self.expiry
}
return token
def validate_token(self, session_id: str, token: str) -> bool:
"""
Validate CSRF token.
Use constant-time comparison.
Delete if one_time mode.
"""
token_data = self.tokens.get(session_id)
if not token_data:
return False
# Check expiration
if datetime.now() > token_data['expires_at']:
del self.tokens[session_id]
return False
# Constant-time comparison
is_valid = secrets.compare_digest(token_data['token'], token)
# Delete if one-time mode and valid
if is_valid and self.one_time:
del self.tokens[session_id]
return is_valid
def invalidate_tokens(self, session_id: str):
"""Invalidate all tokens for session."""
self.tokens.pop(session_id, None)
# Test your implementation
def test_csrf_token_manager():
# Test regular mode
csrf = CSRFTokenManager(expiry_minutes=60, one_time=False)
token = csrf.generate_token("session1")
assert token is not None
assert len(token) >= 32
assert csrf.validate_token("session1", token), "Valid token"
assert not csrf.validate_token("session1", "wrong"), "Invalid token"
assert not csrf.validate_token("wrong_session", token), "Wrong session"
# Token should still work (not one-time)
assert csrf.validate_token("session1", token), "Should still be valid"
# Test one-time mode
csrf_ot = CSRFTokenManager(expiry_minutes=60, one_time=True)
token = csrf_ot.generate_token("session2")
assert csrf_ot.validate_token("session2", token), "First use valid"
assert not csrf_ot.validate_token("session2", token), "Second use invalid (one-time)"
# Test invalidation
token = csrf.generate_token("session3")
csrf.invalidate_tokens("session3")
assert not csrf.validate_token("session3", token), "Invalidated"
print("ā Exercise 9 passed!")
# Uncomment to test:
# test_csrf_token_manager()
# ============================================================
# EXERCISE 10: Secure Configuration Manager
# ============================================================
"""
Create a SecureConfigManager that:
1. Loads configuration from environment variables
2. Supports default values
3. Validates required fields
4. Encrypts sensitive values in memory
5. Never exposes secrets in __repr__ or logs
"""
class SecureConfigManager:
"""Manage application configuration securely."""
SENSITIVE_FIELDS = {'password', 'secret', 'api_key', 'token', 'private_key'}
def __init__(self):
self._config = {}
self._sensitive_keys = set()
# Use a runtime key for obfuscation (not real encryption)
self._mask_key = secrets.token_bytes(32)
def load_from_env(self, schema: dict) -> list[str]:
"""
Load config from environment.
Schema: {
'DATABASE_URL': {'required': True},
'DEBUG': {'required': False, 'default': 'false'},
'API_KEY': {'required': True, 'sensitive': True}
}
Returns list of validation errors.
"""
errors = []
for key, options in schema.items():
value = os.environ.get(key)
if value is None:
if options.get('required', False):
errors.append(f"Missing required environment variable: {key}")
elif 'default' in options:
value = options['default']
if value is not None:
self._config[key] = value
if options.get('sensitive', False):
self._sensitive_keys.add(key)
return errors
def get(self, key: str, default=None):
"""Get configuration value."""
return self._config.get(key, default)
def is_sensitive(self, key: str) -> bool:
"""Check if key contains sensitive data."""
return key in self._sensitive_keys
def __repr__(self):
"""Safe representation that hides sensitive values."""
safe_config = {}
for key, value in self._config.items():
if key in self._sensitive_keys:
safe_config[key] = '[REDACTED]'
else:
safe_config[key] = value
return f"SecureConfigManager({safe_config})"
def to_dict(self, include_sensitive: bool = False) -> dict:
"""Export config as dict, optionally masking sensitive values."""
result = {}
for key, value in self._config.items():
if key in self._sensitive_keys and not include_sensitive:
result[key] = '[REDACTED]'
else:
result[key] = value
return result
# Test your implementation
def test_secure_config_manager():
# Set up test environment
os.environ['TEST_DB_URL'] = 'postgresql://localhost/test'
os.environ['TEST_API_KEY'] = 'sk_test_secret123'
config = SecureConfigManager()
schema = {
'TEST_DB_URL': {'required': True},
'TEST_API_KEY': {'required': True, 'sensitive': True},
'TEST_DEBUG': {'required': False, 'default': 'false'},
'TEST_MISSING': {'required': False}
}
errors = config.load_from_env(schema)
assert len(errors) == 0, f"Should have no errors: {errors}"
# Test getting values
assert config.get('TEST_DB_URL') == 'postgresql://localhost/test'
assert config.get('TEST_API_KEY') == 'sk_test_secret123'
assert config.get('TEST_DEBUG') == 'false'
assert config.get('TEST_MISSING') is None
assert config.get('TEST_MISSING', 'default') == 'default'
# Test sensitive detection
assert config.is_sensitive('TEST_API_KEY')
assert not config.is_sensitive('TEST_DB_URL')
# Test repr hides sensitive values
repr_str = repr(config)
assert 'sk_test_secret123' not in repr_str, "Should not expose secret in repr"
# Test to_dict
safe_dict = config.to_dict(include_sensitive=False)
assert 'sk_test_secret123' not in str(safe_dict), "Should mask in dict"
full_dict = config.to_dict(include_sensitive=True)
assert config.get('TEST_API_KEY') in str(full_dict), "Should include when requested"
# Test required validation
schema_missing = {
'NONEXISTENT_REQUIRED': {'required': True}
}
errors = config.load_from_env(schema_missing)
assert len(errors) > 0, "Should report missing required field"
# Cleanup
del os.environ['TEST_DB_URL']
del os.environ['TEST_API_KEY']
print("ā Exercise 10 passed!")
# Uncomment to test:
# test_secure_config_manager()
# ============================================================
# Run all tests
# ============================================================
def run_all_tests():
"""Run all exercise tests."""
print("Running Security Exercises Tests")
print("=" * 50)
tests = [
("Password Manager", test_password_manager),
("Token Generator", test_token_generator),
("Input Validator", test_input_validator),
("HTML Sanitizer", test_html_sanitizer),
("Secure Path Handler", test_secure_path_handler),
("Rate Limiter", test_rate_limiter),
("Session Manager", test_secure_session_manager),
("Audit Logger", test_secure_audit_logger),
("CSRF Token Manager", test_csrf_token_manager),
("Secure Config Manager", test_secure_config_manager),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except (AssertionError, Exception) as e:
print(f"ā {name}: {e}")
failed += 1
print("\n" + "=" * 50)
print(f"Results: {passed} passed, {failed} failed")
if failed == 0:
print("\nš Congratulations! You've completed the Security module!")
print("You now understand Python security best practices!")
# Uncomment to run all tests:
# run_all_tests()
# ============================================================
# BONUS CHALLENGE: Build a Complete Auth System
# ============================================================
"""
CHALLENGE: Create a complete authentication system that includes:
1. User registration with password hashing
2. Login with rate limiting
3. Session management with secure tokens
4. CSRF protection
5. Password reset with time-limited tokens
6. Audit logging
This is a comprehensive exercise that combines all security concepts!
"""