Coverage for stackone_ai / models.py: 98%
276 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-01 15:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-05-01 15:10 +0000
1from __future__ import annotations
3import base64
4import json
5import logging
6from collections.abc import Sequence
7from datetime import datetime, timezone
8from enum import Enum
9from typing import TYPE_CHECKING, Annotated, Any, ClassVar, TypeAlias, cast
10from urllib.parse import quote
12import httpx
13from langchain_core.tools import BaseTool
14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr
16if TYPE_CHECKING:
17 from pydantic_ai.tools import Tool as PydanticAITool
19# Type aliases for common types
20JsonDict: TypeAlias = dict[str, Any]
21Headers: TypeAlias = dict[str, str]
24logger = logging.getLogger("stackone.tools")
27class StackOneError(Exception):
28 """Base exception for StackOne errors"""
30 pass
33class StackOneAPIError(StackOneError):
34 """Raised when the StackOne API returns an error"""
36 def __init__(self, message: str, status_code: int, response_body: Any) -> None:
37 super().__init__(message)
38 self.status_code = status_code
39 self.response_body = response_body
42class ParameterLocation(str, Enum):
43 """Valid locations for parameters in requests"""
45 HEADER = "header"
46 QUERY = "query"
47 PATH = "path"
48 BODY = "body"
49 FILE = "file" # For file uploads
52def validate_method(v: str) -> str:
53 """Validate HTTP method is uppercase and supported"""
54 method = v.upper()
55 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}:
56 raise ValueError(f"Unsupported HTTP method: {method}")
57 return method
60class ExecuteConfig(BaseModel):
61 """Configuration for executing a tool against an API endpoint"""
63 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request")
64 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use")
65 url: str = Field(description="API endpoint URL")
66 name: str = Field(description="Tool name")
67 body_type: str | None = Field(default=None, description="Content type for request body")
68 parameter_locations: dict[str, ParameterLocation] = Field(
69 default_factory=dict, description="Maps parameter names to their location in the request"
70 )
71 timeout: float = Field(default=60.0, description="Request timeout in seconds")
74class ToolParameters(BaseModel):
75 """Schema definition for tool parameters"""
77 type: str = Field(description="JSON Schema type")
78 properties: JsonDict = Field(description="JSON Schema properties")
81class ToolDefinition(BaseModel):
82 """Complete definition of a tool including its schema and execution config"""
84 description: str = Field(description="Tool description")
85 parameters: ToolParameters = Field(description="Tool parameter schema")
86 execute: ExecuteConfig = Field(description="Tool execution configuration")
89class StackOneTool(BaseModel):
90 """Base class for all StackOne tools. Provides functionality for executing API calls
91 and converting to various formats (OpenAI, LangChain)."""
93 name: str = Field(description="Tool name")
94 description: str = Field(description="Tool description")
95 parameters: ToolParameters = Field(description="Tool parameters")
96 _execute_config: ExecuteConfig = PrivateAttr()
97 _api_key: str = PrivateAttr()
98 _account_id: str | None = PrivateAttr(default=None)
99 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = {
100 "feedback_session_id",
101 "feedback_user_id",
102 "feedback_metadata",
103 }
105 @property
106 def connector(self) -> str:
107 """Extract connector from tool name.
109 Tool names follow the format: {connector}_{action}_{entity}
110 e.g., 'bamboohr_create_employee' -> 'bamboohr'
112 Returns:
113 Connector name in lowercase
114 """
115 return self.name.split("_")[0].lower()
117 def __init__(
118 self,
119 description: str,
120 parameters: ToolParameters,
121 _execute_config: ExecuteConfig,
122 _api_key: str,
123 _account_id: str | None = None,
124 ) -> None:
125 super().__init__(
126 name=_execute_config.name,
127 description=description,
128 parameters=parameters,
129 )
130 self._execute_config = _execute_config
131 self._api_key = _api_key
132 self._account_id = _account_id
134 @classmethod
135 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]:
136 merged_params = dict(params)
137 feedback_options = dict(options or {})
138 for key in cls._FEEDBACK_OPTION_KEYS:
139 if key in merged_params and key not in feedback_options:
140 feedback_options[key] = merged_params.pop(key)
141 return merged_params, feedback_options
143 def _prepare_headers(self) -> Headers:
144 """Prepare headers for the API request
146 Returns:
147 Headers to use in the request
148 """
149 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode()
150 headers: Headers = {
151 "Authorization": f"Basic {auth_string}",
152 "User-Agent": "stackone-python/1.0.0",
153 }
155 if self._account_id:
156 headers["x-account-id"] = self._account_id
158 # Add predefined headers
159 headers.update(self._execute_config.headers)
160 return headers
162 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]:
163 """Prepare URL and parameters for the API request
165 Args:
166 kwargs: Arguments to process
168 Returns:
169 Tuple of (url, body_params, query_params)
170 """
171 url = self._execute_config.url
172 body_params: JsonDict = {}
173 query_params: JsonDict = {}
175 for key, value in kwargs.items():
176 param_location = self._execute_config.parameter_locations.get(key)
178 if param_location == ParameterLocation.PATH:
179 # Safely encode path parameters to prevent SSRF attacks
180 encoded_value = quote(str(value), safe="")
181 url = url.replace(f"{{{key}}}", encoded_value)
182 elif param_location == ParameterLocation.QUERY:
183 query_params[key] = value
184 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE):
185 body_params[key] = value
186 else:
187 # Default behavior
188 if f"{{{key}}}" in url:
189 # Safely encode path parameters to prevent SSRF attacks
190 encoded_value = quote(str(value), safe="")
191 url = url.replace(f"{{{key}}}", encoded_value)
192 elif self._execute_config.method in {"GET", "DELETE"}:
193 query_params[key] = value
194 else:
195 body_params[key] = value
197 return url, body_params, query_params
199 def execute(
200 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
201 ) -> JsonDict:
202 """Execute the tool with the given parameters
204 Args:
205 arguments: Tool arguments as string or dict
206 options: Execution options (e.g. feedback metadata)
208 Returns:
209 API response as dict
211 Raises:
212 StackOneAPIError: If the API request fails
213 ValueError: If the arguments are invalid
214 """
215 datetime.now(timezone.utc)
216 feedback_options: JsonDict = {}
217 result_payload: JsonDict | None = None
218 response_status: int | None = None
219 error_message: str | None = None
220 status = "success"
221 url_used = self._execute_config.url
223 try:
224 if isinstance(arguments, str):
225 parsed_arguments = json.loads(arguments)
226 else:
227 parsed_arguments = arguments or {}
229 if not isinstance(parsed_arguments, dict):
230 status = "error"
231 error_message = "Tool arguments must be a JSON object"
232 raise ValueError(error_message)
234 kwargs = parsed_arguments
235 dict(kwargs)
237 headers = self._prepare_headers()
238 url_used, body_params, query_params = self._prepare_request_params(kwargs)
240 request_kwargs: dict[str, Any] = {
241 "method": self._execute_config.method,
242 "url": url_used,
243 "headers": headers,
244 }
246 if body_params:
247 body_type = self._execute_config.body_type or "json"
248 if body_type == "json":
249 request_kwargs["json"] = body_params
250 elif body_type == "form": 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true
251 request_kwargs["data"] = body_params
253 if query_params:
254 request_kwargs["params"] = query_params
256 response = httpx.request(**request_kwargs, timeout=self._execute_config.timeout)
257 response_status = response.status_code
258 response.raise_for_status()
260 result = response.json()
261 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result}
262 return result_payload
264 except json.JSONDecodeError as exc:
265 status = "error"
266 error_message = f"Invalid JSON in arguments: {exc}"
267 raise ValueError(error_message) from exc
268 except httpx.HTTPStatusError as exc:
269 status = "error"
270 response_body = None
271 if exc.response.text: 271 ↛ 276line 271 didn't jump to line 276 because the condition on line 271 was always true
272 try:
273 response_body = exc.response.json()
274 except json.JSONDecodeError:
275 response_body = exc.response.text
276 raise StackOneAPIError(
277 str(exc),
278 exc.response.status_code,
279 response_body,
280 ) from exc
281 except httpx.RequestError as exc:
282 status = "error"
283 raise StackOneError(f"Request failed: {exc}") from exc
284 finally:
285 datetime.now(timezone.utc)
286 metadata: JsonDict = {
287 "http_method": self._execute_config.method,
288 "url": url_used,
289 "status_code": response_status,
290 "status": status,
291 }
293 feedback_metadata = feedback_options.get("feedback_metadata")
294 if isinstance(feedback_metadata, dict): 294 ↛ 295line 294 didn't jump to line 295 because the condition on line 294 was never true
295 metadata["feedback_metadata"] = feedback_metadata
297 if feedback_options:
298 metadata["feedback_options"] = {
299 key: value
300 for key, value in feedback_options.items()
301 if key in {"feedback_session_id", "feedback_user_id"} and value is not None
302 }
304 # Implicit feedback removed - just API calls
306 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
307 """Call the tool with the given arguments
309 This method provides a more intuitive way to execute tools directly.
311 Args:
312 *args: If a single argument is provided, it's treated as the full arguments dict/string
313 **kwargs: Keyword arguments to pass to the tool
314 options: Optional execution options
316 Returns:
317 API response as dict
319 Raises:
320 StackOneAPIError: If the API request fails
321 ValueError: If the arguments are invalid
323 Examples:
324 >>> tool.call({"name": "John", "email": "john@example.com"})
325 >>> tool.call(name="John", email="john@example.com")
326 """
327 if args and kwargs:
328 raise ValueError("Cannot provide both positional and keyword arguments")
330 if args:
331 if len(args) > 1:
332 raise ValueError("Only one positional argument is allowed")
333 return self.execute(args[0])
335 return self.execute(kwargs if kwargs else None)
337 def __call__(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
338 """Make the tool directly callable.
340 Alias for :meth:`call` so that ``tool(query="…")`` works.
341 """
342 return self.call(*args, options=options, **kwargs)
344 def to_openai_function(self) -> JsonDict:
345 """Convert this tool to OpenAI's function format
347 Returns:
348 Tool definition in OpenAI function format
349 """
350 # Clean properties and handle special types
351 properties = {}
352 required = []
354 for name, prop in self.parameters.properties.items():
355 if isinstance(prop, dict):
356 # Only keep standard JSON Schema properties
357 cleaned_prop = {}
359 # Copy basic properties
360 if "type" in prop:
361 cleaned_prop["type"] = prop["type"]
362 if "description" in prop:
363 cleaned_prop["description"] = prop["description"]
364 if "enum" in prop:
365 cleaned_prop["enum"] = prop["enum"]
367 # Handle array types
368 if cleaned_prop.get("type") == "array" and "items" in prop:
369 if isinstance(prop["items"], dict): 369 ↛ 375line 369 didn't jump to line 375 because the condition on line 369 was always true
370 cleaned_prop["items"] = {
371 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum")
372 }
374 # Handle object types
375 if cleaned_prop.get("type") == "object" and "properties" in prop:
376 cleaned_prop["properties"] = {
377 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")}
378 for k, v in prop["properties"].items()
379 }
381 # Handle required fields - if not explicitly nullable
382 if not prop.get("nullable", False):
383 required.append(name)
385 properties[name] = cleaned_prop
386 else:
387 properties[name] = {"type": "string"}
388 required.append(name)
390 # Create the OpenAI function schema
391 parameters = {
392 "type": "object",
393 "properties": properties,
394 }
396 # Only include required if there are required fields
397 if required:
398 parameters["required"] = required
400 return {
401 "type": "function",
402 "function": {
403 "name": self.name,
404 "description": self.description,
405 "parameters": parameters,
406 },
407 }
409 def to_langchain(self) -> BaseTool:
410 """Convert this tool to LangChain format
412 Returns:
413 Tool in LangChain format
414 """
415 # Create properly annotated schema for the tool
416 schema_props: dict[str, Any] = {}
417 annotations: dict[str, Any] = {}
419 for name, details in self.parameters.properties.items():
420 python_type: type = str # Default to str
421 is_nullable = False
422 if isinstance(details, dict):
423 type_str = details.get("type", "string")
424 is_nullable = details.get("nullable", False)
425 if type_str == "number":
426 python_type = float
427 elif type_str == "integer":
428 python_type = int
429 elif type_str == "boolean":
430 python_type = bool
431 elif type_str == "object":
432 python_type = dict
433 elif type_str == "array": 433 ↛ 434line 433 didn't jump to line 434 because the condition on line 433 was never true
434 python_type = list
436 if is_nullable:
437 field = Field(default=None, description=details.get("description", ""))
438 else:
439 field = Field(description=details.get("description", ""))
440 else:
441 field = Field(description="")
443 schema_props[name] = field
444 if is_nullable:
445 annotations[name] = python_type | None
446 else:
447 annotations[name] = python_type
449 # Create the schema class with proper annotations
450 schema_class = type(
451 f"{self.name.title()}Args",
452 (BaseModel,),
453 {
454 "__annotations__": annotations,
455 "__module__": __name__,
456 **schema_props,
457 },
458 )
460 parent_tool = self
462 class StackOneLangChainTool(BaseTool):
463 name: str = parent_tool.name
464 description: str = parent_tool.description
465 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment]
466 func = staticmethod(parent_tool.execute) # Required by CrewAI
468 def _run(self, **kwargs: Any) -> Any:
469 return parent_tool.execute(kwargs)
471 return StackOneLangChainTool()
473 def to_pydantic_ai_tool(self) -> PydanticAITool:
474 """Convert this tool to a Pydantic AI ``Tool``.
476 Requires ``stackone-ai[pydantic-ai]`` (installs ``pydantic-ai-slim``).
478 Returns:
479 A ``pydantic_ai.tools.Tool`` ready to pass to ``Agent(tools=[...])``.
480 """
481 try:
482 from pydantic_ai.tools import Tool
483 except ImportError as e:
484 raise ImportError(
485 "Install `pydantic-ai-slim` (or `stackone-ai[pydantic-ai]`) "
486 "to use the Pydantic AI integration."
487 ) from e
489 openai_function = self.to_openai_function()
490 json_schema = openai_function["function"]["parameters"]
491 parent_tool = self
493 def implementation(**kwargs: Any) -> Any:
494 return parent_tool.execute(kwargs)
496 return Tool.from_schema(
497 function=implementation,
498 name=self.name,
499 description=self.description,
500 json_schema=json_schema,
501 )
503 def set_account_id(self, account_id: str | None) -> None:
504 """Set the account ID for this tool
506 Args:
507 account_id: The account ID to use, or None to clear it
508 """
509 self._account_id = account_id
511 def get_account_id(self) -> str | None:
512 """Get the current account ID for this tool
514 Returns:
515 Current account ID or None if not set
516 """
517 return self._account_id
520class Tools:
521 """Container for Tool instances with lookup capabilities"""
523 def __init__(
524 self,
525 tools: list[StackOneTool],
526 ) -> None:
527 """Initialize Tools container
529 Args:
530 tools: List of Tool instances to manage
531 """
532 self.tools = tools
533 self._tool_map = {tool.name: tool for tool in tools}
535 def __getitem__(self, index: int) -> StackOneTool:
536 return self.tools[index]
538 def __len__(self) -> int:
539 return len(self.tools)
541 def __iter__(self) -> Any:
542 """Make Tools iterable"""
543 return iter(self.tools)
545 def to_list(self) -> list[StackOneTool]:
546 """Convert to list of tools
548 Returns:
549 List of StackOneTool instances
550 """
551 return list(self.tools)
553 def get_tool(self, name: str) -> StackOneTool | None:
554 """Get a tool by its name
556 Args:
557 name: Name of the tool to retrieve
559 Returns:
560 The tool if found, None otherwise
561 """
562 return self._tool_map.get(name)
564 def set_account_id(self, account_id: str | None) -> None:
565 """Set the account ID for all tools in this collection
567 Args:
568 account_id: The account ID to use, or None to clear it
569 """
570 for tool in self.tools:
571 tool.set_account_id(account_id)
573 def get_account_id(self) -> str | None:
574 """Get the current account ID for this collection
576 Returns:
577 The first non-None account ID found, or None if none set
578 """
579 for tool in self.tools:
580 account_id = tool.get_account_id()
581 if isinstance(account_id, str):
582 return account_id
583 return None
585 def get_connectors(self) -> set[str]:
586 """Get unique connector names from all tools.
588 Returns:
589 Set of connector names (lowercase)
591 Example:
592 tools = toolset.fetch_tools()
593 connectors = tools.get_connectors()
594 # {'bamboohr', 'hibob', 'slack', ...}
595 """
596 return {tool.connector for tool in self.tools}
598 def to_openai(self) -> list[JsonDict]:
599 """Convert all tools to OpenAI function format
601 Returns:
602 List of tools in OpenAI function format
603 """
604 return [tool.to_openai_function() for tool in self.tools]
606 def to_langchain(self) -> Sequence[BaseTool]:
607 """Convert all tools to LangChain format
609 Returns:
610 Sequence of tools in LangChain format
611 """
612 return [tool.to_langchain() for tool in self.tools]
614 def to_pydantic_ai(self) -> list[PydanticAITool]:
615 """Convert all tools to Pydantic AI ``Tool`` instances.
617 Requires ``stackone-ai[pydantic-ai]`` (installs ``pydantic-ai-slim``).
619 Returns:
620 List of ``pydantic_ai.tools.Tool`` ready to pass to ``Agent(tools=[...])``.
621 """
622 return [tool.to_pydantic_ai_tool() for tool in self.tools]