Custom middleware allows you to implement application-specific functionality that runs for every request. This chapter covers designing, implementing, and testing custom middleware for various use cases.
# myapp/middleware.py
import time
import logging
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger(__name__)
class BasicCustomMiddleware:
"""Template for custom middleware"""
def __init__(self, get_response):
"""
One-time configuration and initialization.
"""
self.get_response = get_response
# One-time initialization code here
logger.info("BasicCustomMiddleware initialized")
def __call__(self, request):
"""
Code to be executed for each request before the view is called.
"""
# Process request
start_time = time.time()
# Call the next middleware or view
response = self.get_response(request)
# Process response
end_time = time.time()
processing_time = end_time - start_time
# Add custom header
response['X-Processing-Time'] = f"{processing_time:.4f}"
return response
class FullFeaturedMiddleware:
"""Middleware demonstrating all available hooks"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# This code is executed before the view
self.process_request(request)
response = self.get_response(request)
# This code is executed after the view
return self.process_response(request, response)
def process_request(self, request):
"""Process the request before it reaches the view"""
# Add custom attributes to request
request.custom_timestamp = time.time()
request.custom_id = self.generate_request_id()
logger.info(f"Processing request {request.custom_id} to {request.path}")
def process_view(self, request, view_func, view_args, view_kwargs):
"""Called just before Django calls the view"""
logger.info(f"About to call view: {view_func.__name__}")
# Could return HttpResponse to short-circuit view execution
return None
def process_exception(self, request, exception):
"""Called when a view raises an exception"""
logger.error(f"Exception in request {getattr(request, 'custom_id', 'unknown')}: {exception}")
# Could return HttpResponse to handle the exception
return None
def process_template_response(self, request, response):
"""Called for responses with render() method"""
if hasattr(response, 'context_data'):
response.context_data['request_id'] = getattr(request, 'custom_id', None)
return response
def process_response(self, request, response):
"""Process the response before it's sent to the client"""
# Add custom headers
response['X-Request-ID'] = getattr(request, 'custom_id', 'unknown')
# Calculate processing time
if hasattr(request, 'custom_timestamp'):
processing_time = time.time() - request.custom_timestamp
response['X-Processing-Time'] = f"{processing_time:.4f}"
return response
def generate_request_id(self):
"""Generate unique request ID"""
import uuid
return str(uuid.uuid4())[:8]
import json
import logging
from django.utils import timezone
logger = logging.getLogger('request_logger')
class RequestLoggingMiddleware:
"""Comprehensive request logging middleware"""
def __init__(self, get_response):
self.get_response = get_response
self.sensitive_headers = [
'HTTP_AUTHORIZATION',
'HTTP_COOKIE',
'HTTP_X_API_KEY'
]
def __call__(self, request):
# Log request
request_data = self.extract_request_data(request)
logger.info(f"REQUEST: {json.dumps(request_data)}")
# Process request
start_time = time.time()
response = self.get_response(request)
end_time = time.time()
# Log response
response_data = self.extract_response_data(request, response, end_time - start_time)
logger.info(f"RESPONSE: {json.dumps(response_data)}")
return response
def extract_request_data(self, request):
"""Extract relevant request data for logging"""
return {
'timestamp': timezone.now().isoformat(),
'method': request.method,
'path': request.path,
'query_params': dict(request.GET),
'user': str(request.user) if hasattr(request, 'user') else 'Anonymous',
'ip_address': self.get_client_ip(request),
'user_agent': request.META.get('HTTP_USER_AGENT', ''),
'content_type': request.META.get('CONTENT_TYPE', ''),
'content_length': request.META.get('CONTENT_LENGTH', 0),
'headers': self.get_safe_headers(request)
}
def extract_response_data(self, request, response, processing_time):
"""Extract relevant response data for logging"""
return {
'timestamp': timezone.now().isoformat(),
'status_code': response.status_code,
'content_type': response.get('Content-Type', ''),
'content_length': len(response.content) if hasattr(response, 'content') else 0,
'processing_time': round(processing_time, 4),
'path': request.path
}
def get_client_ip(self, request):
"""Get the client's IP address"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
def get_safe_headers(self, request):
"""Get headers excluding sensitive information"""
headers = {}
for key, value in request.META.items():
if key.startswith('HTTP_') and key not in self.sensitive_headers:
header_name = key[5:].replace('_', '-').title()
headers[header_name] = value
return headers
from django.core.cache import cache
from django.http import HttpResponseTooManyRequests
import json
class RateLimitMiddleware:
"""Rate limiting middleware for API endpoints"""
def __init__(self, get_response):
self.get_response = get_response
self.rate_limits = {
'/api/': {'requests': 100, 'window': 3600}, # 100 requests per hour
'/api/auth/': {'requests': 10, 'window': 300}, # 10 requests per 5 minutes
}
def __call__(self, request):
# Check if request should be rate limited
if self.should_rate_limit(request):
if self.is_rate_limited(request):
return self.rate_limit_response(request)
response = self.get_response(request)
# Add rate limit headers
if self.should_rate_limit(request):
self.add_rate_limit_headers(request, response)
return response
def should_rate_limit(self, request):
"""Determine if request should be rate limited"""
return any(request.path.startswith(path) for path in self.rate_limits.keys())
def get_rate_limit_config(self, request):
"""Get rate limit configuration for request"""
for path, config in self.rate_limits.items():
if request.path.startswith(path):
return config
return None
def is_rate_limited(self, request):
"""Check if request exceeds rate limit"""
config = self.get_rate_limit_config(request)
if not config:
return False
# Create cache key based on IP and path
client_ip = self.get_client_ip(request)
cache_key = f"rate_limit:{client_ip}:{request.path.split('/')[1]}"
# Get current request count
current_requests = cache.get(cache_key, 0)
if current_requests >= config['requests']:
return True
# Increment counter
cache.set(cache_key, current_requests + 1, config['window'])
return False
def rate_limit_response(self, request):
"""Return rate limit exceeded response"""
config = self.get_rate_limit_config(request)
response_data = {
'error': 'Rate limit exceeded',
'message': f"Maximum {config['requests']} requests per {config['window']} seconds",
'retry_after': config['window']
}
response = HttpResponseTooManyRequests(
json.dumps(response_data),
content_type='application/json'
)
response['Retry-After'] = str(config['window'])
return response
def add_rate_limit_headers(self, request, response):
"""Add rate limit information to response headers"""
config = self.get_rate_limit_config(request)
if not config:
return
client_ip = self.get_client_ip(request)
cache_key = f"rate_limit:{client_ip}:{request.path.split('/')[1]}"
current_requests = cache.get(cache_key, 0)
response['X-RateLimit-Limit'] = str(config['requests'])
response['X-RateLimit-Remaining'] = str(max(0, config['requests'] - current_requests))
response['X-RateLimit-Reset'] = str(config['window'])
def get_client_ip(self, request):
"""Get client IP address"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
from django.contrib.auth.models import AnonymousUser
from django.utils import timezone
from .models import UserActivity
class UserActivityMiddleware:
"""Track user activity and update last seen timestamp"""
def __init__(self, get_response):
self.get_response = get_response
self.tracked_paths = ['/dashboard/', '/profile/', '/api/']
self.excluded_paths = ['/static/', '/media/', '/favicon.ico']
def __call__(self, request):
response = self.get_response(request)
# Track activity after successful response
if (response.status_code < 400 and
self.should_track_activity(request)):
self.track_user_activity(request, response)
return response
def should_track_activity(self, request):
"""Determine if activity should be tracked"""
# Don't track static files
if any(request.path.startswith(path) for path in self.excluded_paths):
return False
# Only track authenticated users
if not hasattr(request, 'user') or isinstance(request.user, AnonymousUser):
return False
# Only track specific paths or all if no specific paths defined
if self.tracked_paths:
return any(request.path.startswith(path) for path in self.tracked_paths)
return True
def track_user_activity(self, request, response):
"""Record user activity"""
try:
activity_data = {
'user': request.user,
'path': request.path,
'method': request.method,
'ip_address': self.get_client_ip(request),
'user_agent': request.META.get('HTTP_USER_AGENT', ''),
'timestamp': timezone.now(),
'status_code': response.status_code
}
# Create activity record
UserActivity.objects.create(**activity_data)
# Update user's last activity timestamp
request.user.last_activity = timezone.now()
request.user.save(update_fields=['last_activity'])
except Exception as e:
# Log error but don't break the request
logger.error(f"Error tracking user activity: {e}")
def get_client_ip(self, request):
"""Get client IP address"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
# models.py
class UserActivity(models.Model):
"""Model to store user activity"""
user = models.ForeignKey(User, on_delete=models.CASCADE)
path = models.CharField(max_length=255)
method = models.CharField(max_length=10)
ip_address = models.GenericIPAddressField()
user_agent = models.TextField()
timestamp = models.DateTimeField()
status_code = models.IntegerField()
class Meta:
ordering = ['-timestamp']
indexes = [
models.Index(fields=['user', 'timestamp']),
models.Index(fields=['timestamp']),
]
class CSPMiddleware:
"""Content Security Policy middleware"""
def __init__(self, get_response):
self.get_response = get_response
self.default_policy = {
'default-src': ["'self'"],
'script-src': ["'self'", "'unsafe-inline'"],
'style-src': ["'self'", "'unsafe-inline'"],
'img-src': ["'self'", "data:", "https:"],
'font-src': ["'self'"],
'connect-src': ["'self'"],
'frame-src': ["'none'"],
'object-src': ["'none'"],
'base-uri': ["'self'"],
'form-action': ["'self'"]
}
def __call__(self, request):
response = self.get_response(request)
# Add CSP header
csp_header = self.build_csp_header(request)
response['Content-Security-Policy'] = csp_header
return response
def build_csp_header(self, request):
"""Build CSP header based on request"""
policy = self.default_policy.copy()
# Customize policy based on request path
if request.path.startswith('/admin/'):
# More permissive policy for admin
policy['script-src'].append("'unsafe-eval'")
elif request.path.startswith('/api/'):
# Strict policy for API
policy['script-src'] = ["'none'"]
policy['style-src'] = ["'none'"]
# Build header string
directives = []
for directive, sources in policy.items():
sources_str = ' '.join(sources)
directives.append(f"{directive} {sources_str}")
return '; '.join(directives)
import re
class MobileDetectionMiddleware:
"""Detect mobile devices and add context"""
def __init__(self, get_response):
self.get_response = get_response
# Mobile user agent patterns
self.mobile_patterns = [
re.compile(r'Mobile', re.IGNORECASE),
re.compile(r'Android', re.IGNORECASE),
re.compile(r'iPhone', re.IGNORECASE),
re.compile(r'iPad', re.IGNORECASE),
re.compile(r'Windows Phone', re.IGNORECASE),
re.compile(r'BlackBerry', re.IGNORECASE),
]
# Tablet patterns
self.tablet_patterns = [
re.compile(r'iPad', re.IGNORECASE),
re.compile(r'Android.*Tablet', re.IGNORECASE),
re.compile(r'Kindle', re.IGNORECASE),
]
def __call__(self, request):
# Detect device type
user_agent = request.META.get('HTTP_USER_AGENT', '')
request.is_mobile = self.is_mobile(user_agent)
request.is_tablet = self.is_tablet(user_agent)
request.is_desktop = not (request.is_mobile or request.is_tablet)
# Add device info to request
request.device_info = {
'type': self.get_device_type(request),
'user_agent': user_agent,
'screen_size': self.estimate_screen_size(request)
}
response = self.get_response(request)
# Add device type header
response['X-Device-Type'] = request.device_info['type']
return response
def is_mobile(self, user_agent):
"""Check if user agent indicates mobile device"""
return any(pattern.search(user_agent) for pattern in self.mobile_patterns)
def is_tablet(self, user_agent):
"""Check if user agent indicates tablet device"""
return any(pattern.search(user_agent) for pattern in self.tablet_patterns)
def get_device_type(self, request):
"""Get device type string"""
if request.is_tablet:
return 'tablet'
elif request.is_mobile:
return 'mobile'
else:
return 'desktop'
def estimate_screen_size(self, request):
"""Estimate screen size category"""
if request.is_mobile:
return 'small'
elif request.is_tablet:
return 'medium'
else:
return 'large'
class ConditionalMiddleware:
"""Middleware that runs conditionally based on settings"""
def __init__(self, get_response):
self.get_response = get_response
# Load configuration
from django.conf import settings
self.config = getattr(settings, 'CONDITIONAL_MIDDLEWARE_CONFIG', {})
self.enabled = self.config.get('enabled', True)
self.debug_mode = self.config.get('debug', False)
self.allowed_ips = self.config.get('allowed_ips', [])
self.excluded_paths = self.config.get('excluded_paths', [])
def __call__(self, request):
# Check if middleware should run
if not self.should_process(request):
return self.get_response(request)
# Process request
if self.debug_mode:
print(f"ConditionalMiddleware processing: {request.path}")
response = self.get_response(request)
# Process response
response['X-Processed-By'] = 'ConditionalMiddleware'
return response
def should_process(self, request):
"""Determine if middleware should process this request"""
if not self.enabled:
return False
# Check IP whitelist
if self.allowed_ips:
client_ip = self.get_client_ip(request)
if client_ip not in self.allowed_ips:
return False
# Check excluded paths
if any(request.path.startswith(path) for path in self.excluded_paths):
return False
return True
def get_client_ip(self, request):
"""Get client IP address"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
from .models import RequestLog, BlockedIP
class DatabaseIntegratedMiddleware:
"""Middleware that integrates with database models"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# Check if IP is blocked
client_ip = self.get_client_ip(request)
if self.is_ip_blocked(client_ip):
return HttpResponseForbidden('IP address blocked')
# Log request to database
request_log = self.create_request_log(request)
response = self.get_response(request)
# Update request log with response info
self.update_request_log(request_log, response)
return response
def is_ip_blocked(self, ip_address):
"""Check if IP address is blocked"""
try:
return BlockedIP.objects.filter(
ip_address=ip_address,
is_active=True
).exists()
except Exception:
return False
def create_request_log(self, request):
"""Create request log entry"""
try:
return RequestLog.objects.create(
ip_address=self.get_client_ip(request),
path=request.path,
method=request.method,
user_agent=request.META.get('HTTP_USER_AGENT', ''),
user=request.user if hasattr(request, 'user') and request.user.is_authenticated else None,
timestamp=timezone.now()
)
except Exception as e:
logger.error(f"Error creating request log: {e}")
return None
def update_request_log(self, request_log, response):
"""Update request log with response information"""
if request_log:
try:
request_log.status_code = response.status_code
request_log.response_size = len(response.content) if hasattr(response, 'content') else 0
request_log.save(update_fields=['status_code', 'response_size'])
except Exception as e:
logger.error(f"Error updating request log: {e}")
def get_client_ip(self, request):
"""Get client IP address"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip
# models.py
class RequestLog(models.Model):
ip_address = models.GenericIPAddressField()
path = models.CharField(max_length=255)
method = models.CharField(max_length=10)
user_agent = models.TextField()
user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True)
timestamp = models.DateTimeField()
status_code = models.IntegerField(null=True, blank=True)
response_size = models.IntegerField(null=True, blank=True)
class BlockedIP(models.Model):
ip_address = models.GenericIPAddressField(unique=True)
reason = models.CharField(max_length=255)
blocked_at = models.DateTimeField(auto_now_add=True)
is_active = models.BooleanField(default=True)
from django.test import TestCase, RequestFactory
from django.http import HttpResponse
from django.contrib.auth.models import User, AnonymousUser
from myapp.middleware import RequestLoggingMiddleware, RateLimitMiddleware
class CustomMiddlewareTests(TestCase):
"""Test custom middleware functionality"""
def setUp(self):
self.factory = RequestFactory()
self.user = User.objects.create_user(
username='testuser',
email='test@example.com',
password='testpass123'
)
def test_request_logging_middleware(self):
"""Test request logging middleware"""
def get_response(request):
return HttpResponse("Test response")
middleware = RequestLoggingMiddleware(get_response)
request = self.factory.get('/test/')
request.user = self.user
with self.assertLogs('request_logger', level='INFO') as cm:
response = middleware(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(cm.output), 2) # Request and response logs
self.assertIn('REQUEST:', cm.output[0])
self.assertIn('RESPONSE:', cm.output[1])
def test_rate_limiting_middleware(self):
"""Test rate limiting middleware"""
def get_response(request):
return HttpResponse("Test response")
middleware = RateLimitMiddleware(get_response)
# First request should succeed
request = self.factory.get('/api/test/')
response = middleware(request)
self.assertEqual(response.status_code, 200)
# Simulate many requests to trigger rate limit
# (This would require mocking the cache for proper testing)
def test_mobile_detection_middleware(self):
"""Test mobile detection middleware"""
def get_response(request):
return HttpResponse("Test response")
middleware = MobileDetectionMiddleware(get_response)
# Test mobile user agent
request = self.factory.get('/test/')
request.META['HTTP_USER_AGENT'] = 'Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)'
response = middleware(request)
self.assertTrue(request.is_mobile)
self.assertFalse(request.is_tablet)
self.assertFalse(request.is_desktop)
self.assertEqual(response['X-Device-Type'], 'mobile')
from django.test import TestCase, Client
from django.urls import reverse
class MiddlewareIntegrationTests(TestCase):
"""Integration tests for middleware"""
def setUp(self):
self.client = Client()
def test_middleware_chain_execution(self):
"""Test that middleware chain executes correctly"""
response = self.client.get('/test/')
# Check that all middleware added their headers
self.assertIn('X-Processing-Time', response)
self.assertIn('X-Request-ID', response)
self.assertIn('X-Device-Type', response)
def test_middleware_with_authentication(self):
"""Test middleware behavior with authenticated users"""
user = User.objects.create_user(
username='testuser',
password='testpass123'
)
self.client.login(username='testuser', password='testpass123')
response = self.client.get('/dashboard/')
# Check that user activity was tracked
self.assertTrue(UserActivity.objects.filter(user=user).exists())
# settings.py
# Add custom middleware to MIDDLEWARE list
MIDDLEWARE = [
'django.middleware.security.SecurityMiddleware',
'myapp.middleware.RequestLoggingMiddleware',
'myapp.middleware.RateLimitMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'myapp.middleware.UserActivityMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'myapp.middleware.MobileDetectionMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
# Custom middleware configuration
CONDITIONAL_MIDDLEWARE_CONFIG = {
'enabled': True,
'debug': False,
'allowed_ips': ['127.0.0.1', '192.168.1.0/24'],
'excluded_paths': ['/static/', '/media/', '/health/']
}
# Logging configuration for middleware
LOGGING = {
'version': 1,
'disable_existing_loggers': False,
'handlers': {
'request_file': {
'level': 'INFO',
'class': 'logging.handlers.RotatingFileHandler',
'filename': 'logs/requests.log',
'maxBytes': 10 * 1024 * 1024, # 10MB
'backupCount': 5,
},
},
'loggers': {
'request_logger': {
'handlers': ['request_file'],
'level': 'INFO',
'propagate': False,
},
},
}
Now that you know how to create custom middleware, let's explore middleware ordering and understand how the sequence of middleware affects your application's behavior.
Built-in Middleware
Django comes with several built-in middleware components that handle common web application concerns. Understanding these middleware components helps you leverage Django's capabilities and serves as examples for creating your own middleware.
Middleware Ordering
The order of middleware in Django's MIDDLEWARE setting is crucial for proper application behavior. This chapter explains how middleware ordering affects request/response processing and provides guidelines for optimal middleware arrangement.