index
Tool Selection Evaluation
Tool selection refers to the ability of an LLM to select the appropriate tools from a list in order to respond to a user query.
This notebook walks through how to measure the selected tool precision, including a follow step to try to automatically update the tool descriptions to address errors present in the first pass.
We will use a subset of the ToolBench dataset in these examples.
%pip install -U langchain langchain_openai
import os
# Update with your API URL if using a hosted instance of Langsmith.
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = "YOUR API KEY" # Update with your API key
# Optional: "default" is used if not set
os.environ["LANGCHAIN_PROJECT"] = "Tool Selection"
dev_dataset_url = (
"https://smith.langchain.com/public/bdf7611c-3420-4c71-a492-42715a32d61e/d"
)
dataset_name = "Tool Selection (Logistics) dev"
import langsmith
client = langsmith.Client()
client.clone_public_dataset(dev_dataset_url)
Define Metrics
We will compute the intersection over union metric of the tools selected for the first logical step.
from typing import Set
from langchain.smith import RunEvalConfig
from langsmith.evaluation import run_evaluator
@run_evaluator
def selected_tools_precision(run, example):
expected = example.outputs["expected"]
predicted = run.outputs["output"]
expected: Set[str] = {tool for tools in expected for tool in tools}
predicted: Set[str] = {tool["type"] for tool in predicted}
true_positives = predicted & expected
if len(predicted) == 0:
if len(expected) > 0:
score = 0
else:
score = 1
else:
score = len(true_positives) / len(predicted)
return {"key": "tool_selection_precision", "score": score}
eval_config = RunEvalConfig(
custom_evaluators=[selected_tools_precision],
)
Create model
We will perform simple function calling using the tools appropriate for this dataset.
import json
with open("./data/tools.json") as f:
tools = json.load(f)
# Example tool
tools[0]
{'type': 'function',
'function': {'name': 'TransportistasdeArgentina',
'description': 'Quote for postcode in OCA e-Pack.',
'parameters': {'type': 'object',
'properties': {'postCodeDst': {'type': 'number',
'description': 'Postcode Destination'},
'cuit': {'type': 'string',
'description': 'CUIT of your account in OCA e-Pack'},
'operativa': {'type': 'string',
'description': 'Operativa number of your account in OCA e-Pack'},
'cost': {'type': 'number', 'description': 'Cost of products in ARS'},
'postCodeSrc': {'type': 'number', 'description': 'Postcode Source'},
'volume': {'type': 'number', 'description': 'Volume in cm3'},
'weight': {'type': 'number', 'description': 'Weight in KG'}},
'required': ['postCodeDst',
'cuit',
'operativa',
'cost',
'postCodeSrc',
'volume',
'weight']}}}
from langchain_core.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
model = "gpt-3.5-turbo"
assistant_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Respond to the user's query using the provided tools",
),
("user", "{query}"),
]
)
llm = ChatOpenAI(model=model).bind_tools(tools)
chain = assistant_prompt | llm | JsonOutputToolsParser()
Evaluate
test_results = client.run_on_dataset(
dataset_name=dataset_name,
llm_or_chain_factory=chain,
evaluation=eval_config,
verbose=True,
project_metadata={
"model": model,
"tool_variant": 0,
},
)
View the evaluation results for project 'clear-jet-37' at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/462d8386-60c8-4cb3-84eb-6efeae3a1293/compare?selectedSessions=8b95a94e-c05f-4ecf-b749-aeaef3ff3327
View all tests for Dataset Tool Selection (Logistics) dev at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/462d8386-60c8-4cb3-84eb-6efeae3a1293
[------------------------------------------------->] 100/100
<h3>Experiment Results:</h3>
feedback.tool_selection_precision | error | execution_time | run_id | count | unique | top | freq | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
100.000000 | 0 | 100.000000 | 100 | ||||||||||||
NaN | 0 | NaN | 100 | ||||||||||||
NaN | NaN | NaN | 827e2f98-bcb1-4940-aa16-5a7d0eca80ff | ||||||||||||
NaN | NaN | NaN | 1 | ||||||||||||
0.636667 | NaN | 1.417737 | NaN | ||||||||||||
0.370322 | NaN | 0.581734 | NaN | ||||||||||||
0.000000 | NaN | 0.468482 | NaN | ||||||||||||
0.333333 | NaN | 1.141958 | NaN | ||||||||||||
0.500000 | NaN | 1.331713 | NaN | ||||||||||||
1.000000 | NaN | 1.576078 | NaN | ||||||||||||
1.000000 | NaN | 4.320643 | NaN |
After evaluating, we'd recommend reviewing the results and manually identifying issues you can fix. This is a noisy dataset that we haven't yet cleaned, so you will likely want to fix the labels to actually serve as the ground truth.
If you want to try something more automated (but less reliable), read on.
Prompt Improver
We'll take the lazy approach and ask an LLM to recommend an improved set of tool descriptions that it "thinks" will improve tool selection.
It'll be a basic map-reduce type operation:
- Map each failed case to an update
- Reduce the updates by the tool name
- Distill the updates per tool
from typing import List
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import chain as as_runnable
from langchain_openai import ChatOpenAI
class FunctionUpdate(BaseModel):
name: str = Field(
description="The name of the tool whose description you'd like to update"
)
updated_description: str = Field(
description="The updated description that would make it clear when and why to invoke this function."
)
class ImproveToolDocumentation(BaseModel):
"""Called to update the docstrings and other information about a given tool
so that the user has an easier time selecting."""
updates: List[FunctionUpdate] = Field(
description="The updates to make, one for each tool description you'd like to change"
)
improver_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an API documentation assistant tasked with meticulously improving the descriptions of our API docs."
" Our AI assistant is trying to assist users by calling APIs, but it continues to invoke the wrong ones."
" You must improve their documentation to remove ambiguity so that the assistant will no longer make any mistakes.\n\n"
"##Valid APIs\nBelow are the existing APIs the assistant is choosing between:\n```apis.json\n{apis}\n```\n\n"
"## Failure Case\nBelow is a user query, expected API calls, and actual API calls."
" Use this failure case to make motivated doc changes.\n\n```failure_case.json\n{failure}\n```",
),
(
"user",
"Respond with the updated tool descriptions to clear up"
" whatever ambiguity caused the failure case above."
" Feel free to mention what it is NOT appropriate for (if that's causing issues.), like 'don't use this for x'."
" The updated description should reflect WHY the assistant got it wrong in the first place.",
),
]
)
llm = ChatOpenAI(model="gpt-3.5-turbo").with_structured_output(ImproveToolDocumentation)
improver_chain = improver_prompt | llm
apis = json.dumps(tools, indent=2)
df = test_results.to_dataframe()
# Filter out success cases
df = df[df["feedback.tool_selection_precision"] < 1]
def format_inputs(series):
return {
"apis": apis,
"failure": json.dumps(
{
"query": series["inputs.query"],
"predicted": [out["type"] for out in series["output"]],
"expected": series["reference.expected"][0],
}
),
}
improver_inputs = df.apply(format_inputs, axis=1).tolist()
Map errors -> updates
# This is the basic "Map" step
all_updates = improver_chain.batch(improver_inputs, return_exceptions=True)
# Just in case one of the runs failed (OAI downtime, LLM error, etc.)
all_updates = [u for u in all_updates if isinstance(u, ImproveToolDocumentation)]
Reduce updates per tool
from collections import defaultdict
toolwise_updates = defaultdict(list)
for updates in all_updates:
for tool_update in updates.updates:
toolwise_updates[tool_update.name].append(tool_update.updated_description)
Distill updates into a final description
distill_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an API documentation assistant tasked with meticulously improving the descriptions of our API docs."
" Our AI assistant is trying to help users by calling APIs, but it continues to invoke the wrong ones."
" You are tasked with updating the {target_api} description.\n\n"
"## Current APIs\n"
"Below is a list of the current APIs and descriptions.\n"
"```apis.json\n{apis}\n```\n\n"
"## Candidates\n"
" Here are some candidate desription improvements:\n{candidates}\n"
" Consider the above feedback in your updated description.",
),
(
"user",
"Respond with the updated description for the {target_api} API."
" It should distill and incorporate the candidate descriptions to"
" clear up whatever ambiguity is causing our AI assistant to fail.",
),
]
).partial(apis=apis)
distill_llm = ChatOpenAI(model=model).with_structured_output(FunctionUpdate)
distill_chain = distill_prompt | distill_llm
distill_inputs = [
{
"target_api": name,
"candidates": "\n".join(["- " + c for c in candidates]),
}
for name, candidates in toolwise_updates.items()
]
updated_descriptions = distill_chain.batch(distill_inputs)
updates_dict = {upd.name: upd.updated_description for upd in updated_descriptions}
updates_dict
{'TransportistasdeArgentina': 'Get a shipping quote for sending products within Argentina using OCA e-Pack. Provide destination and source postcodes, CUIT, operativa number, cost, volume, and weight details for accurate pricing.',
'TurkeyPostalCodes': 'Retrieve Turkish plate numbers (1 to 81) based on the city code. This API is specifically designed to provide details about Turkish plates and is not intended for tracking packages or obtaining postal codes for cities in Argentina.',
'CEPBrazil': 'Retrieve address details based on a Brazilian CEP number. This function is NOT intended for tracking package locations or statuses, tracking travel documents, or providing non-address related information. Use this API specifically for address lookup using CEP numbers in Brazil.',
'PridnestroviePost': 'Get track information by providing a track number for international shipments. Use this API specifically for tracking packages and shipments.',
'PackAndSend': 'If you have a Pack & Send Reference Number, use this API to track the delivery status and retrieve relevant information about the package. This API is specifically designed for tracking packages using the Pack & Send Reference Number, and it is not intended for providing postal code information for cities in a specific state.',
'TrackingMore_v2': 'List all supported carriers for package tracking. This API provides a comprehensive overview of available carriers for tracking packages. It is not intended for tracking specific package details, but rather for identifying carriers for package tracking services.',
'SQUAKE': 'This function does not have a defined purpose or parameters. Avoid using it as it does not serve any specific functionality. It is not suitable for tracking packages or retrieving package information.',
'AmexAustraliaFastwayAustraliaTracking': "Track a package's shipping details specifically within Australia using a package tracking number. This API is designed for tracking packages shipped via the AmexAustraliaFastway service and is not suitable for tracking international shipments or non-package related inquiries.",
'suivi-colis': 'Retrieve the current status (i.e., the latest status) of a package by providing the package ID. This function is suitable for tracking package statuses and obtaining real-time updates on delivery progress.',
'CreateContainerTracking': 'Retrieve data related to a container using the container ID provided. This API is suitable for tracking containers and retrieving their details.',
'Transitaires': 'Retrieve details about a specific transit company. This function is NOT designed for tracking packages or event planning.',
'KargomNerede': 'Retrieve a list of shipping companies.',
'GS1Parser': 'Parse machine- or human-readable GS1 barcode data.'}
from copy import deepcopy
new_tools = deepcopy(tools)
for tool in new_tools:
name = tool["function"]["name"]
if name in updates_dict:
updated = updates_dict[name]
tool["function"]["description"] = updated
Re-Evaluate
We will use the same LLM and prompt, only updating the tools descriptions (which are injected into the prompt on OpenAI's server).
llm = ChatOpenAI(model=model).bind_tools(new_tools)
updated_chain = assistant_prompt | llm | JsonOutputToolsParser()
model = "gpt-3.5-turbo"
updated_test_results = client.run_on_dataset(
dataset_name=dataset_name,
llm_or_chain_factory=updated_chain,
evaluation=eval_config,
project_metadata={
"model": model,
# Mark that this is a new tool descsription version
"tool_variant": 2,
},
verbose=True,
)
View the evaluation results for project 'ordinary-step-81' at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/462d8386-60c8-4cb3-84eb-6efeae3a1293/compare?selectedSessions=a4204d34-4d08-42fa-a84d-19b850ad920e
View all tests for Dataset Tool Selection (Logistics) dev at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/462d8386-60c8-4cb3-84eb-6efeae3a1293
[------------------------------------------------->] 99/100
Chain failed for example 033fd6d7-6c80-4ef2-ab26-e4116e4da24a with inputs {'query': "I'm planning a family vacation to Brazil and I need to find a hotel in Rio de Janeiro. Can you provide me with a list of available hotels in Rio de Janeiro downtown? Additionally, I would like to know the current health status of the CEP Brazil API and if it's functioning properly."}
Error Type: InternalServerError, Message: Error code: 500 - {'error': {'message': 'The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error. (Please include the request ID req_35b74dde88208be45493f9827dc88674 in your email.)', 'type': 'server_error', 'param': None, 'code': None}}
[------------------------------------------------->] 100/100
The metrics look slightly better, though not beyond a standard margin of error. You can investigate the dataset and model output quality and refine either to make the process more useful.
Test
We've been hill climbing on a single dataset, meaning that information about what we are evaluating on is seeping in to the model definition. Before concluding that variant B is better than A, you should benchmark both on a held-out set. Below are datasets that were sampled from the same ToolBench dataset.
dataset_urls = {
# Dev is same as above
"dev": dev_dataset_url,
"test": "https://smith.langchain.com/public/a5fd6197-36ed-4d06-993a-89929dded399/d",
"train": "https://smith.langchain.com/public/cf5a1de8-68f0-4170-9bcc-f263c1abb063/d",
}
import langsmith
client = langsmith.Client()
client.clone_public_dataset(dataset_urls["test"])
test_dataset_name = "Tool Selection (Logistics) test"
for target_chain in [chain, updated_chain]:
client.run_on_dataset(
dataset_name=test_dataset_name,
llm_or_chain_factory=chain,
evaluation=eval_config,
project_metadata={
"model": model,
# Mark that this is a new tool descsription version
"tool_variant": 2,
},
)
View the evaluation results for project 'definite-coach-89' at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/ddc1bcf7-c3fb-4669-824d-eb2e23af93d0/compare?selectedSessions=2b7204c8-7f07-4c2c-b798-d9005a059ce0
View all tests for Dataset Tool Selection (Logistics) test at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/ddc1bcf7-c3fb-4669-824d-eb2e23af93d0
[------------------------------------------------->] 234/234View the evaluation results for project 'sparkling-doctor-64' at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/ddc1bcf7-c3fb-4669-824d-eb2e23af93d0/compare?selectedSessions=b9d4bd07-d96b-4da8-97df-279158ffafa1
View all tests for Dataset Tool Selection (Logistics) test at:
https://smith.langchain.com/o/30239cd8-922f-4722-808d-897e1e722845/datasets/ddc1bcf7-c3fb-4669-824d-eb2e23af93d0
[------------------------------------------------> ] 231/234