1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| from dataclasses import dataclass, field from typing import Optional, Literal, Any from typing import Protocol
@dataclass class Message: role: Literal["user", "assistant", "system", "tool", "developer"] content: str name: Optional[str] = None tool_call_id: Optional[str] = None
@dataclass class ToolCall: id: str name: str args: dict[str, Any]
@dataclass class ToolResult: tool_call_id: str content: str is_error: bool = False
@dataclass class AgentState: messages: list[Message] = field(default_factory=list) tool_calls: list[ToolCall] = field(default_factory=list) tool_results: list[ToolResult] = field(default_factory=list) current_step: int = 0
def add_message(self, role: Literal["user", "assistant", "system", "tool"], content: str, **kwargs) -> None: self.messages.append(Message(role=role, content=content, **kwargs)) def add_tool_result(self, tool_call_id: str, content: str, is_error: bool = False) -> None: self.tool_results.append(ToolResult(tool_call_id=tool_call_id, content=content, is_error=is_error))
class Tool(Protocol): @property def name(self) -> str: ... @property def description(self) -> str: ... def execute(self, args: dict[str, Any]) -> str: ... def validate_args(self, args: dict[str, Any]) -> bool: ...
def validate_agent_state(state: AgentState) -> bool: """运行时验证状态一致性""" result_ids = {r.tool_call_id for r in state.tool_results} call_ids = {c.id for c in state.tool_calls} return result_ids.issubset(call_ids)
state = AgentState() state.add_message("user", "What's 2+2?") state.add_message("assistant", "Let me calculate...", tool_call_id="call_1") state.add_tool_result("call_1", "4")
print(f"State valid: {validate_agent_state(state)}") print(f"Messages: {len(state.messages)}")
|