claude-code-openai-wrapper/rate_limiter.py
2025-08-11 20:22:47 +05:30

89 lines
No EOL
2.9 KiB
Python

import os
from typing import Optional
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
def get_rate_limit_key(request: Request) -> str:
"""Get the rate limiting key (IP address) from the request."""
return get_remote_address(request)
def create_rate_limiter() -> Optional[Limiter]:
"""Create and configure the rate limiter based on environment variables."""
rate_limit_enabled = os.getenv('RATE_LIMIT_ENABLED', 'true').lower() in ('true', '1', 'yes', 'on')
if not rate_limit_enabled:
return None
# Create limiter with IP-based identification
limiter = Limiter(
key_func=get_rate_limit_key,
default_limits=[] # We'll apply limits per endpoint
)
return limiter
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Custom rate limit exceeded handler that returns JSON error response."""
# Calculate retry after based on rate limit window (default 60 seconds)
retry_after = 60
response = JSONResponse(
status_code=429,
content={
"error": {
"message": f"Rate limit exceeded. Try again in {retry_after} seconds.",
"type": "rate_limit_exceeded",
"code": "too_many_requests",
"retry_after": retry_after
}
},
headers={"Retry-After": str(retry_after)}
)
return response
def get_rate_limit_for_endpoint(endpoint: str) -> str:
"""Get rate limit string for specific endpoint based on environment variables."""
# Default rate limits
defaults = {
"chat": "10/minute",
"debug": "2/minute",
"auth": "10/minute",
"session": "15/minute",
"health": "30/minute",
"general": "30/minute"
}
# Environment variable mappings
env_mappings = {
"chat": "RATE_LIMIT_CHAT_PER_MINUTE",
"debug": "RATE_LIMIT_DEBUG_PER_MINUTE",
"auth": "RATE_LIMIT_AUTH_PER_MINUTE",
"session": "RATE_LIMIT_SESSION_PER_MINUTE",
"health": "RATE_LIMIT_HEALTH_PER_MINUTE",
"general": "RATE_LIMIT_PER_MINUTE"
}
# Get rate limit from environment or use default
env_var = env_mappings.get(endpoint, "RATE_LIMIT_PER_MINUTE")
rate_per_minute = int(os.getenv(env_var, defaults.get(endpoint, "30").split("/")[0]))
return f"{rate_per_minute}/minute"
def rate_limit_endpoint(endpoint: str):
"""Decorator factory for applying rate limits to endpoints."""
def decorator(func):
if limiter:
return limiter.limit(get_rate_limit_for_endpoint(endpoint))(func)
return func
return decorator
# Create the global limiter instance
limiter = create_rate_limiter()