Coverage for stackone_ai / models.py: 97%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-04-02 08:51 +0000

1from __future__ import annotations 

2 

3import base64 

4import json 

5import logging 

6from collections.abc import Sequence 

7from datetime import datetime, timezone 

8from enum import Enum 

9from typing import Annotated, Any, ClassVar, TypeAlias, cast 

10from urllib.parse import quote 

11 

12import httpx 

13from langchain_core.tools import BaseTool 

14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr 

15 

16# Type aliases for common types 

17JsonDict: TypeAlias = dict[str, Any] 

18Headers: TypeAlias = dict[str, str] 

19 

20 

21logger = logging.getLogger("stackone.tools") 

22 

23 

24class StackOneError(Exception): 

25 """Base exception for StackOne errors""" 

26 

27 pass 

28 

29 

30class StackOneAPIError(StackOneError): 

31 """Raised when the StackOne API returns an error""" 

32 

33 def __init__(self, message: str, status_code: int, response_body: Any) -> None: 

34 super().__init__(message) 

35 self.status_code = status_code 

36 self.response_body = response_body 

37 

38 

39class ParameterLocation(str, Enum): 

40 """Valid locations for parameters in requests""" 

41 

42 HEADER = "header" 

43 QUERY = "query" 

44 PATH = "path" 

45 BODY = "body" 

46 FILE = "file" # For file uploads 

47 

48 

49def validate_method(v: str) -> str: 

50 """Validate HTTP method is uppercase and supported""" 

51 method = v.upper() 

52 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}: 

53 raise ValueError(f"Unsupported HTTP method: {method}") 

54 return method 

55 

56 

57class ExecuteConfig(BaseModel): 

58 """Configuration for executing a tool against an API endpoint""" 

59 

60 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request") 

61 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use") 

62 url: str = Field(description="API endpoint URL") 

63 name: str = Field(description="Tool name") 

64 body_type: str | None = Field(default=None, description="Content type for request body") 

65 parameter_locations: dict[str, ParameterLocation] = Field( 

66 default_factory=dict, description="Maps parameter names to their location in the request" 

67 ) 

68 

69 

70class ToolParameters(BaseModel): 

71 """Schema definition for tool parameters""" 

72 

73 type: str = Field(description="JSON Schema type") 

74 properties: JsonDict = Field(description="JSON Schema properties") 

75 

76 

77class ToolDefinition(BaseModel): 

78 """Complete definition of a tool including its schema and execution config""" 

79 

80 description: str = Field(description="Tool description") 

81 parameters: ToolParameters = Field(description="Tool parameter schema") 

82 execute: ExecuteConfig = Field(description="Tool execution configuration") 

83 

84 

85class StackOneTool(BaseModel): 

86 """Base class for all StackOne tools. Provides functionality for executing API calls 

87 and converting to various formats (OpenAI, LangChain).""" 

88 

89 name: str = Field(description="Tool name") 

90 description: str = Field(description="Tool description") 

91 parameters: ToolParameters = Field(description="Tool parameters") 

92 _execute_config: ExecuteConfig = PrivateAttr() 

93 _api_key: str = PrivateAttr() 

94 _account_id: str | None = PrivateAttr(default=None) 

95 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = { 

96 "feedback_session_id", 

97 "feedback_user_id", 

98 "feedback_metadata", 

99 } 

100 

101 @property 

102 def connector(self) -> str: 

103 """Extract connector from tool name. 

104 

105 Tool names follow the format: {connector}_{action}_{entity} 

106 e.g., 'bamboohr_create_employee' -> 'bamboohr' 

107 

108 Returns: 

109 Connector name in lowercase 

110 """ 

111 return self.name.split("_")[0].lower() 

112 

113 def __init__( 

114 self, 

115 description: str, 

116 parameters: ToolParameters, 

117 _execute_config: ExecuteConfig, 

118 _api_key: str, 

119 _account_id: str | None = None, 

120 ) -> None: 

121 super().__init__( 

122 name=_execute_config.name, 

123 description=description, 

124 parameters=parameters, 

125 ) 

126 self._execute_config = _execute_config 

127 self._api_key = _api_key 

128 self._account_id = _account_id 

129 

130 @classmethod 

131 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]: 

132 merged_params = dict(params) 

133 feedback_options = dict(options or {}) 

134 for key in cls._FEEDBACK_OPTION_KEYS: 

135 if key in merged_params and key not in feedback_options: 

136 feedback_options[key] = merged_params.pop(key) 

137 return merged_params, feedback_options 

138 

139 def _prepare_headers(self) -> Headers: 

140 """Prepare headers for the API request 

141 

142 Returns: 

143 Headers to use in the request 

144 """ 

145 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode() 

146 headers: Headers = { 

147 "Authorization": f"Basic {auth_string}", 

148 "User-Agent": "stackone-python/1.0.0", 

149 } 

150 

151 if self._account_id: 

152 headers["x-account-id"] = self._account_id 

153 

154 # Add predefined headers 

155 headers.update(self._execute_config.headers) 

156 return headers 

157 

158 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]: 

159 """Prepare URL and parameters for the API request 

160 

161 Args: 

162 kwargs: Arguments to process 

163 

164 Returns: 

165 Tuple of (url, body_params, query_params) 

166 """ 

167 url = self._execute_config.url 

168 body_params: JsonDict = {} 

169 query_params: JsonDict = {} 

170 

171 for key, value in kwargs.items(): 

172 param_location = self._execute_config.parameter_locations.get(key) 

173 

174 if param_location == ParameterLocation.PATH: 

175 # Safely encode path parameters to prevent SSRF attacks 

176 encoded_value = quote(str(value), safe="") 

177 url = url.replace(f"{{{key}}}", encoded_value) 

178 elif param_location == ParameterLocation.QUERY: 

179 query_params[key] = value 

180 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE): 

181 body_params[key] = value 

182 else: 

183 # Default behavior 

184 if f"{{{key}}}" in url: 

185 # Safely encode path parameters to prevent SSRF attacks 

186 encoded_value = quote(str(value), safe="") 

187 url = url.replace(f"{{{key}}}", encoded_value) 

188 elif self._execute_config.method in {"GET", "DELETE"}: 

189 query_params[key] = value 

190 else: 

191 body_params[key] = value 

192 

193 return url, body_params, query_params 

194 

195 def execute( 

196 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None 

197 ) -> JsonDict: 

198 """Execute the tool with the given parameters 

199 

200 Args: 

201 arguments: Tool arguments as string or dict 

202 options: Execution options (e.g. feedback metadata) 

203 

204 Returns: 

205 API response as dict 

206 

207 Raises: 

208 StackOneAPIError: If the API request fails 

209 ValueError: If the arguments are invalid 

210 """ 

211 datetime.now(timezone.utc) 

212 feedback_options: JsonDict = {} 

213 result_payload: JsonDict | None = None 

214 response_status: int | None = None 

215 error_message: str | None = None 

216 status = "success" 

217 url_used = self._execute_config.url 

218 

219 try: 

220 if isinstance(arguments, str): 

221 parsed_arguments = json.loads(arguments) 

222 else: 

223 parsed_arguments = arguments or {} 

224 

225 if not isinstance(parsed_arguments, dict): 

226 status = "error" 

227 error_message = "Tool arguments must be a JSON object" 

228 raise ValueError(error_message) 

229 

230 kwargs = parsed_arguments 

231 dict(kwargs) 

232 

233 headers = self._prepare_headers() 

234 url_used, body_params, query_params = self._prepare_request_params(kwargs) 

235 

236 request_kwargs: dict[str, Any] = { 

237 "method": self._execute_config.method, 

238 "url": url_used, 

239 "headers": headers, 

240 } 

241 

242 if body_params: 

243 body_type = self._execute_config.body_type or "json" 

244 if body_type == "json": 

245 request_kwargs["json"] = body_params 

246 elif body_type == "form": 246 ↛ 249line 246 didn't jump to line 249 because the condition on line 246 was always true

247 request_kwargs["data"] = body_params 

248 

249 if query_params: 

250 request_kwargs["params"] = query_params 

251 

252 response = httpx.request(**request_kwargs) 

253 response_status = response.status_code 

254 response.raise_for_status() 

255 

256 result = response.json() 

257 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result} 

258 return result_payload 

259 

260 except json.JSONDecodeError as exc: 

261 status = "error" 

262 error_message = f"Invalid JSON in arguments: {exc}" 

263 raise ValueError(error_message) from exc 

264 except httpx.HTTPStatusError as exc: 

265 status = "error" 

266 response_body = None 

267 if exc.response.text: 267 ↛ 272line 267 didn't jump to line 272 because the condition on line 267 was always true

268 try: 

269 response_body = exc.response.json() 

270 except json.JSONDecodeError: 

271 response_body = exc.response.text 

272 raise StackOneAPIError( 

273 str(exc), 

274 exc.response.status_code, 

275 response_body, 

276 ) from exc 

277 except httpx.RequestError as exc: 

278 status = "error" 

279 raise StackOneError(f"Request failed: {exc}") from exc 

280 finally: 

281 datetime.now(timezone.utc) 

282 metadata: JsonDict = { 

283 "http_method": self._execute_config.method, 

284 "url": url_used, 

285 "status_code": response_status, 

286 "status": status, 

287 } 

288 

289 feedback_metadata = feedback_options.get("feedback_metadata") 

290 if isinstance(feedback_metadata, dict): 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true

291 metadata["feedback_metadata"] = feedback_metadata 

292 

293 if feedback_options: 

294 metadata["feedback_options"] = { 

295 key: value 

296 for key, value in feedback_options.items() 

297 if key in {"feedback_session_id", "feedback_user_id"} and value is not None 

298 } 

299 

300 # Implicit feedback removed - just API calls 

301 

302 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

303 """Call the tool with the given arguments 

304 

305 This method provides a more intuitive way to execute tools directly. 

306 

307 Args: 

308 *args: If a single argument is provided, it's treated as the full arguments dict/string 

309 **kwargs: Keyword arguments to pass to the tool 

310 options: Optional execution options 

311 

312 Returns: 

313 API response as dict 

314 

315 Raises: 

316 StackOneAPIError: If the API request fails 

317 ValueError: If the arguments are invalid 

318 

319 Examples: 

320 >>> tool.call({"name": "John", "email": "john@example.com"}) 

321 >>> tool.call(name="John", email="john@example.com") 

322 """ 

323 if args and kwargs: 

324 raise ValueError("Cannot provide both positional and keyword arguments") 

325 

326 if args: 

327 if len(args) > 1: 

328 raise ValueError("Only one positional argument is allowed") 

329 return self.execute(args[0]) 

330 

331 return self.execute(kwargs if kwargs else None) 

332 

333 def __call__(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

334 """Make the tool directly callable. 

335 

336 Alias for :meth:`call` so that ``tool(query="…")`` works. 

337 """ 

338 return self.call(*args, options=options, **kwargs) 

339 

340 def to_openai_function(self) -> JsonDict: 

341 """Convert this tool to OpenAI's function format 

342 

343 Returns: 

344 Tool definition in OpenAI function format 

345 """ 

346 # Clean properties and handle special types 

347 properties = {} 

348 required = [] 

349 

350 for name, prop in self.parameters.properties.items(): 

351 if isinstance(prop, dict): 

352 # Only keep standard JSON Schema properties 

353 cleaned_prop = {} 

354 

355 # Copy basic properties 

356 if "type" in prop: 

357 cleaned_prop["type"] = prop["type"] 

358 if "description" in prop: 

359 cleaned_prop["description"] = prop["description"] 

360 if "enum" in prop: 

361 cleaned_prop["enum"] = prop["enum"] 

362 

363 # Handle array types 

364 if cleaned_prop.get("type") == "array" and "items" in prop: 

365 if isinstance(prop["items"], dict): 365 ↛ 371line 365 didn't jump to line 371 because the condition on line 365 was always true

366 cleaned_prop["items"] = { 

367 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum") 

368 } 

369 

370 # Handle object types 

371 if cleaned_prop.get("type") == "object" and "properties" in prop: 

372 cleaned_prop["properties"] = { 

373 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")} 

374 for k, v in prop["properties"].items() 

375 } 

376 

377 # Handle required fields - if not explicitly nullable 

378 if not prop.get("nullable", False): 

379 required.append(name) 

380 

381 properties[name] = cleaned_prop 

382 else: 

383 properties[name] = {"type": "string"} 

384 required.append(name) 

385 

386 # Create the OpenAI function schema 

387 parameters = { 

388 "type": "object", 

389 "properties": properties, 

390 } 

391 

392 # Only include required if there are required fields 

393 if required: 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true

394 parameters["required"] = required 

395 

396 return { 

397 "type": "function", 

398 "function": { 

399 "name": self.name, 

400 "description": self.description, 

401 "parameters": parameters, 

402 }, 

403 } 

404 

405 def to_langchain(self) -> BaseTool: 

406 """Convert this tool to LangChain format 

407 

408 Returns: 

409 Tool in LangChain format 

410 """ 

411 # Create properly annotated schema for the tool 

412 schema_props: dict[str, Any] = {} 

413 annotations: dict[str, Any] = {} 

414 

415 for name, details in self.parameters.properties.items(): 

416 python_type: type = str # Default to str 

417 is_nullable = False 

418 if isinstance(details, dict): 

419 type_str = details.get("type", "string") 

420 is_nullable = details.get("nullable", False) 

421 if type_str == "number": 

422 python_type = float 

423 elif type_str == "integer": 

424 python_type = int 

425 elif type_str == "boolean": 

426 python_type = bool 

427 elif type_str == "object": 

428 python_type = dict 

429 elif type_str == "array": 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true

430 python_type = list 

431 

432 if is_nullable: 

433 field = Field(default=None, description=details.get("description", "")) 

434 else: 

435 field = Field(description=details.get("description", "")) 

436 else: 

437 field = Field(description="") 

438 

439 schema_props[name] = field 

440 if is_nullable: 

441 annotations[name] = python_type | None 

442 else: 

443 annotations[name] = python_type 

444 

445 # Create the schema class with proper annotations 

446 schema_class = type( 

447 f"{self.name.title()}Args", 

448 (BaseModel,), 

449 { 

450 "__annotations__": annotations, 

451 "__module__": __name__, 

452 **schema_props, 

453 }, 

454 ) 

455 

456 parent_tool = self 

457 

458 class StackOneLangChainTool(BaseTool): 

459 name: str = parent_tool.name 

460 description: str = parent_tool.description 

461 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment] 

462 func = staticmethod(parent_tool.execute) # Required by CrewAI 

463 

464 def _run(self, **kwargs: Any) -> Any: 

465 return parent_tool.execute(kwargs) 

466 

467 return StackOneLangChainTool() 

468 

469 def set_account_id(self, account_id: str | None) -> None: 

470 """Set the account ID for this tool 

471 

472 Args: 

473 account_id: The account ID to use, or None to clear it 

474 """ 

475 self._account_id = account_id 

476 

477 def get_account_id(self) -> str | None: 

478 """Get the current account ID for this tool 

479 

480 Returns: 

481 Current account ID or None if not set 

482 """ 

483 return self._account_id 

484 

485 

486class Tools: 

487 """Container for Tool instances with lookup capabilities""" 

488 

489 def __init__( 

490 self, 

491 tools: list[StackOneTool], 

492 ) -> None: 

493 """Initialize Tools container 

494 

495 Args: 

496 tools: List of Tool instances to manage 

497 """ 

498 self.tools = tools 

499 self._tool_map = {tool.name: tool for tool in tools} 

500 

501 def __getitem__(self, index: int) -> StackOneTool: 

502 return self.tools[index] 

503 

504 def __len__(self) -> int: 

505 return len(self.tools) 

506 

507 def __iter__(self) -> Any: 

508 """Make Tools iterable""" 

509 return iter(self.tools) 

510 

511 def to_list(self) -> list[StackOneTool]: 

512 """Convert to list of tools 

513 

514 Returns: 

515 List of StackOneTool instances 

516 """ 

517 return list(self.tools) 

518 

519 def get_tool(self, name: str) -> StackOneTool | None: 

520 """Get a tool by its name 

521 

522 Args: 

523 name: Name of the tool to retrieve 

524 

525 Returns: 

526 The tool if found, None otherwise 

527 """ 

528 return self._tool_map.get(name) 

529 

530 def set_account_id(self, account_id: str | None) -> None: 

531 """Set the account ID for all tools in this collection 

532 

533 Args: 

534 account_id: The account ID to use, or None to clear it 

535 """ 

536 for tool in self.tools: 

537 tool.set_account_id(account_id) 

538 

539 def get_account_id(self) -> str | None: 

540 """Get the current account ID for this collection 

541 

542 Returns: 

543 The first non-None account ID found, or None if none set 

544 """ 

545 for tool in self.tools: 

546 account_id = tool.get_account_id() 

547 if isinstance(account_id, str): 

548 return account_id 

549 return None 

550 

551 def get_connectors(self) -> set[str]: 

552 """Get unique connector names from all tools. 

553 

554 Returns: 

555 Set of connector names (lowercase) 

556 

557 Example: 

558 tools = toolset.fetch_tools() 

559 connectors = tools.get_connectors() 

560 # {'bamboohr', 'hibob', 'slack', ...} 

561 """ 

562 return {tool.connector for tool in self.tools} 

563 

564 def to_openai(self) -> list[JsonDict]: 

565 """Convert all tools to OpenAI function format 

566 

567 Returns: 

568 List of tools in OpenAI function format 

569 """ 

570 return [tool.to_openai_function() for tool in self.tools] 

571 

572 def to_langchain(self) -> Sequence[BaseTool]: 

573 """Convert all tools to LangChain format 

574 

575 Returns: 

576 Sequence of tools in LangChain format 

577 """ 

578 return [tool.to_langchain() for tool in self.tools]