Middleware

Creating Custom Middleware

Custom middleware allows you to implement application-specific functionality that runs for every request. This chapter covers designing, implementing, and testing custom middleware for various use cases.

Creating Custom Middleware

Custom middleware allows you to implement application-specific functionality that runs for every request. This chapter covers designing, implementing, and testing custom middleware for various use cases.

Basic Custom Middleware Structure

Simple Middleware Template

# myapp/middleware.py
import time
import logging
from django.utils.deprecation import MiddlewareMixin

logger = logging.getLogger(__name__)

class BasicCustomMiddleware:
    """Template for custom middleware"""
    
    def __init__(self, get_response):
        """
        One-time configuration and initialization.
        """
        self.get_response = get_response
        # One-time initialization code here
        logger.info("BasicCustomMiddleware initialized")
    
    def __call__(self, request):
        """
        Code to be executed for each request before the view is called.
        """
        # Process request
        start_time = time.time()
        
        # Call the next middleware or view
        response = self.get_response(request)
        
        # Process response
        end_time = time.time()
        processing_time = end_time - start_time
        
        # Add custom header
        response['X-Processing-Time'] = f"{processing_time:.4f}"
        
        return response

Middleware with All Hooks

class FullFeaturedMiddleware:
    """Middleware demonstrating all available hooks"""
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # This code is executed before the view
        self.process_request(request)
        
        response = self.get_response(request)
        
        # This code is executed after the view
        return self.process_response(request, response)
    
    def process_request(self, request):
        """Process the request before it reaches the view"""
        # Add custom attributes to request
        request.custom_timestamp = time.time()
        request.custom_id = self.generate_request_id()
        
        logger.info(f"Processing request {request.custom_id} to {request.path}")
    
    def process_view(self, request, view_func, view_args, view_kwargs):
        """Called just before Django calls the view"""
        logger.info(f"About to call view: {view_func.__name__}")
        
        # Could return HttpResponse to short-circuit view execution
        return None
    
    def process_exception(self, request, exception):
        """Called when a view raises an exception"""
        logger.error(f"Exception in request {getattr(request, 'custom_id', 'unknown')}: {exception}")
        
        # Could return HttpResponse to handle the exception
        return None
    
    def process_template_response(self, request, response):
        """Called for responses with render() method"""
        if hasattr(response, 'context_data'):
            response.context_data['request_id'] = getattr(request, 'custom_id', None)
        
        return response
    
    def process_response(self, request, response):
        """Process the response before it's sent to the client"""
        # Add custom headers
        response['X-Request-ID'] = getattr(request, 'custom_id', 'unknown')
        
        # Calculate processing time
        if hasattr(request, 'custom_timestamp'):
            processing_time = time.time() - request.custom_timestamp
            response['X-Processing-Time'] = f"{processing_time:.4f}"
        
        return response
    
    def generate_request_id(self):
        """Generate unique request ID"""
        import uuid
        return str(uuid.uuid4())[:8]

Practical Middleware Examples

Request Logging Middleware

import json
import logging
from django.utils import timezone

logger = logging.getLogger('request_logger')

class RequestLoggingMiddleware:
    """Comprehensive request logging middleware"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.sensitive_headers = [
            'HTTP_AUTHORIZATION',
            'HTTP_COOKIE',
            'HTTP_X_API_KEY'
        ]
    
    def __call__(self, request):
        # Log request
        request_data = self.extract_request_data(request)
        logger.info(f"REQUEST: {json.dumps(request_data)}")
        
        # Process request
        start_time = time.time()
        response = self.get_response(request)
        end_time = time.time()
        
        # Log response
        response_data = self.extract_response_data(request, response, end_time - start_time)
        logger.info(f"RESPONSE: {json.dumps(response_data)}")
        
        return response
    
    def extract_request_data(self, request):
        """Extract relevant request data for logging"""
        return {
            'timestamp': timezone.now().isoformat(),
            'method': request.method,
            'path': request.path,
            'query_params': dict(request.GET),
            'user': str(request.user) if hasattr(request, 'user') else 'Anonymous',
            'ip_address': self.get_client_ip(request),
            'user_agent': request.META.get('HTTP_USER_AGENT', ''),
            'content_type': request.META.get('CONTENT_TYPE', ''),
            'content_length': request.META.get('CONTENT_LENGTH', 0),
            'headers': self.get_safe_headers(request)
        }
    
    def extract_response_data(self, request, response, processing_time):
        """Extract relevant response data for logging"""
        return {
            'timestamp': timezone.now().isoformat(),
            'status_code': response.status_code,
            'content_type': response.get('Content-Type', ''),
            'content_length': len(response.content) if hasattr(response, 'content') else 0,
            'processing_time': round(processing_time, 4),
            'path': request.path
        }
    
    def get_client_ip(self, request):
        """Get the client's IP address"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip
    
    def get_safe_headers(self, request):
        """Get headers excluding sensitive information"""
        headers = {}
        for key, value in request.META.items():
            if key.startswith('HTTP_') and key not in self.sensitive_headers:
                header_name = key[5:].replace('_', '-').title()
                headers[header_name] = value
        return headers

API Rate Limiting Middleware

from django.core.cache import cache
from django.http import HttpResponseTooManyRequests
import json

class RateLimitMiddleware:
    """Rate limiting middleware for API endpoints"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.rate_limits = {
            '/api/': {'requests': 100, 'window': 3600},  # 100 requests per hour
            '/api/auth/': {'requests': 10, 'window': 300},  # 10 requests per 5 minutes
        }
    
    def __call__(self, request):
        # Check if request should be rate limited
        if self.should_rate_limit(request):
            if self.is_rate_limited(request):
                return self.rate_limit_response(request)
        
        response = self.get_response(request)
        
        # Add rate limit headers
        if self.should_rate_limit(request):
            self.add_rate_limit_headers(request, response)
        
        return response
    
    def should_rate_limit(self, request):
        """Determine if request should be rate limited"""
        return any(request.path.startswith(path) for path in self.rate_limits.keys())
    
    def get_rate_limit_config(self, request):
        """Get rate limit configuration for request"""
        for path, config in self.rate_limits.items():
            if request.path.startswith(path):
                return config
        return None
    
    def is_rate_limited(self, request):
        """Check if request exceeds rate limit"""
        config = self.get_rate_limit_config(request)
        if not config:
            return False
        
        # Create cache key based on IP and path
        client_ip = self.get_client_ip(request)
        cache_key = f"rate_limit:{client_ip}:{request.path.split('/')[1]}"
        
        # Get current request count
        current_requests = cache.get(cache_key, 0)
        
        if current_requests >= config['requests']:
            return True
        
        # Increment counter
        cache.set(cache_key, current_requests + 1, config['window'])
        return False
    
    def rate_limit_response(self, request):
        """Return rate limit exceeded response"""
        config = self.get_rate_limit_config(request)
        
        response_data = {
            'error': 'Rate limit exceeded',
            'message': f"Maximum {config['requests']} requests per {config['window']} seconds",
            'retry_after': config['window']
        }
        
        response = HttpResponseTooManyRequests(
            json.dumps(response_data),
            content_type='application/json'
        )
        response['Retry-After'] = str(config['window'])
        
        return response
    
    def add_rate_limit_headers(self, request, response):
        """Add rate limit information to response headers"""
        config = self.get_rate_limit_config(request)
        if not config:
            return
        
        client_ip = self.get_client_ip(request)
        cache_key = f"rate_limit:{client_ip}:{request.path.split('/')[1]}"
        current_requests = cache.get(cache_key, 0)
        
        response['X-RateLimit-Limit'] = str(config['requests'])
        response['X-RateLimit-Remaining'] = str(max(0, config['requests'] - current_requests))
        response['X-RateLimit-Reset'] = str(config['window'])
    
    def get_client_ip(self, request):
        """Get client IP address"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

User Activity Tracking Middleware

from django.contrib.auth.models import AnonymousUser
from django.utils import timezone
from .models import UserActivity

class UserActivityMiddleware:
    """Track user activity and update last seen timestamp"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.tracked_paths = ['/dashboard/', '/profile/', '/api/']
        self.excluded_paths = ['/static/', '/media/', '/favicon.ico']
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # Track activity after successful response
        if (response.status_code < 400 and 
            self.should_track_activity(request)):
            self.track_user_activity(request, response)
        
        return response
    
    def should_track_activity(self, request):
        """Determine if activity should be tracked"""
        # Don't track static files
        if any(request.path.startswith(path) for path in self.excluded_paths):
            return False
        
        # Only track authenticated users
        if not hasattr(request, 'user') or isinstance(request.user, AnonymousUser):
            return False
        
        # Only track specific paths or all if no specific paths defined
        if self.tracked_paths:
            return any(request.path.startswith(path) for path in self.tracked_paths)
        
        return True
    
    def track_user_activity(self, request, response):
        """Record user activity"""
        try:
            activity_data = {
                'user': request.user,
                'path': request.path,
                'method': request.method,
                'ip_address': self.get_client_ip(request),
                'user_agent': request.META.get('HTTP_USER_AGENT', ''),
                'timestamp': timezone.now(),
                'status_code': response.status_code
            }
            
            # Create activity record
            UserActivity.objects.create(**activity_data)
            
            # Update user's last activity timestamp
            request.user.last_activity = timezone.now()
            request.user.save(update_fields=['last_activity'])
            
        except Exception as e:
            # Log error but don't break the request
            logger.error(f"Error tracking user activity: {e}")
    
    def get_client_ip(self, request):
        """Get client IP address"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

# models.py
class UserActivity(models.Model):
    """Model to store user activity"""
    user = models.ForeignKey(User, on_delete=models.CASCADE)
    path = models.CharField(max_length=255)
    method = models.CharField(max_length=10)
    ip_address = models.GenericIPAddressField()
    user_agent = models.TextField()
    timestamp = models.DateTimeField()
    status_code = models.IntegerField()
    
    class Meta:
        ordering = ['-timestamp']
        indexes = [
            models.Index(fields=['user', 'timestamp']),
            models.Index(fields=['timestamp']),
        ]

Content Security Policy Middleware

class CSPMiddleware:
    """Content Security Policy middleware"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        self.default_policy = {
            'default-src': ["'self'"],
            'script-src': ["'self'", "'unsafe-inline'"],
            'style-src': ["'self'", "'unsafe-inline'"],
            'img-src': ["'self'", "data:", "https:"],
            'font-src': ["'self'"],
            'connect-src': ["'self'"],
            'frame-src': ["'none'"],
            'object-src': ["'none'"],
            'base-uri': ["'self'"],
            'form-action': ["'self'"]
        }
    
    def __call__(self, request):
        response = self.get_response(request)
        
        # Add CSP header
        csp_header = self.build_csp_header(request)
        response['Content-Security-Policy'] = csp_header
        
        return response
    
    def build_csp_header(self, request):
        """Build CSP header based on request"""
        policy = self.default_policy.copy()
        
        # Customize policy based on request path
        if request.path.startswith('/admin/'):
            # More permissive policy for admin
            policy['script-src'].append("'unsafe-eval'")
        
        elif request.path.startswith('/api/'):
            # Strict policy for API
            policy['script-src'] = ["'none'"]
            policy['style-src'] = ["'none'"]
        
        # Build header string
        directives = []
        for directive, sources in policy.items():
            sources_str = ' '.join(sources)
            directives.append(f"{directive} {sources_str}")
        
        return '; '.join(directives)

Mobile Detection Middleware

import re

class MobileDetectionMiddleware:
    """Detect mobile devices and add context"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        
        # Mobile user agent patterns
        self.mobile_patterns = [
            re.compile(r'Mobile', re.IGNORECASE),
            re.compile(r'Android', re.IGNORECASE),
            re.compile(r'iPhone', re.IGNORECASE),
            re.compile(r'iPad', re.IGNORECASE),
            re.compile(r'Windows Phone', re.IGNORECASE),
            re.compile(r'BlackBerry', re.IGNORECASE),
        ]
        
        # Tablet patterns
        self.tablet_patterns = [
            re.compile(r'iPad', re.IGNORECASE),
            re.compile(r'Android.*Tablet', re.IGNORECASE),
            re.compile(r'Kindle', re.IGNORECASE),
        ]
    
    def __call__(self, request):
        # Detect device type
        user_agent = request.META.get('HTTP_USER_AGENT', '')
        
        request.is_mobile = self.is_mobile(user_agent)
        request.is_tablet = self.is_tablet(user_agent)
        request.is_desktop = not (request.is_mobile or request.is_tablet)
        
        # Add device info to request
        request.device_info = {
            'type': self.get_device_type(request),
            'user_agent': user_agent,
            'screen_size': self.estimate_screen_size(request)
        }
        
        response = self.get_response(request)
        
        # Add device type header
        response['X-Device-Type'] = request.device_info['type']
        
        return response
    
    def is_mobile(self, user_agent):
        """Check if user agent indicates mobile device"""
        return any(pattern.search(user_agent) for pattern in self.mobile_patterns)
    
    def is_tablet(self, user_agent):
        """Check if user agent indicates tablet device"""
        return any(pattern.search(user_agent) for pattern in self.tablet_patterns)
    
    def get_device_type(self, request):
        """Get device type string"""
        if request.is_tablet:
            return 'tablet'
        elif request.is_mobile:
            return 'mobile'
        else:
            return 'desktop'
    
    def estimate_screen_size(self, request):
        """Estimate screen size category"""
        if request.is_mobile:
            return 'small'
        elif request.is_tablet:
            return 'medium'
        else:
            return 'large'

Advanced Middleware Patterns

Conditional Middleware

class ConditionalMiddleware:
    """Middleware that runs conditionally based on settings"""
    
    def __init__(self, get_response):
        self.get_response = get_response
        
        # Load configuration
        from django.conf import settings
        self.config = getattr(settings, 'CONDITIONAL_MIDDLEWARE_CONFIG', {})
        
        self.enabled = self.config.get('enabled', True)
        self.debug_mode = self.config.get('debug', False)
        self.allowed_ips = self.config.get('allowed_ips', [])
        self.excluded_paths = self.config.get('excluded_paths', [])
    
    def __call__(self, request):
        # Check if middleware should run
        if not self.should_process(request):
            return self.get_response(request)
        
        # Process request
        if self.debug_mode:
            print(f"ConditionalMiddleware processing: {request.path}")
        
        response = self.get_response(request)
        
        # Process response
        response['X-Processed-By'] = 'ConditionalMiddleware'
        
        return response
    
    def should_process(self, request):
        """Determine if middleware should process this request"""
        if not self.enabled:
            return False
        
        # Check IP whitelist
        if self.allowed_ips:
            client_ip = self.get_client_ip(request)
            if client_ip not in self.allowed_ips:
                return False
        
        # Check excluded paths
        if any(request.path.startswith(path) for path in self.excluded_paths):
            return False
        
        return True
    
    def get_client_ip(self, request):
        """Get client IP address"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

Middleware with Database Integration

from .models import RequestLog, BlockedIP

class DatabaseIntegratedMiddleware:
    """Middleware that integrates with database models"""
    
    def __init__(self, get_response):
        self.get_response = get_response
    
    def __call__(self, request):
        # Check if IP is blocked
        client_ip = self.get_client_ip(request)
        if self.is_ip_blocked(client_ip):
            return HttpResponseForbidden('IP address blocked')
        
        # Log request to database
        request_log = self.create_request_log(request)
        
        response = self.get_response(request)
        
        # Update request log with response info
        self.update_request_log(request_log, response)
        
        return response
    
    def is_ip_blocked(self, ip_address):
        """Check if IP address is blocked"""
        try:
            return BlockedIP.objects.filter(
                ip_address=ip_address,
                is_active=True
            ).exists()
        except Exception:
            return False
    
    def create_request_log(self, request):
        """Create request log entry"""
        try:
            return RequestLog.objects.create(
                ip_address=self.get_client_ip(request),
                path=request.path,
                method=request.method,
                user_agent=request.META.get('HTTP_USER_AGENT', ''),
                user=request.user if hasattr(request, 'user') and request.user.is_authenticated else None,
                timestamp=timezone.now()
            )
        except Exception as e:
            logger.error(f"Error creating request log: {e}")
            return None
    
    def update_request_log(self, request_log, response):
        """Update request log with response information"""
        if request_log:
            try:
                request_log.status_code = response.status_code
                request_log.response_size = len(response.content) if hasattr(response, 'content') else 0
                request_log.save(update_fields=['status_code', 'response_size'])
            except Exception as e:
                logger.error(f"Error updating request log: {e}")
    
    def get_client_ip(self, request):
        """Get client IP address"""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0]
        else:
            ip = request.META.get('REMOTE_ADDR')
        return ip

# models.py
class RequestLog(models.Model):
    ip_address = models.GenericIPAddressField()
    path = models.CharField(max_length=255)
    method = models.CharField(max_length=10)
    user_agent = models.TextField()
    user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True)
    timestamp = models.DateTimeField()
    status_code = models.IntegerField(null=True, blank=True)
    response_size = models.IntegerField(null=True, blank=True)

class BlockedIP(models.Model):
    ip_address = models.GenericIPAddressField(unique=True)
    reason = models.CharField(max_length=255)
    blocked_at = models.DateTimeField(auto_now_add=True)
    is_active = models.BooleanField(default=True)

Testing Custom Middleware

Unit Tests for Middleware

from django.test import TestCase, RequestFactory
from django.http import HttpResponse
from django.contrib.auth.models import User, AnonymousUser
from myapp.middleware import RequestLoggingMiddleware, RateLimitMiddleware

class CustomMiddlewareTests(TestCase):
    """Test custom middleware functionality"""
    
    def setUp(self):
        self.factory = RequestFactory()
        self.user = User.objects.create_user(
            username='testuser',
            email='test@example.com',
            password='testpass123'
        )
    
    def test_request_logging_middleware(self):
        """Test request logging middleware"""
        def get_response(request):
            return HttpResponse("Test response")
        
        middleware = RequestLoggingMiddleware(get_response)
        request = self.factory.get('/test/')
        request.user = self.user
        
        with self.assertLogs('request_logger', level='INFO') as cm:
            response = middleware(request)
        
        self.assertEqual(response.status_code, 200)
        self.assertEqual(len(cm.output), 2)  # Request and response logs
        self.assertIn('REQUEST:', cm.output[0])
        self.assertIn('RESPONSE:', cm.output[1])
    
    def test_rate_limiting_middleware(self):
        """Test rate limiting middleware"""
        def get_response(request):
            return HttpResponse("Test response")
        
        middleware = RateLimitMiddleware(get_response)
        
        # First request should succeed
        request = self.factory.get('/api/test/')
        response = middleware(request)
        self.assertEqual(response.status_code, 200)
        
        # Simulate many requests to trigger rate limit
        # (This would require mocking the cache for proper testing)
    
    def test_mobile_detection_middleware(self):
        """Test mobile detection middleware"""
        def get_response(request):
            return HttpResponse("Test response")
        
        middleware = MobileDetectionMiddleware(get_response)
        
        # Test mobile user agent
        request = self.factory.get('/test/')
        request.META['HTTP_USER_AGENT'] = 'Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)'
        
        response = middleware(request)
        
        self.assertTrue(request.is_mobile)
        self.assertFalse(request.is_tablet)
        self.assertFalse(request.is_desktop)
        self.assertEqual(response['X-Device-Type'], 'mobile')

Integration Tests

from django.test import TestCase, Client
from django.urls import reverse

class MiddlewareIntegrationTests(TestCase):
    """Integration tests for middleware"""
    
    def setUp(self):
        self.client = Client()
    
    def test_middleware_chain_execution(self):
        """Test that middleware chain executes correctly"""
        response = self.client.get('/test/')
        
        # Check that all middleware added their headers
        self.assertIn('X-Processing-Time', response)
        self.assertIn('X-Request-ID', response)
        self.assertIn('X-Device-Type', response)
    
    def test_middleware_with_authentication(self):
        """Test middleware behavior with authenticated users"""
        user = User.objects.create_user(
            username='testuser',
            password='testpass123'
        )
        
        self.client.login(username='testuser', password='testpass123')
        response = self.client.get('/dashboard/')
        
        # Check that user activity was tracked
        self.assertTrue(UserActivity.objects.filter(user=user).exists())

Middleware Configuration

Settings Configuration

# settings.py

# Add custom middleware to MIDDLEWARE list
MIDDLEWARE = [
    'django.middleware.security.SecurityMiddleware',
    'myapp.middleware.RequestLoggingMiddleware',
    'myapp.middleware.RateLimitMiddleware',
    'django.contrib.sessions.middleware.SessionMiddleware',
    'django.middleware.common.CommonMiddleware',
    'django.middleware.csrf.CsrfViewMiddleware',
    'django.contrib.auth.middleware.AuthenticationMiddleware',
    'myapp.middleware.UserActivityMiddleware',
    'django.contrib.messages.middleware.MessageMiddleware',
    'myapp.middleware.MobileDetectionMiddleware',
    'django.middleware.clickjacking.XFrameOptionsMiddleware',
]

# Custom middleware configuration
CONDITIONAL_MIDDLEWARE_CONFIG = {
    'enabled': True,
    'debug': False,
    'allowed_ips': ['127.0.0.1', '192.168.1.0/24'],
    'excluded_paths': ['/static/', '/media/', '/health/']
}

# Logging configuration for middleware
LOGGING = {
    'version': 1,
    'disable_existing_loggers': False,
    'handlers': {
        'request_file': {
            'level': 'INFO',
            'class': 'logging.handlers.RotatingFileHandler',
            'filename': 'logs/requests.log',
            'maxBytes': 10 * 1024 * 1024,  # 10MB
            'backupCount': 5,
        },
    },
    'loggers': {
        'request_logger': {
            'handlers': ['request_file'],
            'level': 'INFO',
            'propagate': False,
        },
    },
}

Best Practices for Custom Middleware

Performance Considerations

  • Keep middleware lightweight and fast
  • Avoid database queries in middleware when possible
  • Use caching for expensive operations
  • Profile middleware performance impact

Error Handling

  • Always handle exceptions gracefully
  • Don't let middleware errors break requests
  • Log errors appropriately
  • Provide fallback behavior

Security

  • Validate all inputs and outputs
  • Be careful with request/response modification
  • Implement proper authentication checks
  • Log security-relevant events

Maintainability

  • Keep middleware focused and single-purpose
  • Make middleware configurable
  • Write comprehensive tests
  • Document middleware behavior

Next Steps

Now that you know how to create custom middleware, let's explore middleware ordering and understand how the sequence of middleware affects your application's behavior.