Back to News for Developers

Building a deep learning karaoke application with PyTorch

March 17, 2022BySuraj Subramanian

This blog was written in collaboration with Jesslyn Tannady

At a glance

Demuxr is a machine learning-based web app that takes an audio clip as input and outputs the individual instrument tracks that make up the original audio.

Think of this a bit like taking a song with lyrics and removing only the voice track (or vocals “stem”) so that you are left with an instrumental-only karaoke version of the song. To do this traditionally, you’d need to filter out anything that sounds like vocals (using various audio-engineering filters)—a process that is tedious and not always successful. Demuxr replaces the song-specific audio-engineering filters with an open source ML model (Demucs).

The cool thing is that Demuxr isn’t just good for isolating voice stems—it can isolate other instruments as well. So if you’re an aspiring guitarist and you want a version of your favorite song without the guitar part so that you can play along, Demuxr is just the thing for you.

This blog post walks through the inspiration of this project, how we used open source technologies to build the app, and a step-by-step walkthrough of how all the code fit together.

Let’s take a look at how this works

Demuxr is powered by Demucs, the open source ML model that does all the audio splitting under the hood.

Demucs is a deep neural network created by researchers at Meta and was open sourced in 2019. It provides state of the art audio splitting and generalizes surprisingly well to a variety of genres.

At a very high level, Demucs takes an audio track and converts it into an array of numbers called a “tensor”. Neural networks are massive number crunchers, and this tensor is how neural networks make sense of the sound in an audio track. Over the course of several neural layers, Demucs manipulates the numbers in the tensor to identify which numbers correspond to the bass, drums, vocals and the remaining audio stems.

You can read more about how Meta developed Demucs in this blog post, or this paper that dives into technical details.

Demuxr in action

demo gif

Once you upload the audio file, it first checks the cache to see if it has already been demuxed. If not, the ML model is invoked on the track. After it has separated the original input into 4 constituent tracks (“stems”) - bass, drums, other and vocals - you can hit play and adjust the volume level on each of these stems to make your own karaoke style remix. Check it out!

Let’s dive into the code

You can find the code for the Demuxr app here. Feel free to contribute bugs or patches!

Here’s how Demuxr turns your input into stems:

  1. The front end of the web app, built in React, prompts the user to upload an audio file.

  2. After the user hits “Go”, the app server (built in Flask) first hits the cache to see if it has already demuxed this audio before. If not, it converts the audio to OGG (a more efficient format for serving audio on the internet) and uploads it to an AWS S3 bucket. The audio is ready to be processed by the ML model.

  3. The model is hosted on TorchServe, an open source model server co-developed by engineers at Meta and AWS. TorchServe makes it really easy to deploy PyTorch models on the cloud. When TorchServe is started, the first thing it does is load concurrent instances of the ML model.

    MODEL_WEIGHTS = "https://dl.fbaipublicfiles.com/demucs/v3.0/demucs-e07c671f.th"
                   
    def load_model():
        state = torch.hub.load_state_dict_from_url(
           MODEL_WEIGHTS, map_location="cpu", check_hash=True
        )
        model = Demucs(["bass", "drums", "vocals", "other"])
        model.load_state_dict(state)
        model.eval()
        return model
            
  4. Once the model server receives the input track’s S3 URL, it downloads it, applies basic preprocessing and runs the model on it (“inference”).

    ...
                  
    def read_input(self, data):
        inp = data[0].get("data") or data[0].get("body")
        s3_folder = (inp["Bucket"], inp["Key"].split("/")[0])
        wav = read_ogg_from_s3(inp["Bucket"], inp["Key"])
        wav = wav.to(DEVICE)
        return wav, s3_folder
     
     
    def preprocess(self, wav):
        ref = wav.mean(0)
        wav = (wav - ref.mean()) / ref.std()
        logger.info(f"Processed audio into tensor of size {wav.size()}")
        return wav, ref
     
     
    def inference(self, wav, ref):
        if self.model is None:
            raise RuntimeError("Model not initialized")
        demuxed = apply_model(self.model, wav, 8)
        demuxed = demuxed * ref.std() + ref.mean()
        return demuxed
     
    ...
                
  5. The model outputs are still tensors; they haven’t been “encoded” or converted to a human-audible format yet. Encoding is the only step that takes place on the CPU; everything else is sped up on the GPU. To avoid bottlenecking the process, we deploy the encoding logic on serverless AWS Lambda. The model server simply uploads the inference output to S3 and proceeds with the next inference job.

    def postprocess(self, inference_output):
        stems = []
        for source in inference_output:
            source = source / max(
               1.01 * source.abs().max(), 1
            )
            source = (source * 2 ** 15).clamp_(-(2 ** 15), 2 ** 15 - 1).short()
            source = source.cpu().numpy()
            stems.append(source)
        return stems
     
     
    def cache(self, stems, s3_folder, fmt=None):
        bucket, folder = s3_folder
        key = folder + "/model_output.npz"
        source_names = ["drums", "bass", "other", "vocals"]
        stems = dict(zip(source_names, stems))
        with io.BytesIO() as buf_:
          	np.savez_compressed(buf_, **stems)
            buf_.seek(0)
            S3_CLIENT.upload_fileobj(buf_, bucket, key)
        return key
                
  6. We use sox, an open source audio processing library, to encode our tensors into OGG. Since the sox utility is not on the Lambda runtime, we need to compile it into a binary (see `encode/sox_build.sh` in the repo) and deploy it as a Lambda layer.

  7. Once the tensors are encoded and stored on S3 as audio files, they are served to the client. The React-based frontend enables you to seek the master track and adjust the volume on each stem individually, along with a play/pause button. Audio visualization and seeking is supported by the awesome wavesurfer.js library and its React wrapper wavesurfer-react.

PyTorch makes it really easy to build apps like Demuxr

A lot of the main functionality of Demuxr was built over a weekend using PyTorch and TorchServe. Being a PyTorch developer advocate, I am definitely biased, but I think it has the clearest syntax to build, train and extend ML and deep learning models into usable practical applications. Paired with a model server like TorchServe, developers can take advantage of numerous optimizations like distributed inference, pipeline integrations and optimizations like quantization.

If you’re interested in starting your PyTorch journey, check out the PyTorch website where you’ll find resources like the 101-level Learn the Basics tutorial as well as more advanced guides.

Interested in working on machine learning at Meta?

Researchers and engineers at Meta are exploring the potential of machine learning to solve difficult problems like identifying misinformation and hate speech on Meta platforms. If you’re interested in working with folks at the forefront of machine learning development to solve impactful problems, check out job postings here.

To learn more about Meta Open Source, visit our website, subscribe to our YouTube channel, or follow us on Twitter and Facebook and LinkedIn.