FastAPI Middleware Patterns
Request/response processing, CORS, security, and error handling.
Basic Middleware
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
app = FastAPI()
# Function-based middleware
@app.middleware("http")
async def add_timing_header(request: Request, call_next):
start = time.perf_counter()
response = await call_next(request)
duration = time.perf_counter() - start
response.headers["X-Process-Time"] = f"{duration:.4f}"
return response
# Class-based middleware
class TimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start = time.perf_counter()
response = await call_next(request)
duration = time.perf_counter() - start
response.headers["X-Process-Time"] = f"{duration:.4f}"
return response
app.add_middleware(TimingMiddleware)
CORS Configuration
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:3000",
"https://myapp.com",
],
allow_credentials=True,
allow_methods=["*"], # Or specific: ["GET", "POST"]
allow_headers=["*"],
expose_headers=["X-Request-ID"],
max_age=600, # Cache preflight for 10 minutes
)
# Development: allow all origins
if settings.debug:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
Security Headers
from starlette.middleware.base import BaseHTTPMiddleware
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
# Security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# CSP - customize for your app
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'"
)
# HSTS (only in production with HTTPS)
if not request.url.scheme == "http":
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response
app.add_middleware(SecurityHeadersMiddleware)
Request ID Tracking
from uuid import uuid4
from contextvars import ContextVar
request_id_ctx: ContextVar[str] = ContextVar("request_id", default="")
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Use existing or generate new
request_id = request.headers.get("X-Request-ID") or str(uuid4())
# Store in context for logging
request_id_ctx.set(request_id)
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
app.add_middleware(RequestIDMiddleware)
# Access in endpoints
@app.get("/trace")
async def trace(request: Request):
return {"request_id": request.state.request_id}
Logging Middleware
import logging
import time
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start = time.perf_counter()
# Log request
logger.info(
"Request started",
extra={
"method": request.method,
"path": request.url.path,
"client": request.client.host if request.client else None,
}
)
response = await call_next(request)
# Log response
duration = time.perf_counter() - start
logger.info(
"Request completed",
extra={
"method": request.method,
"path": request.url.path,
"status": response.status_code,
"duration": f"{duration:.3f}s",
}
)
return response
app.add_middleware(LoggingMiddleware)
Error Handling Middleware
from fastapi import Request
from fastapi.responses import JSONResponse
import traceback
class ErrorHandlingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
try:
return await call_next(request)
except Exception as exc:
# Log the full traceback
logger.exception(
"Unhandled exception",
extra={
"path": request.url.path,
"method": request.method,
"traceback": traceback.format_exc(),
}
)
# Return generic error (hide details in production)
return JSONResponse(
status_code=500,
content={
"detail": "Internal server error",
"request_id": getattr(request.state, "request_id", None),
},
)
app.add_middleware(ErrorHandlingMiddleware)
Rate Limiting
from collections import defaultdict
from datetime import datetime, timedelta
import asyncio
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, requests: int = 100, window: int = 60):
super().__init__(app)
self.requests = requests
self.window = window
self.clients: dict[str, list[datetime]] = defaultdict(list)
self.lock = asyncio.Lock()
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
now = datetime.now()
window_start = now - timedelta(seconds=self.window)
async with self.lock:
# Remove old requests
self.clients[client_ip] = [
t for t in self.clients[client_ip]
if t > window_start
]
if len(self.clients[client_ip]) >= self.requests:
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"},
headers={
"Retry-After": str(self.window),
"X-RateLimit-Limit": str(self.requests),
"X-RateLimit-Remaining": "0",
},
)
self.clients[client_ip].append(now)
remaining = self.requests - len(self.clients[client_ip])
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(self.requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
app.add_middleware(RateLimitMiddleware, requests=100, window=60)
GZip Compression
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(
GZipMiddleware,
minimum_size=1000, # Only compress responses > 1KB
)
Trusted Host Validation
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com"],
)
Middleware Order
# Middleware executes in REVERSE order of addition
# Last added = First to process request, Last to process response
app = FastAPI()
# 1. Error handling (outermost - catches all errors)
app.add_middleware(ErrorHandlingMiddleware)
# 2. Logging (log after error handling)
app.add_middleware(LoggingMiddleware)
# 3. Request ID (needed for logging)
app.add_middleware(RequestIDMiddleware)
# 4. Security (before business logic)
app.add_middleware(SecurityHeadersMiddleware)
# 5. CORS (needs to be early for preflight)
app.add_middleware(CORSMiddleware, ...)
# 6. GZip (compress final response)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# Request flow: GZip → CORS → Security → RequestID → Logging → Error → App
# Response flow: App → Error → Logging → RequestID → Security → CORS → GZip
Quick Reference
| Middleware |
Purpose |
CORSMiddleware |
Cross-origin requests |
GZipMiddleware |
Response compression |
TrustedHostMiddleware |
Host validation |
BaseHTTPMiddleware |
Custom middleware base |
@app.middleware("http") |
Simple function middleware |
| Order Position |
Middleware Type |
| First (outer) |
Error handling |
| Early |
Logging, tracing |
| Middle |
Auth, rate limiting |
| Late |
CORS, compression |