Coverage for stackone_ai / local_search.py: 92%

72 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-04-02 08:51 +0000

1"""Local BM25 + TF-IDF hybrid keyword search for tool discovery.""" 

2 

3from __future__ import annotations 

4 

5import bm25s 

6import numpy as np 

7from pydantic import BaseModel 

8 

9from stackone_ai.constants import DEFAULT_HYBRID_ALPHA 

10from stackone_ai.models import StackOneTool 

11from stackone_ai.utils.tfidf_index import TfidfDocument, TfidfIndex 

12 

13 

14class ToolSearchResult(BaseModel): 

15 """Result from tool_search""" 

16 

17 name: str 

18 description: str 

19 score: float 

20 

21 

22class ToolIndex: 

23 """Hybrid BM25 + TF-IDF tool search index""" 

24 

25 def __init__(self, tools: list[StackOneTool], hybrid_alpha: float | None = None) -> None: 

26 """Initialize tool index with hybrid search 

27 

28 Args: 

29 tools: List of tools to index 

30 hybrid_alpha: Weight for BM25 in hybrid search (0-1). If not provided, 

31 uses DEFAULT_HYBRID_ALPHA (0.2), which gives more weight to BM25 scoring 

32 and has been shown to provide better tool discovery accuracy 

33 (10.8% improvement in validation testing). 

34 """ 

35 self.tools = tools 

36 self.tool_map = {tool.name: tool for tool in tools} 

37 # Use default if not provided, then clamp to [0, 1] 

38 alpha = hybrid_alpha if hybrid_alpha is not None else DEFAULT_HYBRID_ALPHA 

39 self.hybrid_alpha = max(0.0, min(1.0, alpha)) 

40 

41 # Prepare corpus for both BM25 and TF-IDF 

42 corpus = [] 

43 tfidf_docs = [] 

44 self.tool_names = [] 

45 

46 for tool in tools: 

47 # Extract category and action from tool name 

48 parts = tool.name.split("_") 

49 category = parts[0] if parts else "" 

50 

51 # Extract action types 

52 action_types = ["create", "update", "delete", "get", "list", "search"] 

53 actions = [p for p in parts if p in action_types] 

54 

55 # Combine name, description, category and tags for indexing 

56 # For TF-IDF: use weighted approach similar to Node.js 

57 tfidf_text = " ".join( 

58 [ 

59 f"{tool.name} {tool.name} {tool.name}", # boost name 

60 f"{category} {' '.join(actions)}", 

61 tool.description, 

62 " ".join(parts), 

63 ] 

64 ) 

65 

66 # For BM25: simpler approach 

67 bm25_text = " ".join( 

68 [ 

69 tool.name, 

70 tool.description, 

71 category, 

72 " ".join(parts), 

73 " ".join(actions), 

74 ] 

75 ) 

76 

77 corpus.append(bm25_text) 

78 tfidf_docs.append(TfidfDocument(id=tool.name, text=tfidf_text)) 

79 self.tool_names.append(tool.name) 

80 

81 # Create BM25 index 

82 self.bm25_retriever = bm25s.BM25() 

83 if corpus: 83 ↛ 88line 83 didn't jump to line 88 because the condition on line 83 was always true

84 corpus_tokens = bm25s.tokenize(corpus, stemmer=None, show_progress=False) # ty: ignore[invalid-argument-type] 

85 self.bm25_retriever.index(corpus_tokens) 

86 

87 # Create TF-IDF index 

88 self.tfidf_index = TfidfIndex() 

89 if tfidf_docs: 89 ↛ exitline 89 didn't return from function '__init__' because the condition on line 89 was always true

90 self.tfidf_index.build(tfidf_docs) 

91 

92 def search(self, query: str, limit: int = 5, min_score: float = 0.0) -> list[ToolSearchResult]: 

93 """Search for relevant tools using hybrid BM25 + TF-IDF 

94 

95 Args: 

96 query: Natural language query 

97 limit: Maximum number of results 

98 min_score: Minimum relevance score (0-1) 

99 

100 Returns: 

101 List of search results sorted by relevance 

102 """ 

103 if not self.tools: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 return [] 

105 

106 # Get more results initially to have better candidate pool for fusion 

107 fetch_limit = max(50, limit) 

108 

109 # Tokenize query for BM25 

110 query_tokens = bm25s.tokenize([query], stemmer=None, show_progress=False) # ty: ignore[invalid-argument-type] 

111 

112 # Search with BM25 

113 bm25_results, bm25_scores = self.bm25_retriever.retrieve( 

114 query_tokens, k=min(fetch_limit, len(self.tools)) 

115 ) 

116 

117 # Search with TF-IDF 

118 tfidf_results = self.tfidf_index.search(query, k=min(fetch_limit, len(self.tools))) 

119 

120 # Build score map for fusion 

121 score_map: dict[str, dict[str, float]] = {} 

122 

123 # Add BM25 scores 

124 for idx, score in zip(bm25_results[0], bm25_scores[0], strict=True): 

125 tool_name = self.tool_names[idx] 

126 # Normalize BM25 score to 0-1 range 

127 normalized_score = float(1 / (1 + np.exp(-score / 10))) 

128 # Clamp to [0, 1] 

129 clamped_score = max(0.0, min(1.0, normalized_score)) 

130 score_map[tool_name] = {"bm25": clamped_score} 

131 

132 # Add TF-IDF scores 

133 for result in tfidf_results: 

134 if result.id not in score_map: 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 score_map[result.id] = {} 

136 score_map[result.id]["tfidf"] = result.score 

137 

138 # Fuse scores: hybrid_score = alpha * bm25 + (1 - alpha) * tfidf 

139 fused_results: list[tuple[str, float]] = [] 

140 for tool_name, scores in score_map.items(): 

141 bm25_score = scores.get("bm25", 0.0) 

142 tfidf_score = scores.get("tfidf", 0.0) 

143 hybrid_score = self.hybrid_alpha * bm25_score + (1 - self.hybrid_alpha) * tfidf_score 

144 fused_results.append((tool_name, hybrid_score)) 

145 

146 # Sort by score descending 

147 fused_results.sort(key=lambda x: x[1], reverse=True) 

148 

149 # Build final results 

150 search_results = [] 

151 for tool_name, score in fused_results: 

152 if score < min_score: 

153 continue 

154 

155 tool = self.tool_map.get(tool_name) 

156 if tool is None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 continue 

158 

159 search_results.append( 

160 ToolSearchResult( 

161 name=tool.name, 

162 description=tool.description, 

163 score=score, 

164 ) 

165 ) 

166 

167 if len(search_results) >= limit: 

168 break 

169 

170 return search_results