1 min read

Using meta tensors to load models that don't fit in memory

PyTorch recently implemented a feature called meta tensors. These are tensors without any data. As I understand it, a meta tensor is just a shape and some hooks for recording operations performed on itself. If you add, subtract, multiply, etc. two meta tensors, you get another meta tensor. Probably using a meta tensor with a real tensor also produces a meta tensor.

Meta tensors enable one nice feature, which is the ability to load (in each process) only the rank-relevant weights. This is nice because for large models, we might be able to fit the whole thing into CPU memory, and we might be able to fit it in GPU memory after partitioning across all K of our GPUs, but we can't fit K complete copies of our large model in memory. I'm in this situation right now at work. To load the relevant model weights into GPU memory, each process has to load only the relevant weights.

DeepSpeed supports loading the rank-relevant weights. You initialize the model from its config, creating the tensors as meta tensors using something like deepspeed.OnDevice(device="meta"). This gives you a model object you can pass to DeepSpeed. You then pass a checkpoint base path and a checkpoint config path to DeepSpeed. Each rank will then load only its own parameters.

I don't understand quite how the DeepSpeed feature works, but here's my guess. It might work by "recording" the tensor parallelization operations on the meta tensors. I'm imagining that AutoTP's parallelization operations amount to "partition the columns/rows of this matrix among the processes, with each process-rank getting a disjoint subset of columns/rows." The recording is done so that we can figure out which weights this process needs to load from the checkpoint shard files.

For an example usage, see the DeepSpeed example code here. [ETA: Especially see the checkpoint JSON generation code which demonstrates the format of the checkpoint config file you have to pass DeepSpeed.] There's also a PyTorch Lightning blog post about doing this with Lightning models if you happen to use Lightning; I don't, but what I love about this blog post is the clear illustration by Phoeby Naren of how the partitioning works.