Coverage for stackone_ai / models.py: 98%

251 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-08 18:25 +0000

1from __future__ import annotations 

2 

3import base64 

4import json 

5import logging 

6from collections.abc import Sequence 

7from datetime import datetime, timezone 

8from enum import Enum 

9from typing import Annotated, Any, ClassVar, TypeAlias, cast 

10from urllib.parse import quote 

11 

12import httpx 

13from langchain_core.tools import BaseTool 

14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr 

15 

16# Type aliases for common types 

17JsonDict: TypeAlias = dict[str, Any] 

18Headers: TypeAlias = dict[str, str] 

19 

20 

21logger = logging.getLogger("stackone.tools") 

22 

23 

24class StackOneError(Exception): 

25 """Base exception for StackOne errors""" 

26 

27 pass 

28 

29 

30class StackOneAPIError(StackOneError): 

31 """Raised when the StackOne API returns an error""" 

32 

33 def __init__(self, message: str, status_code: int, response_body: Any) -> None: 

34 super().__init__(message) 

35 self.status_code = status_code 

36 self.response_body = response_body 

37 

38 

39class ParameterLocation(str, Enum): 

40 """Valid locations for parameters in requests""" 

41 

42 HEADER = "header" 

43 QUERY = "query" 

44 PATH = "path" 

45 BODY = "body" 

46 FILE = "file" # For file uploads 

47 

48 

49def validate_method(v: str) -> str: 

50 """Validate HTTP method is uppercase and supported""" 

51 method = v.upper() 

52 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}: 

53 raise ValueError(f"Unsupported HTTP method: {method}") 

54 return method 

55 

56 

57class ExecuteConfig(BaseModel): 

58 """Configuration for executing a tool against an API endpoint""" 

59 

60 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request") 

61 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use") 

62 url: str = Field(description="API endpoint URL") 

63 name: str = Field(description="Tool name") 

64 body_type: str | None = Field(default=None, description="Content type for request body") 

65 parameter_locations: dict[str, ParameterLocation] = Field( 

66 default_factory=dict, description="Maps parameter names to their location in the request" 

67 ) 

68 

69 

70class ToolParameters(BaseModel): 

71 """Schema definition for tool parameters""" 

72 

73 type: str = Field(description="JSON Schema type") 

74 properties: JsonDict = Field(description="JSON Schema properties") 

75 

76 

77class ToolDefinition(BaseModel): 

78 """Complete definition of a tool including its schema and execution config""" 

79 

80 description: str = Field(description="Tool description") 

81 parameters: ToolParameters = Field(description="Tool parameter schema") 

82 execute: ExecuteConfig = Field(description="Tool execution configuration") 

83 

84 

85class StackOneTool(BaseModel): 

86 """Base class for all StackOne tools. Provides functionality for executing API calls 

87 and converting to various formats (OpenAI, LangChain).""" 

88 

89 name: str = Field(description="Tool name") 

90 description: str = Field(description="Tool description") 

91 parameters: ToolParameters = Field(description="Tool parameters") 

92 _execute_config: ExecuteConfig = PrivateAttr() 

93 _api_key: str = PrivateAttr() 

94 _account_id: str | None = PrivateAttr(default=None) 

95 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = { 

96 "feedback_session_id", 

97 "feedback_user_id", 

98 "feedback_metadata", 

99 } 

100 

101 def __init__( 

102 self, 

103 description: str, 

104 parameters: ToolParameters, 

105 _execute_config: ExecuteConfig, 

106 _api_key: str, 

107 _account_id: str | None = None, 

108 ) -> None: 

109 super().__init__( 

110 name=_execute_config.name, 

111 description=description, 

112 parameters=parameters, 

113 ) 

114 self._execute_config = _execute_config 

115 self._api_key = _api_key 

116 self._account_id = _account_id 

117 

118 @classmethod 

119 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]: 

120 merged_params = dict(params) 

121 feedback_options = dict(options or {}) 

122 for key in cls._FEEDBACK_OPTION_KEYS: 

123 if key in merged_params and key not in feedback_options: 

124 feedback_options[key] = merged_params.pop(key) 

125 return merged_params, feedback_options 

126 

127 def _prepare_headers(self) -> Headers: 

128 """Prepare headers for the API request 

129 

130 Returns: 

131 Headers to use in the request 

132 """ 

133 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode() 

134 headers: Headers = { 

135 "Authorization": f"Basic {auth_string}", 

136 "User-Agent": "stackone-python/1.0.0", 

137 } 

138 

139 if self._account_id: 

140 headers["x-account-id"] = self._account_id 

141 

142 # Add predefined headers 

143 headers.update(self._execute_config.headers) 

144 return headers 

145 

146 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]: 

147 """Prepare URL and parameters for the API request 

148 

149 Args: 

150 kwargs: Arguments to process 

151 

152 Returns: 

153 Tuple of (url, body_params, query_params) 

154 """ 

155 url = self._execute_config.url 

156 body_params: JsonDict = {} 

157 query_params: JsonDict = {} 

158 

159 for key, value in kwargs.items(): 

160 param_location = self._execute_config.parameter_locations.get(key) 

161 

162 if param_location == ParameterLocation.PATH: 

163 # Safely encode path parameters to prevent SSRF attacks 

164 encoded_value = quote(str(value), safe="") 

165 url = url.replace(f"{{{key}}}", encoded_value) 

166 elif param_location == ParameterLocation.QUERY: 

167 query_params[key] = value 

168 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE): 

169 body_params[key] = value 

170 else: 

171 # Default behavior 

172 if f"{{{key}}}" in url: 

173 # Safely encode path parameters to prevent SSRF attacks 

174 encoded_value = quote(str(value), safe="") 

175 url = url.replace(f"{{{key}}}", encoded_value) 

176 elif self._execute_config.method in {"GET", "DELETE"}: 

177 query_params[key] = value 

178 else: 

179 body_params[key] = value 

180 

181 return url, body_params, query_params 

182 

183 def execute( 

184 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None 

185 ) -> JsonDict: 

186 """Execute the tool with the given parameters 

187 

188 Args: 

189 arguments: Tool arguments as string or dict 

190 options: Execution options (e.g. feedback metadata) 

191 

192 Returns: 

193 API response as dict 

194 

195 Raises: 

196 StackOneAPIError: If the API request fails 

197 ValueError: If the arguments are invalid 

198 """ 

199 datetime.now(timezone.utc) 

200 feedback_options: JsonDict = {} 

201 result_payload: JsonDict | None = None 

202 response_status: int | None = None 

203 error_message: str | None = None 

204 status = "success" 

205 url_used = self._execute_config.url 

206 

207 try: 

208 if isinstance(arguments, str): 

209 parsed_arguments = json.loads(arguments) 

210 else: 

211 parsed_arguments = arguments or {} 

212 

213 if not isinstance(parsed_arguments, dict): 

214 status = "error" 

215 error_message = "Tool arguments must be a JSON object" 

216 raise ValueError(error_message) 

217 

218 kwargs = parsed_arguments 

219 dict(kwargs) 

220 

221 headers = self._prepare_headers() 

222 url_used, body_params, query_params = self._prepare_request_params(kwargs) 

223 

224 request_kwargs: dict[str, Any] = { 

225 "method": self._execute_config.method, 

226 "url": url_used, 

227 "headers": headers, 

228 } 

229 

230 if body_params: 

231 body_type = self._execute_config.body_type or "json" 

232 if body_type == "json": 

233 request_kwargs["json"] = body_params 

234 elif body_type == "form": 234 ↛ 237line 234 didn't jump to line 237 because the condition on line 234 was always true

235 request_kwargs["data"] = body_params 

236 

237 if query_params: 

238 request_kwargs["params"] = query_params 

239 

240 response = httpx.request(**request_kwargs) 

241 response_status = response.status_code 

242 response.raise_for_status() 

243 

244 result = response.json() 

245 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result} 

246 return result_payload 

247 

248 except json.JSONDecodeError as exc: 

249 status = "error" 

250 error_message = f"Invalid JSON in arguments: {exc}" 

251 raise ValueError(error_message) from exc 

252 except httpx.HTTPStatusError as exc: 

253 status = "error" 

254 response_body = None 

255 if exc.response.text: 255 ↛ 260line 255 didn't jump to line 260 because the condition on line 255 was always true

256 try: 

257 response_body = exc.response.json() 

258 except json.JSONDecodeError: 

259 response_body = exc.response.text 

260 raise StackOneAPIError( 

261 str(exc), 

262 exc.response.status_code, 

263 response_body, 

264 ) from exc 

265 except httpx.RequestError as exc: 

266 status = "error" 

267 raise StackOneError(f"Request failed: {exc}") from exc 

268 finally: 

269 datetime.now(timezone.utc) 

270 metadata: JsonDict = { 

271 "http_method": self._execute_config.method, 

272 "url": url_used, 

273 "status_code": response_status, 

274 "status": status, 

275 } 

276 

277 feedback_metadata = feedback_options.get("feedback_metadata") 

278 if isinstance(feedback_metadata, dict): 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true

279 metadata["feedback_metadata"] = feedback_metadata 

280 

281 if feedback_options: 

282 metadata["feedback_options"] = { 

283 key: value 

284 for key, value in feedback_options.items() 

285 if key in {"feedback_session_id", "feedback_user_id"} and value is not None 

286 } 

287 

288 # Implicit feedback removed - just API calls 

289 

290 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

291 """Call the tool with the given arguments 

292 

293 This method provides a more intuitive way to execute tools directly. 

294 

295 Args: 

296 *args: If a single argument is provided, it's treated as the full arguments dict/string 

297 **kwargs: Keyword arguments to pass to the tool 

298 options: Optional execution options 

299 

300 Returns: 

301 API response as dict 

302 

303 Raises: 

304 StackOneAPIError: If the API request fails 

305 ValueError: If the arguments are invalid 

306 

307 Examples: 

308 >>> tool.call({"name": "John", "email": "john@example.com"}) 

309 >>> tool.call(name="John", email="john@example.com") 

310 """ 

311 if args and kwargs: 

312 raise ValueError("Cannot provide both positional and keyword arguments") 

313 

314 if args: 

315 if len(args) > 1: 

316 raise ValueError("Only one positional argument is allowed") 

317 return self.execute(args[0]) 

318 

319 return self.execute(kwargs if kwargs else None) 

320 

321 def to_openai_function(self) -> JsonDict: 

322 """Convert this tool to OpenAI's function format 

323 

324 Returns: 

325 Tool definition in OpenAI function format 

326 """ 

327 # Clean properties and handle special types 

328 properties = {} 

329 required = [] 

330 

331 for name, prop in self.parameters.properties.items(): 

332 if isinstance(prop, dict): 

333 # Only keep standard JSON Schema properties 

334 cleaned_prop = {} 

335 

336 # Copy basic properties 

337 if "type" in prop: 

338 cleaned_prop["type"] = prop["type"] 

339 if "description" in prop: 

340 cleaned_prop["description"] = prop["description"] 

341 if "enum" in prop: 

342 cleaned_prop["enum"] = prop["enum"] 

343 

344 # Handle array types 

345 if cleaned_prop.get("type") == "array" and "items" in prop: 

346 if isinstance(prop["items"], dict): 346 ↛ 352line 346 didn't jump to line 352 because the condition on line 346 was always true

347 cleaned_prop["items"] = { 

348 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum") 

349 } 

350 

351 # Handle object types 

352 if cleaned_prop.get("type") == "object" and "properties" in prop: 

353 cleaned_prop["properties"] = { 

354 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")} 

355 for k, v in prop["properties"].items() 

356 } 

357 

358 # Handle required fields - if not explicitly nullable 

359 if not prop.get("nullable", False): 359 ↛ 362line 359 didn't jump to line 362 because the condition on line 359 was always true

360 required.append(name) 

361 

362 properties[name] = cleaned_prop 

363 else: 

364 properties[name] = {"type": "string"} 

365 required.append(name) 

366 

367 # Create the OpenAI function schema 

368 parameters = { 

369 "type": "object", 

370 "properties": properties, 

371 } 

372 

373 # Only include required if there are required fields 

374 if required: 374 ↛ 377line 374 didn't jump to line 377 because the condition on line 374 was always true

375 parameters["required"] = required 

376 

377 return { 

378 "type": "function", 

379 "function": { 

380 "name": self.name, 

381 "description": self.description, 

382 "parameters": parameters, 

383 }, 

384 } 

385 

386 def to_langchain(self) -> BaseTool: 

387 """Convert this tool to LangChain format 

388 

389 Returns: 

390 Tool in LangChain format 

391 """ 

392 # Create properly annotated schema for the tool 

393 schema_props: dict[str, Any] = {} 

394 annotations: dict[str, Any] = {} 

395 

396 for name, details in self.parameters.properties.items(): 

397 python_type: type = str # Default to str 

398 if isinstance(details, dict): 

399 type_str = details.get("type", "string") 

400 if type_str == "number": 

401 python_type = float 

402 elif type_str == "integer": 

403 python_type = int 

404 elif type_str == "boolean": 

405 python_type = bool 

406 

407 field = Field(description=details.get("description", "")) 

408 else: 

409 field = Field(description="") 

410 

411 schema_props[name] = field 

412 annotations[name] = python_type 

413 

414 # Create the schema class with proper annotations 

415 schema_class = type( 

416 f"{self.name.title()}Args", 

417 (BaseModel,), 

418 { 

419 "__annotations__": annotations, 

420 "__module__": __name__, 

421 **schema_props, 

422 }, 

423 ) 

424 

425 parent_tool = self 

426 

427 class StackOneLangChainTool(BaseTool): 

428 name: str = parent_tool.name 

429 description: str = parent_tool.description 

430 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment] 

431 func = staticmethod(parent_tool.execute) # Required by CrewAI 

432 

433 def _run(self, **kwargs: Any) -> Any: 

434 return parent_tool.execute(kwargs) 

435 

436 return StackOneLangChainTool() 

437 

438 def set_account_id(self, account_id: str | None) -> None: 

439 """Set the account ID for this tool 

440 

441 Args: 

442 account_id: The account ID to use, or None to clear it 

443 """ 

444 self._account_id = account_id 

445 

446 def get_account_id(self) -> str | None: 

447 """Get the current account ID for this tool 

448 

449 Returns: 

450 Current account ID or None if not set 

451 """ 

452 return self._account_id 

453 

454 

455class Tools: 

456 """Container for Tool instances with lookup capabilities""" 

457 

458 def __init__(self, tools: list[StackOneTool]) -> None: 

459 """Initialize Tools container 

460 

461 Args: 

462 tools: List of Tool instances to manage 

463 """ 

464 self.tools = tools 

465 self._tool_map = {tool.name: tool for tool in tools} 

466 

467 def __getitem__(self, index: int) -> StackOneTool: 

468 return self.tools[index] 

469 

470 def __len__(self) -> int: 

471 return len(self.tools) 

472 

473 def __iter__(self) -> Any: 

474 """Make Tools iterable""" 

475 return iter(self.tools) 

476 

477 def to_list(self) -> list[StackOneTool]: 

478 """Convert to list of tools 

479 

480 Returns: 

481 List of StackOneTool instances 

482 """ 

483 return list(self.tools) 

484 

485 def get_tool(self, name: str) -> StackOneTool | None: 

486 """Get a tool by its name 

487 

488 Args: 

489 name: Name of the tool to retrieve 

490 

491 Returns: 

492 The tool if found, None otherwise 

493 """ 

494 return self._tool_map.get(name) 

495 

496 def set_account_id(self, account_id: str | None) -> None: 

497 """Set the account ID for all tools in this collection 

498 

499 Args: 

500 account_id: The account ID to use, or None to clear it 

501 """ 

502 for tool in self.tools: 

503 tool.set_account_id(account_id) 

504 

505 def get_account_id(self) -> str | None: 

506 """Get the current account ID for this collection 

507 

508 Returns: 

509 The first non-None account ID found, or None if none set 

510 """ 

511 for tool in self.tools: 

512 account_id = tool.get_account_id() 

513 if isinstance(account_id, str): 

514 return account_id 

515 return None 

516 

517 def to_openai(self) -> list[JsonDict]: 

518 """Convert all tools to OpenAI function format 

519 

520 Returns: 

521 List of tools in OpenAI function format 

522 """ 

523 return [tool.to_openai_function() for tool in self.tools] 

524 

525 def to_langchain(self) -> Sequence[BaseTool]: 

526 """Convert all tools to LangChain format 

527 

528 Returns: 

529 Sequence of tools in LangChain format 

530 """ 

531 return [tool.to_langchain() for tool in self.tools] 

532 

533 def utility_tools(self, hybrid_alpha: float | None = None) -> Tools: 

534 """Return utility tools for tool discovery and execution 

535 

536 Utility tools enable dynamic tool discovery and execution based on natural language queries 

537 using hybrid BM25 + TF-IDF search. 

538 

539 Args: 

540 hybrid_alpha: Weight for BM25 in hybrid search (0-1). If not provided, uses 

541 ToolIndex.DEFAULT_HYBRID_ALPHA (0.2), which gives more weight to BM25 scoring 

542 and has been shown to provide better tool discovery accuracy 

543 (10.8% improvement in validation testing). 

544 

545 Returns: 

546 Tools collection containing tool_search and tool_execute 

547 

548 Note: 

549 This feature is in beta and may change in future versions 

550 """ 

551 from stackone_ai.utility_tools import ( 

552 ToolIndex, 

553 create_tool_execute, 

554 create_tool_search, 

555 ) 

556 

557 # Create search index with hybrid search 

558 index = ToolIndex(self.tools, hybrid_alpha=hybrid_alpha) 

559 

560 # Create utility tools 

561 filter_tool = create_tool_search(index) 

562 execute_tool = create_tool_execute(self) 

563 

564 return Tools([filter_tool, execute_tool])