refactor(conv):重构对话管理器以使用 LCEL 链
- 更新了 LangChain 导入路径以使用 langchain_core - 实现了带 RAG 的 LCEL 链用于对话处理 - 添加了构建对话历史消息的方法 - 实现了从知识库检索上下文的逻辑 - 移除了旧的对话历史构建方法- 更新了聊天方法以使用新的 LCEL 链- 修复了代码以适应新的导入结构
This commit is contained in:
parent
9bd715d080
commit
4bd343c237
|
|
@ -11,7 +11,9 @@
|
||||||
"Bash(git init:*)",
|
"Bash(git init:*)",
|
||||||
"Bash(git remote add:*)",
|
"Bash(git remote add:*)",
|
||||||
"Bash(git add:*)",
|
"Bash(git add:*)",
|
||||||
"Bash(rm:*)"
|
"Bash(rm:*)",
|
||||||
|
"Bash(git commit:*)",
|
||||||
|
"Bash(git push:*)"
|
||||||
],
|
],
|
||||||
"deny": [],
|
"deny": [],
|
||||||
"ask": []
|
"ask": []
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
|
||||||
from app.db.models import ToolCall
|
from app.db.models import ToolCall
|
||||||
from app.services.model_manager import ModelManager
|
from app.services.model_manager import ModelManager
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
"""
|
"""
|
||||||
支持 RAG 的多轮对话管理器。
|
支持 RAG 的多轮对话管理器。
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||||
|
|
||||||
from app.db.models import Conversation, Message
|
from app.db.models import Conversation, Message
|
||||||
from app.services.model_manager import ModelManager
|
from app.services.model_manager import ModelManager
|
||||||
|
|
@ -162,6 +166,107 @@ class ConversationManager:
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
|
||||||
|
"""
|
||||||
|
构建对话历史消息列表(LangChain 消息格式)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 对话 ID
|
||||||
|
limit: 要包含的最大消息数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: LangChain 消息对象列表
|
||||||
|
"""
|
||||||
|
messages = self.get_messages(conversation_id, limit=limit)
|
||||||
|
|
||||||
|
history_messages = []
|
||||||
|
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
|
||||||
|
if msg.role == "user":
|
||||||
|
history_messages.append(HumanMessage(content=msg.content))
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
history_messages.append(AIMessage(content=msg.content))
|
||||||
|
|
||||||
|
return history_messages
|
||||||
|
|
||||||
|
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
|
||||||
|
"""
|
||||||
|
从知识库检索上下文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kb_id: 知识库 ID
|
||||||
|
query: 查询文本
|
||||||
|
k: 返回的文档数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple: (格式化的上下文字符串, 来源元数据列表)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
results = self.kb_manager.query_kb(kb_id, query, k=k)
|
||||||
|
context_parts = []
|
||||||
|
sources = []
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
content = result['content']
|
||||||
|
if len(content) > 500:
|
||||||
|
content = f"{content[:500]}..."
|
||||||
|
context_parts.append(f"[{i}] {content}")
|
||||||
|
sources.append(result['metadata'])
|
||||||
|
|
||||||
|
context = "\n\n".join(context_parts)
|
||||||
|
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
|
||||||
|
return context, sources
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("rag_context_failed", error=str(e))
|
||||||
|
return "(No relevant context found)", []
|
||||||
|
|
||||||
|
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
创建带 RAG 的 LCEL 链。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: LLM 实例
|
||||||
|
kb_id: 可选的知识库 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Runnable: LCEL 链
|
||||||
|
"""
|
||||||
|
if kb_id:
|
||||||
|
# 带 RAG 的提示模板
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||||||
|
("human", "{input}"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 创建检索器函数
|
||||||
|
def retrieve_context(inputs):
|
||||||
|
context, sources = self._retrieve_context(kb_id, inputs["input"])
|
||||||
|
return {
|
||||||
|
**inputs,
|
||||||
|
"context": context,
|
||||||
|
"_sources": sources # 保存来源以供后续使用
|
||||||
|
}
|
||||||
|
|
||||||
|
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
|
||||||
|
chain = (
|
||||||
|
RunnableLambda(retrieve_context)
|
||||||
|
| RunnablePassthrough.assign(
|
||||||
|
response=prompt | llm | StrOutputParser()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 无 RAG 的简单提示模板
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", "You are a helpful AI assistant."),
|
||||||
|
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||||||
|
("human", "{input}"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# LCEL 链: 提示 -> LLM -> 解析输出
|
||||||
|
chain = prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
|
return chain
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
conversation_id: int,
|
conversation_id: int,
|
||||||
|
|
@ -171,7 +276,7 @@ class ConversationManager:
|
||||||
use_rag: bool = True,
|
use_rag: bool = True,
|
||||||
) -> Message:
|
) -> Message:
|
||||||
"""
|
"""
|
||||||
处理用户消息并生成助手回复。
|
处理用户消息并生成助手回复(使用 LCEL)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: 对话 ID
|
conversation_id: 对话 ID
|
||||||
|
|
@ -199,57 +304,39 @@ class ConversationManager:
|
||||||
# 保存用户消息
|
# 保存用户消息
|
||||||
self.add_message(conversation_id, "user", user_input)
|
self.add_message(conversation_id, "user", user_input)
|
||||||
|
|
||||||
# 获取对话历史
|
|
||||||
history = self._build_history(conversation_id, limit=10)
|
|
||||||
|
|
||||||
# 获取 LLM
|
# 获取 LLM
|
||||||
llm = self.model_manager.get_llm(llm_name)
|
llm = self.model_manager.get_llm(llm_name)
|
||||||
|
|
||||||
# 如果启用了 RAG,从知识库构建上下文
|
# 获取对话历史(LangChain 消息格式)
|
||||||
context = ""
|
chat_history = self._build_chat_history_messages(conversation_id, limit=10)
|
||||||
sources = []
|
|
||||||
if use_rag and kb_id:
|
# 创建 LCEL 链
|
||||||
|
kb_id_for_rag = kb_id if use_rag else None
|
||||||
|
chain = self._create_rag_chain(llm, kb_id_for_rag)
|
||||||
|
|
||||||
|
# 准备输入
|
||||||
|
chain_input = {
|
||||||
|
"input": user_input,
|
||||||
|
"chat_history": chat_history,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行链
|
||||||
try:
|
try:
|
||||||
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
|
if kb_id_for_rag:
|
||||||
context_parts = []
|
# 带 RAG 的链返回字典
|
||||||
for i, result in enumerate(results, 1):
|
result = await chain.ainvoke(chain_input)
|
||||||
context_parts.append(
|
response = result.get("response", "")
|
||||||
f"[{i}] {result['content'][:500]}..."
|
sources = result.get("_sources", [])
|
||||||
if len(result['content']) > 500
|
|
||||||
else f"[{i}] {result['content']}"
|
|
||||||
)
|
|
||||||
sources.append(result['metadata'])
|
|
||||||
context = "\n\n".join(context_parts)
|
|
||||||
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("rag_context_failed", error=str(e))
|
|
||||||
context = "(No relevant context found)"
|
|
||||||
|
|
||||||
# 构建提示
|
|
||||||
if context:
|
|
||||||
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.
|
|
||||||
|
|
||||||
Context:
|
|
||||||
{context}
|
|
||||||
|
|
||||||
Conversation history:
|
|
||||||
{history}
|
|
||||||
|
|
||||||
User: {user_input}
|
|
||||||
Assistant:"""
|
|
||||||
else:
|
else:
|
||||||
prompt = f"""{history}
|
# 简单链直接返回字符串
|
||||||
|
response = await chain.ainvoke(chain_input)
|
||||||
|
sources = []
|
||||||
|
|
||||||
User: {user_input}
|
|
||||||
Assistant:"""
|
|
||||||
|
|
||||||
# 生成回复
|
|
||||||
try:
|
|
||||||
response = llm.predict(prompt)
|
|
||||||
logger.info("llm_response_generated", conversation_id=conversation_id)
|
logger.info("llm_response_generated", conversation_id=conversation_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("llm_generation_failed", error=str(e))
|
logger.error("llm_generation_failed", error=str(e))
|
||||||
response = "I apologize, but I encountered an error generating a response. Please try again."
|
response = "I apologize, but I encountered an error generating a response. Please try again."
|
||||||
|
sources = []
|
||||||
|
|
||||||
# 保存助手消息
|
# 保存助手消息
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
@ -265,28 +352,6 @@ Assistant:"""
|
||||||
logger.info("chat_complete", conversation_id=conversation_id)
|
logger.info("chat_complete", conversation_id=conversation_id)
|
||||||
return assistant_message
|
return assistant_message
|
||||||
|
|
||||||
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
|
|
||||||
"""
|
|
||||||
构建对话历史字符串。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: 对话 ID
|
|
||||||
limit: 要包含的最大消息数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化的对话历史
|
|
||||||
"""
|
|
||||||
messages = self.get_messages(conversation_id, limit=limit)
|
|
||||||
|
|
||||||
history_parts = []
|
|
||||||
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
|
|
||||||
if msg.role == "user":
|
|
||||||
history_parts.append(f"User: {msg.content}")
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
history_parts.append(f"Assistant: {msg.content}")
|
|
||||||
|
|
||||||
return "\n".join(history_parts) if history_parts else "(No previous messages)"
|
|
||||||
|
|
||||||
def delete_conversation(self, conversation_id: int) -> bool:
|
def delete_conversation(self, conversation_id: int) -> bool:
|
||||||
"""
|
"""
|
||||||
删除对话及其所有消息。
|
删除对话及其所有消息。
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from langchain.schema import Document as LangChainDocument
|
from langchain_core.documents import Document as LangChainDocument
|
||||||
|
|
||||||
from app.db.models import KnowledgeBase, Document
|
from app.db.models import KnowledgeBase, Document
|
||||||
from app.services.model_manager import ModelManager
|
from app.services.model_manager import ModelManager
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import re
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||||
|
|
||||||
from app.db.models import Model
|
from app.db.models import Model
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import Type
|
from typing import Type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from langchain.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
"""
|
"""
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from langchain.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from app.services.kb_manager import KnowledgeBaseManager
|
from app.services.kb_manager import KnowledgeBaseManager
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain_community.vectorstores import FAISS
|
from langchain_community.vectorstores import FAISS
|
||||||
from langchain.schema import Document as LangChainDocument
|
from langchain_core.documents import Document as LangChainDocument
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from app.utils.exceptions import VectorStoreError
|
from app.utils.exceptions import VectorStoreError
|
||||||
from app.utils.logger import get_logger
|
from app.utils.logger import get_logger
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from langchain.schema import Document as LangChainDocument
|
from langchain_core.documents import Document as LangChainDocument
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter:
|
class TextSplitter:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue