from starlette.requests import Request from starlette.datastructures import Headers from urllib.parse import urlparse, urlunparse import copy from typing import Optional, Dict, Any import threading class ModifiableRequest: """可修改的请求对象,由 Request 对象修改而来""" # 类级别的 round-robin 计数器 _counter = 0 _counter_lock = threading.Lock() def __init__(self, hosts: list[str], request: Request): """ 初始化可修改请求对象 Args: hosts: 主机列表,用于负载均衡 request: 原始 Starlette Request 对象 """ self._original_request = request # 复制原始请求的属性 self._method = request.method self._hosts = hosts self._url = request.url self._headers = dict(request.headers) self._body: Optional[bytes] = None # 延迟加载 self._query_params = dict(request.query_params) self._path_params = dict(request.path_params) # 标记是否已修改 self._modified = False def _get_next_host(self) -> str: """ 使用 round-robin 算法获取下一个主机 Returns: 下一个主机的 URL """ if not self._hosts: raise ValueError("No hosts available for load balancing") with self._counter_lock: host = self._hosts[ModifiableRequest._counter % len(self._hosts)] ModifiableRequest._counter += 1 return host def _build_url_with_host(self, path: str) -> str: """ 使用选定的主机构建完整的 URL Args: path: 请求路径 Returns: 完整的 URL """ host = self._get_next_host() # 确保 host 有协议前缀 if not host.startswith(("http://", "https://")): host = "http://" + host # 确保路径以 / 开头 if not path.startswith("/"): path = "/" + path return host + path def rewrite_uri(self, uri: str): """ 重写请求 URI Args: uri: 新的 URI """ self._url = uri self._modified = True return self def rewrite_method(self, method: str): """ 重写请求方法 Args: method: 新的 HTTP 方法 """ self._method = method.upper() self._modified = True return self def rewrite_headers(self, headers: Dict[str, str]): """ 重写请求头 Args: headers: 新的请求头字典 """ self._headers.update(headers) self._modified = True return self def rewrite_body(self, body: bytes): """ 重写请求体 Args: body: 新的请求体 """ self._body = body self._modified = True return self def add_header(self, key: str, value: str): """ 添加单个请求头 Args: key: 请求头名称 value: 请求头值 """ self._headers[key] = value self._modified = True return self def remove_header(self, key: str): """ 移除请求头 Args: key: 要移除的请求头名称 """ if key in self._headers: del self._headers[key] self._modified = True return self def set_query_param(self, key: str, value: str): """ 设置查询参数 Args: key: 参数名 value: 参数值 """ self._query_params[key] = value self._modified = True return self def remove_query_param(self, key: str): """ 移除查询参数 Args: key: 要移除的参数名 """ if key in self._query_params: del self._query_params[key] self._modified = True return self async def _get_body(self) -> bytes: """获取请求体,如果未修改则从原始请求获取""" if self._body is not None: return self._body return await self._original_request.body() def build(self) -> Request: """ 构建新的 Request 对象 Returns: 修改后的 Request 对象 """ if not self._modified: # 如果没有修改,直接返回原始请求 return self._original_request # 构建新的 URL parsed_url = urlparse(str(self._url)) # 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡 if not parsed_url.netloc: # 使用 round-robin 选择主机 new_url = self._build_url_with_host(parsed_url.path) else: # URL 已经包含主机,直接使用 new_url = str(self._url) # 重新解析 URL 以处理查询参数 parsed_url = urlparse(new_url) # 添加查询参数 query_params = [] for key, value in self._query_params.items(): query_params.append(f"{key}={value}") if query_params: query_string = "&".join(query_params) # 如果 URL 已经有查询参数,则合并 if parsed_url.query: query_string = f"{parsed_url.query}&{query_string}" parsed_url = parsed_url._replace(query=query_string) final_url = urlunparse(parsed_url) # 创建新的 Request 对象 # 注意:这里我们需要创建一个模拟的 Request 对象 # 因为 Starlette 的 Request 对象是不可变的 # 我们返回一个包含修改后数据的字典,供代理处理器使用 # 为了保持兼容性,我们创建一个包含所有必要信息的字典 # 然后在代理处理器中使用这些信息 request_info = { "method": self._method, "url": final_url, "headers": Headers(self._headers), "query_params": self._query_params, "path_params": self._path_params, "body": self._body, "original_request": self._original_request, } # 返回一个包装对象,模拟 Request 接口 return RequestWrapper(request_info) # type: ignore class RequestWrapper: """Request 包装器,提供修改后的请求信息""" def __init__(self, request_info: Dict[str, Any]): self._info = request_info self._original_request = request_info["original_request"] @property def method(self) -> str: return self._info["method"] @property def url(self): return self._info["url"] @property def headers(self): return self._info["headers"] @property def query_params(self): return self._info["query_params"] @property def path_params(self): return self._info["path_params"] async def body(self) -> bytes: if self._info["body"] is not None: return self._info["body"] return await self._original_request.body() # 代理其他必要的属性到原始请求 def __getattr__(self, name): return getattr(self._original_request, name)