Files
starlette-dial/pkg/dial/req.py
2025-06-28 19:03:29 +08:00

268 lines
7.2 KiB
Python

from starlette.requests import Request
from starlette.datastructures import Headers
from urllib.parse import urlparse, urlunparse
import copy
from typing import Optional, Dict, Any
import threading
class ModifiableRequest:
"""可修改的请求对象,由 Request 对象修改而来"""
# 类级别的 round-robin 计数器
_counter = 0
_counter_lock = threading.Lock()
def __init__(self, hosts: list[str], request: Request):
"""
初始化可修改请求对象
Args:
hosts: 主机列表,用于负载均衡
request: 原始 Starlette Request 对象
"""
self._original_request = request
# 复制原始请求的属性
self._method = request.method
self._hosts = hosts
self._url = request.url
self._headers = dict(request.headers)
self._body: Optional[bytes] = None # 延迟加载
self._query_params = dict(request.query_params)
self._path_params = dict(request.path_params)
# 标记是否已修改
self._modified = False
def _get_next_host(self) -> str:
"""
使用 round-robin 算法获取下一个主机
Returns:
下一个主机的 URL
"""
if not self._hosts:
raise ValueError("No hosts available for load balancing")
with self._counter_lock:
host = self._hosts[ModifiableRequest._counter % len(self._hosts)]
ModifiableRequest._counter += 1
return host
def _build_url_with_host(self, path: str) -> str:
"""
使用选定的主机构建完整的 URL
Args:
path: 请求路径
Returns:
完整的 URL
"""
host = self._get_next_host()
# 确保 host 有协议前缀
if not host.startswith(("http://", "https://")):
host = "http://" + host
# 确保路径以 / 开头
if not path.startswith("/"):
path = "/" + path
return host + path
def rewrite_uri(self, uri: str):
"""
重写请求 URI
Args:
uri: 新的 URI
"""
self._url = uri
self._modified = True
return self
def rewrite_method(self, method: str):
"""
重写请求方法
Args:
method: 新的 HTTP 方法
"""
self._method = method.upper()
self._modified = True
return self
def rewrite_headers(self, headers: Dict[str, str]):
"""
重写请求头
Args:
headers: 新的请求头字典
"""
self._headers.update(headers)
self._modified = True
return self
def rewrite_body(self, body: bytes):
"""
重写请求体
Args:
body: 新的请求体
"""
self._body = body
self._modified = True
return self
def add_header(self, key: str, value: str):
"""
添加单个请求头
Args:
key: 请求头名称
value: 请求头值
"""
self._headers[key] = value
self._modified = True
return self
def remove_header(self, key: str):
"""
移除请求头
Args:
key: 要移除的请求头名称
"""
if key in self._headers:
del self._headers[key]
self._modified = True
return self
def set_query_param(self, key: str, value: str):
"""
设置查询参数
Args:
key: 参数名
value: 参数值
"""
self._query_params[key] = value
self._modified = True
return self
def remove_query_param(self, key: str):
"""
移除查询参数
Args:
key: 要移除的参数名
"""
if key in self._query_params:
del self._query_params[key]
self._modified = True
return self
async def _get_body(self) -> bytes:
"""获取请求体,如果未修改则从原始请求获取"""
if self._body is not None:
return self._body
return await self._original_request.body()
def build(self) -> Request:
"""
构建新的 Request 对象
Returns:
修改后的 Request 对象
"""
if not self._modified:
# 如果没有修改,直接返回原始请求
return self._original_request
# 构建新的 URL
parsed_url = urlparse(str(self._url))
# 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡
if not parsed_url.netloc:
# 使用 round-robin 选择主机
new_url = self._build_url_with_host(parsed_url.path)
else:
# URL 已经包含主机,直接使用
new_url = str(self._url)
# 重新解析 URL 以处理查询参数
parsed_url = urlparse(new_url)
# 添加查询参数
query_params = []
for key, value in self._query_params.items():
query_params.append(f"{key}={value}")
if query_params:
query_string = "&".join(query_params)
# 如果 URL 已经有查询参数,则合并
if parsed_url.query:
query_string = f"{parsed_url.query}&{query_string}"
parsed_url = parsed_url._replace(query=query_string)
final_url = urlunparse(parsed_url)
# 创建新的 Request 对象
# 注意:这里我们需要创建一个模拟的 Request 对象
# 因为 Starlette 的 Request 对象是不可变的
# 我们返回一个包含修改后数据的字典,供代理处理器使用
# 为了保持兼容性,我们创建一个包含所有必要信息的字典
# 然后在代理处理器中使用这些信息
request_info = {
"method": self._method,
"url": final_url,
"headers": Headers(self._headers),
"query_params": self._query_params,
"path_params": self._path_params,
"body": self._body,
"original_request": self._original_request,
}
# 返回一个包装对象,模拟 Request 接口
return RequestWrapper(request_info) # type: ignore
class RequestWrapper:
"""Request 包装器,提供修改后的请求信息"""
def __init__(self, request_info: Dict[str, Any]):
self._info = request_info
self._original_request = request_info["original_request"]
@property
def method(self) -> str:
return self._info["method"]
@property
def url(self):
return self._info["url"]
@property
def headers(self):
return self._info["headers"]
@property
def query_params(self):
return self._info["query_params"]
@property
def path_params(self):
return self._info["path_params"]
async def body(self) -> bytes:
if self._info["body"] is not None:
return self._info["body"]
return await self._original_request.body()
# 代理其他必要的属性到原始请求
def __getattr__(self, name):
return getattr(self._original_request, name)