Coverage for stackone_ai/models.py: 95%
252 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-24 09:48 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-12-24 09:48 +0000
1# TODO: Remove when Python 3.9 support is dropped
2from __future__ import annotations
4import base64
5import json
6import logging
7from collections.abc import Sequence
8from datetime import datetime, timezone
9from enum import Enum
10from typing import Annotated, Any, ClassVar, cast
11from urllib.parse import quote
13import httpx
14from langchain_core.tools import BaseTool
15from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr
17# TODO: Remove when Python 3.9 support is dropped
18from typing_extensions import TypeAlias
20# Type aliases for common types
21JsonDict: TypeAlias = dict[str, Any]
22Headers: TypeAlias = dict[str, str]
25logger = logging.getLogger("stackone.tools")
28class StackOneError(Exception):
29 """Base exception for StackOne errors"""
31 pass
34class StackOneAPIError(StackOneError):
35 """Raised when the StackOne API returns an error"""
37 def __init__(self, message: str, status_code: int, response_body: Any) -> None:
38 super().__init__(message)
39 self.status_code = status_code
40 self.response_body = response_body
43class ParameterLocation(str, Enum):
44 """Valid locations for parameters in requests"""
46 HEADER = "header"
47 QUERY = "query"
48 PATH = "path"
49 BODY = "body"
50 FILE = "file" # For file uploads
53def validate_method(v: str) -> str:
54 """Validate HTTP method is uppercase and supported"""
55 method = v.upper()
56 if method not in {"GET", "POST", "PUT", "DELETE", "PATCH"}:
57 raise ValueError(f"Unsupported HTTP method: {method}")
58 return method
61class ExecuteConfig(BaseModel):
62 """Configuration for executing a tool against an API endpoint"""
64 headers: Headers = Field(default_factory=dict, description="HTTP headers to include in the request")
65 method: Annotated[str, BeforeValidator(validate_method)] = Field(description="HTTP method to use")
66 url: str = Field(description="API endpoint URL")
67 name: str = Field(description="Tool name")
68 body_type: str | None = Field(default=None, description="Content type for request body")
69 parameter_locations: dict[str, ParameterLocation] = Field(
70 default_factory=dict, description="Maps parameter names to their location in the request"
71 )
74class ToolParameters(BaseModel):
75 """Schema definition for tool parameters"""
77 type: str = Field(description="JSON Schema type")
78 properties: JsonDict = Field(description="JSON Schema properties")
81class ToolDefinition(BaseModel):
82 """Complete definition of a tool including its schema and execution config"""
84 description: str = Field(description="Tool description")
85 parameters: ToolParameters = Field(description="Tool parameter schema")
86 execute: ExecuteConfig = Field(description="Tool execution configuration")
89class StackOneTool(BaseModel):
90 """Base class for all StackOne tools. Provides functionality for executing API calls
91 and converting to various formats (OpenAI, LangChain)."""
93 name: str = Field(description="Tool name")
94 description: str = Field(description="Tool description")
95 parameters: ToolParameters = Field(description="Tool parameters")
96 _execute_config: ExecuteConfig = PrivateAttr()
97 _api_key: str = PrivateAttr()
98 _account_id: str | None = PrivateAttr(default=None)
99 _FEEDBACK_OPTION_KEYS: ClassVar[set[str]] = {
100 "feedback_session_id",
101 "feedback_user_id",
102 "feedback_metadata",
103 }
105 def __init__(
106 self,
107 description: str,
108 parameters: ToolParameters,
109 _execute_config: ExecuteConfig,
110 _api_key: str,
111 _account_id: str | None = None,
112 ) -> None:
113 super().__init__(
114 name=_execute_config.name,
115 description=description,
116 parameters=parameters,
117 )
118 self._execute_config = _execute_config
119 self._api_key = _api_key
120 self._account_id = _account_id
122 @classmethod
123 def _split_feedback_options(cls, params: JsonDict, options: JsonDict | None) -> tuple[JsonDict, JsonDict]:
124 merged_params = dict(params)
125 feedback_options = dict(options or {})
126 for key in cls._FEEDBACK_OPTION_KEYS:
127 if key in merged_params and key not in feedback_options:
128 feedback_options[key] = merged_params.pop(key)
129 return merged_params, feedback_options
131 def _prepare_headers(self) -> Headers:
132 """Prepare headers for the API request
134 Returns:
135 Headers to use in the request
136 """
137 auth_string = base64.b64encode(f"{self._api_key}:".encode()).decode()
138 headers: Headers = {
139 "Authorization": f"Basic {auth_string}",
140 "User-Agent": "stackone-python/1.0.0",
141 }
143 if self._account_id:
144 headers["x-account-id"] = self._account_id
146 # Add predefined headers
147 headers.update(self._execute_config.headers)
148 return headers
150 def _prepare_request_params(self, kwargs: JsonDict) -> tuple[str, JsonDict, JsonDict]:
151 """Prepare URL and parameters for the API request
153 Args:
154 kwargs: Arguments to process
156 Returns:
157 Tuple of (url, body_params, query_params)
158 """
159 url = self._execute_config.url
160 body_params: JsonDict = {}
161 query_params: JsonDict = {}
163 for key, value in kwargs.items():
164 param_location = self._execute_config.parameter_locations.get(key)
166 if param_location == ParameterLocation.PATH:
167 # Safely encode path parameters to prevent SSRF attacks
168 encoded_value = quote(str(value), safe="")
169 url = url.replace(f"{{{key}}}", encoded_value)
170 elif param_location == ParameterLocation.QUERY:
171 query_params[key] = value
172 elif param_location in (ParameterLocation.BODY, ParameterLocation.FILE):
173 body_params[key] = value
174 else:
175 # Default behavior
176 if f"{{{key}}}" in url:
177 # Safely encode path parameters to prevent SSRF attacks
178 encoded_value = quote(str(value), safe="")
179 url = url.replace(f"{{{key}}}", encoded_value)
180 elif self._execute_config.method in {"GET", "DELETE"}:
181 query_params[key] = value
182 else:
183 body_params[key] = value
185 return url, body_params, query_params
187 def execute(
188 self, arguments: str | JsonDict | None = None, *, options: JsonDict | None = None
189 ) -> JsonDict:
190 """Execute the tool with the given parameters
192 Args:
193 arguments: Tool arguments as string or dict
194 options: Execution options (e.g. feedback metadata)
196 Returns:
197 API response as dict
199 Raises:
200 StackOneAPIError: If the API request fails
201 ValueError: If the arguments are invalid
202 """
203 datetime.now(timezone.utc)
204 feedback_options: JsonDict = {}
205 result_payload: JsonDict | None = None
206 response_status: int | None = None
207 error_message: str | None = None
208 status = "success"
209 url_used = self._execute_config.url
211 try:
212 if isinstance(arguments, str):
213 parsed_arguments = json.loads(arguments)
214 else:
215 parsed_arguments = arguments or {}
217 if not isinstance(parsed_arguments, dict):
218 status = "error"
219 error_message = "Tool arguments must be a JSON object"
220 raise ValueError(error_message)
222 kwargs = parsed_arguments
223 dict(kwargs)
225 headers = self._prepare_headers()
226 url_used, body_params, query_params = self._prepare_request_params(kwargs)
228 request_kwargs: dict[str, Any] = {
229 "method": self._execute_config.method,
230 "url": url_used,
231 "headers": headers,
232 }
234 if body_params:
235 body_type = self._execute_config.body_type or "json"
236 if body_type == "json":
237 request_kwargs["json"] = body_params
238 elif body_type == "form": 238 ↛ 241line 238 didn't jump to line 241 because the condition on line 238 was always true
239 request_kwargs["data"] = body_params
241 if query_params:
242 request_kwargs["params"] = query_params
244 response = httpx.request(**request_kwargs)
245 response_status = response.status_code
246 response.raise_for_status()
248 result = response.json()
249 result_payload = cast(JsonDict, result) if isinstance(result, dict) else {"result": result}
250 return result_payload
252 except json.JSONDecodeError as exc:
253 status = "error"
254 error_message = f"Invalid JSON in arguments: {exc}"
255 raise ValueError(error_message) from exc
256 except httpx.HTTPStatusError as exc:
257 status = "error"
258 response_body = None
259 if exc.response.text: 259 ↛ 264line 259 didn't jump to line 264 because the condition on line 259 was always true
260 try:
261 response_body = exc.response.json()
262 except json.JSONDecodeError:
263 response_body = exc.response.text
264 raise StackOneAPIError(
265 str(exc),
266 exc.response.status_code,
267 response_body,
268 ) from exc
269 except httpx.RequestError as exc:
270 status = "error"
271 raise StackOneError(f"Request failed: {exc}") from exc
272 finally:
273 datetime.now(timezone.utc)
274 metadata: JsonDict = {
275 "http_method": self._execute_config.method,
276 "url": url_used,
277 "status_code": response_status,
278 "status": status,
279 }
281 feedback_metadata = feedback_options.get("feedback_metadata")
282 if isinstance(feedback_metadata, dict): 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 metadata["feedback_metadata"] = feedback_metadata
285 if feedback_options:
286 metadata["feedback_options"] = {
287 key: value
288 for key, value in feedback_options.items()
289 if key in {"feedback_session_id", "feedback_user_id"} and value is not None
290 }
292 # Implicit feedback removed - just API calls
294 def call(self, *args: Any, options: JsonDict | None = None, **kwargs: Any) -> JsonDict:
295 """Call the tool with the given arguments
297 This method provides a more intuitive way to execute tools directly.
299 Args:
300 *args: If a single argument is provided, it's treated as the full arguments dict/string
301 **kwargs: Keyword arguments to pass to the tool
302 options: Optional execution options
304 Returns:
305 API response as dict
307 Raises:
308 StackOneAPIError: If the API request fails
309 ValueError: If the arguments are invalid
311 Examples:
312 >>> tool.call({"name": "John", "email": "john@example.com"})
313 >>> tool.call(name="John", email="john@example.com")
314 """
315 if args and kwargs:
316 raise ValueError("Cannot provide both positional and keyword arguments")
318 if args:
319 if len(args) > 1:
320 raise ValueError("Only one positional argument is allowed")
321 return self.execute(args[0])
323 return self.execute(kwargs if kwargs else None)
325 def to_openai_function(self) -> JsonDict:
326 """Convert this tool to OpenAI's function format
328 Returns:
329 Tool definition in OpenAI function format
330 """
331 # Clean properties and handle special types
332 properties = {}
333 required = []
335 for name, prop in self.parameters.properties.items():
336 if isinstance(prop, dict):
337 # Only keep standard JSON Schema properties
338 cleaned_prop = {}
340 # Copy basic properties
341 if "type" in prop:
342 cleaned_prop["type"] = prop["type"]
343 if "description" in prop:
344 cleaned_prop["description"] = prop["description"]
345 if "enum" in prop:
346 cleaned_prop["enum"] = prop["enum"]
348 # Handle array types
349 if cleaned_prop.get("type") == "array" and "items" in prop:
350 if isinstance(prop["items"], dict): 350 ↛ 356line 350 didn't jump to line 356 because the condition on line 350 was always true
351 cleaned_prop["items"] = {
352 k: v for k, v in prop["items"].items() if k in ("type", "description", "enum")
353 }
355 # Handle object types
356 if cleaned_prop.get("type") == "object" and "properties" in prop:
357 cleaned_prop["properties"] = {
358 k: {sk: sv for sk, sv in v.items() if sk in ("type", "description", "enum")}
359 for k, v in prop["properties"].items()
360 }
362 # Handle required fields - if not explicitly nullable
363 if not prop.get("nullable", False): 363 ↛ 366line 363 didn't jump to line 366 because the condition on line 363 was always true
364 required.append(name)
366 properties[name] = cleaned_prop
367 else:
368 properties[name] = {"type": "string"}
369 required.append(name)
371 # Create the OpenAI function schema
372 parameters = {
373 "type": "object",
374 "properties": properties,
375 }
377 # Only include required if there are required fields
378 if required: 378 ↛ 381line 378 didn't jump to line 381 because the condition on line 378 was always true
379 parameters["required"] = required
381 return {
382 "type": "function",
383 "function": {
384 "name": self.name,
385 "description": self.description,
386 "parameters": parameters,
387 },
388 }
390 def to_langchain(self) -> BaseTool:
391 """Convert this tool to LangChain format
393 Returns:
394 Tool in LangChain format
395 """
396 # Create properly annotated schema for the tool
397 schema_props: dict[str, Any] = {}
398 annotations: dict[str, Any] = {}
400 for name, details in self.parameters.properties.items():
401 python_type: type = str # Default to str
402 if isinstance(details, dict):
403 type_str = details.get("type", "string")
404 if type_str == "number":
405 python_type = float
406 elif type_str == "integer":
407 python_type = int
408 elif type_str == "boolean":
409 python_type = bool
411 field = Field(description=details.get("description", ""))
412 else:
413 field = Field(description="")
415 schema_props[name] = field
416 annotations[name] = python_type
418 # Create the schema class with proper annotations
419 schema_class = type(
420 f"{self.name.title()}Args",
421 (BaseModel,),
422 {
423 "__annotations__": annotations,
424 "__module__": __name__,
425 **schema_props,
426 },
427 )
429 parent_tool = self
431 class StackOneLangChainTool(BaseTool):
432 name: str = parent_tool.name
433 description: str = parent_tool.description
434 args_schema: type[BaseModel] = schema_class # ty: ignore[invalid-assignment]
435 func = staticmethod(parent_tool.execute) # Required by CrewAI
437 def _run(self, **kwargs: Any) -> Any:
438 return parent_tool.execute(kwargs)
440 return StackOneLangChainTool()
442 def set_account_id(self, account_id: str | None) -> None:
443 """Set the account ID for this tool
445 Args:
446 account_id: The account ID to use, or None to clear it
447 """
448 self._account_id = account_id
450 def get_account_id(self) -> str | None:
451 """Get the current account ID for this tool
453 Returns:
454 Current account ID or None if not set
455 """
456 return self._account_id
459class Tools:
460 """Container for Tool instances with lookup capabilities"""
462 def __init__(self, tools: list[StackOneTool]) -> None:
463 """Initialize Tools container
465 Args:
466 tools: List of Tool instances to manage
467 """
468 self.tools = tools
469 self._tool_map = {tool.name: tool for tool in tools}
471 def __getitem__(self, index: int) -> StackOneTool:
472 return self.tools[index]
474 def __len__(self) -> int:
475 return len(self.tools)
477 def __iter__(self) -> Any:
478 """Make Tools iterable"""
479 return iter(self.tools)
481 def to_list(self) -> list[StackOneTool]:
482 """Convert to list of tools
484 Returns:
485 List of StackOneTool instances
486 """
487 return list(self.tools)
489 def get_tool(self, name: str) -> StackOneTool | None:
490 """Get a tool by its name
492 Args:
493 name: Name of the tool to retrieve
495 Returns:
496 The tool if found, None otherwise
497 """
498 return self._tool_map.get(name)
500 def set_account_id(self, account_id: str | None) -> None:
501 """Set the account ID for all tools in this collection
503 Args:
504 account_id: The account ID to use, or None to clear it
505 """
506 for tool in self.tools:
507 tool.set_account_id(account_id)
509 def get_account_id(self) -> str | None:
510 """Get the current account ID for this collection
512 Returns:
513 The first non-None account ID found, or None if none set
514 """
515 for tool in self.tools:
516 account_id = tool.get_account_id()
517 if isinstance(account_id, str):
518 return account_id
519 return None
521 def to_openai(self) -> list[JsonDict]:
522 """Convert all tools to OpenAI function format
524 Returns:
525 List of tools in OpenAI function format
526 """
527 return [tool.to_openai_function() for tool in self.tools]
529 def to_langchain(self) -> Sequence[BaseTool]:
530 """Convert all tools to LangChain format
532 Returns:
533 Sequence of tools in LangChain format
534 """
535 return [tool.to_langchain() for tool in self.tools]
537 def meta_tools(self, hybrid_alpha: float | None = None) -> Tools:
538 """Return meta tools for tool discovery and execution
540 Meta tools enable dynamic tool discovery and execution based on natural language queries
541 using hybrid BM25 + TF-IDF search.
543 Args:
544 hybrid_alpha: Weight for BM25 in hybrid search (0-1). If not provided, uses
545 ToolIndex.DEFAULT_HYBRID_ALPHA (0.2), which gives more weight to BM25 scoring
546 and has been shown to provide better tool discovery accuracy
547 (10.8% improvement in validation testing).
549 Returns:
550 Tools collection containing meta_search_tools and meta_execute_tool
552 Note:
553 This feature is in beta and may change in future versions
554 """
555 from stackone_ai.meta_tools import (
556 ToolIndex,
557 create_meta_execute_tool,
558 create_meta_search_tools,
559 )
561 # Create search index with hybrid search
562 index = ToolIndex(self.tools, hybrid_alpha=hybrid_alpha)
564 # Create meta tools
565 filter_tool = create_meta_search_tools(index)
566 execute_tool = create_meta_execute_tool(self)
568 return Tools([filter_tool, execute_tool])