🎉 init starlette dial

This commit is contained in:
loveuer
2025-06-28 19:03:29 +08:00
commit 739a518e51
19 changed files with 1375 additions and 0 deletions

10
.gitignore vendored Normal file
View File

@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.13

0
README.md Normal file
View File

19
go.mod Normal file
View File

@ -0,0 +1,19 @@
module xtest
go 1.24.2
require github.com/gofiber/fiber/v2 v2.52.8
require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.51.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
golang.org/x/sys v0.28.0 // indirect
)

27
go.sum Normal file
View File

@ -0,0 +1,27 @@
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/gofiber/fiber/v2 v2.52.8 h1:xl4jJQ0BV5EJTA2aWiKw/VddRpHrKeZLF0QPUxqn0x4=
github.com/gofiber/fiber/v2 v2.52.8/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

38
handler/error.py Normal file
View File

@ -0,0 +1,38 @@
from starlette.requests import Request
from pkg.dial import client
from pkg.resp.exception import (
BadRequestException,
ForbiddenException,
ServerErrorException,
UnauthorizationException,
)
async def handle_400(request: Request):
raise BadRequestException(
msg="参数错误", data={"name": "张三"}, err={"code": 400, "msg": "参数错误"}
)
async def handle_500(request: Request):
raise ServerErrorException(
msg="服务器内部错误", data=None, err={"code": 500, "msg": "服务器内部错误"}
)
async def handle_401(request: Request):
raise UnauthorizationException(
msg="未授权", data=None, err={"code": 401, "msg": "未授权"}
)
async def handle_403(request: Request):
raise ForbiddenException(
msg="禁止访问", data=None, err={"code": 403, "msg": "禁止访问"}
)
async def handle_dail_401(request: Request):
res = await client.dial("http://127.0.0.1:3000/api/v3/auth/401", "GET")
return res

5
handler/home.py Normal file
View File

@ -0,0 +1,5 @@
from starlette.requests import Request
from starlette.responses import HTMLResponse
async def handle_home(request: Request):
return HTMLResponse('<h1>Hello World</h1>')

22
handler/other.py Normal file
View File

@ -0,0 +1,22 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
from pkg.dial import client
async def handle_login(request: Request):
response = await client.dial(
"http://127.0.0.1:3000/api/v3/auth/login",
"POST",
body={"username": "admin", "password": "123456"},
)
data = response.json()
print(f"[D] data = {data}")
return JSONResponse(data)
async def handle_auth_check(request: Request):
pass

62
main.py Normal file
View File

@ -0,0 +1,62 @@
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.requests import Request
from handler.error import (
handle_400,
handle_401,
handle_403,
handle_500,
handle_dail_401,
)
from handler.home import handle_home
from handler.other import handle_auth_check, handle_login
from pkg.dial import client
from pkg.dial.req import ModifiableRequest
from pkg.resp.exception import (
BadRequestException,
ForbiddenException,
NotFoundException,
ServerErrorException,
UnauthorizationException,
bad_request_exception_handler,
forbidden_exception_handler,
not_found_exception_handler,
server_error_exception_handler,
unauthorization_exception_handler,
)
def req_fn(request: ModifiableRequest):
request.rewrite_uri("/api/v3/auth/check")
return request.build()
app = Starlette(
debug=True,
routes=[
Route("/", handle_home),
Route("/400", handle_400),
Route("/500", handle_500),
Route("/401", handle_401),
Route("/403", handle_403),
Route("/api/login", handle_login, methods=["POST"]),
Route(
"/api/check", client.proxy("127.0.0.1:3000", req_fn, None), methods=["GET"]
),
Route("/api/dial/401", handle_dail_401, methods=["GET"]),
],
exception_handlers={
ServerErrorException: server_error_exception_handler,
UnauthorizationException: unauthorization_exception_handler,
ForbiddenException: forbidden_exception_handler,
BadRequestException: bad_request_exception_handler,
NotFoundException: not_found_exception_handler,
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

157
pkg/dial/LOAD_BALANCING.md Normal file
View File

@ -0,0 +1,157 @@
# 负载均衡功能
`ModifiableRequest` 类现在支持 round-robin 负载均衡功能,可以自动在多个后端服务器之间分发请求。
## 功能特性
- 🔄 **Round-Robin 轮询**: 按顺序在多个主机之间分发请求
- 🛡️ **线程安全**: 使用锁机制确保并发安全
- 🔧 **灵活配置**: 支持带协议和不带协议的主机配置
- 🎯 **智能路由**: 自动判断是否需要负载均衡
## 使用方法
### 1. 基本负载均衡
```python
from pkg.dial import client
from pkg.dial.req import ModifiableRequest
def request_handler(request: ModifiableRequest):
# 只指定路径,让负载均衡器自动选择主机
request.rewrite_uri("/api/users")
return request.build()
# 创建代理,支持多个后端服务器
proxy_handler = client.proxy(
hosts=["server1.com", "server2.com", "server3.com"],
req_fn=request_handler
)
```
### 2. 带协议的主机配置
```python
# 支持混合协议配置
hosts = [
"http://server1.com",
"https://server2.com",
"server3.com" # 自动添加 http:// 前缀
]
proxy_handler = client.proxy(hosts, request_handler)
```
### 3. 绕过负载均衡
```python
def request_handler_fixed(request: ModifiableRequest):
# 指定完整 URL绕过负载均衡
request.rewrite_uri("https://specific-server.com/api/users")
return request.build()
# 即使配置了多个主机,也会使用指定的 URL
proxy_handler = client.proxy(
hosts=["server1.com", "server2.com"],
req_fn=request_handler_fixed
)
```
## 工作原理
### URL 解析逻辑
1. **检查 URL 是否包含主机**:
- 如果 URL 包含主机(如 `https://example.com/api/users`),直接使用该 URL
- 如果 URL 只有路径(如 `/api/users`),从主机列表中选择一个
2. **Round-Robin 选择**:
- 使用线程安全的计数器
- 按顺序循环选择主机
- 自动处理主机数量变化
3. **URL 构建**:
- 自动添加协议前缀(如果缺失)
- 确保路径格式正确
- 合并查询参数
### 示例请求分发
假设配置了 3 个主机:`["server1.com", "server2.com", "server3.com"]`
```
请求 1: http://server1.com/api/users
请求 2: http://server2.com/api/users
请求 3: http://server3.com/api/users
请求 4: http://server1.com/api/users # 重新开始轮询
请求 5: http://server2.com/api/users
...
```
## 配置示例
### Starlette 应用中的使用
```python
from starlette.applications import Starlette
from starlette.routing import Route
from pkg.dial import client
def api_handler(request: ModifiableRequest):
request.rewrite_uri("/api/v1/data")
return request.build()
app = Starlette(routes=[
Route(
"/api/data",
client.proxy(
hosts=[
"http://backend1:8080",
"http://backend2:8080",
"http://backend3:8080"
],
req_fn=api_handler
)
)
])
```
### 健康检查集成
```python
def health_check_handler(request: ModifiableRequest):
request.rewrite_uri("/health")
return request.build()
# 健康检查路由
health_proxy = client.proxy(
hosts=["server1.com", "server2.com"],
req_fn=health_check_handler
)
```
## 注意事项
1. **主机格式**: 建议使用完整的主机名或 IP 地址
2. **协议处理**: 如果不指定协议,默认使用 `http://`
3. **并发安全**: 使用线程锁确保多线程环境下的安全性
4. **错误处理**: 如果主机列表为空,会抛出 `ValueError` 异常
5. **性能考虑**: 负载均衡器本身开销很小,适合高并发场景
## 故障排除
### 常见问题
1. **主机无法访问**: 确保所有配置的主机都是可访问的
2. **协议错误**: 检查主机配置中的协议是否正确
3. **路径问题**: 确保路径格式正确(以 `/` 开头)
### 调试技巧
```python
def debug_handler(request: ModifiableRequest):
# 添加调试信息
request.add_header("X-Debug", "true")
request.rewrite_uri("/api/debug")
return request.build()
```

168
pkg/dial/README.md Normal file
View File

@ -0,0 +1,168 @@
# Dial - HTTP 客户端和代理工具
`Dial` 是一个基于 `httpx` 的异步 HTTP 客户端类,提供了连接池管理、直接请求和代理功能。
## 功能特性
- 🔄 **连接池管理**: 自动管理 HTTP 连接池,提高性能
- 🌐 **直接请求**: 支持各种 HTTP 方法的直接请求
- 🔀 **代理功能**: 可以创建代理处理器,支持请求和响应重写
-**异步支持**: 完全异步实现,与 Starlette 框架完美集成
- 🛡️ **类型安全**: 完整的类型注解支持
## 安装依赖
确保已安装 `httpx` 依赖:
```bash
pip install httpx
```
或在 `pyproject.toml` 中添加:
```toml
dependencies = [
"httpx>=0.27.0",
]
```
## 基本使用
### 1. 直接请求
```python
from pkg.dial import client
# 简单 GET 请求
response = await client.dial("https://api.example.com/users")
print(response.status_code)
print(response.json())
# POST 请求带 JSON 数据
response = await client.dial(
url="https://api.example.com/users",
method="POST",
headers={"Content-Type": "application/json"},
body={"name": "John", "email": "john@example.com"}
)
```
### 2. 自定义客户端
```python
from pkg.dial import Dial
import httpx
# 创建自定义客户端
dial = Dial(
timeout=30.0,
limits=httpx.Limits(max_keepalive_connections=10, max_connections=50)
)
# 使用上下文管理器
async with dial as client:
response = await client.dial("https://api.example.com/data")
print(response.text)
```
### 3. 代理功能
```python
from pkg.dial import client
from starlette.requests import Request
from starlette.responses import JSONResponse
# 请求处理函数
def request_handler(request: Request):
return {
'url': f'https://api.example.com{request.url.path}',
'method': request.method,
'headers': dict(request.headers),
'params': dict(request.query_params)
}
# 响应处理函数
def response_handler(response):
return JSONResponse({
'status': response.status_code,
'data': response.json()
})
# 创建代理处理器
proxy_handler = client.proxy(
req_fn=request_handler,
res_fn=response_handler
)
# 在 Starlette 应用中使用
from starlette.applications import Starlette
from starlette.routing import Route
app = Starlette(routes=[
Route('/api/{path:path}', proxy_handler)
])
```
## API 参考
### Dial 类
#### 构造函数
```python
Dial(timeout: float = 30.0, limits: Optional[httpx.Limits] = None)
```
- `timeout`: 请求超时时间(秒)
- `limits`: HTTP 连接限制配置
#### 方法
##### `dial()`
```python
async def dial(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
params: Optional[Dict[str, Any]] = None
) -> httpx.Response
```
直接发送 HTTP 请求。
##### `proxy()`
```python
def proxy(
self,
req_fn: Optional[Callable[[Request], Dict[str, Any]]] = None,
res_fn: Optional[Callable[[httpx.Response], Response]] = None
) -> Callable[[Request], Response]
```
创建代理处理器。
- `req_fn`: 请求处理函数,接收 `Request` 对象,返回请求配置字典
- `res_fn`: 响应处理函数,接收 `httpx.Response` 对象,返回 `Response` 对象
##### `close()`
```python
async def close()
```
关闭客户端连接。
## 示例
查看 `examples/dial_example.py` 文件获取完整的使用示例。
## 注意事项
1. **连接池管理**: 客户端会自动管理连接池,无需手动管理
2. **异步使用**: 所有方法都是异步的,需要在异步环境中使用
3. **资源清理**: 使用上下文管理器或手动调用 `close()` 方法清理资源
4. **类型安全**: 建议使用类型注解以获得更好的开发体验

3
pkg/dial/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from pkg.dial.dial import Dial
client = Dial()

199
pkg/dial/dial.py Normal file
View File

@ -0,0 +1,199 @@
import typing
from typing import Optional, Dict, Any, Callable
import httpx
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import StreamingResponse
import json
from pkg.dial.req import ModifiableRequest
from pkg.resp.exception import (
BadRequestException,
ForbiddenException,
NotFoundException,
ServerErrorException,
UnauthorizationException,
)
class Dial:
def __init__(self, timeout: float = 30.0, limits: Optional[httpx.Limits] = None):
"""
初始化 Dial 客户端
Args:
timeout: 请求超时时间(秒)
limits: HTTP 连接限制
"""
self.timeout = timeout
self.limits = limits or httpx.Limits(
max_keepalive_connections=20, max_connections=100
)
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建 HTTP 客户端"""
if self._client is None:
self._client = httpx.AsyncClient(timeout=self.timeout, limits=self.limits)
return self._client
async def close(self):
"""关闭客户端连接"""
if self._client:
await self._client.aclose()
self._client = None
async def dial(
self,
url: str,
method: str = "GET",
headers: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
params: Optional[Dict[str, Any]] = None,
) -> httpx.Response:
"""
直接请求一个 URL
Args:
url: 目标 URL
method: HTTP 方法
headers: 请求头
body: 请求体
params: 查询参数
Returns:
httpx.Response: HTTP 响应对象
"""
client = await self._get_client()
# 处理请求体
if body is not None:
if isinstance(body, (dict, list)):
body = json.dumps(body)
if headers is None:
headers = {}
headers.setdefault("Content-Type", "application/json")
# 发送请求
response = await client.request(
method=method.upper(),
url=url,
headers=headers or {},
content=body,
params=params or {},
)
# 如果 response 的 status != 200
# 尝试将 body 转换为 json
# 获取 body.msg
msg = ""
status = response.status_code
err = ""
result = None
if response.status_code >= 300:
print(f"[D] response.status_code = {response.status_code}")
try:
data = response.json()
msg = data.get("msg", "")
err = data.get("err", "")
result = data.get("data", None)
status = data.get("status", status)
except Exception:
pass
if response.status_code == 400:
raise BadRequestException(msg=msg, status=status, err=err, data=result)
if response.status_code == 401:
raise UnauthorizationException(
msg=msg, status=status, err=err, data=result
)
if response.status_code == 403:
raise ForbiddenException(msg=msg, status=status, err=err, data=result)
if response.status_code == 404:
raise NotFoundException(msg=msg, status=status, err=err, data=result)
raise ServerErrorException(msg=msg, status=status, err=err, data=result)
return response
def proxy(
self,
host: str | list[str],
req_fn: Optional[Callable[[ModifiableRequest], Request]] = None,
res_fn: Optional[Callable[[httpx.Response], Response]] = None,
):
"""
创建代理处理器
Args:
host: 目标主机列表
req_fn: 请求处理函数,接收 ModifiableRequest 对象,返回修改后的 Request 对象
res_fn: 响应处理函数,接收 httpx.Response 对象,返回 Starlette Response
Returns:
代理处理器函数
"""
if isinstance(host, str):
host = [host]
if len(host) == 0:
raise ValueError("host is empty")
for h in host:
if h == "":
raise ValueError("host is empty")
async def proxy_handler(request: Request) -> Response:
# 获取请求配置
if req_fn:
mr = ModifiableRequest(host, request)
modified_request = req_fn(mr)
# 从修改后的请求中提取信息
target_url = str(modified_request.url)
method = modified_request.method
headers = dict(modified_request.headers)
body = await modified_request.body()
params = dict(modified_request.query_params)
else:
# 默认转发到相同路径
target_url = str(request.url)
method = request.method
headers = dict(request.headers)
body = await request.body()
params = dict(request.query_params)
# 移除可能导致问题的头部
headers.pop("host", None)
headers.pop("Host", None)
# 发送请求
response = await self.dial(
url=target_url, method=method, headers=headers, body=body, params=params
)
# 处理响应
if res_fn:
return res_fn(response)
else:
# 默认响应处理
return StreamingResponse(
response.aiter_bytes(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
return proxy_handler
async def __aenter__(self):
"""异步上下文管理器入口"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.close()
# 创建默认客户端实例
client = Dial()

267
pkg/dial/req.py Normal file
View File

@ -0,0 +1,267 @@
from starlette.requests import Request
from starlette.datastructures import Headers
from urllib.parse import urlparse, urlunparse
import copy
from typing import Optional, Dict, Any
import threading
class ModifiableRequest:
"""可修改的请求对象,由 Request 对象修改而来"""
# 类级别的 round-robin 计数器
_counter = 0
_counter_lock = threading.Lock()
def __init__(self, hosts: list[str], request: Request):
"""
初始化可修改请求对象
Args:
hosts: 主机列表,用于负载均衡
request: 原始 Starlette Request 对象
"""
self._original_request = request
# 复制原始请求的属性
self._method = request.method
self._hosts = hosts
self._url = request.url
self._headers = dict(request.headers)
self._body: Optional[bytes] = None # 延迟加载
self._query_params = dict(request.query_params)
self._path_params = dict(request.path_params)
# 标记是否已修改
self._modified = False
def _get_next_host(self) -> str:
"""
使用 round-robin 算法获取下一个主机
Returns:
下一个主机的 URL
"""
if not self._hosts:
raise ValueError("No hosts available for load balancing")
with self._counter_lock:
host = self._hosts[ModifiableRequest._counter % len(self._hosts)]
ModifiableRequest._counter += 1
return host
def _build_url_with_host(self, path: str) -> str:
"""
使用选定的主机构建完整的 URL
Args:
path: 请求路径
Returns:
完整的 URL
"""
host = self._get_next_host()
# 确保 host 有协议前缀
if not host.startswith(("http://", "https://")):
host = "http://" + host
# 确保路径以 / 开头
if not path.startswith("/"):
path = "/" + path
return host + path
def rewrite_uri(self, uri: str):
"""
重写请求 URI
Args:
uri: 新的 URI
"""
self._url = uri
self._modified = True
return self
def rewrite_method(self, method: str):
"""
重写请求方法
Args:
method: 新的 HTTP 方法
"""
self._method = method.upper()
self._modified = True
return self
def rewrite_headers(self, headers: Dict[str, str]):
"""
重写请求头
Args:
headers: 新的请求头字典
"""
self._headers.update(headers)
self._modified = True
return self
def rewrite_body(self, body: bytes):
"""
重写请求体
Args:
body: 新的请求体
"""
self._body = body
self._modified = True
return self
def add_header(self, key: str, value: str):
"""
添加单个请求头
Args:
key: 请求头名称
value: 请求头值
"""
self._headers[key] = value
self._modified = True
return self
def remove_header(self, key: str):
"""
移除请求头
Args:
key: 要移除的请求头名称
"""
if key in self._headers:
del self._headers[key]
self._modified = True
return self
def set_query_param(self, key: str, value: str):
"""
设置查询参数
Args:
key: 参数名
value: 参数值
"""
self._query_params[key] = value
self._modified = True
return self
def remove_query_param(self, key: str):
"""
移除查询参数
Args:
key: 要移除的参数名
"""
if key in self._query_params:
del self._query_params[key]
self._modified = True
return self
async def _get_body(self) -> bytes:
"""获取请求体,如果未修改则从原始请求获取"""
if self._body is not None:
return self._body
return await self._original_request.body()
def build(self) -> Request:
"""
构建新的 Request 对象
Returns:
修改后的 Request 对象
"""
if not self._modified:
# 如果没有修改,直接返回原始请求
return self._original_request
# 构建新的 URL
parsed_url = urlparse(str(self._url))
# 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡
if not parsed_url.netloc:
# 使用 round-robin 选择主机
new_url = self._build_url_with_host(parsed_url.path)
else:
# URL 已经包含主机,直接使用
new_url = str(self._url)
# 重新解析 URL 以处理查询参数
parsed_url = urlparse(new_url)
# 添加查询参数
query_params = []
for key, value in self._query_params.items():
query_params.append(f"{key}={value}")
if query_params:
query_string = "&".join(query_params)
# 如果 URL 已经有查询参数,则合并
if parsed_url.query:
query_string = f"{parsed_url.query}&{query_string}"
parsed_url = parsed_url._replace(query=query_string)
final_url = urlunparse(parsed_url)
# 创建新的 Request 对象
# 注意:这里我们需要创建一个模拟的 Request 对象
# 因为 Starlette 的 Request 对象是不可变的
# 我们返回一个包含修改后数据的字典,供代理处理器使用
# 为了保持兼容性,我们创建一个包含所有必要信息的字典
# 然后在代理处理器中使用这些信息
request_info = {
"method": self._method,
"url": final_url,
"headers": Headers(self._headers),
"query_params": self._query_params,
"path_params": self._path_params,
"body": self._body,
"original_request": self._original_request,
}
# 返回一个包装对象,模拟 Request 接口
return RequestWrapper(request_info) # type: ignore
class RequestWrapper:
"""Request 包装器,提供修改后的请求信息"""
def __init__(self, request_info: Dict[str, Any]):
self._info = request_info
self._original_request = request_info["original_request"]
@property
def method(self) -> str:
return self._info["method"]
@property
def url(self):
return self._info["url"]
@property
def headers(self):
return self._info["headers"]
@property
def query_params(self):
return self._info["query_params"]
@property
def path_params(self):
return self._info["path_params"]
async def body(self) -> bytes:
if self._info["body"] is not None:
return self._info["body"]
return await self._original_request.body()
# 代理其他必要的属性到原始请求
def __getattr__(self, name):
return getattr(self._original_request, name)

82
pkg/resp/exception.py Normal file
View File

@ -0,0 +1,82 @@
from starlette.requests import Request
from starlette.responses import JSONResponse
class BadRequestException(Exception):
def __init__(self, msg="", data=None, err=None, status=None):
self.status = status or 400
self.msg = msg or "参数错误"
self.data = data
self.err = err
super().__init__(self.msg)
class ServerErrorException(Exception):
def __init__(self, msg="", data=None, err=None, status=None):
self.status = status or 500
self.msg = msg or "服务器开小差了"
self.data = data
self.err = err
super().__init__(self.msg)
class UnauthorizationException(Exception):
def __init__(self, msg="", data=None, err=None, status=None):
self.status = status or 401
self.msg = msg or "登录信息不存在或已过期, 请重新登录"
self.data = data
self.err = err
super().__init__(self.msg)
class ForbiddenException(Exception):
def __init__(self, msg="", data=None, err=None, status=None):
self.status = status or 403
self.msg = msg or "权限不足"
self.data = data
self.err = err
super().__init__(self.msg)
class NotFoundException(Exception):
def __init__(self, msg="", data=None, err=None, status=None):
self.status = status or 404
self.msg = msg or "资源不存在"
self.data = data
self.err = err
super().__init__(self.msg)
def server_error_exception_handler(request: Request, exc):
return JSONResponse(
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
status_code=500,
)
def unauthorization_exception_handler(request: Request, exc):
return JSONResponse(
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
status_code=401,
)
def forbidden_exception_handler(request: Request, exc):
return JSONResponse(
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
status_code=403,
)
def bad_request_exception_handler(request: Request, exc):
return JSONResponse(
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
status_code=400,
)
def not_found_exception_handler(request: Request, exc):
return JSONResponse(
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
status_code=404,
)

11
pyproject.toml Normal file
View File

@ -0,0 +1,11 @@
[project]
name = "python-starlette"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"starlette>=0.47.1",
"uvicorn>=0.34.3",
"httpx>=0.27.0",
]

57
test_dial.py Normal file
View File

@ -0,0 +1,57 @@
"""
测试 client.dial 功能
"""
import asyncio
from pkg.dial import client
async def test_dial():
"""测试 client.dial 的基本功能"""
print("=== 测试 client.dial ===")
try:
# 测试 GET 请求到 httpbin.org
print("测试 GET 请求...")
response = await client.dial("https://httpbin.org/get")
print(f"状态码: {response.status_code}")
print(f"响应头: {dict(response.headers)}")
data = response.json()
print(f"响应数据: {data.get('url', 'N/A')}")
print()
# 测试 POST 请求
print("测试 POST 请求...")
response = await client.dial(
url="https://httpbin.org/post",
method="POST",
headers={"Content-Type": "application/json"},
body={"test": "data", "message": "Hello, World!"},
)
print(f"状态码: {response.status_code}")
data = response.json()
print(f"响应数据: {data.get('json', 'N/A')}")
print()
# 测试带查询参数的请求
print("测试带查询参数的请求...")
response = await client.dial(
url="https://httpbin.org/get",
params={"param1": "value1", "param2": "value2"},
)
print(f"状态码: {response.status_code}")
data = response.json()
print(f"查询参数: {data.get('args', 'N/A')}")
print()
print("✅ 所有测试通过!")
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_dial())

143
uv.lock generated Normal file
View File

@ -0,0 +1,143 @@
version = 1
revision = 1
requires-python = ">=3.13"
[[package]]
name = "anyio"
version = "4.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "idna" },
{ name = "sniffio" },
]
sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 },
]
[[package]]
name = "certifi"
version = "2025.6.15"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/73/f7/f14b46d4bcd21092d7d3ccef689615220d8a08fb25e564b65d20738e672e/certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b", size = 158753 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650 },
]
[[package]]
name = "click"
version = "8.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215 },
]
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
]
[[package]]
name = "h11"
version = "0.16.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 },
]
[[package]]
name = "httpcore"
version = "1.0.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "certifi" },
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784 },
]
[[package]]
name = "httpx"
version = "0.28.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
{ name = "certifi" },
{ name = "httpcore" },
{ name = "idna" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
]
[[package]]
name = "idna"
version = "3.10"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
]
[[package]]
name = "python-starlette"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "httpx" },
{ name = "starlette" },
{ name = "uvicorn" },
]
[package.metadata]
requires-dist = [
{ name = "httpx", specifier = ">=0.27.0" },
{ name = "starlette", specifier = ">=0.47.1" },
{ name = "uvicorn", specifier = ">=0.34.3" },
]
[[package]]
name = "sniffio"
version = "1.3.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
]
[[package]]
name = "starlette"
version = "0.47.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "anyio" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747 },
]
[[package]]
name = "uvicorn"
version = "0.34.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "h11" },
]
sdist = { url = "https://files.pythonhosted.org/packages/de/ad/713be230bcda622eaa35c28f0d328c3675c371238470abdea52417f17a8e/uvicorn-0.34.3.tar.gz", hash = "sha256:35919a9a979d7a59334b6b10e05d77c1d0d574c50e0fc98b8b1a0f165708b55a", size = 76631 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6d/0d/8adfeaa62945f90d19ddc461c55f4a50c258af7662d34b6a3d5d1f8646f6/uvicorn-0.34.3-py3-none-any.whl", hash = "sha256:16246631db62bdfbf069b0645177d6e8a77ba950cfedbfd093acef9444e4d885", size = 62431 },
]

104
xtest/main.go Normal file
View File

@ -0,0 +1,104 @@
package main
import (
"fmt"
"log"
"math/rand/v2"
"sync"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
)
func main() {
app := fiber.New()
app.Use(logger.New())
app.Post("/api/v3/auth/login", handleLogin)
app.Get("/api/v3/auth/check", handleAuthCheck)
app.Get("/api/v3/auth/401", func(c *fiber.Ctx) error {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"err": "custom err 401",
})
})
log.Fatal(app.Listen(":3000"))
}
type User struct {
Username string `json:"username"`
Password string `json:"password"`
}
var (
store = &sync.Map{}
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
mockUsers = map[string]User{
"admin": {
Username: "admin",
Password: "123456",
},
"user": {
Username: "user",
Password: "12345678",
},
}
)
func randomString(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letters[rand.IntN(len(letters))]
}
return string(b)
}
func handleLogin(c *fiber.Ctx) error {
type Req struct {
Username string `json:"username"`
Password string `json:"password"`
}
var (
err error
req Req
)
if err = c.BodyParser(&req); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": err.Error(),
})
}
user, ok := mockUsers[req.Username]
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "user not found",
})
}
session := randomString(16)
store.Store(session, user)
c.Response().Header.Set("Set-Cookie", fmt.Sprintf("session=%s; Path=/; HttpOnly; SameSite=Lax", session))
c.Response().Header.Set("Server", "Fiber")
return c.Status(fiber.StatusOK).JSON(user)
}
func handleAuthCheck(c *fiber.Ctx) error {
session := c.Cookies("session")
if session == "" {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "unauthorized",
})
}
user, ok := store.Load(session)
if !ok {
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
"error": "unauthorized",
})
}
return c.Status(fiber.StatusOK).JSON(user)
}