Coverage for stackone_ai / toolset.py: 100%
179 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-08 18:25 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-08 18:25 +0000
1from __future__ import annotations
3import asyncio
4import base64
5import fnmatch
6import json
7import os
8import threading
9from collections.abc import Coroutine
10from dataclasses import dataclass
11from importlib import metadata
12from typing import Any, TypeVar
14from stackone_ai.models import (
15 ExecuteConfig,
16 ParameterLocation,
17 StackOneTool,
18 ToolParameters,
19 Tools,
20)
22try:
23 _SDK_VERSION = metadata.version("stackone-ai")
24except metadata.PackageNotFoundError: # pragma: no cover - best-effort fallback when running from source
25 _SDK_VERSION = "dev"
27DEFAULT_BASE_URL = "https://api.stackone.com"
28_RPC_PARAMETER_LOCATIONS = {
29 "action": ParameterLocation.BODY,
30 "body": ParameterLocation.BODY,
31 "headers": ParameterLocation.BODY,
32 "path": ParameterLocation.BODY,
33 "query": ParameterLocation.BODY,
34}
35_USER_AGENT = f"stackone-ai-python/{_SDK_VERSION}"
37T = TypeVar("T")
40@dataclass
41class _McpToolDefinition:
42 name: str
43 description: str | None
44 input_schema: dict[str, Any]
47class ToolsetError(Exception):
48 """Base exception for toolset errors"""
50 pass
53class ToolsetConfigError(ToolsetError):
54 """Raised when there is an error in the toolset configuration"""
56 pass
59class ToolsetLoadError(ToolsetError):
60 """Raised when there is an error loading tools"""
62 pass
65def _run_async(awaitable: Coroutine[Any, Any, T]) -> T:
66 """Run a coroutine, even when called from an existing event loop."""
68 try:
69 asyncio.get_running_loop()
70 except RuntimeError:
71 return asyncio.run(awaitable)
73 result: dict[str, T] = {}
74 error: dict[str, BaseException] = {}
76 def runner() -> None:
77 try:
78 result["value"] = asyncio.run(awaitable)
79 except BaseException as exc: # pragma: no cover - surfaced in caller context
80 error["error"] = exc
82 thread = threading.Thread(target=runner, daemon=True)
83 thread.start()
84 thread.join()
86 if "error" in error:
87 raise error["error"]
89 return result["value"]
92def _build_auth_header(api_key: str) -> str:
93 token = base64.b64encode(f"{api_key}:".encode()).decode()
94 return f"Basic {token}"
97def _fetch_mcp_tools(endpoint: str, headers: dict[str, str]) -> list[_McpToolDefinition]:
98 try:
99 from mcp import types as mcp_types # ty: ignore[unresolved-import]
100 from mcp.client.session import ClientSession # ty: ignore[unresolved-import]
101 from mcp.client.streamable_http import streamablehttp_client # ty: ignore[unresolved-import]
102 except ImportError as exc: # pragma: no cover - depends on optional extra
103 raise ToolsetConfigError(
104 "MCP dependencies are required for fetch_tools. Install with 'pip install \"stackone-ai[mcp]\"'."
105 ) from exc
107 async def _list() -> list[_McpToolDefinition]:
108 async with streamablehttp_client(endpoint, headers=headers) as (read_stream, write_stream, _):
109 session = ClientSession(
110 read_stream,
111 write_stream,
112 client_info=mcp_types.Implementation(name="stackone-ai-python", version=_SDK_VERSION),
113 )
114 async with session:
115 await session.initialize()
116 cursor: str | None = None
117 collected: list[_McpToolDefinition] = []
118 while True:
119 result = await session.list_tools(cursor)
120 for tool in result.tools:
121 input_schema = tool.inputSchema or {}
122 collected.append(
123 _McpToolDefinition(
124 name=tool.name,
125 description=tool.description,
126 input_schema=dict(input_schema),
127 )
128 )
129 cursor = result.nextCursor
130 if cursor is None:
131 break
132 return collected
134 return _run_async(_list())
137class _StackOneRpcTool(StackOneTool):
138 """RPC-backed tool wired to the StackOne actions RPC endpoint."""
140 def __init__(
141 self,
142 *,
143 name: str,
144 description: str,
145 parameters: ToolParameters,
146 api_key: str,
147 base_url: str,
148 account_id: str | None,
149 ) -> None:
150 execute_config = ExecuteConfig(
151 method="POST",
152 url=f"{base_url.rstrip('/')}/actions/rpc",
153 name=name,
154 headers={},
155 body_type="json",
156 parameter_locations=dict(_RPC_PARAMETER_LOCATIONS),
157 )
158 super().__init__(
159 description=description,
160 parameters=parameters,
161 _execute_config=execute_config,
162 _api_key=api_key,
163 _account_id=account_id,
164 )
166 def execute(
167 self, arguments: str | dict[str, Any] | None = None, *, options: dict[str, Any] | None = None
168 ) -> dict[str, Any]:
169 parsed_arguments = self._parse_arguments(arguments)
171 body_payload = self._extract_record(parsed_arguments.pop("body", None))
172 headers_payload = self._extract_record(parsed_arguments.pop("headers", None))
173 path_payload = self._extract_record(parsed_arguments.pop("path", None))
174 query_payload = self._extract_record(parsed_arguments.pop("query", None))
176 rpc_body: dict[str, Any] = dict(body_payload or {})
177 for key, value in parsed_arguments.items():
178 rpc_body[key] = value
180 payload: dict[str, Any] = {
181 "action": self.name,
182 "body": rpc_body,
183 "headers": self._build_action_headers(headers_payload),
184 }
185 if path_payload:
186 payload["path"] = path_payload
187 if query_payload:
188 payload["query"] = query_payload
190 return super().execute(payload, options=options)
192 def _parse_arguments(self, arguments: str | dict[str, Any] | None) -> dict[str, Any]:
193 if arguments is None:
194 return {}
195 if isinstance(arguments, str):
196 parsed = json.loads(arguments)
197 else:
198 parsed = arguments
199 if not isinstance(parsed, dict):
200 raise ValueError("Tool arguments must be a JSON object")
201 return dict(parsed)
203 @staticmethod
204 def _extract_record(value: Any) -> dict[str, Any] | None:
205 if isinstance(value, dict):
206 return dict(value)
207 return None
209 def _build_action_headers(self, additional_headers: dict[str, Any] | None) -> dict[str, str]:
210 headers: dict[str, str] = {}
211 account_id = self.get_account_id()
212 if account_id:
213 headers["x-account-id"] = account_id
215 if additional_headers:
216 for key, value in additional_headers.items():
217 if value is None:
218 continue
219 headers[str(key)] = str(value)
221 headers.pop("Authorization", None)
222 return headers
225class StackOneToolSet:
226 """Main class for accessing StackOne tools"""
228 def __init__(
229 self,
230 api_key: str | None = None,
231 account_id: str | None = None,
232 base_url: str | None = None,
233 ) -> None:
234 """Initialize StackOne tools with authentication
236 Args:
237 api_key: Optional API key. If not provided, will try to get from STACKONE_API_KEY env var
238 account_id: Optional account ID
239 base_url: Optional base URL override for API requests
241 Raises:
242 ToolsetConfigError: If no API key is provided or found in environment
243 """
244 api_key_value = api_key or os.getenv("STACKONE_API_KEY")
245 if not api_key_value:
246 raise ToolsetConfigError(
247 "API key must be provided either through api_key parameter or "
248 "STACKONE_API_KEY environment variable"
249 )
250 self.api_key: str = api_key_value
251 self.account_id = account_id
252 self.base_url = base_url or DEFAULT_BASE_URL
253 self._account_ids: list[str] = []
255 def set_accounts(self, account_ids: list[str]) -> StackOneToolSet:
256 """Set account IDs for filtering tools
258 Args:
259 account_ids: List of account IDs to filter tools by
261 Returns:
262 This toolset instance for chaining
263 """
264 self._account_ids = account_ids
265 return self
267 def _filter_by_provider(self, tool_name: str, providers: list[str]) -> bool:
268 """Check if a tool name matches any of the provider filters
270 Args:
271 tool_name: Name of the tool to check
272 providers: List of provider names (case-insensitive)
274 Returns:
275 True if the tool matches any provider, False otherwise
276 """
277 # Extract provider from tool name (assuming format: provider_action)
278 provider = tool_name.split("_")[0].lower()
279 provider_set = {p.lower() for p in providers}
280 return provider in provider_set
282 def _filter_by_action(self, tool_name: str, actions: list[str]) -> bool:
283 """Check if a tool name matches any of the action patterns
285 Args:
286 tool_name: Name of the tool to check
287 actions: List of action patterns (supports glob patterns)
289 Returns:
290 True if the tool matches any action pattern, False otherwise
291 """
292 return any(fnmatch.fnmatch(tool_name, pattern) for pattern in actions)
294 def fetch_tools(
295 self,
296 *,
297 account_ids: list[str] | None = None,
298 providers: list[str] | None = None,
299 actions: list[str] | None = None,
300 ) -> Tools:
301 """Fetch tools with optional filtering by account IDs, providers, and actions
303 Args:
304 account_ids: Optional list of account IDs to filter by.
305 If not provided, uses accounts set via set_accounts()
306 providers: Optional list of provider names (e.g., ['hibob', 'bamboohr']).
307 Case-insensitive matching.
308 actions: Optional list of action patterns with glob support
309 (e.g., ['*_list_employees', 'hibob_create_employees'])
311 Returns:
312 Collection of tools matching the filter criteria
314 Raises:
315 ToolsetLoadError: If there is an error loading the tools
317 Examples:
318 # Filter by account IDs
319 tools = toolset.fetch_tools(account_ids=['123', '456'])
321 # Filter by providers
322 tools = toolset.fetch_tools(providers=['hibob', 'bamboohr'])
324 # Filter by actions with glob patterns
325 tools = toolset.fetch_tools(actions=['*_list_employees'])
327 # Combine filters
328 tools = toolset.fetch_tools(
329 account_ids=['123'],
330 providers=['hibob'],
331 actions=['*_list_*']
332 )
334 # Use set_accounts() for account filtering
335 toolset.set_accounts(['123', '456'])
336 tools = toolset.fetch_tools()
337 """
338 try:
339 effective_account_ids = account_ids or self._account_ids
340 if not effective_account_ids and self.account_id:
341 effective_account_ids = [self.account_id]
343 if effective_account_ids:
344 account_scope: list[str | None] = list(dict.fromkeys(effective_account_ids))
345 else:
346 account_scope = [None]
348 endpoint = f"{self.base_url.rstrip('/')}/mcp"
349 all_tools: list[StackOneTool] = []
351 for account in account_scope:
352 headers = self._build_mcp_headers(account)
353 catalog = _fetch_mcp_tools(endpoint, headers)
354 for tool_def in catalog:
355 all_tools.append(self._create_rpc_tool(tool_def, account))
357 if providers:
358 all_tools = [tool for tool in all_tools if self._filter_by_provider(tool.name, providers)]
360 if actions:
361 all_tools = [tool for tool in all_tools if self._filter_by_action(tool.name, actions)]
363 return Tools(all_tools)
365 except ToolsetError:
366 raise
367 except Exception as exc: # pragma: no cover - unexpected runtime errors
368 raise ToolsetLoadError(f"Error fetching tools: {exc}") from exc
370 def _build_mcp_headers(self, account_id: str | None) -> dict[str, str]:
371 headers = {
372 "Authorization": _build_auth_header(self.api_key),
373 "User-Agent": _USER_AGENT,
374 }
375 if account_id:
376 headers["x-account-id"] = account_id
377 return headers
379 def _create_rpc_tool(self, tool_def: _McpToolDefinition, account_id: str | None) -> StackOneTool:
380 schema = tool_def.input_schema or {}
381 parameters = ToolParameters(
382 type=str(schema.get("type") or "object"),
383 properties=self._normalize_schema_properties(schema),
384 )
385 return _StackOneRpcTool(
386 name=tool_def.name,
387 description=tool_def.description or "",
388 parameters=parameters,
389 api_key=self.api_key,
390 base_url=self.base_url,
391 account_id=account_id,
392 )
394 def _normalize_schema_properties(self, schema: dict[str, Any]) -> dict[str, Any]:
395 properties = schema.get("properties")
396 if not isinstance(properties, dict):
397 return {}
399 required_fields = {str(name) for name in schema.get("required", [])}
400 normalized: dict[str, Any] = {}
402 for name, details in properties.items():
403 if isinstance(details, dict):
404 prop = dict(details)
405 else:
406 prop = {"description": str(details)}
408 if name in required_fields:
409 prop.setdefault("nullable", False)
410 else:
411 prop.setdefault("nullable", True)
413 normalized[str(name)] = prop
415 return normalized