#!/usr/bin/env python3
"""
Ollama API 全过程测试脚本
用于检测thinking模式和工作状态
"""

import json
import asyncio
import httpx
from typing import Dict, Any, List
import time

class OllamaTester:
    def __init__(self):
        self.client = httpx.AsyncClient(
            base_url="http://localhost:11435/v1",
            headers={
                "Content-Type": "application/json",
                "Authorization": "Bearer ollama"
            },
            timeout=120.0
        )
        self.test_results = []
    
    async def test_connection(self):
        """测试连接状态"""
        print("=" * 60)
        print("1. 测试Ollama连接状态")
        print("=" * 60)
        
        try:
            response = await self.client.get("/models")
            response.raise_for_status()
            models = response.json()
            
            print("✅ 连接成功!")
            print(f"📊 可用模型: {[model['name'] for model in models.get('models', [])]}")
            self.test_results.append({"test": "connection", "status": "success", "models": models})
            return True
            
        except Exception as e:
            print(f"❌ 连接失败: {e}")
            self.test_results.append({"test": "connection", "status": "failed", "error": str(e)})
            return False
    
    async def test_model_info(self, model_name: str = "qwen3:0.6b"):
        """测试模型信息"""
        print("\n" + "=" * 60)
        print("2. 测试模型信息")
        print("=" * 60)
        
        try:
            # 使用Ollama的原生API
            response = await self.client.post("/api/show", json={"name": model_name})
            response.raise_for_status()
            model_info = response.json()
            
            print(f"✅ 模型 '{model_name}' 信息获取成功!")
            print(f"📝 模板: {model_info.get('template', '')[:100]}...")
            print(f"📋 参数: {model_info.get('parameters', '')[:100]}...")
            print(f"📦 模型文件: {model_info.get('modelfile', '')[:100]}...")
            
            self.test_results.append({
                "test": "model_info", 
                "status": "success", 
                "model": model_name,
                "info": model_info
            })
            return True
            
        except Exception as e:
            print(f"❌ 模型信息获取失败: {e}")
            self.test_results.append({
                "test": "model_info", 
                "status": "failed", 
                "error": str(e)
            })
            return False
    
    async def test_simple_chat(self, model_name: str = "qwen3:0.6b"):
        """测试简单聊天"""
        print("\n" + "=" * 60)
        print("3. 测试简单聊天（非流式）")
        print("=" * 60)
        
        test_prompt = "你好，请简单介绍一下你自己 /no_think"
        
        try:
            payload = {
                "model": model_name,
                "messages": [{"role": "user", "content": test_prompt}],
                "stream": False,
                "options": {"temperature": 0.1}
            }
            
            print(f"📤 发送请求: {test_prompt}")
            print("⏳ 等待响应...")
            
            start_time = time.time()
            response = await self.client.post("/chat/completions", json=payload)
            response.raise_for_status()
            end_time = time.time()
            
            data = response.json()
            result = data["choices"][0]["message"]["content"].strip()
            
            print(f"✅ 响应成功 (耗时: {end_time - start_time:.2f}s)")
            print(f"📥 完整响应: {json.dumps(data, ensure_ascii=False, indent=2)}")
            print(f"💬 回复内容: {result}")
            
            self.test_results.append({
                "test": "simple_chat", 
                "status": "success", 
                "time": end_time - start_time,
                "response": data
            })
            return True
            
        except Exception as e:
            print(f"❌ 聊天测试失败: {e}")
            self.test_results.append({
                "test": "simple_chat", 
                "status": "failed", 
                "error": str(e)
            })
            return False
    
    async def test_streaming_chat(self, model_name: str = "qwen3:0.6b"):
        """测试流式聊天（用于检测thinking模式）"""
        print("\n" + "=" * 60)
        print("4. 测试流式聊天（检测thinking模式）")
        print("=" * 60)
        
        test_prompt = "请思考一下这个问题：2+2等于多少？然后给出答案"
        
        try:
            payload = {
                "model": model_name,
                "messages": [{"role": "user", "content": test_prompt}],
                "stream": True,
                "options": {"temperature": 0.1}
            }
            
            print(f"📤 发送请求: {test_prompt}")
            print("🔄 开始流式响应...")
            
            start_time = time.time()
            response = await self.client.post("/chat/completions", json=payload)
            response.raise_for_status()
            
            full_response = ""
            thinking_detected = False
            thinking_content = ""
            chunk_count = 0
            
            print("\n📊 流式响应 chunks:")
            print("-" * 40)
            
            async for line in response.aiter_lines():
                if line.strip() and line.strip() != "data: [DONE]":
                    chunk_count += 1
                    try:
                        # 移除 "data: " 前缀
                        if line.startswith("data: "):
                            line = line[6:]
                        
                        data = json.loads(line)
                        if "message" in data and "content" in data["message"]:
                            content = data["message"]["content"]
                            full_response += content
                            
                            # 检测thinking模式
                            thinking_keywords = ["思考", "think", "分析", "reason", "step by step", "首先", "让我"]
                            if any(keyword in content for keyword in thinking_keywords):
                                thinking_detected = True
                                thinking_content += content
                                print(f"🧠 THINKING: {content}")
                            else:
                                print(f"📝 CHUNK {chunk_count}: {content}")
                                
                    except json.JSONDecodeError as e:
                        print(f"❌ JSON解析错误: {e}, 原始数据: {line}")
                    except Exception as e:
                        print(f"❌ 处理chunk错误: {e}")
            
            end_time = time.time()
            
            print("\n" + "-" * 40)
            print(f"✅ 流式响应完成 (chunks: {chunk_count}, 耗时: {end_time - start_time:.2f}s)")
            print(f"💬 完整回复: {full_response}")
            
            if thinking_detected:
                print("🎯 检测到THINKING模式!")
                print(f"🧠 Thinking内容: {thinking_content}")
            else:
                print("🎯 未检测到THINKING模式")
            
            self.test_results.append({
                "test": "streaming_chat", 
                "status": "success", 
                "thinking_detected": thinking_detected,
                "thinking_content": thinking_content,
                "chunk_count": chunk_count,
                "time": end_time - start_time,
                "full_response": full_response
            })
            return True
            
        except Exception as e:
            print(f"❌ 流式聊天测试失败: {e}")
            self.test_results.append({
                "test": "streaming_chat", 
                "status": "failed", 
                "error": str(e)
            })
            return False
    
    async def test_with_different_prompts(self, model_name: str = "qwen3:0.6b"):
        """使用不同的prompt测试thinking模式"""
        print("\n" + "=" * 60)
        print("5. 多种prompt测试thinking模式")
        print("=" * 60)
        
        test_cases = [
            {"prompt": "直接回答：2+2=?", "description": "直接指令"},
            {"prompt": "请思考后回答：2+2=?", "description": "思考指令"},
            {"prompt": "Think step by step: 2+2=?", "description": "英文思考指令"},
            {"prompt": "无需思考，直接回答：2+2=?", "description": "禁用思考指令"},
        ]
        
        results = []
        
        for i, test_case in enumerate(test_cases, 1):
            print(f"\n🔹 测试 {i}/{len(test_cases)}: {test_case['description']}")
            print(f"   Prompt: {test_case['prompt']}")
            
            try:
                payload = {
                    "model": model_name,
                    "messages": [{"role": "user", "content": test_case['prompt']}],
                    "stream": True,
                    "options": {"temperature": 0.1}
                }
                
                response = await self.client.post("/chat/completions", json=payload)
                response.raise_for_status()
                
                full_response = ""
                thinking_detected = False
                
                async for line in response.aiter_lines():
                    if line.strip() and line.strip() != "data: [DONE]":
                        if line.startswith("data: "):
                            line = line[6:]
                        try:
                            data = json.loads(line)
                            if "message" in data and "content" in data["message"]:
                                content = data["message"]["content"]
                                full_response += content
                                
                                thinking_keywords = ["思考", "think", "分析", "reason", "step by step"]
                                if any(keyword in content for keyword in thinking_keywords):
                                    thinking_detected = True
                                    
                        except:
                            continue
                
                result = {
                    "test_case": test_case['description'],
                    "prompt": test_case['prompt'],
                    "thinking_detected": thinking_detected,
                    "response": full_response,
                    "status": "success"
                }
                
                print(f"   ✅ 结果: {'有Thinking' if thinking_detected else '无Thinking'}")
                print(f"   💬 回复: {full_response[:50]}...")
                
            except Exception as e:
                result = {
                    "test_case": test_case['description'],
                    "prompt": test_case['prompt'],
                    "thinking_detected": False,
                    "error": str(e),
                    "status": "failed"
                }
                print(f"   ❌ 失败: {e}")
            
            results.append(result)
            await asyncio.sleep(1)  # 避免请求过快
        
        self.test_results.append({
            "test": "multiple_prompts", 
            "results": results
        })
        
        return any(result['status'] == 'success' for result in results)
    
    async def run_all_tests(self):
        """运行所有测试"""
        print("🚀 开始Ollama API全面测试")
        print("=" * 60)
        
        tests = [
            self.test_connection,
            self.test_model_info,
            self.test_simple_chat,
            self.test_streaming_chat,
            self.test_with_different_prompts,
        ]
        
        results = []
        for test in tests:
            success = await test()
            results.append(success)
            await asyncio.sleep(1)  # 测试间间隔
        
        # 生成测试报告
        print("\n" + "=" * 60)
        print("📊 测试报告摘要")
        print("=" * 60)
        
        successful_tests = sum(results)
        total_tests = len(results)
        
        print(f"✅ 成功测试: {successful_tests}/{total_tests}")
        print(f"📈 成功率: {(successful_tests/total_tests)*100:.1f}%")
        
        # 保存详细测试结果
        with open("ollama_test_results.json", "w", encoding="utf-8") as f:
            json.dump(self.test_results, f, ensure_ascii=False, indent=2)
        
        print(f"📁 详细结果已保存到: ollama_test_results.json")
        
        return all(results)
    
    async def close(self):
        """关闭客户端"""
        await self.client.aclose()

async def main():
    tester = OllamaTester()
    
    try:
        success = await tester.run_all_tests()
        if success:
            print("\n🎉 所有测试通过！Ollama工作正常")
        else:
            print("\n⚠️  部分测试失败，请检查Ollama配置")
        
        return success
        
    except Exception as e:
        print(f"\n❌ 测试过程中出现错误: {e}")
        return False
    finally:
        await tester.close()

if __name__ == "__main__":
    print("Ollama API测试脚本")
    print("这个脚本将测试以下内容:")
    print("1. ✅ 连接状态")
    print("2. 📋 模型信息")
    print("3. 💬 简单聊天")
    print("4. 🔄 流式聊天（检测thinking）")
    print("5. 🧪 多种prompt测试")
    print()
    
    result = asyncio.run(main())
    exit(0 if result else 1)
