Import a question answering fine-tuned model into Amazon Bedrock as a custom model

In this post, we provide a step-by-step approach of fine-tuning a Mistral model using SageMaker and import it into Amazon Bedrock using the Custom Import Model feature.

Sep 30, 2024 - 22:00
Import a question answering fine-tuned model into Amazon Bedrock as a custom model

Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs) from leading AI companies like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon through a single API, along with a broad set of capabilities to build generative AI applications with security, privacy, and responsible AI.

Common generative AI use cases, including but not limited to chatbots, virtual assistants, conversational search, and agent assistants, use FMs to provide responses. Retrieval Augment Generation (RAG) is a technique to optimize the output of FMs by providing context around the questions for these use cases. Fine-tuning the FM is recommended to further optimize the output to follow the brand and industry voice or vocabulary.

Custom Model Import for Amazon Bedrock, in preview now, allows you to import customized FMs created in other environments, such as Amazon SageMaker, Amazon Elastic Compute Cloud (Amazon EC2) instances, and on premises, into Amazon Bedrock. This post is part of a series that demonstrates various architecture patterns for importing fine-tuned FMs into Amazon Bedrock.

In this post, we provide a step-by-step approach of fine-tuning a Mistral model using SageMaker and import it into Amazon Bedrock using the Custom Import Model feature. We use the OpenOrca dataset to fine-tune the Mistral model and use the SageMaker FMEval library to evaluate the fine-tuned model imported into Amazon Bedrock.

Key Features

Some of the key features of Custom Model Import for Amazon Bedrock are:

  1. This feature allows you to bring your fine-tuned models and leverage the fully managed serverless capabilities of Amazon Bedrock
  2. Currently we are supporting Llama 2, Llama 3, Flan, Mistral Model architectures using this feature with a precisions of FP32, FP16 and BF16 with further quantizations coming soon.
  3. To leverage this feature you can run the import process (covered later in the blog) with your model weights being in Amazon Simple Storage Service (Amazon S3).
  4. You can even leverage your models created using Amazon SageMaker by referencing the Amazon SageMaker model Amazon Resource Names (ARN) which provides for a seamless integration with SageMaker.
  5. Amazon Bedrock will automatically scale your model as your traffic pattern increases and when not in use, scale your model down to 0 thus reducing your costs.

Let us dive into a use-case and see how easy it is to use this feature.

Solution overview

At the time of writing, the Custom Model Import feature in Amazon Bedrock supports models following the architectures and patterns in the following figure.

In this post, we walk through the following high-level steps:

  1. Fine-tune the model using SageMaker.
  2. Import the fine-tuned model into Amazon Bedrock.
  3. Test the imported model.
  4. Evaluate the imported model using the FMEval library.

The following diagram illustrates the solution architecture.

The process includes the following steps:

  1. We use a SageMaker training job to fine-tune the model using a SageMaker JupyterLab notebook. This training job reads the dataset from Amazon Simple Storage Service (Amazon S3) and writes the model back into Amazon S3. This model will then be imported into Amazon Bedrock.
  2. To import the fine-tuned model, you can use the Amazon Bedrock console, the Boto3 library, or APIs.
  3. An import job orchestrates the process to import the model and make the model available from the customer account.
    1. The import job copies all the model artifacts from the user’s account into an AWS managed S3 bucket.
  4. When the import job is complete, the fine-tuned model is made available for invocation from your AWS account.
  5. We use the SageMaker FMEval library in a SageMaker notebook to evaluate the imported model.

The copied model artifacts will remain in the Amazon Bedrock account until the custom imported model is deleted from Amazon Bedrock. Deleting model artifacts in your AWS account S3 bucket doesn’t delete the model or the related artifacts in the Amazon Bedrock managed account. You can delete an imported model from Amazon Bedrock along with all the copied artifacts using either the Amazon Bedrock console, Boto3 library, or APIs.

Additionally, all data (including the model) remains within the selected AWS Region. The model artifacts are imported into the AWS operated deployment account using a virtual private cloud (VPC) endpoint, and you can encrypt your model data using an AWS Key Management Service (AWS KMS) customer managed key.

In the following sections, we dive deep into each of these steps to deploy, test, and evaluate the model.

Prerequisites

We use Mistral-7B-v0.3 in this post because it uses an extended vocabulary compared to its prior version produced by Mistral AI. This model is straightforward to fine-tune, and Mistral AI has provided example fine-tuned models. We use Mistral for this use case because this model supports a 32,000-token context capacity and is fluent in English, French, Italian, German, Spanish, and coding languages. With the Mixture of Experts (MoE) feature, it can achieve higher accuracy for customer support use cases.

Mistral-7B-v0.3 is a gated model on the Hugging Face model repository. You need to review the terms and conditions and request access to the model by submitting your details.

We use Amazon SageMaker Studio to preprocess the data and fine-tune the Mistral model using a SageMaker training job. To set up SageMaker Studio, refer to Launch Amazon SageMaker Studio. Refer to the SageMaker JupyterLab documentation to set up and launch a JupyterLab notebook. You will submit a SageMaker training job to fine-tune the Mistral model from the SageMaker JupyterLab notebook, which can found on the GitHub repo.

Fine-tune the model using QLoRA

To fine-tune the Mistral model, we apply QLoRA and Parameter-Efficient Fine-Tuning (PEFT) optimization techniques. In the provided notebook, you use the Fully Sharded Data Parallel (FSDP) PyTorch API to perform distributed model tuning. You use supervised fine-tuning (SFT) to fine-tune the Mistral model.

Prepare the dataset

The first step in the fine-tuning process is to prepare and format the dataset. After you transform the dataset into the Mistral Default Instruct format, you upload it as a JSONL file into the S3 bucket used by the SageMaker session, as shown in the following code:

# Load dataset from the hub
dataset = load_dataset("Open-Orca/OpenOrca")
flan_dataset = dataset.filter(lambda example, indice: "flan" in example["id"], with_indices=True)
flan_dataset = flan_dataset["train"].train_test_split(test_size=0.01, train_size=0.035)

columns_to_remove = list(dataset["train"].features)
flan_dataset = flan_dataset.map(create_conversation, remove_columns=columns_to_remove, batched=False)

# save datasets to s3
flan_dataset["train"].to_json(f"{training_input_path}/train_dataset.json", orient="records", force_ascii=False)
flan_dataset["test"].to_json(f"{training_input_path}/test_dataset.json", orient="records", force_ascii=False)

You transform the dataset into Mistral Default Instruct format within the SageMaker training job as instructed in the training script (run_fsdp_qlora.py):

    ################
    # Dataset
    ################
    
    train_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.dataset_path, "train_dataset.json"),
        split="train",
    )
    test_dataset = load_dataset(
        "json",
        data_files=os.path.join(script_args.dataset_path, "test_dataset.json"),
        split="train",
    )

    ################
    # Model & Tokenizer
    ################

    # Tokenizer        
    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id, use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE
    
    # template dataset
    def template_dataset(examples):
        return{"text":  tokenizer.apply_chat_template(examples["messages"], tokenize=False)}
    
    train_dataset = train_dataset.map(template_dataset, remove_columns=["messages"])
    test_dataset = test_dataset.map(template_dataset, remove_columns=["messages"])

Optimize fine-tuning using QLoRA

You optimize your fine-tuning using QLoRA and with the precision provided as input into the training script as SageMaker training job parameters. QLoRA is an efficient fine-tuning approach that reduces memory usage to fine-tune a 65-billion-parameter model on a single 48 GB GPU, preserving the full 16-bit fine-tuning task performance. In this notebook, you use the bitsandbytes library to set up quantization configurations, as shown in the following code:

    # Model    
    torch_dtype = torch.bfloat16 if training_args.bf16 else torch.float32
    quant_storage_dtype = torch.bfloat16

    if script_args.use_qlora:
        print(f"Using QLoRA - {torch_dtype}")
        quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch_dtype,
                bnb_4bit_quant_storage=quant_storage_dtype,
            )
    else:
        quantization_config = None

You use the LoRA config based on the QLoRA paper and Sebastian Raschka experiment, as shown in the following code. Two key points to consider from the Raschka experiment are that QLoRA offers 33% memory savings at the cost of an 39% increase in runtime, and to make sure LoRA is applied to all layers to maximize model performance.

################
# PEFT
################
# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    )

You use SFTTrainer to fine-tune the Mistral model:

    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        dataset_text_field="text",
        eval_dataset=test_dataset,
        peft_config=peft_config,
        max_seq_length=script_args.max_seq_length,
        tokenizer=tokenizer,
        packing=True,
        dataset_kwargs={
            "add_special_tokens": False,  # We template with special tokens
            "append_concat_token": False,  # No need to add additional separator token
        },
    )

At the time of writing, only merged adapters are supported using the Custom Model Import feature for Amazon Bedrock. Let’s look at how to merge the adapter with the base model next.

Merge the adapters

Adapters are new modules added between layers of a pre-trained network. Creation of these new modules is possible by back-propagating gradients through a frozen, 4-bit quantized pre-trained language model into low-rank adapters in the fine-tuning process. To import the Mistral model into Amazon Bedrock, the adapters need to be merged with the base model and saved in Safetensors format. Use the following code to merge the model adapters and save them in Safetensors format:

        # load PEFT model in fp16
        model = AutoPeftModelForCausalLM.from_pretrained(
            training_args.output_dir,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16
        )
        # Merge LoRA and base model and save
        model = model.merge_and_unload()
        model.save_pretrained(
            sagemaker_save_dir, safe_serialization=True, max_shard_size="2GB"
        )

To import the Mistral model into Amazon Bedrock, the model needs to be in an uncompressed directory within an S3 bucket accessible by the Amazon Bedrock service role used in the import job.

Import the fine-tuned model into Amazon Bedrock

Now that you have fine-tuned the model, you can import the model into Amazon Bedrock. In this section, we demonstrate how to import the model using the Amazon Bedrock console or the SDK.

Import the model using the Amazon Bedrock console

To import the model using the Amazon Bedrock console, see Import a model with Custom Model Import. You use the Import model page as shown in the following screenshot to import the model from the S3 bucket.

After you successfully import the fine-tuned model, you can see the model listed on the Amazon Bedrock console.

Import the model using the SDK

The AWS Boto3 library supports importing custom models into Amazon Bedrock. You can use the following code to import a fine-tuned model from within the notebook into Amazon Bedrock. This is an asynchronous method.

import boto3
import datetime
br_client = boto3.client('bedrock', region_name='')
pt_model_nm = ""
pt_imp_jb_nm = f"{pt_model_nm}-{datetime.datetime.now().strftime('%Y%m%d%M%H%S')}"
role_arn = "<>"
pt_model_src = {"s3DataSource": {"s3Uri": f"{pt_pubmed_model_s3_path}"}}
resp = br_client.create_model_import_job(jobName=pt_imp_jb_nm,
                                  importedModelName=pt_model_nm,
                                  roleArn=role_arn,
                                  modelDataSource=pt_model_src)

Test the imported model

Now that you have imported the fine-tuned model into Amazon Bedrock, you can test the model. In this section, we demonstrate how to test the model using the Amazon Bedrock console or the SDK.

Test the model on the Amazon Bedrock console

You can test the imported model using an Amazon Bedrock playground, as illustrated in the following screenshot.

Test the model using the SDK

You can also use the Amazon Bedrock Invoke Model API to run the fine-tuned imported model, as shown in the following code:

client = boto3.client("bedrock-runtime", region_name="us-west-2")
model_id = "<>"


def call_invoke_model_and_print(native_request):
    request = json.dumps(native_request)

    try:
        # Invoke the model with the request.
        response = client.invoke_model(modelId=model_id, body=request)
        model_response = json.loads(response["body"].read())

        response_text = model_response["outputs"][0]["text"]
        print(response_text)
    except (ClientError, Exception) as e:
        print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
        exit(1)

prompt = "will there be a season 5 of shadowhunters"
formatted_prompt = f"[INST] {prompt} [/INST]"
native_request = {
"prompt": formatted_prompt,
"max_tokens": 64,
"top_p": 0.9,
"temperature": 0.91
}
call_invoke_model_and_print(native_request)

The custom Mistral model that you imported using Amazon Bedrock supports temperature, top_p, and max_gen_len parameters when invoking the model for inferencing. The inference parameters top_k, max_seq_len, max_batch_size, and max_new_tokens are not supported for a custom Mistral fine-tuned model.

Evaluate the imported model

Now that you have imported and tested the model, let’s evaluate the imported model using the SageMaker FMEval library. For more details, refer to Evaluate Bedrock Imported Models. To evaluate the question answering task, we use the metrics F1 Score, Exact Match Score, Quasi Exact Match Score, Precision Over Words, and Recall Over Words. The key metrics for the question answering tasks are Exact Match, Quasi-Exact Match, and F1 over words evaluated by comparing the model predicted answers against the ground truth answers. The FMEval library supports out-of-the-box evaluation algorithms for metrics such as accuracy, QA Accuracy, and others detailed in the FMEval documentation. Because you fine-tuned the Mistral model for question answering, you can use the QA Accuracy algorithm, as shown in the following code. The FMEval library supports these metrics for the QA Accuracy algorithm.

config = DataConfig(
    dataset_name="trex_sample",
    dataset_uri="data/test_dataset.json",
    dataset_mime_type=MIME_TYPE_JSONLINES,
    model_input_location="question",
    target_output_location="answer"
)
bedrock_model_runner = BedrockModelRunner(
    model_id=model_id,
    output='outputs[0].text',
    content_template='{"prompt": $prompt, "max_tokens": 500}',
)

eval_algo = QAAccuracy()
eval_output = eval_algo.evaluate(model=bedrock_model_runner, dataset_config=config, 
                                    prompt_template="[INST]$model_input[/INST]", save=True)

You can get the consolidated metrics for the imported model as follows:

for op in eval_output:
    print(f"Eval Name: {op.eval_name}")
    for score in op.dataset_scores:
        print(f"{score.name} : {score.value}")

Clean up

To delete the imported model from Amazon Bedrock, navigate to the model on the Amazon Bedrock console. On the options menu (three dots), choose Delete.

To delete the SageMaker domain along with the SageMaker JupyterLab space, refer to Delete an Amazon SageMaker domain. You may also want to delete the S3 buckets where the data and model are stored. For instructions, see Deleting a bucket.

Conclusion

In this post, we explained the different aspects of fine-tuning a Mistral model using SageMaker, importing the model into Amazon Bedrock, invoking the model using both an Amazon Bedrock playground and Boto3, and then evaluating the imported model using the FMEval library. You can use this feature to import base FMs or FMs fine-tuned either on premises, on SageMaker, or on Amazon EC2 into Amazon Bedrock and use the models without any heavy lifting in your generative AI applications. Explore the Custom Model Import feature for Amazon Bedrock to deploy FMs fine-tuned for code generation tasks in a secure and scalable manner. Visit our GitHub repository to explore samples prepared for fine-tuning and importing models from various families.


About the Authors

Jay Pillai is a Principal Solutions Architect at Amazon Web Services. In this role, he functions as the Lead Architect, helping partners ideate, build, and launch Partner Solutions. As an Information Technology Leader, Jay specializes in artificial intelligence, generative AI, data integration, business intelligence, and user interface domains. He holds 23 years of extensive experience working with several clients across supply chain, legal technologies, real estate, financial services, insurance, payments, and market research business domains.

Rupinder Grewal is a Senior AI/ML Specialist Solutions Architect with AWS. He currently focuses on serving of models and MLOps on Amazon SageMaker. Prior to this role, he worked as a Machine Learning Engineer building and hosting models. Outside of work, he enjoys playing tennis and biking on mountain trails.

Evandro Franco is a Sr. AI/ML Specialist Solutions Architect at Amazon Web Services. He helps AWS customers overcome business challenges related to AI/ML on top of AWS. He has more than 18 years of experience working with technology, from software development, infrastructure, serverless, to machine learning.

Felipe Lopez is a Senior AI/ML Specialist Solutions Architect at AWS. Prior to joining AWS, Felipe worked with GE Digital and SLB, where he focused on modeling and optimization products for industrial applications.

Sandeep Singh is a Senior Generative AI Data Scientist at Amazon Web Services, helping businesses innovate with generative AI. He specializes in generative AI, artificial intelligence, machine learning, and system design. He is passionate about developing state-of-the-art AI/ML-powered solutions to solve complex business problems for diverse industries, optimizing efficiency and scalability.

Ragha Prasad is a Principal Engineer and a founding member of Amazon Bedrock, where he has had the privilege to listen to customer needs first-hand and understands what it takes to build and launch scalable and secure Gen AI products. Prior to Bedrock, he worked on numerous products in Amazon, ranging from devices to Ads to Robotics.

Paras Mehra is a Senior Product Manager at AWS. He is focused on helping build Amazon SageMaker Training and Processing. In his spare time, Paras enjoys spending time with his family and road biking around the Bay Area.

Jat AI Stay informed with the latest in artificial intelligence. Jat AI News Portal is your go-to source for AI trends, breakthroughs, and industry analysis. Connect with the community of technologists and business professionals shaping the future.