Source code for mb.entities.expectations

from typing import Annotated, Any, Literal, overload

from pydantic import BaseModel, Discriminator, Field, Tag, TypeAdapter


[docs] class ValueMatcher(BaseModel): pass
[docs] class EqualsMatcher(ValueMatcher): """Matches an argument if it equals a given value.""" match_as: Literal["equality"] = "equality" """Discriminator field""" value: str | int | float | bool """The value to compare the argument to."""
[docs] class FreeTextMatcher(ValueMatcher): """Matches an argument if semantically it achieves the same goal as the given free text.""" match_as: Literal["free_text"] = "free_text" """Discriminator field""" value: str """The free text to compare the argument to."""
[docs] class DateTimeMatcher(ValueMatcher): """Matches an argument if it matches a given date and / or time. The match is determined as semantic equivalence, given the actual argument value, and the request time as provided in the output environment. """ match_as: Literal["date_time"] = "date_time" """Discriminator field""" value: str """A natural language description of the date and / or time to compare the argument to."""
[docs] class EmailMatcher(ValueMatcher): """Matches an argument if it matches the given email address.""" match_as: Literal["email"] = "email" """Discriminator field""" value: str """The email address to compare the argument to."""
[docs] class MissingMatcher(BaseModel): """Matches an argument if it is missing (not provided).""" match_as: Literal["missing"] = "missing" """Discriminator field"""
[docs] class OptionalMatcher[T: ValueMatcher](BaseModel): """Matches an argument either if it is missing (not provided) or its value matches the given matcher.""" match_as: Literal["optional"] = "optional" """Discriminator field""" default: T """The matcher to use to match the argument."""
[docs] class OneParameterAssertion(BaseModel): """Passes when a single argument matches the given matcher.""" param: str """The name of the parameter to match.""" matcher: Annotated[ EqualsMatcher | FreeTextMatcher | DateTimeMatcher | EmailMatcher | MissingMatcher | OptionalMatcher, Field(discriminator="match_as"), ] """The matcher representing a passing assertion.""" @property def params(self) -> list[str]: return [self.param]
[docs] class ParameterGroupAssertion(BaseModel): """Passes when a group of arguments matches the given matcher.""" params: list[str] """The names of the parameters to assert.""" matcher: Annotated[ FreeTextMatcher | DateTimeMatcher | OptionalMatcher[Annotated[FreeTextMatcher | DateTimeMatcher, Field(discriminator="match_as")]], Field(discriminator="match_as"), ] """The matcher representing a passing assertion."""
GroupMatcherDefault = Annotated[FreeTextMatcher | DateTimeMatcher, Field(discriminator="match_as")] OneParameterMatcher = ( EqualsMatcher | FreeTextMatcher | DateTimeMatcher | EmailMatcher | MissingMatcher | OptionalMatcher ) ParameterGroupMatcher = FreeTextMatcher | DateTimeMatcher | OptionalMatcher[GroupMatcherDefault] def _parameter_assertion_discriminator(value: Any) -> str: if isinstance(value, OneParameterAssertion): return "one" if isinstance(value, ParameterGroupAssertion): return "group" if isinstance(value, dict): has_param = "param" in value has_params = "params" in value if has_param ^ has_params: return "one" if has_param else "group" if has_param: raise ValueError("ParameterAssertion must include exactly one of 'param' or 'params'.") raise ValueError("ParameterAssertion must include either 'param' or 'params'.") ParameterAssertion = Annotated[ Annotated[OneParameterAssertion, Tag("one")] | Annotated[ParameterGroupAssertion, Tag("group")], Discriminator(_parameter_assertion_discriminator), ] """Discriminated union of either an assertion about a single argument or a group of arguments. Discriminated by the presense of either `param` or `params` field. """ @overload def parameter_assertion(*, param: str, matcher: OneParameterMatcher) -> ParameterAssertion: ... @overload def parameter_assertion(*, params: list[str], matcher: ParameterGroupMatcher) -> ParameterAssertion: ...
[docs] def parameter_assertion(**data: Any) -> ParameterAssertion: """A convenience function to create a parameter assertion. Examples: ```python parameter_assertion(param="name", matcher=EqualsMatcher(value="John")) parameter_assertion(params=["name", "age"], matcher=FreeTextMatcher(value="John, 20 years old")) ``` """ return TypeAdapter(ParameterAssertion).validate_python(data)
[docs] class ToolCallAssertion(BaseModel): """Assert that a tool call was made with given parameters.""" assert_that: Literal["tool_called"] = "tool_called" tool: str parameters: list[ParameterAssertion] = []
[docs] class NoToolCallAssertion(BaseModel): """Assert that no tool call was made.""" assert_that: Literal["no_tool_called"] = "no_tool_called"
Assertion = Annotated[ToolCallAssertion | NoToolCallAssertion, Field(discriminator="assert_that")]
[docs] class Expectations(BaseModel, extra="allow"): """Agent expectations description.""" expected_response: str | None = None """The expected response to the user's question. If not provided, the agent's response will not be evaluated against any expected response. """ assertions: list[Assertion] = Field(default_factory=list) """Assertions about the agent's output, such as tool call correctness, abstention, guidelines, etc... Currently, only tool call correctness assertions are supported. """