Coverage for stackone_ai / models.py: 97%
262 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-02 08:51 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-04-02 08:51 +0000
1from __future__ import annotations
3import base64
4import json
5import logging
6from collections.abc import Sequence
7from datetime import datetime, timezone
8from enum import Enum
9from typing import Annotated, Any, ClassVar, TypeAlias, cast
10from urllib.parse import quote
12import httpx
13from langchain_core.tools import BaseTool
14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr
16# Type aliases for common types
17JsonDict: TypeAlias = dict[str, Any]
18Headers: TypeAlias = dict[str, str]
21logger = logging.getLogger("stackone.tools")
24class StackOneError(Exception):
25 """Base exception for StackOne errors"""
27 pass
30class StackOneAPIError(StackOneError):
31 """Raised when the StackOne API returns an error"""
33 def __init__(self, message: str, status_code: int, response_body: Any) -> None:
34 super().__init__(message)
35 self.status_code = status_code
36 self.response_body = response_body
39class ParameterLocation(str, Enum):
40 """Valid locations for parameters in requests"""
42 HEADER = "header"
43 QUERY = "query"
44 PATH = "path"
45 BODY = "body"
46 FILE = "file" # For file uploads
49def validate_method(v: str) -> str:
50 """Validate HTTP method is uppercase and supported"""
51 method = v.upper()
52 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}:
53 raise ValueError(f"Unsupported HTTP method: {method}")
54 return method
57class ExecuteConfig(BaseModel):
58 """Configuration for executing a tool against an API endpoint"""
60 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request")
61 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use")
62 url: str = Field(description="API endpoint URL")
63 name: str = Field(description="Tool name")
64 body_type: str | None = Field(default=None, description="Content type for request body")
65 parameter_locations: dict[str, ParameterLocation] = Field(
66 default_factory=dict, description="Maps parameter names to their location in the request"
67 )
70class ToolParameters(BaseModel):
71 """Schema definition for tool parameters"""
73 type: str = Field(description="JSON Schema type")
74 properties: JsonDict = Field(description="JSON Schema properties")
77class ToolDefinition(BaseModel):
78 """Complete definition of a tool including its schema and execution config"""
80 description: str = Field(description="Tool description")
81 parameters: ToolParameters = Field(description="Tool parameter schema")
82 execute: ExecuteConfig = Field(description="Tool execution configuration")
85class StackOneTool(BaseModel):
86 """Base class for all StackOne tools. Provides functionality for executing API calls
87 and converting to various formats (OpenAI, LangChain)."""
89 name: str = Field(description="Tool name")
90 description: str = Field(description="Tool description")
91 parameters: ToolParameters = Field(description="Tool parameters")
92 _execute_config: ExecuteConfig = PrivateAttr()
93 _api_key: str = PrivateAttr()
94 _account_id: str | None = PrivateAttr(default=None)
95 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = {
96 "feedback_session_id",
97 "feedback_user_id",
98 "feedback_metadata",
99 }
101 @property
102 def connector(self) -> str:
103 """Extract connector from tool name.
105 Tool names follow the format: {connector}_{action}_{entity}
106 e.g., 'bamboohr_create_employee' -> 'bamboohr'
108 Returns:
109 Connector name in lowercase
110 """
111 return self.name.split("_")[0].lower()
113 def __init__(
114 self,
115 description: str,
116 parameters: ToolParameters,
117 _execute_config: ExecuteConfig,
118 _api_key: str,
119 _account_id: str | None = None,
120 ) -> None:
121 super().__init__(
122 name=_execute_config.name,
123 description=description,
124 parameters=parameters,
125 )
126 self._execute_config = _execute_config
127 self._api_key = _api_key
128 self._account_id = _account_id
130 @classmethod
131 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]:
132 merged_params = dict(params)
133 feedback_options = dict(options or {})
134 for key in cls._FEEDBACK_OPTION_KEYS:
135 if key in merged_params and key not in feedback_options:
136 feedback_options[key] = merged_params.pop(key)
137 return merged_params, feedback_options
139 def _prepare_headers(self) -> Headers:
140 """Prepare headers for the API request
142 Returns:
143 Headers to use in the request
144 """
145 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode()
146 headers: Headers = {
147 "Authorization": f"Basic {auth_string}",
148 "User-Agent": "stackone-python/1.0.0",
149 }
151 if self._account_id:
152 headers["x-account-id"] = self._account_id
154 # Add predefined headers
155 headers.update(self._execute_config.headers)
156 return headers
158 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]:
159 """Prepare URL and parameters for the API request
161 Args:
162 kwargs: Arguments to process
164 Returns:
165 Tuple of (url, body_params, query_params)
166 """
167 url = self._execute_config.url
168 body_params: JsonDict = {}
169 query_params: JsonDict = {}
171 for key, value in kwargs.items():
172 param_location = self._execute_config.parameter_locations.get(key)
174 if param_location == ParameterLocation.PATH:
175 # Safely encode path parameters to prevent SSRF attacks
176 encoded_value = quote(str(value), safe="")
177 url = url.replace(f"{{{key}}}", encoded_value)
178 elif param_location == ParameterLocation.QUERY:
179 query_params[key] = value
180 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE):
181 body_params[key] = value
182 else:
183 # Default behavior
184 if f"{{{key}}}" in url:
185 # Safely encode path parameters to prevent SSRF attacks
186 encoded_value = quote(str(value), safe="")
187 url = url.replace(f"{{{key}}}", encoded_value)
188 elif self._execute_config.method in {"GET", "DELETE"}:
189 query_params[key] = value
190 else:
191 body_params[key] = value
193 return url, body_params, query_params
195 def execute(
196 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
197 ) -> JsonDict:
198 """Execute the tool with the given parameters
200 Args:
201 arguments: Tool arguments as string or dict
202 options: Execution options (e.g. feedback metadata)
204 Returns:
205 API response as dict
207 Raises:
208 StackOneAPIError: If the API request fails
209 ValueError: If the arguments are invalid
210 """
211 datetime.now(timezone.utc)
212 feedback_options: JsonDict = {}
213 result_payload: JsonDict | None = None
214 response_status: int | None = None
215 error_message: str | None = None
216 status = "success"
217 url_used = self._execute_config.url
219 try:
220 if isinstance(arguments, str):
221 parsed_arguments = json.loads(arguments)
222 else:
223 parsed_arguments = arguments or {}
225 if not isinstance(parsed_arguments, dict):
226 status = "error"
227 error_message = "Tool arguments must be a JSON object"
228 raise ValueError(error_message)
230 kwargs = parsed_arguments
231 dict(kwargs)
233 headers = self._prepare_headers()
234 url_used, body_params, query_params = self._prepare_request_params(kwargs)
236 request_kwargs: dict[str, Any] = {
237 "method": self._execute_config.method,
238 "url": url_used,
239 "headers": headers,
240 }
242 if body_params:
243 body_type = self._execute_config.body_type or "json"
244 if body_type == "json":
245 request_kwargs["json"] = body_params
246 elif body_type == "form": 246 ↛ 249line 246 didn't jump to line 249 because the condition on line 246 was always true
247 request_kwargs["data"] = body_params
249 if query_params:
250 request_kwargs["params"] = query_params
252 response = httpx.request(**request_kwargs)
253 response_status = response.status_code
254 response.raise_for_status()
256 result = response.json()
257 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result}
258 return result_payload
260 except json.JSONDecodeError as exc:
261 status = "error"
262 error_message = f"Invalid JSON in arguments: {exc}"
263 raise ValueError(error_message) from exc
264 except httpx.HTTPStatusError as exc:
265 status = "error"
266 response_body = None
267 if exc.response.text: 267 ↛ 272line 267 didn't jump to line 272 because the condition on line 267 was always true
268 try:
269 response_body = exc.response.json()
270 except json.JSONDecodeError:
271 response_body = exc.response.text
272 raise StackOneAPIError(
273 str(exc),
274 exc.response.status_code,
275 response_body,
276 ) from exc
277 except httpx.RequestError as exc:
278 status = "error"
279 raise StackOneError(f"Request failed: {exc}") from exc
280 finally:
281 datetime.now(timezone.utc)
282 metadata: JsonDict = {
283 "http_method": self._execute_config.method,
284 "url": url_used,
285 "status_code": response_status,
286 "status": status,
287 }
289 feedback_metadata = feedback_options.get("feedback_metadata")
290 if isinstance(feedback_metadata, dict): 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true
291 metadata["feedback_metadata"] = feedback_metadata
293 if feedback_options:
294 metadata["feedback_options"] = {
295 key: value
296 for key, value in feedback_options.items()
297 if key in {"feedback_session_id", "feedback_user_id"} and value is not None
298 }
300 # Implicit feedback removed - just API calls
302 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
303 """Call the tool with the given arguments
305 This method provides a more intuitive way to execute tools directly.
307 Args:
308 *args: If a single argument is provided, it's treated as the full arguments dict/string
309 **kwargs: Keyword arguments to pass to the tool
310 options: Optional execution options
312 Returns:
313 API response as dict
315 Raises:
316 StackOneAPIError: If the API request fails
317 ValueError: If the arguments are invalid
319 Examples:
320 >>> tool.call({"name": "John", "email": "john@example.com"})
321 >>> tool.call(name="John", email="john@example.com")
322 """
323 if args and kwargs:
324 raise ValueError("Cannot provide both positional and keyword arguments")
326 if args:
327 if len(args) > 1:
328 raise ValueError("Only one positional argument is allowed")
329 return self.execute(args[0])
331 return self.execute(kwargs if kwargs else None)
333 def __call__(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
334 """Make the tool directly callable.
336 Alias for :meth:`call` so that ``tool(query="…")`` works.
337 """
338 return self.call(*args, options=options, **kwargs)
340 def to_openai_function(self) -> JsonDict:
341 """Convert this tool to OpenAI's function format
343 Returns:
344 Tool definition in OpenAI function format
345 """
346 # Clean properties and handle special types
347 properties = {}
348 required = []
350 for name, prop in self.parameters.properties.items():
351 if isinstance(prop, dict):
352 # Only keep standard JSON Schema properties
353 cleaned_prop = {}
355 # Copy basic properties
356 if "type" in prop:
357 cleaned_prop["type"] = prop["type"]
358 if "description" in prop:
359 cleaned_prop["description"] = prop["description"]
360 if "enum" in prop:
361 cleaned_prop["enum"] = prop["enum"]
363 # Handle array types
364 if cleaned_prop.get("type") == "array" and "items" in prop:
365 if isinstance(prop["items"], dict): 365 ↛ 371line 365 didn't jump to line 371 because the condition on line 365 was always true
366 cleaned_prop["items"] = {
367 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum")
368 }
370 # Handle object types
371 if cleaned_prop.get("type") == "object" and "properties" in prop:
372 cleaned_prop["properties"] = {
373 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")}
374 for k, v in prop["properties"].items()
375 }
377 # Handle required fields - if not explicitly nullable
378 if not prop.get("nullable", False):
379 required.append(name)
381 properties[name] = cleaned_prop
382 else:
383 properties[name] = {"type": "string"}
384 required.append(name)
386 # Create the OpenAI function schema
387 parameters = {
388 "type": "object",
389 "properties": properties,
390 }
392 # Only include required if there are required fields
393 if required: 393 ↛ 396line 393 didn't jump to line 396 because the condition on line 393 was always true
394 parameters["required"] = required
396 return {
397 "type": "function",
398 "function": {
399 "name": self.name,
400 "description": self.description,
401 "parameters": parameters,
402 },
403 }
405 def to_langchain(self) -> BaseTool:
406 """Convert this tool to LangChain format
408 Returns:
409 Tool in LangChain format
410 """
411 # Create properly annotated schema for the tool
412 schema_props: dict[str, Any] = {}
413 annotations: dict[str, Any] = {}
415 for name, details in self.parameters.properties.items():
416 python_type: type = str # Default to str
417 is_nullable = False
418 if isinstance(details, dict):
419 type_str = details.get("type", "string")
420 is_nullable = details.get("nullable", False)
421 if type_str == "number":
422 python_type = float
423 elif type_str == "integer":
424 python_type = int
425 elif type_str == "boolean":
426 python_type = bool
427 elif type_str == "object":
428 python_type = dict
429 elif type_str == "array": 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true
430 python_type = list
432 if is_nullable:
433 field = Field(default=None, description=details.get("description", ""))
434 else:
435 field = Field(description=details.get("description", ""))
436 else:
437 field = Field(description="")
439 schema_props[name] = field
440 if is_nullable:
441 annotations[name] = python_type | None
442 else:
443 annotations[name] = python_type
445 # Create the schema class with proper annotations
446 schema_class = type(
447 f"{self.name.title()}Args",
448 (BaseModel,),
449 {
450 "__annotations__": annotations,
451 "__module__": __name__,
452 **schema_props,
453 },
454 )
456 parent_tool = self
458 class StackOneLangChainTool(BaseTool):
459 name: str = parent_tool.name
460 description: str = parent_tool.description
461 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment]
462 func = staticmethod(parent_tool.execute) # Required by CrewAI
464 def _run(self, **kwargs: Any) -> Any:
465 return parent_tool.execute(kwargs)
467 return StackOneLangChainTool()
469 def set_account_id(self, account_id: str | None) -> None:
470 """Set the account ID for this tool
472 Args:
473 account_id: The account ID to use, or None to clear it
474 """
475 self._account_id = account_id
477 def get_account_id(self) -> str | None:
478 """Get the current account ID for this tool
480 Returns:
481 Current account ID or None if not set
482 """
483 return self._account_id
486class Tools:
487 """Container for Tool instances with lookup capabilities"""
489 def __init__(
490 self,
491 tools: list[StackOneTool],
492 ) -> None:
493 """Initialize Tools container
495 Args:
496 tools: List of Tool instances to manage
497 """
498 self.tools = tools
499 self._tool_map = {tool.name: tool for tool in tools}
501 def __getitem__(self, index: int) -> StackOneTool:
502 return self.tools[index]
504 def __len__(self) -> int:
505 return len(self.tools)
507 def __iter__(self) -> Any:
508 """Make Tools iterable"""
509 return iter(self.tools)
511 def to_list(self) -> list[StackOneTool]:
512 """Convert to list of tools
514 Returns:
515 List of StackOneTool instances
516 """
517 return list(self.tools)
519 def get_tool(self, name: str) -> StackOneTool | None:
520 """Get a tool by its name
522 Args:
523 name: Name of the tool to retrieve
525 Returns:
526 The tool if found, None otherwise
527 """
528 return self._tool_map.get(name)
530 def set_account_id(self, account_id: str | None) -> None:
531 """Set the account ID for all tools in this collection
533 Args:
534 account_id: The account ID to use, or None to clear it
535 """
536 for tool in self.tools:
537 tool.set_account_id(account_id)
539 def get_account_id(self) -> str | None:
540 """Get the current account ID for this collection
542 Returns:
543 The first non-None account ID found, or None if none set
544 """
545 for tool in self.tools:
546 account_id = tool.get_account_id()
547 if isinstance(account_id, str):
548 return account_id
549 return None
551 def get_connectors(self) -> set[str]:
552 """Get unique connector names from all tools.
554 Returns:
555 Set of connector names (lowercase)
557 Example:
558 tools = toolset.fetch_tools()
559 connectors = tools.get_connectors()
560 # {'bamboohr', 'hibob', 'slack', ...}
561 """
562 return {tool.connector for tool in self.tools}
564 def to_openai(self) -> list[JsonDict]:
565 """Convert all tools to OpenAI function format
567 Returns:
568 List of tools in OpenAI function format
569 """
570 return [tool.to_openai_function() for tool in self.tools]
572 def to_langchain(self) -> Sequence[BaseTool]:
573 """Convert all tools to LangChain format
575 Returns:
576 Sequence of tools in LangChain format
577 """
578 return [tool.to_langchain() for tool in self.tools]