🎉 init starlette dial
This commit is contained in:
10
.gitignore
vendored
Normal file
10
.gitignore
vendored
Normal file
@ -0,0 +1,10 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.13
|
19
go.mod
Normal file
19
go.mod
Normal file
@ -0,0 +1,19 @@
|
||||
module xtest
|
||||
|
||||
go 1.24.2
|
||||
|
||||
require github.com/gofiber/fiber/v2 v2.52.8
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.51.0 // indirect
|
||||
github.com/valyala/tcplisten v1.0.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
)
|
27
go.sum
Normal file
27
go.sum
Normal file
@ -0,0 +1,27 @@
|
||||
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
|
||||
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
|
||||
github.com/gofiber/fiber/v2 v2.52.8 h1:xl4jJQ0BV5EJTA2aWiKw/VddRpHrKeZLF0QPUxqn0x4=
|
||||
github.com/gofiber/fiber/v2 v2.52.8/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA=
|
||||
github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g=
|
||||
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
|
||||
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
38
handler/error.py
Normal file
38
handler/error.py
Normal file
@ -0,0 +1,38 @@
|
||||
from starlette.requests import Request
|
||||
|
||||
from pkg.dial import client
|
||||
from pkg.resp.exception import (
|
||||
BadRequestException,
|
||||
ForbiddenException,
|
||||
ServerErrorException,
|
||||
UnauthorizationException,
|
||||
)
|
||||
|
||||
|
||||
async def handle_400(request: Request):
|
||||
raise BadRequestException(
|
||||
msg="参数错误", data={"name": "张三"}, err={"code": 400, "msg": "参数错误"}
|
||||
)
|
||||
|
||||
|
||||
async def handle_500(request: Request):
|
||||
raise ServerErrorException(
|
||||
msg="服务器内部错误", data=None, err={"code": 500, "msg": "服务器内部错误"}
|
||||
)
|
||||
|
||||
|
||||
async def handle_401(request: Request):
|
||||
raise UnauthorizationException(
|
||||
msg="未授权", data=None, err={"code": 401, "msg": "未授权"}
|
||||
)
|
||||
|
||||
|
||||
async def handle_403(request: Request):
|
||||
raise ForbiddenException(
|
||||
msg="禁止访问", data=None, err={"code": 403, "msg": "禁止访问"}
|
||||
)
|
||||
|
||||
|
||||
async def handle_dail_401(request: Request):
|
||||
res = await client.dial("http://127.0.0.1:3000/api/v3/auth/401", "GET")
|
||||
return res
|
5
handler/home.py
Normal file
5
handler/home.py
Normal file
@ -0,0 +1,5 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import HTMLResponse
|
||||
|
||||
async def handle_home(request: Request):
|
||||
return HTMLResponse('<h1>Hello World</h1>')
|
22
handler/other.py
Normal file
22
handler/other.py
Normal file
@ -0,0 +1,22 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from pkg.dial import client
|
||||
|
||||
|
||||
async def handle_login(request: Request):
|
||||
response = await client.dial(
|
||||
"http://127.0.0.1:3000/api/v3/auth/login",
|
||||
"POST",
|
||||
body={"username": "admin", "password": "123456"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
print(f"[D] data = {data}")
|
||||
|
||||
return JSONResponse(data)
|
||||
|
||||
|
||||
async def handle_auth_check(request: Request):
|
||||
pass
|
62
main.py
Normal file
62
main.py
Normal file
@ -0,0 +1,62 @@
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.requests import Request
|
||||
|
||||
from handler.error import (
|
||||
handle_400,
|
||||
handle_401,
|
||||
handle_403,
|
||||
handle_500,
|
||||
handle_dail_401,
|
||||
)
|
||||
from handler.home import handle_home
|
||||
from handler.other import handle_auth_check, handle_login
|
||||
from pkg.dial import client
|
||||
from pkg.dial.req import ModifiableRequest
|
||||
from pkg.resp.exception import (
|
||||
BadRequestException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
ServerErrorException,
|
||||
UnauthorizationException,
|
||||
bad_request_exception_handler,
|
||||
forbidden_exception_handler,
|
||||
not_found_exception_handler,
|
||||
server_error_exception_handler,
|
||||
unauthorization_exception_handler,
|
||||
)
|
||||
|
||||
|
||||
def req_fn(request: ModifiableRequest):
|
||||
request.rewrite_uri("/api/v3/auth/check")
|
||||
return request.build()
|
||||
|
||||
|
||||
app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/", handle_home),
|
||||
Route("/400", handle_400),
|
||||
Route("/500", handle_500),
|
||||
Route("/401", handle_401),
|
||||
Route("/403", handle_403),
|
||||
Route("/api/login", handle_login, methods=["POST"]),
|
||||
Route(
|
||||
"/api/check", client.proxy("127.0.0.1:3000", req_fn, None), methods=["GET"]
|
||||
),
|
||||
Route("/api/dial/401", handle_dail_401, methods=["GET"]),
|
||||
],
|
||||
exception_handlers={
|
||||
ServerErrorException: server_error_exception_handler,
|
||||
UnauthorizationException: unauthorization_exception_handler,
|
||||
ForbiddenException: forbidden_exception_handler,
|
||||
BadRequestException: bad_request_exception_handler,
|
||||
NotFoundException: not_found_exception_handler,
|
||||
},
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
157
pkg/dial/LOAD_BALANCING.md
Normal file
157
pkg/dial/LOAD_BALANCING.md
Normal file
@ -0,0 +1,157 @@
|
||||
# 负载均衡功能
|
||||
|
||||
`ModifiableRequest` 类现在支持 round-robin 负载均衡功能,可以自动在多个后端服务器之间分发请求。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🔄 **Round-Robin 轮询**: 按顺序在多个主机之间分发请求
|
||||
- 🛡️ **线程安全**: 使用锁机制确保并发安全
|
||||
- 🔧 **灵活配置**: 支持带协议和不带协议的主机配置
|
||||
- 🎯 **智能路由**: 自动判断是否需要负载均衡
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 基本负载均衡
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
from pkg.dial.req import ModifiableRequest
|
||||
|
||||
def request_handler(request: ModifiableRequest):
|
||||
# 只指定路径,让负载均衡器自动选择主机
|
||||
request.rewrite_uri("/api/users")
|
||||
return request.build()
|
||||
|
||||
# 创建代理,支持多个后端服务器
|
||||
proxy_handler = client.proxy(
|
||||
hosts=["server1.com", "server2.com", "server3.com"],
|
||||
req_fn=request_handler
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 带协议的主机配置
|
||||
|
||||
```python
|
||||
# 支持混合协议配置
|
||||
hosts = [
|
||||
"http://server1.com",
|
||||
"https://server2.com",
|
||||
"server3.com" # 自动添加 http:// 前缀
|
||||
]
|
||||
|
||||
proxy_handler = client.proxy(hosts, request_handler)
|
||||
```
|
||||
|
||||
### 3. 绕过负载均衡
|
||||
|
||||
```python
|
||||
def request_handler_fixed(request: ModifiableRequest):
|
||||
# 指定完整 URL,绕过负载均衡
|
||||
request.rewrite_uri("https://specific-server.com/api/users")
|
||||
return request.build()
|
||||
|
||||
# 即使配置了多个主机,也会使用指定的 URL
|
||||
proxy_handler = client.proxy(
|
||||
hosts=["server1.com", "server2.com"],
|
||||
req_fn=request_handler_fixed
|
||||
)
|
||||
```
|
||||
|
||||
## 工作原理
|
||||
|
||||
### URL 解析逻辑
|
||||
|
||||
1. **检查 URL 是否包含主机**:
|
||||
- 如果 URL 包含主机(如 `https://example.com/api/users`),直接使用该 URL
|
||||
- 如果 URL 只有路径(如 `/api/users`),从主机列表中选择一个
|
||||
|
||||
2. **Round-Robin 选择**:
|
||||
- 使用线程安全的计数器
|
||||
- 按顺序循环选择主机
|
||||
- 自动处理主机数量变化
|
||||
|
||||
3. **URL 构建**:
|
||||
- 自动添加协议前缀(如果缺失)
|
||||
- 确保路径格式正确
|
||||
- 合并查询参数
|
||||
|
||||
### 示例请求分发
|
||||
|
||||
假设配置了 3 个主机:`["server1.com", "server2.com", "server3.com"]`
|
||||
|
||||
```
|
||||
请求 1: http://server1.com/api/users
|
||||
请求 2: http://server2.com/api/users
|
||||
请求 3: http://server3.com/api/users
|
||||
请求 4: http://server1.com/api/users # 重新开始轮询
|
||||
请求 5: http://server2.com/api/users
|
||||
...
|
||||
```
|
||||
|
||||
## 配置示例
|
||||
|
||||
### Starlette 应用中的使用
|
||||
|
||||
```python
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
from pkg.dial import client
|
||||
|
||||
def api_handler(request: ModifiableRequest):
|
||||
request.rewrite_uri("/api/v1/data")
|
||||
return request.build()
|
||||
|
||||
app = Starlette(routes=[
|
||||
Route(
|
||||
"/api/data",
|
||||
client.proxy(
|
||||
hosts=[
|
||||
"http://backend1:8080",
|
||||
"http://backend2:8080",
|
||||
"http://backend3:8080"
|
||||
],
|
||||
req_fn=api_handler
|
||||
)
|
||||
)
|
||||
])
|
||||
```
|
||||
|
||||
### 健康检查集成
|
||||
|
||||
```python
|
||||
def health_check_handler(request: ModifiableRequest):
|
||||
request.rewrite_uri("/health")
|
||||
return request.build()
|
||||
|
||||
# 健康检查路由
|
||||
health_proxy = client.proxy(
|
||||
hosts=["server1.com", "server2.com"],
|
||||
req_fn=health_check_handler
|
||||
)
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **主机格式**: 建议使用完整的主机名或 IP 地址
|
||||
2. **协议处理**: 如果不指定协议,默认使用 `http://`
|
||||
3. **并发安全**: 使用线程锁确保多线程环境下的安全性
|
||||
4. **错误处理**: 如果主机列表为空,会抛出 `ValueError` 异常
|
||||
5. **性能考虑**: 负载均衡器本身开销很小,适合高并发场景
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **主机无法访问**: 确保所有配置的主机都是可访问的
|
||||
2. **协议错误**: 检查主机配置中的协议是否正确
|
||||
3. **路径问题**: 确保路径格式正确(以 `/` 开头)
|
||||
|
||||
### 调试技巧
|
||||
|
||||
```python
|
||||
def debug_handler(request: ModifiableRequest):
|
||||
# 添加调试信息
|
||||
request.add_header("X-Debug", "true")
|
||||
request.rewrite_uri("/api/debug")
|
||||
return request.build()
|
||||
```
|
168
pkg/dial/README.md
Normal file
168
pkg/dial/README.md
Normal file
@ -0,0 +1,168 @@
|
||||
# Dial - HTTP 客户端和代理工具
|
||||
|
||||
`Dial` 是一个基于 `httpx` 的异步 HTTP 客户端类,提供了连接池管理、直接请求和代理功能。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 🔄 **连接池管理**: 自动管理 HTTP 连接池,提高性能
|
||||
- 🌐 **直接请求**: 支持各种 HTTP 方法的直接请求
|
||||
- 🔀 **代理功能**: 可以创建代理处理器,支持请求和响应重写
|
||||
- ⚡ **异步支持**: 完全异步实现,与 Starlette 框架完美集成
|
||||
- 🛡️ **类型安全**: 完整的类型注解支持
|
||||
|
||||
## 安装依赖
|
||||
|
||||
确保已安装 `httpx` 依赖:
|
||||
|
||||
```bash
|
||||
pip install httpx
|
||||
```
|
||||
|
||||
或在 `pyproject.toml` 中添加:
|
||||
|
||||
```toml
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
]
|
||||
```
|
||||
|
||||
## 基本使用
|
||||
|
||||
### 1. 直接请求
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
|
||||
# 简单 GET 请求
|
||||
response = await client.dial("https://api.example.com/users")
|
||||
print(response.status_code)
|
||||
print(response.json())
|
||||
|
||||
# POST 请求带 JSON 数据
|
||||
response = await client.dial(
|
||||
url="https://api.example.com/users",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
body={"name": "John", "email": "john@example.com"}
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 自定义客户端
|
||||
|
||||
```python
|
||||
from pkg.dial import Dial
|
||||
import httpx
|
||||
|
||||
# 创建自定义客户端
|
||||
dial = Dial(
|
||||
timeout=30.0,
|
||||
limits=httpx.Limits(max_keepalive_connections=10, max_connections=50)
|
||||
)
|
||||
|
||||
# 使用上下文管理器
|
||||
async with dial as client:
|
||||
response = await client.dial("https://api.example.com/data")
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
### 3. 代理功能
|
||||
|
||||
```python
|
||||
from pkg.dial import client
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
# 请求处理函数
|
||||
def request_handler(request: Request):
|
||||
return {
|
||||
'url': f'https://api.example.com{request.url.path}',
|
||||
'method': request.method,
|
||||
'headers': dict(request.headers),
|
||||
'params': dict(request.query_params)
|
||||
}
|
||||
|
||||
# 响应处理函数
|
||||
def response_handler(response):
|
||||
return JSONResponse({
|
||||
'status': response.status_code,
|
||||
'data': response.json()
|
||||
})
|
||||
|
||||
# 创建代理处理器
|
||||
proxy_handler = client.proxy(
|
||||
req_fn=request_handler,
|
||||
res_fn=response_handler
|
||||
)
|
||||
|
||||
# 在 Starlette 应用中使用
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
app = Starlette(routes=[
|
||||
Route('/api/{path:path}', proxy_handler)
|
||||
])
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### Dial 类
|
||||
|
||||
#### 构造函数
|
||||
|
||||
```python
|
||||
Dial(timeout: float = 30.0, limits: Optional[httpx.Limits] = None)
|
||||
```
|
||||
|
||||
- `timeout`: 请求超时时间(秒)
|
||||
- `limits`: HTTP 连接限制配置
|
||||
|
||||
#### 方法
|
||||
|
||||
##### `dial()`
|
||||
|
||||
```python
|
||||
async def dial(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[Any] = None,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> httpx.Response
|
||||
```
|
||||
|
||||
直接发送 HTTP 请求。
|
||||
|
||||
##### `proxy()`
|
||||
|
||||
```python
|
||||
def proxy(
|
||||
self,
|
||||
req_fn: Optional[Callable[[Request], Dict[str, Any]]] = None,
|
||||
res_fn: Optional[Callable[[httpx.Response], Response]] = None
|
||||
) -> Callable[[Request], Response]
|
||||
```
|
||||
|
||||
创建代理处理器。
|
||||
|
||||
- `req_fn`: 请求处理函数,接收 `Request` 对象,返回请求配置字典
|
||||
- `res_fn`: 响应处理函数,接收 `httpx.Response` 对象,返回 `Response` 对象
|
||||
|
||||
##### `close()`
|
||||
|
||||
```python
|
||||
async def close()
|
||||
```
|
||||
|
||||
关闭客户端连接。
|
||||
|
||||
## 示例
|
||||
|
||||
查看 `examples/dial_example.py` 文件获取完整的使用示例。
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **连接池管理**: 客户端会自动管理连接池,无需手动管理
|
||||
2. **异步使用**: 所有方法都是异步的,需要在异步环境中使用
|
||||
3. **资源清理**: 使用上下文管理器或手动调用 `close()` 方法清理资源
|
||||
4. **类型安全**: 建议使用类型注解以获得更好的开发体验
|
3
pkg/dial/__init__.py
Normal file
3
pkg/dial/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from pkg.dial.dial import Dial
|
||||
|
||||
client = Dial()
|
199
pkg/dial/dial.py
Normal file
199
pkg/dial/dial.py
Normal file
@ -0,0 +1,199 @@
|
||||
import typing
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
import httpx
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
from pkg.dial.req import ModifiableRequest
|
||||
from pkg.resp.exception import (
|
||||
BadRequestException,
|
||||
ForbiddenException,
|
||||
NotFoundException,
|
||||
ServerErrorException,
|
||||
UnauthorizationException,
|
||||
)
|
||||
|
||||
|
||||
class Dial:
|
||||
def __init__(self, timeout: float = 30.0, limits: Optional[httpx.Limits] = None):
|
||||
"""
|
||||
初始化 Dial 客户端
|
||||
|
||||
Args:
|
||||
timeout: 请求超时时间(秒)
|
||||
limits: HTTP 连接限制
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.limits = limits or httpx.Limits(
|
||||
max_keepalive_connections=20, max_connections=100
|
||||
)
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self.timeout, limits=self.limits)
|
||||
return self._client
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端连接"""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def dial(
|
||||
self,
|
||||
url: str,
|
||||
method: str = "GET",
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
body: Optional[Any] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
直接请求一个 URL
|
||||
|
||||
Args:
|
||||
url: 目标 URL
|
||||
method: HTTP 方法
|
||||
headers: 请求头
|
||||
body: 请求体
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
httpx.Response: HTTP 响应对象
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
# 处理请求体
|
||||
if body is not None:
|
||||
if isinstance(body, (dict, list)):
|
||||
body = json.dumps(body)
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
|
||||
# 发送请求
|
||||
response = await client.request(
|
||||
method=method.upper(),
|
||||
url=url,
|
||||
headers=headers or {},
|
||||
content=body,
|
||||
params=params or {},
|
||||
)
|
||||
|
||||
# 如果 response 的 status != 200
|
||||
# 尝试将 body 转换为 json
|
||||
# 获取 body.msg
|
||||
msg = ""
|
||||
status = response.status_code
|
||||
err = ""
|
||||
result = None
|
||||
|
||||
if response.status_code >= 300:
|
||||
print(f"[D] response.status_code = {response.status_code}")
|
||||
try:
|
||||
data = response.json()
|
||||
msg = data.get("msg", "")
|
||||
err = data.get("err", "")
|
||||
result = data.get("data", None)
|
||||
status = data.get("status", status)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response.status_code == 400:
|
||||
raise BadRequestException(msg=msg, status=status, err=err, data=result)
|
||||
if response.status_code == 401:
|
||||
raise UnauthorizationException(
|
||||
msg=msg, status=status, err=err, data=result
|
||||
)
|
||||
if response.status_code == 403:
|
||||
raise ForbiddenException(msg=msg, status=status, err=err, data=result)
|
||||
if response.status_code == 404:
|
||||
raise NotFoundException(msg=msg, status=status, err=err, data=result)
|
||||
raise ServerErrorException(msg=msg, status=status, err=err, data=result)
|
||||
|
||||
return response
|
||||
|
||||
def proxy(
|
||||
self,
|
||||
host: str | list[str],
|
||||
req_fn: Optional[Callable[[ModifiableRequest], Request]] = None,
|
||||
res_fn: Optional[Callable[[httpx.Response], Response]] = None,
|
||||
):
|
||||
"""
|
||||
创建代理处理器
|
||||
|
||||
Args:
|
||||
host: 目标主机列表
|
||||
req_fn: 请求处理函数,接收 ModifiableRequest 对象,返回修改后的 Request 对象
|
||||
res_fn: 响应处理函数,接收 httpx.Response 对象,返回 Starlette Response
|
||||
|
||||
Returns:
|
||||
代理处理器函数
|
||||
"""
|
||||
|
||||
if isinstance(host, str):
|
||||
host = [host]
|
||||
|
||||
if len(host) == 0:
|
||||
raise ValueError("host is empty")
|
||||
for h in host:
|
||||
if h == "":
|
||||
raise ValueError("host is empty")
|
||||
|
||||
async def proxy_handler(request: Request) -> Response:
|
||||
# 获取请求配置
|
||||
if req_fn:
|
||||
mr = ModifiableRequest(host, request)
|
||||
modified_request = req_fn(mr)
|
||||
|
||||
# 从修改后的请求中提取信息
|
||||
target_url = str(modified_request.url)
|
||||
method = modified_request.method
|
||||
headers = dict(modified_request.headers)
|
||||
body = await modified_request.body()
|
||||
params = dict(modified_request.query_params)
|
||||
else:
|
||||
# 默认转发到相同路径
|
||||
target_url = str(request.url)
|
||||
method = request.method
|
||||
headers = dict(request.headers)
|
||||
body = await request.body()
|
||||
params = dict(request.query_params)
|
||||
|
||||
# 移除可能导致问题的头部
|
||||
headers.pop("host", None)
|
||||
headers.pop("Host", None)
|
||||
|
||||
# 发送请求
|
||||
response = await self.dial(
|
||||
url=target_url, method=method, headers=headers, body=body, params=params
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
if res_fn:
|
||||
return res_fn(response)
|
||||
else:
|
||||
# 默认响应处理
|
||||
return StreamingResponse(
|
||||
response.aiter_bytes(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
return proxy_handler
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.close()
|
||||
|
||||
|
||||
# 创建默认客户端实例
|
||||
client = Dial()
|
267
pkg/dial/req.py
Normal file
267
pkg/dial/req.py
Normal file
@ -0,0 +1,267 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.datastructures import Headers
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import copy
|
||||
from typing import Optional, Dict, Any
|
||||
import threading
|
||||
|
||||
|
||||
class ModifiableRequest:
|
||||
"""可修改的请求对象,由 Request 对象修改而来"""
|
||||
|
||||
# 类级别的 round-robin 计数器
|
||||
_counter = 0
|
||||
_counter_lock = threading.Lock()
|
||||
|
||||
def __init__(self, hosts: list[str], request: Request):
|
||||
"""
|
||||
初始化可修改请求对象
|
||||
|
||||
Args:
|
||||
hosts: 主机列表,用于负载均衡
|
||||
request: 原始 Starlette Request 对象
|
||||
"""
|
||||
self._original_request = request
|
||||
|
||||
# 复制原始请求的属性
|
||||
self._method = request.method
|
||||
self._hosts = hosts
|
||||
self._url = request.url
|
||||
self._headers = dict(request.headers)
|
||||
self._body: Optional[bytes] = None # 延迟加载
|
||||
self._query_params = dict(request.query_params)
|
||||
self._path_params = dict(request.path_params)
|
||||
|
||||
# 标记是否已修改
|
||||
self._modified = False
|
||||
|
||||
def _get_next_host(self) -> str:
|
||||
"""
|
||||
使用 round-robin 算法获取下一个主机
|
||||
|
||||
Returns:
|
||||
下一个主机的 URL
|
||||
"""
|
||||
if not self._hosts:
|
||||
raise ValueError("No hosts available for load balancing")
|
||||
|
||||
with self._counter_lock:
|
||||
host = self._hosts[ModifiableRequest._counter % len(self._hosts)]
|
||||
ModifiableRequest._counter += 1
|
||||
return host
|
||||
|
||||
def _build_url_with_host(self, path: str) -> str:
|
||||
"""
|
||||
使用选定的主机构建完整的 URL
|
||||
|
||||
Args:
|
||||
path: 请求路径
|
||||
|
||||
Returns:
|
||||
完整的 URL
|
||||
"""
|
||||
host = self._get_next_host()
|
||||
|
||||
# 确保 host 有协议前缀
|
||||
if not host.startswith(("http://", "https://")):
|
||||
host = "http://" + host
|
||||
|
||||
# 确保路径以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return host + path
|
||||
|
||||
def rewrite_uri(self, uri: str):
|
||||
"""
|
||||
重写请求 URI
|
||||
|
||||
Args:
|
||||
uri: 新的 URI
|
||||
"""
|
||||
self._url = uri
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_method(self, method: str):
|
||||
"""
|
||||
重写请求方法
|
||||
|
||||
Args:
|
||||
method: 新的 HTTP 方法
|
||||
"""
|
||||
self._method = method.upper()
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_headers(self, headers: Dict[str, str]):
|
||||
"""
|
||||
重写请求头
|
||||
|
||||
Args:
|
||||
headers: 新的请求头字典
|
||||
"""
|
||||
self._headers.update(headers)
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def rewrite_body(self, body: bytes):
|
||||
"""
|
||||
重写请求体
|
||||
|
||||
Args:
|
||||
body: 新的请求体
|
||||
"""
|
||||
self._body = body
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def add_header(self, key: str, value: str):
|
||||
"""
|
||||
添加单个请求头
|
||||
|
||||
Args:
|
||||
key: 请求头名称
|
||||
value: 请求头值
|
||||
"""
|
||||
self._headers[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_header(self, key: str):
|
||||
"""
|
||||
移除请求头
|
||||
|
||||
Args:
|
||||
key: 要移除的请求头名称
|
||||
"""
|
||||
if key in self._headers:
|
||||
del self._headers[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def set_query_param(self, key: str, value: str):
|
||||
"""
|
||||
设置查询参数
|
||||
|
||||
Args:
|
||||
key: 参数名
|
||||
value: 参数值
|
||||
"""
|
||||
self._query_params[key] = value
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
def remove_query_param(self, key: str):
|
||||
"""
|
||||
移除查询参数
|
||||
|
||||
Args:
|
||||
key: 要移除的参数名
|
||||
"""
|
||||
if key in self._query_params:
|
||||
del self._query_params[key]
|
||||
self._modified = True
|
||||
return self
|
||||
|
||||
async def _get_body(self) -> bytes:
|
||||
"""获取请求体,如果未修改则从原始请求获取"""
|
||||
if self._body is not None:
|
||||
return self._body
|
||||
return await self._original_request.body()
|
||||
|
||||
def build(self) -> Request:
|
||||
"""
|
||||
构建新的 Request 对象
|
||||
|
||||
Returns:
|
||||
修改后的 Request 对象
|
||||
"""
|
||||
if not self._modified:
|
||||
# 如果没有修改,直接返回原始请求
|
||||
return self._original_request
|
||||
|
||||
# 构建新的 URL
|
||||
parsed_url = urlparse(str(self._url))
|
||||
|
||||
# 如果 URL 中没有主机(只有路径),则从 hosts 中选择一个进行负载均衡
|
||||
if not parsed_url.netloc:
|
||||
# 使用 round-robin 选择主机
|
||||
new_url = self._build_url_with_host(parsed_url.path)
|
||||
else:
|
||||
# URL 已经包含主机,直接使用
|
||||
new_url = str(self._url)
|
||||
|
||||
# 重新解析 URL 以处理查询参数
|
||||
parsed_url = urlparse(new_url)
|
||||
|
||||
# 添加查询参数
|
||||
query_params = []
|
||||
for key, value in self._query_params.items():
|
||||
query_params.append(f"{key}={value}")
|
||||
|
||||
if query_params:
|
||||
query_string = "&".join(query_params)
|
||||
# 如果 URL 已经有查询参数,则合并
|
||||
if parsed_url.query:
|
||||
query_string = f"{parsed_url.query}&{query_string}"
|
||||
parsed_url = parsed_url._replace(query=query_string)
|
||||
|
||||
final_url = urlunparse(parsed_url)
|
||||
|
||||
# 创建新的 Request 对象
|
||||
# 注意:这里我们需要创建一个模拟的 Request 对象
|
||||
# 因为 Starlette 的 Request 对象是不可变的
|
||||
# 我们返回一个包含修改后数据的字典,供代理处理器使用
|
||||
|
||||
# 为了保持兼容性,我们创建一个包含所有必要信息的字典
|
||||
# 然后在代理处理器中使用这些信息
|
||||
request_info = {
|
||||
"method": self._method,
|
||||
"url": final_url,
|
||||
"headers": Headers(self._headers),
|
||||
"query_params": self._query_params,
|
||||
"path_params": self._path_params,
|
||||
"body": self._body,
|
||||
"original_request": self._original_request,
|
||||
}
|
||||
|
||||
# 返回一个包装对象,模拟 Request 接口
|
||||
return RequestWrapper(request_info) # type: ignore
|
||||
|
||||
|
||||
class RequestWrapper:
|
||||
"""Request 包装器,提供修改后的请求信息"""
|
||||
|
||||
def __init__(self, request_info: Dict[str, Any]):
|
||||
self._info = request_info
|
||||
self._original_request = request_info["original_request"]
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self._info["method"]
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
return self._info["url"]
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._info["headers"]
|
||||
|
||||
@property
|
||||
def query_params(self):
|
||||
return self._info["query_params"]
|
||||
|
||||
@property
|
||||
def path_params(self):
|
||||
return self._info["path_params"]
|
||||
|
||||
async def body(self) -> bytes:
|
||||
if self._info["body"] is not None:
|
||||
return self._info["body"]
|
||||
return await self._original_request.body()
|
||||
|
||||
# 代理其他必要的属性到原始请求
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._original_request, name)
|
82
pkg/resp/exception.py
Normal file
82
pkg/resp/exception.py
Normal file
@ -0,0 +1,82 @@
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class BadRequestException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 400
|
||||
self.msg = msg or "参数错误"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class ServerErrorException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 500
|
||||
self.msg = msg or "服务器开小差了"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class UnauthorizationException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 401
|
||||
self.msg = msg or "登录信息不存在或已过期, 请重新登录"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class ForbiddenException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 403
|
||||
self.msg = msg or "权限不足"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
class NotFoundException(Exception):
|
||||
def __init__(self, msg="", data=None, err=None, status=None):
|
||||
self.status = status or 404
|
||||
self.msg = msg or "资源不存在"
|
||||
self.data = data
|
||||
self.err = err
|
||||
super().__init__(self.msg)
|
||||
|
||||
|
||||
def server_error_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
def unauthorization_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
|
||||
def forbidden_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
|
||||
def bad_request_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
|
||||
def not_found_exception_handler(request: Request, exc):
|
||||
return JSONResponse(
|
||||
{"msg": exc.msg, "data": exc.data, "err": exc.err, "status": exc.status},
|
||||
status_code=404,
|
||||
)
|
11
pyproject.toml
Normal file
11
pyproject.toml
Normal file
@ -0,0 +1,11 @@
|
||||
[project]
|
||||
name = "python-starlette"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"starlette>=0.47.1",
|
||||
"uvicorn>=0.34.3",
|
||||
"httpx>=0.27.0",
|
||||
]
|
57
test_dial.py
Normal file
57
test_dial.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
测试 client.dial 功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pkg.dial import client
|
||||
|
||||
|
||||
async def test_dial():
|
||||
"""测试 client.dial 的基本功能"""
|
||||
print("=== 测试 client.dial ===")
|
||||
|
||||
try:
|
||||
# 测试 GET 请求到 httpbin.org
|
||||
print("测试 GET 请求...")
|
||||
response = await client.dial("https://httpbin.org/get")
|
||||
print(f"状态码: {response.status_code}")
|
||||
print(f"响应头: {dict(response.headers)}")
|
||||
data = response.json()
|
||||
print(f"响应数据: {data.get('url', 'N/A')}")
|
||||
print()
|
||||
|
||||
# 测试 POST 请求
|
||||
print("测试 POST 请求...")
|
||||
response = await client.dial(
|
||||
url="https://httpbin.org/post",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
body={"test": "data", "message": "Hello, World!"},
|
||||
)
|
||||
print(f"状态码: {response.status_code}")
|
||||
data = response.json()
|
||||
print(f"响应数据: {data.get('json', 'N/A')}")
|
||||
print()
|
||||
|
||||
# 测试带查询参数的请求
|
||||
print("测试带查询参数的请求...")
|
||||
response = await client.dial(
|
||||
url="https://httpbin.org/get",
|
||||
params={"param1": "value1", "param2": "value2"},
|
||||
)
|
||||
print(f"状态码: {response.status_code}")
|
||||
data = response.json()
|
||||
print(f"查询参数: {data.get('args', 'N/A')}")
|
||||
print()
|
||||
|
||||
print("✅ 所有测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 测试失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_dial())
|
143
uv.lock
generated
Normal file
143
uv.lock
generated
Normal file
@ -0,0 +1,143 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.13"
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "idna" },
|
||||
{ name = "sniffio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.6.15"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/73/f7/f14b46d4bcd21092d7d3ccef689615220d8a08fb25e564b65d20738e672e/certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b", size = 158753 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.28.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "certifi" },
|
||||
{ name = "httpcore" },
|
||||
{ name = "idna" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.10"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-starlette"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
{ name = "starlette" },
|
||||
{ name = "uvicorn" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.27.0" },
|
||||
{ name = "starlette", specifier = ">=0.47.1" },
|
||||
{ name = "uvicorn", specifier = ">=0.34.3" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.47.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0a/69/662169fdb92fb96ec3eaee218cf540a629d629c86d7993d9651226a6789b/starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b", size = 2583072 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "uvicorn"
|
||||
version = "0.34.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "h11" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/de/ad/713be230bcda622eaa35c28f0d328c3675c371238470abdea52417f17a8e/uvicorn-0.34.3.tar.gz", hash = "sha256:35919a9a979d7a59334b6b10e05d77c1d0d574c50e0fc98b8b1a0f165708b55a", size = 76631 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/0d/8adfeaa62945f90d19ddc461c55f4a50c258af7662d34b6a3d5d1f8646f6/uvicorn-0.34.3-py3-none-any.whl", hash = "sha256:16246631db62bdfbf069b0645177d6e8a77ba950cfedbfd093acef9444e4d885", size = 62431 },
|
||||
]
|
104
xtest/main.go
Normal file
104
xtest/main.go
Normal file
@ -0,0 +1,104 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := fiber.New()
|
||||
app.Use(logger.New())
|
||||
|
||||
app.Post("/api/v3/auth/login", handleLogin)
|
||||
app.Get("/api/v3/auth/check", handleAuthCheck)
|
||||
app.Get("/api/v3/auth/401", func(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"err": "custom err 401",
|
||||
})
|
||||
})
|
||||
|
||||
log.Fatal(app.Listen(":3000"))
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
store = &sync.Map{}
|
||||
letters = []byte("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
mockUsers = map[string]User{
|
||||
"admin": {
|
||||
Username: "admin",
|
||||
Password: "123456",
|
||||
},
|
||||
"user": {
|
||||
Username: "user",
|
||||
Password: "12345678",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func randomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letters[rand.IntN(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func handleLogin(c *fiber.Ctx) error {
|
||||
type Req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
req Req
|
||||
)
|
||||
|
||||
if err = c.BodyParser(&req); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
user, ok := mockUsers[req.Username]
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "user not found",
|
||||
})
|
||||
}
|
||||
|
||||
session := randomString(16)
|
||||
store.Store(session, user)
|
||||
c.Response().Header.Set("Set-Cookie", fmt.Sprintf("session=%s; Path=/; HttpOnly; SameSite=Lax", session))
|
||||
c.Response().Header.Set("Server", "Fiber")
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(user)
|
||||
}
|
||||
|
||||
func handleAuthCheck(c *fiber.Ctx) error {
|
||||
session := c.Cookies("session")
|
||||
if session == "" {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "unauthorized",
|
||||
})
|
||||
}
|
||||
|
||||
user, ok := store.Load(session)
|
||||
if !ok {
|
||||
return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{
|
||||
"error": "unauthorized",
|
||||
})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(user)
|
||||
}
|
Reference in New Issue
Block a user