🎉 init starlette dial
This commit is contained in:
157
pkg/dial/LOAD_BALANCING.md
Normal file
157
pkg/dial/LOAD_BALANCING.md
Normal file
@@ -0,0 +1,157 @@
|
||||
# 负载均衡功能
|
||||
|
||||
`ModifiableRequest` 类现在支持 round-robin 负载均衡功能,可以自动在多个后端服务器之间分发请求。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🔄 **Round-Robin 轮询**: 按顺序在多个主机之间分发请求
|
||||
- 🛡️ **线程安全**: 使用锁机制确保并发安全
|
||||
- 🔧 **灵活配置**: 支持带协议和不带协议的主机配置
|
||||
- 🎯 **智能路由**: 自动判断是否需要负载均衡
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 基本负载均衡
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
from pkg.dial.req import ModifiableRequest
|
||||
|
||||
def request_handler(request: ModifiableRequest):
|
||||
# 只指定路径,让负载均衡器自动选择主机
|
||||
request.rewrite_uri("/api/users")
|
||||
return request.build()
|
||||
|
||||
# 创建代理,支持多个后端服务器
|
||||
proxy_handler = client.proxy(
|
||||
hosts=["server1.com", "server2.com", "server3.com"],
|
||||
req_fn=request_handler
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 带协议的主机配置
|
||||
|
||||
```python
|
||||
# 支持混合协议配置
|
||||
hosts = [
|
||||
"http://server1.com",
|
||||
"https://server2.com",
|
||||
"server3.com" # 自动添加 http:// 前缀
|
||||
]
|
||||
|
||||
proxy_handler = client.proxy(hosts, request_handler)
|
||||
```
|
||||
|
||||
### 3. 绕过负载均衡
|
||||
|
||||
```python
|
||||
def request_handler_fixed(request: ModifiableRequest):
|
||||
# 指定完整 URL,绕过负载均衡
|
||||
request.rewrite_uri("https://specific-server.com/api/users")
|
||||
return request.build()
|
||||
|
||||
# 即使配置了多个主机,也会使用指定的 URL
|
||||
proxy_handler = client.proxy(
|
||||
hosts=["server1.com", "server2.com"],
|
||||
req_fn=request_handler_fixed
|
||||
)
|
||||
```
|
||||
|
||||
## 工作原理
|
||||
|
||||
### URL 解析逻辑
|
||||
|
||||
1. **检查 URL 是否包含主机**:
|
||||
- 如果 URL 包含主机(如 `https://example.com/api/users`),直接使用该 URL
|
||||
- 如果 URL 只有路径(如 `/api/users`),从主机列表中选择一个
|
||||
|
||||
2. **Round-Robin 选择**:
|
||||
- 使用线程安全的计数器
|
||||
- 按顺序循环选择主机
|
||||
- 自动处理主机数量变化
|
||||
|
||||
3. **URL 构建**:
|
||||
- 自动添加协议前缀(如果缺失)
|
||||
- 确保路径格式正确
|
||||
- 合并查询参数
|
||||
|
||||
### 示例请求分发
|
||||
|
||||
假设配置了 3 个主机:`["server1.com", "server2.com", "server3.com"]`
|
||||
|
||||
```
|
||||
请求 1: http://server1.com/api/users
|
||||
请求 2: http://server2.com/api/users
|
||||
请求 3: http://server3.com/api/users
|
||||
请求 4: http://server1.com/api/users # 重新开始轮询
|
||||
请求 5: http://server2.com/api/users
|
||||
...
|
||||
```
|
||||
|
||||
## 配置示例
|
||||
|
||||
### Starlette 应用中的使用
|
||||
|
||||
```python
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
from pkg.dial import client
|
||||
|
||||
def api_handler(request: ModifiableRequest):
|
||||
request.rewrite_uri("/api/v1/data")
|
||||
return request.build()
|
||||
|
||||
app = Starlette(routes=[
|
||||
Route(
|
||||
"/api/data",
|
||||
client.proxy(
|
||||
hosts=[
|
||||
"http://backend1:8080",
|
||||
"http://backend2:8080",
|
||||
"http://backend3:8080"
|
||||
],
|
||||
req_fn=api_handler
|
||||
)
|
||||
)
|
||||
])
|
||||
```
|
||||
|
||||
### 健康检查集成
|
||||
|
||||
```python
|
||||
def health_check_handler(request: ModifiableRequest):
|
||||
request.rewrite_uri("/health")
|
||||
return request.build()
|
||||
|
||||
# 健康检查路由
|
||||
health_proxy = client.proxy(
|
||||
hosts=["server1.com", "server2.com"],
|
||||
req_fn=health_check_handler
|
||||
)
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **主机格式**: 建议使用完整的主机名或 IP 地址
|
||||
2. **协议处理**: 如果不指定协议,默认使用 `http://`
|
||||
3. **并发安全**: 使用线程锁确保多线程环境下的安全性
|
||||
4. **错误处理**: 如果主机列表为空,会抛出 `ValueError` 异常
|
||||
5. **性能考虑**: 负载均衡器本身开销很小,适合高并发场景
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **主机无法访问**: 确保所有配置的主机都是可访问的
|
||||
2. **协议错误**: 检查主机配置中的协议是否正确
|
||||
3. **路径问题**: 确保路径格式正确(以 `/` 开头)
|
||||
|
||||
### 调试技巧
|
||||
|
||||
```python
|
||||
def debug_handler(request: ModifiableRequest):
|
||||
# 添加调试信息
|
||||
request.add_header("X-Debug", "true")
|
||||
request.rewrite_uri("/api/debug")
|
||||
return request.build()
|
||||
```
|
||||
168
pkg/dial/README.md
Normal file
168
pkg/dial/README.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# Dial - HTTP 客户端和代理工具
|
||||
|
||||
`Dial` 是一个基于 `httpx` 的异步 HTTP 客户端类,提供了连接池管理、直接请求和代理功能。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🔄 **连接池管理**: 自动管理 HTTP 连接池,提高性能
|
||||
- 🌐 **直接请求**: 支持各种 HTTP 方法的直接请求
|
||||
- 🔀 **代理功能**: 可以创建代理处理器,支持请求和响应重写
|
||||
- ⚡ **异步支持**: 完全异步实现,与 Starlette 框架完美集成
|
||||
- 🛡️ **类型安全**: 完整的类型注解支持
|
||||
|
||||
## 安装依赖
|
||||
|
||||
确保已安装 `httpx` 依赖:
|
||||
|
||||
```bash
|
||||
pip install httpx
|
||||
```
|
||||
|
||||
或在 `pyproject.toml` 中添加:
|
||||
|
||||
```toml
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
```
|
||||
|
||||
## 基本使用
|
||||
|
||||
### 1. 直接请求
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
|
||||
# 简单 GET 请求
|
||||
response = await client.dial("https://api.example.com/users")
|
||||
print(response.status_code)
|
||||
print(response.json())
|
||||
|
||||
# POST 请求带 JSON 数据
|
||||
response = await client.dial(
|
||||
url="https://api.example.com/users",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
body={"name": "John", "email": "john@example.com"}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 自定义客户端
|
||||
|
||||
```python
|
||||
from pkg.dial import Dial
|
||||
import httpx
|
||||
|
||||
# 创建自定义客户端
|
||||
dial = Dial(
|
||||
timeout=30.0,
|
||||
limits=httpx.Limits(max_keepalive_connections=10, max_connections=50)
|
||||
)
|
||||
|
||||
# 使用上下文管理器
|
||||
async with dial as client:
|
||||
response = await client.dial("https://api.example.com/data")
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
### 3. 代理功能
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# 请求处理函数
|
||||
def request_handler(request: Request):
|
||||
return {
|
||||
'url': f'https://api.example.com{request.url.path}',
|
||||
'method': request.method,
|
||||
'headers': dict(request.headers),
|
||||
'params': dict(request.query_params)
|
||||
}
|
||||
|
||||
# 响应处理函数
|
||||
def response_handler(response):
|
||||
return JSONResponse({
|
||||
'status': response.status_code,
|
||||
'data': response.json()
|
||||
})
|
||||
|
||||
# 创建代理处理器
|
||||
proxy_handler = client.proxy(
|
||||
req_fn=request_handler,
|
||||
res_fn=response_handler
|
||||
)
|
||||
|
||||
# 在 Starlette 应用中使用
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
app = Starlette(routes=[
|
||||
Route('/api/{path:path}', proxy_handler)
|
||||
])
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### Dial 类
|
||||
|
||||
#### 构造函数
|
||||
|
||||
```python
|
||||
Dial(timeout: float = 30.0, limits: Optional[httpx.Limits] = None)
|
||||
```
|
||||
|
||||
- `timeout`: 请求超时时间(秒)
|
||||
- `limits`: HTTP 连接限制配置
|
||||
|
||||
#### 方法
|
||||
|
||||
##### `dial()`
|
||||
|
||||
```python
|
||||
async def dial(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[Any] = None,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> httpx.Response
|
||||
```
|
||||
|
||||
直接发送 HTTP 请求。
|
||||
|
||||
##### `proxy()`
|
||||
|
||||
```python
|
||||
def proxy(
|
||||
self,
|
||||
req_fn: Optional[Callable[[Request], Dict[str, Any]]] = None,
|
||||
res_fn: Optional[Callable[[httpx.Response], Response]] = None
|
||||
) -> Callable[[Request], Response]
|
||||
```
|
||||
|
||||
创建代理处理器。
|
||||
|
||||
- `req_fn`: 请求处理函数,接收 `Request` 对象,返回请求配置字典
|
||||
- `res_fn`: 响应处理函数,接收 `httpx.Response` 对象,返回 `Response` 对象
|
||||
|
||||
##### `close()`
|
||||
|
||||
```python
|
||||
async def close()
|
||||
```
|
||||
|
||||
关闭客户端连接。
|
||||
|
||||
## 示例
|
||||
|
||||
查看 `examples/dial_example.py` 文件获取完整的使用示例。
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **连接池管理**: 客户端会自动管理连接池,无需手动管理
|
||||
2. **异步使用**: 所有方法都是异步的,需要在异步环境中使用
|
||||
3. **资源清理**: 使用上下文管理器或手动调用 `close()` 方法清理资源
|
||||
4. **类型安全**: 建议使用类型注解以获得更好的开发体验
|
||||
3
pkg/dial/__init__.py
Normal file
3
pkg/dial/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from pkg.dial.dial import Dial
|
||||
|
||||
client = Dial()
|
||||
199
pkg/dial/dial.py
Normal file
199
pkg/dial/dial.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import typing
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
import httpx
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
from pkg.dial.req import ModifiableRequest
|
||||
from pkg.resp.exception import (
|
||||
BadRequestException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
ServerErrorException,
|
||||
UnauthorizationException,
|
||||
)
|
||||
|
||||
|
||||
class Dial:
|
||||
def __init__(self, timeout: float = 30.0, limits: Optional[httpx.Limits] = None):
|
||||
"""
|
||||
初始化 Dial 客户端
|
||||
|
||||
Args:
|
||||
timeout: 请求超时时间(秒)
|
||||
limits: HTTP 连接限制
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.limits = limits or httpx.Limits(
|
||||
max_keepalive_connections=20, max_connections=100
|
||||
)
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self.timeout, limits=self.limits)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端连接"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[Any] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
直接请求一个 URL
|
||||
|
||||
Args:
|
||||
url: 目标 URL
|
||||
method: HTTP 方法
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
httpx.Response: HTTP 响应对象
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
# 处理请求体
|
||||
if body is not None:
|
||||
if isinstance(body, (dict, list)):
|
||||
body = json.dumps(body)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
# 发送请求
|
||||
response = await client.request(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
headers=headers or {},
|
||||
content=body,
|
||||
params=params or {},
|
||||
)
|
||||
|
||||
# 如果 response 的 status != 200
|
||||
# 尝试将 body 转换为 json
|
||||
# 获取 body.msg
|
||||
msg = ""
|
||||
status = response.status_code
|
||||
err = ""
|
||||
result = None
|
||||
|
||||
if response.status_code >= 300:
|
||||
print(f"[D] response.status_code = {response.status_code}")
|
||||
try:
|
||||
data = response.json()
|
||||
msg = data.get("msg", "")
|
||||
err = data.get("err", "")
|
||||
result = data.get("data", None)
|
||||
status = data.get("status", status)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response.status_code == 400:
|
||||
raise BadRequestException(msg=msg, status=status, err=err, data=result)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizationException(
|
||||
msg=msg, status=status, err=err, data=result
|
||||
)
|
||||
if response.status_code == 403:
|
||||
raise ForbiddenException(msg=msg, status=status, err=err, data=result)
|
||||
if response.status_code == 404:
|
||||
raise NotFoundException(msg=msg, status=status, err=err, data=result)
|
||||
raise ServerErrorException(msg=msg, status=status, err=err, data=result)
|
||||
|
||||
return response
|
||||
|
||||
def proxy(
|
||||
self,
|
||||
host: str | list[str],
|
||||
req_fn: Optional[Callable[[ModifiableRequest], Request]] = None,
|
||||
res_fn: Optional[Callable[[httpx.Response], Response]] = None,
|
||||
):
|
||||
"""
|
||||
创建代理处理器
|
||||
|
||||
Args:
|
||||
host: 目标主机列表
|
||||
req_fn: 请求处理函数,接收 ModifiableRequest 对象,返回修改后的 Request 对象
|
||||
res_fn: 响应处理函数,接收 httpx.Response 对象,返回 Starlette Response
|
||||
|
||||
Returns:
|
||||
代理处理器函数
|
||||
"""
|
||||
|
||||
if isinstance(host, str):
|
||||
host = [host]
|
||||
|
||||
if len(host) == 0:
|
||||
raise ValueError("host is empty")
|
||||
for h in host:
|
||||
if h == "":
|
||||
raise ValueError("host is empty")
|
||||
|
||||
async def proxy_handler(request: Request) -> Response:
|
||||
# 获取请求配置
|
||||
if req_fn:
|
||||
mr = ModifiableRequest(host, request)
|
||||
modified_request = req_fn(mr)
|
||||
|
||||
# 从修改后的请求中提取信息
|
||||
target_url = str(modified_request.url)
|
||||
method = modified_request.method
|
||||
headers = dict(modified_request.headers)
|
||||
body = await modified_request.body()
|
||||
params = dict(modified_request.query_params)
|
||||
else:
|
||||
# 默认转发到相同路径
|
||||
target_url = str(request.url)
|
||||
method = request.method
|
||||
headers = dict(request.headers)
|
||||
body = await request.body()
|
||||
params = dict(request.query_params)
|
||||
|
||||
# 移除可能导致问题的头部
|
||||
headers.pop("host", None)
|
||||
headers.pop("Host", None)
|
||||
|
||||
# 发送请求
|
||||
response = await self.dial(
|
||||
url=target_url, method=method, headers=headers, body=body, params=params
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
if res_fn:
|
||||
return res_fn(response)
|
||||
else:
|
||||
# 默认响应处理
|
||||
return StreamingResponse(
|
||||
response.aiter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
return proxy_handler
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.close()
|
||||
|
||||
|
||||
# 创建默认客户端实例
|
||||
client = Dial()
|
||||
267
pkg/dial/req.py
Normal file
267
pkg/dial/req.py
Normal file
@@ -0,0 +1,267 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.datastructures import Headers
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import copy
|
||||
from typing import Optional, Dict, Any
|
||||
import threading
|
||||
|
||||
|
||||
class ModifiableRequest:
|
||||
"""可修改的请求对象,由 Request 对象修改而来"""
|
||||
|
||||
# 类级别的 round-robin 计数器
|
||||
_counter = 0
|
||||
_counter_lock = threading.Lock()
|
||||
|
||||
def __init__(self, hosts: list[str], request: Request):
|
||||
"""
|
||||
初始化可修改请求对象
|
||||
|
||||
Args:
|
||||
hosts: 主机列表,用于负载均衡
|
||||
request: 原始 Starlette Request 对象
|
||||
"""
|
||||
self._original_request = request
|
||||
|
||||
# 复制原始请求的属性
|
||||
self._method = request.method
|
||||
self._hosts = hosts
|
||||
self._url = request.url
|
||||
self._headers = dict(request.headers)
|
||||
self._body: Optional[bytes] = None # 延迟加载
|
||||
self._query_params = dict(request.query_params)
|
||||
self._path_params = dict(request.path_params)
|
||||
|
||||
# 标记是否已修改
|
||||
self._modified = False
|
||||
|
||||
def _get_next_host(self) -> str:
|
||||
"""
|
||||
使用 round-robin 算法获取下一个主机
|
||||
|
||||
Returns:
|
||||
下一个主机的 URL
|
||||
"""
|
||||
if not self._hosts:
|
||||
raise ValueError("No hosts available for load balancing")
|
||||
|
||||
with self._counter_lock:
|
||||
host = self._hosts[ModifiableRequest._counter % len(self._hosts)]
|
||||
ModifiableRequest._counter += 1
|
||||
return host
|
||||
|
||||
def _build_url_with_host(self, path: str) -> str:
|
||||
"""
|
||||
使用选定的主机构建完整的 URL
|
||||
|
||||
Args:
|
||||
path: 请求路径
|
||||
|
||||
Returns:
|
||||
完整的 URL
|
||||
"""
|
||||
host = self._get_next_host()
|
||||
|
||||
# 确保 host 有协议前缀
|
||||
if not host.startswith(("http://", "https://")):
|
||||
host = "http://" + host
|
||||
|
||||
# 确保路径以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return host + path
|
||||
|
||||
def rewrite_uri(self, uri: str):
|
||||
"""
|
||||
重写请求 URI
|
||||
|
||||
Args:
|
||||
uri: 新的 URI
|
||||
"""
|
||||
self._url = uri
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_method(self, method: str):
|
||||
"""
|
||||
重写请求方法
|
||||
|
||||
Args:
|
||||
method: 新的 HTTP 方法
|
||||
"""
|
||||
self._method = method.upper()
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_headers(self, headers: Dict[str, str]):
|
||||
"""
|
||||
重写请求头
|
||||
|
||||
Args:
|
||||
headers: 新的请求头字典
|
||||
"""
|
||||
self._headers.update(headers)
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_body(self, body: bytes):
|
||||
"""
|
||||
重写请求体
|
||||
|
||||
Args:
|
||||
body: 新的请求体
|
||||
"""
|
||||
self._body = body
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def add_header(self, key: str, value: str):
|
||||
"""
|
||||
添加单个请求头
|
||||
|
||||
Args:
|
||||
key: 请求头名称
|
||||
value: 请求头值
|
||||
"""
|
||||
self._headers[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_header(self, key: str):
|
||||
"""
|
||||
移除请求头
|
||||
|
||||
Args:
|
||||
key: 要移除的请求头名称
|
||||
"""
|
||||
if key in self._headers:
|
||||
del self._headers[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def set_query_param(self, key: str, value: str):
|
||||
"""
|
||||
设置查询参数
|
||||
|
||||
Args:
|
||||
key: 参数名
|
||||
value: 参数值
|
||||
"""
|
||||
self._query_params[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_query_param(self, key: str):
|
||||
"""
|
||||
移除查询参数
|
||||
|
||||
Args:
|
||||
key: 要移除的参数名
|
||||
"""
|
||||
if key in self._query_params:
|
||||
del self._query_params[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
async def _get_body(self) -> bytes:
|
||||
"""获取请求体,如果未修改则从原始请求获取"""
|
||||
if self._body is not None:
|
||||
return self._body
|
||||
return await self._original_request.body()
|
||||
|
||||
def build(self) -> Request:
|
||||
"""
|
||||
构建新的 Request 对象
|
||||
|
||||
Returns:
|
||||
修改后的 Request 对象
|
||||
"""
|
||||
if not self._modified:
|
||||
# 如果没有修改,直接返回原始请求
|
||||
return self._original_request
|
||||
|
||||
# 构建新的 URL
|
||||
parsed_url = urlparse(str(self._url))
|
||||
|
||||
# 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡
|
||||
if not parsed_url.netloc:
|
||||
# 使用 round-robin 选择主机
|
||||
new_url = self._build_url_with_host(parsed_url.path)
|
||||
else:
|
||||
# URL 已经包含主机,直接使用
|
||||
new_url = str(self._url)
|
||||
|
||||
# 重新解析 URL 以处理查询参数
|
||||
parsed_url = urlparse(new_url)
|
||||
|
||||
# 添加查询参数
|
||||
query_params = []
|
||||
for key, value in self._query_params.items():
|
||||
query_params.append(f"{key}={value}")
|
||||
|
||||
if query_params:
|
||||
query_string = "&".join(query_params)
|
||||
# 如果 URL 已经有查询参数,则合并
|
||||
if parsed_url.query:
|
||||
query_string = f"{parsed_url.query}&{query_string}"
|
||||
parsed_url = parsed_url._replace(query=query_string)
|
||||
|
||||
final_url = urlunparse(parsed_url)
|
||||
|
||||
# 创建新的 Request 对象
|
||||
# 注意:这里我们需要创建一个模拟的 Request 对象
|
||||
# 因为 Starlette 的 Request 对象是不可变的
|
||||
# 我们返回一个包含修改后数据的字典,供代理处理器使用
|
||||
|
||||
# 为了保持兼容性,我们创建一个包含所有必要信息的字典
|
||||
# 然后在代理处理器中使用这些信息
|
||||
request_info = {
|
||||
"method": self._method,
|
||||
"url": final_url,
|
||||
"headers": Headers(self._headers),
|
||||
"query_params": self._query_params,
|
||||
"path_params": self._path_params,
|
||||
"body": self._body,
|
||||
"original_request": self._original_request,
|
||||
}
|
||||
|
||||
# 返回一个包装对象,模拟 Request 接口
|
||||
return RequestWrapper(request_info) # type: ignore
|
||||
|
||||
|
||||
class RequestWrapper:
|
||||
"""Request 包装器,提供修改后的请求信息"""
|
||||
|
||||
def __init__(self, request_info: Dict[str, Any]):
|
||||
self._info = request_info
|
||||
self._original_request = request_info["original_request"]
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self._info["method"]
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._info["url"]
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._info["headers"]
|
||||
|
||||
@property
|
||||
def query_params(self):
|
||||
return self._info["query_params"]
|
||||
|
||||
@property
|
||||
def path_params(self):
|
||||
return self._info["path_params"]
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if self._info["body"] is not None:
|
||||
return self._info["body"]
|
||||
return await self._original_request.body()
|
||||
|
||||
# 代理其他必要的属性到原始请求
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._original_request, name)
|
||||
82
pkg/resp/exception.py
Normal file
82
pkg/resp/exception.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class BadRequestException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 400
|
||||
self.msg = msg or "参数错误"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class ServerErrorException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 500
|
||||
self.msg = msg or "服务器开小差了"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class UnauthorizationException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 401
|
||||
self.msg = msg or "登录信息不存在或已过期, 请重新登录"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class ForbiddenException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 403
|
||||
self.msg = msg or "权限不足"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class NotFoundException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 404
|
||||
self.msg = msg or "资源不存在"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
def server_error_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def unauthorization_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
|
||||
def forbidden_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
|
||||
def bad_request_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
def not_found_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=404,
|
||||
)
|
||||
Reference in New Issue
Block a user