🎉 init starlette dial
This commit is contained in:
267
pkg/dial/req.py
Normal file
267
pkg/dial/req.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.datastructures import Headers
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import copy
|
||||
from typing import Optional, Dict, Any
|
||||
import threading
|
||||
|
||||
|
||||
class ModifiableRequest:
|
||||
"""可修改的请求对象,由 Request 对象修改而来"""
|
||||
|
||||
# 类级别的 round-robin 计数器
|
||||
_counter = 0
|
||||
_counter_lock = threading.Lock()
|
||||
|
||||
def __init__(self, hosts: list[str], request: Request):
|
||||
"""
|
||||
初始化可修改请求对象
|
||||
|
||||
Args:
|
||||
hosts: 主机列表,用于负载均衡
|
||||
request: 原始 Starlette Request 对象
|
||||
"""
|
||||
self._original_request = request
|
||||
|
||||
# 复制原始请求的属性
|
||||
self._method = request.method
|
||||
self._hosts = hosts
|
||||
self._url = request.url
|
||||
self._headers = dict(request.headers)
|
||||
self._body: Optional[bytes] = None # 延迟加载
|
||||
self._query_params = dict(request.query_params)
|
||||
self._path_params = dict(request.path_params)
|
||||
|
||||
# 标记是否已修改
|
||||
self._modified = False
|
||||
|
||||
def _get_next_host(self) -> str:
|
||||
"""
|
||||
使用 round-robin 算法获取下一个主机
|
||||
|
||||
Returns:
|
||||
下一个主机的 URL
|
||||
"""
|
||||
if not self._hosts:
|
||||
raise ValueError("No hosts available for load balancing")
|
||||
|
||||
with self._counter_lock:
|
||||
host = self._hosts[ModifiableRequest._counter % len(self._hosts)]
|
||||
ModifiableRequest._counter += 1
|
||||
return host
|
||||
|
||||
def _build_url_with_host(self, path: str) -> str:
|
||||
"""
|
||||
使用选定的主机构建完整的 URL
|
||||
|
||||
Args:
|
||||
path: 请求路径
|
||||
|
||||
Returns:
|
||||
完整的 URL
|
||||
"""
|
||||
host = self._get_next_host()
|
||||
|
||||
# 确保 host 有协议前缀
|
||||
if not host.startswith(("http://", "https://")):
|
||||
host = "http://" + host
|
||||
|
||||
# 确保路径以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return host + path
|
||||
|
||||
def rewrite_uri(self, uri: str):
|
||||
"""
|
||||
重写请求 URI
|
||||
|
||||
Args:
|
||||
uri: 新的 URI
|
||||
"""
|
||||
self._url = uri
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_method(self, method: str):
|
||||
"""
|
||||
重写请求方法
|
||||
|
||||
Args:
|
||||
method: 新的 HTTP 方法
|
||||
"""
|
||||
self._method = method.upper()
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_headers(self, headers: Dict[str, str]):
|
||||
"""
|
||||
重写请求头
|
||||
|
||||
Args:
|
||||
headers: 新的请求头字典
|
||||
"""
|
||||
self._headers.update(headers)
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_body(self, body: bytes):
|
||||
"""
|
||||
重写请求体
|
||||
|
||||
Args:
|
||||
body: 新的请求体
|
||||
"""
|
||||
self._body = body
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def add_header(self, key: str, value: str):
|
||||
"""
|
||||
添加单个请求头
|
||||
|
||||
Args:
|
||||
key: 请求头名称
|
||||
value: 请求头值
|
||||
"""
|
||||
self._headers[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_header(self, key: str):
|
||||
"""
|
||||
移除请求头
|
||||
|
||||
Args:
|
||||
key: 要移除的请求头名称
|
||||
"""
|
||||
if key in self._headers:
|
||||
del self._headers[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def set_query_param(self, key: str, value: str):
|
||||
"""
|
||||
设置查询参数
|
||||
|
||||
Args:
|
||||
key: 参数名
|
||||
value: 参数值
|
||||
"""
|
||||
self._query_params[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_query_param(self, key: str):
|
||||
"""
|
||||
移除查询参数
|
||||
|
||||
Args:
|
||||
key: 要移除的参数名
|
||||
"""
|
||||
if key in self._query_params:
|
||||
del self._query_params[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
async def _get_body(self) -> bytes:
|
||||
"""获取请求体,如果未修改则从原始请求获取"""
|
||||
if self._body is not None:
|
||||
return self._body
|
||||
return await self._original_request.body()
|
||||
|
||||
def build(self) -> Request:
|
||||
"""
|
||||
构建新的 Request 对象
|
||||
|
||||
Returns:
|
||||
修改后的 Request 对象
|
||||
"""
|
||||
if not self._modified:
|
||||
# 如果没有修改,直接返回原始请求
|
||||
return self._original_request
|
||||
|
||||
# 构建新的 URL
|
||||
parsed_url = urlparse(str(self._url))
|
||||
|
||||
# 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡
|
||||
if not parsed_url.netloc:
|
||||
# 使用 round-robin 选择主机
|
||||
new_url = self._build_url_with_host(parsed_url.path)
|
||||
else:
|
||||
# URL 已经包含主机,直接使用
|
||||
new_url = str(self._url)
|
||||
|
||||
# 重新解析 URL 以处理查询参数
|
||||
parsed_url = urlparse(new_url)
|
||||
|
||||
# 添加查询参数
|
||||
query_params = []
|
||||
for key, value in self._query_params.items():
|
||||
query_params.append(f"{key}={value}")
|
||||
|
||||
if query_params:
|
||||
query_string = "&".join(query_params)
|
||||
# 如果 URL 已经有查询参数,则合并
|
||||
if parsed_url.query:
|
||||
query_string = f"{parsed_url.query}&{query_string}"
|
||||
parsed_url = parsed_url._replace(query=query_string)
|
||||
|
||||
final_url = urlunparse(parsed_url)
|
||||
|
||||
# 创建新的 Request 对象
|
||||
# 注意:这里我们需要创建一个模拟的 Request 对象
|
||||
# 因为 Starlette 的 Request 对象是不可变的
|
||||
# 我们返回一个包含修改后数据的字典,供代理处理器使用
|
||||
|
||||
# 为了保持兼容性,我们创建一个包含所有必要信息的字典
|
||||
# 然后在代理处理器中使用这些信息
|
||||
request_info = {
|
||||
"method": self._method,
|
||||
"url": final_url,
|
||||
"headers": Headers(self._headers),
|
||||
"query_params": self._query_params,
|
||||
"path_params": self._path_params,
|
||||
"body": self._body,
|
||||
"original_request": self._original_request,
|
||||
}
|
||||
|
||||
# 返回一个包装对象,模拟 Request 接口
|
||||
return RequestWrapper(request_info) # type: ignore
|
||||
|
||||
|
||||
class RequestWrapper:
|
||||
"""Request 包装器,提供修改后的请求信息"""
|
||||
|
||||
def __init__(self, request_info: Dict[str, Any]):
|
||||
self._info = request_info
|
||||
self._original_request = request_info["original_request"]
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self._info["method"]
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._info["url"]
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._info["headers"]
|
||||
|
||||
@property
|
||||
def query_params(self):
|
||||
return self._info["query_params"]
|
||||
|
||||
@property
|
||||
def path_params(self):
|
||||
return self._info["path_params"]
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if self._info["body"] is not None:
|
||||
return self._info["body"]
|
||||
return await self._original_request.body()
|
||||
|
||||
# 代理其他必要的属性到原始请求
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._original_request, name)
|
||||
Reference in New Issue
Block a user