Align Meta Llama 3 to human preferences with DPO, Amazon SageMaker Studio, and Amazon SageMaker Ground Truth

In this post, we show you how to enhance the performance of Meta Llama 3 8B Instruct by fine-tuning it using direct preference optimization (DPO) on data collected with SageMaker Ground Truth.

Sep 10, 2024 - 00:00
Align Meta Llama 3 to human preferences with DPO, Amazon SageMaker Studio, and Amazon SageMaker Ground Truth

Large language models (LLMs) have remarkable capabilities. Nevertheless, using them in customer-facing applications often requires tailoring their responses to align with your organization’s values and brand identity. In this post, we demonstrate how to use direct preference optimization (DPO), a technique that allows you to fine-tune an LLM with human preference data, together with Amazon SageMaker Studio and Amazon SageMaker Ground Truth to align the Meta Llama 3 8B Instruct model responses to your organization’s values.

Using SageMaker Studio and SageMaker Ground Truth for DPO

With DPO, you can fine-tune an LLM with human preference data such as ratings or rankings so that it generates outputs that align to end-user expectations. DPO is computationally efficient and helps enhance a model’s helpfulness, honesty, and harmlessness, divert the LLM from addressing specific subjects, and mitigate biases. In this technique, you typically start with selecting an existing or training a new supervised fine-tuned (SFT) model. You use the model to generate responses and you gather human feedback on these responses. After that, you use this feedback to perform DPO fine-tuning and align the model to human preferences.

Whether you are fine-tuning a pre-trained LLM with supervised fine-tuning (SFT) or loading an existing fine-tuned model for DPO, you typically need powerful GPUs. The same applies during DPO fine-tuning. With Amazon SageMaker, you can get started quickly and experiment rapidly by using managed Jupyter notebooks equipped with GPU instances. You can quickly get started by creating a JupyterLab space in SageMaker Studio, the integrated development environment (IDE) purpose-built for machine learning (ML), launch a JupyterLab application that runs on a GPU instance.

Orchestrating the end-to-end data collection workflow and developing an application for annotators to rate or rank model responses for DPO fine-tuning can be time-consuming. SageMaker Ground Truth offers human-in-the-loop capabilities that help you set up workflows, manage annotators, and collect consistent, high-quality feedback.

This post walks you through the steps of using DPO to align an SFT model’s responses to the values of a fictional digital bank called Example Bank. Your notebook runs in a JupyterLab space in SageMaker Studio powered by a single ml.g5.48xlarge instance (8 A10G GPUs). Optionally, you can choose to run this notebook inside a smaller instance type such as ml.g5.12xlarge (4 A10G GPUs) or ml.g6.12xlarge (4 L4 GPUs) with bitsandbytes quantization. You use Meta Llama 3 8B Instruct (the Meta Llama 3 instruction tuned model optimized for dialogue use cases from the Hugging Face Hub) to generate responses, SageMaker Ground Truth to collect preference data, and the DPOTrainer from the HuggingFace TRL library for DPO fine-tuning together with Parameter-Efficient Fine-Tuning (PEFT). You also deploy the aligned model to a SageMaker endpoint for real-time inference. You can use the same approach with other models.

Solution overview

The following diagram illustrates the approach.

The workflow contains the following key steps:

  1. Load the Meta Llama 3 8B Instruct model into SageMaker Studio and generate responses for a curated set of common and toxic questions. The dataset serves as the initial benchmark for the model’s performance.
  2. The generated question-answer pairs are stored in Amazon Simple Storage Service (Amazon S3). These will be presented to the human annotators later so they can rank the model responses.
  3. Create a workflow in SageMaker Ground Truth to gather human preference data for the responses. This involves creating a work team, designing a UI for feedback collection, and setting up a labeling job.
  4. Human annotators interact with the labeling portal to evaluate and rank the model’s responses based on their alignment to the organization’s values.
  5. The collected data is processed to adhere to the DPOTrainer expected format.
  6. Using the Hugging Face TRL library and the DPOTrainer, fine-tune the Llama 3 model using the processed data from the previous step.
  7. Test the fine-tuned model on a holdout evaluation dataset to assess its performance and verify it meets the desired standards.
  8. When you’re satisfied with the model performance, you can deploy it to a SageMaker endpoint for real-time inference at scale.

Prerequisites

To run the solution described in this post, you must have an AWS account set up, along with an AWS Identity and Access Management (IAM) role that grants you the necessary permissions to create and access the solution resources. If you are new to AWS and haven’t created an account yet, refer to Create a standalone AWS account.

To use SageMaker Studio, you need to have a SageMaker domain set up with a user profile that has the necessary permissions to launch the SageMaker Studio application. If you’re new to SageMaker Studio, the Quick Studio setup is the fastest way to get started. With a single click, SageMaker provisions the required domain with default presets, including setting up the user profile, IAM role, IAM authentication, and public internet access. The notebook associated with this post assumes the use of an ml.g5.48xlarge instance type. To review or increase your quota limits, navigate to the AWS Service Quotas console, choose AWS Services in the navigation pane, choose Amazon SageMaker, and refer to the value for Studio JupyterLab Apps running on ml.g5.48xlarge instances.

Request an increase in quota value greater than or equal to 1 for experimentation.

Meta Llama 3 8B Instruct is available under the Llama 3 license. To download the model from Hugging Face, you need an access token. If you don’t already have one, navigate to the Settings page on the Hugging Face website to obtain it.

Make sure that the SageMaker Studio role has the necessary permissions for SageMaker Ground Truth and Amazon S3 access. When you’re working in SageMaker Studio, you’re already using an IAM role, which you’ll need to modify for launching SageMaker Ground Truth labeling jobs. To enable SageMaker Ground Truth functionality, you should attach the AWS managed policy AmazonSageMakerGroundTruthExecution to your SageMaker Studio role. This policy provides the essential permissions for creating and managing labeling jobs.

For Amazon S3 access, scoping permissions to specific buckets and actions enhances security and aligns with best practices. This approach adheres to the principle of least privilege, reducing potential risks associated with overly permissive policies. The following is an example of a restricted Amazon S3 policy that grants only the necessary permissions:

{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:ListBucket"
            ],
            "Resource": [
                "arn:aws:s3:::",
                "arn:aws:s3:::/*"
            ]
        }
    ]
}

To add these policies to your SageMaker Studio role, complete the following steps:

  1. On the IAM console, find and choose your SageMaker Studio role (it usually starts with AmazonSageMaker-ExecutionRole-).
  2. On the Permissions tab, choose Add permissions and then Attach policies.
  3. Search for and attach AmazonSageMakerGroundTruthExecution.
  4. Create and attach the custom Amazon S3 inline policy as shown in the preceding example, if needed.

Remember to follow the principle of least privilege, granting only the permissions necessary for your specific use case. Regularly review your IAM roles and policies to validate their alignment with your security requirements. For more details on IAM policies for SageMaker Ground Truth, refer to Use IAM Managed Policies with Ground Truth.

Set up the notebook and environment

To get started, open SageMaker Studio and create a JupyterLab space. For Instance, choose ml.g5.48xlarge. Run the space, open JupyterLab, and clone the code in the following GitHub repository. You can configure the JupyterLab space to use up to 100 GB in your Amazon Elastic Block Store (Amazon EBS) volume. In addition, the ml.g5 instance family comes with NVMe SSD local storage, which you can use in the JupyterLab application. The NVMe instance store directory is mounted to the application container in /mnt/sagemaker-nvme. For this post, you use the NVMe storage available in the ml.g5.48xlarge instance.

When your space is ready, clone the GitHub repo and open the notebook llama3/rlhf-genai-studio/RLHF-with-Llama3-on-Studio-DPO.ipynb, which contains the solution code. In the pop-up, make sure that the Python 3 kernel is selected.

Let’s go through the notebook. First, install the necessary Python libraries:

import torch
import os
import sagemaker
import boto3
import datetime
from transformers import pipeline
import json
import asyncio
import aiofiles
from datasets import Dataset, load_dataset
from peft import (
get_peft_model,
    LoraConfig,
    prepare_model_for_kbit_training,
)
import bitsandbytes as bnb
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    AutoModelForSequenceClassification
)
from IPython.core.display import display, HTML

The following line sets the default path where you store temporary artifacts to the location in the NVMe storage:

cache_dir = "/mnt/sagemaker-nvme"

This is local storage, which means that your data will be lost when the JupyterLab application is deleted, restarted, or patched. Alternatively, you can increase your EBS volume of your SageMaker Studio space to greater than or equal to 100 GB to provide sufficient storage for the Meta Llama 3 base model, PEFT adapter, and new merged fine-tuned model.

Load Meta Llama 3 8B Instruct in the notebook

After you have imported the necessary libraries, you can download the Meta Llama 3 8B Instruct model and its associated tokenizers from Hugging Face:

base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    token=hf_access_token,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    cache_dir=cache_dir
)

model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    token=hf_access_token,
    cache_dir=cache_dir
)

Collect initial model responses for common and toxic questions

The example_bank_questions.txt file contains a list of common questions received by call centers in financial organizations combined with a list of toxic and off-topic questions.

Before you ask the model to generate answers to these questions, you need to specify the brand and core values of Example Bank. You will include these values in the prompt as context later so the model has the appropriate information it needs to respond.

company_context = """Example Bank is a next-generation digital bank on a mission to revolutionize the banking experience. Founded in 2020, we are committed to leveraging cutting-edge technology to make banking simple, accessible, and transparent for everyone. In Example Bank, we believe that banking should be seamless, intuitive, and tailored to the needs of modern consumers. Our founders, seasoned professionals from the tech and finance industries, set out to create a bank that puts people first, empowering them to take control of their finances with ease. At Example Bank, we envision a world where banking is no longer a chore but a delightful experience. We are dedicated to breaking down barriers and democratizing access to financial services. Our goal is to empower individuals and businesses alike by providing them with the tools and resources they need to thrive in an increasingly digital landscape.
Our values:
- Innovation: We embrace cutting-edge technologies and continuously seek out innovative solutions to deliver the best possible banking experience. We are a digital-only bank, which means we don't have any physical branches. Instead, we offer all of our services online or through our mobile app. This allows us to keep our costs low and pass the savings on to our customers.
- Transparency: We are committed to being direct and honest with our customers. We believe that transparency is key to building trust, and we want our customers to feel confident that they are making informed decisions about their money. That's why we provide clear and concise information about our products and services, and we are always available to answer any questions our customers may have.
- Accessibility: Our services are designed to be inclusive and user-friendly, catering to a diverse range of customers, regardless of their financial backgrounds.
- Security: We prioritize the safety and security of our customers' data and assets, employing state-of-the-art encryption and cybersecurity measures.
In addition to our core values, Example Bank offers a range of innovative financial products and services:
- Loans: Whether you’re looking to buy a home, start a business, or finance a major purchase, our flexible loan options are designed to meet your needs. With competitive interest rates and a simple application process, obtaining a loan has never been easier.
- Credit Cards: Our credit cards come with a host of benefits including cashback rewards, low-interest rates, and no annual fees. Manage your spending effortlessly with real-time notifications and intuitive budgeting tools.
- Mobile Apps: Our user-friendly apps on the Google Play Store and Apple App Store offer a seamless banking experience. From checking balances to transferring funds, our apps ensure you have complete control of your finances at your fingertips.
- Savings and Investments: Grow your wealth with our high-yield savings accounts and a variety of investment options. Our financial advisors are available to help you make informed decisions tailored to your financial goals.
- Customer Support: We provide 24/7 customer support to assist with any inquiries or issues. Our dedicated team is always ready to help, ensuring you receive the best possible service at all times.
At Example Bank, we are committed to enhancing your financial well-being through innovation, transparency, and unparalleled service. Join us today and experience the future of banking.
"""

Now you’re ready to invoke the model. For each question in the file, you construct a prompt that contains the context and the actual question. You send the prompt to the model four times to generate four different outputs and save the results in the llm_responses.json file.

questions = 'example_bank_questions.txt'
llm_responses = os.path.join(sample_files_path, 'llm_responses.json')

from timeit import default_timer as timer
import tqdm.asyncio

async def invoke_model(question, context):
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    messages = [
        {"role": "user", "content": f"{context}: {question}"}
    ]

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    response = pipe(
        messages, 
        max_new_tokens=120, 
        do_sample=True,
        temperature=gl_temperature, 
        top_p=gl_top_p, 
        eos_token_id=terminators
    )[0]['generated_text'][-1]
    return response['content']

async def process_lines(file_path):
    results = []
    context = f"""{company_context} You are a customer service agent for {company_name} Sometimes you are smart with your answers. Answer the following customer question in one or two sentences:
    """
    async with aiofiles.open(file_path, 'r') as file:
        lines = [line async for line in file]
        for line in tqdm.asyncio.tqdm(lines, desc="Processing Question Bank"):
            start = timer()
            responses = await asyncio.gather(*[invoke_model(line, context) for _ in range(4)])
            result = {
                'context': context,
                'question': line.strip(),
                'responses': responses
            }
            end = timer()
            results.append(result)
    return results

results = await process_lines(questions)

with open(llm_responses, 'w') as file:
    json.dump(
        results, 
        file, 
        indent=4
    )

The following is an example entry from llm_reponses.json.

Set up the SageMaker Ground Truth labeling job and human preference data

To fine-tune the model using DPO, you need to gather human preference data for the generated responses. SageMaker Ground Truth helps orchestrate the data collection process. It offers customizable labeling workflows and robust workforce management features for ranking tasks. This section shows you how to set up a SageMaker Ground Truth labeling job and invite a human workforce with requisite expertise to review the LLM responses and rank them.

Set up the workforce

A private workforce in SageMaker Ground Truth consists of individuals who are specifically invited to perform data labeling tasks. These individuals can be employees or contractors who have the required expertise to evaluate the model’s responses. Setting up a private workforce helps achieve data security and quality by limiting access to trusted individuals for data labeling.

For this use case, the workforce consists of the group of people who will rank the model responses. You can set up a private workforce using the SageMaker console by creating a private team and inviting members through email. For detailed instructions, refer to Create a Private Workforce (Amazon SageMaker Console).

Create the instruction template

With the instruction template, you can manage the UI and guide human annotators in reviewing model outputs. It needs to clearly present the model responses and provide a straightforward way for the annotators to rank them. Here, you use the text ranking template. This template allows you to display the instructions for the human reviewer and the prompts with the pregenerated LLM responses. The annotator reviews the prompt and responses and ranks the latter based on their alignment to the organization’s brand.

The definition of the template is as follows. The template shows a pane on the left with instructions from the job requester, a prompt at the top, and three LLM responses in the main body. The right side of the UI is where the annotator ranks the responses from most to least preferable.

  
  
    
    
    
    Text Ranking Tool
    
  

  
    

The template is saved locally on your Studio JupyterLab space EBS volume as instructions.template in a temporary directory. Then you upload this template file to your designated S3 bucket using s3.upload_file(), placing it in the specified bucket and prefix. This Amazon S3 hosted template will be referenced when you create the SageMaker Ground Truth labeling job, so workers see the correct interface for the text ranking task.

Preprocess the input data

Before you create the labeling job, verify that the input data matches the format expected by SageMaker Ground Truth and is saved as a JSON file in Amazon S3. You can use the prompts and responses in the llm_responses.json file to create the manifest file inp-manifest-trank.json. Each row in the manifest file contains a JSON object (source-responses pair). The previous entry now looks like the following code.

Upload the structured data to the S3 bucket so that it can be ingested by SageMaker Ground Truth.

Create the labeling job

Now you’re ready to configure and launch the labeling job using the SageMaker API from within the notebook. This involves specifying the work team, UI template, and data stored in the S3 bucket. By setting appropriate parameters such as task time limits and the number of workers per data object, you can run jobs efficiently and effectively. The following code shows how to start the labeling job:

sm_client.create_labeling_job(
    LabelingJobName=labeling_job_name,
    LabelAttributeName='label',
    InputConfig={
        'DataSource': {
            'S3DataSource': {
                'ManifestS3Uri': model_responses_s3_uri
            }
        }
    },
    OutputConfig={
        'S3OutputPath': 's3://{}/{}/output/'.format(bucket,prefix) #Enter S3 URI of Output folder
    },
    RoleArn=role, 
    HumanTaskConfig={
        'WorkteamArn': WORKTEAM_ARN,
        'UiConfig':{
            'UiTemplateS3Uri': UI_TEMPLATE_S3_URI
        },
        'PreHumanTaskLambdaArn': 'arn:aws:lambda:us-east-1:432418664414:function:PRE-PassThrough',
        'TaskKeywords': [
            'QnA',
        ],
        'TaskTitle': 'Rank LLM responses',
        'TaskDescription': "Rank the responses provided by the LLM",
        'NumberOfHumanWorkersPerDataObject': 1,
        'TaskTimeLimitInSeconds': 60*30,
        'TaskAvailabilityLifetimeInSeconds': 60*60*24*10,
        'MaxConcurrentTaskCount': 100,
        'AnnotationConsolidationConfig': {
            'AnnotationConsolidationLambdaArn': 'arn:aws:lambda:us-east-1:432418664414:function:ACS-PassThrough'
        } 
    }

As the job is launched, it’s essential to monitor its progress closely, making sure tasks are being distributed and completed as expected.

Gather human feedback through the labeling portal

When the job setup is complete, annotators can log in to the labeling portal and start ranking the model responses.

Workers can first consult the Instructions pane to understand the task, then use the main interface to evaluate and rank the model’s responses according to the given criteria. The following screenshot illustrates the UI.

The human feedback is collected and stored in an S3 bucket. This feedback will be the basis for DPO. With this data, you will fine-tune the Meta Llama 3 model and align its responses with the organization’s values, improving its overall performance.

Align Meta Llama 3 8B Instruct with the DPOTrainer

In this section, we show how to use the preference dataset that you prepared using SageMaker Ground Truth to fine-tune the model using DPO. DPO explicitly optimizes the model’s output based on human evaluations. It aligns the model’s behavior more closely with human expectations and improves its performance on tasks requiring nuanced understanding and contextual appropriateness. By integrating human preferences, DPO enhances the model’s relevance, coherence, and overall effectiveness in generating desired responses.

DPO makes it more straightforward to preference-tune a model in comparison to other popular techniques such as Proximal Policy Optimization (PPO). DPO eliminates the necessity for a separate rewards model, thereby avoiding the cost associated with training it. Additionally, DPO requires significantly less data to achieve performance comparable to PPO.

Fine-tuning a language model using DPO consists of two steps:

  1. Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
  2. Maximize the log-likelihood of the DPO loss directly.

To learn more about the DPO algorithm, refer to the following whitepaper.

Expected data format

The DPO trainer expects a very specific format for the dataset, which contains sentence pairs where one sentence is a chosen response and the other is a rejected response. This is represented as a Python dictionary with three keys:

  • prompt – Consists of the context prompt given to a model at inference time for text generation
  • chosen – Contains the preferred generated response to the corresponding prompt
  • rejected – Contains the response that is not preferred or should not be the sampled response for the given prompt

The following function definition illustrates how to process the data stored in Amazon S3 to create a DPO dataset using with sample pairs and a prompt:

def return_prompt_and_responses(samples, index):
    prompt = f"{samples['context']}\n\n{samples['question']}"
    chosen_index = response_rankings[index]["responseRankings"].index(1)
    rejected_index = response_rankings[index]["responseRankings"].index(4)

    prompt = {"role": "user", "content": prompt},

    chosen_messages = [
        {"role": "assistant", "content": samples["responses"][chosen_index]},
    ]
    rejected_messages = [
        # {"role": "system", "content": prompt},
        {"role": "assistant", "content": samples["responses"][rejected_index]},
    ]
    
    return {
        "prompt": tokenizer.apply_chat_template(prompt, tokenize=False),
        "chosen": "{}".format(tokenizer.apply_chat_template(chosen_messages, tokenize=False).replace('<|begin_of_text|>', '')),
        "rejected": "{}".format(tokenizer.apply_chat_template(rejected_messages, tokenize=False).replace('<|begin_of_text|>', ''))
    }

Here is an example sentence pair:

You split the DPO trainer dataset into train and test samples using an 80/20 split and tokenize the dataset in preparation for DPO fine-tuning:

dataset = prepared_dataset.train_test_split(test_size=0.2)

dataset["train"].to_json(
    os.path.join(sample_files_path, "processed_human_feedback", "train_dataset.json"), 
    orient="records", 
    index="False"
)

dataset["test"].to_json(
    os.path.join(sample_files_path, "processed_human_feedback", "test_dataset.json"), 
    orient="records", 
    index="False"

Supervised fine-tuning using DPO

Now that the dataset is formatted for the DPO trainer, you can use the train and test datasets prepared earlier to initiate the DPO model fine-tuning. Meta Llama 3 8B belongs to a category of small language models, but even Meta Llama 3 8B barely fits into a SageMaker ML instance like ml.g5.48xlarge in fp16 or fp32, leaving little room for full fine-tuning. You can use PEFT with DPO to fine-tune Meta Llama 3 8B’s responses based on human preferences. PEFT is a method of fine-tuning that focuses on training only a subset of the pre-trained model’s parameters. This approach involves identifying the most important parameters for the new task and updating only those parameters during training. By doing so, PEFT can significantly reduce the computation required for fine-tuning. See the following code:

# configure PEFT module
peft_config = LoraConfig(
    r=512,
    lora_alpha=1024,
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
    target_modules="all-linear",

For a full list of LoraConfig training arguments, refer to LoRA. At a high level, you need to initialize the DPOTrainer with the following components: the model you want to train, a reference model (ref_model) used to calculate the implicit rewards of the preferred and rejected responses, the beta hyperparameter that controls the balance between the implicit rewards assigned to the preferred and rejected responses, and a dataset containing prompt, chosen, and rejected responses. If ref_model=None, the trainer will create a reference model with the same architecture as the input model to be optimized. See the following code:

from trl import DPOConfig, DPOTrainer

dpo_model_dir = "/path/to/save/dpo/model"

args = DPOConfig(
    output_dir=dpo_model_dir,               # directory to save and repository id
    num_train_epochs=5,                     # number of training epochs
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim = "adamw_torch_fused",            # use fused adamw optimizer
    learning_rate=1e-5,                     # 10x higher LR than QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                       # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",             # use cosine learning rate scheduler
    logging_steps=10,                       
    save_steps=10,                         # when to save checkpoint
    evaluation_strategy="steps",            
    eval_steps=100,
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    push_to_hub=False,                      # push model to hub,
    report_to='tensorboard',
    remove_unused_columns=False
)

dpo_args = {
    "beta": 0.1,                            # The beta factor in DPO loss. Higher beta means less divergence
    "loss_type": "sigmoid"                  # The loss type for DPO.
}

trainer = DPOTrainer(
    model,
    ref_model=None,
    peft_config=peft_config,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    beta=dpo_args["beta"],
    loss_type=dpo_args["loss_type"],
)

# kick-off model training
trainer.train()

Once you start the training, you can see the status in the notebook:

When model fine-tuning is complete, save the PEFT adapter model to disk and merge it with the base model to create a newly tuned model. You can use the saved model for local inference and validation or deploy it as a SageMaker endpoint after you have gained sufficient confidence in the model’s responses.

peft_output_dir = "/path/to/save/tuned/model/"
print(f"saving peft model to: {peft_output_dir}")
trainer.save_model(output_dir=peft_output_dir)
...
...
merged_model = model.merge_and_unload()
...
...
merged_model.save_pretrained(
    new_dpo_output_dir,
    safe_serialization=True,
    max_shard_size="9GB"
)

Evaluate the fine-tuned model inside a SageMaker Studio notebook

Before you host your model for inference, verify that its response optimization aligns with user preferences. You can collect the model’s response both before and after DPO fine-tuning and compare them side by side, as shown in the following table.

The DPO Model Response column indicates the RLHF aligned model’s response post-fine-tuning, and the Rejected Model Response column refers to the model’s response to the input prompt prior to fine-tuning.

Deploy the model to a SageMaker endpoint

After you have gained sufficient confidence in your model, you can deploy it to a SageMaker endpoint for real-time inference. SageMaker endpoints are fully managed and provide auto scaling capabilities. For this post, we use DJL Serving to host the fine-tuned, DPO-aligned Meta Llama3 8B model. To learn more about hosting your LLM using DJL Serving, refer to Deploy large models on Amazon SageMaker using DJLServing and DeepSpeed model parallel inference.

To deploy an LLM directly from your SageMaker Studio notebook using DJL Serving, complete the following steps:

  1. Upload model weights and other model artifacts to Amazon S3.
  2. Create a meta-model definition file called serving.properties. This definition file dictates how the DJL Serving container is configured for inference.

engine = DeepSpeed
option.tensor_parallel_degree = 1
option.s3url = s3:///llama3-dpo-ft/modelweights
option.hf_access_token=hf_xx1234

  1. Create a custom inference file called model.py, which defines a custom inference logic:
%%writefile llama3-serving-model/model.py

from djl_python import Input, Output
...

predictor = None


def get_model(properties):

    ...
    return generator


def handle(inputs: Input) -> None:
    ...
    outputs = predictor(message, **generation_kwargs)[0]['generated_text'][-1]
    result = {"outputs": outputs['content']}
    return Output().add(result)
  1. Deploy the DPO fine-tuned model as a SageMaker endpoint:
from sagemaker import image_uris
from sagemaker.model import Model
from datetime import datetime

inference_image_uri = image_uris.retrieve(
    framework="djl-deepspeed",
    region=region,
    version="0.23.0"
)

...

dpo_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name=f"ep-{dpo_model.name}",
    container_startup_health_check_timeout=900,
    wait=False, # <-- Set to True, if you would prefer to wait 6-8 minutes for the endpoint to spin up
)
  1. Invoke the hosted model for inference using the sageMaker.Predictor class:
dpo_ft_predictor = sagemaker.Predictor(
    endpoint_name="my_custom_dpo_endpoint",
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.JSONDeserializer(),
)
...
# invoke inference
response = dpo_ft_predictor.predict(
    {
        "inputs": content,
        "parameters": parameters
    }
)

Clean up

After you complete your tasks in the SageMaker Studio notebook, remember to stop your JupyterLab workspace to prevent incurring additional charges. You can do this by choosing Stop next to your JupyterLab space. Additionally, you have the option to set up lifecycle configuration scripts that will automatically shut down resources when they’re not in use.

If you deployed the model to a SageMaker endpoint, run the following code at the end of the notebook to delete the endpoint:

#delete your endpoint
sm_client.delete_endpoint(EndpointName=tg_sm_model.endpoint_name)

Conclusion

Amazon SageMaker offers tools to streamline the process of fine-tuning LLMs to align with human preferences. With SageMaker Studio, you can experiment interactively with different models, questions, and fine-tuning techniques. With SageMaker Ground Truth, you can set up workflows, manage teams, and collect consistent, high-quality human feedback.

In this post, we showed how to enhance the performance of Meta Llama 3 8B Instruct by fine-tuning it using DPO on data collected with SageMaker Ground Truth. To get started, launch SageMaker Studio and run the notebook available in the following GitHub repo. Share your thoughts in the comments section!


About the Authors

Anastasia Tzeveleka is a GenAI/ML Specialist Solutions Architect at AWS. As part of her work, she helps customers build foundation models and create scalable generative AI and machine learning solutions using AWS services.

Pranav Murthy is an AI/ML Specialist Solutions Architect at AWS. He focuses on helping customers build, train, deploy and migrate machine learning (ML) workloads to SageMaker. He previously worked in the semiconductor industry developing large computer vision (CV) and natural language processing (NLP) models to improve semiconductor processes. In his free time, he enjoys playing chess and traveling.

Sundar Raghavan is an AI/ML Specialist Solutions Architect at AWS, helping customers build scalable and cost-efficient AI/ML pipelines with Human in the Loop services. In his free time, Sundar loves traveling, sports and enjoying outdoor activities with his family.

Jat AI Stay informed with the latest in artificial intelligence. Jat AI News Portal is your go-to source for AI trends, breakthroughs, and industry analysis. Connect with the community of technologists and business professionals shaping the future.