Skip to main content

OpenAI Fine-Tuning

Open In Collab Open In GitHub

Once you've captured run traces from your deployment (production or beta), it's likely you'll want to use that data to fine-tune a model. This walkthrough will show a quick way to do so.

Steps:

  1. Query runs (optionally filtering by project, time, tags, etc.)
    • [Optional] Create a 'training' dataset to keep track of the data used for this model.
  2. Convert runs to OpenAI messages or another format)
  3. Fine-tune and use new model.
from langsmith import Client

client = Client()

1. Query runs

LangSmith saves traces for each runnable component in your LLM application. You can then query these runs in a variety of ways to construct your a training dataset. We will show a few common patterns below.

For examples of more 'advanced' filtering, check out the filtering guide in the LangSmith docs.

List all LLM runs for a specific project.

The simplest query is just listing "llm" runs in your project (filtering out runs with errors). Below is an example where we list all LLM runs in the default project.

import datetime

project_name = "default"
run_type = "llm"
end_time = datetime.datetime.now()

runs = client.list_runs(
project_name=project_name,
run_type=run_type,
error=False,
)

Filter by feedback

Depending on how you're fine-tuning, you'll likely want to filter out 'bad' examples (and want to filter in 'good' examples).

You can directly list by feedback! Usually you assign feedback to the root of the run trace, so we will use 2 queries.

from langchain import chains, chat_models, prompts, schema, callbacks

chain = (
prompts.ChatPromptTemplate.from_template("Tell a joke for:\n{input}")
| chat_models.ChatAnthropic(tags=["my-anthropic-run"])
| schema.output_parser.StrOutputParser()
)

with callbacks.collect_runs() as cb:
chain.invoke({"input": "foo"})
# Assume feedback is logged
run = cb.traced_runs[0]
client.create_feedback(run.id, key="user_click", score=1)
project_name = "default"
end_time = datetime.datetime.now()

runs = client.list_runs(
project_name=project_name,
execution_order=1,
filter='and(eq(feedback_key, "user_click"), eq(feedback_score, 1))',
# For continuous scores, you can filter for >, <, >=, <= with the followingg arguments: gt/lt/gte/lte(feedback_score, 0.9)
# filter='and(eq(feedback_key, "user_click"), gt(feedback_score, 0.9))',
error=False,
)

Once you have these run ids, you can find the LLM run if it is a direct child of the root or if you use a tag for a given trace.

llm_runs = []
for run in runs:
llm_run = next(
client.list_runs(
project_name=project_name, run_type="llm", parent_run_id=run.id
)
)
llm_runs.append(llm_run)

llm_runs[0].tags
['my-anthropic-run']

Filter by tags

It's common to have multiple chain types in a single project, meaning that the LLM calls may span multiple tasks and domains. Tags are a useful way to organize runs by task, component, test variant, etc, so you can curate a coherent dataset.

Below is a quick example. Please also reference the Tracing FAQs for more information on tagging.

# For any "Chain" object, you can add tags directly on the Example with LLMChain
import uuid

unique_tag = f"call:{uuid.uuid4()}"

chain = chains.LLMChain(
llm=chat_models.ChatAnthropic(
tags=["my-cool-llm-tag"]
), # This tag will only be applied to the LLM
prompt=prompts.ChatPromptTemplate.from_template(
"Tell a joke based on the following prompt:\n\nPrompt:{input}"
),
tags=["my-tag"],
)

# You can also define at call time for the call/invoke/batch methods.
# This tag will be propagated to all child calls
print(chain({"input": "podcasting these days"}, tags=[unique_tag]))

# If you're defining using Runnables (aka langchain expression language)
runnable = (
prompts.ChatPromptTemplate.from_template(
"Tell a joke based on the following prompt:\n\nPrompt:{input}"
)
| chat_models.ChatAnthropic(
tags=["my-cool-llm-tag"]
) # This tag will only be applied to the LLM
| schema.StrOutputParser(tags=["some-parser-tag"])
)

# Again, you can tag at call time as well. This tag will be propagated to all child calls
print(runnable.invoke({"input": "podcasting these days"}, {"tags": [unique_tag]}))
{'input': 'podcasting these days', 'text': ' Here\'s a joke about podcasting these days:\n\nEveryone seems to have a podcast these days. My neighbor started one where he just reads his grocery lists out loud. It\'s called "What\'s in Store?" and it\'s surprisingly addicting. I also heard about a podcast where someone literally just snoozes and snores into the mic for an hour. They call it "Podsleeping." And don\'t even get me started on my cousin\'s mumbling podcast. You can\'t understand a word she\'s saying, but she swears it\'s riveting stuff. I guess when it comes to podcasts, you can record and release pretty much anything nowadays. The bar is so low, it\'s practically underground at this point!'}
Here's a joke about podcasting these days:

It seems like everyone has a podcast now. My grandma just started one called "The Quilting Hour" where she talks about different stitch patterns while soft piano music plays in the background. Then my dog launched his own podcast where he just barks and pants into the microphone for 30 minutes. At this rate, my goldfish is going to start a podcast where he blows bubbles near the mic. Podcasting has really jumped the shark. Pretty soon we'll have more podcasts than actual listeners!
project_name = "default"
end_time = datetime.datetime.now()

runs = client.list_runs(
execution_order=1, # Only return the root trace
filter=f'has(tags, "{unique_tag}")',
)
len(list(runs))
2

Filter by run name.

By default, the run name is the class of the object being traced. You can filter by run name to narrow your search by, e.g., the LLM class.

Below, we will list all runs sent to a "ChatAnthropic" llm.

project_name = "default"
run_type = "llm"
end_time = datetime.datetime.now()

runs = client.list_runs(
project_name=project_name,
run_type=run_type,
filter='eq(name, "ChatAnthropic")',
error=False,
)

Retrieve prompt inputs directly

If you fetch the LLM or chat run directly, the input will be the formatted prompt, with the values injected. You may want to separate the injected values from the prompt templating to remove or reduce the quantity of instruction prompting needed to obtain the desired prediction.

If your chain is composed as runnables (for instance, if you use LangChain Expression Language), each prompt runnable will be given its own run trace. You can fetch the inputs to the prompt template directly so that when you fine-tune, you can elide the other template content and train directly on the input values and LLM outputs.

Take the following chain, for instance, which is promoted to a RunnableSequence via the piping operation.

# Example chain for the following query
from langchain import prompts, chat_models

chain = (
prompts.ChatPromptTemplate.from_template(
"Summarize the following chat log: {input}"
)
| chat_models.ChatOpenAI()
)

chain.invoke({"input": "hi there, hello...."})
AIMessage(content='The chat log consists of a simple greeting exchange, with one person saying "hi there" and the other responding with "hello."', additional_kwargs={}, example=False)
import datetime

project_name = "default"
run_type = "prompt"
end_time = datetime.datetime.now()

runs = client.list_runs(
project_name=project_name,
run_type=run_type,
end_time=end_time,
error=False,
)
# You can then get a sibling LLM run by searching by parent_run_id and including other criteria
for prompt_run in runs:
llm_run = next(
client.list_runs(
project_name=project_name,
run_type="llm",
parent_run_id=prompt_run.parent_run_id,
)
)
inputs, outputs = prompt_run.inputs, llm_run.outputs

While not necessary for the fast-path of making your first fine-tuned model, datasets help build in a principled way by helping track the exact data used in a given model. They also are a natural place to add manual review or spot checking in the web app.

dataset = client.create_dataset(
dataset_name="Fine-Tuning Dataset Example",
description=f"Chat logs taken from project {project_name} for fine-tuning",
data_type="chat",
)
for run in runs:
if "messages" not in run.inputs or not run.outputs:
# Filter out non chat runs
continue
try:
# Convenience method for creating a chat example
client.create_example_from_run(
dataset_id=dataset.id,
run=run,
)
# Or if you want to select certain keys/values in inputs
# inputs = convert_inputs(run.inputs)
# outputs = convert_outputs(run.outputs)
# client.create_example(
# dataset_id=dataset.id,
# inputs=inputs,
# outputs=outputs,
# run=run,
# )
except:
# Duplicate inputs raise an exception
pass

2. Load examples as messages

We will first load the messages as LangChain objects then take advantage of the OpenAI adapter helper to convert these to dictionaries in the form expected by OpenAI's fine-tuning endpoint.

from langsmith import schemas
from langchain import load


def convert_messages(example: schemas.Example) -> dict:
messages = load.load(example.inputs)["messages"]
message_chunk = load.load(example.outputs)["generations"][0]["message"]
return {"messages": messages + [message_chunk]}
messages = [
convert_messages(example)
for example in client.list_examples(dataset_name="Fine-Tuning Dataset Example")
]

Now that we have the traces back as LangChain message objects, you can use the adapters to convert to other formats, such as OpenAI's fine-tuning format.

from langchain.adapters import openai as openai_adapter

finetuning_messages = openai_adapter.convert_messages_for_finetuning(messages)

3. Finetune

Now you can use these message dictionaries for downstream tasks like fine-tuning. Note that the OpenAI API doesn't currently support the 'function_call' argument when fine-tuning. We will filter these out first here. It may be that this requirement is relaxed by the time you read this guide.

import time
import json
import io

import openai

my_file = io.BytesIO()
for group in finetuning_messages:
if any(["function_call" in message for message in group]):
continue
my_file.write((json.dumps({"messages": group}) + "\n").encode("utf-8"))

my_file.seek(0)
training_file = openai.File.create(file=my_file, purpose="fine-tune")

# Wait while the file is processed
status = openai.File.retrieve(training_file.id).status
start_time = time.time()
while status != "processed":
print(f"Status=[{status}]... {time.time() - start_time:.2f}s", end="\r", flush=True)
time.sleep(5)
status = openai.File.retrieve(training_file.id).status
print(f"File {training_file.id} ready after {time.time() - start_time:.2f} seconds.")
File file-ixTtVCKDGrZ7PiVZFszQr6kN ready after 30.55 seconds.

Next, fine-tune the model. This could take 10+ minutes depending on the server's load and your dataset size.

job = openai.FineTuningJob.create(
training_file=training_file.id,
model="gpt-3.5-turbo",
)

# It may take 10-20+ minutes to complete training.
status = openai.FineTuningJob.retrieve(job.id).status
start_time = time.time()
while status != "succeeded":
print(f"Status=[{status}]... {time.time() - start_time:.2f}s", end="\r", flush=True)
time.sleep(5)
job = openai.FineTuningJob.retrieve(job.id)
status = job.status
Status=[running]... 743.73s

Now you can use the model within langchain.

from langchain import chat_models, prompts

model_name = job.fine_tuned_model
# Example: ft:gpt-3.5-turbo-0613:personal::5mty86jblapsed
model = chat_models.ChatOpenAI(model=model_name)
chain.invoke({"input": "Who are you designed to assist?"})
AIMessage(content='I am designed to assist anyone who has questions or needs help related to LangChain. This could include developers, users, or anyone else who is interested in or using LangChain.', additional_kwargs={}, example=False)

Conclusion

Congratulations! You've fine-tuned a model on your traced LLM runs.

This is an extremely simple recipe that demonstrates the end-to-end workflow. It is likely that you will want to use various methods to filter, fix, and observe the data you choose to fine-tune the model on. We welcome additional recipes of things that have worked for you!


Was this page helpful?