1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
| import asyncio from typing import Optional from contextlib import AsyncExitStack import os import json from loguru import logger
from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client
from openai import OpenAI from dotenv import load_dotenv
load_dotenv()
api_key = os.environ["DS_API_KEY"] base_url = os.environ["DS_API_BASE"] model_name = os.environ["API_MODEL_NAME"] max_tool_calls_allowed = 5
logger.debug("FastMCP 服务器启动中...")
class MCPClient: def __init__(self): self.session: Optional[ClientSession] = None self.exit_stack = AsyncExitStack()
self.openai = OpenAI(api_key=api_key, base_url=base_url)
async def connect_to_server(self, server_script_path: str): """连接到 MCP 服务器
参数: server_script_path: 服务器脚本路径 (.py 或 .js 文件) """ is_python = server_script_path.endswith(".py") is_js = server_script_path.endswith(".js") if not (is_python or is_js): raise ValueError("服务器脚本必须是 .py 或 .js 文件")
command = "python" if is_python else "node" server_params = StdioServerParameters( command=command, args=[server_script_path], env=None )
stdio_transport = await self.exit_stack.enter_async_context( stdio_client(server_params) ) self.stdio, self.write = stdio_transport self.session = await self.exit_stack.enter_async_context( ClientSession(self.stdio, self.write) )
await self.session.initialize()
response = await self.session.list_tools() tools = response.tools print( "\n成功连接到服务器,检测到的工具:", [[tool.name, tool.description, tool.inputSchema] for tool in tools], )
async def process_query(self, query: str) -> str: """处理用户查询,支持多轮工具调用""" messages = [{"role": "user", "content": query}]
response = await self.session.list_tools() available_tools = [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": getattr(tool, "inputSchema", {}), }, } for tool in response.tools ]
current_tool_calls_count = 0 while True: model_response = self.openai.chat.completions.create( model=model_name, messages=messages, tools=available_tools, max_tokens=1000 )
assistant_message = model_response.choices[0].message logger.debug(f"助手返回消息: {assistant_message}")
messages.append({ "role": "assistant", "content": assistant_message.content or "", "tool_calls": getattr(assistant_message, "tool_calls", None) })
if not hasattr(assistant_message, "tool_calls") or not assistant_message.tool_calls or max_tool_calls_allowed <= current_tool_calls_count: return assistant_message.content or ""
for tool_call in assistant_message.tool_calls: try: tool_name = tool_call.function.name tool_args = json.loads(tool_call.function.arguments) logger.debug(f"执行工具: {tool_name},参数: {tool_args}") result = await self.session.call_tool(tool_name, tool_args) logger.debug(f"工具返回结果: {result}")
if isinstance(result, bytes): result = result.decode('utf-8', errors='replace') elif not isinstance(result, str): result = str(result)
messages.append({ "role": "tool", "content": result, "tool_call_id": tool_call.id })
except Exception as e: error_msg = f"工具调用失败: {str(e)}" logger.error(error_msg) messages.append({ "role": "tool", "content": f"Error: {str(e)}", "tool_call_id": tool_call.id }) current_tool_calls_count += 1 if current_tool_calls_count >= max_tool_calls_allowed: logger.warning("工具调用次数过多,停止调用。")
async def chat_loop(self): """运行交互式聊天循环""" print("\nMCP 客户端已启动!") print("输入你的问题,或输入 'quit' 退出。")
while True: try: query = input("\nQuery: ").strip()
if query.lower() == "quit": break
response = await self.process_query(query) print("\n" + response)
except Exception as e: logger.exception("聊天循环中出错") print(f"\n出错了: {str(e)}")
async def cleanup(self): """清理资源""" await self.exit_stack.aclose()
async def main(): if len(sys.argv) < 2: print("用法: python client.py <服务器脚本路径>") sys.exit(1)
client = MCPClient() try: await client.connect_to_server(sys.argv[1]) await client.chat_loop() finally: await client.cleanup()
if __name__ == "__main__": import sys
asyncio.run(main())
|