Inference AudioCraft MusicGen models using Amazon SageMaker

Music generation models have emerged as powerful tools that transform natural language text into musical compositions. Originating from advancements in artificial intelligence (AI) and deep learning, these models are designed to understand and translate descriptive text into coherent, aesthetically pleasing music. Their ability to democratize music production allows individuals without formal training to create high-quality […]

Aug 6, 2024 - 16:00
Inference AudioCraft MusicGen models using Amazon SageMaker

Music generation models have emerged as powerful tools that transform natural language text into musical compositions. Originating from advancements in artificial intelligence (AI) and deep learning, these models are designed to understand and translate descriptive text into coherent, aesthetically pleasing music. Their ability to democratize music production allows individuals without formal training to create high-quality music by simply describing their desired outcomes.

Generative AI models are revolutionizing music creation and consumption. Companies can take advantage of this technology to develop new products, streamline processes, and explore untapped potential, yielding significant business impact. Such music generation models enable diverse applications, from personalized soundtracks for multimedia and gaming to educational resources for students exploring musical styles and structures. It assists artists and composers by providing new ideas and compositions, fostering creativity and collaboration.

One prominent example of a music generation model is AudioCraft MusicGen by Meta. MusicGen code is released under MIT, model weights are released under CC-BY-NC 4.0. MusicGen can create music based on text or melody inputs, giving you better control over the output. The following diagram shows how MusicGen, a single stage auto-regressive Transformer model, can generate high-quality music based on text descriptions or audio prompts.

Music Generation Models - MusicGen Input Output flow

MusicGen uses cutting-edge AI technology to generate diverse musical styles and genres, catering to various creative needs. Unlike traditional methods that include cascading several models, such as hierarchically or upsampling, MusicGen operates as a single language model, which operates over several streams of compressed discrete music representation (tokens). This streamlined approach empowers users with precise control over generating high-quality mono and stereo samples tailored to their preferences, revolutionizing AI-driven music composition.

MusicGen models can be used across education, content creation, and music composition. They can enable students to experiment with diverse musical styles, generate custom soundtracks for multimedia projects, and create personalized music compositions. Additionally, MusicGen can assist musicians and composers, fostering creativity and innovation.

This post demonstrates how to deploy MusicGen, a music generation model on Amazon SageMaker using asynchronous inference. We specifically focus on text conditioned generation of music samples using MusicGen models.

Solution overview

With the ability to generate audio, music, or video, generative AI models can be computationally intensive and time-consuming. Generative AI models with audio, music, and video output can use asynchronous inference that queues incoming requests and process them asynchronously. Our solution involves deploying the AudioCraft MusicGen model on SageMaker using SageMaker endpoints for asynchronous inference. This entails deploying AudioCraft MusicGen models sourced from the Hugging Face Model Hub onto a SageMaker infrastructure.

The following solution architecture diagram shows how a user can generate music using natural language text as an input prompt by using AudioCraft MusicGen models deployed on SageMaker.

MusicGen on Amazon SageMaker Asynchronous Inference

The following steps detail the sequence happening in the workflow from the moment the user enters the input to the point where music is generated as output:

  1. The user invokes the SageMaker asynchronous endpoint using an Amazon SageMaker Studio notebook.
  2. The input payload is uploaded to an Amazon Simple Storage Service (Amazon S3) bucket for inference. The payload consists of both the prompt and the music generation parameters. The generated music will be downloaded from the S3 bucket.
  3. The facebook/musicgen-large model is deployed to a SageMaker asynchronous endpoint. This endpoint is used to infer for music generation.
  4. The HuggingFace Inference Containers image is used as a base image. We use an image that supports PyTorch 2.1.0 with a Hugging Face Transformers framework.
  5. The SageMaker HuggingFaceModel is deployed to a SageMaker asynchronous endpoint.
  6. The Hugging Face model (facebook/musicgen-large) is uploaded to Amazon S3 during deployment. Also, during inference, the generated outputs are uploaded to Amazon S3.
  7. We use Amazon Simple Notification Service (Amazon SNS) topics to notify the success and failure as defined as a part of SageMaker asynchronous inference configuration.

Prerequisites

Make sure you have the following prerequisites in place :

  1. Confirm you have access to the AWS Management Console to create and manage resources in SageMaker, AWS Identity and Access Management (IAM), and other AWS services.
  2. If you’re using SageMaker Studio for the first time, create a SageMaker domain. Refer to Quick setup to Amazon SageMaker to create a SageMaker domain with default settings.
  3. Obtain the AWS Deep Learning Containers for Large Model Inference from pre-built HuggingFace Inference Containers.

Deploy the solution

To deploy the AudioCraft MusicGen model to a SageMaker asynchronous inference endpoint, complete the following steps:

  1. Create a model serving package for MusicGen.
  2. Create a Hugging Face model.
  3. Define asynchronous inference configuration.
  4. Deploy the model on SageMaker.

We detail each of the steps and show how we can deploy the MusicGen model onto SageMaker. For sake of brevity, only significant code snippets are included. The full source code for deploying the MusicGen model is available in the GitHub repo.

Create a model serving package for MusicGen

To deploy MusicGen, we first create a model serving package. The model package contains a requirements.txt file that lists the necessary Python packages to be installed to serve the MusicGen model. The model package also contains an inference.py script that holds the logic for serving the MusicGen model.

Let’s look at the key functions used in serving the MusicGen model for inference on SageMaker:

def model_fn(model_dir):
    '''loads model'''
    model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-large")
    return model

The model_fn function loads the MusicGen model facebook/musicgen-large from the Hugging Face Model Hub. We rely on the MusicgenForConditionalGeneration Transformers module to load the pre-trained MusicGen model.

You can also refer to musicgen-large-load-from-s3/deploy-musicgen-large-from-s3.ipynb, which demonstrates the best practice of downloading the model from the Hugging Face Hub to Amazon S3 and reusing the model artifacts for future deployments. Instead of downloading the model every time from Hugging Face when we deploy or when scaling happens, we download the model to Amazon S3 and reuse it for deployment and during scaling activities. Doing so can improve the download speed, especially for large models, thereby helping prevent the download from happening over the internet from a website outside of AWS. This best practice also maintains consistency, which means the same model from Amazon S3 can be deployed across various staging and production environments.

The predict_fn function uses the data provided during the inference request and the model loaded through model_fn:

texts, generation_params = _process_input(data)
processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
inputs = processor (
    text = texts,
    padding=True,
    return_tensors="pt",
)

Using the information available in the data dictionary, we process the input data to obtain the prompt and generation parameters used to generate the music. We discuss the generation parameters in more detail later in this post.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
audio_values = model.generate(**inputs.to(device),
                                **generation_params)

We load the model to the device and then send the inputs and generation parameters as inputs to the model. This process generates the music in the form of a three-dimensional Torch tensor of shape (batch_size, num_channels, sequence_length).

sampling_rate = model.config.audio_encoder.sampling_rate
disk_wav_locations = _write_wavs_to_disk(sampling_rate, audio_values)
# Upload wavs to S3
result_dict["generated_outputs_s3"] = _upload_wav_files(disk_wav_locations, bucket_name)
# Clean up disk
for wav_on_disk in disk_wav_locations:
    _delete_file_on_disk(wav_on_disk)

We then use the tensor to generate .wav music and upload these files to Amazon S3 and clean up the .wav files saved on disk. We then obtain the S3 URI of the .wav files and send them locations in the response.

We now create the archive of the inference scripts and upload those to the S3 bucket:

musicgen_prefix = 'musicgen_large'
s3_model_key = f'{musicgen_prefix}/model/model.tar.gz'
s3_model_location = f"s3://{sagemaker_session_bucket}/{s3_model_key}"
s3 = boto3.resource("s3")
s3.Bucket(sagemaker_session_bucket).upload_file("model.tar.gz", s3_model_key)

The uploaded URI of this object on Amazon S3 will later be used to create the Hugging Face model.

Create the Hugging Face model

Now we initialize HuggingFaceModel with the necessary arguments. During deployment, the model serving artifacts, stored in s3_model_location, will be deployed. Before the model serving, the MusicGen model will be downloaded from Hugging Face as per the logic in model_fn.

huggingface_model = HuggingFaceModel(
    name=async_endpoint_name,
    model_data=s3_model_location,  # path to your model artifacts 
    role=role,
    env= {
           'TS_MAX_REQUEST_SIZE': '100000000',
           'TS_MAX_RESPONSE_SIZE': '100000000',
           'TS_DEFAULT_RESPONSE_TIMEOUT': '3600'
       },# iam role with permissions to create an Endpoint
    transformers_version="4.37",  # transformers version used
    pytorch_version="2.1",  # pytorch version used
    py_version="py310",  # python version used
)

The env argument accepts a dictionary of parameters such as TS_MAX_REQUEST_SIZE and TS_MAX_RESPONSE_SIZE, which define the byte size values for request and response payloads to the asynchronous inference endpoint. The TS_DEFAULT_RESPONSE_TIMEOUT key in the env dictionary represents the timeout in seconds after which the asynchronous inference endpoint stops responding.

You can run MusicGen with the Hugging Face Transformers library from version 4.31.0 onwards. Here we set transformers_version to 4.37. MusicGen requires at least PyTorch version 2.1 or latest, and we have set pytorch_version to 2.1.

Define asynchronous inference configuration

Music generation using a text prompt as input can be both computationally intensive and time-consuming. Asynchronous inference in SageMaker is designed to address these demands. When working with music generation models, it’s important to note that the process can often take more than 60 seconds to complete.

SageMaker asynchronous inference queues incoming requests and processes them asynchronously, making it ideal for requests with large payload sizes (up to 1 GB), long processing times (up to 1 hour), and near real-time latency requirements. By queuing incoming requests and processing them asynchronously, this capability efficiently handles the extended processing times inherent in music generation tasks. Moreover, asynchronous inference enables seamless auto scaling, making sure that resources are allocated only when needed, leading to cost savings.

Before we proceed with asynchronous inference configuration , we create SNS topics for success and failure that can be used to perform downstream tasks:

from utils.sns_client import SnsClient
import time
sns_client = SnsClient(boto3.client("sns"))
timestamp = time.time_ns()
topic_names = [f"musicgen-large-topic-SuccessTopic-{timestamp}", f"musicgen-large-topic-ErrorTopic-{timestamp}"]

topic_arns = []
for topic_name in topic_names:
    print(f"Creating topic {topic_name}.")
    response = sns_client.create_topic(topic_name)
    topic_arns.append(response.get('TopicArn'))

We now create an asynchronous inference endpoint configuration by specifying the AsyncInferenceConfig object:

# create async endpoint configuration
async_config = AsyncInferenceConfig(
    output_path=s3_path_join(
        "s3://", sagemaker_session_bucket, "musicgen_large/async_inference/output"
    ),  # Where our results will be stored
    # Add nofitication SNS if needed
    notification_config={
        "SuccessTopic": topic_arns[0],
        "ErrorTopic": topic_arns[1],
    },  #  Notification configuration
)

The arguments to the AsyncInferenceConfig are detailed as follows:

  • output_path – The location where the output of the asynchronous inference endpoint will be stored. The files in this location will have an .out extension and will contain the details of the asynchronous inference performed by the MusicGen model.
  • notification_config – Optionally, you can associate success and error SNS topics. Dependent workflows can poll these topics to make informed decisions based on the inference outcomes.

Deploy the model on SageMaker

With the asynchronous inference configuration defined, we can deploy the Hugging Face model, setting initial_instance_count to 1:

# deploy the endpoint
async_predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    async_inference_config=async_config,
    endpoint_name=async_endpoint_name,
)

After successfully deploying, you can optionally configure automatic scaling to the asynchronous endpoint. With asynchronous inference, you can also scale down your asynchronous endpoint’s instances to zero.

We now dive into inferencing the asynchronous endpoint for music generation.

Inference

In this section, we show how to perform inference using an asynchronous inference endpoint with the MusicGen model. For the sake of brevity, only significant code snippets are included. The full source code for inferencing the MusicGen model is available in the GitHub repo. The following diagram explains the sequence of steps to invoke the asynchronous inference endpoint.

MusicGen - Amazon SageMaker Async Inference Sequence Diagram

We detail the steps to invoke the SageMaker asynchronous inference endpoint for MusicGen by prompting a desired mood in natural language using English. We then demonstrate how to download and play the .wav files generated from the user prompt. Finally, we cover the process of cleaning up the resources created as part of this deployment.

Prepare prompt and instructions

For controlled music generation using MusicGen models, it’s important to understand various generation parameters:

generation_params = { 
    'guidance_scale': 3,
    'max_new_tokens': 1200, 
    'do_sample': True, 
    'temperature': 1 
}

From the preceding code, let’s understand the generation parameters:

  • guidance_scale – The guidance_scale is used in classifier-free guidance (CFG), setting the weighting between the conditional logits (predicted from the text prompts) and the unconditional logits (predicted from an unconditional or ‘null’ prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input prompt, usually at the expense of poorer audio quality. CFG is enabled by setting guidance_scale > 1. For best results, use guidance_scale = 3. Our deployment defaults to 3.
  • max_new_tokens – The max_new_tokens parameter specifies the number of new tokens to generate. Generation is limited by the sinusoidal positional embeddings to 30-second inputs, meaning MusicGen can’t generate more than 30 seconds of audio (1,503 tokens). Our deployment defaults to 256.
  • do_sample – The model can generate an audio sample conditioned on a text prompt through use of the MusicgenProcessor to preprocess the inputs. The preprocessed inputs can then be passed to the .generate method to generate text-conditional audio samples. Our deployment defaults to True.
  • temperature – This is the softmax temperature parameter. A higher temperature increases the randomness of the output, making it more diverse. Our deployment defaults to 1.

Let’s look at how to build a prompt to infer the MusicGen model:

data = {
    "texts": [
        "Warm and vibrant weather on a sunny day, feeling the vibes of hip hop and synth",
    ],
    "bucket_name": sagemaker_session_bucket,
    "generation_params": generation_params
}

The preceding code is the payload, which will be saved as a JSON file and uploaded to an S3 bucket. We then provide the URI of the input payload during the asynchronous inference endpoint invocation along with other arguments as follows.

The texts key accepts an array of texts, which may contain the mood you want to reflect in your generated music. You can include musical instruments in the text prompt to the MusicGen model to generate music featuring those instruments.

The response from the invoke_endpoint_async is a dictionary of various parameters:

response = sagemaker_runtime.invoke_endpoint_async(
    EndpointName=endpoint_name,
    InputLocation=input_s3_location,
    ContentType="application/json",
    InvocationTimeoutSeconds=3600
)

OutputLocation in the response metadata represents Amazon S3 URI where the inference response payload is stored.

Asynchronous music generation

As soon as the response metadata is sent to the client, the asynchronous inference begins the music generation. The music generation happens on the instance chosen during the deployment of the MusicGen model on the SageMaker asynchronous Inference endpoint , as detailed in the deployment section.

Continuous polling and obtaining music files

While the music generation is in progress, we continuously poll for the response metadata parameter OutputLocation:

from utils.inference_utils import get_output
output = get_output(sm_session, response.get('OutputLocation'))

The get_output function keeps polling for the presence of OutputLocation and returns the S3 URI of the .wav music file.

Audio output

Lastly, we download the files from Amazon S3 and play the output using the following logic:

from utils.inference_utils import play_output_audios
music_files = []
for s3_url in output.get('generated_outputs_s3'):
    if s3_url is not None:
        music_files.append(download_from_s3(s3_url))
play_output_audios(music_files, data.get('texts'))

You now have access to the .wav files and can try changing the generation parameters to experiment with various text prompts.

The following is another music sample based on the following generation parameters:

generation_params = { 'guidance_scale': 5, 'max_new_tokens': 1503, 'do_sample': True, 'temperature': 0.9 }
data = {
    "texts": [
        "Catchy funky beats with drums and bass, synthesized pop for an upbeat pop game",
    ],
    "bucket_name": sagemaker_session_bucket,
    "generation_params": generation_params
}

Clean up

To avoid incurring unnecessary charges, you can clean up using the following code:

import boto3
sagemaker_runtime = boto3.client('sagemaker-runtime')

cleanup = False # < - Set this to True to clean up resources.
endpoint_name = 

sm_client = boto3.client('sagemaker')
endpoint = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']
notification_config = endpoint_config['AsyncInferenceConfig']['OutputConfig'].get('NotificationConfig', None)
print(f"""
About to delete the following sagemaker resources:
Endpoint: {endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")
for k,v in notification_config.items():
    print(f'About to delete SNS topics for {k} with ARN: {v}')

if cleanup:
    # delete endpoint
    sm_client.delete_endpoint(EndpointName=endpoint_name)
    # delete endpoint config
    sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
    # delete model
    sm_client.delete_model(ModelName=model_name)
    print('deleted model, config and endpoint')

The aforementioned cleanup routine will delete the SageMaker endpoint, endpoint configurations, and models associated with MusicGen model, so that you avoid incurring unnecessary charges. Make sure to set cleanup variable to True, and replace with the actual endpoint name of the MusicGen model deployed on SageMaker. Alternatively, you can use the console to delete the endpoints and its associated resources that were created while running the code mentioned in the post.

Conclusion

In this post, we learned how to use SageMaker asynchronous inference to deploy the AudioCraft MusicGen model. We started by exploring how the MusicGen models work and covered various use cases for deploying MusicGen models. We also explored how you can benefit from capabilities such as auto scaling and the integration of asynchronous endpoints with Amazon SNS to power downstream tasks. We then took a deep dive into the deployment and inference workflow of MusicGen models on SageMaker, using the AWS Deep Learning Containers for HuggingFace inference and the MusicGen model sourced from the Hugging Face Hub.

Get started with generating music using your creative prompts by signing up for AWS. The full source code is available on the official GitHub repository.

References


About the Authors

Pavan Kumar Rao NavulePavan Kumar Rao Navule is a Solutions Architect at Amazon Web Services, where he works with ISVs in India to help them innovate on the AWS platform. He is specialized in architecting AI/ML and generative AI services at AWS. Pavan is a published author for the book “Getting Started with V Programming.” In his free time, Pavan enjoys listening to the great magical voices of Sia and Rihanna.

David John ChakramDavid John Chakram is a Principal Solutions Architect at AWS. He specializes in building data platforms and architecting seamless data ecosystems. With a profound passion for databases, data analytics, and machine learning, he excels at transforming complex data challenges into innovative solutions and driving businesses forward with data-driven insights.

Sudhanshu HateSudhanshu Hate is a principal AI/ML specialist with AWS and works with clients to advise them on their MLOps and generative AI journey. In his previous role before Amazon, he conceptualized, created, and led teams to build ground-up open source-based AI and gamification platforms, and successfully commercialized it with over 100 clients. Sudhanshu has to his credit a couple of patents, has written two books and several papers and blogs, and has presented his points of view in various technical forums. He has been a thought leader and speaker, and has been in the industry for nearly 25 years. He has worked with Fortune 1000 clients across the globe and most recently with digital native clients in India.

Rupesh BajajRupesh Bajaj is a Solutions Architect at Amazon Web Services, where he collaborates with ISVs in India to help them leverage AWS for innovation. He specializes in providing guidance on cloud adoption through well-architected solutions and holds seven AWS certifications. With 5 years of AWS experience, Rupesh is also a Gen AI Ambassador. In his free time, he enjoys playing chess.

Jat AI Stay informed with the latest in artificial intelligence. Jat AI News Portal is your go-to source for AI trends, breakthroughs, and industry analysis. Connect with the community of technologists and business professionals shaping the future.