Coverage for stackone_ai / models.py: 98%
251 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-08 18:25 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-08 18:25 +0000
1from __future__ import annotations
3import base64
4import json
5import logging
6from collections.abc import Sequence
7from datetime import datetime, timezone
8from enum import Enum
9from typing import Annotated, Any, ClassVar, TypeAlias, cast
10from urllib.parse import quote
12import httpx
13from langchain_core.tools import BaseTool
14from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr
16# Type aliases for common types
17JsonDict: TypeAlias = dict[str, Any]
18Headers: TypeAlias = dict[str, str]
21logger = logging.getLogger("stackone.tools")
24class StackOneError(Exception):
25 """Base exception for StackOne errors"""
27 pass
30class StackOneAPIError(StackOneError):
31 """Raised when the StackOne API returns an error"""
33 def __init__(self, message: str, status_code: int, response_body: Any) -> None:
34 super().__init__(message)
35 self.status_code = status_code
36 self.response_body = response_body
39class ParameterLocation(str, Enum):
40 """Valid locations for parameters in requests"""
42 HEADER = "header"
43 QUERY = "query"
44 PATH = "path"
45 BODY = "body"
46 FILE = "file" # For file uploads
49def validate_method(v: str) -> str:
50 """Validate HTTP method is uppercase and supported"""
51 method = v.upper()
52 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}:
53 raise ValueError(f"Unsupported HTTP method: {method}")
54 return method
57class ExecuteConfig(BaseModel):
58 """Configuration for executing a tool against an API endpoint"""
60 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request")
61 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use")
62 url: str = Field(description="API endpoint URL")
63 name: str = Field(description="Tool name")
64 body_type: str | None = Field(default=None, description="Content type for request body")
65 parameter_locations: dict[str, ParameterLocation] = Field(
66 default_factory=dict, description="Maps parameter names to their location in the request"
67 )
70class ToolParameters(BaseModel):
71 """Schema definition for tool parameters"""
73 type: str = Field(description="JSON Schema type")
74 properties: JsonDict = Field(description="JSON Schema properties")
77class ToolDefinition(BaseModel):
78 """Complete definition of a tool including its schema and execution config"""
80 description: str = Field(description="Tool description")
81 parameters: ToolParameters = Field(description="Tool parameter schema")
82 execute: ExecuteConfig = Field(description="Tool execution configuration")
85class StackOneTool(BaseModel):
86 """Base class for all StackOne tools. Provides functionality for executing API calls
87 and converting to various formats (OpenAI, LangChain)."""
89 name: str = Field(description="Tool name")
90 description: str = Field(description="Tool description")
91 parameters: ToolParameters = Field(description="Tool parameters")
92 _execute_config: ExecuteConfig = PrivateAttr()
93 _api_key: str = PrivateAttr()
94 _account_id: str | None = PrivateAttr(default=None)
95 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = {
96 "feedback_session_id",
97 "feedback_user_id",
98 "feedback_metadata",
99 }
101 def __init__(
102 self,
103 description: str,
104 parameters: ToolParameters,
105 _execute_config: ExecuteConfig,
106 _api_key: str,
107 _account_id: str | None = None,
108 ) -> None:
109 super().__init__(
110 name=_execute_config.name,
111 description=description,
112 parameters=parameters,
113 )
114 self._execute_config = _execute_config
115 self._api_key = _api_key
116 self._account_id = _account_id
118 @classmethod
119 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]:
120 merged_params = dict(params)
121 feedback_options = dict(options or {})
122 for key in cls._FEEDBACK_OPTION_KEYS:
123 if key in merged_params and key not in feedback_options:
124 feedback_options[key] = merged_params.pop(key)
125 return merged_params, feedback_options
127 def _prepare_headers(self) -> Headers:
128 """Prepare headers for the API request
130 Returns:
131 Headers to use in the request
132 """
133 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode()
134 headers: Headers = {
135 "Authorization": f"Basic {auth_string}",
136 "User-Agent": "stackone-python/1.0.0",
137 }
139 if self._account_id:
140 headers["x-account-id"] = self._account_id
142 # Add predefined headers
143 headers.update(self._execute_config.headers)
144 return headers
146 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]:
147 """Prepare URL and parameters for the API request
149 Args:
150 kwargs: Arguments to process
152 Returns:
153 Tuple of (url, body_params, query_params)
154 """
155 url = self._execute_config.url
156 body_params: JsonDict = {}
157 query_params: JsonDict = {}
159 for key, value in kwargs.items():
160 param_location = self._execute_config.parameter_locations.get(key)
162 if param_location == ParameterLocation.PATH:
163 # Safely encode path parameters to prevent SSRF attacks
164 encoded_value = quote(str(value), safe="")
165 url = url.replace(f"{{{key}}}", encoded_value)
166 elif param_location == ParameterLocation.QUERY:
167 query_params[key] = value
168 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE):
169 body_params[key] = value
170 else:
171 # Default behavior
172 if f"{{{key}}}" in url:
173 # Safely encode path parameters to prevent SSRF attacks
174 encoded_value = quote(str(value), safe="")
175 url = url.replace(f"{{{key}}}", encoded_value)
176 elif self._execute_config.method in {"GET", "DELETE"}:
177 query_params[key] = value
178 else:
179 body_params[key] = value
181 return url, body_params, query_params
183 def execute(
184 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
185 ) -> JsonDict:
186 """Execute the tool with the given parameters
188 Args:
189 arguments: Tool arguments as string or dict
190 options: Execution options (e.g. feedback metadata)
192 Returns:
193 API response as dict
195 Raises:
196 StackOneAPIError: If the API request fails
197 ValueError: If the arguments are invalid
198 """
199 datetime.now(timezone.utc)
200 feedback_options: JsonDict = {}
201 result_payload: JsonDict | None = None
202 response_status: int | None = None
203 error_message: str | None = None
204 status = "success"
205 url_used = self._execute_config.url
207 try:
208 if isinstance(arguments, str):
209 parsed_arguments = json.loads(arguments)
210 else:
211 parsed_arguments = arguments or {}
213 if not isinstance(parsed_arguments, dict):
214 status = "error"
215 error_message = "Tool arguments must be a JSON object"
216 raise ValueError(error_message)
218 kwargs = parsed_arguments
219 dict(kwargs)
221 headers = self._prepare_headers()
222 url_used, body_params, query_params = self._prepare_request_params(kwargs)
224 request_kwargs: dict[str, Any] = {
225 "method": self._execute_config.method,
226 "url": url_used,
227 "headers": headers,
228 }
230 if body_params:
231 body_type = self._execute_config.body_type or "json"
232 if body_type == "json":
233 request_kwargs["json"] = body_params
234 elif body_type == "form": 234 ↛ 237line 234 didn't jump to line 237 because the condition on line 234 was always true
235 request_kwargs["data"] = body_params
237 if query_params:
238 request_kwargs["params"] = query_params
240 response = httpx.request(**request_kwargs)
241 response_status = response.status_code
242 response.raise_for_status()
244 result = response.json()
245 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result}
246 return result_payload
248 except json.JSONDecodeError as exc:
249 status = "error"
250 error_message = f"Invalid JSON in arguments: {exc}"
251 raise ValueError(error_message) from exc
252 except httpx.HTTPStatusError as exc:
253 status = "error"
254 response_body = None
255 if exc.response.text: 255 ↛ 260line 255 didn't jump to line 260 because the condition on line 255 was always true
256 try:
257 response_body = exc.response.json()
258 except json.JSONDecodeError:
259 response_body = exc.response.text
260 raise StackOneAPIError(
261 str(exc),
262 exc.response.status_code,
263 response_body,
264 ) from exc
265 except httpx.RequestError as exc:
266 status = "error"
267 raise StackOneError(f"Request failed: {exc}") from exc
268 finally:
269 datetime.now(timezone.utc)
270 metadata: JsonDict = {
271 "http_method": self._execute_config.method,
272 "url": url_used,
273 "status_code": response_status,
274 "status": status,
275 }
277 feedback_metadata = feedback_options.get("feedback_metadata")
278 if isinstance(feedback_metadata, dict): 278 ↛ 279line 278 didn't jump to line 279 because the condition on line 278 was never true
279 metadata["feedback_metadata"] = feedback_metadata
281 if feedback_options:
282 metadata["feedback_options"] = {
283 key: value
284 for key, value in feedback_options.items()
285 if key in {"feedback_session_id", "feedback_user_id"} and value is not None
286 }
288 # Implicit feedback removed - just API calls
290 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
291 """Call the tool with the given arguments
293 This method provides a more intuitive way to execute tools directly.
295 Args:
296 *args: If a single argument is provided, it's treated as the full arguments dict/string
297 **kwargs: Keyword arguments to pass to the tool
298 options: Optional execution options
300 Returns:
301 API response as dict
303 Raises:
304 StackOneAPIError: If the API request fails
305 ValueError: If the arguments are invalid
307 Examples:
308 >>> tool.call({"name": "John", "email": "john@example.com"})
309 >>> tool.call(name="John", email="john@example.com")
310 """
311 if args and kwargs:
312 raise ValueError("Cannot provide both positional and keyword arguments")
314 if args:
315 if len(args) > 1:
316 raise ValueError("Only one positional argument is allowed")
317 return self.execute(args[0])
319 return self.execute(kwargs if kwargs else None)
321 def to_openai_function(self) -> JsonDict:
322 """Convert this tool to OpenAI's function format
324 Returns:
325 Tool definition in OpenAI function format
326 """
327 # Clean properties and handle special types
328 properties = {}
329 required = []
331 for name, prop in self.parameters.properties.items():
332 if isinstance(prop, dict):
333 # Only keep standard JSON Schema properties
334 cleaned_prop = {}
336 # Copy basic properties
337 if "type" in prop:
338 cleaned_prop["type"] = prop["type"]
339 if "description" in prop:
340 cleaned_prop["description"] = prop["description"]
341 if "enum" in prop:
342 cleaned_prop["enum"] = prop["enum"]
344 # Handle array types
345 if cleaned_prop.get("type") == "array" and "items" in prop:
346 if isinstance(prop["items"], dict): 346 ↛ 352line 346 didn't jump to line 352 because the condition on line 346 was always true
347 cleaned_prop["items"] = {
348 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum")
349 }
351 # Handle object types
352 if cleaned_prop.get("type") == "object" and "properties" in prop:
353 cleaned_prop["properties"] = {
354 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")}
355 for k, v in prop["properties"].items()
356 }
358 # Handle required fields - if not explicitly nullable
359 if not prop.get("nullable", False): 359 ↛ 362line 359 didn't jump to line 362 because the condition on line 359 was always true
360 required.append(name)
362 properties[name] = cleaned_prop
363 else:
364 properties[name] = {"type": "string"}
365 required.append(name)
367 # Create the OpenAI function schema
368 parameters = {
369 "type": "object",
370 "properties": properties,
371 }
373 # Only include required if there are required fields
374 if required: 374 ↛ 377line 374 didn't jump to line 377 because the condition on line 374 was always true
375 parameters["required"] = required
377 return {
378 "type": "function",
379 "function": {
380 "name": self.name,
381 "description": self.description,
382 "parameters": parameters,
383 },
384 }
386 def to_langchain(self) -> BaseTool:
387 """Convert this tool to LangChain format
389 Returns:
390 Tool in LangChain format
391 """
392 # Create properly annotated schema for the tool
393 schema_props: dict[str, Any] = {}
394 annotations: dict[str, Any] = {}
396 for name, details in self.parameters.properties.items():
397 python_type: type = str # Default to str
398 if isinstance(details, dict):
399 type_str = details.get("type", "string")
400 if type_str == "number":
401 python_type = float
402 elif type_str == "integer":
403 python_type = int
404 elif type_str == "boolean":
405 python_type = bool
407 field = Field(description=details.get("description", ""))
408 else:
409 field = Field(description="")
411 schema_props[name] = field
412 annotations[name] = python_type
414 # Create the schema class with proper annotations
415 schema_class = type(
416 f"{self.name.title()}Args",
417 (BaseModel,),
418 {
419 "__annotations__": annotations,
420 "__module__": __name__,
421 **schema_props,
422 },
423 )
425 parent_tool = self
427 class StackOneLangChainTool(BaseTool):
428 name: str = parent_tool.name
429 description: str = parent_tool.description
430 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment]
431 func = staticmethod(parent_tool.execute) # Required by CrewAI
433 def _run(self, **kwargs: Any) -> Any:
434 return parent_tool.execute(kwargs)
436 return StackOneLangChainTool()
438 def set_account_id(self, account_id: str | None) -> None:
439 """Set the account ID for this tool
441 Args:
442 account_id: The account ID to use, or None to clear it
443 """
444 self._account_id = account_id
446 def get_account_id(self) -> str | None:
447 """Get the current account ID for this tool
449 Returns:
450 Current account ID or None if not set
451 """
452 return self._account_id
455class Tools:
456 """Container for Tool instances with lookup capabilities"""
458 def __init__(self, tools: list[StackOneTool]) -> None:
459 """Initialize Tools container
461 Args:
462 tools: List of Tool instances to manage
463 """
464 self.tools = tools
465 self._tool_map = {tool.name: tool for tool in tools}
467 def __getitem__(self, index: int) -> StackOneTool:
468 return self.tools[index]
470 def __len__(self) -> int:
471 return len(self.tools)
473 def __iter__(self) -> Any:
474 """Make Tools iterable"""
475 return iter(self.tools)
477 def to_list(self) -> list[StackOneTool]:
478 """Convert to list of tools
480 Returns:
481 List of StackOneTool instances
482 """
483 return list(self.tools)
485 def get_tool(self, name: str) -> StackOneTool | None:
486 """Get a tool by its name
488 Args:
489 name: Name of the tool to retrieve
491 Returns:
492 The tool if found, None otherwise
493 """
494 return self._tool_map.get(name)
496 def set_account_id(self, account_id: str | None) -> None:
497 """Set the account ID for all tools in this collection
499 Args:
500 account_id: The account ID to use, or None to clear it
501 """
502 for tool in self.tools:
503 tool.set_account_id(account_id)
505 def get_account_id(self) -> str | None:
506 """Get the current account ID for this collection
508 Returns:
509 The first non-None account ID found, or None if none set
510 """
511 for tool in self.tools:
512 account_id = tool.get_account_id()
513 if isinstance(account_id, str):
514 return account_id
515 return None
517 def to_openai(self) -> list[JsonDict]:
518 """Convert all tools to OpenAI function format
520 Returns:
521 List of tools in OpenAI function format
522 """
523 return [tool.to_openai_function() for tool in self.tools]
525 def to_langchain(self) -> Sequence[BaseTool]:
526 """Convert all tools to LangChain format
528 Returns:
529 Sequence of tools in LangChain format
530 """
531 return [tool.to_langchain() for tool in self.tools]
533 def utility_tools(self, hybrid_alpha: float | None = None) -> Tools:
534 """Return utility tools for tool discovery and execution
536 Utility tools enable dynamic tool discovery and execution based on natural language queries
537 using hybrid BM25 + TF-IDF search.
539 Args:
540 hybrid_alpha: Weight for BM25 in hybrid search (0-1). If not provided, uses
541 ToolIndex.DEFAULT_HYBRID_ALPHA (0.2), which gives more weight to BM25 scoring
542 and has been shown to provide better tool discovery accuracy
543 (10.8% improvement in validation testing).
545 Returns:
546 Tools collection containing tool_search and tool_execute
548 Note:
549 This feature is in beta and may change in future versions
550 """
551 from stackone_ai.utility_tools import (
552 ToolIndex,
553 create_tool_execute,
554 create_tool_search,
555 )
557 # Create search index with hybrid search
558 index = ToolIndex(self.tools, hybrid_alpha=hybrid_alpha)
560 # Create utility tools
561 filter_tool = create_tool_search(index)
562 execute_tool = create_tool_execute(self)
564 return Tools([filter_tool, execute_tool])