Coverage for stackone_ai / models.py: 98%

276 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-05-01 15:10 +0000

1from __future__ import annotations 

2 

3import base64 

4import json 

5import logging 

6from collections.abc import Sequence 

7from datetime import datetime, timezone 

8from enum import Enum 

9from typing import TYPE_CHECKING, Annotated, Any, ClassVar, TypeAlias, cast 

10from urllib.parse import quote 

11 

12import httpx 

13from langchain_core.tools import BaseTool 

14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr 

15 

16if TYPE_CHECKING: 

17 from pydantic_ai.tools import Tool as PydanticAITool 

18 

19# Type aliases for common types 

20JsonDict: TypeAlias = dict[str, Any] 

21Headers: TypeAlias = dict[str, str] 

22 

23 

24logger = logging.getLogger("stackone.tools") 

25 

26 

27class StackOneError(Exception): 

28 """Base exception for StackOne errors""" 

29 

30 pass 

31 

32 

33class StackOneAPIError(StackOneError): 

34 """Raised when the StackOne API returns an error""" 

35 

36 def __init__(self, message: str, status_code: int, response_body: Any) -> None: 

37 super().__init__(message) 

38 self.status_code = status_code 

39 self.response_body = response_body 

40 

41 

42class ParameterLocation(str, Enum): 

43 """Valid locations for parameters in requests""" 

44 

45 HEADER = "header" 

46 QUERY = "query" 

47 PATH = "path" 

48 BODY = "body" 

49 FILE = "file" # For file uploads 

50 

51 

52def validate_method(v: str) -> str: 

53 """Validate HTTP method is uppercase and supported""" 

54 method = v.upper() 

55 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}: 

56 raise ValueError(f"Unsupported HTTP method: {method}") 

57 return method 

58 

59 

60class ExecuteConfig(BaseModel): 

61 """Configuration for executing a tool against an API endpoint""" 

62 

63 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request") 

64 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use") 

65 url: str = Field(description="API endpoint URL") 

66 name: str = Field(description="Tool name") 

67 body_type: str | None = Field(default=None, description="Content type for request body") 

68 parameter_locations: dict[str, ParameterLocation] = Field( 

69 default_factory=dict, description="Maps parameter names to their location in the request" 

70 ) 

71 timeout: float = Field(default=60.0, description="Request timeout in seconds") 

72 

73 

74class ToolParameters(BaseModel): 

75 """Schema definition for tool parameters""" 

76 

77 type: str = Field(description="JSON Schema type") 

78 properties: JsonDict = Field(description="JSON Schema properties") 

79 

80 

81class ToolDefinition(BaseModel): 

82 """Complete definition of a tool including its schema and execution config""" 

83 

84 description: str = Field(description="Tool description") 

85 parameters: ToolParameters = Field(description="Tool parameter schema") 

86 execute: ExecuteConfig = Field(description="Tool execution configuration") 

87 

88 

89class StackOneTool(BaseModel): 

90 """Base class for all StackOne tools. Provides functionality for executing API calls 

91 and converting to various formats (OpenAI, LangChain).""" 

92 

93 name: str = Field(description="Tool name") 

94 description: str = Field(description="Tool description") 

95 parameters: ToolParameters = Field(description="Tool parameters") 

96 _execute_config: ExecuteConfig = PrivateAttr() 

97 _api_key: str = PrivateAttr() 

98 _account_id: str | None = PrivateAttr(default=None) 

99 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = { 

100 "feedback_session_id", 

101 "feedback_user_id", 

102 "feedback_metadata", 

103 } 

104 

105 @property 

106 def connector(self) -> str: 

107 """Extract connector from tool name. 

108 

109 Tool names follow the format: {connector}_{action}_{entity} 

110 e.g., 'bamboohr_create_employee' -> 'bamboohr' 

111 

112 Returns: 

113 Connector name in lowercase 

114 """ 

115 return self.name.split("_")[0].lower() 

116 

117 def __init__( 

118 self, 

119 description: str, 

120 parameters: ToolParameters, 

121 _execute_config: ExecuteConfig, 

122 _api_key: str, 

123 _account_id: str | None = None, 

124 ) -> None: 

125 super().__init__( 

126 name=_execute_config.name, 

127 description=description, 

128 parameters=parameters, 

129 ) 

130 self._execute_config = _execute_config 

131 self._api_key = _api_key 

132 self._account_id = _account_id 

133 

134 @classmethod 

135 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]: 

136 merged_params = dict(params) 

137 feedback_options = dict(options or {}) 

138 for key in cls._FEEDBACK_OPTION_KEYS: 

139 if key in merged_params and key not in feedback_options: 

140 feedback_options[key] = merged_params.pop(key) 

141 return merged_params, feedback_options 

142 

143 def _prepare_headers(self) -> Headers: 

144 """Prepare headers for the API request 

145 

146 Returns: 

147 Headers to use in the request 

148 """ 

149 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode() 

150 headers: Headers = { 

151 "Authorization": f"Basic {auth_string}", 

152 "User-Agent": "stackone-python/1.0.0", 

153 } 

154 

155 if self._account_id: 

156 headers["x-account-id"] = self._account_id 

157 

158 # Add predefined headers 

159 headers.update(self._execute_config.headers) 

160 return headers 

161 

162 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]: 

163 """Prepare URL and parameters for the API request 

164 

165 Args: 

166 kwargs: Arguments to process 

167 

168 Returns: 

169 Tuple of (url, body_params, query_params) 

170 """ 

171 url = self._execute_config.url 

172 body_params: JsonDict = {} 

173 query_params: JsonDict = {} 

174 

175 for key, value in kwargs.items(): 

176 param_location = self._execute_config.parameter_locations.get(key) 

177 

178 if param_location == ParameterLocation.PATH: 

179 # Safely encode path parameters to prevent SSRF attacks 

180 encoded_value = quote(str(value), safe="") 

181 url = url.replace(f"{{{key}}}", encoded_value) 

182 elif param_location == ParameterLocation.QUERY: 

183 query_params[key] = value 

184 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE): 

185 body_params[key] = value 

186 else: 

187 # Default behavior 

188 if f"{{{key}}}" in url: 

189 # Safely encode path parameters to prevent SSRF attacks 

190 encoded_value = quote(str(value), safe="") 

191 url = url.replace(f"{{{key}}}", encoded_value) 

192 elif self._execute_config.method in {"GET", "DELETE"}: 

193 query_params[key] = value 

194 else: 

195 body_params[key] = value 

196 

197 return url, body_params, query_params 

198 

199 def execute( 

200 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None 

201 ) -> JsonDict: 

202 """Execute the tool with the given parameters 

203 

204 Args: 

205 arguments: Tool arguments as string or dict 

206 options: Execution options (e.g. feedback metadata) 

207 

208 Returns: 

209 API response as dict 

210 

211 Raises: 

212 StackOneAPIError: If the API request fails 

213 ValueError: If the arguments are invalid 

214 """ 

215 datetime.now(timezone.utc) 

216 feedback_options: JsonDict = {} 

217 result_payload: JsonDict | None = None 

218 response_status: int | None = None 

219 error_message: str | None = None 

220 status = "success" 

221 url_used = self._execute_config.url 

222 

223 try: 

224 if isinstance(arguments, str): 

225 parsed_arguments = json.loads(arguments) 

226 else: 

227 parsed_arguments = arguments or {} 

228 

229 if not isinstance(parsed_arguments, dict): 

230 status = "error" 

231 error_message = "Tool arguments must be a JSON object" 

232 raise ValueError(error_message) 

233 

234 kwargs = parsed_arguments 

235 dict(kwargs) 

236 

237 headers = self._prepare_headers() 

238 url_used, body_params, query_params = self._prepare_request_params(kwargs) 

239 

240 request_kwargs: dict[str, Any] = { 

241 "method": self._execute_config.method, 

242 "url": url_used, 

243 "headers": headers, 

244 } 

245 

246 if body_params: 

247 body_type = self._execute_config.body_type or "json" 

248 if body_type == "json": 

249 request_kwargs["json"] = body_params 

250 elif body_type == "form": 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true

251 request_kwargs["data"] = body_params 

252 

253 if query_params: 

254 request_kwargs["params"] = query_params 

255 

256 response = httpx.request(**request_kwargs, timeout=self._execute_config.timeout) 

257 response_status = response.status_code 

258 response.raise_for_status() 

259 

260 result = response.json() 

261 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result} 

262 return result_payload 

263 

264 except json.JSONDecodeError as exc: 

265 status = "error" 

266 error_message = f"Invalid JSON in arguments: {exc}" 

267 raise ValueError(error_message) from exc 

268 except httpx.HTTPStatusError as exc: 

269 status = "error" 

270 response_body = None 

271 if exc.response.text: 271 ↛ 276line 271 didn't jump to line 276 because the condition on line 271 was always true

272 try: 

273 response_body = exc.response.json() 

274 except json.JSONDecodeError: 

275 response_body = exc.response.text 

276 raise StackOneAPIError( 

277 str(exc), 

278 exc.response.status_code, 

279 response_body, 

280 ) from exc 

281 except httpx.RequestError as exc: 

282 status = "error" 

283 raise StackOneError(f"Request failed: {exc}") from exc 

284 finally: 

285 datetime.now(timezone.utc) 

286 metadata: JsonDict = { 

287 "http_method": self._execute_config.method, 

288 "url": url_used, 

289 "status_code": response_status, 

290 "status": status, 

291 } 

292 

293 feedback_metadata = feedback_options.get("feedback_metadata") 

294 if isinstance(feedback_metadata, dict): 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true

295 metadata["feedback_metadata"] = feedback_metadata 

296 

297 if feedback_options: 

298 metadata["feedback_options"] = { 

299 key: value 

300 for key, value in feedback_options.items() 

301 if key in {"feedback_session_id", "feedback_user_id"} and value is not None 

302 } 

303 

304 # Implicit feedback removed - just API calls 

305 

306 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

307 """Call the tool with the given arguments 

308 

309 This method provides a more intuitive way to execute tools directly. 

310 

311 Args: 

312 *args: If a single argument is provided, it's treated as the full arguments dict/string 

313 **kwargs: Keyword arguments to pass to the tool 

314 options: Optional execution options 

315 

316 Returns: 

317 API response as dict 

318 

319 Raises: 

320 StackOneAPIError: If the API request fails 

321 ValueError: If the arguments are invalid 

322 

323 Examples: 

324 >>> tool.call({"name": "John", "email": "john@example.com"}) 

325 >>> tool.call(name="John", email="john@example.com") 

326 """ 

327 if args and kwargs: 

328 raise ValueError("Cannot provide both positional and keyword arguments") 

329 

330 if args: 

331 if len(args) > 1: 

332 raise ValueError("Only one positional argument is allowed") 

333 return self.execute(args[0]) 

334 

335 return self.execute(kwargs if kwargs else None) 

336 

337 def __call__(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict: 

338 """Make the tool directly callable. 

339 

340 Alias for :meth:`call` so that ``tool(query="…")`` works. 

341 """ 

342 return self.call(*args, options=options, **kwargs) 

343 

344 def to_openai_function(self) -> JsonDict: 

345 """Convert this tool to OpenAI's function format 

346 

347 Returns: 

348 Tool definition in OpenAI function format 

349 """ 

350 # Clean properties and handle special types 

351 properties = {} 

352 required = [] 

353 

354 for name, prop in self.parameters.properties.items(): 

355 if isinstance(prop, dict): 

356 # Only keep standard JSON Schema properties 

357 cleaned_prop = {} 

358 

359 # Copy basic properties 

360 if "type" in prop: 

361 cleaned_prop["type"] = prop["type"] 

362 if "description" in prop: 

363 cleaned_prop["description"] = prop["description"] 

364 if "enum" in prop: 

365 cleaned_prop["enum"] = prop["enum"] 

366 

367 # Handle array types 

368 if cleaned_prop.get("type") == "array" and "items" in prop: 

369 if isinstance(prop["items"], dict): 369 ↛ 375line 369 didn't jump to line 375 because the condition on line 369 was always true

370 cleaned_prop["items"] = { 

371 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum") 

372 } 

373 

374 # Handle object types 

375 if cleaned_prop.get("type") == "object" and "properties" in prop: 

376 cleaned_prop["properties"] = { 

377 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")} 

378 for k, v in prop["properties"].items() 

379 } 

380 

381 # Handle required fields - if not explicitly nullable 

382 if not prop.get("nullable", False): 

383 required.append(name) 

384 

385 properties[name] = cleaned_prop 

386 else: 

387 properties[name] = {"type": "string"} 

388 required.append(name) 

389 

390 # Create the OpenAI function schema 

391 parameters = { 

392 "type": "object", 

393 "properties": properties, 

394 } 

395 

396 # Only include required if there are required fields 

397 if required: 

398 parameters["required"] = required 

399 

400 return { 

401 "type": "function", 

402 "function": { 

403 "name": self.name, 

404 "description": self.description, 

405 "parameters": parameters, 

406 }, 

407 } 

408 

409 def to_langchain(self) -> BaseTool: 

410 """Convert this tool to LangChain format 

411 

412 Returns: 

413 Tool in LangChain format 

414 """ 

415 # Create properly annotated schema for the tool 

416 schema_props: dict[str, Any] = {} 

417 annotations: dict[str, Any] = {} 

418 

419 for name, details in self.parameters.properties.items(): 

420 python_type: type = str # Default to str 

421 is_nullable = False 

422 if isinstance(details, dict): 

423 type_str = details.get("type", "string") 

424 is_nullable = details.get("nullable", False) 

425 if type_str == "number": 

426 python_type = float 

427 elif type_str == "integer": 

428 python_type = int 

429 elif type_str == "boolean": 

430 python_type = bool 

431 elif type_str == "object": 

432 python_type = dict 

433 elif type_str == "array": 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true

434 python_type = list 

435 

436 if is_nullable: 

437 field = Field(default=None, description=details.get("description", "")) 

438 else: 

439 field = Field(description=details.get("description", "")) 

440 else: 

441 field = Field(description="") 

442 

443 schema_props[name] = field 

444 if is_nullable: 

445 annotations[name] = python_type | None 

446 else: 

447 annotations[name] = python_type 

448 

449 # Create the schema class with proper annotations 

450 schema_class = type( 

451 f"{self.name.title()}Args", 

452 (BaseModel,), 

453 { 

454 "__annotations__": annotations, 

455 "__module__": __name__, 

456 **schema_props, 

457 }, 

458 ) 

459 

460 parent_tool = self 

461 

462 class StackOneLangChainTool(BaseTool): 

463 name: str = parent_tool.name 

464 description: str = parent_tool.description 

465 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment] 

466 func = staticmethod(parent_tool.execute) # Required by CrewAI 

467 

468 def _run(self, **kwargs: Any) -> Any: 

469 return parent_tool.execute(kwargs) 

470 

471 return StackOneLangChainTool() 

472 

473 def to_pydantic_ai_tool(self) -> PydanticAITool: 

474 """Convert this tool to a Pydantic AI ``Tool``. 

475 

476 Requires ``stackone-ai[pydantic-ai]`` (installs ``pydantic-ai-slim``). 

477 

478 Returns: 

479 A ``pydantic_ai.tools.Tool`` ready to pass to ``Agent(tools=[...])``. 

480 """ 

481 try: 

482 from pydantic_ai.tools import Tool 

483 except ImportError as e: 

484 raise ImportError( 

485 "Install `pydantic-ai-slim` (or `stackone-ai[pydantic-ai]`) " 

486 "to use the Pydantic AI integration." 

487 ) from e 

488 

489 openai_function = self.to_openai_function() 

490 json_schema = openai_function["function"]["parameters"] 

491 parent_tool = self 

492 

493 def implementation(**kwargs: Any) -> Any: 

494 return parent_tool.execute(kwargs) 

495 

496 return Tool.from_schema( 

497 function=implementation, 

498 name=self.name, 

499 description=self.description, 

500 json_schema=json_schema, 

501 ) 

502 

503 def set_account_id(self, account_id: str | None) -> None: 

504 """Set the account ID for this tool 

505 

506 Args: 

507 account_id: The account ID to use, or None to clear it 

508 """ 

509 self._account_id = account_id 

510 

511 def get_account_id(self) -> str | None: 

512 """Get the current account ID for this tool 

513 

514 Returns: 

515 Current account ID or None if not set 

516 """ 

517 return self._account_id 

518 

519 

520class Tools: 

521 """Container for Tool instances with lookup capabilities""" 

522 

523 def __init__( 

524 self, 

525 tools: list[StackOneTool], 

526 ) -> None: 

527 """Initialize Tools container 

528 

529 Args: 

530 tools: List of Tool instances to manage 

531 """ 

532 self.tools = tools 

533 self._tool_map = {tool.name: tool for tool in tools} 

534 

535 def __getitem__(self, index: int) -> StackOneTool: 

536 return self.tools[index] 

537 

538 def __len__(self) -> int: 

539 return len(self.tools) 

540 

541 def __iter__(self) -> Any: 

542 """Make Tools iterable""" 

543 return iter(self.tools) 

544 

545 def to_list(self) -> list[StackOneTool]: 

546 """Convert to list of tools 

547 

548 Returns: 

549 List of StackOneTool instances 

550 """ 

551 return list(self.tools) 

552 

553 def get_tool(self, name: str) -> StackOneTool | None: 

554 """Get a tool by its name 

555 

556 Args: 

557 name: Name of the tool to retrieve 

558 

559 Returns: 

560 The tool if found, None otherwise 

561 """ 

562 return self._tool_map.get(name) 

563 

564 def set_account_id(self, account_id: str | None) -> None: 

565 """Set the account ID for all tools in this collection 

566 

567 Args: 

568 account_id: The account ID to use, or None to clear it 

569 """ 

570 for tool in self.tools: 

571 tool.set_account_id(account_id) 

572 

573 def get_account_id(self) -> str | None: 

574 """Get the current account ID for this collection 

575 

576 Returns: 

577 The first non-None account ID found, or None if none set 

578 """ 

579 for tool in self.tools: 

580 account_id = tool.get_account_id() 

581 if isinstance(account_id, str): 

582 return account_id 

583 return None 

584 

585 def get_connectors(self) -> set[str]: 

586 """Get unique connector names from all tools. 

587 

588 Returns: 

589 Set of connector names (lowercase) 

590 

591 Example: 

592 tools = toolset.fetch_tools() 

593 connectors = tools.get_connectors() 

594 # {'bamboohr', 'hibob', 'slack', ...} 

595 """ 

596 return {tool.connector for tool in self.tools} 

597 

598 def to_openai(self) -> list[JsonDict]: 

599 """Convert all tools to OpenAI function format 

600 

601 Returns: 

602 List of tools in OpenAI function format 

603 """ 

604 return [tool.to_openai_function() for tool in self.tools] 

605 

606 def to_langchain(self) -> Sequence[BaseTool]: 

607 """Convert all tools to LangChain format 

608 

609 Returns: 

610 Sequence of tools in LangChain format 

611 """ 

612 return [tool.to_langchain() for tool in self.tools] 

613 

614 def to_pydantic_ai(self) -> list[PydanticAITool]: 

615 """Convert all tools to Pydantic AI ``Tool`` instances. 

616 

617 Requires ``stackone-ai[pydantic-ai]`` (installs ``pydantic-ai-slim``). 

618 

619 Returns: 

620 List of ``pydantic_ai.tools.Tool`` ready to pass to ``Agent(tools=[...])``. 

621 """ 

622 return [tool.to_pydantic_ai_tool() for tool in self.tools]