Question Graph
Example of a graph for asking and evaluating questions.
Demonstrates:
Running the Example
With dependencies installed and environment variables set, run:
python -m pydantic_ai_examples.question_graph
uv run -m pydantic_ai_examples.question_graph
Example Code
question_graph.py
from __future__ import annotations as _annotations
from dataclasses import dataclass, field
from pathlib import Path
import logfire
from groq import BaseModel
from pydantic_graph import (
BaseNode,
End,
Graph,
GraphRunContext,
)
from pydantic_graph.persistence.file import FileStatePersistence
from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
from pydantic_ai.messages import ModelMessage
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
logfire.configure(send_to_logfire='if-token-present')
ask_agent = Agent('openai:gpt-4o', result_type=str, instrument=True)
@dataclass
class QuestionState:
question: str | None = None
ask_agent_messages: list[ModelMessage] = field(default_factory=list)
evaluate_agent_messages: list[ModelMessage] = field(default_factory=list)
@dataclass
class Ask(BaseNode[QuestionState]):
async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer:
result = await ask_agent.run(
'Ask a simple question with a single correct answer.',
message_history=ctx.state.ask_agent_messages,
)
ctx.state.ask_agent_messages += result.all_messages()
ctx.state.question = result.data
return Answer(result.data)
@dataclass
class Answer(BaseNode[QuestionState]):
question: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate:
answer = input(f'{self.question}: ')
return Evaluate(answer)
class EvaluationResult(BaseModel, use_attribute_docstrings=True):
correct: bool
"""Whether the answer is correct."""
comment: str
"""Comment on the answer, reprimand the user if the answer is wrong."""
evaluate_agent = Agent(
'openai:gpt-4o',
result_type=EvaluationResult,
system_prompt='Given a question and answer, evaluate if the answer is correct.',
)
@dataclass
class Evaluate(BaseNode[QuestionState, None, str]):
answer: str
async def run(
self,
ctx: GraphRunContext[QuestionState],
) -> End[str] | Reprimand:
assert ctx.state.question is not None
result = await evaluate_agent.run(
format_as_xml({'question': ctx.state.question, 'answer': self.answer}),
message_history=ctx.state.evaluate_agent_messages,
)
ctx.state.evaluate_agent_messages += result.all_messages()
if result.data.correct:
return End(result.data.comment)
else:
return Reprimand(result.data.comment)
@dataclass
class Reprimand(BaseNode[QuestionState]):
comment: str
async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
print(f'Comment: {self.comment}')
ctx.state.question = None
return Ask()
question_graph = Graph(
nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState
)
async def run_as_continuous():
state = QuestionState()
node = Ask()
end = await question_graph.run(node, state=state)
print('END:', end.output)
async def run_as_cli(answer: str | None):
persistence = FileStatePersistence(Path('question_graph.json'))
persistence.set_graph_types(question_graph)
if snapshot := await persistence.load_next():
state = snapshot.state
assert answer is not None, (
'answer required, usage "uv run -m pydantic_ai_examples.question_graph cli <answer>"'
)
node = Evaluate(answer)
else:
state = QuestionState()
node = Ask()
# debug(state, node)
async with question_graph.iter(node, state=state, persistence=persistence) as run:
while True:
node = await run.next()
if isinstance(node, End):
print('END:', node.data)
history = await persistence.load_all()
print('history:', '\n'.join(str(e.node) for e in history), sep='\n')
print('Finished!')
break
elif isinstance(node, Answer):
print(node.question)
break
# otherwise just continue
if __name__ == '__main__':
import asyncio
import sys
try:
sub_command = sys.argv[1]
assert sub_command in ('continuous', 'cli', 'mermaid')
except (IndexError, AssertionError):
print(
'Usage:\n'
' uv run -m pydantic_ai_examples.question_graph mermaid\n'
'or:\n'
' uv run -m pydantic_ai_examples.question_graph continuous\n'
'or:\n'
' uv run -m pydantic_ai_examples.question_graph cli [answer]',
file=sys.stderr,
)
sys.exit(1)
if sub_command == 'mermaid':
print(question_graph.mermaid_code(start_node=Ask))
elif sub_command == 'continuous':
asyncio.run(run_as_continuous())
else:
a = sys.argv[2] if len(sys.argv) > 2 else None
asyncio.run(run_as_cli(a))
The mermaid diagram generated in this example looks like this:
---
title: question_graph
---
stateDiagram-v2
[*] --> Ask
Ask --> Answer: ask the question
Answer --> Evaluate: answer the question
Evaluate --> Congratulate
Evaluate --> Castigate
Congratulate --> [*]: success
Castigate --> Ask: try again