Coverage for stackone_ai / toolset.py: 100%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-08 18:25 +0000

1from __future__ import annotations 

2 

3import asyncio 

4import base64 

5import fnmatch 

6import json 

7import os 

8import threading 

9from collections.abc import Coroutine 

10from dataclasses import dataclass 

11from importlib import metadata 

12from typing import Any, TypeVar 

13 

14from stackone_ai.models import ( 

15 ExecuteConfig, 

16 ParameterLocation, 

17 StackOneTool, 

18 ToolParameters, 

19 Tools, 

20) 

21 

22try: 

23 _SDK_VERSION = metadata.version("stackone-ai") 

24except metadata.PackageNotFoundError: # pragma: no cover - best-effort fallback when running from source 

25 _SDK_VERSION = "dev" 

26 

27DEFAULT_BASE_URL = "https://api.stackone.com" 

28_RPC_PARAMETER_LOCATIONS = { 

29 "action": ParameterLocation.BODY, 

30 "body": ParameterLocation.BODY, 

31 "headers": ParameterLocation.BODY, 

32 "path": ParameterLocation.BODY, 

33 "query": ParameterLocation.BODY, 

34} 

35_USER_AGENT = f"stackone-ai-python/{_SDK_VERSION}" 

36 

37T = TypeVar("T") 

38 

39 

40@dataclass 

41class _McpToolDefinition: 

42 name: str 

43 description: str | None 

44 input_schema: dict[str, Any] 

45 

46 

47class ToolsetError(Exception): 

48 """Base exception for toolset errors""" 

49 

50 pass 

51 

52 

53class ToolsetConfigError(ToolsetError): 

54 """Raised when there is an error in the toolset configuration""" 

55 

56 pass 

57 

58 

59class ToolsetLoadError(ToolsetError): 

60 """Raised when there is an error loading tools""" 

61 

62 pass 

63 

64 

65def _run_async(awaitable: Coroutine[Any, Any, T]) -> T: 

66 """Run a coroutine, even when called from an existing event loop.""" 

67 

68 try: 

69 asyncio.get_running_loop() 

70 except RuntimeError: 

71 return asyncio.run(awaitable) 

72 

73 result: dict[str, T] = {} 

74 error: dict[str, BaseException] = {} 

75 

76 def runner() -> None: 

77 try: 

78 result["value"] = asyncio.run(awaitable) 

79 except BaseException as exc: # pragma: no cover - surfaced in caller context 

80 error["error"] = exc 

81 

82 thread = threading.Thread(target=runner, daemon=True) 

83 thread.start() 

84 thread.join() 

85 

86 if "error" in error: 

87 raise error["error"] 

88 

89 return result["value"] 

90 

91 

92def _build_auth_header(api_key: str) -> str: 

93 token = base64.b64encode(f"{api_key}:".encode()).decode() 

94 return f"Basic {token}" 

95 

96 

97def _fetch_mcp_tools(endpoint: str, headers: dict[str, str]) -> list[_McpToolDefinition]: 

98 try: 

99 from mcp import types as mcp_types # ty: ignore[unresolved-import] 

100 from mcp.client.session import ClientSession # ty: ignore[unresolved-import] 

101 from mcp.client.streamable_http import streamablehttp_client # ty: ignore[unresolved-import] 

102 except ImportError as exc: # pragma: no cover - depends on optional extra 

103 raise ToolsetConfigError( 

104 "MCP dependencies are required for fetch_tools. Install with 'pip install \"stackone-ai[mcp]\"'." 

105 ) from exc 

106 

107 async def _list() -> list[_McpToolDefinition]: 

108 async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _): 

109 session = ClientSession( 

110 read_stream, 

111 write_stream, 

112 client_info=mcp_types.Implementation(name="stackone-ai-python", version=_SDK_VERSION), 

113 ) 

114 async with session: 

115 await session.initialize() 

116 cursor: str | None = None 

117 collected: list[_McpToolDefinition] = [] 

118 while True: 

119 result = await session.list_tools(cursor) 

120 for tool in result.tools: 

121 input_schema = tool.inputSchema or {} 

122 collected.append( 

123 _McpToolDefinition( 

124 name=tool.name, 

125 description=tool.description, 

126 input_schema=dict(input_schema), 

127 ) 

128 ) 

129 cursor = result.nextCursor 

130 if cursor is None: 

131 break 

132 return collected 

133 

134 return _run_async(_list()) 

135 

136 

137class _StackOneRpcTool(StackOneTool): 

138 """RPC-backed tool wired to the StackOne actions RPC endpoint.""" 

139 

140 def __init__( 

141 self, 

142 *, 

143 name: str, 

144 description: str, 

145 parameters: ToolParameters, 

146 api_key: str, 

147 base_url: str, 

148 account_id: str | None, 

149 ) -> None: 

150 execute_config = ExecuteConfig( 

151 method="POST", 

152 url=f"{base_url.rstrip('/')}/actions/rpc", 

153 name=name, 

154 headers={}, 

155 body_type="json", 

156 parameter_locations=dict(_RPC_PARAMETER_LOCATIONS), 

157 ) 

158 super().__init__( 

159 description=description, 

160 parameters=parameters, 

161 _execute_config=execute_config, 

162 _api_key=api_key, 

163 _account_id=account_id, 

164 ) 

165 

166 def execute( 

167 self, arguments: str | dict[str, Any] | None = None, *, options: dict[str, Any] | None = None 

168 ) -> dict[str, Any]: 

169 parsed_arguments = self._parse_arguments(arguments) 

170 

171 body_payload = self._extract_record(parsed_arguments.pop("body", None)) 

172 headers_payload = self._extract_record(parsed_arguments.pop("headers", None)) 

173 path_payload = self._extract_record(parsed_arguments.pop("path", None)) 

174 query_payload = self._extract_record(parsed_arguments.pop("query", None)) 

175 

176 rpc_body: dict[str, Any] = dict(body_payload or {}) 

177 for key, value in parsed_arguments.items(): 

178 rpc_body[key] = value 

179 

180 payload: dict[str, Any] = { 

181 "action": self.name, 

182 "body": rpc_body, 

183 "headers": self._build_action_headers(headers_payload), 

184 } 

185 if path_payload: 

186 payload["path"] = path_payload 

187 if query_payload: 

188 payload["query"] = query_payload 

189 

190 return super().execute(payload, options=options) 

191 

192 def _parse_arguments(self, arguments: str | dict[str, Any] | None) -> dict[str, Any]: 

193 if arguments is None: 

194 return {} 

195 if isinstance(arguments, str): 

196 parsed = json.loads(arguments) 

197 else: 

198 parsed = arguments 

199 if not isinstance(parsed, dict): 

200 raise ValueError("Tool arguments must be a JSON object") 

201 return dict(parsed) 

202 

203 @staticmethod 

204 def _extract_record(value: Any) -> dict[str, Any] | None: 

205 if isinstance(value, dict): 

206 return dict(value) 

207 return None 

208 

209 def _build_action_headers(self, additional_headers: dict[str, Any] | None) -> dict[str, str]: 

210 headers: dict[str, str] = {} 

211 account_id = self.get_account_id() 

212 if account_id: 

213 headers["x-account-id"] = account_id 

214 

215 if additional_headers: 

216 for key, value in additional_headers.items(): 

217 if value is None: 

218 continue 

219 headers[str(key)] = str(value) 

220 

221 headers.pop("Authorization", None) 

222 return headers 

223 

224 

225class StackOneToolSet: 

226 """Main class for accessing StackOne tools""" 

227 

228 def __init__( 

229 self, 

230 api_key: str | None = None, 

231 account_id: str | None = None, 

232 base_url: str | None = None, 

233 ) -> None: 

234 """Initialize StackOne tools with authentication 

235 

236 Args: 

237 api_key: Optional API key. If not provided, will try to get from STACKONE_API_KEY env var 

238 account_id: Optional account ID 

239 base_url: Optional base URL override for API requests 

240 

241 Raises: 

242 ToolsetConfigError: If no API key is provided or found in environment 

243 """ 

244 api_key_value = api_key or os.getenv("STACKONE_API_KEY") 

245 if not api_key_value: 

246 raise ToolsetConfigError( 

247 "API key must be provided either through api_key parameter or " 

248 "STACKONE_API_KEY environment variable" 

249 ) 

250 self.api_key: str = api_key_value 

251 self.account_id = account_id 

252 self.base_url = base_url or DEFAULT_BASE_URL 

253 self._account_ids: list[str] = [] 

254 

255 def set_accounts(self, account_ids: list[str]) -> StackOneToolSet: 

256 """Set account IDs for filtering tools 

257 

258 Args: 

259 account_ids: List of account IDs to filter tools by 

260 

261 Returns: 

262 This toolset instance for chaining 

263 """ 

264 self._account_ids = account_ids 

265 return self 

266 

267 def _filter_by_provider(self, tool_name: str, providers: list[str]) -> bool: 

268 """Check if a tool name matches any of the provider filters 

269 

270 Args: 

271 tool_name: Name of the tool to check 

272 providers: List of provider names (case-insensitive) 

273 

274 Returns: 

275 True if the tool matches any provider, False otherwise 

276 """ 

277 # Extract provider from tool name (assuming format: provider_action) 

278 provider = tool_name.split("_")[0].lower() 

279 provider_set = {p.lower() for p in providers} 

280 return provider in provider_set 

281 

282 def _filter_by_action(self, tool_name: str, actions: list[str]) -> bool: 

283 """Check if a tool name matches any of the action patterns 

284 

285 Args: 

286 tool_name: Name of the tool to check 

287 actions: List of action patterns (supports glob patterns) 

288 

289 Returns: 

290 True if the tool matches any action pattern, False otherwise 

291 """ 

292 return any(fnmatch.fnmatch(tool_name, pattern) for pattern in actions) 

293 

294 def fetch_tools( 

295 self, 

296 *, 

297 account_ids: list[str] | None = None, 

298 providers: list[str] | None = None, 

299 actions: list[str] | None = None, 

300 ) -> Tools: 

301 """Fetch tools with optional filtering by account IDs, providers, and actions 

302 

303 Args: 

304 account_ids: Optional list of account IDs to filter by. 

305 If not provided, uses accounts set via set_accounts() 

306 providers: Optional list of provider names (e.g., ['hibob', 'bamboohr']). 

307 Case-insensitive matching. 

308 actions: Optional list of action patterns with glob support 

309 (e.g., ['*_list_employees', 'hibob_create_employees']) 

310 

311 Returns: 

312 Collection of tools matching the filter criteria 

313 

314 Raises: 

315 ToolsetLoadError: If there is an error loading the tools 

316 

317 Examples: 

318 # Filter by account IDs 

319 tools = toolset.fetch_tools(account_ids=['123', '456']) 

320 

321 # Filter by providers 

322 tools = toolset.fetch_tools(providers=['hibob', 'bamboohr']) 

323 

324 # Filter by actions with glob patterns 

325 tools = toolset.fetch_tools(actions=['*_list_employees']) 

326 

327 # Combine filters 

328 tools = toolset.fetch_tools( 

329 account_ids=['123'], 

330 providers=['hibob'], 

331 actions=['*_list_*'] 

332 ) 

333 

334 # Use set_accounts() for account filtering 

335 toolset.set_accounts(['123', '456']) 

336 tools = toolset.fetch_tools() 

337 """ 

338 try: 

339 effective_account_ids = account_ids or self._account_ids 

340 if not effective_account_ids and self.account_id: 

341 effective_account_ids = [self.account_id] 

342 

343 if effective_account_ids: 

344 account_scope: list[str | None] = list(dict.fromkeys(effective_account_ids)) 

345 else: 

346 account_scope = [None] 

347 

348 endpoint = f"{self.base_url.rstrip('/')}/mcp" 

349 all_tools: list[StackOneTool] = [] 

350 

351 for account in account_scope: 

352 headers = self._build_mcp_headers(account) 

353 catalog = _fetch_mcp_tools(endpoint, headers) 

354 for tool_def in catalog: 

355 all_tools.append(self._create_rpc_tool(tool_def, account)) 

356 

357 if providers: 

358 all_tools = [tool for tool in all_tools if self._filter_by_provider(tool.name, providers)] 

359 

360 if actions: 

361 all_tools = [tool for tool in all_tools if self._filter_by_action(tool.name, actions)] 

362 

363 return Tools(all_tools) 

364 

365 except ToolsetError: 

366 raise 

367 except Exception as exc: # pragma: no cover - unexpected runtime errors 

368 raise ToolsetLoadError(f"Error fetching tools: {exc}") from exc 

369 

370 def _build_mcp_headers(self, account_id: str | None) -> dict[str, str]: 

371 headers = { 

372 "Authorization": _build_auth_header(self.api_key), 

373 "User-Agent": _USER_AGENT, 

374 } 

375 if account_id: 

376 headers["x-account-id"] = account_id 

377 return headers 

378 

379 def _create_rpc_tool(self, tool_def: _McpToolDefinition, account_id: str | None) -> StackOneTool: 

380 schema = tool_def.input_schema or {} 

381 parameters = ToolParameters( 

382 type=str(schema.get("type") or "object"), 

383 properties=self._normalize_schema_properties(schema), 

384 ) 

385 return _StackOneRpcTool( 

386 name=tool_def.name, 

387 description=tool_def.description or "", 

388 parameters=parameters, 

389 api_key=self.api_key, 

390 base_url=self.base_url, 

391 account_id=account_id, 

392 ) 

393 

394 def _normalize_schema_properties(self, schema: dict[str, Any]) -> dict[str, Any]: 

395 properties = schema.get("properties") 

396 if not isinstance(properties, dict): 

397 return {} 

398 

399 required_fields = {str(name) for name in schema.get("required", [])} 

400 normalized: dict[str, Any] = {} 

401 

402 for name, details in properties.items(): 

403 if isinstance(details, dict): 

404 prop = dict(details) 

405 else: 

406 prop = {"description": str(details)} 

407 

408 if name in required_fields: 

409 prop.setdefault("nullable", False) 

410 else: 

411 prop.setdefault("nullable", True) 

412 

413 normalized[str(name)] = prop 

414 

415 return normalized