The road to 1.0: production ready PyTorch
Dear PyTorch Users,
We would like to give you a preview of the roadmap for PyTorch 1.0 , the next release of PyTorch. Over the last year, we’ve had 0.2, 0.3 and 0.4 transform PyTorch from a [Torch+Chainer]-like interface into something cleaner, adding double-backwards, numpy-like functions, advanced indexing and removing Variable boilerplate. At this time, we’re confident that the API is in a reasonable and stable state to confidently release a 1.0.
However, 1.0 isn’t just about stability of the interface.
One of PyTorch’s biggest strengths is its first-class Python integration, imperative style, simplicity of the API and options. These are aspects that make PyTorch good for research and hackability.
One of its biggest downsides has been production-support. What we mean by production-support is the countless things one has to do to models to run them efficiently at massive scale:
- exporting to C++-only runtimes for use in larger projects
- optimizing mobile systems on iPhone, Android, Qualcomm and other systems
- using more efficient data layouts and performing kernel fusion to do faster inference (saving 10% of speed or memory at scale is a big win)
- quantized inference (such as 8-bit inference)
Startups, large companies and anyone who wants to build a product around PyTorch have asked for production support. At Facebook (the largest stakeholder for PyTorch) we have Caffe2, which has been the production-ready platform, running in our datacenters and shipping to more than 1 billion phones spanning eight generations of iPhones and six generations of Android CPU architectures. It has server-optimized inference on Intel / ARM, TensorRT support, and all the necessary bits for production. Considering all this value locked-in to a platform that the PyTorch team works quite closely with, we decided to marry PyTorch and Caffe2 which gives the production-level readiness for PyTorch.
Supporting production features without adding usability issues for our researchers and end-users needs creative solutions.
Production != Pain for researchers
Adding production capabilities involves increasing the API complexity and number of configurable options for models. One configures memory-layouts (NCHW vs NHWC vs N,C/32,H,W,32, each providing different performance characteristics), quantization (8-bit? 3-bit?), fusion of low-level kernels (you used a Conv + BatchNorm + ReLU, let’s fuse them into a single kernel), separate backend options (MKLDNN backend for a few layers and NNPACK backend for other layers), etc.
PyTorch’s central goal is to provide a great platform for research and hackability. So, while we add all these optimizations, we’ve been working with a hard design constraint to never trade these off against usability.
To pull this off, we are introducing torch.jit
, a just-in-time (JIT) compiler that at runtime takes your PyTorch models and rewrites them to run at production-efficiency. The JIT compiler can also export your model to run in a C++-only runtime based on Caffe2 bits.
In 1.0, your code continues to work as-is, we’re not making any big changes to the existing API.
Making your model production-ready is an opt-in annotation, which uses the torch.jit
compiler to export your model to a Python-less environment, and improving its performance. Let’s walk through the JIT compiler in detail.
torch.jit
: A JIT-compiler for your models
We strongly believe that it’s hard to match the productivity you get from specifying your models directly as idiomatic Python code. This is what makes PyTorch so flexible, but it also means that PyTorch pretty much never knows the operation you’ll run next. This however is a big blocker for export/productionization and heavyweight automatic performance optimizations because they need full upfront knowledge of how the computation will look before it even gets executed.
We provide two opt-in ways of recovering this information from your code, one based on tracing native python code and one based on compiling a subset of the python language annotated into a python-free intermediate representation. After thorough discussions we concluded that they’re both going to be useful in different contexts, and as such you will be able to mix and match them freely.
Tracing Mode
The PyTorch tracer, torch.jit.trace
, is a function that records all the native PyTorch operations performed in a code region, along with the data dependencies between them. In fact, PyTorch has had a tracer since 0.3, which has been used for exporting models through ONNX. What changes now, is that you no longer necessarily need to take the trace and run it elsewhere - PyTorch can re-execute it for you, using a carefully designed high-performance C++ runtime. As we develop PyTorch 1.0 this runtime will integrate all the optimizations and hardware integrations that Caffe2 provides.
The biggest benefit of this approach is that it doesn’t really care how your Python code is structured — you can trace through generators or coroutines, modules or pure functions. Since we only record native PyTorch operators, these details have no effect on the trace recorded. This behavior, however, is a double-edged sword. For example, if you have a loop in your model, it will get unrolled in the trace, inserting a copy of the loop body for as many times as the loop ran. This opens up opportunities for zero-cost abstraction (e.g. you can loop over modules, and the actual trace will be loop-overhead free!), but on the other hand this will also affect data dependent loops (think of e.g. processing sequences of varying lengths), effectively hard-coding a single length into the trace.
For networks that do not contain loops and if statements, tracing is non-invasive and is robust enough to handle a wide variety of coding styles. This code example illustrates what tracing looks like:
# This will run your nn.Module or regular Python function with the example
# input that you provided. The returned callable can be used to re-execute
# all operations that happened during the example run, but it will no longer
# use the Python interpreter.
from torch.jit import trace
traced_model = trace(model, example_input=input)
traced_fn = trace(fn, example_input=input)
# The training loop doesn't change. Traced model behaves exactly like an
# nn.Module, except that you can't edit what it does or change its attributes.
# Think of it as a "frozen module".
for input, target in data_loader:
loss = loss_fn(traced_model(input), target)
Script Mode
Tracing mode is a great way to minimize the impact on your code, but we’re also very excited about the models that fundamentally make use of control flow such as RNNs. Our solution to this is a scripting mode.
In this case you write out a regular Python function, except that you can no longer use certain more complicated language features. Once you isolated the desired functionality, you let us know that you’d like the function to get compiled by decorating it with an @script
decorator. This annotation will transform your python function directly into our high-performance C++ runtime. This lets us recover all the PyTorch operations along with loops and conditionals. They will be embedded into our internal representation of this function, and will be accounted for every time this function is run.
from torch.jit import script
@script
def rnn_loop(x):
hidden = None
for x_t in x.split(1):
x, hidden = model(x, hidden)
return x
Optimization and Export
Regardless of whether you use tracing or @script
, the result is a python-free representation of your model, which can be used to optimize the model or to export the model from python for use in production environments.
Extracting bigger segments of the model into an intermediate representation makes it possible to do sophisticated whole-program optimizations and to offload computation to specialized AI accelerators which operate on graphs of computation. We have already been developing the beginnings of these optimizations, including passes that fuse GPU operations together to improve the performance of smaller RNN models.
It also allows us to use existing high-performance backends available in Caffe2 today to run the model efficiently. Additionally, @script functions (and modules!) can be fully exported to ONNX in a way that retains their dynamic nature, such that you can easily run them in a Python-free environment using the model executors from Caffe2 or by transferring the model to any other framework supporting ONNX.
Usability
We care deeply about maintaining our current level of usability and we know that execution of the code not directly in Python leads to harder debugging, but this is something that we think about a lot, and we’re making sure that you’re not getting locked in to a completely different programming language.
First, we follow the principle of pay for what you use — if you don’t need to optimize or export your model, you do not have to use these new features and won’t see any downsides. Furthermore, use of traced or @script modules/functions can be done incrementally. For instance, all of these behaviors are allowed: You can trace part of your model and use the trace in a larger non-traced model. You can use tracing for 90% of your model, and use @script for the one sub-module that actually has some control flow in it. You can write a function using @script and have it call a native python function. If something appears incorrect in an @script function, you can remove the annotation and the code will execute in native python where it is easy to debug using your favorite tools and methods. Think of tracing and @script like type annotations using MyPy or TypeScript — each additional annotation can be tested incrementally, and none are required until you want to optimize or productionize.
Most importantly, these modes will be built into the core of PyTorch so that mixing and matching them with your existing code can happen seamlessly.
Note: The name JIT for these components is a bit of a misnomer and comes from historical reasons. The tracing/function execution in PyTorch started out as an optimizing JIT compiler that generated fused CUDA kernels but then grew to encompass optimization, @script, and export. When it is ready for release we will likely rename this functionality to the hybrid frontend, but we wanted to present it here as it is named in the code so that you can follow along as we develop it.
Other changes and improvements
Production support is the big feature for 1.0, but we will continue optimizing and fixing other parts of PyTorch as course of the standard release process.
On the backend side of things, PyTorch will see some changes, which might affect user-written C and C++ extensions. We are replacing (or refactoring) the backend ATen library to incorporate features and optimizations from Caffe2.
Last Words
We aim to release 1.0 some time during the summer. You can follow-along our progress on the Pull Requests page.
You can read this from the perspective of the Caffe2 project at: https://caffe2.ai/blog/2018/05/02/Caffe2_PyTorch_1_0.html