From 4bd343c23780ccaadd1c1483e3973bca9a6baacd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E7=85=9C?= Date: Thu, 2 Oct 2025 18:10:53 +0800 Subject: [PATCH] =?UTF-8?q?refactor(conv):=E9=87=8D=E6=9E=84=E5=AF=B9?= =?UTF-8?q?=E8=AF=9D=E7=AE=A1=E7=90=86=E5=99=A8=E4=BB=A5=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=20LCEL=20=E9=93=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新了 LangChain 导入路径以使用 langchain_core - 实现了带 RAG 的 LCEL 链用于对话处理 - 添加了构建对话历史消息的方法 - 实现了从知识库检索上下文的逻辑 - 移除了旧的对话历史构建方法- 更新了聊天方法以使用新的 LCEL 链- 修复了代码以适应新的导入结构 --- .claude/settings.local.json | 4 +- app/services/agent_orchestrator.py | 2 +- app/services/conv_manager.py | 193 +++++++++++++++++++---------- app/services/kb_manager.py | 2 +- app/services/model_manager.py | 4 +- app/tools/calculator.py | 2 +- app/tools/retriever.py | 2 +- app/utils/faiss_helper.py | 4 +- app/utils/text_splitter.py | 4 +- 9 files changed, 142 insertions(+), 75 deletions(-) diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 8642a26..6349764 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -11,7 +11,9 @@ "Bash(git init:*)", "Bash(git remote add:*)", "Bash(git add:*)", - "Bash(rm:*)" + "Bash(rm:*)", + "Bash(git commit:*)", + "Bash(git push:*)" ], "deny": [], "ask": [] diff --git a/app/services/agent_orchestrator.py b/app/services/agent_orchestrator.py index aa2f34a..2f3b6de 100644 --- a/app/services/agent_orchestrator.py +++ b/app/services/agent_orchestrator.py @@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional from datetime import datetime from sqlalchemy.orm import Session from langchain.agents import AgentExecutor, create_openai_tools_agent -from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from app.db.models import ToolCall from app.services.model_manager import ModelManager diff --git a/app/services/conv_manager.py b/app/services/conv_manager.py index ca9d75a..8c53813 100644 --- a/app/services/conv_manager.py +++ b/app/services/conv_manager.py @@ -1,8 +1,12 @@ """ 支持 RAG 的多轮对话管理器。 """ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Tuple from sqlalchemy.orm import Session +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnablePassthrough, RunnableLambda +from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from app.db.models import Conversation, Message from app.services.model_manager import ModelManager @@ -162,6 +166,107 @@ class ConversationManager: return messages + def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List: + """ + 构建对话历史消息列表(LangChain 消息格式)。 + + Args: + conversation_id: 对话 ID + limit: 要包含的最大消息数 + + Returns: + List: LangChain 消息对象列表 + """ + messages = self.get_messages(conversation_id, limit=limit) + + history_messages = [] + for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加) + if msg.role == "user": + history_messages.append(HumanMessage(content=msg.content)) + elif msg.role == "assistant": + history_messages.append(AIMessage(content=msg.content)) + + return history_messages + + def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]: + """ + 从知识库检索上下文。 + + Args: + kb_id: 知识库 ID + query: 查询文本 + k: 返回的文档数量 + + Returns: + Tuple: (格式化的上下文字符串, 来源元数据列表) + """ + try: + results = self.kb_manager.query_kb(kb_id, query, k=k) + context_parts = [] + sources = [] + + for i, result in enumerate(results, 1): + content = result['content'] + if len(content) > 500: + content = f"{content[:500]}..." + context_parts.append(f"[{i}] {content}") + sources.append(result['metadata']) + + context = "\n\n".join(context_parts) + logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results)) + return context, sources + except Exception as e: + logger.warning("rag_context_failed", error=str(e)) + return "(No relevant context found)", [] + + def _create_rag_chain(self, llm, kb_id: Optional[int] = None): + """ + 创建带 RAG 的 LCEL 链。 + + Args: + llm: LLM 实例 + kb_id: 可选的知识库 ID + + Returns: + Runnable: LCEL 链 + """ + if kb_id: + # 带 RAG 的提示模板 + prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"), + MessagesPlaceholder(variable_name="chat_history", optional=True), + ("human", "{input}"), + ]) + + # 创建检索器函数 + def retrieve_context(inputs): + context, sources = self._retrieve_context(kb_id, inputs["input"]) + return { + **inputs, + "context": context, + "_sources": sources # 保存来源以供后续使用 + } + + # LCEL 链: 检索 -> 提示 -> LLM -> 解析输出 + chain = ( + RunnableLambda(retrieve_context) + | RunnablePassthrough.assign( + response=prompt | llm | StrOutputParser() + ) + ) + else: + # 无 RAG 的简单提示模板 + prompt = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful AI assistant."), + MessagesPlaceholder(variable_name="chat_history", optional=True), + ("human", "{input}"), + ]) + + # LCEL 链: 提示 -> LLM -> 解析输出 + chain = prompt | llm | StrOutputParser() + + return chain + async def chat( self, conversation_id: int, @@ -171,7 +276,7 @@ class ConversationManager: use_rag: bool = True, ) -> Message: """ - 处理用户消息并生成助手回复。 + 处理用户消息并生成助手回复(使用 LCEL)。 Args: conversation_id: 对话 ID @@ -199,57 +304,39 @@ class ConversationManager: # 保存用户消息 self.add_message(conversation_id, "user", user_input) - # 获取对话历史 - history = self._build_history(conversation_id, limit=10) - # 获取 LLM llm = self.model_manager.get_llm(llm_name) - # 如果启用了 RAG,从知识库构建上下文 - context = "" - sources = [] - if use_rag and kb_id: - try: - results = self.kb_manager.query_kb(kb_id, user_input, k=3) - context_parts = [] - for i, result in enumerate(results, 1): - context_parts.append( - f"[{i}] {result['content'][:500]}..." - if len(result['content']) > 500 - else f"[{i}] {result['content']}" - ) - sources.append(result['metadata']) - context = "\n\n".join(context_parts) - logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results)) - except Exception as e: - logger.warning("rag_context_failed", error=str(e)) - context = "(No relevant context found)" + # 获取对话历史(LangChain 消息格式) + chat_history = self._build_chat_history_messages(conversation_id, limit=10) - # 构建提示 - if context: - prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question. + # 创建 LCEL 链 + kb_id_for_rag = kb_id if use_rag else None + chain = self._create_rag_chain(llm, kb_id_for_rag) -Context: -{context} + # 准备输入 + chain_input = { + "input": user_input, + "chat_history": chat_history, + } -Conversation history: -{history} - -User: {user_input} -Assistant:""" - else: - prompt = f"""{history} - -User: {user_input} -Assistant:""" - - # 生成回复 + # 执行链 try: - response = llm.predict(prompt) + if kb_id_for_rag: + # 带 RAG 的链返回字典 + result = await chain.ainvoke(chain_input) + response = result.get("response", "") + sources = result.get("_sources", []) + else: + # 简单链直接返回字符串 + response = await chain.ainvoke(chain_input) + sources = [] + logger.info("llm_response_generated", conversation_id=conversation_id) except Exception as e: logger.error("llm_generation_failed", error=str(e)) response = "I apologize, but I encountered an error generating a response. Please try again." + sources = [] # 保存助手消息 metadata = {} @@ -265,28 +352,6 @@ Assistant:""" logger.info("chat_complete", conversation_id=conversation_id) return assistant_message - def _build_history(self, conversation_id: int, limit: int = 10) -> str: - """ - 构建对话历史字符串。 - - Args: - conversation_id: 对话 ID - limit: 要包含的最大消息数 - - Returns: - str: 格式化的对话历史 - """ - messages = self.get_messages(conversation_id, limit=limit) - - history_parts = [] - for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中) - if msg.role == "user": - history_parts.append(f"User: {msg.content}") - elif msg.role == "assistant": - history_parts.append(f"Assistant: {msg.content}") - - return "\n".join(history_parts) if history_parts else "(No previous messages)" - def delete_conversation(self, conversation_id: int) -> bool: """ 删除对话及其所有消息。 diff --git a/app/services/kb_manager.py b/app/services/kb_manager.py index 249b386..76ee879 100644 --- a/app/services/kb_manager.py +++ b/app/services/kb_manager.py @@ -3,7 +3,7 @@ """ from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session -from langchain.schema import Document as LangChainDocument +from langchain_core.documents import Document as LangChainDocument from app.db.models import KnowledgeBase, Document from app.services.model_manager import ModelManager diff --git a/app/services/model_manager.py b/app/services/model_manager.py index cce6616..c4cdb0f 100644 --- a/app/services/model_manager.py +++ b/app/services/model_manager.py @@ -6,8 +6,8 @@ import re from typing import Optional, List, Dict, Any from sqlalchemy.orm import Session -from langchain.llms.base import BaseLLM -from langchain.embeddings.base import Embeddings +from langchain_core.language_models.llms import BaseLLM +from langchain_core.embeddings import Embeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings from app.db.models import Model diff --git a/app/tools/calculator.py b/app/tools/calculator.py index c245237..e3025d0 100644 --- a/app/tools/calculator.py +++ b/app/tools/calculator.py @@ -3,7 +3,7 @@ """ from typing import Type from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from langchain_core.tools import BaseTool from app.utils.logger import get_logger diff --git a/app/tools/retriever.py b/app/tools/retriever.py index 3457e8c..dce690f 100644 --- a/app/tools/retriever.py +++ b/app/tools/retriever.py @@ -3,7 +3,7 @@ """ from typing import Optional, Type from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from langchain_core.tools import BaseTool from app.services.kb_manager import KnowledgeBaseManager from app.utils.logger import get_logger diff --git a/app/utils/faiss_helper.py b/app/utils/faiss_helper.py index d0d621d..ae2b1eb 100644 --- a/app/utils/faiss_helper.py +++ b/app/utils/faiss_helper.py @@ -5,8 +5,8 @@ import os from typing import List, Optional from langchain_community.vectorstores import FAISS -from langchain.schema import Document as LangChainDocument -from langchain.embeddings.base import Embeddings +from langchain_core.documents import Document as LangChainDocument +from langchain_core.embeddings import Embeddings from app.utils.exceptions import VectorStoreError from app.utils.logger import get_logger diff --git a/app/utils/text_splitter.py b/app/utils/text_splitter.py index 0efa528..fac6756 100644 --- a/app/utils/text_splitter.py +++ b/app/utils/text_splitter.py @@ -3,8 +3,8 @@ """ from typing import List -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.schema import Document as LangChainDocument +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_core.documents import Document as LangChainDocument class TextSplitter: