middleware-patterns.md 8.8 KB

FastAPI Middleware Patterns

Request/response processing, CORS, security, and error handling.

Basic Middleware

from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time

app = FastAPI()

# Function-based middleware
@app.middleware("http")
async def add_timing_header(request: Request, call_next):
    start = time.perf_counter()
    response = await call_next(request)
    duration = time.perf_counter() - start
    response.headers["X-Process-Time"] = f"{duration:.4f}"
    return response


# Class-based middleware
class TimingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start = time.perf_counter()
        response = await call_next(request)
        duration = time.perf_counter() - start
        response.headers["X-Process-Time"] = f"{duration:.4f}"
        return response

app.add_middleware(TimingMiddleware)

CORS Configuration

from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000",
        "https://myapp.com",
    ],
    allow_credentials=True,
    allow_methods=["*"],  # Or specific: ["GET", "POST"]
    allow_headers=["*"],
    expose_headers=["X-Request-ID"],
    max_age=600,  # Cache preflight for 10 minutes
)

# Development: allow all origins
if settings.debug:
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

Security Headers

from starlette.middleware.base import BaseHTTPMiddleware

class SecurityHeadersMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)

        # Security headers
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"] = "DENY"
        response.headers["X-XSS-Protection"] = "1; mode=block"
        response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"

        # CSP - customize for your app
        response.headers["Content-Security-Policy"] = (
            "default-src 'self'; "
            "script-src 'self' 'unsafe-inline'; "
            "style-src 'self' 'unsafe-inline'"
        )

        # HSTS (only in production with HTTPS)
        if not request.url.scheme == "http":
            response.headers["Strict-Transport-Security"] = (
                "max-age=31536000; includeSubDomains"
            )

        return response

app.add_middleware(SecurityHeadersMiddleware)

Request ID Tracking

from uuid import uuid4
from contextvars import ContextVar

request_id_ctx: ContextVar[str] = ContextVar("request_id", default="")

class RequestIDMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Use existing or generate new
        request_id = request.headers.get("X-Request-ID") or str(uuid4())

        # Store in context for logging
        request_id_ctx.set(request_id)
        request.state.request_id = request_id

        response = await call_next(request)
        response.headers["X-Request-ID"] = request_id

        return response

app.add_middleware(RequestIDMiddleware)


# Access in endpoints
@app.get("/trace")
async def trace(request: Request):
    return {"request_id": request.state.request_id}

Logging Middleware

import logging
import time
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)

class LoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        start = time.perf_counter()

        # Log request
        logger.info(
            "Request started",
            extra={
                "method": request.method,
                "path": request.url.path,
                "client": request.client.host if request.client else None,
            }
        )

        response = await call_next(request)

        # Log response
        duration = time.perf_counter() - start
        logger.info(
            "Request completed",
            extra={
                "method": request.method,
                "path": request.url.path,
                "status": response.status_code,
                "duration": f"{duration:.3f}s",
            }
        )

        return response

app.add_middleware(LoggingMiddleware)

Error Handling Middleware

from fastapi import Request
from fastapi.responses import JSONResponse
import traceback

class ErrorHandlingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        try:
            return await call_next(request)
        except Exception as exc:
            # Log the full traceback
            logger.exception(
                "Unhandled exception",
                extra={
                    "path": request.url.path,
                    "method": request.method,
                    "traceback": traceback.format_exc(),
                }
            )

            # Return generic error (hide details in production)
            return JSONResponse(
                status_code=500,
                content={
                    "detail": "Internal server error",
                    "request_id": getattr(request.state, "request_id", None),
                },
            )

app.add_middleware(ErrorHandlingMiddleware)

Rate Limiting

from collections import defaultdict
from datetime import datetime, timedelta
import asyncio

class RateLimitMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, requests: int = 100, window: int = 60):
        super().__init__(app)
        self.requests = requests
        self.window = window
        self.clients: dict[str, list[datetime]] = defaultdict(list)
        self.lock = asyncio.Lock()

    async def dispatch(self, request: Request, call_next):
        client_ip = request.client.host if request.client else "unknown"
        now = datetime.now()
        window_start = now - timedelta(seconds=self.window)

        async with self.lock:
            # Remove old requests
            self.clients[client_ip] = [
                t for t in self.clients[client_ip]
                if t > window_start
            ]

            if len(self.clients[client_ip]) >= self.requests:
                return JSONResponse(
                    status_code=429,
                    content={"detail": "Rate limit exceeded"},
                    headers={
                        "Retry-After": str(self.window),
                        "X-RateLimit-Limit": str(self.requests),
                        "X-RateLimit-Remaining": "0",
                    },
                )

            self.clients[client_ip].append(now)
            remaining = self.requests - len(self.clients[client_ip])

        response = await call_next(request)
        response.headers["X-RateLimit-Limit"] = str(self.requests)
        response.headers["X-RateLimit-Remaining"] = str(remaining)
        return response

app.add_middleware(RateLimitMiddleware, requests=100, window=60)

GZip Compression

from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(
    GZipMiddleware,
    minimum_size=1000,  # Only compress responses > 1KB
)

Trusted Host Validation

from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["example.com", "*.example.com"],
)

Middleware Order

# Middleware executes in REVERSE order of addition
# Last added = First to process request, Last to process response

app = FastAPI()

# 1. Error handling (outermost - catches all errors)
app.add_middleware(ErrorHandlingMiddleware)

# 2. Logging (log after error handling)
app.add_middleware(LoggingMiddleware)

# 3. Request ID (needed for logging)
app.add_middleware(RequestIDMiddleware)

# 4. Security (before business logic)
app.add_middleware(SecurityHeadersMiddleware)

# 5. CORS (needs to be early for preflight)
app.add_middleware(CORSMiddleware, ...)

# 6. GZip (compress final response)
app.add_middleware(GZipMiddleware, minimum_size=1000)

# Request flow: GZip → CORS → Security → RequestID → Logging → Error → App
# Response flow: App → Error → Logging → RequestID → Security → CORS → GZip

Quick Reference

Middleware Purpose
CORSMiddleware Cross-origin requests
GZipMiddleware Response compression
TrustedHostMiddleware Host validation
BaseHTTPMiddleware Custom middleware base
@app.middleware("http") Simple function middleware
Order Position Middleware Type
First (outer) Error handling
Early Logging, tracing
Middle Auth, rate limiting
Late CORS, compression