Coverage for stackone_ai/models.py: 95%

252 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-12-24 09:48 +0000

1# TODO: Remove when Python 3.9 support is dropped 

2from __future__ import annotations 

3 

4import base64 

5import json 

6import logging 

7from collections.abc import Sequence 

8from datetime import datetime, timezone 

9from enum import Enum 

10from typing import Annotated, Any, ClassVar, cast 

11from urllib.parse import quote 

12 

13import httpx 

14from langchain_core.tools import BaseTool 

15from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr 

16 

17# TODO: Remove when Python 3.9 support is dropped 

18from typing_extensions import TypeAlias 

19 

20# Type aliases for common types 

21JsonDict: TypeAlias = dict[str, Any] 

22Headers: TypeAlias = dict[str, str] 

23 

24 

25logger = logging.getLogger("stackone.tools") 

26 

27 

28class StackOneError(Exception): 

29 """Base exception for StackOne errors""" 

30 

31 pass 

32 

33 

34class StackOneAPIError(StackOneError): 

35 """Raised when the StackOne API returns an error""" 

36 

37 def __init__(self, message: str, status_code: int, response_body: Any) -> None: 

38 super().__init__(message) 

39 self.status_code = status_code 

40 self.response_body = response_body 

41 

42 

43class ParameterLocation(str, Enum): 

44 """Valid locations for parameters in requests""" 

45 

46 HEADER = "header" 

47 QUERY = "query" 

48 PATH = "path" 

49 BODY = "body" 

50 FILE = "file" # For file uploads 

51 

52 

53def validate_method(v: str) -> str: 

54 """Validate HTTP method is uppercase and supported""" 

55 method = v.upper() 

56 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}: 

57 raise ValueError(f"Unsupported HTTP method: {method}") 

58 return method 

59 

60 

61class ExecuteConfig(BaseModel): 

62 """Configuration for executing a tool against an API endpoint""" 

63 

64 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request") 

65 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use") 

66 url: str = Field(description="API endpoint URL") 

67 name: str = Field(description="Tool name") 

68 body_type: str | None = Field(default=None, description="Content type for request body") 

69 parameter_locations: dict[str, ParameterLocation] = Field( 

70 default_factory=dict, description="Maps parameter names to their location in the request" 

71 ) 

72 

73 

74class ToolParameters(BaseModel): 

75 """Schema definition for tool parameters""" 

76 

77 type: str = Field(description="JSON Schema type") 

78 properties: JsonDict = Field(description="JSON Schema properties") 

79 

80 

81class ToolDefinition(BaseModel): 

82 """Complete definition of a tool including its schema and execution config""" 

83 

84 description: str = Field(description="Tool description") 

85 parameters: ToolParameters = Field(description="Tool parameter schema") 

86 execute: ExecuteConfig = Field(description="Tool execution configuration") 

87 

88 

89class StackOneTool(BaseModel): 

90 """Base class for all StackOne tools. Provides functionality for executing API calls 

91 and converting to various formats (OpenAI, LangChain).""" 

92 

93 name: str = Field(description="Tool name") 

94 description: str = Field(description="Tool description") 

95 parameters: ToolParameters = Field(description="Tool parameters") 

96 _execute_config: ExecuteConfig = PrivateAttr() 

97 _api_key: str = PrivateAttr() 

98 _account_id: str | None = PrivateAttr(default=None) 

99 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = { 

100 "feedback_session_id", 

101 "feedback_user_id", 

102 "feedback_metadata", 

103 } 

104 

105 def __init__( 

106 self, 

107 description: str, 

108 parameters: ToolParameters, 

109 _execute_config: ExecuteConfig, 

110 _api_key: str, 

111 _account_id: str | None = None, 

112 ) -> None: 

113 super().__init__( 

114 name=_execute_config.name, 

115 description=description, 

116 parameters=parameters, 

117 ) 

118 self._execute_config = _execute_config 

119 self._api_key = _api_key 

120 self._account_id = _account_id 

121 

122 @classmethod 

123 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]: 

124 merged_params = dict(params) 

125 feedback_options = dict(options or {}) 

126 for key in cls._FEEDBACK_OPTION_KEYS: 

127 if key in merged_params and key not in feedback_options: 

128 feedback_options[key] = merged_params.pop(key) 

129 return merged_params, feedback_options 

130 

131 def _prepare_headers(self) -> Headers: 

132 """Prepare headers for the API request 

133 

134 Returns: 

135 Headers to use in the request 

136 """ 

137 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode() 

138 headers: Headers = { 

139 "Authorization": f"Basic {auth_string}", 

140 "User-Agent": "stackone-python/1.0.0", 

141 } 

142 

143 if self._account_id: 

144 headers["x-account-id"] = self._account_id 

145 

146 # Add predefined headers 

147 headers.update(self._execute_config.headers) 

148 return headers 

149 

150 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]: 

151 """Prepare URL and parameters for the API request 

152 

153 Args: 

154 kwargs: Arguments to process 

155 

156 Returns: 

157 Tuple of (url, body_params, query_params) 

158 """ 

159 url = self._execute_config.url 

160 body_params: JsonDict = {} 

161 query_params: JsonDict = {} 

162 

163 for key, value in kwargs.items(): 

164 param_location = self._execute_config.parameter_locations.get(key) 

165 

166 if param_location == ParameterLocation.PATH: 

167 # Safely encode path parameters to prevent SSRF attacks 

168 encoded_value = quote(str(value), safe="") 

169 url = url.replace(f"{{{key}}}", encoded_value) 

170 elif param_location == ParameterLocation.QUERY: 

171 query_params[key] = value 

172 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE): 

173 body_params[key] = value 

174 else: 

175 # Default behavior 

176 if f"{{{key}}}" in url: 

177 # Safely encode path parameters to prevent SSRF attacks 

178 encoded_value = quote(str(value), safe="") 

179 url = url.replace(f"{{{key}}}", encoded_value) 

180 elif self._execute_config.method in {"GET", "DELETE"}: 

181 query_params[key] = value 

182 else: 

183 body_params[key] = value 

184 

185 return url, body_params, query_params 

186 

187 def execute( 

188 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None 

189 ) -> JsonDict: 

190 """Execute the tool with the given parameters 

191 

192 Args: 

193 arguments: Tool arguments as string or dict 

194 options: Execution options (e.g. feedback metadata) 

195 

196 Returns: 

197 API response as dict 

198 

199 Raises: 

200 StackOneAPIError: If the API request fails 

201 ValueError: If the arguments are invalid 

202 """ 

203 datetime.now(timezone.utc) 

204 feedback_options: JsonDict = {} 

205 result_payload: JsonDict | None = None 

206 response_status: int | None = None 

207 error_message: str | None = None 

208 status = "success" 

209 url_used = self._execute_config.url 

210 

211 try: 

212 if isinstance(arguments, str): 

213 parsed_arguments = json.loads(arguments) 

214 else: 

215 parsed_arguments = arguments or {} 

216 

217 if not isinstance(parsed_arguments, dict): 

218 status = "error" 

219 error_message = "Tool arguments must be a JSON object" 

220 raise ValueError(error_message) 

221 

222 kwargs = parsed_arguments 

223 dict(kwargs) 

224 

225 headers = self._prepare_headers() 

226 url_used, body_params, query_params = self._prepare_request_params(kwargs) 

227 

228 request_kwargs: dict[str, Any] = { 

229 "method": self._execute_config.method, 

230 "url": url_used, 

231 "headers": headers, 

232 } 

233 

234 if body_params: 

235 body_type = self._execute_config.body_type or "json" 

236 if body_type == "json": 

237 request_kwargs["json"] = body_params 

238 elif body_type == "form": 238 ↛ 241line 238 didn't jump to line 241 because the condition on line 238 was always true

239 request_kwargs["data"] = body_params 

240 

241 if query_params: 

242 request_kwargs["params"] = query_params 

243 

244 response = httpx.request(**request_kwargs) 

245 response_status = response.status_code 

246 response.raise_for_status() 

247 

248 result = response.json() 

249 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result} 

250 return result_payload 

251 

252 except json.JSONDecodeError as exc: 

253 status = "error" 

254 error_message = f"Invalid JSON in arguments: {exc}" 

255 raise ValueError(error_message) from exc 

256 except httpx.HTTPStatusError as exc: 

257 status = "error" 

258 response_body = None 

259 if exc.response.text: 259 ↛ 264line 259 didn't jump to line 264 because the condition on line 259 was always true

260 try: 

261 response_body = exc.response.json() 

262 except json.JSONDecodeError: 

263 response_body = exc.response.text 

264 raise StackOneAPIError( 

265 str(exc), 

266 exc.response.status_code, 

267 response_body, 

268 ) from exc 

269 except httpx.RequestError as exc: 

270 status = "error" 

271 raise StackOneError(f"Request failed: {exc}") from exc 

272 finally: 

273 datetime.now(timezone.utc) 

274 metadata: JsonDict = { 

275 "http_method": self._execute_config.method, 

276 "url": url_used, 

277 "status_code": response_status, 

278 "status": status, 

279 } 

280 

281 feedback_metadata = feedback_options.get("feedback_metadata") 

282 if isinstance(feedback_metadata, dict): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true

283 metadata["feedback_metadata"] = feedback_metadata 

284 

285 if feedback_options: 

286 metadata["feedback_options"] = { 

287 key: value 

288 for key, value in feedback_options.items() 

289 if key in {"feedback_session_id", "feedback_user_id"} and value is not None 

290 } 

291 

292 # Implicit feedback removed - just API calls 

293 

294 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

295 """Call the tool with the given arguments 

296 

297 This method provides a more intuitive way to execute tools directly. 

298 

299 Args: 

300 *args: If a single argument is provided, it's treated as the full arguments dict/string 

301 **kwargs: Keyword arguments to pass to the tool 

302 options: Optional execution options 

303 

304 Returns: 

305 API response as dict 

306 

307 Raises: 

308 StackOneAPIError: If the API request fails 

309 ValueError: If the arguments are invalid 

310 

311 Examples: 

312 >>> tool.call({"name": "John", "email": "john@example.com"}) 

313 >>> tool.call(name="John", email="john@example.com") 

314 """ 

315 if args and kwargs: 

316 raise ValueError("Cannot provide both positional and keyword arguments") 

317 

318 if args: 

319 if len(args) > 1: 

320 raise ValueError("Only one positional argument is allowed") 

321 return self.execute(args[0]) 

322 

323 return self.execute(kwargs if kwargs else None) 

324 

325 def to_openai_function(self) -> JsonDict: 

326 """Convert this tool to OpenAI's function format 

327 

328 Returns: 

329 Tool definition in OpenAI function format 

330 """ 

331 # Clean properties and handle special types 

332 properties = {} 

333 required = [] 

334 

335 for name, prop in self.parameters.properties.items(): 

336 if isinstance(prop, dict): 

337 # Only keep standard JSON Schema properties 

338 cleaned_prop = {} 

339 

340 # Copy basic properties 

341 if "type" in prop: 

342 cleaned_prop["type"] = prop["type"] 

343 if "description" in prop: 

344 cleaned_prop["description"] = prop["description"] 

345 if "enum" in prop: 

346 cleaned_prop["enum"] = prop["enum"] 

347 

348 # Handle array types 

349 if cleaned_prop.get("type") == "array" and "items" in prop: 

350 if isinstance(prop["items"], dict): 350 ↛ 356line 350 didn't jump to line 356 because the condition on line 350 was always true

351 cleaned_prop["items"] = { 

352 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum") 

353 } 

354 

355 # Handle object types 

356 if cleaned_prop.get("type") == "object" and "properties" in prop: 

357 cleaned_prop["properties"] = { 

358 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")} 

359 for k, v in prop["properties"].items() 

360 } 

361 

362 # Handle required fields - if not explicitly nullable 

363 if not prop.get("nullable", False): 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true

364 required.append(name) 

365 

366 properties[name] = cleaned_prop 

367 else: 

368 properties[name] = {"type": "string"} 

369 required.append(name) 

370 

371 # Create the OpenAI function schema 

372 parameters = { 

373 "type": "object", 

374 "properties": properties, 

375 } 

376 

377 # Only include required if there are required fields 

378 if required: 378 ↛ 381line 378 didn't jump to line 381 because the condition on line 378 was always true

379 parameters["required"] = required 

380 

381 return { 

382 "type": "function", 

383 "function": { 

384 "name": self.name, 

385 "description": self.description, 

386 "parameters": parameters, 

387 }, 

388 } 

389 

390 def to_langchain(self) -> BaseTool: 

391 """Convert this tool to LangChain format 

392 

393 Returns: 

394 Tool in LangChain format 

395 """ 

396 # Create properly annotated schema for the tool 

397 schema_props: dict[str, Any] = {} 

398 annotations: dict[str, Any] = {} 

399 

400 for name, details in self.parameters.properties.items(): 

401 python_type: type = str # Default to str 

402 if isinstance(details, dict): 

403 type_str = details.get("type", "string") 

404 if type_str == "number": 

405 python_type = float 

406 elif type_str == "integer": 

407 python_type = int 

408 elif type_str == "boolean": 

409 python_type = bool 

410 

411 field = Field(description=details.get("description", "")) 

412 else: 

413 field = Field(description="") 

414 

415 schema_props[name] = field 

416 annotations[name] = python_type 

417 

418 # Create the schema class with proper annotations 

419 schema_class = type( 

420 f"{self.name.title()}Args", 

421 (BaseModel,), 

422 { 

423 "__annotations__": annotations, 

424 "__module__": __name__, 

425 **schema_props, 

426 }, 

427 ) 

428 

429 parent_tool = self 

430 

431 class StackOneLangChainTool(BaseTool): 

432 name: str = parent_tool.name 

433 description: str = parent_tool.description 

434 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment] 

435 func = staticmethod(parent_tool.execute) # Required by CrewAI 

436 

437 def _run(self, **kwargs: Any) -> Any: 

438 return parent_tool.execute(kwargs) 

439 

440 return StackOneLangChainTool() 

441 

442 def set_account_id(self, account_id: str | None) -> None: 

443 """Set the account ID for this tool 

444 

445 Args: 

446 account_id: The account ID to use, or None to clear it 

447 """ 

448 self._account_id = account_id 

449 

450 def get_account_id(self) -> str | None: 

451 """Get the current account ID for this tool 

452 

453 Returns: 

454 Current account ID or None if not set 

455 """ 

456 return self._account_id 

457 

458 

459class Tools: 

460 """Container for Tool instances with lookup capabilities""" 

461 

462 def __init__(self, tools: list[StackOneTool]) -> None: 

463 """Initialize Tools container 

464 

465 Args: 

466 tools: List of Tool instances to manage 

467 """ 

468 self.tools = tools 

469 self._tool_map = {tool.name: tool for tool in tools} 

470 

471 def __getitem__(self, index: int) -> StackOneTool: 

472 return self.tools[index] 

473 

474 def __len__(self) -> int: 

475 return len(self.tools) 

476 

477 def __iter__(self) -> Any: 

478 """Make Tools iterable""" 

479 return iter(self.tools) 

480 

481 def to_list(self) -> list[StackOneTool]: 

482 """Convert to list of tools 

483 

484 Returns: 

485 List of StackOneTool instances 

486 """ 

487 return list(self.tools) 

488 

489 def get_tool(self, name: str) -> StackOneTool | None: 

490 """Get a tool by its name 

491 

492 Args: 

493 name: Name of the tool to retrieve 

494 

495 Returns: 

496 The tool if found, None otherwise 

497 """ 

498 return self._tool_map.get(name) 

499 

500 def set_account_id(self, account_id: str | None) -> None: 

501 """Set the account ID for all tools in this collection 

502 

503 Args: 

504 account_id: The account ID to use, or None to clear it 

505 """ 

506 for tool in self.tools: 

507 tool.set_account_id(account_id) 

508 

509 def get_account_id(self) -> str | None: 

510 """Get the current account ID for this collection 

511 

512 Returns: 

513 The first non-None account ID found, or None if none set 

514 """ 

515 for tool in self.tools: 

516 account_id = tool.get_account_id() 

517 if isinstance(account_id, str): 

518 return account_id 

519 return None 

520 

521 def to_openai(self) -> list[JsonDict]: 

522 """Convert all tools to OpenAI function format 

523 

524 Returns: 

525 List of tools in OpenAI function format 

526 """ 

527 return [tool.to_openai_function() for tool in self.tools] 

528 

529 def to_langchain(self) -> Sequence[BaseTool]: 

530 """Convert all tools to LangChain format 

531 

532 Returns: 

533 Sequence of tools in LangChain format 

534 """ 

535 return [tool.to_langchain() for tool in self.tools] 

536 

537 def meta_tools(self, hybrid_alpha: float | None = None) -> Tools: 

538 """Return meta tools for tool discovery and execution 

539 

540 Meta tools enable dynamic tool discovery and execution based on natural language queries 

541 using hybrid BM25 + TF-IDF search. 

542 

543 Args: 

544 hybrid_alpha: Weight for BM25 in hybrid search (0-1). If not provided, uses 

545 ToolIndex.DEFAULT_HYBRID_ALPHA (0.2), which gives more weight to BM25 scoring 

546 and has been shown to provide better tool discovery accuracy 

547 (10.8% improvement in validation testing). 

548 

549 Returns: 

550 Tools collection containing meta_search_tools and meta_execute_tool 

551 

552 Note: 

553 This feature is in beta and may change in future versions 

554 """ 

555 from stackone_ai.meta_tools import ( 

556 ToolIndex, 

557 create_meta_execute_tool, 

558 create_meta_search_tools, 

559 ) 

560 

561 # Create search index with hybrid search 

562 index = ToolIndex(self.tools, hybrid_alpha=hybrid_alpha) 

563 

564 # Create meta tools 

565 filter_tool = create_meta_search_tools(index) 

566 execute_tool = create_meta_execute_tool(self) 

567 

568 return Tools([filter_tool, execute_tool])