写了几百行 __init__、__repr__、__eq__,只为定义一个「数据容器」?Dataclass 让 Python 程序员从重复劳动中解放出来,专注于业务逻辑本身。
前言
在 AI 开发中,我们经常需要定义:
- Agent 配置:name、model、temperature、max_tokens…
- 消息对象:role、content、timestamp…
- 工具定义:name、description、parameters…
传统写法:
1 2 3 4 5 6 7 8 9 10 11
| class Message: def __init__(self, role: str, content: str, timestamp: float): self.role = role self.content = content self.timestamp = timestamp def __repr__(self): return f"Message(role={self.role!r}, content={self.content!r}, timestamp={self.timestamp!r})" def __eq__(self, other): if not isinstance(other, Message): return False return self.role == other.role and self.content == other.content and self.timestamp == other.timestamp
|
四五十行代码,只为定义一个数据结构? Dataclass 登场。
一、@dataclass 基础:自动生成魔法方法
1 2 3 4 5 6 7 8 9 10 11 12 13
| from dataclasses import dataclass, field from typing import Optional, ClassVar import time
@dataclass class Message: role: str content: str timestamp: float = field(default_factory=time.time)
msg = Message(role="user", content="Hello") print(msg)
|
输出:
1
| Message(role='user', content='Hello', timestamp=1745312400.123)
|
@dataclass 自动生成:
| 魔法方法 | 作用 |
|---|
__init__(self, ...) | 自动根据字段生成 |
__repr__(self) | 可读性良好的字符串 |
__eq__(self, other) | 按值比较(而非引用) |
💡 提示:field(default_factory=...) 用于可变默认值(如 list、dict),避免共享引用 bug。
二、post_init:初始化后验证
dataclass 生成 __init__ 后,会调用 __post_init__(如果存在)。这是验证和二次初始化的好地方。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| @dataclass class AgentConfig: name: str model: str temperature: float = 0.7 max_tokens: int = 2048 def __post_init__(self): if not 0.0 <= self.temperature <= 2.0: raise ValueError(f"Temperature 0-2 required, got {self.temperature}") if self.max_tokens <= 0: raise ValueError(f"max_tokens must be positive, got {self.max_tokens}")
config = AgentConfig(name="Assistant", model="gpt-4") print(config)
try: bad_config = AgentConfig(name="Test", model="gpt-4", temperature=5.0) except ValueError as e: print(f"❌ Validation error: {e}")
|
适用场景:
三、field 配置详解
field() 是 dataclass 的精细化配置神器:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| @dataclass class Conversation: messages: list[dict] = field(default_factory=list) system_prompt: str = "" max_history: int = field(default=100, compare=False) metadata: dict = field(default_factory=dict, metadata={"alias": "元数据"}) def add(self, role: str, content: str): self.messages.append({"role": role, "content": content}) while len(self.messages) > self.max_history: self.messages.pop(0)
conv = Conversation(system_prompt="You are helpful.") conv.add("user", "Hi") conv.add("assistant", "Hello!") print(f"Messages: {len(conv.messages)}") print(f"Conv: {conv}")
|
常用 field 参数
| 参数 | 作用 | 示例 |
|---|
default | 简单默认值 | field(default=42) |
default_factory | 可变对象工厂 | field(default_factory=list) |
compare | 是否参与 __eq__ | field(compare=False) |
metadata | 附加元数据 | field(metadata={"alias": "x"}) |
四、slots 模式:性能飞跃
Python 3.10+ 支持 @dataclass(slots=True),为 dataclass 启用 slots 优化。
1 2 3 4 5 6 7 8 9 10 11
| @dataclass(slots=True) class ToolResult: tool_name: str result: str success: bool = True error: Optional[str] = None
r = ToolResult("calc", "42") print(r)
|
slots vs 普通 dataclass
| 维度 | slots=True | 默认 |
|---|
| 内存占用 | 🟢 低(约减少 40%) | 🔴 较高 |
| 实例创建 | 🟢 快 | 🟡 略慢 |
| 动态字段 | ❌ 不支持 | ✅ 支持 |
| 继承 | ⚠️ 有限制 | ✅ 完全支持 |
AI 场景推荐:对于频繁创建的 ToolResult、Message 等轻量对象,启用 slots=True 可显著降低内存占用。
五、attrs 库对比
attrs 是 Python 另一个流行的数据类库,比 dataclass 更早出现,功能也更丰富。
1 2 3 4 5 6 7 8 9 10 11 12 13 14
| import attr
@attr.s(auto_attribs=True) class Message: role: str content: str timestamp: float = attr.Factory(time.time) @attr.s class Validator: @staticmethod def validate_message(msg): return len(msg.content) > 0
|
dataclass vs attrs
| 特性 | dataclass | attrs |
|---|
| 标准库 | ✅ 内置 | ❌ 需安装 |
slots 支持 | ✅ 3.10+ | ✅ @attr.s(slots=True) |
validators | ❌ 需手动 | ✅ 内置 |
converters | ❌ 无 | ✅ 内置 |
| 复杂度 | 简单 | 功能更多 |
建议:AI 开发中,优先使用标准库 dataclass;若需要 validators/converters,再考虑 attrs。
六、AI 实战:完整 Agent 示例
案例 1:消息与会话管理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| from dataclasses import dataclass, field from typing import Optional
@dataclass class Message: role: str content: str timestamp: float = field(default_factory=time.time) def to_dict(self) -> dict: return {"role": self.role, "content": self.content}
@dataclass class Conversation: messages: list[Message] = field(default_factory=list) system_prompt: str = "" max_history: int = 100 def add(self, role: str, content: str) -> None: msg = Message(role=role, content=content) self.messages.append(msg) while len(self.messages) > self.max_history: self.messages.pop(0) def get_context(self) -> list[dict]: ctx = [] if self.system_prompt: ctx.append({"role": "system", "content": self.system_prompt}) ctx.extend(m.to_dict() for m in self.messages) return ctx
|
案例 2:Agent 配置与工具定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| @dataclass class AgentConfig: name: str model: str temperature: float = 0.7 max_tokens: int = 2048 def __post_init__(self): if not 0.0 <= self.temperature <= 2.0: raise ValueError(f"Temperature must be 0-2, got {self.temperature}")
@dataclass class Tool: name: str description: str parameters: dict required: list[str] = field(default_factory=list) def validate(self, args: dict) -> bool: return all(k in args for k in self.required) def get_schema(self) -> dict: return { "name": self.name, "description": self.description, "parameters": self.parameters }
@dataclass class Agent: name: str config: AgentConfig tools: list[Tool] = field(default_factory=list) conversation: Conversation = field(default_factory=Conversation) def execute_tool(self, tool_name: str, args: dict) -> str: for tool in self.tools: if tool.name == tool_name and tool.validate(args): return f"Executed {tool_name}: {args}" return f"Tool {tool_name} not found or invalid args" def add_message(self, role: str, content: str): self.conversation.add(role, content) def get_context(self) -> list[dict]: return self.conversation.get_context()
tool = Tool( name="calculator", description="Perform math calculations", parameters={ "type": "object", "properties": { "expr": {"type": "string", "description": "Math expression"} } }, required=["expr"] )
agent = Agent( name="MathBot", config=AgentConfig(name="MathBot", model="gpt-4", temperature=0.5), tools=[tool] )
agent.add_message("user", "What's 2+2?") print(f"Context: {agent.get_context()}") print(agent.execute_tool("calculator", {"expr": "2+2"}))
|
输出:
1 2
| Context: [{'role': 'user', 'content': "What's 2+2?"}] Executed calculator: {'expr': '2+2'}
|
七、高级技巧
7.1 类变量 vs 实例字段
1 2 3 4 5 6 7 8 9 10
| @dataclass class ModelInfo: supported_models: ClassVar[list[str]] = ["gpt-4", "gpt-3.5", "claude-3"] name: str version: str
print(ModelInfo.supported_models)
|
7.2 继承注意事项
1 2 3 4 5 6 7 8 9
| @dataclass class BaseResponse: request_id: str
@dataclass class LLMResponse(BaseResponse): content: str model: str tokens_used: int = 0
|
7.3 字段排序
1 2 3 4 5 6
| from dataclasses import dataclass, field
@dataclass(order=True) class PriorityMessage: priority: int content: str = field(compare=False)
|
八、常见错误与最佳实践
| 错误 | 正确做法 | 原因 |
|---|
default=[] | field(default_factory=list) | 避免可变默认值共享 |
__init__ 中验证 | __post_init__ 中验证 | dataclass 不调用自定义 __init__ |
slots=True + 继承 | 谨慎使用 | 复杂继承链有问题 |
忘记 from dataclasses import field | 显式导入 | 避免运行时错误 |
九、总结
graph TB
DC["🟡 @dataclass"]
INIT["📗 自动 __init__"]
REPR["📗 自动 __repr__"]
EQ["📗 自动 __eq__"]
POST["📗 __post_init__"]
FIELD["📗 field() 精细配置"]
SLOTS["📗 slots=True 性能优化"]
DC --> INIT
DC --> REPR
DC --> EQ
DC --> POST
DC --> FIELD
DC --> SLOTS
INIT -->|"AI 场景"| AGENT["🤖 Agent 配置"]
REPR -->|"AI 场景"| MSG["💬 Message 对象"]
FIELD -->|"AI 场景"| CONV["📜 Conversation"]
SLOTS -->|"AI 场景"| RESULT["⚡ ToolResult"]
style DC fill:#FFF9C4,stroke:#F9A825,color:#333
style INIT fill:#C7CEEA,stroke:#9FA8DA,color:#333
style REPR fill:#B5EAD7,stroke:#80CBC4,color:#333
style EQ fill:#FFDAB9,stroke:#FFAB76,color:#333
style POST fill:#FFB3C6,stroke:#F48FB1,color:#333
style FIELD fill:#E8D5F5,stroke:#CE93D8,color:#333
style SLOTS fill:#C7CEEA,stroke:#9FA8DA,color:#333
style AGENT fill:#B5EAD7,stroke:#80CBC4,color:#333
style MSG fill:#B5EAD7,stroke:#80CBC4,color:#333
style CONV fill:#B5EAD7,stroke:#80CBC4,color:#333
style RESULT fill:#B5EAD7,stroke:#80CBC4,color:#333
核心要点:
@dataclass 自动生成 __init__、__repr__、__eq__field(default_factory=...) 处理可变默认值__post_init__ 是验证和二次初始化的好地方slots=True 显著提升大量创建时的内存效率- AI 开发中,dataclass 是定义 Agent、Tool、Message 的首选
下期预告:【Python AI教程】(六)异步编程:让 AI 调用不阻塞——asyncio 在 AI 开发中的实战,批量并发调用 LLM。
代码已通过 Python 3.11+ 验证。
📚 Python AI教程 系列导航
本文是《Python AI教程》系列第 5/14 篇。
📖 全部 14 篇目录(点击展开)
- (一)闭包与装饰器
- (二)上下文管理器
- (三)生成器与迭代器
- (四)类型提示
- (五)Dataclass 与 attrs ← 当前
- (六)async/await
- (七)Threading 与 Multiprocessing
- (八)函数式编程
- (九)描述符协议
- (十)元类
- (十一)Protocol与结构化类型
- (十二)异常链与日志
- (十三)缓存艺术
- (十四)组合模式实战