Back to skills
SkillHub ClubShip Full StackFull Stack

alphaear-predictor

Market prediction skill using Kronos. Use when user needs finance market time-series forecasting or news-aware finance market adjustments.

Packaged view

This page reorganizes the original catalog entry around fit, installability, and workflow context first. The original raw source lives below.

Stars
320
Hot score
99
Updated
March 20, 2026
Overall rating
C3.3
Composite score
3.3
Best-practice grade
B71.9

Install command

npx @skill-hub/cli install rkiding-awesome-finance-skills-alphaear-predictor

Repository

RKiding/Awesome-finance-skills

Skill path: skills/alphaear-predictor

Market prediction skill using Kronos. Use when user needs finance market time-series forecasting or news-aware finance market adjustments.

Open repository

Best for

Primary workflow: Ship Full Stack.

Technical facets: Full Stack.

Target audience: everyone.

License: Unknown.

Original source

Catalog source: SkillHub Club.

Repository owner: RKiding.

This is still a mirrored public skill entry. Review the repository before installing into production workflows.

What it helps with

  • Install alphaear-predictor into Claude Code, Codex CLI, Gemini CLI, or OpenCode workflows
  • Review https://github.com/RKiding/Awesome-finance-skills before adding alphaear-predictor to shared team environments
  • Use alphaear-predictor for development workflows

Works across

Claude CodeCodex CLIGemini CLIOpenCode

Favorites: 0.

Sub-skills: 0.

Aggregator: No.

Original source / Raw SKILL.md

---
name: alphaear-predictor
description: Market prediction skill using Kronos. Use when user needs finance market time-series forecasting or news-aware finance market adjustments.
---

# AlphaEar Predictor Skill

## Overview

This skill utilizes the Kronos model (via `KronosPredictorUtility`) to perform time-series forecasting and adjust predictions based on news sentiment.

## Capabilities

### 1. Forecast Market Trends

### 1. Forecast Market Trends

**Workflow:**
1.  **Generate Base Forecast**: Use `scripts/kronos_predictor.py` (via `KronosPredictorUtility`) to generate the technical/quantitative forecast.
2.  **Adjust Forecast (Agentic)**: Use the **Forecast Adjustment Prompt** in `references/PROMPTS.md` to subjectively adjust the numbers based on latest news/logic.

**Key Tools:**
-   `KronosPredictorUtility.get_base_forecast(df, lookback, pred_len, news_text)`: Returns `List[KLinePoint]`.

**Example Usage (Python):**

```python
from scripts.utils.kronos_predictor import KronosPredictorUtility
from scripts.utils.database_manager import DatabaseManager

db = DatabaseManager()
predictor = KronosPredictorUtility()

# Forecast
forecast = predictor.predict("600519", horizon="7d")
print(forecast)
```


## Configuration

This skill requires the **Kronos** model and an embedding model.

1.  **Kronos Model**:
    -   Ensure `exports/models` directory exists in the project root.
    -   Place trained news projector weights (e.g., `kronos_news_v1.pt`) in `exports/models/`.
    -   Or depend on the base model (automatically downloaded).

2.  **Environment Variables**:
    -   `EMBEDDING_MODEL`: Path or name of the embedding model (default: `sentence-transformers/all-MiniLM-L6-v2`).
    -   `KRONOS_MODEL_PATH`: Optional path to override model loading.

## Dependencies

-   `torch`
-   `transformers`
-   `sentence-transformers`
-   `pandas`
-   `numpy`
-   `scikit-learn`


---

## Referenced Files

> The following files are referenced in this skill and included for context.

### scripts/kronos_predictor.py

```python
import torch
import pandas as pd
import numpy as np
from datetime import datetime
from typing import List, Optional
from loguru import logger
from pandas.tseries.offsets import BusinessDay
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Fix for Kronos internal imports
import sys
import os
KRONOS_DIR = os.path.join(os.path.dirname(__file__), 'predictor')
if KRONOS_DIR not in sys.path:
    sys.path.append(KRONOS_DIR)

import glob
from sentence_transformers import SentenceTransformer

from .predictor.model import Kronos, KronosTokenizer, KronosPredictor
from .schema.models import KLinePoint

class KronosPredictorUtility:
    """
    Kronos 时序预测工具类
    负责模型加载、推理以及数据结构转换
    """
    _instance = None
    _predictor = None

    def __new__(cls, *args, **kwargs):
        if not cls._instance:
            cls._instance = super(KronosPredictorUtility, cls).__new__(cls)
        return cls._instance

    def __init__(self, device: Optional[str] = None):
        if self._predictor is not None:
            return
            
        try:
            if not device:
                device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
            
            logger.info(f"🔮 Loading Kronos Model on {device}...")
            
            # 1. Load Embedder (SentenceTransformer)
            model_name = os.getenv('EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')  # Match training
            try:
                self.embedder = SentenceTransformer(model_name, device=device, local_files_only=True)
            except Exception:
                logger.warning(f"⚠️ Local embedder {model_name} not found. Downloading...")
                self.embedder = SentenceTransformer(model_name, device=device)

            # 2. Load Kronos Base
            try:
                tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", local_files_only=True)
                model = Kronos.from_pretrained("NeoQuasar/Kronos-base", local_files_only=True)
            except Exception:
                logger.warning("⚠️ Local Kronos cache not found. Attempting to download...")
                tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
                model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
            
            # 3. Load Trained News Projector Weights
            # Check predictor/exports/models directory
            models_dir = os.path.join(KRONOS_DIR, "exports/models")
            model_files = glob.glob(os.path.join(models_dir, "*.pt"))
            
            if model_files:
                latest_model = max(model_files, key=os.path.getctime)
                logger.info(f"🔄 Loading trained news weights from {latest_model}...")
                try:
                    checkpoint = torch.load(latest_model, map_location=device)
                    # The checkpoint contains 'news_proj_state_dict'
                    if 'news_proj_state_dict' in checkpoint:
                        if not hasattr(model, 'news_proj') or model.news_proj is None:
                            import torch.nn as nn
                            news_dim = checkpoint.get('news_dim', 384)
                            model.news_proj = nn.Linear(news_dim, model.d_model).to(device)
                        
                        model.news_proj.load_state_dict(checkpoint['news_proj_state_dict'])
                        logger.success("✅ News-Aware Projection Layer loaded!")
                        self.has_news_model = True
                    else:
                        logger.warning("⚠️ Checkpoint found but missing 'news_proj_state_dict'. Using base model.")
                        self.has_news_model = False
                except Exception as e:
                    logger.error(f"❌ Failed to load trained weights: {e}. Using base model.")
                    self.has_news_model = False
            else:
                logger.info("ℹ️ No trained news models found. Using base model.")
                self.has_news_model = False
            
            tokenizer = tokenizer.to(device)
            model = model.to(device)
            
            self._predictor = KronosPredictor(model, tokenizer, device=device, max_context=512)
            logger.info("✅ Kronos Model loaded successfully.")
        except Exception as e:
            logger.error(f"❌ Failed to load Kronos Model: {e}")
            self._predictor = None
            self.has_news_model = False

    def get_base_forecast(self, df: pd.DataFrame, lookback: int = 20, pred_len: int = 5, news_text: Optional[str] = None) -> List[KLinePoint]:
        """
        生成原始模型预测
        """
        if self._predictor is None:
            logger.error("Predictor not initialized.")
            return []

        if len(df) < lookback:
            logger.warning(f"Insufficient historical data ({len(df)}) for lookback ({lookback}).")
            return []

        # 获取最后 lookback 条数据
        x_df = df.iloc[-lookback:].copy()
        x_timestamp = pd.to_datetime(x_df['date']) # Ensure datetime
        last_date = x_timestamp.iloc[-1]
        
        # 生成未来时间戳
        future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B')
        y_timestamp = pd.Series(future_dates)

        # Embedding News if available
        news_emb = None
        if news_text and getattr(self, 'has_news_model', False) and hasattr(self, 'embedder'):
            try:
                # Truncate to avoid too long text
                emb = self.embedder.encode(news_text[:1000])
                news_emb = emb # KronosPredictor expects numpy array or tensor
            except Exception as e:
                logger.error(f"Failed to encode news: {e}")

        try:
            # 预测所需的列
            cols = ['open', 'high', 'low', 'close', 'volume']
            pred_df = self._predictor.predict(
                df=x_df[cols],
                x_timestamp=x_timestamp,
                y_timestamp=y_timestamp,
                pred_len=pred_len,
                T=1.0, 
                top_p=0.9, 
                sample_count=1,
                verbose=False,
                news_emb=news_emb
            )
            
            # 转换为 KLinePoint
            results = []
            for date, row in pred_df.iterrows():
                results.append(KLinePoint(
                    date=date.strftime("%Y-%m-%d"),
                    open=float(row['open']),
                    high=float(row['high']),
                    low=float(row['low']),
                    close=float(row['close']),
                    volume=float(row['volume'])
                ))
            return results
        except Exception as e:
            logger.error(f"Forecast generation failed: {e}")
            return []

# Singleton instance for easy access
# Usage: predictor = KronosPredictorUtility()

```

### references/PROMPTS.md

```markdown
# AlphaEar Predictor Prompts

## Forecast Adjustment (Analyst)

**Prompt:**

```markdown
You are a senior quantitative strategy analyst.
Your task is to subjectively/logically adjust the given [Kronos Model Forecast] based on the [Latest Intelligence/News Context].

Ticker: {ticker}

【Kronos Base Forecast (OHLC)】:
{forecast_str}

【Latest Intelligence Context】:
{news_context}

**Adjustment Principles:**
1. Base forecast is technical-only.
2. Context may contain a "Quantitative Correction" from a news-aware model. **Highly respect** this unless logic is flawed.
3. Use qualitative analysis (news logic) to verify or fine-tune.
4. If no quantitative correction exists, verify trend manually against news sentiment.

**Output (Strict JSON):**
```json
{
  "adjusted_forecast": [
    {
      "date": "YYYY-MM-DD",
      "open": <float>,
      "high": <float>,
      "low": <float>,
      "close": <float>,
      "volume": <float>
    },
    ...
  ],
  "rationale": "Detailed logic..."
}
```
Ensure same number of data points as base forecast.
```

```



---

## Skill Companion Files

> Additional files collected from the skill directory layout.

### scripts/forecast_agent.py

```python
import json
from typing import List, Optional, Dict, Any
from datetime import datetime
from loguru import logger
import pandas as pd

from .kronos_predictor import KronosPredictorUtility
from .utils.database_manager import DatabaseManager
from .schema.models import ForecastResult, KLinePoint, InvestmentSignal

class ForecastUtils:
    """
    预测辅助工具 (ForecastUtils)
    提供数据准备、基础模型预测等功能。
    LLM 调整逻辑已移交 Agent 执行 (参考 scripts/prompts/PROMPTS.md)。
    """
    
    def __init__(self, db: DatabaseManager):
        self.db = db
        self.predictor_util = KronosPredictorUtility() # Singleton

    def get_base_forecast(
        self,
        ticker: str,
        signals: List[Dict] = None,
        lookback: int = 20,
        pred_len: int = 5,
    ) -> Optional[List[KLinePoint]]:
        """
        获取基础预测数据 (技术面 + 新闻模型定量修正)。
        Agent 应随后使用 PROMPTS.md 中的指令进行定性调整。
        """
        logger.info(f"🔮 Generating base forecast for {ticker}...")
        
        # 1. 获取历史数据
        from .stock_tools import StockTools
        stock_tools = StockTools(self.db, auto_update=False)
        
        end_date = datetime.now().strftime("%Y-%m-%d")
        # 宽放一点时间以确保有足够的交易日
        start_date = (datetime.now() - pd.Timedelta(days=max(lookback * 4, 90))).strftime("%Y-%m-%d")
        df = stock_tools.get_stock_price(ticker, start_date=start_date, end_date=end_date)

        if df.empty or len(df) < lookback:
            # Try force sync
            df = stock_tools.get_stock_price(ticker, start_date=start_date, end_date=end_date, force_sync=True)

        if df.empty:
            logger.warning(f"⚠️ No history data for {ticker}")
            return None

        effective_lookback = lookback
        if len(df) < lookback:
            if len(df) < 10:
                logger.warning(f"⚠️ Insufficient history for {ticker}")
                return None
            effective_lookback = len(df)

        # 2. 准备信号上下文
        signal_lines = []
        for s in (signals or []):
            try:
                title = s.get('title', '') if isinstance(s, dict) else getattr(s, 'title', '')
                summary = s.get('summary', '') if isinstance(s, dict) else getattr(s, 'summary', '')
                if title or summary:
                    signal_lines.append(f"- {title}: {summary}")
            except Exception:
                continue

        signals_context = "\n".join(signal_lines).strip()
        
        # 3. 模型预测 (News-Adjusted if context exists)
        if signals_context:
            return self.predictor_util.get_base_forecast(df, lookback=effective_lookback, pred_len=pred_len, news_text=signals_context)
        else:
            return self.predictor_util.get_base_forecast(df, lookback=effective_lookback, pred_len=pred_len, news_text=None)

```

### scripts/json_utils.py

```python
import ast
import json
import re
from typing import Optional, Any
from loguru import logger

def _strip_comments(text: str) -> str:
    """
    Safely remove C-style comments (// and /* */) from JSON-like text,
    preserving strings (including URLs like http://).
    """
    result = []
    i = 0
    n = len(text)
    in_string = False
    escape = False
    
    while i < n:
        char = text[i]
        
        if in_string:
            if char == '\\':
                escape = not escape
            elif char == '"' and not escape:
                in_string = False
            else:
                escape = False
            result.append(char)
            i += 1
            continue
            
        # Not in string
        if char == '"':
            in_string = True
            result.append(char)
            i += 1
            continue
            
        # Check for // comment
        if i + 1 < n and text[i:i+2] == '//':
            i += 2
            while i < n and text[i] != '\n':
                i += 1
            continue
            
        # Check for /* comment
        if i + 1 < n and text[i:i+2] == '/*':
            i += 2
            while i + 1 < n and text[i:i+2] != '*/':
                i += 1
            i += 2
            continue
            
        result.append(char)
        i += 1
        
    return ''.join(result)

def extract_json(text: str) -> Optional[Any]:
    """
    更加鲁棒的 JSON 提取工具。
    处理:
    1. Markdown 代码块 (```json ... ```)
    2. 首尾多余字符
    3. 同一个文本中多个 JSON 对象 (仅提取第一个)
    4. 简单的 JSON 修复 (末尾逗号等)
    5. C 风格注释 (// 和 /* */)
    """
    if not text:
        return None
    
    # 1. 清理明显的 Markdown 包装
    text = text.strip()
    
    # 先尝试精确匹配 ```json ... ``` 或 ```...```
    md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
    if md_match:
        text = md_match.group(1).strip()
    elif text.startswith("```"):
        # 回退:如果开头有 ``` 但没完整匹配
        text = re.sub(r'^```[a-z]*\n?', '', text)
        text = re.sub(r'\n?```\s*$', '', text)
    
    # 2. 寻找第一个 JSON 起始符 { 或 [
    start_brace = text.find('{')
    start_bracket = text.find('[')
    
    if start_brace == -1 and start_bracket == -1:
        return None
        
    start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket
    
    # 2.5 预处理:修复一些极其常见的 LLM 错误
    potential_json = text[start_idx:].strip()
    
    # remove comments safely
    potential_json = _strip_comments(potential_json)
    
    # b. 修复缺失开头引号的键:  nodes": [  -> "nodes": [
    # 匹配模式: (空白或换行) 单词 紧跟引号和冒号
    potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json)
    
    # c. 修复缺失末尾引号的键:  "nodes: [ -> "nodes": [
    potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)

    # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [
    # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后
    potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)
    
    # 3. 使用 raw_decode 尝试解析
    decoder = json.JSONDecoder()
    
    # 首先尝试直接解析(不做任何预处理)
    try:
        obj = json.loads(potential_json)
        return obj
    except json.JSONDecodeError:
        pass
    
    # 简单预处理:移除对象/列表末位多余逗号
    processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json)
    
    try:
        obj, end_pos = decoder.raw_decode(processed_json)
        return obj
    except json.JSONDecodeError:
        pass
    
    # e. 修复未终止的字符串字面量问题:移除值中的实际换行符
    # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法
    def fix_multiline_strings(s):
        # 简单策略:将字符串值内的换行替换为空格
        lines = s.split('\n')
        result = []
        in_string = False
        for line in lines:
            # 计算未转义的引号数
            quote_count = line.count('"') - line.count('\\"')
            if in_string:
                result[-1] += ' ' + line.strip()
            else:
                result.append(line)
            
            if quote_count % 2 == 1:
                in_string = not in_string
        return '\n'.join(result)
    
    fixed_json = fix_multiline_strings(processed_json)
    
    try:
        obj, end_pos = decoder.raw_decode(fixed_json)
        return obj
    except json.JSONDecodeError:
        try:
            # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号)
            # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构
            # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退
            fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键
            fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes)   # 修复简单值
            obj, end_pos = decoder.raw_decode(fix_quotes)
            return obj
        except (json.JSONDecodeError, TypeError):
            try:
                # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式)
                # 提取第一个匹配的括号对内容
                # 寻找匹配的 { }
                stack = []
                for i, char in enumerate(potential_json):
                    if char == '{': stack.append('{')
                    elif char == '}':
                        if stack: stack.pop()
                        if not stack:
                            content = potential_json[:i+1]
                            return ast.literal_eval(content)
            except (ValueError, SyntaxError, MemoryError) as e:
                logger.warning(f"All JSON extraction attempts failed: {e}")
            except Exception as e:
                logger.error(f"Unexpected error during JSON extraction: {e}")
    
    return None

```

### scripts/predictor/model/__init__.py

```python
from .kronos import KronosTokenizer, Kronos, KronosPredictor

model_dict = {
    'kronos_tokenizer': KronosTokenizer,
    'kronos': Kronos,
    'kronos_predictor': KronosPredictor
}


def get_model_class(model_name):
    if model_name in model_dict:
        return model_dict[model_name]
    else:
        print(f"Model {model_name} not found in model_dict")
        raise NotImplementedError


```

### scripts/predictor/model/kronos.py

```python
import numpy as np
import pandas as pd
import torch
from huggingface_hub import PyTorchModelHubMixin
import sys

from tqdm import trange

sys.path.append("../")
from model.module import *


class KronosTokenizer(nn.Module, PyTorchModelHubMixin):
    """
    KronosTokenizer module for tokenizing input data using a hybrid quantization approach.

    This tokenizer utilizes a combination of encoder and decoder Transformer blocks
    along with the Binary Spherical Quantization (BSQuantizer) to compress and decompress input data.

    Args:
           d_in (int): Input dimension.
           d_model (int): Model dimension.
           n_heads (int): Number of attention heads.
           ff_dim (int): Feed-forward dimension.
           n_enc_layers (int): Number of encoder layers.
           n_dec_layers (int): Number of decoder layers.
           ffn_dropout_p (float): Dropout probability for feed-forward networks.
           attn_dropout_p (float): Dropout probability for attention mechanisms.
           resid_dropout_p (float): Dropout probability for residual connections.
           s1_bits (int): Number of bits for the pre token in BSQuantizer.
           s2_bits (int): Number of bits for the post token in BSQuantizer.
           beta (float): Beta parameter for BSQuantizer.
           gamma0 (float): Gamma0 parameter for BSQuantizer.
           gamma (float): Gamma parameter for BSQuantizer.
           zeta (float): Zeta parameter for BSQuantizer.
           group_size (int): Group size parameter for BSQuantizer.

    """

    def __init__(self, d_in, d_model, n_heads, ff_dim, n_enc_layers, n_dec_layers, ffn_dropout_p, attn_dropout_p, resid_dropout_p, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):

        super().__init__()
        self.d_in = d_in
        self.d_model = d_model
        self.n_heads = n_heads
        self.ff_dim = ff_dim
        self.enc_layers = n_enc_layers
        self.dec_layers = n_dec_layers
        self.ffn_dropout_p = ffn_dropout_p
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout_p = resid_dropout_p

        self.s1_bits = s1_bits
        self.s2_bits = s2_bits
        self.codebook_dim = s1_bits + s2_bits # Total dimension of the codebook after quantization
        self.embed = nn.Linear(self.d_in, self.d_model)
        self.head = nn.Linear(self.d_model, self.d_in)

        # Encoder Transformer Blocks
        self.encoder = nn.ModuleList([
            TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
            for _ in range(self.enc_layers - 1)
        ])
        # Decoder Transformer Blocks
        self.decoder = nn.ModuleList([
            TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
            for _ in range(self.dec_layers - 1)
        ])
        self.quant_embed = nn.Linear(in_features=self.d_model, out_features=self.codebook_dim) # Linear layer before quantization
        self.post_quant_embed_pre = nn.Linear(in_features=self.s1_bits, out_features=self.d_model) # Linear layer after quantization (pre part - s1 bits)
        self.post_quant_embed = nn.Linear(in_features=self.codebook_dim, out_features=self.d_model) # Linear layer after quantization (full codebook)
        self.tokenizer = BSQuantizer(self.s1_bits, self.s2_bits, beta, gamma0, gamma, zeta, group_size) # BSQuantizer module

    def forward(self, x):
        """
        Forward pass of the KronosTokenizer.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).

        Returns:
            tuple: A tuple containing:
                - tuple: (z_pre, z) - Reconstructed outputs from decoder with s1_bits and full codebook respectively,
                         both of shape (batch_size, seq_len, d_in).
                - torch.Tensor: bsq_loss - Loss from the BSQuantizer.
                - torch.Tensor: quantized - Quantized representation from BSQuantizer.
                - torch.Tensor: z_indices - Indices from the BSQuantizer.
        """
        z = self.embed(x)

        for layer in self.encoder:
            z = layer(z)

        z = self.quant_embed(z) # (B, T, codebook)

        bsq_loss, quantized, z_indices = self.tokenizer(z)

        quantized_pre = quantized[:, :, :self.s1_bits] # Extract the first part of quantized representation (s1_bits)
        z_pre = self.post_quant_embed_pre(quantized_pre)

        z = self.post_quant_embed(quantized)

        # Decoder layers (for pre part - s1 bits)
        for layer in self.decoder:
            z_pre = layer(z_pre)
        z_pre = self.head(z_pre)

        # Decoder layers (for full codebook)
        for layer in self.decoder:
            z = layer(z)
        z = self.head(z)

        return (z_pre, z), bsq_loss, quantized, z_indices

    def indices_to_bits(self, x, half=False):
        """
        Converts indices to bit representations and scales them.

        Args:
            x (torch.Tensor): Indices tensor.
            half (bool, optional): Whether to process only half of the codebook dimension. Defaults to False.

        Returns:
            torch.Tensor: Bit representation tensor.
        """
        if half:
            x1 = x[0] # Assuming x is a tuple of indices if half is True
            x2 = x[1]
            mask = 2 ** torch.arange(self.codebook_dim//2, device=x1.device, dtype=torch.long) # Create a mask for bit extraction
            x1 = (x1.unsqueeze(-1) & mask) != 0 # Extract bits for the first half
            x2 = (x2.unsqueeze(-1) & mask) != 0 # Extract bits for the second half
            x = torch.cat([x1, x2], dim=-1) # Concatenate the bit representations
        else:
            mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # Create a mask for bit extraction
            x = (x.unsqueeze(-1) & mask) != 0 # Extract bits

        x = x.float() * 2 - 1 # Convert boolean to bipolar (-1, 1)
        q_scale = 1. / (self.codebook_dim ** 0.5) # Scaling factor
        x = x * q_scale
        return x

    def encode(self, x, half=False):
        """
        Encodes the input data into quantized indices.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_in).
            half (bool, optional): Whether to use half quantization in BSQuantizer. Defaults to False.

        Returns:
            torch.Tensor: Quantized indices from BSQuantizer.
        """
        z = self.embed(x)
        for layer in self.encoder:
            z = layer(z)
        z = self.quant_embed(z)

        bsq_loss, quantized, z_indices = self.tokenizer(z, half=half, collect_metrics=False)
        return z_indices

    def decode(self, x, half=False):
        """
        Decodes quantized indices back to the input data space.

        Args:
            x (torch.Tensor): Quantized indices tensor.
            half (bool, optional): Whether the indices were generated with half quantization. Defaults to False.

        Returns:
            torch.Tensor: Reconstructed output tensor of shape (batch_size, seq_len, d_in).
        """
        quantized = self.indices_to_bits(x, half)
        z = self.post_quant_embed(quantized)
        for layer in self.decoder:
            z = layer(z)
        z = self.head(z)
        return z


class Kronos(nn.Module, PyTorchModelHubMixin):
    """
    Kronos Model.

    Args:
        s1_bits (int): Number of bits for pre tokens.
        s2_bits (int): Number of bits for post tokens.
        n_layers (int): Number of Transformer blocks.
        d_model (int): Dimension of the model's embeddings and hidden states.
        n_heads (int): Number of attention heads in the MultiheadAttention layers.
        ff_dim (int): Dimension of the feedforward network in the Transformer blocks.
        ffn_dropout_p (float): Dropout probability for the feedforward network.
        attn_dropout_p (float): Dropout probability for the attention layers.
        resid_dropout_p (float): Dropout probability for residual connections.
        token_dropout_p (float): Dropout probability for token embeddings.
        learn_te (bool): Whether to use learnable temporal embeddings.
    """

    def __init__(self, s1_bits, s2_bits, n_layers, d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p, token_dropout_p, learn_te, news_dim=None):
        super().__init__()
        self.s1_bits = s1_bits
        self.s2_bits = s2_bits
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_heads = n_heads
        self.learn_te = learn_te
        self.ff_dim = ff_dim
        self.ffn_dropout_p = ffn_dropout_p
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout_p = resid_dropout_p
        self.token_dropout_p = token_dropout_p
        self.news_dim = news_dim

        self.s1_vocab_size = 2 ** self.s1_bits
        self.token_drop = nn.Dropout(self.token_dropout_p)
        self.embedding = HierarchicalEmbedding(self.s1_bits, self.s2_bits, self.d_model)
        self.time_emb = TemporalEmbedding(self.d_model, self.learn_te)
        self.transformer = nn.ModuleList([
            TransformerBlock(self.d_model, self.n_heads, self.ff_dim, self.ffn_dropout_p, self.attn_dropout_p, self.resid_dropout_p)
            for _ in range(self.n_layers)
        ])
        self.norm = RMSNorm(self.d_model)
        self.dep_layer = DependencyAwareLayer(self.d_model)
        self.head = DualHead(self.s1_bits, self.s2_bits, self.d_model)

        if self.news_dim is not None:
            self.news_proj = nn.Linear(self.news_dim, self.d_model)
        else:
            self.news_proj = None

        self.apply(self._init_weights)

    def _init_weights(self, module):

        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0, std=self.embedding.d_model ** -0.5)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)
        elif isinstance(module, RMSNorm):
            nn.init.ones_(module.weight)

    def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None, news_emb=None):
        """
        Args:
            s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
            s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
            stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
            padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
            use_teacher_forcing (bool, optional): Whether to use teacher forcing for s1 decoding. Defaults to False.
            s1_targets (torch.Tensor, optional): Target s1 token IDs for teacher forcing. Shape: [batch_size, seq_len]. Defaults to None.
            news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
                - s2_logits: Logits for s2 token predictions, conditioned on s1. Shape: [batch_size, seq_len, s2_vocab_size]
        """
        x = self.embedding([s1_ids, s2_ids])
        if stamp is not None:
            time_embedding = self.time_emb(stamp)
            x = x + time_embedding
        x = self.token_drop(x)

        for layer in self.transformer:
            x = layer(x, key_padding_mask=padding_mask)

        x = self.norm(x)

        if news_emb is not None and self.news_proj is not None:
            news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
            x = x + news_bias

        s1_logits = self.head(x)

        if use_teacher_forcing:
            sibling_embed = self.embedding.emb_s1(s1_targets)
        else:
            s1_probs = F.softmax(s1_logits.detach(), dim=-1)
            sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape)
            sibling_embed = self.embedding.emb_s1(sample_s1_ids)

        x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings
        s2_logits = self.head.cond_forward(x2)
        return s1_logits, s2_logits

    def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None, news_emb=None):
        """
        Decodes only the s1 tokens.

        This method performs a forward pass to predict only s1 tokens. It returns the s1 logits
        and the context representation from the Transformer, which can be used for subsequent s2 decoding.

        Args:
            s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
            s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len]
            stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None.
            padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.
            news_emb (torch.Tensor, optional): News embedding tensor. Shape: [batch_size, news_dim]. Defaults to None.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size]
                - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model]
        """
        x = self.embedding([s1_ids, s2_ids])
        if stamp is not None:
            time_embedding = self.time_emb(stamp)
            x = x + time_embedding
        x = self.token_drop(x)

        for layer in self.transformer:
            x = layer(x, key_padding_mask=padding_mask)

        x = self.norm(x)

        if news_emb is not None and self.news_proj is not None:
            news_bias = self.news_proj(news_emb).unsqueeze(1) # [B, 1, d_model]
            x = x + news_bias

        s1_logits = self.head(x)
        return s1_logits, x

    def decode_s2(self, context, s1_ids, padding_mask=None):
        """
        Decodes the s2 tokens, conditioned on the context and s1 tokens.

        This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`)
        and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens.

        Args:
            context (torch.Tensor): Context representation from the transformer (output of decode_s1).
                                     Shape: [batch_size, seq_len, d_model]
            s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len]
            padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None.

        Returns:
            torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size]
        """
        sibling_embed = self.embedding.emb_s1(s1_ids)
        x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask)
        return self.head.cond_forward(x2)


def top_k_top_p_filtering(
        logits,
        top_k: int = 0,
        top_p: float = 1.0,
        filter_value: float = -float("Inf"),
        min_tokens_to_keep: int = 1,
):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
        return logits

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
        return logits


def sample_from_logits(logits, temperature=1.0, top_k=None, top_p=None, sample_logits=True):
    logits = logits / temperature
    if top_k is not None or top_p is not None:
        if top_k > 0 or top_p < 1.0:
            logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)

    probs = F.softmax(logits, dim=-1)

    if not sample_logits:
        _, x = top_k(probs, k=1, dim=-1)
    else:
        x = torch.multinomial(probs, num_samples=1)

    return x


def auto_regressive_inference(tokenizer, model, x, x_stamp, y_stamp, max_context, pred_len, clip=5, T=1.0, top_k=0, top_p=0.99, sample_count=5, verbose=False, news_emb=None):
    with torch.no_grad():
        x = torch.clip(x, -clip, clip)

        device = x.device
        x = x.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x.size(1), x.size(2)).to(device)
        x_stamp = x_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, x_stamp.size(1), x_stamp.size(2)).to(device)
        y_stamp = y_stamp.unsqueeze(1).repeat(1, sample_count, 1, 1).reshape(-1, y_stamp.size(1), y_stamp.size(2)).to(device)

        x_token = tokenizer.encode(x, half=True)
        
        initial_seq_len = x.size(1)
        batch_size = x_token[0].size(0)
        total_seq_len = initial_seq_len + pred_len
        full_stamp = torch.cat([x_stamp, y_stamp], dim=1)

        generated_pre = x_token[0].new_empty(batch_size, pred_len)
        generated_post = x_token[1].new_empty(batch_size, pred_len)

        pre_buffer = x_token[0].new_zeros(batch_size, max_context)
        post_buffer = x_token[1].new_zeros(batch_size, max_context)
        buffer_len = min(initial_seq_len, max_context)
        if buffer_len > 0:
            start_idx = max(0, initial_seq_len - max_context)
            pre_buffer[:, :buffer_len] = x_token[0][:, start_idx:start_idx + buffer_len]
            post_buffer[:, :buffer_len] = x_token[1][:, start_idx:start_idx + buffer_len]

        if verbose:
            ran = trange
        else:
            ran = range
        for i in ran(pred_len):
            current_seq_len = initial_seq_len + i
            window_len = min(current_seq_len, max_context)

            if current_seq_len <= max_context:
                input_tokens = [
                    pre_buffer[:, :window_len],
                    post_buffer[:, :window_len]
                ]
            else:
                input_tokens = [pre_buffer, post_buffer]

            context_end = current_seq_len
            context_start = max(0, context_end - max_context)
            current_stamp = full_stamp[:, context_start:context_end, :].contiguous()

            s1_logits, context = model.decode_s1(input_tokens[0], input_tokens[1], current_stamp, news_emb=news_emb)
            s1_logits = s1_logits[:, -1, :]
            sample_pre = sample_from_logits(s1_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)

            s2_logits = model.decode_s2(context, sample_pre)
            s2_logits = s2_logits[:, -1, :]
            sample_post = sample_from_logits(s2_logits, temperature=T, top_k=top_k, top_p=top_p, sample_logits=True)

            generated_pre[:, i] = sample_pre.squeeze(-1)
            generated_post[:, i] = sample_post.squeeze(-1)

            if current_seq_len < max_context:
                pre_buffer[:, current_seq_len] = sample_pre.squeeze(-1)
                post_buffer[:, current_seq_len] = sample_post.squeeze(-1)
            else:
                pre_buffer.copy_(torch.roll(pre_buffer, shifts=-1, dims=1))
                post_buffer.copy_(torch.roll(post_buffer, shifts=-1, dims=1))
                pre_buffer[:, -1] = sample_pre.squeeze(-1)
                post_buffer[:, -1] = sample_post.squeeze(-1)

        full_pre = torch.cat([x_token[0], generated_pre], dim=1)
        full_post = torch.cat([x_token[1], generated_post], dim=1)

        context_start = max(0, total_seq_len - max_context)
        input_tokens = [
            full_pre[:, context_start:total_seq_len].contiguous(),
            full_post[:, context_start:total_seq_len].contiguous()
        ]
        z = tokenizer.decode(input_tokens, half=True)
        z = z.reshape(-1, sample_count, z.size(1), z.size(2))
        preds = z.cpu().numpy()
        preds = np.mean(preds, axis=1)

        return preds


def calc_time_stamps(x_timestamp):
    time_df = pd.DataFrame()
    time_df['minute'] = x_timestamp.dt.minute
    time_df['hour'] = x_timestamp.dt.hour
    time_df['weekday'] = x_timestamp.dt.weekday
    time_df['day'] = x_timestamp.dt.day
    time_df['month'] = x_timestamp.dt.month
    return time_df


class KronosPredictor:

    def __init__(self, model, tokenizer, device="cuda:0", max_context=512, clip=5):
        self.tokenizer = tokenizer
        self.model = model
        self.max_context = max_context
        self.clip = clip
        self.price_cols = ['open', 'high', 'low', 'close']
        self.vol_col = 'volume'
        self.amt_vol = 'amount'
        self.time_cols = ['minute', 'hour', 'weekday', 'day', 'month']
        self.device = device

        self.tokenizer = self.tokenizer.to(self.device)
        self.model = self.model.to(self.device)

    def generate(self, x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=None):

        x_tensor = torch.from_numpy(np.array(x).astype(np.float32)).to(self.device)
        x_stamp_tensor = torch.from_numpy(np.array(x_stamp).astype(np.float32)).to(self.device)
        y_stamp_tensor = torch.from_numpy(np.array(y_stamp).astype(np.float32)).to(self.device)

        preds = auto_regressive_inference(self.tokenizer, self.model, x_tensor, x_stamp_tensor, y_stamp_tensor, self.max_context, pred_len,
                                          self.clip, T, top_k, top_p, sample_count, verbose, news_emb=news_emb)
        preds = preds[:, -pred_len:, :]
        return preds

    def predict(self, df, x_timestamp, y_timestamp, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True, news_emb=None):

        if not isinstance(df, pd.DataFrame):
            raise ValueError("Input must be a pandas DataFrame.")

        if not all(col in df.columns for col in self.price_cols):
            raise ValueError(f"Price columns {self.price_cols} not found in DataFrame.")

        df = df.copy()
        if self.vol_col not in df.columns:
            df[self.vol_col] = 0.0  # Fill missing volume with zeros
            df[self.amt_vol] = 0.0  # Fill missing amount with zeros
        if self.amt_vol not in df.columns and self.vol_col in df.columns:
            df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)

        if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
            raise ValueError("Input DataFrame contains NaN values in price or volume columns.")

        x_time_df = calc_time_stamps(x_timestamp)
        y_time_df = calc_time_stamps(y_timestamp)

        x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
        x_stamp = x_time_df.values.astype(np.float32)
        y_stamp = y_time_df.values.astype(np.float32)

        x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)

        x = (x - x_mean) / (x_std + 1e-5)
        x = np.clip(x, -self.clip, self.clip)

        x = x[np.newaxis, :]
        x_stamp = x_stamp[np.newaxis, :]
        y_stamp = y_stamp[np.newaxis, :]

        if news_emb is not None:
            news_emb_tensor = torch.from_numpy(np.array(news_emb).astype(np.float32)).to(self.device)
            # Ensure batch dimension for news_emb if only one sample
            if news_emb_tensor.ndim == 1:
                news_emb_tensor = news_emb_tensor.unsqueeze(0)
        else:
            news_emb_tensor = None

        preds = self.generate(x, x_stamp, y_stamp, pred_len, T, top_k, top_p, sample_count, verbose, news_emb=news_emb_tensor)

        preds = preds.squeeze(0)
        preds = preds * (x_std + 1e-5) + x_mean

        pred_df = pd.DataFrame(preds, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp)
        return pred_df


    def predict_batch(self, df_list, x_timestamp_list, y_timestamp_list, pred_len, T=1.0, top_k=0, top_p=0.9, sample_count=1, verbose=True):
        """
        Perform parallel (batch) prediction on multiple time series. All series must have the same historical length and prediction length (pred_len).

        Args:
            df_list (List[pd.DataFrame]): List of input DataFrames, each containing price columns and optional volume/amount columns.
            x_timestamp_list (List[pd.DatetimeIndex or Series]): List of timestamps corresponding to historical data, length should match the number of rows in each DataFrame.
            y_timestamp_list (List[pd.DatetimeIndex or Series]): List of future prediction timestamps, length should equal pred_len.
            pred_len (int): Number of prediction steps.
            T (float): Sampling temperature.
            top_k (int): Top-k filtering threshold.
            top_p (float): Top-p (nucleus sampling) threshold.
            sample_count (int): Number of parallel samples per series, automatically averaged internally.
            verbose (bool): Whether to display autoregressive progress.

        Returns:
            List[pd.DataFrame]: List of prediction results in the same order as input, each DataFrame contains
                                `open, high, low, close, volume, amount` columns, indexed by corresponding `y_timestamp`.
        """
        # Basic validation
        if not isinstance(df_list, (list, tuple)) or not isinstance(x_timestamp_list, (list, tuple)) or not isinstance(y_timestamp_list, (list, tuple)):
            raise ValueError("df_list, x_timestamp_list, y_timestamp_list must be list or tuple types.")
        if not (len(df_list) == len(x_timestamp_list) == len(y_timestamp_list)):
            raise ValueError("df_list, x_timestamp_list, y_timestamp_list must have consistent lengths.")

        num_series = len(df_list)

        x_list = []
        x_stamp_list = []
        y_stamp_list = []
        means = []
        stds = []
        seq_lens = []
        y_lens = []

        for i in range(num_series):
            df = df_list[i]
            if not isinstance(df, pd.DataFrame):
                raise ValueError(f"Input at index {i} is not a pandas DataFrame.")
            if not all(col in df.columns for col in self.price_cols):
                raise ValueError(f"DataFrame at index {i} is missing price columns {self.price_cols}.")

            df = df.copy()
            if self.vol_col not in df.columns:
                df[self.vol_col] = 0.0
                df[self.amt_vol] = 0.0
            if self.amt_vol not in df.columns and self.vol_col in df.columns:
                df[self.amt_vol] = df[self.vol_col] * df[self.price_cols].mean(axis=1)

            if df[self.price_cols + [self.vol_col, self.amt_vol]].isnull().values.any():
                raise ValueError(f"DataFrame at index {i} contains NaN values in price or volume columns.")

            x_timestamp = x_timestamp_list[i]
            y_timestamp = y_timestamp_list[i]

            x_time_df = calc_time_stamps(x_timestamp)
            y_time_df = calc_time_stamps(y_timestamp)

            x = df[self.price_cols + [self.vol_col, self.amt_vol]].values.astype(np.float32)
            x_stamp = x_time_df.values.astype(np.float32)
            y_stamp = y_time_df.values.astype(np.float32)

            if x.shape[0] != x_stamp.shape[0]:
                raise ValueError(f"Inconsistent lengths at index {i}: x has {x.shape[0]} vs x_stamp has {x_stamp.shape[0]}.")
            if y_stamp.shape[0] != pred_len:
                raise ValueError(f"y_timestamp length at index {i} should equal pred_len={pred_len}, got {y_stamp.shape[0]}.")

            x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
            x_norm = (x - x_mean) / (x_std + 1e-5)
            x_norm = np.clip(x_norm, -self.clip, self.clip)

            x_list.append(x_norm)
            x_stamp_list.append(x_stamp)
            y_stamp_list.append(y_stamp)
            means.append(x_mean)
            stds.append(x_std)

            seq_lens.append(x_norm.shape[0])
            y_lens.append(y_stamp.shape[0])

        # Require all series to have consistent historical and prediction lengths for batch processing
        if len(set(seq_lens)) != 1:
            raise ValueError(f"Parallel prediction requires all series to have consistent historical lengths, got: {seq_lens}")
        if len(set(y_lens)) != 1:
            raise ValueError(f"Parallel prediction requires all series to have consistent prediction lengths, got: {y_lens}")

        x_batch = np.stack(x_list, axis=0).astype(np.float32)           # (B, seq_len, feat)
        x_stamp_batch = np.stack(x_stamp_list, axis=0).astype(np.float32) # (B, seq_len, time_feat)
        y_stamp_batch = np.stack(y_stamp_list, axis=0).astype(np.float32) # (B, pred_len, time_feat)

        preds = self.generate(x_batch, x_stamp_batch, y_stamp_batch, pred_len, T, top_k, top_p, sample_count, verbose)
        # preds: (B, pred_len, feat)

        pred_dfs = []
        for i in range(num_series):
            preds_i = preds[i] * (stds[i] + 1e-5) + means[i]
            pred_df = pd.DataFrame(preds_i, columns=self.price_cols + [self.vol_col, self.amt_vol], index=y_timestamp_list[i])
            pred_dfs.append(pred_df)

        return pred_dfs

```

### scripts/predictor/model/module.py

```python
import math

from einops import rearrange, reduce
import torch
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F


class DifferentiableEntropyFunction(Function):
    @staticmethod
    def forward(ctx, zq, basis, K, eps):
        zb = (zq + 1) / 2
        zi = ((zb * basis).sum(-1)).to(torch.int64)
        cnt = torch.scatter_reduce(torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
                                   0,
                                   zi.flatten(),
                                   torch.ones_like(zi.flatten()).to(zq.dtype),
                                   'sum')
        prob = (cnt + eps) / (cnt + eps).sum()
        H = -(prob * torch.log(prob)).sum()
        ctx.save_for_backward(zq, zi, prob)
        ctx.K = K
        return H

    @staticmethod
    def backward(ctx, grad_output):
        zq, zi, prob = ctx.saved_tensors
        grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
        reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
        grad_input = reord_grad.unsqueeze(-1) * zq
        return grad_input, None, None, None, None


def codebook_entropy(zq, basis, K, eps=1e-4):
    return DifferentiableEntropyFunction.apply(zq, basis, K, eps)


class BinarySphericalQuantizer(nn.Module):
    def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
                 input_format='bchw',
                 soft_entropy=True, group_size=9,
                 persample_entropy_compute='analytical',
                 cb_entropy_compute='group',
                 l2_norm=True,
                 inv_temperature=1):
        """
        Paper link: https://arxiv.org/pdf/2406.07548.pdf
        Here we use the official implementation of the BinarySphericalQuantizer.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.beta = beta  # loss weight for commit loss
        self.gamma0 = gamma0  # loss weight for entropy penalty
        self.gamma = gamma  # loss weight for entropy penalty
        self.zeta = zeta  # loss weight for entire entropy penalty
        self.input_format = input_format
        assert self.embed_dim % group_size == 0, "embed_dim must be divisible by group_size"
        self.num_groups = self.embed_dim // group_size
        self.group_size = group_size
        assert persample_entropy_compute in ['group', 'analytical'], "persample_entropy_compute must be either 'group' or 'analytical'"
        assert cb_entropy_compute in ['group', 'nce'], "cb_entropy_compute must be either 'group' or 'nce'"
        self.persample_entropy_compute = persample_entropy_compute
        self.cb_entropy_compute = cb_entropy_compute
        self.l2_norm = l2_norm
        self.inv_temperature = inv_temperature

        self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
        self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))

        self.num_dimensions = 2 ** embed_dim
        self.bits_per_index = embed_dim

        # we only need to keep the codebook portion up to the group size
        # because we approximate the H loss with this subcode
        group_codes = torch.arange(2 ** self.group_size)
        group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
        self.register_buffer('group_codebook', group_codebook, persistent=False)

        self.soft_entropy = soft_entropy  # soft_entropy: Sec 3.2 of https://arxiv.org/pdf/1911.05894.pdf

    def quantize(self, z):
        assert z.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {z.shape[-1]}"

        zhat = torch.where(z > 0,
                           torch.tensor(1, dtype=z.dtype, device=z.device),
                           torch.tensor(-1, dtype=z.dtype, device=z.device))
        return z + (zhat - z).detach()

    def forward(self, z, collect_metrics=True):
        # if self.input_format == 'bchw':
        #     z = rearrange(z, 'b c h w -> b h w c')
        zq = self.quantize(z)

        q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.

        zq = zq * q_scale

        if not collect_metrics:
            return zq, zq.new_zeros(()), {}

        indices = self.codes_to_indexes(zq.detach())
        group_indices = self.codes_to_group_indexes(zq.detach())
        if not self.training:
            used_codes = torch.unique(indices, return_counts=False)
        else:
            used_codes = None

        if self.soft_entropy:
            persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
            entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
        else:
            zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
            persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
            cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
            entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy

        # commit loss
        commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))

        # if self.input_format == 'bchw':
        #     zq = rearrange(zq, 'b h w c -> b c h w')

        return (
            zq,
            commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
            {"H": cb_entropy, "used_codes": used_codes, "indices": indices, "group_indices": group_indices,
             "avg_prob": avg_prob}
        )

    def soft_entropy_loss(self, z):
        # if we divide the code in subgroups of size group_size, the codebook will be of size 2 ** group_size
        # the sub-code is the last group_size bits of the full code
        group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
        divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)

        # we calculate the distance between the divided_z and the codebook for each subgroup
        distance = - 2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
        prob = (-distance * self.inv_temperature).softmax(dim=-1)
        if self.persample_entropy_compute == 'analytical':
            if self.l2_norm:
                p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
            else:
                p = torch.sigmoid(-4 * z * self.inv_temperature)
            prob = torch.stack([p, 1 - p], dim=-1)
            per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
        else:
            per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()

        # macro average of the probability of each subgroup
        avg_prob = reduce(prob, '... g d ->g d', 'mean')
        codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)

        # the approximation of the entropy is the sum of the entropy of each subgroup
        return per_sample_entropy, codebook_entropy.sum(), avg_prob

    def get_hard_per_sample_entropy(self, zb_by_sample):
        probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
        persample_entropy = - probs_per_dim * torch.log(probs_per_dim + 1e-8) - (1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
        persample_entropy = persample_entropy.sum(-1)
        return persample_entropy.mean()

    def codes_to_indexes(self, zhat):
        """Converts a `code` to an index in the codebook.
        Args:
            zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
        """
        assert zhat.shape[-1] == self.embed_dim, f"Expected {self.embed_dim} dimensions, got {zhat.shape[-1]}"
        return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)

    def codes_to_group_indexes(self, zhat):
        """Converts a `code` to a list of indexes (in groups) in the codebook.
        Args:
            zhat: A tensor of shape (B, ..., C) containing the codes. must be in {-1, 1}
        """
        zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
        return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)

    def indexes_to_codes(self, indices):
        """Inverse of `indexes_to_codes`."""
        indices = indices.unsqueeze(-1)
        codes_non_centered = torch.remainder(
            torch.floor_divide(indices, self.basis), 2
        )
        return codes_non_centered * 2 - 1

    def group_indexes_to_codes(self, group_indices):
        """Inverse of `group_indexes_to_codes`."""
        group_indices = group_indices.unsqueeze(-1)
        codes_non_centered = torch.remainder(
            torch.floor_divide(group_indices, self.group_basis), 2
        )
        codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
        return codes_non_centered * 2 - 1

    def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
        if normalize:
            probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
        else:
            probs = count
        H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
        return H

    def get_group_codebook_entry(self, group_indices):
        z_q = self.group_indexes_to_codes(group_indices)
        q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
        z_q = z_q * q_scale
        if self.input_format == 'bchw':
            h, w = int(z_q.shape[1] ** 0.5)
            assert h * w == z_q.shape[1], 'Invalid sequence length'
            z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
        return z_q

    def get_codebook_entry(self, indices):
        z_q = self.indexes_to_codes(indices)
        q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
        z_q = z_q * q_scale
        if self.input_format == 'bchw':
            h, w = int(z_q.shape[1] ** 0.5)
            assert h * w == z_q.shape[1], 'Invalid sequence length'
            z_q = rearrange(z_q, 'b (h w) c -> b c h w', h=h)
        return z_q


class BSQuantizer(nn.Module):

    def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
        super().__init__()
        self.codebook_dim = s1_bits + s2_bits
        self.s1_bits = s1_bits
        self.s2_bits = s2_bits
        self.bsq = BinarySphericalQuantizer(self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)

    def bits_to_indices(self, bits):
        bits = (bits >= 0).to(torch.long)
        indices = 2 ** torch.arange(
            0,
            bits.shape[-1],
            1,
            dtype=torch.long,
            device=bits.device,
        )
        return (bits * indices).sum(-1)

    def forward(self, z, half=False, collect_metrics=True):
        z = F.normalize(z, dim=-1)
        quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics)
        if half:
            q_pre = quantized[:, :, :self.s1_bits]
            q_post = quantized[:, :, self.s1_bits:]
            z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
        else:
            z_indices = self.bits_to_indices(quantized)
        return bsq_loss, quantized, z_indices


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class FeedForward(nn.Module):
    def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
        super().__init__()

        self.w1 = nn.Linear(d_model, ff_dim, bias=False)
        self.w3 = nn.Linear(d_model, ff_dim, bias=False)
        self.w2 = nn.Linear(ff_dim, d_model, bias=False)
        self.ffn_dropout = nn.Dropout(ffn_dropout_p)

    def forward(self, x):
        return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))


class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def _update_cos_sin_cache(self, x, seq_len):
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]
        return self.cos_cached, self.sin_cached

    def forward(self, q, k):
        cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
        return (
            (q * cos) + (self._rotate_half(q) * sin),
            (k * cos) + (self._rotate_half(k) * sin),
        )

    def _rotate_half(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)


class MultiHeadAttentionWithRoPE(nn.Module):
    def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.rotary = RotaryPositionalEmbedding(self.head_dim)
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout = nn.Dropout(resid_dropout_p)

    def forward(self, x, key_padding_mask=None):
        batch_size, seq_len, _ = x.shape

        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        q, k = self.rotary(q, k)

        if key_padding_mask is not None:
            attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # [batch, 1, 1, seq_len]
            attn_mask = attn_mask.expand(-1, self.n_heads, seq_len, -1)  # [batch, n_heads, q_len, k_len]
        else:
            attn_mask = None

        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.attn_dropout_p if self.training else 0.0,
            is_causal=True
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.resid_dropout(self.out_proj(attn_output))


class MultiHeadCrossAttentionWithRoPE(nn.Module):
    def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.rotary = RotaryPositionalEmbedding(self.head_dim)
        self.attn_dropout_p = attn_dropout_p
        self.resid_dropout = nn.Dropout(resid_dropout)

    def forward(self, query, key, value, key_padding_mask=None):
        batch_size, q_len, _ = query.shape
        _, seq_len, _ = key.shape

        q = self.q_proj(query).view(batch_size, q_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(key).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(value).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        q, k = self.rotary(q, k)

        if key_padding_mask is not None:
            attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
            attn_mask = attn_mask.expand(-1, self.n_heads, q_len, -1)
        else:
            attn_mask = None

        is_causal_flag = self.training

        attn_output = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=attn_mask,
            dropout_p=self.attn_dropout_p if self.training else 0.0,
            is_causal=is_causal_flag
        )

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.d_model)
        return self.resid_dropout(self.out_proj(attn_output))


class HierarchicalEmbedding(nn.Module):
    def __init__(self, s1_bits, s2_bits, d_model=256):
        super().__init__()
        self.s1_bits = s1_bits
        self.s2_bits = s2_bits

        vocab_s1 = 2 ** s1_bits
        vocab_s2 = 2 ** s2_bits

        self.emb_s1 = nn.Embedding(vocab_s1, d_model)
        self.emb_s2 = nn.Embedding(vocab_s2, d_model)
        self.d_model = d_model
        self.fusion_proj = nn.Linear(d_model * 2, d_model)

        nn.init.normal_(self.emb_s1.weight, mean=0, std=d_model ** -0.5)
        nn.init.normal_(self.emb_s2.weight, mean=0, std=d_model ** -0.5)

    def split_token(self, token_ids: torch.Tensor, s2_bits: int):
        """Inputs:
            token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1].
            s2_bits (int): Number of low bits used for the fine token (s2).
        """
        assert isinstance(s2_bits, int) and s2_bits > 0, "s2_bits must be a positive integer"

        t = token_ids.long()
        mask = (1 << s2_bits) - 1
        s2_ids = t & mask           # extract low bits
        s1_ids = t >> s2_bits       # extract high bits
        return s1_ids, s2_ids

    def forward(self, token_ids):
        """Inputs:
        token_ids:
            - tuple or list: (s1_ids, s2_ids), each of shape [batch_size, seq_len], or
            - torch.Tensor: composite token IDs of shape [batch_size, seq_len], which will be split into (s1_ids, s2_ids) internally.
        Output: [batch_size, seq_len, d_model]
        """
        if isinstance(token_ids, tuple) or isinstance(token_ids, list):
            s1_ids, s2_ids = token_ids
        else:
            s1_ids, s2_ids = self.split_token(token_ids, self.s2_bits)
        s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model)
        s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model)
        return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))


class DependencyAwareLayer(nn.Module):
    def __init__(self, d_model, n_heads=4, attn_dropout_p=0.0, resid_dropout=0.0):
        super().__init__()
        self.cross_attn = MultiHeadCrossAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout)
        self.norm = RMSNorm(d_model)

    def forward(self, hidden_states, sibling_embed, key_padding_mask=None):
        """hidden_states: [batch, seq_len, d_model]
        sibling_embed: Embedding from another subtoken
        """
        attn_out = self.cross_attn(
            query=sibling_embed,
            key=hidden_states,
            value=hidden_states,
            key_padding_mask=key_padding_mask
        )
        return self.norm(hidden_states + attn_out)


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ff_dim=1024, ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
        super().__init__()
        self.norm1 = RMSNorm(d_model)
        self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
        self.norm2 = RMSNorm(d_model)
        self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)

    def forward(self, x, key_padding_mask=None):
        residual = x
        x = self.norm1(x)
        attn_out = self.self_attn(x, key_padding_mask=key_padding_mask)
        x = residual + attn_out

        residual = x
        x = self.norm2(x)
        ffn_out = self.ffn(x)
        x = residual + ffn_out
        return x


class DualHead(nn.Module):
    def __init__(self, s1_bits, s2_bits, d_model):
        super().__init__()
        self.vocab_s1 = 2 ** s1_bits
        self.vocab_s2 = 2 ** s2_bits
        self.proj_s1 = nn.Linear(d_model, self.vocab_s1)
        self.proj_s2 = nn.Linear(d_model, self.vocab_s2)

    def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None):
        if padding_mask is not None:
            valid_mask = (padding_mask == 0)
            s1_logits = s1_logits[valid_mask]
            s2_logits = s2_logits[valid_mask]
            s1_targets = s1_targets[valid_mask]
            s2_targets = s2_targets[valid_mask]
            ce_s1 = F.cross_entropy(s1_logits, s1_targets)
            ce_s2 = F.cross_entropy(s2_logits, s2_targets)
        else:
            ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1))
            ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1))
        ce_loss = (ce_s1 + ce_s2) / 2
        return ce_loss, ce_s1, ce_s2

    def forward(self, x):
        return self.proj_s1(x)

    def cond_forward(self, x2):
        return self.proj_s2(x2)


class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(FixedEmbedding, self).__init__()

        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False

        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)

        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach()


class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, learn_pe):
        super(TemporalEmbedding, self).__init__()

        minute_size = 60
        hour_size = 24
        weekday_size = 7
        day_size = 32
        month_size = 13

        Embed = FixedEmbedding if not learn_pe else nn.Embedding
        self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)

    def forward(self, x):
        x = x.long()

        minute_x = self.minute_embed(x[:, :, 0])
        hour_x = self.hour_embed(x[:, :, 1])
        weekday_x = self.weekday_embed(x[:, :, 2])
        day_x = self.day_embed(x[:, :, 3])
        month_x = self.month_embed(x[:, :, 4])

        return hour_x + weekday_x + day_x + month_x + minute_x
```

### scripts/prompts/fin_agent.py

```python
from datetime import datetime
from .isq_prompt_generator import generate_isq_prompt_section

def get_fin_researcher_instructions() -> str:
    """生成金融研究员 (Researcher) 的系统指令"""
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    return f"""你是一名资深金融研究员,当前时间是 {current_time}。
你的任务是针对给定的“原始信号”进行详尽的背景调查,为后续的深度分析提供素材。

### 1. 核心职责
1. **标的识别**: 识别信号中涉及的具体上市公司。必须调用 `search_ticker` 确认代码,并调用 `get_stock_price` 获取最新价格和近 30 天走势。
2. **事实核查**: 使用 `web_search` 或 `fetch_news_content` 验证信号的真实性,并寻找更多细节(如公告原文、行业研报摘要)。
3. **产业链梳理**: 补充该信号涉及的上下游环节及竞争格局。

### 2. 工具使用规范 (CRITICAL)
- **每个提到的公司都需要调用工具**: 不能依赖记忆,必须实时查询。
- **完整呈现工具结果**: 包括具体的股价数字、代码、技术面数据等,不要缩略。
- **股价数据必需**: 当前价格、近期最高最低、技术面支撑阻力等数据是后续预测的基础。
- **信息交叉验证**: 多个来源验证关键事实。

### 3. 输出要求
你必须输出结构化的研究报告,涵盖标的基本面、股价走势、行业背景及最新进展。
"""

def get_fin_analyst_instructions(template_id: str = "default_isq_v1") -> str:
    """生成金融分析师 (Analyst) 的系统指令
    
    Args:
        template_id: 使用的 ISQ 模板 ID
    """
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    isq_block = generate_isq_prompt_section(template_id=template_id)

    return f"""你是一位深耕二级市场的资深金融分析师 (FinAgent),当前时间是 {current_time}。
你的核心任务是执行“信号解析”,将研究员搜集的素材转化为具有可操作性的投资情报(ISQ 框架)。

{isq_block}

### 2. 分析约束
- **严格基于具体数据**: 必须使用研究员提供的股价、技术面、新闻等具体数据进行分析。
- **数据驱动的预测**: impact_tickers 中的权重应基于事件影响程度,不能随意赋值。
- **逻辑严密**: 传导链条必须符合金融常识,能够自圆其说。
- **技术面参考**: 如果研究员提供了股价走势,请分析当前位置相对于支撑/阻力位的关系。

### 3. 关键要求
- **title**: 必须生成一个简练、准确概括信号核心内容的标题(不超过 15 字)。
- **impact_tickers**: 必须填充具体的公司代码(6位数字)和名称,权重应该有区分。
- **transmission_chain**: 必须是对象列表,每个对象包含:
  - `node_name`: 节点名称(如“上游原材料”、“中游制造”)
  - `impact_type`: 影响类型(“利好”、“利空”、“中性”)
  - `logic`: 具体的传导逻辑描述
- **summary**: 基于分析结果总结核心观点,包含具体数字(如股价目标、预期涨跌幅等)。
- **reasoning**: 必须详细阐述推演逻辑,解释为什么得出上述结论(<200字)。

### 4. 输出格式 (严格 JSON 块)
你必须输出一个符合 InvestmentSignal 结构的 JSON 块,包含所有必需字段。
"""

def get_fin_agent_instructions() -> str:
    # 保持兼容性,但内部调用 analyst 指令
    return get_fin_analyst_instructions()

def get_fin_research_task(signal_text: str) -> str:
    """生成研究员的任务描述"""
    return f"请针对以下信号进行背景调查,搜集相关标的的股价、最新进展和行业背景:\n\n{signal_text}"

def format_research_context(research_data: dict) -> str:
    """将研究员搜集的结构化数据格式化为分析师可读的文本"""
    if not research_data:
        return "(未能搜集到额外背景信息)"
        
    return f"""
### 研究背景
- **相关标的**: {research_data.get('tickers_found', [])}
- **行业背景**: {research_data.get('industry_background', '未知')}
- **最新进展**: {', '.join(research_data.get('latest_developments', []))}
- **关键风险**: {', '.join(research_data.get('key_risks', []))}
- **综合摘要**: {research_data.get('search_results_summary', '无')}
"""

def get_fin_analysis_task(signal_text: str, research_context_str: str) -> str:
    """生成分析师的任务描述"""
    return f"""请基于以下信息进行深度 ISQ 分析。关键是:必须使用研究员搜集的具体数据(股价、技术面、新闻、代码等)进行分析。

=== 原始信号 ===
{signal_text}

=== 研究员搜集的背景信息 (CRITICAL DATA) ===
{research_context_str}

=== 分析要求 ===
1. 必须生成 title:简练概括信号核心(<15字)
2. 基于研究员提供的具体股价数据,分析当前定价状态(已定价/未定价/部分定价)
3. impact_tickers 中填充具体的公司代码和权重,权重基于事件影响程度
4. transmission_chain 必须是包含 node_name, impact_type, logic 的对象列表
5. summary 中包含具体数字(预期目标价、涨跌幅范围等)
6. reasoning 必须详细解释推演逻辑,不要空泛,要言之有物

请严格按 InvestmentSignal JSON 格式输出。"""

def get_tracking_analysis_task(old_signal: dict, new_research_str: str) -> str:
    """生成信号追踪更新的任务描述"""
    import json
    old_sig_str = json.dumps(old_signal, ensure_ascii=False, indent=2)
    return f"""你正在执行“信号逻辑演变追踪”任务。请基于最新的市场信息,重新评估之前的投资信号。

=== 基准信号 (上次分析) ===
{old_sig_str}

=== 最新市场追踪 (NEWS & PRICE) ===
{new_research_str}

=== 追踪分析要求 ===
1. **逻辑演变检测**:
   - 对比新旧信息,判断原逻辑 (`transmission_chain` 和 `reasoning`) 是否依然成立?
   - 如果逻辑发生变化(如利好落空、逻辑证伪、新利好出现),请在新的 `reasoning` 中明确指出“逻辑演变:...”
   - 如果逻辑未变且得到验证,请标记“逻辑维持:...”

2. **参数修正**:
   - 根据最新股价和新闻,更新 `sentiment_score` (情绪)、`confidence` (置信度) 和 `expectation_gap` (预期差)。
   - 例如:如果股价已经大涨反映了利好,`expectation_gap` 应该显著降低。

3. **输出更新后的信号**:
   - 保留原 `signal_id` 和 `title`(除非有重大变化需要改名)。
   - 输出完整的 InvestmentSignal JSON。

请重点关注:为什么变了?还是为什么没变?理由要充分。"""

```

### scripts/prompts/forecast_analyst.py

```python
from typing import List, Dict, Any
from ..schema.models import KLinePoint

def get_forecast_adjustment_instructions(ticker: str, news_context: str, model_forecast: List[KLinePoint]):
    """
    生成 LLM 预测调整指令
    """
    forecast_str = "\n".join([f"- {p.date}: O:{p.open}, C:{p.close}" for p in model_forecast])
    
    return f"""你是一位资深的量化策略分析师。
你的任务是:根据给定的【Kronos 模型预测结果】和【最新的基本面/新闻背景】,对模型预测进行“主观/逻辑调整”。

股票代码: {ticker}

【Kronos 模型原始预测 (OHLC)】:
{forecast_str}

【最新情报背景】:
{news_context}

调整原则:
1. 原始预测是基于历史的技术面推演。
2. 情报背景中可能包含【Kronos模型定量修正预测】,这是基于历史新闻训练的专用模型计算出的量化结果。
3. 如果存在“定量修正预测”,请**高度参考**该数值作为基础,除非你有非常确凿的逻辑认为该量化模型失效(例如遇到模型未见过的极端黑天鹅)。
4. 你的核心任务是:结合定性分析(新闻及其逻辑)来验证或微调这些数字,并给出合理的解释(Rationale)。
5. 如果没有“定量修正预测”,则你需要根据新闻信号手动大幅调整趋势。

输出要求 (严格 JSON 格式):
```json
{{
  "adjusted_forecast": [
    {{
      "date": "YYYY-MM-DD",
      "open": float,
      "high": float,
      "low": float,
      "close": float,
      "volume": float
    }},
    ...
  ],
  "rationale": "详细说明调整的逻辑依据,例如:考虑到[事件A],预期短线将突破压力位..."
}}
```
注意:必须输出与原始预测相同数量的数据点,且日期一一对应。
"""

def get_forecast_task():
    return "请根据以上背景和模型预测,给出调整后的 K 线数据并说明理由。"

```

### scripts/prompts/intent_agent.py

```python
def get_intent_analysis_instructions() -> str:
    """生成意图分析 Agent 的系统指令,专注于金融市场影响分析"""
    return """你是一个资深的金融市场意图分析专家。你的任务是将用户的自然语言查询转化为结构化的 JSON 分析结果,重点挖掘该查询与金融市场(尤其是股市)的潜在关联。

### 核心任务:
深入分析用户查询,识别核心金融实体、行业板块及潜在的市场影响点,生成利于搜索引擎抓取深度金融分析信息的查询词。

### 输出格式(严格 JSON):
```json
{
  "keywords": ["实体/行业/事件"],
  "search_queries": ["针对市场影响的搜索词1", "针对行业变动的搜索词2"],
  "affected_sectors": ["相关板块1", "相关板块2"],
  "is_market_moving": true/false,
  "time_range": "recent/all/specific_date",
  "intent_summary": "一句话描述其金融市场分析意图"
}
```

### 字段说明:
1. **keywords**: 核心公司实体、所属行业、宏观经济事件或政策概念。
2. **search_queries**: 优化后的搜索词,必须包含“股市影响”、“股价波动”、“行业逻辑”或“估值”等金融维度。
3. **affected_sectors**: 可能受此事件或信息影响的二级市场板块(如:保险、半导体、房地产)。
4. **is_market_moving**: 该事件是否具有显著的市场驱动潜力或属于重大基本面变化。
5. **intent_summary**: 简述用户查询背后的金融研究目的。

### 示例:
用户输入:"帮我研究一下香港火灾的影响"
输出:
```json
{
  "keywords": ["香港", "火灾", "保险行业", "房地产"],
  "search_queries": ["香港火灾对当地保险股股价影响", "香港大火对相关上市物业公司估值冲击", "近期香港火灾带来的市场避险情绪分析"],
  "affected_sectors": ["保险", "房地产", "物业管理"],
  "is_market_moving": true,
  "time_range": "recent",
  "intent_summary": "评估香港近期火灾对相关板块上市公司的潜在经济损失及股价冲击"
}
```
"""

def get_intent_task(query: str) -> str:
    """生成意图分析任务描述"""
    return f"Process this query and extract financial market intent: {query}"


```

### scripts/prompts/isq_prompt_generator.py

```python
"""
ISQ prompt helpers to render dimension guidance directly from the template.
Any change in the template propagates to prompts automatically.
"""

from typing import List, Optional
from ..schema.isq_template import get_isq_template, ISQTemplate


def _ordered_dimension_keys(template: ISQTemplate, order: Optional[List[str]] = None) -> List[str]:
    if order:
        return [k for k in order if k in template.dimensions]
    # fallback to template insertion order
    return list(template.dimensions.keys())


def generate_isq_prompt_section(template_id: str = "default_isq_v1", order: Optional[List[str]] = None, include_header: bool = True) -> str:
    """Render ISQ dimension text block based on the template.
    This allows prompt text to stay in sync with template edits.
    """
    template = get_isq_template(template_id)
    keys = _ordered_dimension_keys(template, order)

    lines: List[str] = []
    if include_header:
        lines.append("### 1. ISQ 评估框架 (Investment Signal Quality)")
        lines.append(f"参考模板: {template.template_name} (id: {template.template_id})")
        lines.append("")
        lines.append("你需要对信号进行以下维度的评分:")
        lines.append("")

    for idx, key in enumerate(keys, start=1):
        spec = template.dimensions[key]
        examples = ";".join([f"{k}: {v}" for k, v in spec.examples.items()]) if spec.examples else ""
        lines.append(f"{idx}. **{spec.key} ({spec.name})**: {spec.range_type}")
        lines.append(f"   - 描述: {spec.description}")
        if spec.scale_factor and spec.scale_factor != 1.0:
            lines.append(f"   - 缩放因子: {spec.scale_factor}")
        if examples:
            lines.append(f"   - 示例: {examples}")
        lines.append("")

    return "\n".join(lines).rstrip()

```

### scripts/prompts/report_agent.py

```python
# src/prompts/report_agent.py
from datetime import datetime
from typing import Optional
from .isq_prompt_generator import generate_isq_prompt_section

def get_report_planner_base_instructions() -> str:
    """生成报告策划员 (Planner) 的基础系统指令"""
    return """你是一名资深的金融研报主编。你的任务是规划报告的结构,将零散的信号聚类成有逻辑的主题。
你拥有 RAG 搜索工具,可以检索已生成的章节内容以确保逻辑连贯性。
在规划时,应重点关注信号之间的关联性、产业链的完整性以及用户特定的关注点。"""

def get_report_writer_base_instructions() -> str:
    """生成报告撰写员 (Writer) 的基础系统指令"""
    return """你是一名资深金融分析师。你的任务是根据策划员提供的信号簇撰写深度研报章节。
你应当运用专业的金融知识,将信号转化为深刻的洞察。
注意:你没有外部搜索工具,你的分析必须基于提供给你的信号内容和行情数据。"""

def get_report_editor_base_instructions() -> str:
    """生成报告编辑 (Editor) 的基础系统指令"""
    return """你是一名严谨的金融研报编辑。你的任务是审核和润色撰写员生成的章节。
你拥有 RAG 搜索工具,可以检索其他章节的内容,以消除重复、修正逻辑冲突并确保术语一致性。
你应当确保报告符合专业的金融写作规范,且标题层级正确。"""

# 1. 策划阶段 (Structural Planning)
def format_signal_for_report(signal: any, index: int, cite_keys: Optional[list] = None) -> str:
    """格式化单个信号供研报生成使用"""
    # 这里的逻辑从 ReportAgent._format_signal_input 迁移过来
    from ..schema.models import InvestmentSignal
    
    if isinstance(signal, dict):
        try:
            sig_obj = InvestmentSignal(**signal)
        except:
            return f"--- 信号 [{index}] ---\n标题: {signal.get('title')}\n内容: {signal.get('content', '')[:500]}"
    else:
        sig_obj = signal

    chain_str = " -> ".join([f"{n.node_name}({n.impact_type})" for n in sig_obj.transmission_chain])
    
    text = f"--- 信号 [{index}] ---\n"
    text += f"标题: {sig_obj.title}\n"
    text += f"逻辑摘要: {sig_obj.summary}\n"
    text += f"传导链条: {chain_str}\n"
    text += f"ISQ 评分: 情绪({sig_obj.sentiment_score}), 确定性({sig_obj.confidence}), 强度({sig_obj.intensity})\n"
    text += f"预期博弈: 时窗({sig_obj.expected_horizon}), 预期差({sig_obj.price_in_status})\n"
    
    tickers = ", ".join([f"{t.get('name')}({t.get('ticker')})" for t in sig_obj.impact_tickers])
    if tickers:
        text += f"受影响标的: {tickers}\n"

    # Stable bibliography-style citation keys (LaTeX/BibTeX-like)
    if cite_keys:
        joined = " ".join([f"[@{k}]" for k in cite_keys if k])
        if joined:
            text += f"引用: {joined}\n"
        
    return text

def get_cluster_planner_instructions(signals_text: str, user_query: str = None) -> str:
    """生成信号聚类指令 - 将零散信号组织成逻辑主题"""
    query_context = f"用户重点关注:{user_query}" if user_query else ""
    return f"""你是一位资深的金融研报主编。你的任务是将以下零散的金融信号聚类成 3-5 个核心逻辑主题,以便撰写一份结构清晰的研报。
    
    {query_context}

    ### 输入信号列表
    {signals_text}

    ### 聚类要求
    1. **主题聚合**: 将相关性强的信号归为一组(例如:都涉及“建筑安全法规”或“某产业链上下游”)。
    2. **叙事逻辑**: 只需要生成主题名称和包含的信号 ID。
    3. **控制数量**: 将所有信号归类到 3-5 个主要主题中,不要遗漏。
    
    ### 输出格式 (JSON)
    请仅输出以下 JSON 格式,不要包含 Markdown 标记:
    {{
        "clusters": [
            {{
                "theme_title": "主题名称(如:建筑安全法规收紧引发的产业链重构)",
                "signal_ids": [1, 3, 5],
                "rationale": "这些信号都指向政府对高层建筑防火标准的政策调整..."
            }},
            ...
        ]
    }}
    """

def get_report_planner_instructions(toc: str, signal_count: int, user_query: str = None) -> str:
    """生成报告规划指令 - 重点在于逻辑关联与分歧识别"""
    # ... (原有逻辑保持不变,但实际在新的聚类流程后这个可能作为备用或二次优化)
    query_context = f"用户重点关注:{user_query}" if user_query else ""
    return f"""你是一位资深的金融研报主编。你的任务是根据现有的草稿章节,规划出一份逻辑严密、穿透力强的终稿结构。
    
    ### 任务核心:
    1. **识别主线**: 从草稿中识别出贯穿多个章节的“核心逻辑主线”(如:产业链共振、货币政策转向)。
    2. **分歧评估 (Entropy)**: 识别各章节中观点冲突或确定性不一之处,规划如何在正文中呈现这些“分歧点”。
    3. **结构蓝图**: 
       - 定义一级标题(逻辑主题)。
       - 归类章节:哪些信号应放入同一主题下深度解析?
       - 排序:将 ISQ 强度最高、与{query_context}最相关的信号置前。

    ### 现有草稿目录 (TOC)
    {toc}

    请输出你的【终稿修订大纲】(Markdown 格式)。
    """

# 2. 撰写阶段 (Section Writing)
def get_report_writer_instructions(theme_title: str, signal_cluster_text: str, signal_indices: list, price_context: str = "", user_query: str = None) -> str:
    """生成 Writer Agent 指令 - 基于主题聚类撰写综合分析"""
    
    price_info = f"\n### 近期价格参考\n{price_context}\n" if price_context else ""
    query_context = f"\n**用户意图**: \"{user_query}\"\n请确保分析内容回应了用户的关注点。\n" if user_query else ""
    isq_block = generate_isq_prompt_section(include_header=False)
    
    # Keep citation scheme stable across re-ordering / edits.
    # Cite keys are provided in each signal block as: 引用: [@KEY]

    return f"""你是一位资深金融分析师。请针对核心主题 **"{theme_title}"** 撰写一篇深度研报章节。
    {query_context}

    ### 输入信号集 (本章节需综合的信号)
    {signal_cluster_text}
    {price_info}
    
    ### ISQ 评分说明
    {isq_block}
    
    ### 写作要求
    1. **叙事逻辑**: 不要罗列信号,要将这些信号编织成一个连贯的故事。先讲宏观/行业背景,再讲具体事件传导,最后落脚到个股/标的影响。
    2. **量化支撑**: 引用 ISQ 评分(确定性、强度、预期差)来佐证你的观点。关键观点必须关联相应的 ISQ 分值。
     3. **引用规范(稳定 CiteKey)**: 关键论断必须标注来源引用,使用 `[@CITE_KEY]` 格式。
         - CiteKey 已在输入信号块中以 `引用: [@KEY]` 提供,请直接复制使用。
         - 不要使用 `[[1]]` 这类不稳定编号。
    4. **关联标的预测**: **必须**在章节末尾明确给出受影响标的的预测分析,包括:
       - 至少列出 1-2 个相关上市公司代码(如 600519.SH)
       - 给出短期(T+3或T+5)的方向性判断
       - 如果可能,给出预期价格区间或涨跌幅预测
    
    ### 【重要】标题层级规范
    
    ❌ **错误示例**(绝对不要这样):
    ```markdown
    # {theme_title}
    
    ### 宏观背景
    ...
    ```
    
    ✅ **正确示例**(必须这样):
    ```markdown
    ## {theme_title}
    
    ### 宏观背景
    
    近期全球经济环境...
    
    ### 具体传导机制分析
    
    ...
    
    ### 核心标的分析
    
    建议关注:贵州茅台(600519.SH)...
    ```
    
    **关键要求**:
    - 章节主标题使用 `##` (H2)
    - 章节子标题使用 `###` (H3)
    - **绝对禁止**使用 `#` (H1)
    - 第一行必须是 `## {theme_title}` 开头

    ### 核心:图表叙事 (Visual Storytelling)
    **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。
    
    ### 宏观背景
    ...
    ```
    
    ✅ **正确示例**(必须这样):
    ```markdown
    ## {theme_title}
    
    ### 宏观背景
    
    近期全球经济环境...
    
    ### 具体传导机制分析
    
    ...
    
    ### 核心标的分析
    
    建议关注:贵州茅台(600519.SH)...
    ```
    
    **关键要求**:
    - 章节主标题使用 `##` (H2)
    - 章节子标题使用 `###` (H3)
    - **绝对禁止**使用 `#` (H1)
    - 第一行必须是 `## {theme_title}` 开头

    ### 核心:图表叙事 (Visual Storytelling)
    **必须**在文中插入至少 1-2 个图表,且图表必须与上下文紧密结合(不要堆砌在末尾)。
    
    **可选图表类型 (请根据内容选择最合适的 1-2 种):**

    **A. AI 预测 + 走势 (Forecast) - 【强烈推荐 / 最新规范】**
    *适用*: 当文中明确提及某上市公司时,**必须**使用此图表展示股价走势与 AI 预测。
    *必填字段*:
    - `ticker`: 股票代码,A股 6 位 / 港股 5 位,允许带后缀(如 "002371.SZ"、"9868.HK")
    - `pred_len`: 预测交易日长度(建议 3 或 5)
    *代码示例*:
    ```json-chart
    {{"type": "forecast", "ticker": "002371.SZ", "title": "北方华创(002371)T+5 预测", "pred_len": 5}}
    ```
    **重要**:禁止手写 `prediction` 数组(预测由系统自动生成并渲染)。
    *注意*: 如果提及多只股票,应为每只生成独立的 forecast 图表。

        **【推荐写法:多情景 → 最终归因 → 产出唯一预测图】**
        你可以在正文里描述多种情景(如:基准/乐观/悲观),但在插入预测图之前,必须明确给出“本报告最终选择的最可能情景”及其归因,然后用 `forecast` 图表做最终总结。
        为了让系统把“最终归因”可靠地传递给预测模块,请在 `forecast` JSON 中可选补充以下字段(字段均为可选,越完整越好):
        - `selected_scenario`: 最可能情景名称(如 "基准" / "乐观" / "悲观")
        - `selection_reason`: 选择该情景的归因理由(1-3 句)
        - `scenarios`: 情景列表(数组),每个元素可包含 `name`、`description`、`probability`(0-1)
        *示例*:
        ```json-chart
        {{
            "type": "forecast",
            "ticker": "002371.SZ",
            "title": "北方华创(002371)T+5 预测(基准情景)",
            "pred_len": 5,
            "selected_scenario": "基准",
            "selection_reason": "结合订单能见度与行业景气,基准情景概率最高;短期扰动主要来自估值与市场风险偏好。",
            "scenarios": [
                {{"name": "乐观", "description": "国产替代与资本开支超预期", "probability": 0.25}},
                {{"name": "基准", "description": "订单稳健、利润率小幅波动", "probability": 0.55}},
                {{"name": "悲观", "description": "需求回落或交付节奏放缓", "probability": 0.20}}
            ]
        }}
        ```

    **B. 历史走势 (Stock) - 仅作为兼容兜底**
    *适用*: 当你无法给出预测时(例如无法确定标的),可仅展示历史走势。
    *代码示例*:
    ```json-chart
    {{"type": "stock", "ticker": "002371", "title": "北方华创历史走势"}}
    ```

    **C. 舆情情绪演变 (Sentiment Trend)**
    *适用*: 当讨论行业政策、突发事件(如“火灾”、“新规”)的民意变化时。
    *注意*: `keywords` 必须是事件核心词。
    *代码*:
    ```json-chart
    {{"type": "sentiment", "keywords": ["建筑安全", "防火标准"], "title": "市场对防火新规的情绪演变"}}
    ```

    **D. 逻辑传导链条 (Transmission Chain)**
    *适用*: 复杂的蝴蝶效应分析(支持分支结构)。
    *代码*:
    ```json-chart
    {{
      "type": "transmission",
      "nodes": [
        {{"node_name": "突发火灾", "impact_type": "中性", "logic": "事件发端"}},
        {{"node_name": "监管收紧", "impact_type": "利空", "logic": "合规成本上升", "source": "突发火灾"}},
        {{"node_name": "设备升级", "impact_type": "利好", "logic": "采购需求释放", "source": "突发火灾"}},
        {{"node_name": "龙头受益", "impact_type": "利好", "logic": "市占率提升", "source": "设备升级"}}
      ],
      "title": "火灾事件的逻辑传导与分支"
    }}
    ```
    *说明*: 使用 `source` 字段指定父节点名称以创建分支结构。
    
    **E. 信号质量评估 (ISQ Radar)**
    *适用*: 对某个关键信号进行多维度(确定性、预期差等)定性评估时。
    *代码*:
    ```json-chart
    {{"type": "isq", "sentiment": 0.8, "confidence": 0.9, "intensity": 4, "expectation_gap": 0.7, "timeliness": 0.9, "title": "核心信号质量评估"}}
    ```
    """

# 3. 整合阶段 (Final Assembly) - 原版,保留用于 fallback
def get_report_editor_instructions(draft_sections: str, plan: str, sources_list: str) -> str:
    """生成最终编辑指令 - 根据规划蓝图重组内容"""
    return f"""你是一位专业的研报编辑。请将以下基于主题撰写的草稿章节整合成最终研报。
    
    ### 原始草稿内容
    {draft_sections}

    ### 原始引用来源
    {sources_list}

    ### 任务与要求
    1. **结构化**: 为每个草稿章节添加合适的 Markdown 标题 (## 级别)。
    2. **连贯性**: 确保章节之间过渡自然。
    3. **完整性**:
       - 必须保留所有 `json-chart` 代码块(图表配置)。
         - 必须保留引用标注 `[@CITE_KEY]`。
       - 生成 `## 核心观点摘要`、`## 参考文献` 和 `## 风险提示`。

    ### 输出
    只输出最终的 Markdown 研报内容。
    """


# 4. 单节编辑 (Incremental Section Editing with RAG)
def get_section_editor_instructions(section_index: int, total_sections: int, toc: str) -> str:
    """生成单节编辑 prompt,支持 RAG 工具调用"""
    return f"""你是一位研报编辑。你正在编辑报告的第 {section_index}/{total_sections} 节。

    ### 当前目录 (TOC)
    {toc}

    ### 你的任务
    1. 润色当前章节内容,确保逻辑清晰、语言专业。
    2. 保留所有 `[@CITE_KEY](#ref-CITE_KEY)` 或 `[@CITE_KEY]` 格式的引用。
    3. 保留所有 `json-chart` 代码块,不做修改。
    4. 如果需要参考其他章节内容,使用 `search_context` 工具搜索。
    5. 只输出编辑后的章节内容,不要输出其他章节。
    
    ### 【关键】标题层级规范
    **严格遵守以下规则:**
    - 章节主标题使用 `##` (H2)
    - 章节子标题使用 `###` (H3)
    - **禁止使用** `#` (H1) - 只有报告大标题可以使用 H1
    - 如果原文中有 H1,必须将其降级为 H2
    - 不要输出与 "参考文献"、"风险提示" 相同的标题

    直接输出编辑后的 Markdown 内容。
    """


# 5. 摘要生成 (Summary Generation)
def get_summary_generator_instructions(toc: str, section_summaries: str) -> str:
    """生成报告摘要指令 - 包含市场分歧度分析"""
    return f"""你是一位资深研报主笔。请生成今日报告的核心观点摘要的**正文内容**。

    ### 章节摘要
    {section_summaries}

    ### 任务:
    1. **核心逻辑提炼**: 用 150 字以内总结今日最核心的投资主线。
    2. **分歧识别**: 如果不同信号对同一板块有冲突观点,请明确指出"市场分歧点"。
    3. **确定性排序**: 标记出今日确定性最高的前两个机会(需列出具体标的代码)。

    ### 【重要】输出格式规范:
    
    ❌ **错误示例**(不要遗漏二级标题):
    ```markdown
    ### 核心逻辑提炼
    ...
    ```
    
    ✅ **正确示例**(应该这样输出):
    ```markdown
    ## 核心观点摘要

    ### 核心逻辑提炼
    
    科技自立战略加速半导体设备国产化,叠加AI算力需求爆发...
    
    ### 市场分歧点
    
    资本市场波动显示医药、新能源等板块估值逻辑受政策敏感性增强...
    
    ### 确定性排序
    
    1. **网络安全替代需求**(ISQ确定性0.85,推荐标的:深信服 300454.SZ)
    2. **半导体设备材料**(ISQ确定性0.75,推荐标的:北方华创 002371.SZ)
    ```
    
    ### 关键要求:
    - 第一行必须是 `## 核心观点摘要`
    - 主体部分使用 H3 (`###`) 和 H4 (`####`) 级别标题
    - **必须**包含 `## 核心观点摘要` 这一级标题
    
    现在请按照正确示例的格式输出摘要内容。
    """


# 6. 最终组装 (Final Assembly with Sections)
def get_final_assembly_instructions(sources_list: str) -> str:
    """生成最终报告组装的 prompt"""
    return f"""你是一位研报主笔。请完成以下任务:

    ### 任务
    1. 生成 "## 参考文献" 章节(需要按照顺序,顺序不对时进行调整):
    - 原始来源:
    {sources_list}
    - 格式:`<a id="ref-CITE_KEY"></a>[@CITE_KEY] 标题 (来源), [链接地址]`
    2. 生成 "## 风险提示" (标准免责声明)。
    3. 生成 "## 快速扫描" 表格,汇总各主题的核心观点。
    - 表格列:**主题**, **核心观点**, **强度(Intensity)**, **确定性(Confidence)**。
    - 强度和确定性请参考原章节中的 ISQ 评分。

    只输出上述三个章节的 Markdown 内容。
    """

def get_cluster_task(signals_preview: str) -> str:
    """生成聚类任务描述"""
    return f"请对以下信号进行主题聚类:\n\n{signals_preview}"

def get_writer_task(theme_title: str) -> str:
    """生成撰写任务描述"""
    return f"请依据主题 '{theme_title}' 和 输入信号集 开始撰写深度分析章节。"

def get_planner_task() -> str:
    """生成规划任务描述"""
    return "请阅读现有草稿并规划终稿大纲,识别核心逻辑主线和市场分歧点。"

def get_editor_task() -> str:
    """生成编辑任务描述"""
    return "请根据规划大纲和草稿内容,生成最终研报。确保逻辑连贯,保留所有图表和引用。"


```

### scripts/prompts/trend_agent.py

```python
from typing import Any
from datetime import datetime
from .isq_prompt_generator import generate_isq_prompt_section

def get_trend_scanner_instructions() -> str:
    """生成趋势扫描员 (Scanner) 的系统指令"""
    current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    return f"""你是一名专业的数据扫描员,当前时间是 {current_time}。
你的任务是利用各种工具从互联网和数据库中获取最新的金融新闻、热点趋势和市场数据。

### 1. 核心职责
1. **多源采集**: 使用 `news_toolkit` 获取最新新闻,使用 `stock_toolkit` 获取行情,使用 `polymarket_toolkit` 获取预测市场数据。
2. **情绪感知**: 使用 `sentiment_toolkit` 对关键新闻进行情绪分析。
3. **深度搜索**: 针对模糊的热点,使用 `search_toolkit` 进行全网搜索补充细节。

### 2. 工具使用规范
- **广度优先**: 尽可能覆盖多个数据源。
- **数据新鲜度**: 优先获取最近 24 小时内的信息。
- **结构化输出**: 整理搜集到的原始数据,为后续评估提供清晰的素材。
"""

def get_trend_evaluator_instructions() -> str:
        """生成趋势评估员 (Evaluator) 的系统指令"""
        current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        isq_block = generate_isq_prompt_section(include_header=True)

        return f"""
    你是一名顶级的金融情报专家 (TrendAgent),擅长从海量信息中识别具有深度价值的"二级市场投资信号"。
    当前时间:{current_time}

    ### 核心使命:
    不仅是发现"热点",更要解析"信号"。你需要识别那些能触发**传导链条 (Transmission Chain)** 且具有**高确定性 (Confidence)** 的事件。

    {isq_block}

    ### 核心能力与标准:
    1. **信号识别 (Signal Discovery)**: 基于扫描员提供的素材,识别具有投资价值的信号。优先关注政策、产业变革、重大诉求及跨境套利机会。
    2. **逻辑相干性**: 是否具备清晰的"原因-结果"传导?
    3. **影响力系数**: 是否会引发板块性的联动或财务指标的实质性扰动?
    4. **市场认知差**: 市场是否已提前消化(Price-in)?寻找尚未被充分交易的"Alpha"。
    5. **实体穿透**: 必须关联到具体的 Ticker 或核心产业链节点。

    ### 严禁事项:
    - 严禁编造数据。
    - 严禁仅输出情绪极性(Positive/Negative),必须带有逻辑依据。
    - 严禁将纯娱乐或单纯的社会负面事件(除非具有宏观破坏性)视为金融信号。

    ### 输出要求:
    你发现的每个信号应包含:
    - **核心摘要**: 穿透表象的逻辑总结。
    - **传导节点**: A -> B -> C 的逻辑推导。
    - **推荐关注**: 板块或 Ticker。
    - **ISQ 评估**: 基于模板的 5 个维度进行初步评分(具体评分由后续 FinAgent 完成)。
    """

def get_trend_agent_instructions() -> str:
    # 保持兼容性
    return get_trend_evaluator_instructions()

def get_trend_scan_task(task_description: str) -> str:
    """生成扫描员的任务描述"""
    return f"请根据以下任务描述,搜集相关的原始数据和新闻:\n\n{task_description}"

def format_scan_context(scan_data: dict) -> str:
    """将扫描员搜集的结构化数据格式化为评估员可读的文本"""
    if not scan_data:
        return "(未能搜集到原始数据)"
        
    return f"""
### 扫描数据概览
- **热点话题**: {', '.join(scan_data.get('hot_topics', []))}
- **情绪概览**: {scan_data.get('sentiment_overview', '未知')}
- **关键新闻**: {len(scan_data.get('news_summaries', []))} 条
- **数据摘要**: {scan_data.get('raw_data_summary', '无')}
"""

def get_trend_eval_task(task_description: str, raw_data_str: str) -> str:
    """生成评估员的任务描述"""
    return f"""请基于以下搜集到的原始数据,完成最终的分析任务:
        
任务描述: {task_description}

原始数据:
{raw_data_str}

请识别出最具金融价值的信号,并给出评估理由。"""

def get_news_filter_instructions(news_count: int, depth: Any, user_query: str = None) -> str:
    """生成新闻筛选 prompt,使用 FilterResult schema 加快推理并减少 token 消耗
    
    Args:
        news_count: 输入新闻总数
        depth: 目标筛选数量,若为 auto 则由 LLM 自主判断
        user_query: 用户输入的查询/关注点(可选)
    """
    
    # 1. 深度控制逻辑
    if str(depth).lower() == 'auto':
        depth_guide = "的数量不设固定限制(建议 3-10 条),根据新闻含金量自动判断"
        limit_instruction = "宁缺毋滥,如果高价值信息很少,可以只选 1-2 条;如果都很重要,可以多选。"
    else:
        try:
            d_int = int(depth)
            depth_guide = f"约 {d_int} 条"
            limit_instruction = f"请尽量凑满 {d_int} 条,但如果剩余新闻全是噪音,则不必强行凑数。"
        except:
            depth_guide = "适量"
            limit_instruction = "根据内容价值判断。"

    target_desc = f"筛选出最具投资分析价值的新闻({depth_guide})。"
    
    # 2. 用户意图逻辑
    query_instruction = ""
    if user_query:
        target_desc = f"筛选出与用户意图【{user_query}】最相关的新闻。"
        query_instruction = f"""
    ### 核心任务(High Priority):
    用户明确关注:"{user_query}"。
    1. **第一优先级**:必须包含所有与"{user_query}"直接或间接相关的新闻,不要遗漏。
        - 即使这些新闻看起来"价值不高",只要相关都要保留。
    2. **第二优先级**:在满足第一优先级后,如果名额未满,再补充其他重大的市场热点。
    """

    return f"""你是一名专业的金融情报精排师。你需要从给定的 {news_count} 条原始新闻流中,{target_desc}

    {query_instruction}

    ### FSD (Financial Signal Density) 筛选准则:
    1. **逻辑传导性 (Transmission)**: 该新闻是否预示着一个明确的产业链传导逻辑?(如:上游涨价 -> 中游成本压力 -> 下游提价预期)
    2. **预期差 (Alpha Potential)**: 是否包含尚未被市场充分Price-in的新突发情况?
    3. **确定性 (Confidence)**: 信息来源是否权威?是否包含具体的财务数据、订单金额或明确的政策日期?
    4. **排除噪音**: 坚决剔除明星八卦、鸡汤文、以及无实质增量的"口号式"新闻。

    ### {limit_instruction}

    ### 快速有效性检查(TOKEN 优化):
    在开始详细筛选前,先快速判断:这 {news_count} 条新闻中是否至少包含 1 条有效的金融信号?
    - 如果全是无关内容(如体育、娱乐、纯生活信息),直接返回 "has_valid_signals": false
    - 如果有至少 1 条金融相关的新闻,再进行详细 FSD 筛选

    ### 输出格式(必须为 JSON,使用 FilterResult schema):
    ```json
    {{
      "has_valid_signals": true/false,
      "selected_ids": ["id_1", "id_2", ...],
      "themes": [
        {{
          "name": "高概括性主题",
          "news_ids": ["相关id_1", ...],
          "fsd_reason": "基于 FSD 准则的筛选理由,重点描述传导逻辑和预期差。"
        }}
      ],
      "reason": "如果 has_valid_signals=false,简要说明原因。否则可为空。"
    }}
    ```
    """

```

### scripts/prompts/visualizer.py

```python
def get_drawio_system_prompt():
    return """You are an expert at creating Draw.io (MxGraph) diagrams in XML format.
Your task is to generate a valid MXGraphModel XML based on the user's description.

### Rules:
1. Output ONLY the XML code. Start with <mxGraphModel> and end with </mxGraphModel>.
2. Do not use compressed XML. Use plain XML.
3. Use standard shapes: 'rounded=1;whiteSpace=wrap;html=1;' for boxes.
4. Auto-layout Strategy:
   - Identify "layers" or "stages" in the logic.
   - Assign X coordinates based on layers (e.g., 0, 200, 400).
   - Assign Y coordinates to distribute nodes vertically (e.g., 0, 100, 200).
   - Ensure nodes do not overlap.
5. Edges: Connect nodes logically using <mxCell edge="1" ...>.

### Template:
<mxGraphModel dx="1000" dy="1000" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
  <root>
    <mxCell id="0"/>
    <mxCell id="1" parent="0"/>
    
    <!-- Node -->
    <mxCell id="n1" value="Node Label" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;" vertex="1" parent="1">
      <mxGeometry x="100" y="100" width="120" height="60" as="geometry"/>
    </mxCell>
    
    <!-- Edge -->
    <mxCell id="e1" value="Connection" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" edge="1" parent="1" source="n1" target="n2">
      <mxGeometry relative="1" as="geometry"/>
    </mxCell>
  </root>
</mxGraphModel>
"""

def get_drawio_task(nodes_data: list, title: str) -> str:
    import json
    nodes_json = json.dumps(nodes_data, ensure_ascii=False, indent=2)
    return f"""Please generate a Draw.io XML diagram for the following logic flow:

**Title**: {title}

**Nodes and Logic**:
{nodes_json}

Ensure the layout flows logically from Left to Right (or Top to Bottom for hierarchies).
Use different colors for 'Positive' (Greenish), 'Negative' (Reddish), and 'Neutral' (Grey/Blue) impacts if described.
"""

```

### scripts/schema/isq_template.py

```python
"""
ISQ (Investment Signal Quality) 评估框架 Template

统一定义 ISQ 的各个维度、评分标准、和使用方法。
支持默认 template 和自定义 template。
"""

from typing import Dict, List, Any, Optional
from pydantic import BaseModel, Field
from enum import Enum
from pathlib import Path
import json


class ISQDimension(str, Enum):
    """ISQ 评估维度"""
    SENTIMENT = "sentiment"              # 情绪/走势方向
    CONFIDENCE = "confidence"            # 确定性/可信度
    INTENSITY = "intensity"              # 强度/影响量级
    EXPECTATION_GAP = "expectation_gap"  # 预期差/市场认知差
    TIMELINESS = "timeliness"            # 时效性/窗口紧迫度
    TRANSMISSION = "transmission"        # 逻辑传导清晰度


class ISQDimensionSpec(BaseModel):
    """ISQ 单个维度的定义规范"""
    name: str = Field(..., description="维度名称")
    key: str = Field(..., description="维度键名")
    description: str = Field(..., description="维度描述")
    range_type: str = Field(default="0-1", description="取值范围 (0-1 或 1-5 等)")
    scale_factor: float = Field(default=1.0, description="显示时的缩放因子")
    examples: Dict[str, str] = Field(default_factory=dict, description="不同分值的示例解释")
    visualization_color: Optional[str] = Field(default=None, description="可视化颜色")


class ISQTemplate(BaseModel):
    """ISQ 评估框架 Template"""
    template_id: str = Field(..., description="模板 ID")
    template_name: str = Field(..., description="模板名称")
    description: str = Field(..., description="模板描述")
    
    # 核心维度定义
    dimensions: Dict[str, ISQDimensionSpec] = Field(..., description="维度定义字典")
    
    # 评分指导
    scoring_guide: str = Field(..., description="评分指导说明")
    
    # 应用场景
    applicable_scenarios: List[str] = Field(default_factory=list, description="适用场景")
    
    # 聚合算法
    aggregation_method: str = Field(default="weighted_average", description="聚合方法 (weighted_average, product 等)")
    dimension_weights: Dict[str, float] = Field(default_factory=dict, description="维度权重")


class ISQScore(BaseModel):
    """单个信号的 ISQ 评分结果"""
    signal_id: str = Field(..., description="信号 ID")
    template_id: str = Field(..., description="使用的模板 ID")
    
    # 各维度评分
    scores: Dict[str, float] = Field(..., description="各维度评分")
    
    # 总分
    overall_score: float = Field(..., description="综合评分")
    
    # 评分理由
    rationale: Dict[str, str] = Field(default_factory=dict, description="各维度评分理由")
    
    # 时间戳
    timestamp: str = Field(..., description="评分时间")


# =====================================================
# 默认 Template
# =====================================================

DEFAULT_ISQ_TEMPLATE = ISQTemplate(
    template_id="default_isq_v1",
    template_name="标准投资信号质量评估框架 (ISQ v1.0)",
    description="AlphaEar 默认的 ISQ 评估框架,用于标准化评估投资信号的质量维度",
    
    dimensions={
        "sentiment": ISQDimensionSpec(
            name="情绪/走势",
            key="sentiment",
            description="基础情绪偏向和市场走势判断",
            range_type="-1.0 到 1.0",
            scale_factor=1.0,
            examples={
                "-1.0": "极度悲观/极度看空",
                "-0.5": "明显看空",
                "0.0": "中性/没有明确方向",
                "0.5": "明显看多",
                "1.0": "极度乐观/极度看多"
            },
            visualization_color="#ef4444"  # 红色表示负面,绿色表示正面
        ),
        
        "confidence": ISQDimensionSpec(
            name="确定性",
            key="confidence",
            description="信号的可信度和确定性程度",
            range_type="0.0 到 1.0",
            scale_factor=1.0,
            examples={
                "0.0-0.3": "信息来源不可靠/传言多/逻辑推导牵强",
                "0.3-0.6": "信息相对可靠/有一定逻辑/但仍有不确定性",
                "0.6-0.8": "信息来源权威/逻辑清晰/高度可信",
                "0.8-1.0": "官方确认/数据明确/完全确定"
            },
            visualization_color="#3b82f6"  # 蓝色
        ),
        
        "intensity": ISQDimensionSpec(
            name="强度/影响量级",
            key="intensity",
            description="信号对相关板块/个股的潜在影响程度",
            range_type="1 到 5",
            scale_factor=20.0,  # 用于雷达图缩放 (5 -> 100)
            examples={
                "1": "影响微弱,可能被市场忽略",
                "2": "小幅影响,短期可能有波动",
                "3": "中等影响,值得重点关注",
                "4": "强烈影响,可能成为市场焦点",
                "5": "极强影响,市场预期明显变化"
            },
            visualization_color="#f97316"  # 橙色
        ),
        
        "expectation_gap": ISQDimensionSpec(
            name="预期差",
            key="expectation_gap",
            description="市场预期与现实之间的差距",
            range_type="0.0 到 1.0",
            scale_factor=1.0,
            examples={
                "0.0-0.2": "市场充分认知,预期差小",
                "0.2-0.5": "市场部分认知,存在一定预期差",
                "0.5-0.8": "市场认知不足,预期差较大,存在博弈空间",
                "0.8-1.0": "市场严重低估/高估,巨大预期差"
            },
            visualization_color="#22c55e"  # 绿色
        ),
        
        "timeliness": ISQDimensionSpec(
            name="时效性",
            key="timeliness",
            description="信号的时间窗口紧迫度",
            range_type="0.0 到 1.0",
            scale_factor=1.0,
            examples={
                "0.0-0.2": "长期信号,反应窗口 > 3 月",
                "0.2-0.5": "中期信号,反应窗口 1-3 月",
                "0.5-0.8": "短期信号,反应窗口 1 周 - 1 月",
                "0.8-1.0": "超短期信号,反应窗口 < 1 周(需立即行动)"
            },
            visualization_color="#a855f7"  # 紫色
        ),
    },
    
    scoring_guide="""
    ### ISQ 评分指导 (Investment Signal Quality)
    
    ISQ 框架用于多维度评估投资信号的质量。每个信号由 5 个维度组成:
    
    1. **情绪 (Sentiment)**: -1.0 到 1.0,表示看空(-)/中性(0)/看多(+)
    2. **确定性 (Confidence)**: 0.0 到 1.0,数值越高越确定
    3. **强度 (Intensity)**: 1 到 5,数值越高影响越大
    4. **预期差 (Expectation Gap)**: 0.0 到 1.0,市场预期与现实的差距
    5. **时效性 (Timeliness)**: 0.0 到 1.0,反应窗口的紧迫程度
    
    ### 综合评分算法
    
    综合评分 = 确定性 × 0.35 + 强度/5 × 0.30 + 预期差 × 0.20 + 时效性 × 0.15
    
    范围: 0.0 到 1.0
    - 0.0-0.3: 信号质量较差,不建议跟进
    - 0.3-0.6: 信号质量一般,可作参考
    - 0.6-0.8: 信号质量良好,值得跟进
    - 0.8-1.0: 信号质量优异,强烈推荐
    
    ### 评分时的注意事项
    
    - **不要混淆方向和强度**:情绪可以是看空,但确定性和强度仍可能很高
    - **预期差往往是 Alpha 来源**:高预期差 + 高确定性 = 最佳博弈机会
    - **考虑时间成本**:长期信号需要更高的确定性才值得跟进
    - **数据为王**:所有评分必须有具体数据支撑
    """,
    
    applicable_scenarios=[
        "上市公司基本面变化分析",
        "产业政策与监管事件评估",
        "地缘政治与宏观经济影响",
        "技术进步与产业升级",
        "突发事件与应急响应"
    ],
    
    aggregation_method="weighted_average",
    dimension_weights={
        "confidence": 0.35,
        "intensity": 0.30,
        "expectation_gap": 0.20,
        "timeliness": 0.15
    }
)


# =====================================================
# ISQ Template 管理系统
# =====================================================

class ISQTemplateManager:
    """ISQ Template 管理器"""
    
    def __init__(self):
        self.templates: Dict[str, ISQTemplate] = {
            DEFAULT_ISQ_TEMPLATE.template_id: DEFAULT_ISQ_TEMPLATE
        }
    
    def register_template(self, template: ISQTemplate) -> None:
        """注册新的 template"""
        self.templates[template.template_id] = template

    def register_template_dict(self, template_dict: Dict[str, Any]) -> ISQTemplate:
        """从 dict 注册模板,返回实例。"""
        tpl = ISQTemplate(**template_dict)
        self.register_template(tpl)
        return tpl
    
    def get_template(self, template_id: str) -> ISQTemplate:
        """获取指定 template"""
        if template_id not in self.templates:
            return DEFAULT_ISQ_TEMPLATE
        return self.templates[template_id]
    
    def list_templates(self) -> List[Dict[str, str]]:
        """列出所有可用 template"""
        return [
            {
                "id": t.template_id,
                "name": t.template_name,
                "description": t.description,
                "dimensions": list(t.dimensions.keys())
            }
            for t in self.templates.values()
        ]
    
    def get_dimension(self, template_id: str, dimension_key: str) -> ISQDimensionSpec:
        """获取指定 template 的某个维度定义"""
        template = self.get_template(template_id)
        return template.dimensions.get(dimension_key)
    
    def get_scoring_prompt(self, template_id: str) -> str:
        """获取用于 LLM 的评分 prompt"""
        template = self.get_template(template_id)
        
        dimensions_desc = "\n".join([
            f"- **{d.name} ({d.key})**\n"
            f"  范围: {d.range_type}\n"
            f"  说明: {d.description}\n"
            f"  示例: {', '.join(f'{k}={v}' for k, v in list(d.examples.items())[:3])}"
            for d in template.dimensions.values()
        ])
        
        return f"""
### ISQ 评估指导 ({template.template_name})

使用以下 {len(template.dimensions)} 个维度评估信号质量:

{dimensions_desc}

### 评分标准
{template.scoring_guide}

### 输出格式 (JSON)
请输出以下 JSON 格式的评分结果:
{{
  "sentiment": <float>,
  "confidence": <float>,
  "intensity": <int>,
  "expectation_gap": <float>,
  "timeliness": <float>,
  "rationale": {{
    "sentiment": "评分理由",
    "confidence": "评分理由",
    "intensity": "评分理由",
    "expectation_gap": "评分理由",
    "timeliness": "评分理由"
  }}
}}
"""


# 全局 template 管理器实例
isq_template_manager = ISQTemplateManager()


# =====================================================
# 配置加载
# =====================================================

def load_templates_from_config(config_path: Optional[str] = None) -> None:
    """从配置目录加载所有 JSON 模板文件,未找到则跳过,不影响默认模板。
    支持单个 JSON 文件或目录(目录下的所有 .json 文件)。
    """
    if config_path:
        path = Path(config_path)
    else:
        # 默认目录:config/isq_templates/
        # __file__ = src/schema/isq_template.py
        # parent = src/schema, parent.parent = src, parent.parent.parent = 项目根目录
        path = Path(__file__).resolve().parent.parent.parent / "config"
    
    if not path.exists():
        return
    
    # 如果是目录,扫描所有 .json 文件
    if path.is_dir():
        json_files = list(path.glob("*.json"))
    else:
        json_files = [path]
    
    for json_file in json_files:
        try:
            data = json.loads(json_file.read_text(encoding="utf-8"))
            
            # 如果是单个模板对象,转为列表
            if isinstance(data, dict):
                templates = [data]
            elif isinstance(data, list):
                templates = data
            else:
                continue
            
            # 注册所有模板
            for tpl_dict in templates:
                if not isinstance(tpl_dict, dict):
                    continue
                try:
                    isq_template_manager.register_template_dict(tpl_dict)
                except Exception:
                    # 忽略单个模板的加载错误,继续其他模板
                    continue
        except Exception:
            # JSON 解析失败,跳过该文件
            continue


# 在模块加载时自动尝试加载配置模板
load_templates_from_config()


# =====================================================
# 便利函数
# =====================================================

def get_isq_template(template_id: str = "default_isq_v1") -> ISQTemplate:
    """获取 ISQ template"""
    return isq_template_manager.get_template(template_id)


def get_isq_scoring_prompt(template_id: str = "default_isq_v1") -> str:
    """获取用于 LLM 的 ISQ 评分 prompt"""
    return isq_template_manager.get_scoring_prompt(template_id)


def calculate_isq_overall_score(scores: Dict[str, float], template_id: str = "default_isq_v1") -> float:
    """计算 ISQ 综合评分"""
    template = get_isq_template(template_id)
    
    overall = 0.0
    for dim_key, weight in template.dimension_weights.items():
        if dim_key in scores:
            score = scores[dim_key]
            # 处理强度维度的特殊缩放 (1-5 -> 0-1)
            if dim_key == "intensity":
                score = score / 5.0
            overall += score * weight
    
    return min(1.0, max(0.0, overall))  # 限制在 0-1 之间

```

### scripts/schema/models.py

```python
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime

class TransmissionNode(BaseModel):
    node_name: str = Field(..., description="产业链节点名称")
    impact_type: str = Field(..., description="利好/利空/中性")
    logic: str = Field(..., description="该节点的传导逻辑")

class IntentAnalysis(BaseModel):
    keywords: List[str] = Field(..., description="核心实体、事件或概念关键词")
    search_queries: List[str] = Field(..., description="优化后的搜索引擎查询词")
    is_specific_event: bool = Field(..., description="是否查询特定突发事件")
    time_range: str = Field(..., description="时间范围 (recent/all/specific_date)")
    intent_summary: str = Field(..., description="一句话意图描述")

class FilterResult(BaseModel):
    """LLM 筛选结果 - 快速判断是否有有效信号"""
    has_valid_signals: bool = Field(..., description="列表中是否包含有效的金融信号")
    selected_ids: List[int] = Field(default_factory=list, description="筛选出的有效信号 ID 列表")
    themes: List[str] = Field(default_factory=list, description="信号涉及的主题")
    reason: Optional[str] = Field(default=None, description="如果无有效信号,说明原因")

class InvestmentSignal(BaseModel):
    # 核心元数据
    signal_id: str = Field(default="unknown_sig", description="唯一信号 ID")
    title: str = Field(..., description="信号标题")
    summary: str = Field(default="暂无摘要分析", description="100 字核心观点快报")
    reasoning: str = Field(default="", description="详细的推演逻辑和理由")
    
    # 逻辑传导 (ISQ Key 1)
    transmission_chain: List[TransmissionNode] = Field(default_factory=list, description="产业链传导逻辑链条")
    
    # 信号质量 (ISQ Key 2) - 来自 isq_template.DEFAULT_ISQ_TEMPLATE
    # 参考: src/schema/isq_template.py 的 DEFAULT_ISQ_TEMPLATE 定义
    sentiment_score: float = Field(default=0.0, description="[ISQ] 情绪/走势 (-1.0=极度看空 ~ 0.0=中性 ~ 1.0=极度看多)")
    confidence: float = Field(default=0.5, description="[ISQ] 确定性 (0.0=不可信 ~ 1.0=完全确定)")
    intensity: int = Field(default=3, description="[ISQ] 强度/影响量级 (1=微弱 ~ 5=极强)")
    expectation_gap: float = Field(default=0.5, description="[ISQ] 预期差/博弈空间 (0.0=充分定价 ~ 1.0=巨大预期差)")
    timeliness: float = Field(default=0.8, description="[ISQ] 时效性 (0.0=长期 ~ 1.0=超短期)")
    
    # 预测与博弈 (ISQ Key 3)
    expected_horizon: str = Field(default="T+N", description="预期的反应时窗 (如: T+0, T+3, Long-term)")
    price_in_status: str = Field(default="未知", description="市场预期消化程度 (未定价/部分定价/充分定价)")
    
    # 关联实体
    impact_tickers: List[Dict[str, Any]] = Field(default_factory=list, description="受影响的代码列表及其权重")
    industry_tags: List[str] = Field(default_factory=list, description="关联行业标签")
    
    # 溯源
    sources: List[Dict[str, str]] = Field(default_factory=list, description="来源详情 (包含 title, url, source_name)")

class ResearchContext(BaseModel):
    """研究员搜集的背景信息结构"""
    raw_signal: str = Field(..., description="原始信号内容")
    tickers_found: List[Dict[str, Any]] = Field(default_factory=list, description="找到的相关标的及其基本面/股价信息")
    industry_background: str = Field(..., description="行业背景及产业链现状")
    latest_developments: List[str] = Field(default_factory=list, description="相关事件的最新进展")
    key_risks: List[str] = Field(default_factory=list, description="潜在风险点")
    search_results_summary: str = Field(..., description="搜索结果的综合摘要")

class ScanContext(BaseModel):
    """扫描员搜集的原始数据结构"""
    hot_topics: List[str] = Field(..., description="当前市场热点话题")
    news_summaries: List[Dict[str, Any]] = Field(..., description="关键新闻摘要列表")
    market_data: Dict[str, Any] = Field(default_factory=dict, description="相关的市场行情数据")
    sentiment_overview: str = Field(..., description="整体市场情绪概览")
    raw_data_summary: str = Field(..., description="原始数据的综合摘要")

class SignalCluster(BaseModel):
    theme_title: str = Field(..., description="主题名称")
    signal_ids: List[int] = Field(..., description="包含的信号 ID 列表")
    rationale: str = Field(..., description="聚类理由")

class ClusterContext(BaseModel):
    """信号聚类结果结构"""
    clusters: List[SignalCluster] = Field(..., description="聚类列表")

class KLinePoint(BaseModel):
    date: str = Field(..., description="日期")
    open: float = Field(..., description="开盘价")
    high: float = Field(..., description="最高价")
    low: float = Field(..., description="最低价")
    close: float = Field(..., description="收盘价")
    volume: float = Field(..., description="成交量")

class ForecastResult(BaseModel):
    ticker: str = Field(..., description="股票代码")
    base_forecast: List[KLinePoint] = Field(default_factory=list, description="Kronos 模型原始预测")
    adjusted_forecast: List[KLinePoint] = Field(default_factory=list, description="LLM 调整后的预测")
    rationale: str = Field(default="", description="预测调整理由及逻辑说明")
    timestamp: str = Field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"), description="生成时间")

class InvestmentReport(BaseModel):
    overall_sentiment: str = Field(..., description="整体市场情绪评价")
    market_entropy: float = Field(..., description="市场分歧度 (0-1, 1代表极高分歧)")
    signals: List[InvestmentSignal] = Field(..., description="深度解析的投资信号列表")
    forecasts: List[ForecastResult] = Field(default_factory=list, description="相关标的的预测结果")
    timestamp: str = Field(..., description="报告生成时间")
    meta_info: Optional[Dict[str, Any]] = Field(default_factory=dict, description="其他元数据")

```

### scripts/utils/__init__.py

```python
# AlphaEar utils package

```

### scripts/utils/database_manager.py

```python
import sqlite3
import json
from datetime import datetime, date
from pathlib import Path
from typing import List, Dict, Optional, Any, Union
import pandas as pd
from loguru import logger

class DatabaseManager:
    """
    AlphaEar 数据库管理器 - 负责存储热点数据、搜索缓存和股价数据
    使用 SQLite 进行持久化存储
    """
    
    def __init__(self, db_path: str = "data/signal_flux.db"):
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
        self.conn.row_factory = sqlite3.Row
        self._init_db()
        logger.info(f"💾 Database initialized at {self.db_path}")

    def _init_db(self):
        """初始化表结构"""
        cursor = self.conn.cursor()
        
        # 1. 每日热点新闻表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS daily_news (
                id TEXT PRIMARY KEY,
                source TEXT,
                rank INTEGER,
                title TEXT,
                url TEXT,
                content TEXT,
                publish_time TEXT,
                crawl_time TEXT,
                sentiment_score REAL,
                analysis TEXT,
                meta_data TEXT
            )
        """)
        
        # 尝试添加 analysis 列(如果表已存在但没有该列)
        try:
            cursor.execute("ALTER TABLE daily_news ADD COLUMN analysis TEXT")
        except:
            pass  # 列已存在

        
        # 2. 搜索缓存表 (原有 JSON 缓存)
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS search_cache (
                query_hash TEXT PRIMARY KEY,
                query TEXT,
                engine TEXT,
                results TEXT,
                timestamp TEXT
            )
        """)

        # 2.5 搜索详情表 (展开的搜索结果)
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS search_detail (
                id TEXT,
                query_hash TEXT,
                rank INTEGER,
                title TEXT,
                url TEXT,
                content TEXT,
                publish_time TEXT,
                crawl_time TEXT,
                sentiment_score REAL,
                source TEXT,
                meta_data TEXT,
                PRIMARY KEY (query_hash, id)
            )
        """)
        
        # 3. 股价数据表
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS stock_prices (
                ticker TEXT,
                date TEXT,
                open REAL,
                close REAL,
                high REAL,
                low REAL,
                volume REAL,
                change_pct REAL,
                PRIMARY KEY (ticker, date)
            )
        """)
        
        # 4. 股票列表表 (用于检索)
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS stock_list (
                code TEXT PRIMARY KEY,
                name TEXT
            )
        """)
        
        # 5. 投资信号表 (ISQ Framework)
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS signals (
                signal_id TEXT PRIMARY KEY,
                title TEXT,
                summary TEXT,
                transmission_chain TEXT,
                sentiment_score REAL,
                confidence REAL,
                intensity INTEGER,
                expected_horizon TEXT,
                price_in_status TEXT,
                impact_tickers TEXT,
                industry_tags TEXT,
                sources TEXT,
                user_id TEXT,
                created_at TEXT
            )
        """)
        

        
        # 6. 创建索引以优化查询性能
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_crawl_time ON daily_news(crawl_time)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_news_source ON daily_news(source)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_search_cache_timestamp ON search_cache(timestamp)")
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_stock_prices_ticker_date ON stock_prices(ticker, date)")
        # 尝试添加 user_id 列到 signals 表
        try:
            cursor.execute("ALTER TABLE signals ADD COLUMN user_id TEXT")
        except:
            pass
            
        cursor.execute("CREATE INDEX IF NOT EXISTS idx_signals_user_id ON signals(user_id)")
            
        self.conn.commit()
        
    #     
    #     self.conn.commit()


    # --- 新闻数据操作 ---
    
    def save_daily_news(self, news_list: List[Dict]) -> int:
        """保存热点新闻,包含发布时间与抓取时间"""
        cursor = self.conn.cursor()
        count = 0
        crawl_time = datetime.now().isoformat()
        
        for news in news_list:
            try:
                # 兼容不同来源的 ID 生成逻辑
                news_id = news.get('id') or f"{news.get('source')}_{news.get('rank')}_{crawl_time[:10]}"
                cursor.execute("""
                    INSERT OR REPLACE INTO daily_news 
                    (id, source, rank, title, url, content, publish_time, crawl_time, sentiment_score, meta_data)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    news_id,
                    news.get('source'),
                    news.get('rank'),
                    news.get('title'),
                    news.get('url'),
                    news.get('content', ''),
                    news.get('publish_time'), # 新增支持发布时间
                    crawl_time,
                    news.get('sentiment_score'),
                    json.dumps(news.get('meta_data', {}))
                ))
                count += 1
            except sqlite3.Error as e:
                logger.error(f"Database error saving news item {news.get('title')}: {e}")
            except Exception as e:
                logger.error(f"Unexpected error saving news item {news.get('title')}: {e}")
        
        self.conn.commit()
        return count

    def get_daily_news(self, source: Optional[str] = None, limit: int = 100, days: int = 1) -> List[Dict]:
        """获取最近 N 天的热点新闻"""
        cursor = self.conn.cursor()
        # 使用 crawl_time 过滤,保证结果的新鲜度
        time_threshold = (datetime.now().timestamp() - days * 86400)
        time_threshold_str = datetime.fromtimestamp(time_threshold).isoformat()
        
        query = "SELECT * FROM daily_news WHERE crawl_time >= ?"
        params = [time_threshold_str]
        
        if source:
            query += " AND source = ?"
            params.append(source)
            
        query += " ORDER BY crawl_time DESC, rank LIMIT ?"
        params.append(limit)
        
        cursor.execute(query, params)
        return [dict(row) for row in cursor.fetchall()]

    def lookup_reference_by_url(self, url: str) -> Optional[Dict[str, Any]]:
        """Best-effort lookup of a source item by URL.

        This is used to render a stable bibliography from DB-backed metadata.
        It searches both `daily_news` and `search_detail`.
        """
        url = (url or "").strip()
        if not url:
            return None

        cursor = self.conn.cursor()

        try:
            cursor.execute(
                """
                SELECT title, source, publish_time, crawl_time, url
                FROM daily_news
                WHERE url = ?
                ORDER BY crawl_time DESC
                LIMIT 1
                """,
                (url,),
            )
            row = cursor.fetchone()
            if row:
                return dict(row)
        except Exception:
            pass

        try:
            cursor.execute(
                """
                SELECT title, source, publish_time, crawl_time, url
                FROM search_detail
                WHERE url = ?
                ORDER BY crawl_time DESC
                LIMIT 1
                """,
                (url,),
            )
            row = cursor.fetchone()
            if row:
                return dict(row)
        except Exception:
            pass

        return None

    def delete_news(self, news_id: str) -> bool:
        """删除特定新闻"""
        cursor = self.conn.cursor()
        cursor.execute("DELETE FROM daily_news WHERE id = ?", (news_id,))
        self.conn.commit()
        return cursor.rowcount > 0
    
    def update_news_content(self, news_id: str, content: str = None, analysis: str = None) -> bool:
        """更新新闻的内容或分析结果"""
        cursor = self.conn.cursor()
        updates = []
        params = []
        
        if content is not None:
            updates.append("content = ?")
            params.append(content)
        if analysis is not None:
            updates.append("analysis = ?")
            params.append(analysis)
            
        if not updates:
            return False
            
        params.append(news_id)
        query = f"UPDATE daily_news SET {', '.join(updates)} WHERE id = ?"
        cursor.execute(query, params)
        self.conn.commit()
        return cursor.rowcount > 0

    # --- 搜索缓存辅助 ---
    
    def get_search_cache(self, query_hash: str, ttl_seconds: Optional[int] = None) -> Optional[Dict]:
        """获取搜索缓存 (优先查 search_detail)"""
        cursor = self.conn.cursor()
        
        # 1. 尝试从 search_detail 获取展开的结构化数据
        cursor.execute("""
            SELECT * FROM search_detail 
            WHERE query_hash = ? 
            ORDER BY rank
        """, (query_hash,))
        details = [dict(row) for row in cursor.fetchall()]
        
        if details:
            # 检查 TTL (取第一条的时间)
            first_time = datetime.fromisoformat(details[0]['crawl_time'])
            if ttl_seconds and (datetime.now() - first_time).total_seconds() > ttl_seconds:
                logger.info(f"⌛ Detailed cache expired for hash {query_hash}")
                pass # Expired, fall through or return None? If Detail expired, Cache likely expired too.
                # But let's check basic cache just in case metadata differs? 
                # Actually if details exist, we prefer them. If expired, we return None.
                return None
            
            logger.info(f"✅ Hit detailed search cache for {query_hash} ({len(details)} items)")
            # Reconstruct the expected 'results' list format for SearchTools
            # SearchTools expects a list of dicts. 
            # We return a dict wrapper to match get_search_cache signature returning Dict usually containing 'results' string.
            # But SearchTools logic: 
            # cache = db.get_search_cache(...)
            # cached_data = json.loads(cache['results'])
            
            # To minimize SearchTools changes, we can return a dict mimicking the old structure
            # OR Change SearchTools to handle list return.
            # Let's return a special dict that SearchTools can recognize or just format it as before.
            return {"results": json.dumps(details), "timestamp": details[0]['crawl_time']}

        # 2. Fallback to old table
        cursor.execute("SELECT * FROM search_cache WHERE query_hash = ?", (query_hash,))
        row = cursor.fetchone()
        
        if not row:
            return None
            
        row_dict = dict(row)
        if ttl_seconds:
            cache_time = datetime.fromisoformat(row_dict['timestamp'])
            if (datetime.now() - cache_time).total_seconds() > ttl_seconds:
                logger.info(f"⌛ Cache expired for hash {query_hash}")
                return None
                
        return row_dict

    def save_search_cache(self, query_hash: str, query: str, engine: str, results: Union[str, List[Dict]]):
        """保存搜索结果 (同时保存到 search_cache 和 search_detail)"""
        cursor = self.conn.cursor()
        current_time = datetime.now().isoformat()
        
        results_str = results if isinstance(results, str) else json.dumps(results)
        
        # 1. Save summary to search_cache
        cursor.execute("""
            INSERT OR REPLACE INTO search_cache (query_hash, query, engine, results, timestamp)
            VALUES (?, ?, ?, ?, ?)
        """, (query_hash, query, engine, results_str, current_time))
        
        # 2. Save details to search_detail if results is a list
        if isinstance(results, list):
            for item in results:
                try:
                    item_id = item.get('id') or f"{hash(item.get('url', ''))}"
                    cursor.execute("""
                        INSERT OR REPLACE INTO search_detail
                        (id, query_hash, rank, title, url, content, publish_time, crawl_time, sentiment_score, source, meta_data)
                        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                    """, (
                        str(item_id),
                        query_hash,
                        item.get('rank', 0),
                        item.get('title'),
                        item.get('url'),
                        item.get('content', ''),
                        item.get('publish_time'),
                        item.get('crawl_time') or current_time,
                        item.get('sentiment_score'),
                        item.get('source'),
                        json.dumps(item.get('meta_data', {}))
                    ))
                except sqlite3.Error as e:
                    logger.error(f"Database error saving search detail {item.get('title')}: {e}")
                except Exception as e:
                    logger.error(f"Unexpected error saving search detail {item.get('title')}: {e}")
                    
        self.conn.commit()

    def find_similar_queries(self, query: str, limit: int = 5) -> List[Dict]:
        """模糊搜索相似的已缓存查询"""
        cursor = self.conn.cursor()
        
        # Simple fuzzy match: query in cached OR cached in query
        q_wild = f"%{query}%"
        cursor.execute("""
            SELECT query, query_hash, timestamp, results 
            FROM search_cache 
            WHERE query LIKE ? OR ? LIKE ('%' || query || '%')
            ORDER BY timestamp DESC
            LIMIT ?
        """, (q_wild, query, limit))
        
        return [dict(row) for row in cursor.fetchall()]

    def search_local_news(self, query: str, limit: int = 5) -> List[Dict]:
        """从本地 daily_news 搜索相关新闻"""
        cursor = self.conn.cursor()
        q_wild = f"%{query}%"
        # Search title and content
        cursor.execute("""
            SELECT * FROM daily_news
            WHERE title LIKE ? OR content LIKE ?
            ORDER BY crawl_time DESC
            LIMIT ?
        """, (q_wild, q_wild, limit))
        return [dict(row) for row in cursor.fetchall()]

    # --- 股票数据操作 ---

    def save_stock_list(self, df: pd.DataFrame):
        """保存股票列表到 stock_list 表"""
        cursor = self.conn.cursor()
        try:
            # 清空旧表
            cursor.execute("DELETE FROM stock_list")
            
            # 批量插入
            data = df[['code', 'name']].to_dict('records')
            cursor.executemany(
                "INSERT INTO stock_list (code, name) VALUES (:code, :name)",
                data
            )
            self.conn.commit()
        except sqlite3.Error as e:
            logger.error(f"Database error saving stock list: {e}")
        except Exception as e:
            logger.error(f"Unexpected error saving stock list: {e}")

    def search_stock(self, query: str, limit: int = 5) -> List[Dict]:
        """模糊搜索股票代码或名称"""
        cursor = self.conn.cursor()
        wild = f"%{query}%"
        cursor.execute("""
            SELECT code, name FROM stock_list 
            WHERE code LIKE ? OR name LIKE ? 
            LIMIT ?
        """, (wild, wild, limit))
        return [dict(row) for row in cursor.fetchall()]

    def get_stock_by_code(self, code: str) -> Optional[Dict[str, str]]:
        """精确按代码获取股票信息。

        Args:
            code: 股票代码(A股6位 / 港股5位),必须为纯数字字符串。

        Returns:
            dict: {"code": str, "name": str} 或 None。
        """
        if not code:
            return None
        clean = "".join([c for c in str(code).strip() if c.isdigit()])
        if not clean:
            return None

        cursor = self.conn.cursor()
        cursor.execute("SELECT code, name FROM stock_list WHERE code = ? LIMIT 1", (clean,))
        row = cursor.fetchone()
        return dict(row) if row else None

    def save_stock_prices(self, ticker: str, df: pd.DataFrame):
        """保存股价历史数据"""
        if df.empty:
            return
            
        cursor = self.conn.cursor()
        
        # 确保 DataFrame 有必要的列
        required_cols = ['date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
        for col in required_cols:
            if col not in df.columns:
                logger.warning(f"Missing column {col} in stock data for {ticker}")
                return

        try:
            for _, row in df.iterrows():
                cursor.execute("""
                    INSERT OR REPLACE INTO stock_prices 
                    (ticker, date, open, close, high, low, volume, change_pct)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    ticker,
                    row['date'],
                    row['open'],
                    row['close'],
                    row['high'],
                    row['low'],
                    row['volume'],
                    row['change_pct']
                ))
            self.conn.commit()
        except sqlite3.Error as e:
            logger.error(f"Database error saving stock prices for {ticker}: {e}")
        except Exception as e:
            logger.error(f"Unexpected error saving stock prices for {ticker}: {e}")

    def get_stock_prices(self, ticker: str, start_date: str, end_date: str) -> pd.DataFrame:
        """获取指定日期范围的股价数据"""
        cursor = self.conn.cursor()
        
        cursor.execute("""
            SELECT * FROM stock_prices 
            WHERE ticker = ? AND date >= ? AND date <= ?
            ORDER BY date
        """, (ticker, start_date, end_date))
        
        rows = cursor.fetchall()
        if not rows:
            return pd.DataFrame()
            
        columns = ['ticker', 'date', 'open', 'close', 'high', 'low', 'volume', 'change_pct']
        return pd.DataFrame([dict(row) for row in rows], columns=columns)

    def execute_query(self, query: str, params: tuple = ()) -> List[Any]:
        """执行自定义 SQL 查询"""
        try:
            cursor = self.conn.cursor()
            cursor.execute(query, params)
            if query.strip().upper().startswith("SELECT"):
                return cursor.fetchall()
            else:
                self.conn.commit()
                return []
        except sqlite3.Error as e:
            logger.error(f"SQL execution failed (Database error): {e}")
            return []
        except Exception as e:
            logger.error(f"SQL execution failed (Unexpected error): {e}")
            return []

    # --- 投资信号操作 (ISQ Framework) ---

    def save_signal(self, signal: Dict[str, Any]):
        """保存投资信号"""
        cursor = self.conn.cursor()
        created_at = datetime.now().isoformat()
        
        cursor.execute("""
            INSERT OR REPLACE INTO signals 
            (signal_id, title, summary, transmission_chain, sentiment_score, 
             confidence, intensity, expected_horizon, price_in_status, 
             impact_tickers, industry_tags, sources, user_id, created_at)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            signal.get('signal_id'),
            signal.get('title'),
            signal.get('summary'),
            json.dumps(signal.get('transmission_chain', [])),
            signal.get('sentiment_score', 0.0),
            signal.get('confidence', 0.0),
            signal.get('intensity', 1),
            signal.get('expected_horizon', 'T+0'),
            signal.get('price_in_status', '未知'),
            json.dumps(signal.get('impact_tickers', [])),
            json.dumps(signal.get('industry_tags', [])),
            json.dumps(signal.get('sources', [])),
            signal.get('user_id'),
            created_at
        ))
        self.conn.commit()

    def get_recent_signals(self, limit: int = 20, user_id: Optional[str] = None) -> List[Dict]:
        """获取最近的投资信号"""
        cursor = self.conn.cursor()
        if user_id:
            cursor.execute("SELECT * FROM signals WHERE user_id = ? ORDER BY created_at DESC LIMIT ?", (user_id, limit))
        else:
            cursor.execute("SELECT * FROM signals ORDER BY created_at DESC LIMIT ?", (limit,))
        rows = cursor.fetchall()
        
        signals = []
        for row in rows:
            d = dict(row)
            # 解析 JSON 字段
            for field in ['transmission_chain', 'impact_tickers', 'industry_tags', 'sources']:
                if d.get(field):
                    try:
                        d[field] = json.loads(d[field])
                    except:
                        pass
            signals.append(d)
        return signals

    def close(self):
        if self.conn:
            self.conn.close()
            logger.info("Database connection closed.")


```

### scripts/utils/json_utils.py

```python
import ast
import json
import re
from typing import Optional, Any
from loguru import logger

def _strip_comments(text: str) -> str:
    """
    Safely remove C-style comments (// and /* */) from JSON-like text,
    preserving strings (including URLs like http://).
    """
    result = []
    i = 0
    n = len(text)
    in_string = False
    escape = False
    
    while i < n:
        char = text[i]
        
        if in_string:
            if char == '\\':
                escape = not escape
            elif char == '"' and not escape:
                in_string = False
            else:
                escape = False
            result.append(char)
            i += 1
            continue
            
        # Not in string
        if char == '"':
            in_string = True
            result.append(char)
            i += 1
            continue
            
        # Check for // comment
        if i + 1 < n and text[i:i+2] == '//':
            i += 2
            while i < n and text[i] != '\n':
                i += 1
            continue
            
        # Check for /* comment
        if i + 1 < n and text[i:i+2] == '/*':
            i += 2
            while i + 1 < n and text[i:i+2] != '*/':
                i += 1
            i += 2
            continue
            
        result.append(char)
        i += 1
        
    return ''.join(result)

def extract_json(text: str) -> Optional[Any]:
    """
    更加鲁棒的 JSON 提取工具。
    处理:
    1. Markdown 代码块 (```json ... ```)
    2. 首尾多余字符
    3. 同一个文本中多个 JSON 对象 (仅提取第一个)
    4. 简单的 JSON 修复 (末尾逗号等)
    5. C 风格注释 (// 和 /* */)
    """
    if not text:
        return None
    
    # 1. 清理明显的 Markdown 包装
    text = text.strip()
    
    # 先尝试精确匹配 ```json ... ``` 或 ```...```
    md_match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
    if md_match:
        text = md_match.group(1).strip()
    elif text.startswith("```"):
        # 回退:如果开头有 ``` 但没完整匹配
        text = re.sub(r'^```[a-z]*\n?', '', text)
        text = re.sub(r'\n?```\s*$', '', text)
    
    # 2. 寻找第一个 JSON 起始符 { 或 [
    start_brace = text.find('{')
    start_bracket = text.find('[')
    
    if start_brace == -1 and start_bracket == -1:
        return None
        
    start_idx = start_brace if (start_bracket == -1 or (start_brace != -1 and start_brace < start_bracket)) else start_bracket
    
    # 2.5 预处理:修复一些极其常见的 LLM 错误
    potential_json = text[start_idx:].strip()
    
    # remove comments safely
    potential_json = _strip_comments(potential_json)
    
    # b. 修复缺失开头引号的键:  nodes": [  -> "nodes": [
    # 匹配模式: (空白或换行) 单词 紧跟引号和冒号
    potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\"\s*:', r'\1"\2":', potential_json)
    
    # c. 修复缺失末尾引号的键:  "nodes: [ -> "nodes": [
    potential_json = re.sub(r'([\{\,]\s*)\"([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)

    # d. 修复完全缺失引号的键: nodes: [ -> "nodes": [
    # 注意避免匹配到像 http:// 这种内容,所以限定在 { 或 , 之后
    potential_json = re.sub(r'([\{\,]\s*)([a-zA-Z_]\w*)\s*:', r'\1"\2":', potential_json)
    
    # 3. 使用 raw_decode 尝试解析
    decoder = json.JSONDecoder()
    
    # 首先尝试直接解析(不做任何预处理)
    try:
        obj = json.loads(potential_json)
        return obj
    except json.JSONDecodeError:
        pass
    
    # 简单预处理:移除对象/列表末位多余逗号
    processed_json = re.sub(r',\s*([\]}])', r'\1', potential_json)
    
    try:
        obj, end_pos = decoder.raw_decode(processed_json)
        return obj
    except json.JSONDecodeError:
        pass
    
    # e. 修复未终止的字符串字面量问题:移除值中的实际换行符
    # LLM 可能在字符串值中生成包含真实 newline 的内容,导致 JSON 非法
    def fix_multiline_strings(s):
        # 简单策略:将字符串值内的换行替换为空格
        lines = s.split('\n')
        result = []
        in_string = False
        for line in lines:
            # 计算未转义的引号数
            quote_count = line.count('"') - line.count('\\"')
            if in_string:
                result[-1] += ' ' + line.strip()
            else:
                result.append(line)
            
            if quote_count % 2 == 1:
                in_string = not in_string
        return '\n'.join(result)
    
    fixed_json = fix_multiline_strings(processed_json)
    
    try:
        obj, end_pos = decoder.raw_decode(fixed_json)
        return obj
    except json.JSONDecodeError:
        try:
            # 4. 尝试处理单引号问题 (JSON 规范要求双引号,但 LLM 常输出单引号)
            # 这是一个简单的替换技巧,仅针对像 {'key': 'value'} 这样的结构
            # 注意:这可能会破坏包含单引号的字符串值,所以作为较后的回退
            fix_quotes = re.sub(r"'(.*?)':", r'"\1":', processed_json) # 修复键
            fix_quotes = re.sub(r":\s*'(.*?)'", r': "\1"', fix_quotes)   # 修复简单值
            obj, end_pos = decoder.raw_decode(fix_quotes)
            return obj
        except (json.JSONDecodeError, TypeError):
            try:
                # 5. 使用 ast.literal_eval 作为终极回退 (处理 Python 字典格式)
                # 提取第一个匹配的括号对内容
                # 寻找匹配的 { }
                stack = []
                for i, char in enumerate(potential_json):
                    if char == '{': stack.append('{')
                    elif char == '}':
                        if stack: stack.pop()
                        if not stack:
                            content = potential_json[:i+1]
                            return ast.literal_eval(content)
            except (ValueError, SyntaxError, MemoryError) as e:
                logger.warning(f"All JSON extraction attempts failed: {e}")
            except Exception as e:
                logger.error(f"Unexpected error during JSON extraction: {e}")
    
    return None

```

### scripts/utils/llm/capability.py

```python
import os
from typing import Optional, List, Dict, Any
from agno.agent import Agent
from agno.models.base import Model
from loguru import logger
from ..llm.factory import get_model

def test_tool_call_support(model: Model) -> bool:
    """
    测试模型是否支持原生的 Tool Call (Function Calling)。
    通过尝试执行一个简单的加法工具来验证。
    """
    def get_current_weather(location: str):
        """获取指定地点的天气"""
        return f"{location} 的天气是晴天,25度。"

    test_agent = Agent(
        model=model,
        tools=[get_current_weather],
        instructions="请调用工具查询北京的天气,并直接返回工具的输出结果。"
    )

    try:
        # 运行一个简单的任务,观察是否触发了 tool_call
        response = test_agent.run("北京天气怎么样?")
        
        # 检查 response 中是否包含 tool_calls
        # Agno 的 RunResponse 对象通常包含 messages,我们可以检查最后几条消息
        has_tool_call = False
        for msg in response.messages:
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                has_tool_call = True
                break
        
        if has_tool_call:
            logger.info(f"✅ Model {model.id} supports native tool calling.")
            return True
        else:
            # 如果没有 tool_calls 但返回了正确答案,可能是模型通过纯文本模拟了工具调用(ReAct)
            # 或者根本没用工具。对于原生支持的判断,我们坚持要求有 tool_calls 结构。
            logger.warning(f"⚠️ Model {model.id} did NOT use native tool calling structure.")
            return False
            
    except Exception as e:
        logger.error(f"❌ Error testing tool call for {model.id}: {e}")
        return False

class ModelCapabilityRegistry:
    """
    模型能力注册表,用于缓存和管理不同模型的能力测试结果。
    """
    _cache = {}

    @classmethod
    def get_capabilities(cls, provider: str, model_id: str, **kwargs) -> Dict[str, bool]:
        key = f"{provider}:{model_id}"
        if key not in cls._cache:
            logger.info(f"🔍 Testing capabilities for {key}...")
            model = get_model(provider, model_id, **kwargs)
            supports_tool_call = test_tool_call_support(model)
            cls._cache[key] = {
                "supports_tool_call": supports_tool_call
            }
        return cls._cache[key]

if __name__ == "__main__":
    # 简单测试脚本
    from dotenv import load_dotenv
    load_dotenv()
    
    # 测试当前配置的模型
    p = os.getenv("LLM_PROVIDER", "ust")
    m = os.getenv("LLM_MODEL", "Qwen")
    
    print(f"Testing {p}/{m}...")
    res = ModelCapabilityRegistry.get_capabilities(p, m)
    print(f"Result: {res}")

```

### scripts/utils/llm/factory.py

```python
import os
from agno.models.openai import OpenAIChat
from agno.models.ollama import Ollama
from agno.models.dashscope import DashScope
from agno.models.deepseek import DeepSeek
from agno.models.openrouter import OpenRouter

def get_model(model_provider: str, model_id: str, **kwargs):
    """
    Factory to get the appropriate LLM model.
    
    Args:
        model_provider: "openai", "ollama", "deepseek"
        model_id: The specific model ID (e.g., "gpt-4o", "llama3", "deepseek-chat")
        **kwargs: Additional arguments for the model constructor
    """
    if model_provider == "openai":
        return OpenAIChat(id=model_id, **kwargs)
    
    elif model_provider == "ollama":
        return Ollama(id=model_id, **kwargs)
    
    elif model_provider == "deepseek":
        # DeepSeek is OpenAI compatible
        api_key = os.getenv("DEEPSEEK_API_KEY")
        if not api_key:
            print("Warning: DEEPSEEK_API_KEY not set.")
        
        return DeepSeek(
            id=model_id,
            api_key=api_key,
            **kwargs
        )
    elif model_provider == "dashscope":
        api_key = os.getenv("DASHSCOPE_API_KEY")
        if not api_key:
            print("Warning: DASHSCOPE_API_KEY not set.")
        
        return DashScope(
            id=model_id,
            base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
            api_key=api_key,
            **kwargs
        )
    elif model_provider == 'openrouter':
        api_key = os.getenv("OPENROUTER_API_KEY")
        if not api_key:
            print('Warning: OPENROUTER_API_KEY not set.')
        
        return OpenRouter(
            id=model_id,
            api_key=api_key,
            **kwargs
        )

    elif model_provider == 'zai':
        api_key = os.getenv("ZAI_KEY_API")
        if not api_key:
            print('Warning: ZAI_KEY_API not set.')

        # role_map to ensure compatibility.
        default_role_map = {
            "system": "system",
            "user": "user",
            "assistant": "assistant",
            "tool": "tool",
            "model": "assistant",
        }

        # Allow callers to override role_map via kwargs, otherwise use default
        role_map = kwargs.pop("role_map", default_role_map)
        
        return OpenAIChat(
            id=model_id,
            base_url="https://api.z.ai/api/paas/v4",
            api_key=api_key,
            timeout=60,
            role_map=role_map,
            extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
            **kwargs
        )
    
    elif model_provider == 'ust':
        api_key = os.getenv("UST_KEY_API")
        if not api_key:
            print('Warning: UST_KEY_API not set.')
        
        # Some UST-compatible endpoints expect the standard OpenAI role names
        # (e.g. "system", "user", "assistant") rather than Agno's default
        # mapping which maps "system" -> "developer". Provide an explicit
        # role_map to ensure compatibility.
        default_role_map = {
            "system": "system",
            "user": "user",
            "assistant": "assistant",
            "tool": "tool",
            "model": "assistant",
        }

        # Allow callers to override role_map via kwargs, otherwise use default
        role_map = kwargs.pop("role_map", default_role_map)

        return OpenAIChat(
            id=model_id,
            api_key=api_key,
            base_url=os.getenv("UST_URL"),
            role_map=role_map,
            extra_body={"enable_thinking": False}, # TODO: one more setting for thinking
            **kwargs
        )
    
    else:
        raise ValueError(f"Unknown model provider: {model_provider}")


```

### scripts/utils/llm/router.py

```python
import os
from typing import Optional, List, Dict, Any, Union
from agno.models.base import Model
from loguru import logger
from dotenv import load_dotenv
from ..llm.factory import get_model
from ..llm.capability import ModelCapabilityRegistry

# 确保在初始化前加载环境变量
load_dotenv()

class ModelRouter:
    """
    模型路由管理器
    
    功能:
    1. 管理“推理/写作模型” (Reasoning Model) 和“工具调用模型” (Tool Model)。
    2. 根据任务需求自动选择合适的模型。
    """
    
    def __init__(self):
        # 默认从环境变量读取
        self.reasoning_provider = os.getenv("REASONING_MODEL_PROVIDER", os.getenv("LLM_PROVIDER", "openai"))
        self.reasoning_id = os.getenv("REASONING_MODEL_ID", os.getenv("LLM_MODEL", "gpt-4o"))
        self.reasoning_host = os.getenv("REASONING_MODEL_HOST", os.getenv("LLM_HOST"))
        
        self.tool_provider = os.getenv("TOOL_MODEL_PROVIDER", self.reasoning_provider)
        self.tool_id = os.getenv("TOOL_MODEL_ID", self.reasoning_id)
        self.tool_host = os.getenv("TOOL_MODEL_HOST", self.reasoning_host)
        
        self._reasoning_model = None
        self._tool_model = None
        
        logger.info(f"🤖 ModelRouter initialized: Reasoning={self.reasoning_id} ({self.reasoning_host or 'default'}), Tool={self.tool_id} ({self.tool_host or 'default'})")

    def get_reasoning_model(self, **kwargs) -> Model:
        if not self._reasoning_model:
            # 优先使用路由配置的 host
            if self.reasoning_host and "host" not in kwargs:
                kwargs["host"] = self.reasoning_host
            self._reasoning_model = get_model(self.reasoning_provider, self.reasoning_id, **kwargs)
        return self._reasoning_model

    def get_tool_model(self, **kwargs) -> Model:
        if not self._tool_model:
            # 优先使用路由配置的 host
            if self.tool_host and "host" not in kwargs:
                kwargs["host"] = self.tool_host
                
            # 检查 tool_model 是否真的支持 tool call
            caps = ModelCapabilityRegistry.get_capabilities(self.tool_provider, self.tool_id, **kwargs)
            if not caps["supports_tool_call"]:
                logger.warning(f"⚠️ Configured tool model {self.tool_id} might not support native tool calls! Consider using ReAct mode or a different model.")
            
            self._tool_model = get_model(self.tool_provider, self.tool_id, **kwargs)
        return self._tool_model

    def get_model_for_agent(self, has_tools: bool = False, **kwargs) -> Model:
        """
        根据 Agent 是否包含工具来返回合适的模型。
        """
        if has_tools:
            return self.get_tool_model(**kwargs)
        return self.get_reasoning_model(**kwargs)

# 全局单例
router = ModelRouter()

```

### scripts/utils/logging_setup.py

```python
import os
import sys
from datetime import datetime
from typing import Optional

from loguru import logger


def setup_file_logging(
    run_id: str,
    log_dir: str = "logs",
    level: str = "INFO",
    retention: str = "10 days",
    rotation: str = "20 MB",
) -> str:
    """Configure Loguru to log to stderr + a per-run file.

    Returns the log file path.
    """
    os.makedirs(log_dir, exist_ok=True)

    # Remove default handler to avoid duplicate logs.
    logger.remove()

    # Console
    logger.add(sys.stderr, level=level, backtrace=False, diagnose=False)

    # File (safe for multi-thread via enqueue)
    log_path = os.path.join(log_dir, f"signalflux_{run_id}.log")
    logger.add(
        log_path,
        level=level,
        rotation=rotation,
        retention=retention,
        enqueue=True,
        backtrace=True,
        diagnose=False,
        encoding="utf-8",
    )
    return log_path


def make_run_id(prefix: Optional[str] = None) -> str:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    return f"{prefix}_{ts}" if prefix else ts

```

### scripts/utils/predictor/evaluation.py

```python
import os
import sys
import torch
import pandas as pd
import numpy as np
import glob
from loguru import logger
from datetime import datetime, timedelta

# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

from ..kronos.auto_synthesis_training import AutoSynthesisTrainer
from ..kronos.model import KronosPredictor
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint

class NewsModelEvaluator:
    def __init__(self, model_path=None):
        self.trainer = AutoSynthesisTrainer()
        self.device = self.trainer.device
        
        if model_path is None:
            # Try to find the latest model in exports/models
            model_files = glob.glob(os.path.join(SRC_DIR, "exports/models/*.pt"))
            if not model_files:
                logger.warning("⚠️ No trained models found in exports/models/. Using base model (zero-init proj).")
            else:
                model_path = max(model_files, key=os.path.getctime)
        
        if model_path:
            self.load_weights(model_path)

    def load_weights(self, path):
        logger.info(f"🔄 Loading model weights from {path}...")
        checkpoint = torch.load(path, map_location=self.device)
        self.trainer.model.news_proj.load_state_dict(checkpoint['news_proj_state_dict'])
        logger.success("✅ News projection layer loaded.")

    def evaluate_range(self, start_idx=100, end_idx=200, pred_len=5):
        # 1. Fetch Tickers
        res = self.trainer.db.execute_query("SELECT code FROM stock_list")
        all_tickers = [row['code'] for row in res]
        test_tickers = all_tickers[start_idx:end_idx]
        
        if not test_tickers:
            logger.error(f"No tickers found in range {start_idx}-{end_idx}")
            return

        logger.info(f"🚀 Evaluating News Model on stocks {start_idx} to {end_idx}...")
        
        # 2. Discover Shocks
        shocks = self.trainer.discover_shocks(test_tickers, pred_len=pred_len)
        
        # 3. Associate News & Predict
        self.trainer.model.eval()
        predictor = KronosPredictor(self.trainer.model, self.trainer.tokenizer, device=self.device)
        
        save_dir = os.path.join(SRC_DIR, "exports/evaluation_results")
        os.makedirs(save_dir, exist_ok=True)

        count = 0
        for shock in shocks:
            summary = self.trainer.find_reason_and_verify(shock)
            if not summary:
                continue
            
            logger.info(f"📈 Testing shock: {shock['ticker']} on {shock['date']}")
            
            # Embedding news
            news_emb = self.trainer.embedder.encode(summary)
            
            # Prediction
            h = shock['history']
            t = shock['target']
            actuals = t['close'].values[:pred_len]
            
            x_ts = pd.to_datetime(h['date'])
            future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B')
            y_ts = pd.Series(future_dates)
            
            # A. Base Prediction (No news)
            p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False)
            
            # B. News-Aware Prediction
            p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=news_emb, verbose=False)
            
            # Calculate Improvement
            b_preds = p_base['close'].values[:len(actuals)]
            n_preds = p_news['close'].values[:len(actuals)]
            b_mae = np.mean(np.abs(b_preds - actuals))
            n_mae = np.mean(np.abs(n_preds - actuals))
            improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100

            # C. Visualize
            try:
                def to_kp_list(preds_df):
                    points = []
                    for idx, row in preds_df.iterrows():
                        points.append(KLinePoint(
                            date=str(idx)[:10], open=row['open'], high=row['high'],
                            low=row['low'], close=row['close'], volume=row.get('volume', 0)
                        ))
                    return points

                forecast_obj = ForecastResult(
                    ticker=shock['ticker'],
                    base_forecast=to_kp_list(p_base),
                    adjusted_forecast=to_kp_list(p_news),
                    rationale=summary
                )

                chart = VisualizerTools.generate_stock_chart(
                    df=h, ticker=shock['ticker'],
                    title=f"Test Eval: {shock['ticker']} ({shock['date']}) Imp: {improvement:.1f}%",
                    forecast=forecast_obj,
                    ground_truth=t[['date', 'open', 'high', 'low', 'close', 'volume']]
                )
                
                safe_date = shock['date'].replace("-", "")
                filename = f"test_{shock['ticker']}_{safe_date}.html"
                VisualizerTools.render_chart_to_file(chart, os.path.join(save_dir, filename))
                
                logger.success(f"📊 Result for {shock['ticker']} saved. Base MAE: {b_mae:.4f}, News MAE: {n_mae:.4f}")
                count += 1
            except Exception as e:
                logger.error(f"Visualization failed: {e}")

        logger.info(f"🏁 Finished evaluation. {count} cases visualized in {save_dir}")

if __name__ == "__main__":
    # If you have a specific model, pass the path here. Otherwise it picks the latest.
    evaluator = NewsModelEvaluator()
    evaluator.evaluate_range(start_idx=100, end_idx=200, pred_len=1)

```

### scripts/utils/predictor/kline_generate.py

```python
# Ref: https://github.com/shiyu-coder/Kronos

from model import Kronos, KronosTokenizer, KronosPredictor
import pandas as pd
import sqlite3
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pandas.tseries.offsets import BusinessDay
import numpy as np

def get_device():
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")
    return device

def load_predictor():
    tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
    model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
    device = get_device()
    tokenizer = tokenizer.to(device)
    model = model.to(device)
    return KronosPredictor(model, tokenizer, device=device, max_context=512)

def load_data(ticker="002111", db_path="AlphaEar/data/signal_flux.db"):
    with sqlite3.connect(db_path) as conn:
        df = pd.read_sql_query(f"SELECT * FROM stock_prices WHERE ticker = '{ticker}'", conn)
    df['date'] = pd.to_datetime(df['date'])
    df = df.sort_values('date').reset_index(drop=True)
    return df

def plot_kline_matplotlib(ax, ax_vol, dates, df, label_suffix="", color_up='#ef4444', color_down='#22c55e', alpha=1.0, is_prediction=False):
    """
    绘制 K 线图和成交量
    """
    # X axis mapping to integers for consistent spacing
    x = np.arange(len(dates))
    
    # K-line data
    opens = df['open'].values
    closes = df['close'].values
    highs = df['high'].values
    lows = df['low'].values
    volumes = df['volume'].values
    
    # Width of the candlestick
    width = 0.6
    
    for i in range(len(x)):
        color = color_up if closes[i] >= opens[i] else color_down
        linestyle = '--' if is_prediction else '-'
        
        # Wick
        ax.vlines(x[i], lows[i], highs[i], color=color, linewidth=1, alpha=alpha, linestyle=linestyle)
        
        # Body
        rect_bottom = min(opens[i], closes[i])
        rect_height = abs(opens[i] - closes[i])
        if rect_height == 0: rect_height = 0.001 # Visual hair
        
        ax.add_patch(plt.Rectangle((x[i] - width/2, rect_bottom), width, rect_height, 
                                 edgecolor=color, facecolor=color if not is_prediction else 'none', 
                                 alpha=alpha, linewidth=1, linestyle=linestyle))
        
        # Volume
        ax_vol.bar(x[i], volumes[i], color=color, alpha=alpha * 0.5, width=width)

def render_comparison_chart(history_df, actual_df, pred_df, title):
    """
    渲染组合图:历史 K 线 + 真值 K 线 + 预测 K 线
    """
    # Combine all dates for X axis
    all_dates = pd.concat([history_df['date'], actual_df['date'] if actual_df is not None else pred_df.index.to_series()]).unique()
    all_dates = sorted(all_dates)
    date_to_idx = {date: i for i, date in enumerate(all_dates)}
    
    fig = plt.figure(figsize=(14, 8), facecolor='white')
    gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1], hspace=0.1)
    ax_main = fig.add_subplot(gs[0])
    ax_vol = fig.add_subplot(gs[1], sharex=ax_main)
    
    # 1. Plot History
    hist_indices = [date_to_idx[d] for d in history_df['date']]
    # We use a custom x for plotting to ensure continuity
    plot_kline_matplotlib(ax_main, ax_vol, history_df['date'], history_df, alpha=0.8)
    
    offset = len(history_df)
    
    # 2. Plot Actual if exists
    if actual_df is not None:
        # Shift indices
        actual_x = np.arange(len(actual_df)) + offset
        # Plotting manually to handle offset
        for i in range(len(actual_df)):
            idx = actual_x[i]
            row = actual_df.iloc[i]
            color = '#ef4444' if row['close'] >= row['open'] else '#22c55e'
            ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1, alpha=0.9)
            ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), 
                                         edgecolor=color, facecolor=color, alpha=0.9))
            ax_vol.bar(idx, row['volume'], color=color, alpha=0.4)
            
    # 3. Plot Prediction
    pred_x = np.arange(len(pred_df)) + offset
    for i in range(len(pred_df)):
        idx = pred_x[i]
        row = pred_df.iloc[i]
        color = '#ff8c00' # Orange for prediction to distinguish
        ax_main.vlines(idx, row['low'], row['high'], color=color, linewidth=1.5, linestyle='--')
        ax_main.add_patch(plt.Rectangle((idx - 0.3, min(row['open'], row['close'])), 0.6, abs(row['open']-row['close']), 
                                     edgecolor=color, facecolor='none', linewidth=1.5, linestyle='--'))
        # Plot secondary prediction line for close
        if i == 0:
            # Connect to history
            ax_main.plot([offset-1, idx], [history_df['close'].iloc[-1], row['close']], color=color, linestyle='--', alpha=0.6)
        elif i > 0:
            ax_main.plot([idx-1, idx], [pred_df['close'].iloc[i-1], row['close']], color=color, linestyle='--', alpha=0.6)

    # Styling
    ax_main.set_title(title, fontsize=14, fontweight='bold')
    ax_main.grid(True, linestyle=':', alpha=0.6)
    ax_vol.grid(True, linestyle=':', alpha=0.6)
    ax_vol.set_ylabel('Volume')
    ax_main.set_ylabel('Price')
    
    # Set X ticks
    step = max(1, len(all_dates) // 10)
    ax_vol.set_xticks(np.arange(0, len(all_dates), step))
    ax_vol.set_xticklabels([all_dates[i].strftime('%Y-%m-%d') for i in range(0, len(all_dates), step)], rotation=45)
    
    plt.tight_layout()
    plt.show()
    plt.close()

def run_backtest(df, predictor, lookback, pred_len, start_index=0):
    total_len = len(df)
    history_start = start_index
    history_end = start_index + lookback 
    pred_start = history_end
    
    available_pred_len = total_len - pred_start
    if available_pred_len <= 0: return
    actual_pred_len = min(pred_len, available_pred_len)
    pred_end = pred_start + actual_pred_len
    
    x_df = df.iloc[history_start : history_end].copy()
    y_true_df = df.iloc[pred_start : pred_end].copy()
    y_timestamp = y_true_df['date']
    
    print(f"Backtesting: {x_df['date'].iloc[0].date()} to {y_timestamp.iloc[-1].date()}")
    
    pred_df = predictor.predict(
        df=x_df[['open', 'high', 'low', 'close', 'volume']],
        x_timestamp=x_df['date'],
        y_timestamp=y_timestamp,
        pred_len=actual_pred_len,
        T=1.0, top_p=0.9, sample_count=1
    )
    
    render_comparison_chart(x_df, y_true_df, pred_df, f"Backtest: {TICKER} K-Line Comparison")

def run_forecast(df, predictor, lookback, pred_len):
    if len(df) < lookback: return
    x_df = df.iloc[-lookback:].copy()
    last_date = x_df['date'].iloc[-1]
    future_dates = pd.date_range(start=last_date + BusinessDay(1), periods=pred_len, freq='B')
    future_dates = pd.Series(future_dates)
    
    print(f"Forecasting: Starting from {future_dates.iloc[0].date()}")
    
    pred_df = predictor.predict(
        df=x_df[['open', 'high', 'low', 'close', 'volume']],
        x_timestamp=x_df['date'],
        y_timestamp=future_dates,
        pred_len=pred_len,
        T=1.0, top_p=0.9, sample_count=1
    )
    
    render_comparison_chart(x_df, None, pred_df, f"Forecast: {TICKER} Future K-Line")

if __name__ == "__main__":
    LOOKBACK = 20
    PRED_LEN = 10
    TICKER = '002111'
    
    pred_model = load_predictor()
    stock_data = load_data(TICKER)
    
    total_rows = len(stock_data)
    backtest_start = max(0, total_rows - LOOKBACK - PRED_LEN - 10) # Leave some space to see trend
    
    print("\n--- Running Backtest ---")
    run_backtest(stock_data, pred_model, LOOKBACK, PRED_LEN, start_index=backtest_start)
    
    print("\n--- Running Forecast ---")
    run_forecast(stock_data, pred_model, LOOKBACK, PRED_LEN)
```

### scripts/utils/predictor/training.py

```python
import os
import sys
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import json
import random
from loguru import logger
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Setup paths
KRONOS_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.dirname(os.path.dirname(KRONOS_DIR))
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

from ..kronos.model import Kronos, KronosTokenizer, KronosPredictor
from ..database_manager import DatabaseManager
from ..stock_tools import StockTools
from ..search_tools import SearchTools
from ..llm.factory import get_model
from ..visualizer import VisualizerTools
from ..schema.models import ForecastResult, KLinePoint
from agno.agent import Agent

class AutoSynthesisTrainer:
    def __init__(self, news_dim=384):
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        self.db = DatabaseManager()
        self.tools = StockTools(self.db)
        self.searcher = SearchTools(self.db)
        # Try loading from local cache first to avoid network timeouts
        model_name = os.getenv('EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
        try:
            logger.info(f"🔄 Attempting to load {model_name} from local cache...")
            self.embedder = SentenceTransformer(model_name, device=self.device, local_files_only=True)
            logger.success("✅ Model loaded from local cache.")
        except Exception:
            logger.warning("⚠️ Local cache not found or incomplete. Attempting to download...")
            self.embedder = SentenceTransformer(model_name, device=self.device)
        self.news_dim = news_dim
        
        # Try loading from local cache first to avoid network timeouts
        try:
            logger.info("🔄 Attempting to load Kronos and Tokenizer from local cache...")
            self.tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base", local_files_only=True).to(self.device)
            base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base", local_files_only=True)
            logger.success("✅ Kronos and Tokenizer loaded from local cache.")
        except Exception:
            logger.warning("⚠️ Local Kronos/Tokenizer not found or incomplete. Attempting to download...")
            self.tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base").to(self.device)
            base_model = Kronos.from_pretrained("NeoQuasar/Kronos-base")
            
        self.model = Kronos(
            base_model.s1_bits, base_model.s2_bits, base_model.n_layers, 
            base_model.d_model, base_model.n_heads, base_model.ff_dim,
            base_model.ffn_dropout_p, base_model.attn_dropout_p,
            base_model.resid_dropout_p, base_model.token_dropout_p,
            base_model.learn_te, news_dim=self.news_dim
        ).to(self.device)
        self.model.load_state_dict(base_model.state_dict(), strict=False)
        
        # LLM for causality verification
        provider = os.getenv("LLM_PROVIDER", "ust")
        model_id = os.getenv("LLM_MODEL", "Qwen")
        self.llm_agent = Agent(model=get_model(provider, model_id))

    def discover_shocks(self, ticker_list, threshold=2.0, limit_per_stock=5, days=365, pred_len=5):
        """1. Find days with significant price movements (Look back 1 year)"""
        shocks = []
        end_date = datetime.now().strftime("%Y-%m-%d")
        start_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
        
        for ticker in ticker_list:
            df = self.tools.get_stock_price(ticker, start_date=start_date, end_date=end_date)
            if df.empty or len(df) < 60:
                continue
            
            # Look for big moves
            moves = df[df['change_pct'].abs() > threshold].copy()
            if moves.empty: continue
            
            count = 0
            for idx, row in moves.iterrows():
                # Ensure we have history before this day AND enough future days for eval
                date_idx = df.index.get_loc(idx)
                if date_idx < 50 or date_idx + pred_len > len(df): continue
                
                shocks.append({
                    'ticker': ticker,
                    'date': row['date'],
                    'change': row['change_pct'],
                    'history': df.iloc[date_idx-50:date_idx],
                    'target': df.iloc[date_idx:date_idx + pred_len] # Now capturing pred_len days
                })
                count += 1
                if count >= limit_per_stock: break
        
        logger.info(f"✨ Discovered {len(shocks)} potential price shocks over the last {days} days.")
        return shocks

    def find_reason_and_verify(self, shock):
        """2. Search for reasons and verify causality using LLM"""
        ticker_info = self.db.get_stock_by_code(shock['ticker'])
        name = ticker_info['name'] if ticker_info else shock['ticker']
        date_str = shock['date']
        
        # Try multiple query variations and engines
        queries = [
            f"{name} ({shock['ticker']}) {date_str} 为什么涨跌 原因",
            f"{name} {date_str} 异动 原因",
            f"{shock['ticker']} {date_str} 新闻"
        ]
        
        search_results = []
        for query in queries:
            logger.info(f"🔍 Searching for reason: {query}")
            # Try alternate engines
            for engine in ["baidu"]:
                try:
                    results = self.searcher.search_list(query, engine=engine, max_results=3, enrich=False)
                    if results:
                        search_results = results
                        break
                except Exception as e:
                    logger.warning(f"Search failed for {query} on {engine}: {e}")
            
            if search_results:
                break
            time.sleep(random.uniform(1.0, 2.0))
            
        if not search_results:
            logger.warning(f"⚠️ No search results found for {name} on {date_str} after multiple attempts.")
            return None
        
        context = "\n".join([f"- {r['title']}: {r.get('content', '')[:300]}" for r in search_results])
        
        prompt = f"""
        任务:判断以下新闻是否解释了该股票在 {date_str} 的 {shock['change']:.2f}% 价格变动。
        
        股票:{name}
        日期:{date_str}
        变动:{shock['change']:.2f}%
        
        搜索结果:
        {context}
        
        要求:
        1. 该新闻是否在该日期左右发生?
        2. 该新闻是否能逻辑上解释这种大幅波动(如财报、利好政策、重组、大环境暴跌等)?
        3. 如果是,请总结一段 100 字以内的“核心推动原因”。
        4. 返回 JSON: {{"is_causal": true/false, "summary": "原因摘要"}}
        """
        
        try:
            res = self.llm_agent.run(prompt)
            data = json.loads(res.content.replace('```json', '').replace('```', '').strip())
            if data.get('is_causal'):
                logger.success(f"✅ Verified cause for {name} on {date_str}: {data['summary']}")
                return data['summary']
            else:
                logger.warning(f"❌ Verified cause for {name} on {date_str}: {data['summary']}")
                return None
        except Exception as e:
            logger.warning(f"Verification failed: {e}")
        return None

    def save_model(self, path=None):
        """Save the news_proj weights"""
        if path is None:
            save_dir = os.path.join(SRC_DIR, "exports/models")
            os.makedirs(save_dir, exist_ok=True)
            path = os.path.join(save_dir, f"kronos_news_v1_{datetime.now().strftime('%Y%m%d_%H%M')}.pt")
        
        # We only really need to save the news_proj part as it's the only one we train
        torch.save({
            'news_proj_state_dict': self.model.news_proj.state_dict(),
            'news_dim': self.news_dim,
            'd_model': self.model.d_model
        }, path)
        logger.success(f"💾 Model weights saved to {path}")
        return path

    def run_synthesis_and_train(self, tickers, pred_len=5):
        # 1. Discovery
        shocks = self.discover_shocks(tickers, pred_len=pred_len)
        print(f'find {len(shocks)} shocks')
        
        # 2. News Association & Verification
        dataset = []
        max_news_items = 200 # Limit to 200 news items per session to avoid search bans
        
        logger.info(f"🧬 Starting News Association for {len(shocks)} shocks (Max limit: {max_news_items})")
        
        for i, shock in enumerate(shocks):
            if len(dataset) >= max_news_items:
                logger.info("Reached maximum news items limit for this session.")
                break
                
            summary = self.find_reason_and_verify(shock)
            if summary:
                # 3. Embedding news
                emb = self.embedder.encode(summary)
                dataset.append({
                    'history': shock['history'],
                    'target': shock['target'],
                    'news_emb': emb,
                    'summary': summary
                })
            
            # Add delay after search with randomness to avoid being blocked
            if i < len(shocks) - 1:
                delay = random.uniform(2.0, 4.0)
                time.sleep(delay)
        
        if not dataset:
            logger.error("❌ No verified news-price pairs found. Adjust threshold or check if news is available in that period.")
            return

        # 4. Train/Val Split
        random.seed(42)
        random.shuffle(dataset)
        
        if len(dataset) < 2:
            train_set = dataset
            val_set = []
            logger.warning(f"⚠️ Only {len(dataset)} sample(s) found. Training on all, skipping validation.")
        else:
            split_idx = max(1, int(len(dataset) * 0.8))
            if split_idx >= len(dataset):
                split_idx = len(dataset) - 1
                
            train_set = dataset[:split_idx]
            val_set = dataset[split_idx:]
            logger.info(f"🏗️ Dataset Split: {len(train_set)} samples for training, {len(val_set)} for validation.")

        if not train_set:
            logger.error("❌ No samples for training.")
            return

        # 5. Training (Few-shot)
        optimizer = torch.optim.Adam(self.model.news_proj.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        
        loss_history = []
        logger.info(f"🚀 Training for 30 epochs...")
        for epoch in range(30):
            total_loss = 0
            for item in train_set:
                optimizer.zero_grad()
                
                # Prep Data
                hist_df = item['history']
                # For training, we still focus on the immediate next point (teacher forcing)
                target_df = item['target'].iloc[:1]
                
                hist_raw = hist_df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
                hist_raw = np.column_stack([hist_raw, hist_raw[:, 3] * hist_raw[:, 4]]) 
                
                mean, std = hist_raw.mean(axis=0), hist_raw.std(axis=0) + 1e-5
                hist_norm = torch.from_numpy((hist_raw - mean) / std).unsqueeze(0).to(self.device)
                
                target_raw = target_df[['open', 'high', 'low', 'close', 'volume']].values.astype(np.float32)
                target_raw = np.column_stack([target_raw, target_raw[:, 3] * target_raw[:, 4]])
                target_norm = torch.from_numpy((target_raw - mean) / std).unsqueeze(0).to(self.device)
                
                with torch.no_grad():
                    z_indices = self.tokenizer.encode(hist_norm, half=True)
                    t_indices = self.tokenizer.encode(target_norm, half=True)
                    s1_ids, s2_ids = z_indices[0], z_indices[1]
                    t_s1, t_s2 = t_indices[0], t_indices[1]
                
                news_t = torch.from_numpy(item['news_emb']).unsqueeze(0).to(self.device)
                s1_logits, s2_logits = self.model(s1_ids, s2_ids, news_emb=news_t, use_teacher_forcing=True, s1_targets=t_s1)
                
                loss = (criterion(s1_logits[:, -1, :], t_s1[:, 0]) + criterion(s2_logits[:, -1, :], t_s2[:, 0])) / 2
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                
            avg_epoch_loss = total_loss / max(1, len(train_set))
            loss_history.append(avg_epoch_loss)
            
            if (epoch + 1) % 10 == 0:
                logger.info(f"Epoch {epoch+1} Loss: {avg_epoch_loss:.4f}")

        # 5.1 Visualize Loss Curve
        loss_chart = VisualizerTools.generate_loss_chart(loss_history)
        VisualizerTools.render_chart_to_file(loss_chart, os.path.join(SRC_DIR, "exports/training_results/loss_curve.html"))

        # 5.2 Save final model
        self.save_model()

        # 6. Final Evaluation on Validation Set
        if not val_set:
            logger.warning("⚠️ Validation set is empty. Skipping statistical analysis.")
            return

        logger.info(f"🧪 Final Evaluation: Base vs News-Integrated ({pred_len}-day Window)")
        self.model.eval()
        predictor = KronosPredictor(self.model, self.tokenizer, device=self.device)
        
        base_maes = []
        news_maes = []
        
        print("\n" + "="*90)
        print(f"{'Date':<12} | {'Ticker':<8} | {'Base MAE':<15} | {'News MAE':<15} | {'Improvement'}")
        print("-" * 90)

        for item in val_set:
            h = item['history']
            t = item['target']
            actuals = t['close'].values[:pred_len]
            
            x_ts = pd.to_datetime(h['date'])
            # Future timestamps: handle business days if possible, or just simple offset
            future_dates = pd.date_range(start=x_ts.iloc[-1] + timedelta(days=1), periods=pred_len, freq='B')
            y_ts = pd.Series(future_dates)
            
            # A. Base Prediction
            p_base = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=None, verbose=False)
            b_preds = p_base['close'].values[:len(actuals)]
            
            # B. News-Aware Prediction
            p_news = predictor.predict(h, x_ts, y_ts, pred_len=pred_len, news_emb=item['news_emb'], verbose=False)
            n_preds = p_news['close'].values[:len(actuals)]
            
            # Calculate MAE over the window
            b_mae = np.mean(np.abs(b_preds - actuals))
            n_mae = np.mean(np.abs(n_preds - actuals))
            
            base_maes.append(b_mae)
            news_maes.append(n_mae)
            
            improvement = (b_mae - n_mae) / (b_mae + 1e-6) * 100
            
            date_str = str(t['date'].values[0])[:10]
            ticker = h.iloc[-1]['ticker'] if 'ticker' in h.columns else "Stock"
            print(f"{date_str:<12} | {ticker:<8} | {b_mae:<15.4f} | {n_mae:<15.4f} | {improvement:>+7.1f}%")

            # C. Generate Visualization for this case
            try:
                # Helper to convert DF to KLinePoints
                def to_kp_list(preds_df):
                    points = []
                    for idx, row in preds_df.iterrows():
                        points.append(KLinePoint(
                            date=str(idx)[:10],
                            open=row['open'],
                            high=row['high'],
                            low=row['low'],
                            close=row['close'],
                            volume=row['volume'] if 'volume' in row else 0
                        ))
                    return points

                forecast_obj = ForecastResult(
                    ticker=ticker,
                    base_forecast=to_kp_list(p_base),
                    adjusted_forecast=to_kp_list(p_news),
                    rationale=item['summary']
                )

                # Ground truth for visualizer expects a DataFrame with 'date' and 'close'
                gt_df = t[['date', 'open', 'high', 'low', 'close', 'volume']]
                
                chart = VisualizerTools.generate_stock_chart(
                    df=h, 
                    ticker=ticker, 
                    title=f"Training Eval: {ticker} ({date_str}) Improvement: {improvement:.1f}%",
                    forecast=forecast_obj,
                    ground_truth=gt_df
                )
                
                safe_date = date_str.replace("-", "")
                filename = f"eval_{ticker}_{safe_date}.html"
                VisualizerTools.render_chart_to_file(chart, os.path.join(SRC_DIR, f"exports/training_results/{filename}"))
            except Exception as e:
                logger.error(f"Failed to generate eval chart for {ticker}: {e}")

        # Summary Statistics
        avg_base_err = sum(base_maes) / max(1, len(base_maes))
        avg_news_err = sum(news_maes) / max(1, len(news_maes))
        overall_imp = (avg_base_err - avg_news_err) / (avg_base_err + 1e-6) * 100
        
        print("-" * 90)
        print(f"{'AVERAGE':<12} | {'-':<8} | {avg_base_err:<15.4f} | {avg_news_err:<15.4f} | {overall_imp:>+7.1f}%")
        print("="*90 + "\n")
        
        logger.success(f"🏁 Statistical Analysis Complete. Avg Error Reduction ({pred_len}-day): {overall_imp:.2f}%")
        logger.info(f"📊 Visualization results saved to: {os.path.join(SRC_DIR, 'exports/training_results/')}")

if __name__ == "__main__":
    trainer = AutoSynthesisTrainer()
    
    logger.info("📂 Fetching all stock codes from database...")
    res = trainer.db.execute_query("SELECT code FROM stock_list")
    all_tickers = [row['code'] for row in res]
    
    if not all_tickers:
        logger.warning("⚠️ No tickers found in stock_list table. Trying to sync...")
        trainer.tools._check_and_update_stock_list(force=True)
        res = trainer.db.execute_query("SELECT code FROM stock_list")
        all_tickers = [row['code'] for row in res]

    logger.info(f"🚀 Starting training on potential stocks (1-year scan)...")
    # 为了演示,我们扫描前 100 个股票,寻找最近一年的冲击点
    trainer.run_synthesis_and_train(all_tickers[:100], pred_len=1)

```

### scripts/utils/search_tools.py

```python
import os
import hashlib
import json
import re
import requests
import time
import threading
from typing import List, Dict, Optional, Any
from agno.tools.duckduckgo import DuckDuckGoTools
from agno.tools.baidusearch import BaiduSearchTools
from agno.agent import Agent
from loguru import logger
from datetime import datetime
from .database_manager import DatabaseManager
from .content_extractor import ContentExtractor
from .llm.factory import get_model
from .hybrid_search import LocalNewsSearch

# 默认搜索缓存 TTL(秒),可通过环境变量覆盖
DEFAULT_SEARCH_TTL = int(os.getenv("SEARCH_CACHE_TTL", "3600"))  # 默认 1 小时


class JinaSearchEngine:
    """Jina Search API 封装 - 使用 s.jina.ai 进行网络搜索"""
    
    JINA_SEARCH_URL = "https://s.jina.ai/"
    
    # 速率限制配置
    _rate_limit_no_key = 10  # 无 key 时每分钟最大请求数
    _rate_window = 60.0
    _min_interval = 2.0
    _request_times = []
    _last_request_time = 0.0
    _lock = threading.Lock()
    
    def __init__(self):
        self.api_key = os.getenv("JINA_API_KEY", "").strip()
        self.has_api_key = bool(self.api_key)
        if self.has_api_key:
            logger.info("✅ Jina Search API key configured")
    
    @classmethod
    def _wait_for_rate_limit(cls, has_api_key: bool) -> None:
        """等待以满足速率限制"""
        if has_api_key:
            time.sleep(0.3)
            return
        
        with cls._lock:
            current_time = time.time()
            cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
            
            if len(cls._request_times) >= cls._rate_limit_no_key:
                oldest = cls._request_times[0]
                wait_time = cls._rate_window - (current_time - oldest) + 1.0
                if wait_time > 0:
                    logger.warning(f"⏳ Jina Search rate limit, waiting {wait_time:.1f}s...")
                    time.sleep(wait_time)
                    current_time = time.time()
                    cls._request_times = [t for t in cls._request_times if current_time - t < cls._rate_window]
            
            time_since_last = current_time - cls._last_request_time
            if time_since_last < cls._min_interval:
                time.sleep(cls._min_interval - time_since_last)
            
            cls._request_times.append(time.time())
            cls._last_request_time = time.time()
    
    def search(self, query: str, max_results: int = 5) -> List[Dict]:
        """
        使用 Jina Search API 执行搜索
        
        Args:
            query: 搜索关键词
            max_results: 返回结果数量
            
        Returns:
            搜索结果列表,每个结果包含 title, url, content
        """
        if not query:
            return []
        
        logger.info(f"🔍 Jina Search: {query}")
        
        # 等待速率限制
        self._wait_for_rate_limit(self.has_api_key)
        
        headers = {
            "Accept": "application/json",
            "X-Retain-Images": "none",
        }
        
        if self.has_api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"
        
        try:
            # Jina Search API: https://s.jina.ai/{query}
            import urllib.parse
            encoded_query = urllib.parse.quote(query)
            url = f"{self.JINA_SEARCH_URL}{encoded_query}"
            
            response = requests.get(url, headers=headers, timeout=30)
            
            if response.status_code == 429:
                logger.warning("⚠️ Jina Search rate limited (429), waiting 30s...")
                time.sleep(30)
                return self.search(query, max_results)
            
            if response.status_code != 200:
                logger.warning(f"Jina Search failed (Status {response.status_code})")
                return []
            
            # 解析响应
            try:
                data = response.json()
            except json.JSONDecodeError:
                # 如果返回纯文本,尝试解析
                data = {"data": [{"title": "Search Result", "url": "", "content": response.text}]}
            
            results = []
            
            # Jina 返回格式可能是 {"data": [...]} 或直接是列表
            items = data.get("data", []) if isinstance(data, dict) else data
            if not isinstance(items, list):
                items = [items] if items else []
            
            for i, item in enumerate(items[:max_results]):
                if isinstance(item, dict):
                    results.append({
                        "title": item.get("title", f"Result {i+1}"),
                        "url": item.get("url", ""),
                        "href": item.get("url", ""),  # 兼容性
                        "content": item.get("content", item.get("description", "")),
                        "body": item.get("content", item.get("description", "")),  # 兼容性
                    })
                elif isinstance(item, str):
                    results.append({
                        "title": f"Result {i+1}",
                        "url": "",
                        "content": item
                    })
            
            logger.info(f"✅ Jina Search returned {len(results)} results")
            return results
            
        except requests.exceptions.Timeout:
            logger.error("Jina Search timeout")
            return []
        except requests.exceptions.RequestException as e:
            logger.error(f"Jina Search request error: {e}")
            return []
        except Exception as e:
            logger.error(f"Jina Search unexpected error: {e}")
            return []

class SearchTools:
    """扩展性搜索工具库 - 支持多引擎聚合与内容缓存"""
    
    def __init__(self, db: DatabaseManager):
        self.db = db
        
        # 检查 Jina API Key 是否配置
        jina_api_key = os.getenv("JINA_API_KEY", "").strip()
        self._jina_enabled = bool(jina_api_key)
        
        self._engines = {
            "ddg": DuckDuckGoTools(),
            "baidu": BaiduSearchTools(),
            "local": LocalNewsSearch(db)
        }
        
        # 如果配置了 Jina API Key,添加 Jina 引擎
        if self._jina_enabled:
            self._engines["jina"] = JinaSearchEngine()
            logger.info("🚀 Jina Search engine enabled (JINA_API_KEY configured)")
        
        # 确定默认搜索引擎
        self._default_engine = "jina" if self._jina_enabled else "ddg"

    def _generate_hash(self, query: str, engine: str, max_results: int) -> str:
        return hashlib.md5(f"{engine}:{query}:{max_results}".encode()).hexdigest()

    def search(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None) -> str:
        """
        使用指定搜索引擎执行网络搜索,结果会被缓存以提高效率。
        
        Args:
            query: 搜索关键词,如 "英伟达财报" 或 "光伏行业政策"。
            engine: 搜索引擎选择。可选值: 
                    "jina" (Jina Search,需配置 JINA_API_KEY,LLM友好输出),
                    "ddg" (DuckDuckGo,推荐英文/国际搜索), 
                    "baidu" (百度,推荐中文/国内搜索),
                    "local" (本地历史新闻搜索,基于向量+BM25)。
                    默认: 若配置了 JINA_API_KEY 则使用 "jina",否则 "ddg"。
            max_results: 期望返回的结果数量,默认 5 条。
            ttl: 缓存有效期(秒)。如果缓存超过此时间会重新搜索。
                 默认使用环境变量 SEARCH_CACHE_TTL 或 3600 秒。
                 设为 0 可强制刷新。
        
        Returns:
            搜索结果的文本描述,包含标题、摘要和链接。
        """
        # 使用默认引擎(如果配置了 Jina 则优先使用 Jina)
        if engine is None:
            engine = self._default_engine
        
        if engine not in self._engines:
            return f"Error: Unsupported engine '{engine}'. Available: {list(self._engines.keys())}"

        query_hash = self._generate_hash(query, engine, max_results)
        effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
        
        # 1. 尝试从缓存读取 (local 引擎不缓存,因为它本身就是查库)
        if engine != "local":
            cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
            if cache and effective_ttl != 0:
                logger.info(f"ℹ️ Found search results in cache for: {query} ({engine})")
                return cache['results']

        # 2. 执行真实搜索
        logger.info(f"📡 Searching {engine} for: {query}")
        try:
            tool = self._engines[engine]
            if engine == "jina":
                # Jina Search 返回 List[Dict]
                jina_results = tool.search(query, max_results=max_results)
                results = []
                for r in jina_results:
                    results.append({
                        "title": r.get("title", ""),
                        "href": r.get("url", ""),
                        "body": r.get("content", "")
                    })
            elif engine == "ddg":
                results = tool.duckduckgo_search(query, max_results=max_results)
            elif engine == "baidu":
                results = tool.baidu_search(query, max_results=max_results)
            elif engine == "local":
                # LocalNewsSearch 返回的是 List[Dict]
                local_results = tool.search(query, top_n=max_results)
                results = []
                for r in local_results:
                    results.append({
                        "title": r.get("title"),
                        "href": r.get("url", "local"),
                        "body": r.get("content", "")
                    })
            else:
                results = "Search not implemented for this engine."
            
            results_str = str(results)
            if engine != "local":
                self.db.save_search_cache(query_hash, query, engine, results_str)
            return results_str
            
        except Exception as e:
            # 搜索失败时的降级策略
            if engine == "jina":
                logger.warning(f"⚠️ Jina search failed, falling back to ddg: {query} ({e})")
                try:
                    return self.search(query, engine="ddg", max_results=max_results, ttl=ttl)
                except Exception as e2:
                    logger.error(f"❌ DDG fallback also failed for {query}: {e2}")
            elif engine == "ddg":
                logger.warning(f"⚠️ DDG search failed, falling back to baidu: {query} ({e})")
                try:
                    return self.search(query, engine="baidu", max_results=max_results, ttl=ttl)
                except Exception as e2:
                    logger.error(f"❌ Baidu fallback also failed for {query}: {e2}")

            logger.error(f"❌ Search failed for {query}: {e}")
            return f"Error occurred during search: {str(e)}"

    def search_list(self, query: str, engine: str = None, max_results: int = 5, ttl: Optional[int] = None, enrich: bool = True) -> List[Dict]:
        """
        执行搜索并返回结构化列表 (List[Dict])。
        Dict 包含: title, href (or url), body (or snippet)
        
        Args:
            engine: 搜索引擎,默认使用配置的默认引擎(Jina 优先)
            enrich: 是否抓取正文内容 (默认 True)
        """
        # 使用默认引擎
        if engine is None:
            engine = self._default_engine
            
        if engine not in self._engines:
            logger.error(f"Unsupported engine {engine}")
            return []
            
        # 不同的 hash 以区分是否 enrichment
        enrich_suffix = ":enriched" if enrich else ""
        query_hash = self._generate_hash(query, engine + enrich_suffix, max_results)
        effective_ttl = ttl if ttl is not None else DEFAULT_SEARCH_TTL
        
        # 1. 尝试从缓存读取
        cache = self.db.get_search_cache(query_hash, ttl_seconds=effective_ttl if effective_ttl > 0 else None)
        if cache and effective_ttl != 0:
            try:
                cached_data = json.loads(cache['results'])
                if isinstance(cached_data, list):
                    logger.info(f"ℹ️ Found structured search cache for: {query}")
                    return cached_data
            except:
                pass
        
        # 1.5 Smart Cache (Fuzzy + LLM)
        if effective_ttl != 0:
            try:
                # 1. Similar cached queries
                similar_queries = self.db.find_similar_queries(query, limit=3)
                # Filter by TTL
                valid_candidates = []
                for q in similar_queries:
                    if q['query'] == query: continue 
                    q_time = datetime.fromisoformat(q['timestamp'])
                    if effective_ttl and (datetime.now() - q_time).total_seconds() > effective_ttl:
                        continue
                    q['type'] = 'cached_search'
                    valid_candidates.append(q)

                # 2. Relevant local news (as search results)
                local_news = self.db.search_local_news(query, limit=3)
                if local_news:
                    # Group local news as a single "candidate" source? Or individual?
                    # Better to treat "Local News Database" as one candidate source that contains X items.
                    # Or just add them to candidates list?
                    # Let's package strictly relevant news as a "local_news_bundle"
                    valid_candidates.append({
                        'type': 'local_news',
                        'query': 'Local Database News',
                        'items': local_news,
                        'timestamp': datetime.now().isoformat()
                    })
                
                if valid_candidates:
                    logger.info(f"🤔 Found {len(valid_candidates)} smart cache candidates (Queries/News). Asking LLM...")
                    evaluation = self._evaluate_cache_relevance(query, valid_candidates)
                    
                    if evaluation and evaluation.get('reuse', False):
                        idx = evaluation.get('index', -1)
                        if 0 <= idx < len(valid_candidates):
                            chosen = valid_candidates[idx]
                            logger.info(f"🤖 LLM suggested reusing: '{chosen.get('query')}' ({chosen['type']})")
                            
                            if chosen['type'] == 'cached_search':
                                # Load the chosen cache
                                cache = self.db.get_search_cache(chosen['query_hash']) 
                                if cache:
                                    try:
                                        cached_data = json.loads(cache['results'])
                                        if isinstance(cached_data, list):
                                            return cached_data
                                    except:
                                        pass
                            elif chosen['type'] == 'local_news':
                                # Convert local news items to search result format
                                news_results = []
                                for i, news in enumerate(chosen['items'], 1):
                                    news_results.append({
                                        "id": news.get('id'),
                                        "rank": i,
                                        "title": news.get('title'),
                                        "url": news.get('url'),
                                        "content": news.get('content'),
                                        "original_snippet": news.get('content')[:200] if news.get('content') else '',
                                        "source": f"Local News ({news.get('source')})",
                                        "publish_time": news.get('publish_time'),
                                        "crawl_time": news.get('crawl_time'),
                                        "sentiment_score": news.get('sentiment_score', 0),
                                        "meta_data": {"origin": "local_db"}
                                    })
                                return news_results

            except Exception as e:
                logger.warning(f"Smart cache check failed: {e}")
        
        # 2. 执行搜索
        logger.info(f"📡 Searching {engine} (structured) for: {query}")
        try:
            tool = self._engines[engine]
            results = []
            if engine == "jina":
                # Jina Search 直接返回结构化数据
                jina_results = tool.search(query, max_results=max_results)
                for r in jina_results:
                    results.append({
                        "title": r.get("title", ""),
                        "url": r.get("url", ""),
                        "href": r.get("url", ""),
                        "body": r.get("content", ""),
                        "content": r.get("content", ""),
                        "source": "Jina Search"
                    })
            elif engine == "ddg":
                results = tool.duckduckgo_search(query, max_results=max_results)
            elif engine == "baidu":
                results = tool.baidu_search(query, max_results=max_results)
            elif engine == "local":
                # LocalNewsSearch 返回的是 List[Dict]
                local_results = tool.search(query, top_n=max_results)
                results = []
                for r in local_results:
                    results.append({
                        "title": r.get("title"),
                        "url": r.get("url", "local"),
                        "body": r.get("content", "")[:500],
                        "source": f"Local ({r.get('source', 'db')})",
                        "publish_time": r.get("publish_time")
                    })
            
            # 处理字符串类型的 JSON 返回 (Baidu 常返 JSON 字符串)
            if isinstance(results, str) and engine not in ["local", "jina"]:
                try:
                    results = json.loads(results)
                except:
                    pass
            
            # 转为统一格式
            normalized_results = []
            if isinstance(results, list):
                
                for i, r in enumerate(results, 1):
                    title = r.get('title', '')
                    url = r.get('href') or r.get('url') or r.get('link', '')
                    content = r.get('body') or r.get('snippet') or r.get('abstract', '')
                    
                    if title and url:
                        normalized_results.append({
                            "id": self._generate_hash(url + query, "search_item", i),
                            "rank": i,
                            "title": title,
                            "url": url,
                            "content": content,
                            "original_snippet": content, # 保留摘要
                            "source": f"Search ({engine})",
                            "publish_time": datetime.now().isoformat(), # 暂用当前时间
                            "crawl_time": datetime.now().isoformat(),
                            "meta_data": {"query": query, "engine": engine}
                        })
            
            # Fallback if still string and failed to parse
            elif isinstance(results, str) and results:
                 normalized_results.append({"title": query, "url": "", "content": results, "source": engine})

            # 3. 抓取正文 & 计算情绪 (Enrichment)
            # 注意:如果使用 Jina Search,内容已经是 LLM 友好格式,可选择跳过 enrichment
            skip_content_enrichment = (engine == "jina")
            
            if enrich and normalized_results:
                logger.info(f"🕸️ Enriching {len(normalized_results)} search results with Jina & Sentiment...")
                extractor = ContentExtractor()
                
                # Lazy load sentiment tool
                if not hasattr(self, 'sentiment_tool') or self.sentiment_tool is None:
                    from ..sentiment_tools import SentimentTools
                    self.sentiment_tool = SentimentTools(self.db)
                
                for item in normalized_results:
                    if item.get("url"):
                        try:
                            # 如果是 Jina Search,内容已经足够好,跳过额外抓取
                            if skip_content_enrichment and item.get("content") and len(item.get("content", "")) > 100:
                                full_content = item["content"]
                            else:
                                # Use Jina Reader to get full content
                                full_content = extractor.extract_with_jina(item["url"], timeout=60)
                            
                            if full_content and len(full_content) > 100:
                                item["content"] = full_content
                                
                                # Calculate sentiment
                                # Use title + snippet of content for efficiency
                                text_to_analyze = f"{item['title']} {full_content[:500]}"
                                sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)  # Using self.sentiment_tool
                                score = sent_result.get('score', 0.0)
                                item["sentiment_score"] = float(score)
                                
                                logger.info(f"  ✅ Enriched: {item['title'][:20]}... (Sentiment: {score:.2f})")
                            else:
                                # Fallback: Use snippet for sentiment
                                logger.info(f"  ⚠️ Content short/failed for {item['url']}, using snippet for sentiment.")
                                text_to_analyze = f"{item['title']} {item['content']}" # content is snippet here
                                sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
                                score = sent_result.get('score', 0.0)
                                item["sentiment_score"] = float(score)

                        except Exception as e:
                             # Fallback: Use snippet for sentiment on error
                            logger.warning(f"Failed to enrich {item['url']}: {e}. Using snippet.")
                            text_to_analyze = f"{item['title']} {item['content']}"
                            sent_result = self.sentiment_tool.analyze_sentiment(text_to_analyze)
                            score = sent_result.get('score', 0.0)
                            item["sentiment_score"] = float(score)
            
            # 缓存结果 list
            if normalized_results:
                # Pass list directly, DB manager will handle JSON dump for main cache and populate search_details
                # Only cache if NOT from local news reuse (though this logic path is for fresh search)
                self.db.save_search_cache(query_hash, query, engine, normalized_results)
            
            return normalized_results
            
        except Exception as e:
            # 搜索失败时的降级策略
            if engine == "jina":
                logger.warning(f"⚠️ Jina search_list failed, falling back to ddg: {query} ({e})")
                try:
                    return self.search_list(query, engine="ddg", max_results=max_results, ttl=ttl, enrich=enrich)
                except Exception as e2:
                    logger.error(f"❌ DDG fallback (search_list) also failed for {query}: {e2}")
            elif engine == "ddg":
                logger.warning(f"⚠️ DDG search_list failed, falling back to baidu: {query} ({e})")
                try:
                    return self.search_list(query, engine="baidu", max_results=max_results, ttl=ttl, enrich=enrich)
                except Exception as e2:
                    logger.error(f"❌ Baidu fallback (search_list) also failed for {query}: {e2}")

            logger.error(f"❌ Structured search failed for {query}: {e}")
            return []

    def _evaluate_cache_relevance(self, current_query: str, candidates: List[Dict]) -> Dict:
        """
        使用 LLM 评估缓存候选是否足以回答当前问题。
        """
        try:
            # Prepare candidates text
            candidates_desc = []
            for i, c in enumerate(candidates):
                if c['type'] == 'cached_search':
                    # Preview cached results if available? 
                    # Maybe just use the query string as a proxy for what's in there.
                    # Or peek at 'results' snippet.
                    preview = ""
                    try:
                         # Attempt to peek first result title from JSON string
                         # Note: c.get('results') might be a stringified JSON list
                         res_list = json.loads(c.get('results', '[]'))
                         if res_list and isinstance(res_list, list) and len(res_list) > 0:
                             first_item = res_list[0]
                             if isinstance(first_item, dict) and 'title' in first_item:
                                 preview = f" (Contains: {first_item.get('title', '')[:50]}...)"
                    except:
                        pass
                    candidates_desc.append(f"[{i}] Old Search Query: '{c['query']}' {preview} (Time: {c['timestamp']})")
                elif c['type'] == 'local_news':
                     # List titles of local news
                     titles = [item['title'] for item in c['items'][:3]]
                     candidates_desc.append(f"[{i}] Local Database News: {', '.join(titles)}... (Time: {c['timestamp']})")

            prompt = f"""
            Task: Decide if existing information is sufficient for the new search query.
            
            New Query: "{current_query}"
            
            Available Information Candidates:
            {chr(10).join(candidates_desc)}
            
            Instructions:
            1. Analyze if any candidate provides ENOUGH up-to-date info for the "New Query".
            2. If yes, choose the best one.
            3. If the query implies needing LATEST real-time info and candidates are old, choose none.
            4. Return strictly JSON: {{"reuse": true/false, "index": <candidate_index_int>, "reason": "short explanation"}}
            """
            # 初始化模型
            provider = os.getenv("LLM_PROVIDER", "ust")
            model_id = os.getenv("LLM_MODEL", "Qwen")
            host = os.getenv("LLM_HOST")
            if host:
                model = get_model(provider, model_id, host=host)
            else:
                model = get_model(provider, model_id)
                
            agent = Agent(model=model, markdown=True)
            
            response = agent.run(prompt)
            content = response.content
            
            # Parse JSON
            json_match = re.search(r'```json\s*(.*?)\s*```', content, re.DOTALL)
            if json_match:
                return json.loads(json_match.group(1))
            elif '{' in content:
                 # Fallback for cases where LLM doesn't wrap in ```json
                 return json.loads(content[content.find('{'):content.rfind('}')+1])
            return {"reuse": False}
            
        except Exception as e:
            logger.warning(f"LLM evaluation failed: {e}")
            return {"reuse": False}

    def aggregate_search(self, query: str, engines: Optional[List[str]] = None, max_results: int = 5) -> str:
        """
        使用多个搜索引擎同时搜索并聚合结果,获得更全面的信息覆盖。
        
        Args:
            query: 搜索关键词。
            engines: 要使用的搜索引擎列表。可选值: ["ddg", "baidu"]。
                     默认同时使用 ddg 和 baidu。
            max_results: 每个引擎期望返回的结果数量。
        
        Returns:
            聚合后的搜索结果,按引擎分组显示。
        """
        engines = engines or ["ddg", "baidu"]
        aggregated_results = []
        for engine in engines:
            res = self.search(query, engine=engine, max_results=max_results)
            aggregated_results.append(f"--- Results from {engine.upper()} ---\n{res}")
        
        return "\n\n".join(aggregated_results)

```

### scripts/utils/stock_tools.py

```python
from datetime import datetime, timedelta
from typing import List, Dict, Optional
import akshare as ak
import pandas as pd
import re
import sqlite3
from requests.exceptions import RequestException
from loguru import logger
from .database_manager import DatabaseManager
import os
from contextlib import contextmanager

@contextmanager
def temporary_no_proxy():
    """Context manager to temporarily unset proxy environment variables."""
    proxies = {k: os.environ.get(k) for k in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY']}
    for k in proxies:
        if k in os.environ:
            del os.environ[k]
    try:
        yield
    finally:
        for k, v in proxies.items():
            if v is not None:
                os.environ[k] = v

class StockTools:
    """金融分析股票工具 - 结合高性能数据库缓存与增量更新"""
    
    def __init__(self, db: DatabaseManager, auto_update: bool = True):
        """
        初始化股票工具
        
        Args:
            db: 数据库管理器
            auto_update: 是否在列表为空时自动更新,默认 True
        """
        self.db = db
        if auto_update:
            self._check_and_update_stock_list()

    def _check_and_update_stock_list(self, force: bool = False):
        """检查并更新股票列表。仅在列表为空或 force=True 时从网络拉取。"""
        # 直接查询表中记录数
        cursor = self.db.conn.cursor()
        cursor.execute("SELECT COUNT(*) FROM stock_list")
        count = cursor.fetchone()[0]
        
        if count > 0 and not force:
            logger.info(f"ℹ️ Stock list already cached ({count} stocks)")
            return
        
        logger.info("📡 Updating A-share and HK-share stock list from akshare...")
        
        def fetch_data():
            # A-share
            df_a = ak.stock_zh_a_spot_em()
            df_a = df_a[['代码', '名称']].copy()
            df_a.columns = ['code', 'name']
            
            # HK-share
            df_hk = ak.stock_hk_spot_em()
            df_hk = df_hk[['代码', '名称']].copy()
            df_hk.columns = ['code', 'name']
            
            # Combine
            return pd.concat([df_a, df_hk], ignore_index=True)

        try:
            try:
                df_combined = fetch_data()
            except (RequestException, Exception) as e:
                if "Proxy" in str(e) or "proxy" in str(e):
                    logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...")
                    with temporary_no_proxy():
                        df_combined = fetch_data()
                else:
                    raise e
            
            self.db.save_stock_list(df_combined)
            logger.info(f"✅ Cached {len(df_combined)} stocks (A-share + HK) to database.")
            
        except Exception as e:
            logger.error(f"❌ Failed to sync stock list: {e}")


    def search_ticker(self, query: str, limit: int = 5) -> List[Dict]:
        """
        模糊搜索 A 股股票代码或名称,支持常见缩写。
        """
        # 清洗后缀 (如 CATL.SZ -> CATL, 000001.SZ -> 000001)
        clean_query = re.sub(r'\.(SZ|SH|HK|US)$', '', query, flags=re.IGNORECASE)
        
        # 常见缩写映射
        aliases = {
            "CATL": "宁德时代",
            "BYD": "比亚迪",
            "TSLA": "特斯拉",
            "Moutai": "贵州茅台",
            "Tencent": "腾讯",
            "Alibaba": "阿里巴巴",
            "Meituan": "美团",
        }
        
        search_query = aliases.get(clean_query.upper(), clean_query)
        
        # Robustness: if regex-like ticker code is embedded in query (e.g. "300364 中文在线"), try to extract it
        if not search_query.isdigit():
             # Extract explicit 5-6 digit codes
             match = re.search(r'\b(\d{5,6})\b', clean_query)
             if match:
                 search_query = match.group(1)

        return self.db.search_stock(search_query, limit)

    def get_stock_price(
        self,
        ticker: str,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        force_sync: bool = False,
    ) -> pd.DataFrame:
        """
        获取指定股票的历史价格数据。优先从本地缓存读取,缺失时自动从网络补齐。
        
        Args:
            ticker: 股票代码,如 "600519"(贵州茅台)或 "000001"(平安银行)。
            start_date: 开始日期,格式 "YYYY-MM-DD"。默认为 90 天前。
            end_date: 结束日期,格式 "YYYY-MM-DD"。默认为今天。
        
        Returns:
            包含 date, open, close, high, low, volume, change_pct 列的 DataFrame。
        """
        now = datetime.now()
        if not end_date:
            end_date = now.strftime('%Y-%m-%d')
        if not start_date:
            start_date = (now - timedelta(days=90)).strftime('%Y-%m-%d')

        df_db = self.db.get_stock_prices(ticker, start_date, end_date)
        
        need_update = False
        if df_db.empty:
            need_update = True
        else:
            db_latest = pd.to_datetime(df_db['date'].max())
            req_latest = pd.to_datetime(end_date)
            if (req_latest - db_latest).days > 2:
                need_update = True

        if force_sync:
            need_update = True

        if need_update:
            logger.info(f"📡 Data stale or missing for {ticker}, syncing from network...")
            
            # 清洗 ticker,确保只包含数字(Akshare A 股接口通常只需要数字代码)
            clean_ticker = "".join(filter(str.isdigit, ticker))
            if not clean_ticker:
                # Non A/H numeric tickers are not supported by the current data source.
                logger.warning(f"⚠️ Unsupported ticker format (A/H only): {ticker}")
                return df_db

            try:
                s_fmt = start_date.replace("-", "")
                e_fmt = end_date.replace("-", "")
                
                df_remote = None
                
                def fetch_data():
                    if len(clean_ticker) == 5:
                        # HK Stock
                        return ak.stock_hk_hist(
                            symbol=clean_ticker, period="daily",
                            start_date=s_fmt, end_date=e_fmt,
                            adjust="qfq"
                        )
                    else:
                        # A-share Stock
                        return ak.stock_zh_a_hist(
                            symbol=clean_ticker, period="daily",
                            start_date=s_fmt, end_date=e_fmt,
                            adjust="qfq"
                        )

                try:
                    df_remote = fetch_data()
                except (RequestException, Exception) as e:
                    if "Proxy" in str(e) or "proxy" in str(e):
                        logger.warning(f"⚠️ Proxy error detected: {e}. Retrying with proxy disabled...")
                        with temporary_no_proxy():
                            df_remote = fetch_data()
                    else:
                        raise e
                
                if df_remote is not None and not df_remote.empty:
                    df_remote = df_remote.rename(columns={
                        '日期': 'date', '开盘': 'open', '收盘': 'close',
                        '最高': 'high', '最低': 'low', '成交量': 'volume',
                        '涨跌幅': 'change_pct'
                    })
                    # 确保日期格式正确
                    df_remote['date'] = pd.to_datetime(df_remote['date']).dt.strftime('%Y-%m-%d')
                    
                    # 只有在获取到有意义的数据时才保存
                    self.db.save_stock_prices(clean_ticker, df_remote) # 保存时使用清洗后的 clean_ticker
                    
                    # 重新查询数据库返回结果,保证一致性
                    return self.db.get_stock_prices(clean_ticker, start_date, end_date)
                else:
                    logger.warning(f"⚠️ Akshare returned empty data for {clean_ticker}")
                    
            except KeyError as e:
                # Akshare 有时在某些股票无数据时会抛出 KeyError
                logger.warning(f"⚠️ Akshare data missing for {clean_ticker}: {e}")
            except (RequestException, ConnectionError) as e:
                logger.error(f"❌ Network error during Akshare sync for {clean_ticker}: {e}")
            except sqlite3.Error as e:
                logger.error(f"❌ Database error during Akshare sync for {clean_ticker}: {e}")
            except Exception as e:
                logger.error(f"❌ Unexpected error during Akshare sync for {clean_ticker}: {e}")
        
        return df_db


def get_stock_analysis(ticker: str, db: DatabaseManager) -> str:
    """
    生成指定股票的分析摘要报告。
    
    Args:
        ticker: 股票代码
        db: 数据库管理器实例
    
    Returns:
        Markdown 格式的分析报告,包含价格走势和关键指标。
    """
    tools = StockTools(db)
    df = tools.get_stock_price(ticker)
    
    if df.empty:
        return f"❌ 未能获取 {ticker} 的股价数据。"
    
    latest = df.iloc[-1]
    change = ((latest['close'] - df.iloc[0]['close']) / df.iloc[0]['close']) * 100
    
    report = [
        f"## 📊 {ticker} 分析报告",
        f"- **查询时段**: {df.iloc[0]['date']} -> {latest['date']}",
        f"- **当前价**: ¥{latest['close']:.2f}",
        f"- **时段涨跌**: {change:+.2f}%",
        f"- **最高/最低**: ¥{df['high'].max():.2f} / ¥{df['low'].min():.2f}",
        "\n### 最近交易概览",
        "```",
        df.tail(5)[['date', 'close', 'change_pct', 'volume']].to_string(index=False),
        "```"
    ]
    return "\n".join(report)

```

alphaear-predictor | SkillHub