Middleware can significantly impact application performance and introduce complex debugging challenges. This chapter covers techniques for optimizing middleware performance and debugging middleware-related issues.
import time
import logging
from django.conf import settings
logger = logging.getLogger('middleware_performance')
class PerformanceMonitoringMiddleware:
"""Monitor and log middleware performance"""
def __init__(self, get_response):
self.get_response = get_response
self.enabled = getattr(settings, 'MIDDLEWARE_PERFORMANCE_MONITORING', False)
def __call__(self, request):
if not self.enabled:
return self.get_response(request)
# Record start time
start_time = time.time()
# Add timing context to request
request._middleware_timings = []
# Wrap get_response to measure total time
original_get_response = self.get_response
def timed_get_response(req):
middleware_start = time.time()
response = original_get_response(req)
middleware_end = time.time()
req._middleware_timings.append({
'name': 'view_and_middleware_chain',
'duration': middleware_end - middleware_start
})
return response
response = timed_get_response(request)
# Calculate total request time
total_time = time.time() - start_time
# Log performance data
self.log_performance_data(request, total_time)
# Add performance headers
response['X-Total-Time'] = f"{total_time:.4f}"
response['X-Middleware-Count'] = str(len(request._middleware_timings))
return response
def log_performance_data(self, request, total_time):
"""Log detailed performance information"""
performance_data = {
'path': request.path,
'method': request.method,
'total_time': round(total_time, 4),
'middleware_timings': getattr(request, '_middleware_timings', [])
}
# Log slow requests
if total_time > 1.0: # Requests taking more than 1 second
logger.warning(f"Slow request detected: {performance_data}")
else:
logger.info(f"Request performance: {performance_data}")
class ProfilingMiddleware:
"""Profile individual middleware performance"""
def __init__(self, get_response):
self.get_response = get_response
self.middleware_name = self.__class__.__name__
def __call__(self, request):
# Start timing
start_time = time.perf_counter()
# Process request
response = self.get_response(request)
# End timing
end_time = time.perf_counter()
duration = end_time - start_time
# Store timing data
if not hasattr(request, '_middleware_timings'):
request._middleware_timings = []
request._middleware_timings.append({
'name': self.middleware_name,
'duration': duration
})
# Add individual timing header
response[f'X-{self.middleware_name}-Time'] = f"{duration:.6f}"
# Log if middleware is slow
if duration > 0.1: # More than 100ms
logger.warning(f"{self.middleware_name} took {duration:.4f}s for {request.path}")
return response
# Use by inheriting from ProfilingMiddleware
class MyCustomMiddleware(ProfilingMiddleware):
def __call__(self, request):
# Your middleware logic here
return super().__call__(request)
from django.db import connection
from django.conf import settings
class DatabaseOptimizedMiddleware:
"""Middleware optimized for database performance"""
def __init__(self, get_response):
self.get_response = get_response
self.cache_timeout = 300 # 5 minutes
def __call__(self, request):
# Track database queries
initial_queries = len(connection.queries) if settings.DEBUG else 0
# Use caching to avoid repeated database hits
user_data = self.get_cached_user_data(request)
if user_data:
request.cached_user_data = user_data
response = self.get_response(request)
# Log database usage
if settings.DEBUG:
final_queries = len(connection.queries)
query_count = final_queries - initial_queries
if query_count > 10: # More than 10 queries
logger.warning(f"High query count in middleware: {query_count} queries for {request.path}")
return response
def get_cached_user_data(self, request):
"""Get user data with caching to avoid repeated DB hits"""
if not hasattr(request, 'user') or not request.user.is_authenticated:
return None
from django.core.cache import cache
cache_key = f"user_data:{request.user.id}"
user_data = cache.get(cache_key)
if user_data is None:
# Only hit database if not in cache
user_data = {
'preferences': self.get_user_preferences(request.user),
'permissions': list(request.user.get_all_permissions()),
'groups': list(request.user.groups.values_list('name', flat=True))
}
cache.set(cache_key, user_data, self.cache_timeout)
return user_data
def get_user_preferences(self, user):
"""Get user preferences with select_related optimization"""
try:
# Use select_related to avoid additional queries
return user.profile.preferences
except AttributeError:
return {}
class ConditionalPerformanceMiddleware:
"""Middleware that conditionally executes based on performance criteria"""
def __init__(self, get_response):
self.get_response = get_response
self.skip_paths = ['/static/', '/media/', '/favicon.ico']
self.heavy_processing_paths = ['/api/reports/', '/admin/']
def __call__(self, request):
# Skip processing for static files
if self.should_skip_processing(request):
return self.get_response(request)
# Use lightweight processing for most requests
if self.requires_heavy_processing(request):
return self.heavy_processing(request)
else:
return self.lightweight_processing(request)
def should_skip_processing(self, request):
"""Check if request should skip middleware processing"""
return any(request.path.startswith(path) for path in self.skip_paths)
def requires_heavy_processing(self, request):
"""Check if request requires heavy processing"""
return any(request.path.startswith(path) for path in self.heavy_processing_paths)
def lightweight_processing(self, request):
"""Fast processing for most requests"""
# Minimal processing
request.processing_type = 'lightweight'
response = self.get_response(request)
# Add minimal headers
response['X-Processing-Type'] = 'lightweight'
return response
def heavy_processing(self, request):
"""Comprehensive processing for specific requests"""
# Full processing with all features
request.processing_type = 'heavy'
# Add detailed request analysis
self.analyze_request(request)
response = self.get_response(request)
# Add comprehensive headers
response['X-Processing-Type'] = 'heavy'
response['X-Analysis-Complete'] = 'true'
return response
def analyze_request(self, request):
"""Perform detailed request analysis"""
# Only run expensive analysis when needed
request.analysis = {
'user_agent_parsed': self.parse_user_agent(request),
'geo_location': self.get_geo_location(request),
'security_score': self.calculate_security_score(request)
}
import functools
import traceback
from django.http import HttpResponse
def debug_middleware(middleware_class):
"""Decorator to add debugging capabilities to middleware"""
class DebuggedMiddleware(middleware_class):
def __init__(self, get_response):
self.middleware_name = middleware_class.__name__
print(f"[DEBUG] Initializing {self.middleware_name}")
super().__init__(get_response)
def __call__(self, request):
print(f"[DEBUG] {self.middleware_name}: Processing request to {request.path}")
try:
# Call original middleware
response = super().__call__(request)
print(f"[DEBUG] {self.middleware_name}: Successfully processed request")
return response
except Exception as e:
print(f"[ERROR] {self.middleware_name}: Exception occurred: {e}")
print(f"[ERROR] Traceback: {traceback.format_exc()}")
# Return error response or re-raise
if settings.DEBUG:
return HttpResponse(
f"Middleware Error in {self.middleware_name}: {e}",
status=500
)
else:
raise
return DebuggedMiddleware
# Usage:
@debug_middleware
class MyCustomMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# Your middleware logic
response = self.get_response(request)
return response
class RequestStateInspector:
"""Middleware to inspect and log request state changes"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# Capture initial request state
initial_state = self.capture_request_state(request)
response = self.get_response(request)
# Capture final request state
final_state = self.capture_request_state(request)
# Log state changes
self.log_state_changes(request, initial_state, final_state)
return response
def capture_request_state(self, request):
"""Capture current request state"""
state = {
'attributes': list(vars(request).keys()),
'session_keys': list(request.session.keys()) if hasattr(request, 'session') else [],
'user': str(getattr(request, 'user', 'No user attribute')),
'META_keys': list(request.META.keys()),
}
# Capture custom attributes added by middleware
custom_attrs = [attr for attr in vars(request) if attr.startswith('_') or attr.islower()]
state['custom_attributes'] = custom_attrs
return state
def log_state_changes(self, request, initial_state, final_state):
"""Log changes in request state"""
# Check for new attributes
new_attrs = set(final_state['attributes']) - set(initial_state['attributes'])
if new_attrs:
logger.info(f"New request attributes added: {new_attrs}")
# Check for new session keys
new_session_keys = set(final_state['session_keys']) - set(initial_state['session_keys'])
if new_session_keys:
logger.info(f"New session keys added: {new_session_keys}")
# Check for user changes
if initial_state['user'] != final_state['user']:
logger.info(f"User changed from {initial_state['user']} to {final_state['user']}")
class MiddlewareChainTracer:
"""Trace the complete middleware execution chain"""
def __init__(self, get_response):
self.get_response = get_response
self.middleware_name = self.__class__.__name__
def __call__(self, request):
# Initialize tracing if not already done
if not hasattr(request, '_middleware_trace'):
request._middleware_trace = {
'chain': [],
'timings': {},
'errors': []
}
# Record entry
entry_time = time.time()
request._middleware_trace['chain'].append(f"{self.middleware_name}:enter")
try:
response = self.get_response(request)
# Record successful exit
exit_time = time.time()
request._middleware_trace['chain'].append(f"{self.middleware_name}:exit")
request._middleware_trace['timings'][self.middleware_name] = exit_time - entry_time
# Log complete trace on final middleware
if self.is_final_middleware(request):
self.log_complete_trace(request)
return response
except Exception as e:
# Record error
request._middleware_trace['errors'].append(f"{self.middleware_name}: {str(e)}")
request._middleware_trace['chain'].append(f"{self.middleware_name}:error")
# Log error trace
self.log_error_trace(request, e)
raise
def is_final_middleware(self, request):
"""Check if this is the final middleware in the chain"""
# Simple heuristic: if we've seen both enter and exit for this middleware
chain = request._middleware_trace['chain']
return chain.count(f"{self.middleware_name}:exit") > 0
def log_complete_trace(self, request):
"""Log the complete middleware execution trace"""
trace = request._middleware_trace
logger.info(f"Middleware trace for {request.path}:")
logger.info(f" Chain: {' → '.join(trace['chain'])}")
logger.info(f" Timings: {trace['timings']}")
if trace['errors']:
logger.error(f" Errors: {trace['errors']}")
def log_error_trace(self, request, exception):
"""Log error trace information"""
trace = request._middleware_trace
logger.error(f"Middleware error in {self.middleware_name} for {request.path}:")
logger.error(f" Exception: {exception}")
logger.error(f" Chain so far: {' → '.join(trace['chain'])}")
logger.error(f" Timings: {trace['timings']}")
# Base class for traceable middleware
class TraceableMiddleware(MiddlewareChainTracer):
"""Base class that adds tracing to any middleware"""
pass
# Usage:
class MyMiddleware(TraceableMiddleware):
def __call__(self, request):
# Your middleware logic
response = super().__call__(request)
# More logic
return response
class BottleneckDetector:
"""Detect performance bottlenecks in middleware chain"""
def __init__(self, get_response):
self.get_response = get_response
self.slow_threshold = 0.1 # 100ms
self.very_slow_threshold = 0.5 # 500ms
def __call__(self, request):
# Initialize performance tracking
if not hasattr(request, '_performance_data'):
request._performance_data = {
'start_time': time.time(),
'middleware_times': [],
'bottlenecks': []
}
middleware_start = time.time()
response = self.get_response(request)
middleware_end = time.time()
middleware_time = middleware_end - middleware_start
# Record timing
request._performance_data['middleware_times'].append({
'middleware': self.__class__.__name__,
'duration': middleware_time
})
# Detect bottlenecks
if middleware_time > self.very_slow_threshold:
bottleneck = {
'middleware': self.__class__.__name__,
'duration': middleware_time,
'severity': 'critical'
}
request._performance_data['bottlenecks'].append(bottleneck)
logger.critical(f"Critical bottleneck detected: {self.__class__.__name__} "
f"took {middleware_time:.4f}s for {request.path}")
elif middleware_time > self.slow_threshold:
bottleneck = {
'middleware': self.__class__.__name__,
'duration': middleware_time,
'severity': 'warning'
}
request._performance_data['bottlenecks'].append(bottleneck)
logger.warning(f"Slow middleware detected: {self.__class__.__name__} "
f"took {middleware_time:.4f}s for {request.path}")
# Add performance data to response
if hasattr(request, '_performance_data'):
total_time = time.time() - request._performance_data['start_time']
response['X-Total-Time'] = f"{total_time:.4f}"
if request._performance_data['bottlenecks']:
bottleneck_count = len(request._performance_data['bottlenecks'])
response['X-Bottlenecks'] = str(bottleneck_count)
return response
import psutil
import os
class MemoryMonitoringMiddleware:
"""Monitor memory usage during request processing"""
def __init__(self, get_response):
self.get_response = get_response
self.process = psutil.Process(os.getpid())
def __call__(self, request):
# Get initial memory usage
initial_memory = self.process.memory_info().rss / 1024 / 1024 # MB
response = self.get_response(request)
# Get final memory usage
final_memory = self.process.memory_info().rss / 1024 / 1024 # MB
memory_delta = final_memory - initial_memory
# Log significant memory increases
if memory_delta > 10: # More than 10MB increase
logger.warning(f"High memory usage increase: {memory_delta:.2f}MB for {request.path}")
# Add memory info to response
response['X-Memory-Usage'] = f"{final_memory:.2f}MB"
response['X-Memory-Delta'] = f"{memory_delta:.2f}MB"
return response
import gc
import threading
class ResourceLeakDetector:
"""Detect potential resource leaks in middleware"""
def __init__(self, get_response):
self.get_response = get_response
self.check_interval = 100 # Check every 100 requests
self.request_count = 0
def __call__(self, request):
self.request_count += 1
# Periodic resource check
if self.request_count % self.check_interval == 0:
self.check_resources()
response = self.get_response(request)
return response
def check_resources(self):
"""Check for potential resource leaks"""
# Check object counts
object_counts = {}
for obj in gc.get_objects():
obj_type = type(obj).__name__
object_counts[obj_type] = object_counts.get(obj_type, 0) + 1
# Log unusual object counts
for obj_type, count in object_counts.items():
if count > 10000: # Arbitrary threshold
logger.warning(f"High object count: {count} instances of {obj_type}")
# Check thread count
thread_count = threading.active_count()
if thread_count > 50: # Arbitrary threshold
logger.warning(f"High thread count: {thread_count} active threads")
# Force garbage collection
collected = gc.collect()
if collected > 1000:
logger.info(f"Garbage collection freed {collected} objects")
import uuid
class CorrelationIDMiddleware:
"""Add correlation ID for request tracing across services"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# Get or generate correlation ID
correlation_id = (
request.META.get('HTTP_X_CORRELATION_ID') or
str(uuid.uuid4())
)
# Add to request
request.correlation_id = correlation_id
# Add to logging context
import logging
logger = logging.getLogger()
old_factory = logging.getLogRecordFactory()
def record_factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
record.correlation_id = correlation_id
return record
logging.setLogRecordFactory(record_factory)
try:
response = self.get_response(request)
# Add correlation ID to response
response['X-Correlation-ID'] = correlation_id
return response
finally:
# Restore original log record factory
logging.setLogRecordFactory(old_factory)
class ErrorContextMiddleware:
"""Collect detailed context information for errors"""
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
try:
response = self.get_response(request)
return response
except Exception as e:
# Collect error context
error_context = self.collect_error_context(request, e)
# Log detailed error information
logger.error(f"Middleware error context: {error_context}", exc_info=True)
# Store context for error reporting
if hasattr(request, 'sentry') or hasattr(request, 'rollbar'):
# Add context to error reporting service
pass
# Re-raise the exception
raise
def collect_error_context(self, request, exception):
"""Collect comprehensive error context"""
context = {
'request': {
'path': request.path,
'method': request.method,
'GET': dict(request.GET),
'POST': dict(request.POST) if request.method == 'POST' else {},
'user': str(getattr(request, 'user', 'Anonymous')),
'session_key': getattr(request.session, 'session_key', None) if hasattr(request, 'session') else None,
},
'exception': {
'type': type(exception).__name__,
'message': str(exception),
'module': getattr(exception, '__module__', None),
},
'middleware': {
'chain': getattr(request, '_middleware_trace', {}).get('chain', []),
'timings': getattr(request, '_middleware_trace', {}).get('timings', {}),
},
'server': {
'time': time.time(),
'process_id': os.getpid(),
'thread_id': threading.get_ident(),
}
}
return context
from django.test import TestCase, RequestFactory
from django.test.utils import override_settings
import time
class MiddlewarePerformanceTests(TestCase):
"""Test middleware performance characteristics"""
def setUp(self):
self.factory = RequestFactory()
def test_middleware_performance_under_load(self):
"""Test middleware performance with multiple requests"""
def get_response(request):
time.sleep(0.01) # Simulate view processing
return HttpResponse("OK")
middleware = MyCustomMiddleware(get_response)
# Test multiple requests
times = []
for i in range(100):
request = self.factory.get(f'/test/{i}/')
start_time = time.time()
response = middleware(request)
end_time = time.time()
times.append(end_time - start_time)
# Analyze performance
avg_time = sum(times) / len(times)
max_time = max(times)
# Assert performance requirements
self.assertLess(avg_time, 0.05, "Average response time too high")
self.assertLess(max_time, 0.1, "Maximum response time too high")
def test_middleware_memory_usage(self):
"""Test middleware memory usage"""
import psutil
import os
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss
def get_response(request):
return HttpResponse("OK")
middleware = MyCustomMiddleware(get_response)
# Process many requests
for i in range(1000):
request = self.factory.get(f'/test/{i}/')
response = middleware(request)
final_memory = process.memory_info().rss
memory_increase = final_memory - initial_memory
# Assert memory usage is reasonable
self.assertLess(memory_increase, 50 * 1024 * 1024, "Memory usage increased by more than 50MB")
By following these performance optimization and debugging techniques, you can ensure your middleware operates efficiently and provides valuable insights when issues occur in production environments.
Middleware Ordering
The order of middleware in Django's MIDDLEWARE setting is crucial for proper application behavior. This chapter explains how middleware ordering affects request/response processing and provides guidelines for optimal middleware arrangement.
Security
Security is a fundamental aspect of web application development, and Django provides robust built-in protections against common web vulnerabilities. This comprehensive guide covers Django's security features and best practices for building secure applications.