from typing import Annotated, Any, Literal, overload
from pydantic import BaseModel, Discriminator, Field, Tag, TypeAdapter
[docs]
class ValueMatcher(BaseModel):
pass
[docs]
class EqualsMatcher(ValueMatcher):
"""Matches an argument if it equals a given value."""
match_as: Literal["equality"] = "equality"
"""Discriminator field"""
value: str | int | float | bool
"""The value to compare the argument to."""
[docs]
class FreeTextMatcher(ValueMatcher):
"""Matches an argument if semantically it achieves the same goal as the given free text."""
match_as: Literal["free_text"] = "free_text"
"""Discriminator field"""
value: str
"""The free text to compare the argument to."""
[docs]
class DateTimeMatcher(ValueMatcher):
"""Matches an argument if it matches a given date and / or time.
The match is determined as semantic equivalence, given the actual argument value, and the request time as provided
in the output environment.
"""
match_as: Literal["date_time"] = "date_time"
"""Discriminator field"""
value: str
"""A natural language description of the date and / or time to compare the argument to."""
[docs]
class EmailMatcher(ValueMatcher):
"""Matches an argument if it matches the given email address."""
match_as: Literal["email"] = "email"
"""Discriminator field"""
value: str
"""The email address to compare the argument to."""
[docs]
class MissingMatcher(BaseModel):
"""Matches an argument if it is missing (not provided)."""
match_as: Literal["missing"] = "missing"
"""Discriminator field"""
[docs]
class OptionalMatcher[T: ValueMatcher](BaseModel):
"""Matches an argument either if it is missing (not provided) or its value matches the given matcher."""
match_as: Literal["optional"] = "optional"
"""Discriminator field"""
default: T
"""The matcher to use to match the argument."""
[docs]
class OneParameterAssertion(BaseModel):
"""Passes when a single argument matches the given matcher."""
param: str
"""The name of the parameter to match."""
matcher: Annotated[
EqualsMatcher | FreeTextMatcher | DateTimeMatcher | EmailMatcher | MissingMatcher | OptionalMatcher,
Field(discriminator="match_as"),
]
"""The matcher representing a passing assertion."""
@property
def params(self) -> list[str]:
return [self.param]
[docs]
class ParameterGroupAssertion(BaseModel):
"""Passes when a group of arguments matches the given matcher."""
params: list[str]
"""The names of the parameters to assert."""
matcher: Annotated[
FreeTextMatcher
| DateTimeMatcher
| OptionalMatcher[Annotated[FreeTextMatcher | DateTimeMatcher, Field(discriminator="match_as")]],
Field(discriminator="match_as"),
]
"""The matcher representing a passing assertion."""
GroupMatcherDefault = Annotated[FreeTextMatcher | DateTimeMatcher, Field(discriminator="match_as")]
OneParameterMatcher = (
EqualsMatcher | FreeTextMatcher | DateTimeMatcher | EmailMatcher | MissingMatcher | OptionalMatcher
)
ParameterGroupMatcher = FreeTextMatcher | DateTimeMatcher | OptionalMatcher[GroupMatcherDefault]
def _parameter_assertion_discriminator(value: Any) -> str:
if isinstance(value, OneParameterAssertion):
return "one"
if isinstance(value, ParameterGroupAssertion):
return "group"
if isinstance(value, dict):
has_param = "param" in value
has_params = "params" in value
if has_param ^ has_params:
return "one" if has_param else "group"
if has_param:
raise ValueError("ParameterAssertion must include exactly one of 'param' or 'params'.")
raise ValueError("ParameterAssertion must include either 'param' or 'params'.")
ParameterAssertion = Annotated[
Annotated[OneParameterAssertion, Tag("one")] | Annotated[ParameterGroupAssertion, Tag("group")],
Discriminator(_parameter_assertion_discriminator),
]
"""Discriminated union of either an assertion about a single argument or a group of arguments.
Discriminated by the presense of either `param` or `params` field.
"""
@overload
def parameter_assertion(*, param: str, matcher: OneParameterMatcher) -> ParameterAssertion: ...
@overload
def parameter_assertion(*, params: list[str], matcher: ParameterGroupMatcher) -> ParameterAssertion: ...
[docs]
def parameter_assertion(**data: Any) -> ParameterAssertion:
"""A convenience function to create a parameter assertion.
Examples:
```python
parameter_assertion(param="name", matcher=EqualsMatcher(value="John"))
parameter_assertion(params=["name", "age"], matcher=FreeTextMatcher(value="John, 20 years old"))
```
"""
return TypeAdapter(ParameterAssertion).validate_python(data)
Assertion = Annotated[ToolCallAssertion | NoToolCallAssertion, Field(discriminator="assert_that")]
[docs]
class Expectations(BaseModel, extra="allow"):
"""Agent expectations description."""
expected_response: str | None = None
"""The expected response to the user's question.
If not provided, the agent's response will not be evaluated against any expected response.
"""
assertions: list[Assertion] = Field(default_factory=list)
"""Assertions about the agent's output, such as tool call correctness, abstention, guidelines, etc...
Currently, only tool call correctness assertions are supported.
"""