zou3519
released this
PyTorch 1.5.1 Release Notes
- Backwards Incompatible Changes
- Known Issues and Workarounds
- Critical Fixes
- Crashes and Error Fixes
- Other Fixes
Backwards Incompatible Changes
Autograd: Operations that return integer-type tensors now always returns tensors that don’t require grad (#37789).
This most notably affects torch.argmin
, torch.argmax
, and torch.argsort
.
This change is BC-Breaking because previously one could obtain an
integer-type tensor that requires grad in 1.5.0. However, said tensors
were not usable by autograd; calling .backward()
on them resulted in an error, so most users are likely to not have been relying on this behavior.
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> tensor = torch.randn(3, requires_grad=True)
>>> torch.argmax(tensor).requires_grad
True
| >>> tensor = torch.randn(3, requires_grad=True)
>>> torch.argmax(tensor).requires_grad
False
|
Known Issues and Workarounds
When using multiprocessing, PyTorch 1.5.1 and 1.5.0 may error out with complaints about incompatibility between MKL and libgomp (#37377)
You may see error messages like the following when using the torch.multiprocessing
package. This bug has primarily affected users with AMD CPUs.
`Error: mkl-service + Intel(R) MKL: MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library.
Try to import numpy first or set the threading layer accordingly. Set MKL_SERVICE_FORCE_INTEL to force it.`
You can get rid of the error and the error message by setting the environment MKL_THREADING_LAYER=GNU
. This can be done either by including the following in your python code:
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
or by specifying the environment variable when running your script:
MKL_THREADING_LAYER=GNU python my_script.py
To learn more about what triggers this bug and other workarounds if the above isn’t working, please read this comment on the issue.
Critical Fixes
torch.multinomial
: Fixed a bug where CUDA multinomial
generated the same sequence over and over again with a shift of 4. (#38046)
nn.Conv2d
: Fixed a bug where circular padding applied padding across the wrong dimension (#37881)
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> circular = nn.Conv2d(6, 1, (3, 3), padding=(0, 1), padding_mode='circular')
>>> circular(torch.zeros(1, 6, 10, 10)).shape
# Notice the padding is incorrectly on the H dimension, not the W dimension.
torch.Size([1, 1, 10, 8])
|
>>> tensor = torch.randn(3, requires_grad=True)
>>> other = tensor + 1
>>> output = nn.LeakyReLU(0, inplace=True)(other)
>>> output.sum().backward()
torch.Size([1, 1, 8, 10])
|
Fixed bug where asserts in CUDA kernels were mistakingly disabled, leading to many silent kernel errors. (#38943, #39047, #39218)
torch.gather
, torch.scatter
: added checks for illegal input dtypes that caused silently incorrect behaviors (#38025, #38646)
torch.argmin
, torch.argmax
: Fixed silently incorrect result for inputs with more than 2^32 elements (#39212)
C++ Custom Operators: fixed a bug where custom operators stopped working with autograd and ignored the requires_grad=True
flag. (#37355)
Crashes and Error Fixes
Fixed CUDA reduction operations on inputs with more than 2^32 elements (#37788)
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> `torch.zeros(5, 14400, 14400, device='cuda').sum(0)`
`RuntimeError: sub_iter.strides(0)[0] == 0 INTERNAL ASSERT FAILED at /pytorch/aten/src/ATen/native/cuda/Reduce.cuh:706, please report a bug to PyTorch.` |
>>> torch.zeros(5, 14400, 14400, device='cuda').sum(0)
# No problem
|
Fixed pickling of PyTorch operators (#38033)
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> `pickle.dumps(torch.tanh)`
PicklingError: Can't pickle : it's not the same object as torch._C._VariableFunctions
|
>>> pickle.dumps(torch.tanh)
# No problem
|
nn.LeakyReLU
: Fixed a bug where using autograd with in-place nn.LeakyReLu
with a slope of 0 incorrectly errored out. (#37453, #37559)
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> tensor = torch.randn(3, requires_grad=True)
>>> other = tensor + 1
>>> output = nn.LeakyReLU(0, inplace=True)(other)
>>> output.sum().backward()
RuntimeError: In-place leakyReLu backward calculation is triggered with a non-positive slope which is not supported. This is caused by calling in-place forward function with a non-positive slope, please call out-of-place version instead.
|
>>> tensor = torch.randn(3, requires_grad=True)
>>> other = tensor + 1
>>> output = nn.LeakyReLU(0, inplace=True)(other)
>>> output.sum().backward()
# No error
|
torch.as_strided
: Fixed crash when passed sizes
and strides
of different lengths. (#39301)
nn.SyncBatchNorm.convert_sync_batchnorm
: Fixed bug where it did not respect the devices of the original BatchNorm module, resulting in device mismatch errors (#39344)
nn.utils.clip_grad_norm_
: Fixed ability to operate on tensors on different devices (#38615)
torch.min
, torch.max
: added check for illegal output dtypes (#38850)
MacOS: Fixed import torch
error (#36941).
C++ Extensions: fixed compilation error when building with older versions of nvcc (#37221)
This bug mainly affected users of ubuntu 16.04. We’re certain it affected the following configurations:
- ubuntu 16.04 + cuda 9.2 + gcc 5
- ubuntu 16.04 + cuda 9.2 + gcc 7
- ubuntu 16.04 + cuda 10.0 + gcc 5
C++ Extensions: fixed ability to compile with paths that include spaces (#38860, #38670)
C++ Extensions: fixed ability to compile with relative include_dirs
for ahead-of-time compilation (#38264)
Other Fixes
nn.Conv1d
, nn.Conv2d
, nn.Conv3d
: Fixed a bug where convolutions were using more memory than previous versions of PyTorch. (#38674)
Fixed in-place floor division magic method (#38695)
In 1.5.0, the in-place floor division magic method mistakingly performed the floor division out-of-place. We’ve fixed this in 1.5.1.
Version 1.5.0 | Version 1.5.1 |
---|---|
>>> tensor = torch.ones(1)
>>> expected_data_ptr = tensor.data_ptr()
>>> tensor //= 1
>>> tensor.data_ptr() == expected_data_ptr
False
| >>> tensor = torch.ones(1)
>>> expected_data_ptr = tensor.data_ptr()
>>> tensor //= 1
>>> tensor.data_ptr() == expected_data_ptr
True
|
Documentation: fixed link to java docs. (#39039)
Quantization: Fixed weight quantization inaccuracies for LSTM (#35961)
Weight quantization was done incorrectly for LSTMs, the statistics for all weights (across layers) were combined in the observer. This meant that weights for later layers in a LSTM would use sub-optimal scales impacting accuracy. The problem gets worse as the number of layers increases.
DistributedDataParallel: Fixed single-process multi-GPU use case (#36503)
RPC: Fixed future callbacks not capturing and restoring autograd context id (#38512)
TorchScript: Fixed support with torch.unique
(#38156)
ONNX: Fix pow
operator export (#39791)
Assets
2
zou3519
released this
PyTorch 1.5.0 Release Notes
- Highlights
- Known Issues
- Backwards Incompatible Changes
- Python
- C++ API
- JIT
- Quantization
- RPC
- New Features
- Improvements
- Bug Fixes
- Performance
- Documentation
- Deprecations
- Python
- C++ API
- Miscellaneous
Highlights
This release includes several major new API additions and
improvements. These include new APIs for autograd allowing for easy
computation of hessians and jacobians, a significant update to the C++
frontend, ‘channels last’ memory format for more performant computer
vision models, a stable release of the distributed RPC framework used
for model parallel training, and a new API that allows for the creation
of Custom C++ Classes that was inspired by PyBind. Additionally torch_xla
1.5 is now available and tested with the PyTorch 1.5 release providing a mature Cloud TPU experience.
C++ Frontend API [Now Stable]
The C++ frontend API is now at parity with Python and the features overall has been moved to ‘stable’. (previously tagged as experimental). Some of the major highlights include:
- C++ torch::nn module/functional are now at ~100% parity with Python API, with appropriate documentation. Now users can easily translate their model from Python API to C++ API, making the model authoring experience much smoother.
- C++ optimizers now behave identically to the Python API. In the past, optimizers in C++ had deviated from the Python equivalent: C++ optimizers couldn’t take parameter groups as input while the Python ones could. Also step function implementations were not exactly the same. With the 1.5 release, C++ optimizers will always behave the same as the Python equivalent.
- New C++ tensor multi-dim indexing API which looks and behaves the
similar to the Python API. The previous workaround was to use a
combination of
narrow
/select
/index_select
/masked_select
, which is clunky and error-prone compared to the Python API’s eleganttensor[:, 0, ..., mask]
syntax. With the 1.5 release users can usetensor.index({Slice(), 0, "...", mask})
to achieve the same result.
Channels last memory format for Computer Vision models [Experimental]
Channels Last memory format is an alternative way of ordering NCHW tensors in memory while preserving the NCHW semantic dimensions ordering. Channels Last tensors are ordered in memory in such a way that channels become the densest dimension (aka storing images pixel-per-pixel).
Channels Last memory format unlocks the ability to use performance efficient convolution algorithms and hardware (NVidia’s Tensor Cores, FBGEMM, QNNPACK). Additionally it was designed to automatically propagate through the operators, which allows easy switching between memory layouts.
Learn more here on how to write memory format aware operators.
Custom C++ Classes [Experimental]
This release adds a new API for binding custom C++ classes into TorchScript and Python simultaneously. This API is almost identical in syntax to pybind11. It allows users to expose their C++ class and its methods to the TorchScript type system and runtime system such that they can instantiate and manipulate arbitrary C++ objects from TorchScript and Python. An example C++ binding:
template <class T>
struct MyStackClass : torch::CustomClassHolder {
std::vector<T> stack_;
MyStackClass(std::vector<T> init) : stack_(std::move(init)) {}
void push(T x) {
stack_.push_back(x);
}
T pop() {
auto val = stack_.back();
stack_.pop_back();
return val;
}
};
static auto testStack =
torch::class_<MyStackClass<std::string>>("myclasses", "MyStackClass")
.def(torch::init<std::vector<std::string>>())
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("size", [](const c10::intrusive_ptr<MyStackClass>& self) {
return self->stack_.size();
});
Which exposes a class you can use in Python and TorchScript like so:
@torch.jit.script
def do_stacks(s : torch.classes.myclasses.MyStackClass):
s2 = torch.classes.myclasses.MyStackClass(["hi", "mom"])
print(s2.pop()) # "mom"
s2.push("foobar")
return s2 # ["hi", "foobar"]
You can try it out in the tutorial here.
Distributed RPC framework APIs [Now Stable]
The torch.distributed.rpc
package aims at supporting a wide range of distributed training paradigms that do not fit into DistributedDataParallel
.
Examples include parameter server training, distributed model
parallelism, and distributed pipeline parallelism. Features in the torch.distributed.rpc
package can be categorized into four main sets of APIs.
- The RPC API allows running a function on a specified destination worker with given arguments and fetches the return value or creates a distributed reference to the return value.
- The RRef (Remote REFerence) serves as a reference to an object on another worker. A worker holding an RRef can explicitly request copies of the object, and it can also share the light-weight RRef with other workers without worrying about reference counting. This is especially useful when multiple workers need to repeatedly access different versions of the same remote object.
- With Distributed Autograd, applications can automatically compute gradients even if a model is split on multiple workers using RPC. This is achieved by stitching together local autograd graphs at RPC boundaries in the forward pass and reaching out to participants to transparently launch local autograd in the backward pass.
- The Distributed Optimizer uses gradients computed by Distributed Autograd to update model parameters. Its constructor takes a local optimizer (e.g.,
SGD
,Adagrad
, etc.) and a list of parameter RRefs, and itsstep()
function automatically uses the local optimizer to update parameters on all distinct RRef owner workers.
Learn more here.
torch_xla 1.5 now available
torch_xla is a Python package that uses the XLA linear algebra compiler to accelerate the PyTorch deep learning framework on Cloud TPUs and Cloud TPU Pods. torch_xla aims to give PyTorch users the ability to do everything they can do on GPUs on Cloud TPUs as well while minimizing changes to the user experience. This release of torch_xla is aligned and tested with PyTorch 1.5 to reduce friction for developers and to provide a stable and mature PyTorch/XLA stack for training models using Cloud TPU hardware. You can try it for free in your browser on an 8-core Cloud TPU device with Google Colab, and you can use it at a much larger scale on Google Cloud.
See the full torch_xla release notes here and the full docs here.
New High level autograd API [Experimental]
PyTorch 1.5 brings new functions including jacobian, hessian, jvp, vjp, hvp and vhp to the torch.autograd.functional.*
submodule. This feature builds on the current API and allow the user to easily perform these functions.
See the full docs here.
Python 2 no longer supported
For PyTorch 1.5.0 we will no longer support Python 2, specifically version 2.7. Going forward support for Python will be limited to Python 3, specifically Python 3.5, 3.6, 3.7 and 3.8 (first enabled in PyTorch 1.4.0).
Known Issues
torch.nn.parallel.DistributedDataParallel
does not work in Single-Process Multi-GPU mode.
DistributedDataParallel
(DDP) used to support two modes
- Single-Process Multi-GPU (SPMG): In this mode, each DDP process replicates the input
module
to all specified devices and trains on allmodule
replicas. This mode is enabled when application passes in adevice_ids
argument that contains multiple devices. Or ifdevice_ids
is not presented, DDP will try to use all available devices. - Multi-Process Single-GPU (MPSG): This is the recommended mode, as it is faster than SPMG. In this mode, each DDP process directly works on the provided
module
without creating additional replicas. This mode is enabled whendevice_ids
only contains a single device or if there is only one visible device (e.g., by settingCUDA_VISIBLE_DEVICES
).
A recent change (#33907) in torch.nn.parallel.replicate
breaks DDP’s assumption on replicated modules and leads to failures in
the SPMG mode. However, since SPMG is known to be slower due to GIL
contention and additional overhead caused by scattering input and
gathering output, we are planning to retire this mode in future releases
and make MPSG the only supported mode in DDP. The code below shows an
example of the recommended way to construct DDP.
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
# use "cuda:1" as the target device
target_device = 1
local_model = torch.nn.Linear(2, 2).to(target_device)
ddp_model = DDP(local_model, device_ids=[target_device])
See #36268 for more discussion.
Tensor.exponential_(0)
used to return Inf
, now it incorrectly returns 0
Previously in 1.4, x.exponential_(0)
gives a tensor full of inf
. On 1.5.0, it wrongly gives a tensor full of zeros.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.randn(3).exponential_(0)
tensor([inf, inf, inf])
|
>>> torch.randn(3).exponential_(0)
# This is wrong!
tensor([0., 0., 0.])
|
See #36798 for more details
Backwards Incompatible Changes
Python
Tensor.clone
, Tensor.to
, Tensor.empty_like
, and similar functions preserve stride information instead of returning contiguous tensors
clone
, to
, type
, cuda
, cpu
, byte
, char
, double
, bool
, half
, int
, long
, short
, float
, bfloat16
, empty_like
, full_like
, ones_like
, zeros_like
, rand_like
, randn_like
, randint_like
operators now propagate memory format (roughly, the strides) of the input tensor to the output tensor.
Since PyTorch operators generally support non-contiguous tensors, this should have no functional effect on most PyTorch programs.
The most common incompatibility with Python programs is with the view
operator, which has specific stride requirements. If these requirements
are no longer met as a result of this change, you will get an error
message indicating that you should use reshape instead, i.e.
"RuntimeError: view size is not compatible with input tensor's size and
stride (at least one dimension spans across two contiguous subspaces).
Use .reshape(...) instead."
Another possible exception incompatibility is if you have a (usually) C++ operator implementation that works directly on memory (i.e. calls data_ptr and relies on the strides being contiguous).
In the following example, we go through the implementation of a simple clone
operation and see how it needs to change between versions.
# Version 1.4.0
Tensor simple_clone(const Tensor& input) {
TORCH_CHECK(input.dim() == 1);
auto output = at::empty_like(input);
auto input_stride = input.strides()[0];
auto* output_ptr = output.data_ptr<float>();
auto* input_ptr = input.data_ptr<float>();
// Before 1.5.0, the result of `empty_like` is always contiguous.
for (int64_t idx = 0; idx < input.size(); idx++) {
output[idx] = input[idx * input_stride]
}
}
# Version 1.5.0
Tensor simple_clone(const Tensor& input) {
TORCH_CHECK(input.dim() == 1);
// From 1.5.0 on, the result of `empty_like` may not be contiguous.
auto output = at::empty_like(input);
// As a result, we need to keep track of the output stride.
auto input_stride = input.strides()[0];
auto output_stride = output.strides()[0];
auto* output_ptr = output.data_ptr<float>();
auto* input_ptr = input.data_ptr<float>();
for (int64_t idx = 0; idx < input.size(); idx++) {
output[idx * output_stride] = input[idx * input_stride]
}
}
The inferred dtype of np.float_, np.float64 scalars in tensor constructors (e.g. torch.tensor(...), torch.as_tensor(...) is now torch.float64 instead of the default dtype (usually torch.float32). (#30486 (#30486))
Please explicitly pass in the desired dtype when constructing tensors with NumPy float64 scalars to get the old behavior.
Version 1.4.0 | Version 1.5.0 |
---|---|
# Old behavior: return torch.float32 tensor (by default)
>>> torch.tensor(np.float64(0))
tensor(0.)
|
# To keep the old behavior, please explicitly pass the dtype
>>> torch.tensor(np.float64(0), dtype=torch.get_default_dtype())
tensor(0.)
|
This can cause your program to execute in torch.float64, potentially slowing down your program or can lead to errors for operators that don't support torch.float64 or mixed-dtypes.
numpy integer scalars are now treated as integers for the purposes of type promotion (#30486 (#30486))
Previously, in 1.4.0, they were mistakenly treated as floats (so for example, torch.ones(3) * np.int64(3) would return a float32 tensor. In 1.5.0, we’ve fixed that behavior; torch.ones(3) * np.int64(3) returns an int32 tensor.
This can cause your code to fail if you performed operations between PyTorch tensors and numpy scalars and then passed the result into an operation that does not support integral types or mixed types. To fix your code, please cast the resulting tensor to the desired dtype.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.ones(3) * np.int64(3)
tensor([3., 3., 3.])
|
>>> (torch.ones(3) * np.int64(3)).float()
tensor([3., 3., 3.])
|
numpy integer scalars are now treated as integers for the purposes of type promotion (#30486)
Previously, in 1.4.0, they were mistakenly treated as floats (so for example, torch.ones(3) * np.int64(3)
would return a float32 tensor. In 1.5.0, we’ve fixed that behavior; torch.ones(3) * np.int64(3)
returns an int32 tensor.
This can cause your code to fail if you performed operations between PyTorch tensors and numpy scalars and then passed the result into an operation that does not support integral types or mixed types. To fix your code, please cast the resulting tensor to the desired dtype.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.ones(3) * np.int64(3)
tensor([3., 3., 3.])
|
>>> (torch.ones(3) * np.int64(3)).float()
tensor([3., 3., 3.])
|
torch.autograd.Function
: dropped support for old-style Functions (#33956).
In previous versions of PyTorch, there were two ways to write autograd Functions. We deprecated one of them in 1.3.0 and dropped support for it entirely in 1.5.0. Old-style autograd Functions will no longer work in user code.
These Functions be identified by not having staticmethod
forward
and backward
functions (see the example below) Please see the current documentation for how to write new-style Functions.
# Version 1.4.0
class Exp(torch.autograd.Function):
def forward(self, i):
result = i.exp()
self.save_for_backward(result)
return result
def backward(self, grad_output):
result, = self.saved_tensors
return grad_output * result
Exp()(torch.tensor(1.))
# Version 1.5.0
class Exp(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
Exp.apply(torch.tensor(1.))
torch.optim
optimizers changed to fix in-place checks for the changes made by the optimizer (#33640, #34211)
If this causes your code to fail, there are two possible reasons:
Reason 1: The value of that parameter was actually saved and used and
we were computing incorrect gradients in previous versions of PyTorch.
This would result in an error message mentioning incorrect version
numbers. You should replace code that uses self.my_param
by self.my_param.clone()
to make sure the saved version is different from the one that is modified by the optimizer. For example:
Before 1.5.0, the following may have worked.
def model(input, target, param):
return `(input * param ** 2 - target).norm()`
param = torch.randn(2, requires_grad=True)
input = torch.randn(2)
target = torch.randn(2)
sgd = optim.SGD([param], lr=0.001)
loss = model(input, target, param)
loss.backward(retain_graph=True)
sgd.step()
loss.backward()
param.grad
If after upgrading to 1.5.0, the above fails due to a version counter
error, then that means the gradient computed was incorrect. To remedy
this, clone param
before using it in the model:
def model(input, target, param):
return (input * param ** 2 - target).norm()
param = torch.randn(2, requires_grad=True)
input = torch.randn(2)
target = torch.randn(2)
sgd = optim.SGD([param], lr=0.001)
loss = model(input, target, param.clone())
loss.backward(retain_graph=True)
sgd.step()
loss.backward()
param.grad
Reason 2: You know what you're doing and change the values back to the right thing before the next backward. However, you're running into an error because the version counter cannot be decremented. Open an issue with your particular use case and we will help you to work around the version counter issue.
utils.cpp_extensions
now use ninja
as the default compilation backend (#32495)
ninja
enables parallel compilation of your C++
extension, greatly speeding up compilation. This change will not break
most user code; if you do not have ninja
installed, we fallback to the old distutils
backend.
However, if you do have ninja
installed, it is possible
that this change will cause your C++ extension build to fail by
oversubscribing your system with too many worker processes. There are
two potential workarounds to this.
Method 1: If a previously succeeding python setup.py install
now fails, try setting the MAX_JOBS
environment variable.
Version 1.4.0 | Version 1.5.0 |
---|---|
python setup.py install
|
MAX_JOBS=2 python setup.py install
|
Method 2: Switch back to the old distutils
backend inside your setup.py
Version 1.4.0 | Version 1.5.0 |
---|---|
cmdclass={'clean': clean,
'build_ext': BuildExtension},
|
cmdclass={'clean': clean,
'build_ext': BuildExtension.with_options(use_ninja=False)},
|
torch.optim.Adam
, torch.optim.SGD
changed to not modify gradients in-place (#30257)
In previous versions of PyTorch, the Adam and SGD optimizers modified gradients (e.g. param.grad
) in-place via in-place addition of params.grad += weight_decay * param
.
To make this consistent with the behavior of other optimizers and to
prevent surprises about the behavior, we’ve changed them to stop
modifying gradients in-place.
This should not have an effect on most PyTorch programs unless they relied on this behavior. The easiest way to replicate the old behavior is to create a custom optimizer that implements it.
torch.masked_select
now always returns a 1D tensor (#29923)
The behavior of torch.masked_select
when both "self" and
"mask" are 0-dimensional was changed. In previous versions of PyTorch,
this would return a 0-dimensional tensor. Now, we return a 1-dimensional
tensor to be consistent with other input sizes and our documentation.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.masked_select(torch.tensor(0), torch.tensor(True))
tensor(0)
|
>>> torch.masked_select(torch.tensor(0), torch.tensor(True))
tensor([0])
|
torch.index_select
on a 0-d tensor now returns a 0-d tensor. (#30790)
In previous versions of PyTorch, the output of torch.index_select
on a 0D input tensor produced a 1D tensor. This was inconsistent with
our documentation on it, which stated "The returned tensor has the same
number of dimensions as the original tensor (input)." Now, we return a
0D tensor.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.index_select(torch.tensor(5), 0, torch.tensor([0]))
tensor([5])
|
>>> torch.index_select(torch.tensor(5), 0, torch.tensor([0]))
tensor(5)
|
nn.MultiLabelMarginLoss:
'none' reduction on 1D tensor now returns a 0D tensor (#30768)
In previous versions of PyTorch, the output of nn.MultiLabelMarginLoss
on 1D and 0D tensors incorrectly produced 1-D tensors. Now, those cases
return a 0D tensor to be consistent with the 2-D tensor case.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> nn.MultiLabelMarginLoss(reduction='none')(torch.randn(3), torch.zeros(3, dtype=torch.long))
tensor([0.2959])
|
>>> nn.MultiLabelMarginLoss(reduction='none')(torch.randn(3), torch.zeros(3, dtype=torch.long))
tensor(0.2959)
|
nn.MultiMarginLoss:
‘none' reduction on 1D target now returns a 1D tensor (#30826)
In previous versions of PyTorch, the output of nn.MultiMarginLoss
on a 1D target
tensor produced a 0D output. We changed this to return a 1D target
tensor to make it consistent with other input sizes which return an output that matches the target shape.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> nn.MultiMarginLoss(reduction='none')(torch.tensor([1.]), torch.tensor([0]))
tensor(0.)
|
>>> nn.MultiMarginLoss(reduction='none')(torch.tensor([1.]), torch.tensor([0]))
tensor([0.])
|
Tensor.exponential_(lambda)
no longer supports lambda < 0
(#32501)
lambda
, the rate parameter of the exponential distribution, mathematically should be greater than 0. We’ve disabled support lambda < 0
to be mathematically correct; most users will not have used a lambda less than zero.
Version 1.4.0 | Version 1.5.0 |
---|---|
tensor = torch.empty(3).exponential_(-1.5)
|
# Negative lambda not supported!
|
nn.BCELoss
, nn.functional.binary_cross_entropy
no longer accept inputs with the same number of elements that are not broadcastable (#31365)
Previously, we supported accepting inputs with the same number of
elements. However, this behavior was deprecated and we removed it in
1.5.0. In order to replicate the old behavior, please explicitly reshape
your input and target tensors to have the same shape.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> input = torch.rand(3, 3)
>>> target = torch.randn(9)
>>> torch.nn.functional.binary_cross_entropy(input, target)
|
>>> input = torch.rand(3, 3)
>>> target = torch.randn(9)
>>> torch.nn.functional.binary_cross_entropy(input, target.reshape_as(input))
|
torch.normal
out argument is now required to have the same size as the computed output (#32031)
Previously, on CPU devices, torch.normal(mean, std, out=out)
would resize out
to the correct size. To be consistent with the CUDA implementation, we’ve changed it so that out
must either already have the correct size, or be an empty tensor with size [0]
. To work around this, please ensure that your out
tensor has the correct size.
Version 1.4.0 | Version 1.5.0 |
---|---|
>>> torch.normal(torch.zeros(3), torch.ones(3), out=torch.randn(2))
tensor([ 0.0300, 0.7830, -1.3579])
|
>>> torch.normal(torch.zeros(3), torch.ones(3), out=torch.randn(2))
RuntimeError: inconsistent tensor, output size ([2]) is not the same as broadcasted mean and std size (3)
|
Tensor.geometric_
no longer supports integral Tensors (#31878)
Previously, on CPU devices, Tensor.geometric_
supported
Tensors with integral dtype. Now, it only supports floating point. We
removed support for this because it doesn’t make sense for geometric_
to operate on integral dtypes.
Changed torch.floor_divide
input
positional argument name to self
(#34552)
Before PyTorch 1.5, torch.floor_divide
took two positional arguments: torch.floor_divide(input, other)
. We’ve changed the name of the input
argument to self
; this will break code that called torch.floor_divide
via keyword argument. For example:
Version 1.4.0 | Version 1.5.0 |
---|---|
torch.floor_divide(input=x, other=y)
|
# Either of the following works.
torch.floor_divide(self=x, other=y)
torch.floor_divide(x, y)
|
C++ API
RNN / GRU / LSTM layers (#34322)
- Instead of returning
RNNOutput
, RNN / GRUforward
method now returnsstd::tuple<Tensor, Tensor>
, and LSTMforward
method now returnsstd::tuple<Tensor, std::tuple<Tensor, Tensor>>
, matching Python API. - LSTM forward method’s hidden state parameter now has type
torch::optional<std::tuple<Tensor, Tensor>>
, matching Python API. - RNN / LSTM / GRU layers now have
forward_with_packed_input
method which acceptsPackedSequence
as input and optionally hidden state, matching theforward(PackedSequence, ...)
variant in Python API. - RNN / LSTM / GRU layers no longer have these fields:
w_ih
/w_hh
/b_ih
/b_hh
. Instead, to access the weights and biases of the gates, users should do e.g.rnn->named_parameters()["weight_ih_l0"]
, which mirrors the Python APIrnn.weight_ih_l0
. - In
RNNOptions
tanh()
/relu()
/activation
are removed. Instead,nonlinearity
is added which takes eithertorch::kTanh
ortorch::kReLU
layers
is renamed tonum_layers
with_bias
is renamed tobias
- In
LSTMOptions
layers
is renamed tonum_layers
with_bias
is renamed tobias
- In
GRUOptions
layers
is renamed tonum_layers
with_bias
is renamed tobias
Upsample layer / F::interpolate function (#35025)
- There are changes to
UpsampleOptions
andInterpolateFuncOptions
:size
is changed fromstd::vector<int64_t>
toc10::optional<std::vector<int64_t>>
. If you want to pass a list ofint64_t
to this argument, you must pass it asstd::vector<int64_t>
.scale_factor
is changed fromstd::vector<double>
toc10::optional<std::vector<double>>
. If you want to pass a list ofdouble
to this argument, you must pass it asstd::vector<double>
.
- F::multilabel_margin_loss / F::multilabel_soft_margin_loss functions (#35163)
torch::nn::functional::MultiLabelMarginLossFuncOptions
is renamed totorch::nn::functional::MultilabelMarginLossFuncOptions
torch::nn::functional::MultiLabelSoftMarginLossFuncOptions
is renamed totorch::nn::functional::MultilabelSoftMarginLossFuncOptions
- The deprecated
torch::nn::BatchNorm
is removed in favor oftorch::nn::BatchNorm{1,2,3}d
- The deprecated
torch::nn::FeatureDropout
is removed in favor oftorch::nn::Dropout{2,3}d
- The deprecated
torch::nn::modules_ordered_dict
is removed. User should doSequential sequential({{"m1", MyModule(1)}, {"m2", MyModule(2)}})
instead. - The deprecated
torch::nn::init::Nonlinearity
is removed, in favor of these enums:torch::kLinear
/torch::kConv1D
/torch::kConv2D
/torch::kConv3D
/torch::kConvTranspose1D
/torch::kConvTranspose2D
/torch::kConvTranspose3D
/torch::kSigmoid
/torch::kTanh
/torch::kReLU
/torch::kLeakyReLU
- The deprecated
torch::nn::init::FanMode
is removed, in favor of these enums:torch::kFanIn
/torch::kFanOut
Optimizers
Optimizer::step
now accepts closure function as optional input and returns a tensor, andLossClosureOptimizer
is removed (#34790) (#34957). If you had a custom optimizer class defined as:
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
void step() override {...}
};
* you would need to update your optimizer class definition as follows:
struct MyOptimizer : Optimizer {
using Optimizer::Optimizer;
torch::Tensor step(LossClosure closure = nullptr) override {
...
// return `torch::Tensor()` if `closure` is nullptr
// (i.e. we are not computing the loss)
return torch::Tensor();
}
};
- Adagrad (#29335)
- In
AdagradOptions
,learning_rate
is renamed tolr
. - In
Adagrad
,sum_buffers
andstep_buffers
are now removed, and parameter state should be accessed by calling the accessors on the parameter’s corresponding state object. For example:
- In
auto& param_state = static_cast<AdagradParamState&>(
*optimizer.state()[c10::guts::to_string(parameter.unsafeGetTensorImpl())]);
// Use the following to access parameter state:
//
// param_state.sum()
// param_state.step()
- SGD (#32592)
- In
SGDOptions
,learning_rate
is renamed tolr
. - In
SGD
,momentum_buffers
is now removed, and parameter state should be accessed by calling the accessors on the parameter’s corresponding state object. For example:
- In
auto& param_state = static_cast<SGDParamState&>(
*optimizer.state()[c10::guts::to_string(parameter.unsafeGetTensorImpl())]);
// Use the following to access parameter state:
//
// param_state.momentum_buffer()
- Adam (#33730)
- In
AdamOptions
:learning_rate
is renamed tolr
beta1
andbeta2
are replaced by a tuplebetas
- In
Adam
,step_buffers
,exp_average_buffers
,exp_average_sq_buffers
andmax_exp_average_sq_buffers
are now removed, and parameter state should be accessed by calling the accessors on the parameter’s corresponding state object. For example:
- In
auto& param_state = static_cast<AdamParamState&>(
*optimizer.state()[c10::guts::to_string(parameter.unsafeGetTensorImpl())]);
// Use the following to access parameter state:
//
// param_state.step()
// param_state.exp_avg()
// param_state.exp_avg_sq()
// param_state.max_exp_avg_sq()
- RMSprop (#33450)
- In
RMSpropOptions
:learning_rate
is renamed tolr
- In
RMSprop
,square_average_buffers
,momentum_buffers
andgrad_average_buffers
are now removed, and parameter state should be accessed by calling the accessors on the parameter’s corresponding state object. For example:
- In
auto& param_state = static_cast<RMSpropParamState&>(
*optimizer.state()[c10::guts::to_string(parameter.unsafeGetTensorImpl())]);
// Use the following to access parameter state:
//
// param_state.square_avg()
// param_state.momentum_buffer()
// param_state.grad_avg()
-
- In
LBFGSOptions
:learning_rate
is renamed tolr
max_eval
‘s type is changed fromint64_t
toc10::optional<int64_t>
tolerance_grads type
is changed fromfloat
todouble
tolerance_change type
is changed fromfloat
todouble
history_size type
is changed fromsize_t
toint64_t
- In
LBFGS
,d
,H_diag
,prev_flat_grad
,t
,prev_loss
,ro
,al
,old_dirs
,old_stps
,func_evals
andstate_n_iter
are now removed, and parameter state should be accessed by calling the accessors on the parameter’s corresponding state object. For example:
- In
auto& param_state = static_cast<LBFGSParamState&>(
*optimizer.state()[c10::guts::to_string(parameter.unsafeGetTensorImpl())]);
// Use the following to access parameter state:
//
// param_state.d()
// param_state.H_diag()
// param_state.prev_flat_grad()
// param_state.t()
// param_state.prev_loss()
// param_state.ro()
// param_state.al()
// param_state.old_dirs()
// param_state.old_stps()
// param_state.func_evals()
// param_state.n_iter()
Removed AutoGIL/AutoNoGIL
in favor of pybind11::gil_scoped_*
functions (#34301)
If your code released or acquired the GIL via AutoNoGIL or AutoGIL, please change the invocations to pybind11::gil_scoped_release
or pybind11::gil_scoped_release
, respectively.
Others
torch::tensor(floating-point values)
will always produce tensor of default dtype, andtorch::tensor(integer values)
will always produce tensor oftorch::kLong
dtype, matching Python API behavior (#32367).torch::Tensor::base()
is renamed totorch::Tensor::_base()
, matching Python API. (#33316)- Renamed TensorTypeId to DispatchKey (#32154)
- Throw an error if nbytes is called on a sparse tensor. (#33897)
JIT
Simple Executor Is Now On By Default
The simple executor skips the number of fusion-related passes and analyses that are very time-consuming. Disabling these optimizations fixes pathologically long compilation times. The users that rely on GPU fusion to have their desired performance profile, should turn on the profiling executor. We provide C++ and python API to enable the profiling executor:
- in python, call
torch._C._jit_set_profiling_mode(True)
before you call your model for the first time. - in C++, include
#include <torch/csrc/jit/runtime/graph_executor.h>
and setgetProfilingMode() = true
before you invoke your model for the first time.
Quantization
Remove qconfig_dict in top level eager mode quantization API (#31972).
In eager mode quantization, one needs to manually insert quant and dequant stubs in a model to specify where activations are quantized. Having a qconfig_dict that specifies the quantization configuration for each module is not useful as one needs to manually modify the model with quant/dequant stubs. The new API makes it explicit that the model needs to be manually modified for quantization.
# previously qconfig_dict was an optional argument to prepare
def prepare(model, qconfig_dict=None, inplace=False):
# now replaced with
def prepare(model, inplace=False):
RPC
Functional API for Distributed Autograd and Distributed Optimizer
More specifically, callers must pass context_id
to torch.distributed.autograd.backward()
and torch.distributed.optim.step()
. (#33711)
# Before
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
with dist_autograd.context() as context_id:
# Forward pass.
rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
loss = rref1.to_here() + rref2.to_here()
# Backward pass.
dist_autograd.backward([loss.sum()])
# Optimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# After
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
with dist_autograd.context() as context_id:
# Forward pass.
rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
loss = rref1.to_here() + rref2.to_here()
# Backward pass.
dist_autograd.backward(context_id, [loss.sum()])
# Optimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
dist_optim.step(context_id)
Disallow sending CUDA tensors over RPC
The motivation is to prevent potential invalid device errors when the number of devices on the sender and the receiver does not match. However applications, can always move CUDA tensors to CPU before sending (#33604).
Version 1.4.0 | Version 1.5.0 |
---|---|
import torch
import torch.distributed.rpc as rpc
rpc.init_rpc("worker0", rank=0, world_size=2)
x = torch.zeros(2, device=0)
ret = rpc.rpc_sync("worker1", torch.add, args=(x, 3))
rpc.shutdown()
|
import torch
import torch.distributed.rpc as rpc
rpc.init_rpc("worker0", rank=0, world_size=2)
x = torch.zeros(2, device=0)
ret = rpc.rpc_sync("worker1", torch.add, args=(x.cpu(), 3))
rpc.shutdown()
|
New Features
Python
Added new functional autograd API (#34066)
- See Highlights for more details
New __torch_function__
API Override Mechanism (#30730, #32194, #32799, #34240, #34303).
We introduced __torch_function__
, an API override mechanism for subclassing torch.Tensor
in Python. This is useful for creating custom objects that implement the torch.*
APIs. These currently support overriding most torch.*
, and torch.nn.functional
APIs; we’ve also planned future support for subclassing torch.Tensor
(see tracking issue #22402).
New Operators
torch.logical_and
andtorch.logical_or
operations added (#30521).torch.square
added (#30719).torch.bitwise_and
added (#31104).torch.cummax
,torch.cummin
added (#32169, #32238, #32537, #33492).torch.floor_divide
,Tensor.floor_divide
added (#30493, #34552).torch.true_divide
,Tensor.true_divide
added, analogous to Python 's, and NumPy's (true) division (#34236, #34794)nn.functional.hardsigmoid
added(#34545).- Added PCA and SVD for low-rank matrices (
torch.pca_lowrank
,torch.svd_lowrank
),torch.lobpcg
for positive-defined generalized eigenvalue problem (#34721).
Distributions
distributions.von_mises
added (#33418).distributions.mixture_same_family
: Added support for mixture distributions (#22742, #33408).distributions.transforms.TanhTransform
added(#19785).distributions.continuous_bernoulli
added (#34619).
C++ API
- NN modules / functionals
- C++ tensor indexing (#30424, #32841, #30427, #34255)
- Please see docs: https://pytorch.org/cppdocs/notes/tensor_indexing.html
- Operators
- C++ API parity:
isinf
(#31099).
- C++ API parity:
- Autograd
- Add
at::Tensor::retain_grad
API (#33349).
- Add
- C++ extensions
Distributed
- Allows Python application to create subclass of C++
c10d.Store
using pybind11 trampoline class #30415.
Mobile
Quantization
- qnnpack TanH (#31013).
- Adding quantized clamp kernel (#30541).
- Quantized H Tangent function (#31031).
- QNNPACK: Add support for dynamic quantization. (#31896).
- Add operator support for dynamic quant on mobile (#32479).
- Adding native qconcat (#32252).
- FP16 dynamic quantized Linear (#32331).
- Add support for Dynamic LSTM quantization on Mobile (#32757).
- Quantized sigmoid function (#31851).
- Quantized leaky relu (#33004).
- Add a quantized batch_norm operator (#33080).
- Add Quantized BatchNorm2d module (#33109).
- Add the 3d avg pool for video related model (#33339).
- Add quantized_hardtanh (#34097).
- Add quantized ELU activation (#34267).
- Add the 3d upsample quantized op for video model (#34594).
- Add the quantized batch_norm3d and also batch_norm3d fused with relu operators (#34702).
- Add quantized implementation of hard sigmoid (#34607).
RPC
- [Experimental] Enable autograd profiler to work with RPC (#31381, #34398, #30677, #31346, #31380).
- [Experimental] Allow calling remote TorchScript functions using RPC (#32466, #33190, #32990, #32959, #33526, #33992, #33582, #32197, #33329, #34183).
Improvements
AMD/ROCm
nn.RNN
: Ensure MIOpen is called on same stream as operator (#30672)- Fixed asserts in CUDA kernels (#31276, #31297).
- Enable BFloat16 support for convolutions (#30948).
- Abstract atomic add calls (#31992).
- Install complete set of headers for ROCm build (#32076).
- Adjust
elementwise_kernel
settings on ROCm (#32609). nn.BatchNorm{1,2,3}d
: UseC10_WARP_SIZE
to fix functionality on HIP vs CUDA for gradient computation (#33098).- Enabled Bfloat16 type for activation functions and
batch_norm
(#32065). - Added ability to enable/disable MIOpen at runtime (#33118).
- Enable BFloat16 type for pooling ops (#34166).
torch.pdist
: improved precision by enabling double__shfl_down
(#34103).- Enabled BFloat16 type for loss functions and few misc ops required for resnet50 (#34469).
- Enabled BFloat16 type for EmbeddingBag, Index, and Sigmoid ops (#34630).
- Enabled 3D batch norms through MIOpen (#33262).
- Enabled 3D convolutions through ROCm (#33067).
nn.RNN
: Check if weights need to be flattened (#34265).
C++ API
- NN modules / functionals
- Allow skipping default arguments in module's forward method when module is used in
torch::nn::Sequential
(#33027) (#33718) - Make
torch::nn::Sequential::push_back(AnyModule)
methods public (#34208). - Refactor RNN / GRU / LSTM layers to match Python API (#34322).
- For
Conv{1,2,3}d
,padding_mode
now acceptstorch::kZeros
/torch::kReflect
/torch::kReplicate
/torch::kCircular
, matching Python API behavior. (#35023) - Fix
F::interpolate
andtorch::nn::Upsample
implementation to match Python API behavior (#35025) (#36274) - Renaming: MultiLabelMarginLossFuncOptions -> MultilabelMarginLossFuncOptions, MultiLabelSoftMarginLossFuncOptions -> MultilabelSoftMarginLossFuncOptions (#35163)
- Allow skipping default arguments in module's forward method when module is used in
- Optimizers
- All existing optimizers in the C++ API (Adagrad / SGD / Adam /
RMSprop / LBFGS) have the following changes to achieve parity with the
Python API: (#29335) (#30739) (#32592) (#33730) (#33450) (#34790) (#34564) (#34957) (#35001) (#36033) (#36245)
- step function implementation is changed to behave the same as Python equivalent
- Constructor now accepts
std::vector<OptimizerParamGroup>
as input optimizer.add_param_group(...)
can be used to add parameter group to an existing optimizeroptimizer.state()
should be used to access parameter state
- All existing optimizers in the C++ API (Adagrad / SGD / Adam /
RMSprop / LBFGS) have the following changes to achieve parity with the
Python API: (#29335) (#30739) (#32592) (#33730) (#33450) (#34790) (#34564) (#34957) (#35001) (#36033) (#36245)
- autograd
- Renamed
at::Tensor::base()
to_base()
, matching Python API (#33316)
- Renamed
Distributed
- Allow TCPStore to pick a port to bind to (#31674).
- Enhance NCCL watchdog to actively abort communicators for timed out ops (#32338).
- Adding DDP Design Note (#32158).
- Recommend using DDP over DataParallel (#35063)
Distributions
distributions.independent
: added explicit string representation (#33676).categorical.sample
: Reduced memory overhead (#34900).distributions.MultivariateNormal
: improved numeric stability and performance (#32092).
Mobile
- Add module level qpl logging. (#30906).
- Expose setNumThreads to android api (#31033).
- remove unused SparseCPUType from mobile build (#33517).
- make sure mobile build work with dynamic dispatch (#34038).
- support for custom mobile build with dynamic dispatch (#34055).
- Add watchOS support (#33318).
- speed_benchmark_torch switch to log latency from dataset level to row level (#34598).
ONNX
Exporting More Torch Operators to ONNX
In PyTorch 1.5, we have added support for 10 additional operators and also enhanced support for another set of 10+ existing operators. We have also added support for exporting large models (> 2GB) to ONNX. Additionally, we have made enhancements and optimizations to the export of ScriptModules and will continue to do that in the next release. We have also made improvements to the custom op export experience.
- Export dynamic unbind, split and getitem (#29136).
- Export torch.new_zeros (#34077).
- Export Im2col (#30972).
- Export bitwise_not for bool (#28439).
- Export logsoftmax with dim != -1 (#30433).
- Export einsum (#32716).
- Export aten::copy_ and aten::index_put to ONNX opset 11 (#26941).
- Export floor_divide (#31081).
- Export one_hot (#34454).
- Export torch.take (#33061).
- Export bool type index mask (#32445).
- Export split with list of sizes (#33161).
- Export scalar tensor for split (#32493).
- Export flatten to accept negative indices in opset 11 (#30751).
- Export sort with negative axes (#31971).
- Export Interpolate to support scale (#28324, #31526, #32554).
- Export quantized concat (#30887).
Enhancing the Support for ScriptModule
- Fixed access to element in size tensor for scripting (#32652).
- Export Conv in TorchScript module (#30618).
- Export Dim operation in TorchScript module (#31928).
- Export randnlike in TorchScript module (#32830).
- Partially support tensor lists in loop/concat/stack (#30126)
Enhancing Existing Export Logic
- Updating ONNX checker logic. (#33522).
- Adding ONNX large model export support in exporter (#33062).
- Extend op registration (#32943).
- Support op registration if name starts with underscore (#32017).
Optimizing Exported ONNX Graph
- Try exporting ONNX with force_outplace=False (#29466).
- Enable constant folding (#29834).
- Added cons folding for ONNX mul, div, sqrt ops (#32077).
- Enable constant folding for Reshape (#31054).
Adding Utility Functions and Refactoring
- Added ONNX model checker to ONNX export (#32298).
- Export custom ops (#29752).
- Upgrade exported ONNX IR version to 6 (#31025).
- Provide names for operator nodes in ONNX exported graph (#27342).
- Update ONNX landing page since 1.3 (#32805).
- Turn ONNX_ML into a proper build option (#33424).
Operator Benchmark
- Added small input shapes to test operator overhead (#30617).
- Added
binary_test
to benchmark binary ops (#31326). - Added
Tensor.copy_
operator (#31327). - Removed option to wipe cache because it did not help with variance (#31334).
- Added
torch.diag
(#32597).
Quantization
- Guard against copying from quantized Tensor to non-quantized Tensor (#29660).
- Add assert for min, max, qmin, qmax for ChooseQuantizationParams (#32739).
- Support broadcast for quantized mul kernel (#30442).
- Make FakeQuant use
REGISTER_DISPATCH
(#33682). - Set alias analysis kind to
FROM_SCHEMA
for qadd, qmul, qclamp, qconcat (#33359). - Migrate
fake_quant_slice
to TensorIterator (#33744). - Parallelize quantize and dequantize (#33765).
- Make FP16 RNN use new prepack op (#34339).
- Refactor QAT Conv module for better extensibility (#30362).
- Use non-inplace for insert observer pass (#34190).
RPC
- Add default arguments for
init_method
(#30208). - By default ignore RRef leaks during shutdown (#30217).
- Robustify
rpc_agent
handlers with generic Future (#31224). - Fix error message in incorrect
rref.localValue()
call (#31199). - Add
RpcAgent::getWorkerInfos()
API to return allWorkInfo
s in the group (#30241). - Add local shutdown to process group agent (#30330).
- Add
RRef.str()
API to return a string representation of the RRef (#30609). - Adding Debug Info for RRef Context (#30610).
- Add
get_metrics
andget_debug_info
to RPC agent (#30833). - Adding debugging metrics to process group agent (#30884).
- Add glue code to collect debug info from all components (#30888).
- Make RRef leak detection always print a warning log (#31922).
- Allow multiple backward passes to accumulate gradients. (#32506).
- Allow RRef local creation with IValue objects (#33263).
- Improve ProcessGroup
RpcBackendOptions
Constructor API (#34081). - Enhanced Error Reporting in Dist Autograd/RPC (#34179).
- Delete all user forks tracked in
RRefContext
before graceful shutdown (#31893). - Best-effort Error Detection for Using Deleted UserRRefs (#34673).
- Don't run user function until all UserRRefs in the args are confirmed (#34497).
- Support using self as the destination in
rpc.remote
for builtin operators (#34931). - Add debug info API for distributed autograd. (#30642).
- Propagate errors in
clearAndWaitForOutstandingRpcsAsync
. (#32952).
Type Hints
- DataLoader
default_collate
type hint added (#28935). Tensor.rsub, Tensor.rpow, Tensor.rtruediv, Tensor.map_
type hints were added (#30576).torch.optim
: added more missing type hints (#31130).nn.functional.grid_sample
,nn.functional.affine_grid
: added missing align_corners annotation (#32492).torch.nn.Parameter
constructor type hint was fixed (#32617).nn.MultiheadAttention
,nn.Transformer
: added type hints (#28396).torch.optim.LambdaLR
constructor type hint was fixed (#33271).torch.optim
: added missing default value forLRScheduler.step()
(#32411).- Make type of
Tensor.type()
more specific (#32353). torch.optim.optimizer.Optimizer
type hints were fixed (#32900).optim.AdamW
type hints were fixed (#34299).torch.utils.data.Sampler
subclasses type hints were added (#33679).nn.Sequential
,nn.ModuleList
,nn.ParameterList
,nn.ParameterDict
type hints were fixed (#33686).Tensor.bfloat16()
type hint was added (#33747).- Binary operator type hints were fixed (#33748).
torch.bfloat16
,nn.Module.training
,Tensor.cuda
, and 10s of other type hints added (#33762).torch.add
type hint was fixed(#33935).Tensor.shape
type hint was fixed (#34595).- Fixed
utils.data
imports (#33543). Tensor.__radd__
type hint was fixed (#35231)
Other
autograd.detect_anomaly
: added support for Sparse Tensors (#29803).autograd.detect_anomaly
: Error messages now print the current Node name (#33875).autograd.profiler
: added better error message when crashing while profiling multi-worker DataLoader (#31473).autograd.profiler
Enable usingtorch.autograd.profiler.record_function
as decorator (#30861).autograd.profiler
Speed upexport_chrome_trace
by up to 4x (#30724).torch.autograd
: added better error message when attempting to fork (#33885).torch.cuda.memory.caching_allocator_alloc
,torch.cuda.memory.caching_allocator_delete
exposed in Python API (#33860).torch.roll
: added bool tensor support (#31194).torch.flip
: added support for bool tensors (#31267).torch.equal
: added support for bfloat16 CPU scalar types (#30817).torch.save
,torch.load
: added error message for minimum dill version support (#30985).torch.diagonal
: added named tensor support(#30193).torch.linspace
: added support for integral types on CPU (#32218).torch.eig
: Added autograd support in the case where eigenvalues are real (#33090).torch.mvlgamma
: improved error message (#32665).torch.no_grad
,torch.enable_grad
: added support for decorating generator functions (#31792).torch.narrow
: added Tensor overload forstart
(#34317).Tensor.random_
: enabled support for half on CPU (#34030).Tensor.grad
: added warnings when accessing it if it won't be populated for known reasons (#30531).torch.cuda.comm.gather
: improved error message (#27456).nn.functional.max_pool{1,2,3}d
: added named tensor support (#31669).nn.Module.load_state_dict
: Include the contents of the exception in error messages (#32693).nn.MultiheadAttention
: add support for 3D attention mask (#31996).nn.MSELoss
: Added performance warning for using CPU Half (#33021).nn.ModuleList
,nn.ParameterDict
,nn.ParameterDict
: added more descriptive error messages when attempting to call these like Modules (#29991).nn.init.dirac_
: Addedgroups
option for compatibility with initializing group convolutions (#32825).- Added error message to indicate that reduction operations are not supported for dim >= 64 (#31476).
- Type Promotion: added supports for sparse tensors and arithmetic operations (#30429).
- Enabled indexing for bfloat16 tensors (#31692).
- Add 64-bit indexing support for CUDA Tensors (#33405).
- Added warning when converting a read-only NumPy array to
torch.Tensor
(#33615). - Set rpath for JNI library on Mac (#32247).
- Updated MAGMA to 2.5.2 for Windows (#30513, #34205).
- Marked PyTorch incompatible with Python-3.6.0 (#34724).
- Consider
hub_dir
alongsideTORCH_HOME
env variable for storing hub models (#32844). - Improved dll loading logic on Windows (#33856).
- Error out if legacy
Tensor.new
is called on alternate layouts or dtypes (#31485). utils.checkpoint.checkpoint_sequential
: Removed deprecated variadic arguments behavior (#25985).
Bug Fixes
C++ API
- NN modules / functionals
output_ratio
forFractionalMaxPool{2,3}d
module andfractional_max_pool{2,3}d
functional should accept double as data type (#33304)- For
AdaptiveAvgPool{2,3}d
andAdaptiveMaxPool{2,3}d
,output_size
is changed to acceptc10::nullopt
in its elements, matching Python API behavior. (#35022) - Fix bug in
fractional_max_pool3d_with_indices
implementation (#35024) - Remove
namespace F = torch::nn::functional
from torch/nn/modules/batchhnorm.h, so that people don't have to useF
to aliastorch::nn::functional
if they don't want to (#30684)
- autograd
- For
AutogradContext
,get_dirty()
is removed andget_and_bump_dirty()
is added, and the latter always bumps the version counter of the returned tensors (#33068) - Fix allow_unused checking for C++ API (#34035)
- Remove
using namespace torch::autograd
fromtorch/csrc/api/include/torch/nn/modules/_functions.h
(#34423)
- For
- Operators
torch::tensor(floating-point values)
will always produce tensor of default dtype, andtorch::tensor(integer values)
will always produce tensor oftorch::kLong
dtype, matching Python API behavior (#32367)- Fix
torch::allclose
to handlestd::numeric_limits::lowest()
for integral types (#32978) - Switch
torch::empty_like
to usemerge_in
to process TensorOptions (#33505)
Distributed
- Allow DDP to detect globally unused parameters (#28883).
- Accept url query when
rank
orworld_size
is specified in Process Groupinit_method
URL (#32016). - Add ability to abort NCCL communicators from the store. (#32895).
- Fix timeout support when initializing process group with TCP store (#33434).
- Abort NCCL communicators before throwing operation timed out (#31128).
- Fix logging for aborted communicators in ProcessGroupNCCL (#33147).
- Fix handling of replica parameters in DataParallel (#33907).
- Specify
requires_grad
for Parameter replica so it's not always set to True by default (#32356) - Put sparse
allreduce
results to input tensors (#32226) - Issue a warning when
zero_grad
is used inDataParallel
(#33064)
JIT
- TorchScript compilation fixed for (#33783):
torch.stft
torch.lu
,torch.lu_unpack
torch.cdist
torch.norm
tensor.tolist()
compilation now supported, requires output type annotation (#33472)
def foo(float_matrix, scalar_ten):
# type: (Tensor, Tensor) -> Tuple[List[List[float]], bool]
out1 : List[List[float]] = float_matrix.tolist()
out2 = torch.jit.annotate(bool, scalar_ten.tolist())
return out1, out2
torch.rand_like
and other_like
constructors no longer require additional arguments in TorchScript- Compilation for
nn.Module
APIs added (#29495):children
named_children
modules
named_modules
- Support for ModuleList Indexing with Integer Literal (#29236)
- Fixed flipped outputs for
PackedSequence
(#32955) - Support
index
andtype
properties onDevice
(#32953)device.index
device.type
- Add remaining
Tensor
properties (#33906)tensor.ndim
tensor.T
tensor.name
tensor.is_leaf
- Fix augmented assignment to non-tensor attributes #32993
- Fixed type resolution for function arguments #29623
- Previously we resolved types by parsing their names directly, but now TorchScript uses the value of the type directly from Python
- This allows types types like
torch.device
to be used
len
on tuples containing different types #35768
Mobile
- Fix exception message in Java Tensor (#30205).
- Fix the crashes for c++ not able to find java class through Jni (#30390).
- Add @DoNotStrip to nativeNewTensor method. (#30472).
- GenericDict/List type use unshapedType() (#30428).
- Support tensors with a storage offset in Java (#31584).
- Fix SIGABORT caused by double exception in PyTorchStreamReader when file not found. (#33243).
- Fix
SELECTED_OP_LIST
file path issue (#33942). - Fix for handling batch size 0. (#34599).
- fixed AutoGradMode/AutoNonVariableTypeMode uses for mobile callsites
- Use
gettimeofday
on iOS (#30361).
ONNX
- Fix
weight_norm
export for dim=0 (#31015). - Fix for constant folding flaky tests (#32546).
- Fix export for avg_pool with default stride (#33017).
- Fix ONNX CI by moving test data to aws (#33200).
- Fix for random generators export (#33789).
- Fix export of index_put (#31552).
- Fix for expand -1 dim value (#34069).
- Reduce ONNX test time on CI (#33242).
- ONNX Error Message on Missing Op (#33593).
- Fix exporting
copy_
with index as tensor input (#32801). - Fix for
rand_like
as well (#33095). - Added torchvision tests as part of ORT tests (#31835).
- Remove non-ascii character from
torch/onnx/symbolic_opset11.py
(#31814). - Add flag to enable script tests (#32654).
- Skip same tests in ONNX Python3 CI as in Python2 (#31827).
- Fixed
torch.mm
export (#34794) - Fixed
aten::size
for opset 11 (#35984)
Quantization
- Bug fix: Handle missing keys in observer state dict during load (#30357).
- Fix BC for quantized linear (#30481).
- Fix mapping white list to avoid attaching qconfig for DeQuantStub (#30636).
- Fix default instantation of dynamic quantized LSTM (#31433).
- Use default scale/zero_point in fake_quantize module instead of None (#32318).
- Fix ASAN / potential segfault in quantized Tensor memory allocations. (#29882).
- Don't serialize None values in observer (#32733).
- Enable inplace relu fusion for training (#33105).
- Bug fix in dynamic quantization kernels + better test coverage. (#33320).
- Run weight_post_process for QAT (#33852).
- Fix histogram observer to work with QAT on GPU (#34232).
- Fix the quantized batchnorm2d (#34579).
- Move QScheme ops to c10 (#30134).
- Remove incorrect fp16 dynamic linear/relu op (#32774).
RPC
- Fix serialization memory lifetime issue. (#30603).
- Don't crash callee when function does not exist on it, instead return an Exception (#32726).
- Throw the correct Exception on local client based on the
RemoteException
(#32936). - Attach autograd edges only for tensors requiring grad. (#30904).
WireSerializer
should checkhas_storage()
(#34626).- Fixed potential deadlock in python exception handling (#35283)
Other
-
torch.split
: Fixed incorrect gradient computation that assumed the output was not a view (#32044). -
Allowed numpy integer types to be used where we accept Python integers (#30486).
-
torch.unique
,torch.unique_consecutive
: fixed bug with zero-element input support (#31211). -
Tensor.to_sparse
: fixed backward in the non-contiguous tensor case (#31223). -
torch.index_put
: Added error checks for input tensors’ devices (#31280) (#31280). -
Ensure we switch the CUDA stream correctly in CUDA operations (#31537, #31538, #31541).
-
torch.SparseTensor
: ensure the legacy sparse constructor doesn't interpret Python data as tensor data. (#31490). -
torch.argmax
,torch.argmin
: Fixed incorrect behavior on large tensors (#33310). -
torch.div
: Fixed to throw an error when dividing by integer zero on CPU (#32629). -
torch.cos
: Fixed incorrect gradient computation caused by not properly initializing temporary vectors in avx2 code (#32722, #34281). -
torch.logspace
: Added missing integer dtype support, fixed precision issues in floating-point implementation (#32744). -
torch.prod
: Fixed behavior when passed atorch.half
input tensor andtorch.float
output tensor (#32831). -
torch.max
,torch.min
: Fixed NaN handling (#32541). -
torch.max
,torch.min
: Added error check that operand and outputs are on the same device type (#32862). -
torch.stack
: Added missing input size checks (#32931). -
torch.add
: Fixed memory leak on certain platforms (#32478). -
torch.normal
: Fixed shape checks (#33050). -
torch.cumsum
: fixed to handle inputs with zero-sized dimensions correctly (#31694). -
torch.device
: Disallow incorrectly formatted device strings (#29087). -
torch.cat
: Disallow passingout
as one of the input tensors (#30577). -
torch.pdist
: Added support for large batch sizes (#31593). -
torch.stft
: Fixed crash when used withnn.DataParallel
(#31861). -
torch.autograd
: Ensure the original grad mode is restored during backward (#31884). -
torch.autograd
: Fixed a race condition by locking graph_task before writing leaf_streams. (#31995) (#31995). -
torch.tensordot
: Fixed support for negative dimensions (#31954). -
torch.cumprod
: Fixed to handle inputs with zero-sized dimensions correctly (#32070). -
torch.pow
: Fixed the gradient computation when the base is a Tensor or Scalar of zeros (#32062, #32063). -
torch.baddbmm
: Fixed bug in corner case (#33538). -
torch.where
: Added check for consistent devices (#33432). -
torch.cdist
: Fixed gradient computation forp=2
and large inputs (#31167). -
torch.mv
: Fixed NaN handling (#31666). -
torch.index_put
: Added handling for large input tensors (#33753). -
torch.addmm
: Fixed incorrect output when using BLAS backend (#33819). -
torch.topk
fixed double backward when input has non-finite values (#35253) -
torch.load
: Avoid problematic pickle usages on Python 3.8.0 and 3.8.1 (#33824). -
Tensor.to
: Fixed race condition for gradient computation that spans CUDA devices (#31930). -
Tensor.random_
added check thatfrom
andto
are within the Tensor’s dtype bounds (#34033). -
Tensor.copy_
: Fixed memory overlap check and allowed outputs to be zero-strided tensors if the size is <= 1 along that dimension (#34100). -
nn.BatchNorm{1,2,3}d
: fixed gradient computation for empty inputs (#32820). -
nn.BatchNorm
: Fixed behavior for inputs with large batch sizes (#32763). -
nn.Conv2d
: Fixed 5d weight handling with MKLDNN backend (#34115). -
nn.Conv3d
: Fixed unstable gradient computation (#34358). -
nn.Conv{1,2,3}d
: added support for empty batch size(#32709). -
nn.Conv{1,2,3}d
: fixedCUDNN_STATUS_NOT_SUPPORTED
errors by trying multiple algorithms (#33073). -
nn.Conv{1,2,3}d
: fixed padding mode support and added additional padding modes (reflection and replication) (#31784). -
nn.Conv2d
,nn.Conv3d
,nn.Conv1d
,nn.ConvTranspose2d
: Fixed support for batch sizes greater than 2^32 (#31383, #31379, #31889, #34407,#31510). -
nn.InstanceNorm
,nn.GroupNorm
: Added error check for input with exactly one element (#29082). -
nn.RNN
: Fixed moving RNNs to a device after applying weight norm (#32563, #32989). -
nn.MultiLabelMarginLoss
: added support for 0-d tensors (#30765). -
nn.GroupNorm
: added support for empty batch (#32401). -
nn.NLLLoss
: fixed to support empty tensors on CUDA (#31491). -
nn.GroupNorm
: corrected input size check (#33008) -
nn.MultiLabelMarginLoss
: fixed memory leak on CUDA (#30767). -
nn.MultiMarginLoss
: fixed error checking on CUDA for the 1D case. (#30825). -
nn.Softmax
: Fixed half->float case of softmax backward (#30838). -
nn.Softshrink
: Added check that lambda is no less than zero (#33201). -
nn.functional.interpolate
: added support for empty batch size input for interpolate. (#32400). -
nn.functional.pad
: Also return a new tensor instead of sometimes returning a view (#32350). -
nn.functional.grid_sample
: Fixed gradient computation at image borders (#32829). -
nn.functional.leaky_relu_
: disabled incorrect leaky_relu_ negative slope backward calculation (#33639). -
optim.LambdaLR
: removed unintentional side effects (#32848). -
optim.Adam
,optim.AdamW
: Added missingweight_decay
parameter validation (#33126). -
optim.MultiStepLR
: Fix “unbound local variable” error by removing return value for__exit__
(#32997). -
optim.MultiStepLR
: Fixed brokenstep()
method (#33356). -
torch.autograd
: added new error message if incorrect usage would cause a deadlock (#32295). -
torch.autograd
: Prohibited copying autograd engines (#34567). -
torch.autograd
: Fixed incorrect handling of functions that return multiple views (#32790). -
autograd.Function
: Fixed error ifFunction
returned a view in atorch.no_grad
block (#33896). -
autograd.Function
: Added more error checks for incorrect behavior (#33069). -
autograd.Function
: Added nice error message if missing overrides (#33142). -
autograd.Function
: Fixed version check forgrad_fn
for views (#34145). -
autograd.profiler
: Fix incorrect chrome trace formatting output for CUDA traces (#33987). -
multiprocessing.util.register_after_fork
: fixed crash on Windows (#30809). -
utils.data.DataLoader
: Fixed potential hang when exiting main process (#33721). -
utils.tensorboard.SummaryWriter
fixedscale_factor
calculation for uint8 tensor (#31778). -
utils.tensorboard
Fix for when PyTorch model trace has RecursiveScriptModules (#30430). -
Fixed
CPU_INTEL
flag error on Windows (#30564). -
Don't use
RTLD_GLOBAL
to load_C
, resolving a multitude of weird segfaults and crashes
when PyTorch is imported along with other packages (#31162). -
Fixed dll load logic for Python 3.8 on Windows (#32215).
-
quasirandom.SobolEngine
: Fixed crash when default tensor type is CUDA (#32496). -
Fixed error message when converting NumPy array with negative strides to a
torch.Tensor
(#33254). -
Fixed crash when indexing a
torch.Tensor
with a single-element array (#33456). -
Fixed crash when converting CUDA tensors and non-strided tensors to NumPy arrays (#33612).
-
Prevented crash on exit from static destructor race on Windows (#33955).
-
Fixed uncaught
std::domain_error
on macOS (#34301). -
Don’t reset worker affinity when using operators that call into OpenMP (#29006).
-
torch.backends.mkldnn
: changed to be usable without import (#32055).
Performance
Mobile
- Java Tensor hybrid, owns at::Tensor, no memcopy for java outputs. (#30501).
- Tensor prep from image in native (#31426).
- Pass to remove prepacking ops. (#34319).
Quantization
- Per channel quantization performance improvement (#33772).
- Speed up per-channel min-max observer (#34118).
- Vectorized qmul and more methods on qint data types (#34376).
RPC
- Improve
ProcessGroupAgent
serialization speed (#29785). - Avoid sending large unneeded data over wire in
ProcessGroupAgent
. (#31357). - Integrate async mode for autograd engine with distributed autograd. (#31508).
- Make handling of
FORWARD_AUTOGRAD_REQ
inrequest_callback_impl
nonblocking (#32476).
Other
- Major multithreaded performance regression when doing operator calls resolved (#30333)
- Improved performance of comparison ops on CUDA (#29743).
Tensor.view
improved performance (#30554).- Improved tensor creation overhead (#30452, #30709)
nn.SmoothL1Loss
: vectorized gradient computation on CPU. (#30046).nn.EmbeddingBag
: improved performance on CPU (#30701, #27477).nn.LayerNorm
: optimized with explicit vectorization using Vec256 (#31127).Tensor.copy_
: fixed kernel speed regression introduced in #29631 (#31279).- Moved a number of debug asserts to not compile in release builds (#31240).
Tensor::has_names
sped up for unnamed tensors (#31436).torch.index_select
: optimized performance on CPU (#30598).nn.Conv{1,2,3}d
: Improved performance by refactoringbias
handling for cuDNN backend (#31524).torch.norm
: Optimized case wherep = 2
(#31903).nn.utils.clip_grad_norm_
: Refactored the computation for more performance (#32020).- Made an assert on a hotpath trigger only in DEBUG mode (#32117).
- First steps toward TensorIterator unrolling and vectorized load (#31974).
nn.functional.normalize
: changed to useclamp_min_
(#32360).- Stopped refreshing numel on a stride update (#32116).
nn.functional.softplus
: vectorized operator and gradient computation on CPU (#32944).torch.gather
regression fixed by not materializing loop vars in error message (#33108).nn.ELU
forward and backward vectorized on CPU (#32985, #32986)torch.cat
: optimized performance on CPU (#30806, #33534).torch.conv3d
: optimized Unfold3d to improve performance (#33191).- Workaround performance bug and memory leak in GOMP for AMD CPUs (#32875).
- Improved TensorIterator overhead (#33165).
torch.conv3d
: optimized Unfold3dAcc to improve gradient computation performance (#33317).torch.roll
improved performance (#33623).- Bounds checking for functor execution in vectorized/unrolled kernels (#33642).
nn.EmbeddingBag
: improved performance on CUDA (#33589).- Remove unnecessary tensor copies while calling operators (#33732).
- clang intrinsics targeting on Windows (#33958).
nn.Dropout
: added vectorized CUDA implementation (#33879).nn.UpSampleNearest{1, 2, 3}d
performance on CPU optimized (#31452) (#31452).- Remove
cudaMemcpy
on full memory overlap (#34548). - CUDA Loops: move address computation into policy, make
policy.load
load all arguments (#33720). nn.BatchNorm{1, 2, 3}d
contiguous case's performance improved (#34530).- Add the build for runtime dispatch for AVX, AVX2 instruction set (#26125).
nn.RReLU
performance improved up to 5x for inference on CPU (#31094).nn.LogSigmoid
performance improved up to 10x on CPU (#30958).torch.dist
performance improved up to 2x (#29714).torch.max
,torch.min
performance improved up to 1.5x on CPU (#33936).nn.GLU
performance improved up to 1.5X on CPU (#33179).nn.LeakyReLU
performance improved up to 4x (#29899).nn.HardTanh
performance improved up to 5x (#30152).
Documentation
Python
- Added documentation for
nn.functional.softplus
(#30055, #32945). torch.max
: Added warning about different, nondeterministic behavior on CPU and CUDA (#31115).- Clarified the documentation for
nn.NLLLoss
(#31488). - Exclude generated source docs from Google search indexing (#31484).
torch.poisson
docstring added to documentation (#31667) (#31667).torch.eq
fixed incorrect examples in documentation (#32399).torch.load
: added warning regarding pickle insecurity (#32593).optim.CosineAnnealingLR
: fixed the usage in examples (#31358).- Added doc previewing instructions (#31905).
- Removed legacy
.data
usages from thetorch.nn
documentation (#31481). - Fixed description of convolution modules (#30079).
Tensor.t()
,Tensor.permute()
,Tensor.unfold()
, andTensor.select()
clarified to note that they return views (#32512).torch.multiprocessing
Updated documentation indicating that start_method is ignored formp.spawn()
(#33070).- Improved CPU threading documentation (#33083).
nn.BCELoss
: documented how it avoids infinite results (#33160).nn.utils.rnn.pack_padded_sequence
: Improved the description ofenforce_sorted
(#33617).nn.utils.pad_packed_sequence
: doc improvement (#33768).nn.LPPool{1,2}d
: removed nonexistent parameter (#33714).- Created a Tensor View documentation page that documents all PyTorch operations that return views (#32560).
- Added grad context manager doc to top level torch module. (#33877).
- Enhanced reproducibility documentation (#33795).
- Numerous typo fixes (#30448, #30518, #30614, #30464, #30608, #24335, #34581, #34624, #34008, #31395, #31677, #31617, #31973, #32068, #33689, #30385, #32003, #31682, #30846, #33478, #33549, #32307, #33144, #33805, #33836, #34053).
- Numerous formatting and/or rendering fixes (#30377, #30779, #32667, #34027, #32911, #30814, #30815, #31760, #34503).
C++ API
- Fix
at::Tensor
docs generation and make it accessible again at https://pytorch.org/cppdocs/api/classat_1_1_tensor.html (#34467) - Add docs for all
torch::nn modules
and functionals (#34522) (#34688) (#34752) - Improve C++ autograd and tensor indexing docs (#35919)
- Fix example in
torch::nn::ModuleList
docs (#34463)
RPC
- Reorganize RPC API doc and add introduction (#30491, #35109).
- Make doc source format consistent in
rpc/init.cpp
(#30515). - Add examples to RRef doc (#30516).
- Add more details to explain
rpc_backend_options
arg ininit_rpc
(#30855). - Fix examples in API doc (#30856).
- Fix examples in RRef API doc (#30857).
- Document WorkerInfo and
RpcBackendOptions
structures in RPC docs. (#31077). - Explain RPC behavior when using Tensor as arg or return value (#31968).
- Update RPC docs to reflect correct use of dist_autograd backwards and dist_optim
step()
(#34670). - Minor doc tweak to use mp.spawn in example (#30381).
- Update distributed autograd note (#34657).
Mobile
- Add info about transitive dependencies in case of using local aars (#30128).
- Update Docs for building PyTorch for Android. (#32578).
- Javadoc changes (#31956).
Quantization
- Updates to quantization documentation (#30288).
- Fix docs so that the example works (#30120).
- Add the explicit per-tensor/per-channel quant info when we print the module (#30591).
- Fixed typos in quantization docs / docstrings (#34182).
- Docs entry for the
is_quantized
(#32075).
Deprecations
Python
How to figure out which line in your code is raising a warning
Attempting to use deprecated behavior will raise warnings.
Unfortunately, sometimes it is not entirely obvious what line of code
the warning corresponds to, especially if the the warning comes from our
C++ backend. For example, with a file named foo.py
with the following contents,
import torch
# This is newly deprecated behavior, see the next section
torch.tensor(1) / torch.tensor(2)
running it doesn’t give us the location of the warning:
> python foo.py
../aten/src/ATen/native/BinaryOps.cpp:81: UserWarning: Integer division of tensors using div or / is deprecated, and in a future release div will perform true
division as in Python 3. Use true_divide or floor_divide (// in Python) instead.
We can use the warnings
module to tell us where the warning is by asking it to treat warnings as errors:
import torch
import warnings
warnings.filterwarnings('error', message='Integer division')
# This is newly deprecated behavior, see the next section
torch.tensor(1) / torch.tensor(2)
Running the file now tells us exactly where the warning is:
> python foo.py
Traceback (most recent call last):
File "foo.py", line 5, in <module>
torch.tensor(1) / torch.tensor(2)
UserWarning: Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide
or floor_divide (// in Python) instead.
Deprecated torch.div
and torch.addcdiv
integer floor division behavior (#34570)
In 1.5.0 and older PyTorch releases torch.div
and the /
operator perform integer floor division. In a future PyTorch release, torch.div (including the /
operator) will perform "true" division as in Python3 and NumPy.
To floor divide integer tensors, please use torch.floor_divide
instead.
Before | After |
---|---|
>>> torch.tensor(3) / torch.tensor(2)
../aten/src/ATen/native/BinaryOps.cpp:81: UserWarning: Integer division of tensors using div or / is deprecated, and in a future release div will perform true division as in Python 3. Use true_divide or floor_divide (// in Python) instead.
tensor(1)
|
>>> NB: the following is equivalent to `torch.floor_divide(torch.tensor(3), torch.tensor(2))
>>> torch.tensor(3) // torch.tensor(2)
tensor(1)
|
The fix for torch.addcdiv
is similar.
Before | After |
---|---|
>>> input = torch.tensor(0)
>>> tensor = torch.tensor(1)
>>> other = torch.tensor(3)
>>> value = 1
>>> torch.addcdiv(input, tensor, other, value=value)
../aten/src/ATen/native/PointwiseOps.cpp:81: UserWarning: Integer division with addcdiv is deprecated, and in a future release addcdiv will perform a true division of tensor1 and tensor2. The current addcdiv behavior can be replicated using floor_divide for integral inputs (self + value * tensor1 // tensor2) and division for float inputs (self + value * tensor1 / tensor2). The new addcdiv behavior can be implemented with true_divide (self + value * torch.true_divide(tensor1, tensor2).
tensor(0)
|
>>> input = torch.tensor(0)
>>> tensor = torch.tensor(1)
>>> other = torch.tensor(3)
>>> value = 1
>>> (input + torch.floor_divide(value * tensor, other))
tensor(0)
|
Deprecated torch.full
returning float tensors if no dtype is specified (#34709).
In a future PyTorch release, torch.full
will infer its
dtype from its fill value when the optional dtype and out parameters are
unspecified, matching NumPy's inference for numpy.full
. For example, torch.full(size, 1)
will return a tensor of torch.long
dtype, unlike today where it returns a tensor of torch.float
dtype.
Deprecated torch.nn.modules.conv._ConvTransposeMixin
(#31784).
This is an internal-facing class that is not a part of our public API. We’ve refactored some PyTorch internals to work without it and will remove it in a future release.
Deprecated positional args in multiple torch
function signatures (#32009, #33428)
Below please find a list of deprecated signatures and what to change them to.
torch.add(self: Tensor, alpha: Scalar, other: Tensor)
,torch.sub(self: Tensor, alpha: Scalar, other: Tensor)
please usealpha
as a keyword-only arg instead of positional argstorch.addbmm(beta: Scalar, self: Tensor, alpha: Scalar, batch1: Tensor, batch2: Tensor)
: please usealpha
andbeta
as keyword only args instead of positional args.torch.addcdiv(self: Tensor, value: Scalar, tensor1: Tensor, tensor2: Tensor)
,torch.addmdiv(self: Tensor, value: Scalar, tensor1: Tensor, tensor2: Tensor)
: please usevalue
as a keyword-only argtorch.addmm(beta: Scalar, self: Tensor, alpha: Scalar, mat1: Tensor, mat2: Tensor)
,torch.sspaddmm(beta: Scalar, self: Tensor, alpha: Scalar, mat1: Tensor, mat2: Tensor)
please usealpha
andbeta
as keyword only args instead of positional args.torch.addmv(beta: Scalar, self: Tensor, alpha: Scalar, mat: Tensor, vec: Tensor)
: please usealpha
andbeta
as keyword only args instead of positional args.torch.addr(beta: Scalar, self: Tensor, alpha: Scalar, vec1: Tensor, vec2: Scalar)
: please usealpha
andbeta
as keyword only args instead of positional args.torch.baddbmm(beta: Scalar, self: Tensor, alpha: Scalar, batch1: Tensor, batch2: Tensor)
: please usealpha
andbeta
as keyword only args instead of positional args.
Before | After |
---|---|
>>> torch.zeros(2,3).add(2, torch.ones(2, 3))
../torch/csrc/utils/python_arg_parser.cpp:750: UserWarning: This overload of add is deprecated:
add(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add(Tensor other, Number alpha)
tensor([[2., 2., 2.],
[2., 2., 2.]])
|
>>> torch.zeros(2, 3).add(torch.ones(2, 3), alpha=2)
tensor([[2., 2., 2.],
[2., 2., 2.]])
|
Deprecate modifying in-place a view that returned by a custom autograd Function (#32839).
Modifying in-place a view that was created by a custom Function leads to the custom backward not being called or being called with a partial gradient. This behavior will be removed in 1.6.
Please clone() the output of the Function to avoid incorrect gradient computation.
class Id(Function):
@staticmethod
def forward(ctx, input):
return input.view_as(input)
@staticmethod
def backward(ctx, grad_input):
return grad_input
Version 1.5.0 | Version 1.5.0 |
---|---|
>>> input = torch.randn(3, requires_grad=True)
>>> other = torch.randn(3)
>>> output = Id.apply(input)
>>> output.copy_(other)
# Warning: Incorrect gradients
|
>>> input = torch.randn(3, requires_grad=True)
>>> other = torch.randn(3)
>>> output = Id.apply(input).clone()
>>> output.copy_(other)
|
Deprecate modifying in-place a view created inside a no_grad block (#32839)
Modifying in-place a view created inside a no_grad block is ambiguous and error-prone so we have deprecated it.
Here is an example of some code that we’ve deprecated. In previous versions of PyTorch, the following code throws a non-descriptive error message, but we've added a deprecation in 1.5.0.
>>> base = torch.rand(10, requires_grad=True)
>>> var = torch.rand([], requires_grad=True)
>>> with torch.no_grad():
>>> view = base[1]
>>> view.copy_(var)
>>> torch.autograd.grad(base.sum(), var)
RuntimeError: A view was created in no_grad mode and is being modified inplace with grad mode enabled. Given that this use case is ambiguous and error-prone,
it is deprecated and will be forbidden starting 1.6 (see https://github.com/pytorch/pytorch/pull/32839 for more details about this). You can clarify your code and remove this warning by moving both the view and the inplace either both inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want the inplace to be tracked).
If you want to differentiate, you should change the above code to
>>> base = torch.rand(10, requires_grad=True)
>>> var = torch.rand([], requires_grad=True)
>>> view = base[1]
>>> view.copy_(var)
>>> torch.autograd.grad(base.sum(), var)
(tensor(1.),)
If you don’t want to differentiate, you should change it to
>>> base = torch.rand(10, requires_grad=True)
>>> var = torch.rand([], requires_grad=True)
>>> with torch.no_grad():
>>> view = base[1]
>>> view.copy_(var)
C++ API
Deprecated Tensor.type()
(#30281)
Please use Tensor.options()
instead.
Miscellaneous
Assets
2
nairbv
released this
PyTorch 1.4.0 Release Notes
- Highlights
- Backwards Incompatible Changes
- Python
- JIT
- C++
- New Features
- torch.optim
- Distributed
- RPC [Experimental]
- JIT
- Mobile
- Improvements
- Distributed
- JIT
- Mobile
- Named Tensors
- C++ API
- AMD Support
- ONNX
- Quantization
- Visualization
- Other Improvements
- Bug Fixes
- Distributed
- RPC
- C++ API
- JIT
- Quantization
- Mobile
- Other Bug fixes
- Deprecations
- Performance
The PyTorch v1.4.0 release is now available.
The release contains over 1,500 commits and a significant amount of effort in areas spanning existing areas like JIT, ONNX, Distributed, Performance and Eager Frontend Improvements and improvements to experimental areas like mobile and quantization. It also contains new experimental features including rpc-based model parallel distributed training and language bindings for the Java language (inference only).
PyTorch 1.4 is the last release that supports Python 2. For the C++ API, it is the last release that supports C++11: you should start migrating to Python 3 and building with C++14 to make the future transition from 1.4 to 1.5 easier.
Highlights
PyTorch Mobile - Build level customization
Following the experimental release of PyTorch Mobile in the 1.3 release, PyTorch 1.4 adds additional mobile support including the ability to customize build scripts at a fine-grain level. This allows mobile developers to optimize library size by only including the operators used by their models and, in the process, reduce their on device footprint significantly. Initial results show that, for example, a customized MobileNetV2 is 40% to 50% smaller than the prebuilt PyTorch mobile library. Learn more about how to create your own custom builds, and please engage with the community on the PyTorch forums to provide any feedback you have.
Distributed Model Parallel Training [Experimental]
With the scale of models, such as RoBERTa, continuing to increase into the billions of parameters, model parallel training has become ever more important to help researchers push the limits. This release provides a distributed RPC framework to support distributed model parallel training. It allows for running functions remotely and referencing remote objects without copying the real data around, and provides autograd and optimizer APIs to transparently run backwards and update parameters across RPC boundaries.
To learn more about the APIs and the design of this feature, see the links below:
For the full tutorials, see the links below:
- A full RPC tutorial
- Examples using model parallel training for reinforcement learning and with an LSTM
As always, you can connect with community members and discuss more on the forums.
Java bindings [Experimental]
In addition to supporting Python and C++, this release adds experimental support for Java bindings. Based on the interface developed for Android in PyTorch Mobile, the new bindings allow you to invoke TorchScript models from any Java program. Note that the Java bindings are only available for Linux for this release, and for inference only. We expect support to expand in subsequent releases. See the code snippet below for how to use PyTorch within Java:
Learn more about how to use PyTorch from Java here, and see the full Javadocs API documentation here.
Pruning
Pruning functionalities have been added to PyTorch in the nn.utils.prune
module. This provides out-of-the-box support for common magnitude-based
and random pruning techniques, both structured and unstructured, both
layer-wise and global, and it also enables custom pruning from
user-provided masks.
To prune a tensor, first select a pruning technique among those available in nn.utils.prune
(or implement your own by subclassing BasePruningMethod
).
from torch.nn.utils import prune
t = torch.rand(2, 5)
p = prune.L1Unstructured(amount=0.7)
pruned_tensor = p.prune(t)
To prune a module, select one of the pruning functions available in nn.utils.prune
(or implement your own) and specify which module and which parameter within that module pruning should act on.
m = nn.Conv2d(3, 1, 2)
prune.ln_structured(module=m, name='weight', amount=5, n=2, dim=1)
Pruning reparametrizes the module by turning weight
(in the example above) from a parameter to an attribute, and replacing it with a new parameter called weight_orig
(i.e. appending "_orig"
to the initial parameter name
) that stores the unpruned version of the tensor. The pruning mask is stored as a buffer named weight_mask
(i.e. appending "_mask"
to the initial parameter name
). Pruning is applied prior to each forward pass by recomputing weight
through a multiplication with the updated mask using PyTorch's forward_pre_hooks
.
Iterative pruning is seamlessly enabled by repeatedly calling pruning
functions on the same parameter (this automatically handles the
combination of successive masks by making use of a PruningContainer
under the hood).
nn.utils.prune
is easily extensible to support new pruning functions by subclassing the BasePruningMethod
base class and implementing the compute_mask
method with the instructions to compute the mask according to the logic of the new pruning technique.
Backwards Incompatible Changes
Python
torch.optim
: It is no longer supported to use Scheduler.get_lr()
to obtain the last computed learning rate. to get the last computed learning rate, call Scheduler.get_last_lr()
instead. (26423)
Learning rate schedulers are now “chainable,” as mentioned in the New Features section below. Scheduler.get_lr
was sometimes used for monitoring purposes to obtain the current learning rate. But since Scheduler.get_lr
is also used internally for computing new learning rates, this actually
returns a value that is “one step ahead.” To get the last computed
learning rate, use Scheduler.get_last_lr
instead.
Note that optimizer.param_groups[0]['lr']
was in version 1.3.1 and remains in 1.4.0 a way of getting the current learning rate used in the optimizer.
Tensor.unfold
on a 0-dimensional Tensor now properly returns a 1-dimensional Tensor.
Version 1.3.1 | Version 1.4.0 |
---|---|
>>> torch.tensor(5).unfold(dimension=0, size=1, step=1)
tensor(5)
|
>>> torch.tensor(5).unfold(dimension=0, size=1, step=1)
tensor([5])
|
torch.symeig
now return a 0-element eigenvectors tensor when eigenvectors=False
(the default).
Version 1.3.1 | Version 1.4.0 |
---|---|
>>> torch.symeig(torch.randn(3,3)).eigenvectors.shape
torch.Size([3, 3])
|
>>> torch.symeig(torch.randn(3,3)).eigenvectors.shape
torch.Size([0])
|
JIT
- Make
torch.jit.get_trace_graph
private (it is nowtorch.jit._get_trace_graph
) (29149)- This function was intended only for ONNX integration; use
traced_module.graph
instead, like: - traced_module = torch.jit.trace(my_module, example_inputs)
traced_graph = traced_module.graph
- This function was intended only for ONNX integration; use
@property
onScriptModule
s has been disabled (28395)- Scripted
@property
accesses were silently broken before, where we would evaluate the theget
function once and store that as the attribute permanently. They properly error now; a workaround is to make your@property
a regular method.
- Scripted
- Custom ops:
torch::jit::RegisterOperators
has been removed, usetorch::RegisterOperators
instead (28229). The usage and behavior should remain the same. - Remove
torch.jit._register_*
bindings from Python (e.g.torch.jit._register_attribute
). These were private functions that were not intended to be used. (29499)
C++
[C++] The distinction between Tensor and Variable has been eliminated at the C++ level. (28287)
This change simplifies our C++ API and matches previous changes we did at the python level that merged Tensors and Variables into a single type.
This change is unlikely to affect user code; the most likely exceptions are:
-
Argument-dependent lookup for
torch::autograd
may no longer work. This can break because Variable is now defined as an alias for Tensor (using Variable = Tensor;
). In this case, you must explicitly qualify the calls totorch::autograd
functions. -
Because
Variable
andTensor
are now the same type, code which assumes that they are different types (e.g., for the purposes of templating, orstd::enable_if
checks) will not work until you delete the (now) redundant overload/specialization. -
Some operators may trace differently. If this happens, please file a bug. The most likely situations are:
- There are now more operations in your trace than before (usually, calls to
aten::empty
) - There are now less operations in your trace than before (e.g., the trace complains that
"there is no observable dependence"
with the inputs)
[C++] arguments in torch::nn::LinearOptions
are renamed to match the Python API. (27382)
- Arguments that are renamed:
in
->in_features
out
->out_features
with_bias
->bias
[C++] arguments in torch::nn::Conv{1,2,3}dOptions
are renamed to match the Python API. (28917) (29838)
- Arguments that are renamed:
input_channels
->in_channels
output_channels
->out_channels
with_bias
->bias
[C++] torch::nn::Conv{1,2,3}dOptions
no longer has the transposed
argument. (31005)
- If users have
transposed
originally set totrue
intorch::nn::Conv{1,2,3}dOptions
, they should migrate their code to usetorch::nn::ConvTranspose{1,2,3}d
layers instead.
[C++] All Reduction enums for torch::nn
layers and functionals are changed to have torch::KEnumNAME
syntax. (27942, 26837)
- Example: previously, to specify “mean” as the reduction method in a torch::nn layer or functional, we would use
torch::Reduction::Mean
. Now,torch::Reduction::Mean
has been renamed to the shortertorch::kMean
.
[C++] torch::tensor
constructor is improved to match Python API behavior. (28523) (29632) (29066)
- Shape checking fixes
- Example 1: previously,
torch::tensor({{1}, {2}})
produced a tensor of sizes{2}
. Now, it produces a tensor of sizes{2, 1}
. - Example 2: previously,
torch::tensor(1.1)
produced a 1-dim tensor. Now it produces a 0-dim tensor.
- Example 1: previously,
- Type inference improvements
- Example 1: previously, C++
torch::tensor
with a double (e.g.torch::tensor(1.1)
) or a (nested) braced-init-list of doubles (e.g.torch::tensor({{1.1, 2.2}})
produces a tensor with dtypetorch::kDouble
. Now it produces a tensor with dtypetorch::get_default_dtype()
. - Example 2: previously, C++
torch::tensor
with an integer type (e.g.torch::tensor(1)
) or a (nested) braced-init-list of integer types (e.g.torch::tensor({{1, 2}})
) produces a tensor with the same dtype. Now it always produces a tensor of dtypetorch::kLong
(aka.int64_t
). - Example 3: previously, when passed a
TensorOptions
without a dtype set to thetorch::tensor
constructor, it always produces a tensor of dtypetorch::get_default_dtype()
. Now it produces a tensor of different dtypes based on the dtype of the braced-init-list and the default dtype.
- Example 1: previously, C++
- Passing a
std::initializer_list
(NOT braced-init-list) totorch::tensor
will no longer compile, and the user should pass the equivalent braced-init-list totorch::tensor
instead. For example, writetorch::tensor({1.1, 1.2})
instead oftorch::tensor(std::initializer_list<double>({1.1, 1.2}))
.
[C++] Some activation modules’ forward
function now take Tensor
instead of Tensor&
as input. (28501)
torch::nn
layers affected: ELU
/ SELU
/ Hardtanh
/ LeakyReLU
/ ReLU
/ ReLU6
/ RReLU
/ CELU
This change ensures that the above layers can be used in a torch::nn::Sequential
module. If your C++ model uses any of the above layers, you must recompile your C++ code with the new libtorch binary.
New Features
torch.optim
Learning rate schedulers (torch.optim.lr_scheduler
) now
support “chaining.” This means that two schedulers can be defined and
stepped one after the other to compound their effect, see example below.
Previously, the schedulers would overwrite each other.
>>> import torch
>>> from torch.optim import SGD
>>> from torch.optim.lr_scheduler import ExponentialLR, StepLR
>>>
>>> model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
>>> optimizer = SGD(model, 0.1)
>>>
>>> scheduler1 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler2 = StepLR(optimizer, step_size=3, gamma=0.1)
>>>
>>> for epoch in range(4):
>>> print(epoch, scheduler2.get_last_lr()[0])
>>>
>>> optimizer.step()
>>> scheduler1.step()
>>> scheduler2.step()
0 0.1
1 0.09000000000000001
2 0.08100000000000002
3 0.00729000000000002
4 0.00656100000000002
Distributed
- Add
allgather_coalesced
API toProcessGroup
(28634,29059) - Add
abort
API inProcessGroupGloo
Send/Recv Work (29928). - Add
--no_python
flag to allow using a bash script wrapper in the launch command (29144).
RPC [Experimental]
torch.distributed.rpc
is a newly introduced package. It
contains basic building blocks to run functions remotely in model
training and inference, which will be useful for scenarios like
distributed model parallel or implementing parameter server frameworks.
More specifically, it contains four pillars: RPC, Remote Reference,
Distributed Autograd, and Distributed Optimizer. Please refer to the documentation and the tutorial for more details.
- Add
rpc_sync
andrpc_async
for builtin operators and Python user functions (23228, 23569, 28392). - Add
remote
andRRef
for builtin operators and Python user functions (25169, 25499). - Distributed Autograd - FAST mode backward pass implementation. (27022, 27576).
- Integrate
remote
andRRef
with distributed autograd (28630, 28656). - Add a distributed optimizer (29304, 30062).
- Add python API for
get_gradients()
method to retrieve gradients from distributed autograd context. (28926). - Support creating local
RRef
s on local values and to-selfremote
calls (28948, 29634). - Support custom pickler for RPC (30185).
- Add default RPC agent options based on the backend type (30201).
- Add local
shutdown
toProcessGroup
agent (30330).
JIT
script::Module
: implement more of of the nn.Module API (28828)- In particular, adds the (optionally recursive) methods that iterate over submodules, parameters, etc.
- Adds a pybind-like
attr()
method to simplify attribute access.
- Add support for
@staticmethod
onScriptModule
s (27163) - Support Module Containers as Iterables (26465)
- Support Iterables In List Comprehensions (26768)
- Dictionaries now preserve insertion order, and
OrderedDict
is supported (26465) - Add support for
hasattr()
(29332) - TorchScript classes can now be callable (26743)
- Add
clone_instance
forScriptModule
s (30168) - Add
torch.memory_format
support to the TorchScript (28544) - Custom
forward()
is now allowed on container modules (28988) - Calls to submodules are now preserved in the traced graph (29261)
- Add support for module containers to be used as iterables (28255)
- Make JIT Serialization support arbitrary std::function<> IO (28039)
- Support
layout()
in script (27100) - Methods and functions are no longer inlined in the serialized file format (26706)
Mobile
- Build level customization
Improvements
Distributed
Improvements
- Add timeout support in
ProcessGroupNCCL
(27224). - Ensure that DDP wrapped module has parameters that require gradients (25858).
- Making
torch/csrc/cuda
NCCL usage safe for NCCL 2.5 (29014). - Enable
test_distributed
for ROCm but only with NCCL backend (28814).
RPC Improvements
- Separate out RPC to
rpc_sync
andrpc_async
APIs (26570). - Make python user function serialization format to be consistent with builtin operators (27136).
- Clean up distributed autograd context on all participants on exit (27951).
- Improve error handling for distributed autograd engine. (27940).
- Scope pybind11 functions to
torch.distributed.{autograd,rpc}
(27529). - Lift
rpc_timeout
toRpcAgent
to make it reusable for otherRpcAgent
implementations. (29341). - Support sending message to self in
process_group_agent
(29253). - Properly shutdown RPC even in the case of
clean_shutdown=False
. (29148). - Ensure
initializedContextIds_
map is cleaned up appropriately in distributed autograd engine. (29787). - Add hash and equality operators for
WorkerInfo
(29958). - Add
RpcAgentOptions
struct type to bundle arguments for differentRpcAgent
s (29972). - Mark timeout
FutureMessage
s and throw exceptions inProcessGroupAgent
(29601). - Re-throw python remote exception when using remote reference to itself (29930).
- By default ignore
RRef
leaks during shutdown (30217).
Documentation
- Add Design doc for Distributed Autograd Engine (29175, 30068, 29927)
- Add Design doc for Remote Reference (30066).
- Add documentation page for
torch.distrbuted.rpc
(29276, 28030, 29971, 30160, 30050, 30069, 30179, 30218, 30240, 30243, 30259).
MISC
- Add known worker IDs to distributed autograd context (26324).
- Minor tweaks to RPC message API (28326).
- Rename
PythonUDF{Call,Resp}
(27530). - Use
std::shared_ptr
forDistAutogradContext
(29770). - Mark
c10d::~NCCLUtils
as noexcept (29118).
JIT
- Move custom passes to last optimization step (29256)
- Represent the original Python name of a module type the same way in traced and scripted modules. (29912)
- Only print original SourceRange on highlight (29708)
- Error message and ergonomic improvements:
- Show full call stack in TorchScript exception even when calls were inlined. (29911)
- Reduce error context from 10 -> 3 (26765)
- Fix error report highlight for unmatched type annotation (27195)
- Make default string arguments in schemas human readable (27088)
- Print which output didn't have dependence during trace checking. (29047)
- Improvements to save/load and serialization performance:
- Modules can now share JIT types if their implementation is the same, improving save/load performance (26666)
- Improve float pickling speed. (28553)
- Pickler: convert
std::stringstream
cases for improved performance. (29351) - Buffer to speed Unpickler (27727)
- Buffer in Pickler to improve performance. (27720)
- In
torch::save()
avoid zip compressing small header records. (28180) - String optimizations related to serialization. (28230)
- Clean up serialized source format (28129)
- API for finding a common ancestor block for a pair of nodes (28864)
- Make inserted child module names unique (27237)
- Better hashing for constant pool (27733)
- Improve error messages when a method or attribute is missing (27110)
- Display original source range in
Node::print
(27524) - Always use the closure to resolve variable names (27515)
Mobile
- Improve Java API / JNI
- Add module method to allow explicitly destructing native part (27090).
- Add methods to write image tensor content to buffer (27359).
- Various improvements to Android API (27454, 27455).
- Add support for PyTorch JNI build (29412, 42faf961c8, d22f61432d).
- Various fixes to PyTorch JNI (29350, 29861, 30206, 30207).
- Improve support for older Android NDK
- Improve error message, documentation, debuggability
- Improve support for benchmark and profiling
- Improve build / CI
- Improve Android Gradle build and publishing (26833, 27389, 29262, 29738).
- Misc fixes to the Android test project (27453).
- Improve XCode build script (27358, 28996, 29002).
- Add testing code to iOS CI jobs (27593, 27594, 27784, 30133).
- Misc fixes to the iOS TestApp (27591, 28356, 28809, 29247, 29962, 29963).
- Add support for host build to pytorch_android (27662,27664).
- Add host build Gradle publishing (29749).
- Add mobile build CI with host toolchain (30292).
Named Tensors
torch.addcdiv
,torch.addcmul
Added named tensor support (28975).torch.{ones,zeros,full,rand,randn}_like
Added named tensor support (28981).torch.cdist
Added named tensor support (29129).torch.equal
Added named tensor support (29322).- Added named tensor support for comparison ops (27162).
Tensor.align_to
Fixed error message (27221).Tensor.align_to
Make method-only. (27304).Tensor.align_to
Accept partially named tensors (27308).torch.mean(Tensor, Dimname)
Fixed autograd support (29199).Tensor.unflatten
Fix when dim is a negative integer (#31208) (31432).- Fix type errors in examples about Named Tensor (27828).
C++ API
New torch::nn modules
- Convolution layers
- Pooling layers
- Loss layers
- torch::nn::HingeEmbeddingLoss / CosineEmbeddingLoss /MultiMarginLoss (27101) (27345) (27424) (27770).
- torch::nn::TripletMarginLoss / SoftMarginloss / MultiLabelMargin / MarginRankingLoss / MultiLabelSoftMarginLoss (27713, 27956) (27660) (27659) (29000) (27669).
- torch::nn::MSELoss / KLDivLoss / BCELoss / SmoothL1Loss / PoissonNLLLoss / BCEWithLogitsLoss (27156) (28806) (30146) (27661) (28755) (28783).
- torch::nn::NLLLoss / CrossEntropyLoss / CTCLoss (29812) (28654).
- Normalization Layers
- Activation Layers
- torch::nn::ELU / LeakyReLU / SELU / PReLU / ReLU / ReLU6 / RRelu / CELU / GLU (27028) (27059) (27434) (27429) (27435) (27436) (27437) (27487) (29922).
- torch::nn::Sigmoid / LogSigmoid / LogSoftmax / Softmax / Softmax2d / Softplus / Softmin / Softsign / Softshrink / Hardshrink / Hardtanh / Tanh / Threshold (27488) (27060) (27462) (27446) (27509) (27489) (27459) (27535) (27534) (27035) (27537) (27038) (27536) (27538).
- Dropout Layers
- Padding Layers
- Embedding layers
- torch::nn::Embedding / EmbeddingBag (26358).
- Linear layers
- Vision layers
New torch::nn::functional functions
- Convolution functions
- Pooling functions
- Loss functions
- torch::nn::functional::hinge_embedding_loss / multi_margin_loss / multilabel_soft_margin_loss / triplet_margin_loss / soft_margin_loss / margin_ranking_loss (27101) (27424) (27669) (27713) (27660) (29000).
- torch::nn::functional::poisson_nll_loss / nll_loss / cross_entropy / binary_cross_entropy_with_logits (28755) (29812) (28783).
- torch::nn::functional::l1_loss / kl_div / mse_loss / binary_cross_entropy / smooth_l1_loss / ctc_loss (27156) (28806) (30146) (27661) (28654).
- Normalization functions
- Activation functions
- torch::nn::functional::elu / leaky_relu / selu / prelu / relu / relu6 / rrelu / celu / glu / gelu (27028) (27059) (27434) (27429) (27435) (27436) (27437) (27487) (29922) (28433).
- torch::nn::functional:: log_sigmoid/ log_softmax / softmax / softplus / softmin / softsign / softshrink / hardshrink / tanhshrink / hardtanh / gumbel_softmax / threshold (27060) (27462) (27446) (27489) (27459) (27535) (27534) (27035) (27537) (27038) (28121) (27538).
- Embedding functions
- Linear functions
- Padding functions
- Vision functions
- Distance functions
- torch::nn::functional::pdist (27122).
- Utility functions
AMD Support
- New features integration
- Build/CI
ONNX
In PyTorch 1.4, we have mainly focused on expanding the coverage for ONNX Opset 11, and enabling exporting torchvision models. Most of the torchvision models can be exported to ONNX (Opset 11, with fixed input size), including FasterRCNN, MaskRCNN, and KeypointRCNN. We have also enhanced export support for some tensor indexing scenarios, with more enhancements to come in the next release. In addition, 20+ new PyTorch operators are enabled in ONNX exporter.
Expanding Coverage for ONNX Opset 11
torch.sort/torch.topk
are supported in Opset 11 (25739)torch.size/torch.squeeze/torch.unsqueeze/torch.mm/torch.index_fill/torch.index_copy
are supported in Opset 11 (27578)torch.masked_select/torch.masked_scatter
are supported in Opset 11 (25949)torch.arange
is supported in Opset 11 (26875)avg_pool, constant_pad_nd, reflection_pad, replication_pad
Support enhanced in Opset 11 (28225)torch.hardtanh
is supported in Opset 11 (30169)- Enable ONNX constant folding for opset 11 (29011)
Exporting More Torch Operators/Models to ONNX
torch.remainder
is enabled in exporter (24410)torch.unfold
is enabled in exporter (24970)torch.slice/torch.select
with negative index are enabled in exporter (25273, 26549)torch.ones/torch.ones_like/torch.zeros/torch.zeros_like/torch.full/torch.full_like
with default dtype are enabled in exporter (27577)torch.unbind
is enabled in exporter (27247)torch.nn.functional.interpolate
export is enhanced (27179, 27566, 28560, 29489)torch.det
is enabled in exporter (26958)torch.group_norm
is enabled in exporter (27071)torch.meshgrid
is enabled in exporter (26037)torch.randn/torch.randn_like
are enabled in exporter (28470, 29354)torch.weight_norm
enabled in exporter (28618)torch.scalar_tensor
is enabled in exporter (28713)torch.logdet
is enabled in exporter (29767)torch.batch_norm
2D with affine=False is enabled in exporter (29458)torch.bitshift
is enabled in exporter (28210)
Enhancing Export/Test Infra
- Use deepcopy inputs in ONNX ORT test cases (27186)
- Return NotImplemented from all binary math ops (27423).
- Disabling ONNX IR v4 sematics for opset 8 or lower (28990)
- Add ONNX tests for torchvision models (30121)
- Keep output type information while exporting ONNX graph (25906)
Quantization
Quantization updates correspond to a mix of bug-fixes and feature improvements, with feature improvements adding improved operator coverage and performance improvements. We have also made a lot of progress towards enabling graph mode quantization support.
- Feature improvements:
- Enabling intra-op parallelism (26692).
- Enabling inplace relu (28710).
- Quantized Tensor support copy (28612).
- Add quantized torch mean implementation (27675).
- Add quantized avg_pool2d for pytorch mobile (27631).
- Add nn.quantized.Conv3d (29813).
- Adding inplace quantized relu6 (29245).
- Fast histogram observer (29790).
- PackedSequence support for quantized LSTM (29585).
- Improve legacy QuantizedLinear functions to reduce overhead (29773).
- Add support for quantized operator conversion from PT to C2 via ONNX (29694).
- enable per channel dynamic quantization (30122).
- Scripting support:
Visualization
- Fixed graph visualization: displaying proper names after recent JIT changes (30244)
- Support logging embedding for TensorBoard visualizations to generic filesystem (27716)
Other Improvements
torch.argmax/argmin
Allow half type (28787).torch.cuda.memory_stats / memory_summary
instrumentation for CUDA memory allocator (27361).torch.set_num_threads
Allow calling multiple times with TBB (27190).torch.set_num_threads
Allow calling multiple times in parallel native (27947).torch.logical_xor
Allow non-bool tensors (27248).torch.promote_types
Nicer error message. (27941).torch.batch_norm_elemt
Add an out-variant (27621).torch.lerp
Implement derivative with respect to weight (28219).torch.complex32
Add type promotion support (27929).torch.unique
Support bool tensors (28374).torch.reshape
Improve backward for viewable geometries (28901).torch.lu
Generalized factorization (28608).torch.equal
Add the intra-op parallelism (28810).torch.randint
Accept generator=None (29748).torch.bfloat16
Enabled for cuda (27259).torch.multinomial
Enable for torch.half (29266).nn.RNN
Respect the current stream in cudnn (27026).nn.RNN
Preserve nonlinearity attribute (28058).nn.Linear
Support 0-batch size. (27211).nn.functional.binary_cross_entropy
implement double backwards (26983).nn.AdaptiveAvgPool2d
Add support for NHWC memory format (24396).nn.GELU
Add GELU activation (28944).nn.LayerNorm
Handle batch size of zero (28614).nn.BatchNorm
Add NHWC support on cudnn (23861).nn.BatchNorm2d
support torch.channels_last (28982).nn.BatchNorm2d
Handle empty inputs (30035).nn.LayerNorm
Enable the intra-op parallelism (28464).nn.utils.prune
Add pruning functionality (24076).nn.Sequential
Make iterable (28987).dtype.is_signed
Ability to differentiate signed dtypes (29511).optim.lr_scheduler.MultiplicativeLR
Add new multiplicative learning rate scheduler. (27254).cuda.comm.scatter, gather
Add channel-last support (28077).at::parallel_for
Choose number of OMP threads based on GRAIN_SIZE (26963).- Return NotImplemented from unsupported tensor arithmetic operators (26507).
- Automatically select proper tqdm submodule (27108).
- Pickle support for sparse tensors (27062).
- Vectorized complex unary and binary op support. (26500).
- Complex support for reduce and linpack ops on CPU (27653).
- Complex support for compare and pointwise ops on CPU (28735).
- Make PyTorch Python 3.8 compatible (29302).
- Buffer python warning to avoid deadlocks (26613).
- Use NNPACK for strided convolutions. (29084).
Bug Fixes
Distributed
- Ensure NCCL error handling code is disabled for NCCL versions < 2.4 (27124).
- Fix segmentation fault in
FileStore
with concurrent accesses. (28812). - Fix DDP incompatibility issue with
nn.MultiheadAttention
(26826).
RPC
- Add
ProcessGroupAgent
termination detection algorithm (26984). - Fix pybind11 warnings in Python RPC handler implementation (27284).
- Defer creating
ProcessGroupAgent
listener thread until contexts are initialized (28013). - Fix Python RPC handler exit crash (27251).
- Fix distributed autograd initialization (29069).
- Always include autograd context id in
rpc_*
/remote
requests (29781). - Make
RRefContext
singleton leaky, deal with module destruct order race. (30172).
C++ API Bug Fixes
- at::Tensor::requires_grad_ now supported (26332).
- torch::isfinite now supported (30083).
- torch::nn::modules_ordered_dict is deprecated (28774).
- Add reset_parameters to torch::nn modules (29832).
- Allow passing undefined Tensor to Module::register_parameter (27948).
- Exclude undefined tensors in the result of Module::parameters() / named_paramters() / buffers() / named_buffers() (30626).
- Include hierarchy information in C++ API loading error messages (28499).
- Fix a bug: the C++ L-BFGS optimizer does not work properly if there are one or more registered tensors with no grad in the model (27606).
- Use c10::variant-based enums for Nonlinearity and FanMode (27933). Support for
torch::nn::init::Nonlinearity
andtorch::nn::init::FanMode
will be removed in 1.5.
JIT
- Make dropout properly condition on training. (29436)
- Fix aten::grad to return optional list (29577)
- Fix
torch.arange
dtype - Fix type sharing on loaded ScriptModules (29826)
- Fix type sharing between traced modules (29583)
- Check for mutable default parameters (29833)
- Fix tracing of autograd functions (29791)
- Check for unrolled loop in break & continue (29474)
- Fix negative string indexing (22700)
- Make jit.trace_module reentrant (29411)
- Fix jit outplace tracing and reapply changes to _like operators. (28839)
- Properly guard against inheritance on TorchScript classes (28407)
- Fix when giving jit format warning about unsupported options (28616)
- Fix handling of function attributes. (28569)
- Fix pushLong() issue in pickler. (28057)
- Fix broken name mangling (27511)
- Fix segfault while printing value type for an error msg in emitListComprehension (27261)
- Fix
toIValue
dict iteration (26856) - Fix race condition in Function::optimized_graph(). (27012)
- Sanitize module names on legacy import (27764)
- Python None should have its type inferred as NoneType (26665)
- Properly set existing attributes under recursive script (27514)
Quantization
- Skip copy_same_type_transpose_ for quantized tensor (29609).
- Add note that cuda quantization is not supported (27829).
- Rename _intrinsic to intrinsic (27194).
- Better error message for quantized dispatch (28635).
- Update the misleading comments for zero_points and scale in dynamic quant linear module [1/2] (28767).
- Avoid the misleading zero_point and scale [2/2] (28827).
- Add the warning message for API with linear modules (28766).
- Do not insert observers for empty sequential modules (28384).
- Fix the padding issue of quantized average pool operator (28260).
Mobile
Other Bug fixes
-
torch.kthvalue
Fix CUDA shared memory out of bound access in findPattern (28989). -
torch.save
Fix source files not being saved (28965). -
torch.load
Fix OSError loading files larger than 2GB. (27125). -
torch.linspace
clearer error message for negative step sizes. (28274). -
torch.histc
Add range checks to avoid segfaults (27712). -
torch.lu
Fix thread -
torch.max_pool2d
Limit tensor size to max CUDA grid size (28931). -
torch.renorm
Fix a memory leak in CUDA renorm. (29873). -
torch.index_add
Fix bug in atomicAdd on CUDA for some dtypes (29231). -
torch.addmm
Fix handling of empty tensors (28613). -
nn.CTCLoss
Fix incorrect gradient for large target sizes (27460). -
nn.functional.ctc_loss
Fix incorrect gradient on cudnn (27039). -
nn.Embedding
Incorrect gradient at padding_idx in cuda kernel. (27731). -
nn.LayerNorm
Fix an illegal memory access error (28196). -
nn.Conv2d
handle zero stride (28784). -
nn.PoissonNLLLoss
Fix incorrect result withfull=True
(28637). -
nn.AvgPool2d
fix an overflow for 2^31-1 sized inputs (30793). -
nn.RNNBase
Fix an issue with use of children of RNN third party device types (28562). -
nn.Upsample
Fix “invalid configuration argument” error (28927). -
nn.Upsample
Fix a CUDA launch config failure (29016). -
optim.lr_scheduler.OneCycleLR
Correctly handle div_factor parameter (28217). -
PackedSequence.to
Ensure all tensors are moved (27245). -
EventList.total_average
Fix a regression caused by missing iadd (27498). -
Tensor.record_stream
Ensure stream is recorded for shifted view tensors (27371). -
torch.hub
Handle branch names containing a slash. (27960). -
Fix error handling in Magma kernels (29003).
-
Fix avx for c++14 (28207).
-
Fix illegal memory access thread safety issue in sparse CUDA (29426).
-
__cuda_array_interface__
Fix stride calculation (31450).
Deprecations
Python 2 support is deprecated and will not be supported in the 1.5 release.
torch.optim
: Scheduler.step(epoch)
is now deprecated; use Scheduler.step()
instead. (26432)
For example:
>>> for epoch in range(10):
>>> optimizer.step()
>>> scheduler.step(epoch)
DeprecationWarning: The epoch parameter in `scheduler.step()` was not necessary and is being deprecated where possible. Please use `scheduler.step()` to step the scheduler. During the deprecation, if epoch is different from None, the closed form is used instead of the new chainable form, where available. Please open an issue if you are unable to replicate your use case: https://github.com/pytorch/pytorch/issues/new/choose.
warnings.warn(EPOCH_DEPRECATION_WARNING, DeprecationWarning)
[C++] C++11 is deprecated and will not be supported in the 1.5 release.
[C++] Tensor::is_variable()
has been deprecated. As noted in the Backwards Incompatible Changes
section, the distinction between variable and non-variable has been
eliminated, so this check is no longer meaningful. Generally, is_variable()
will now return true except in some special circumstances (see 29653 for more details). (29653)
[C++] torch::nn::modules_ordered_dict
has been deprecated. It is generally no longer necessary and can just be removed. (28774)
torch.jit.quantized
API has been deprecated in favor of torch.quantization.quantize_dynamic
(28766)
Performance
A benchmark suite is available to easily measure the performance of operators with a range of input shapes. The generated benchmark data fully characterize the performance of operators in terms of execution time. For more details see README.md in the benchmarks/operator_benchmark directory.
torch.nn.functional.threshold, torch.nn.functional.layer_norm, torch.cdist
Performance of threshold (CPU), layer norm (CUDA) and cdist operations was improved (27155,27634, 25799)torch.Tensor.fill_
Performance for half and bfloat16 types on CPU was improved (28397).torch.nn.MaxPool2d
implementation for channels_last format was added (24872)- There is a fast pass reducing the overheads of pointwise operations relying on TensorIterator under certain conditions (contiguous inputs, no broadcast) (29180).
- Overheads of operations with scalars/number literals was improved (29915).
- In case of type promotion on the GPU, the values are converted on the fly, without explicit casting of the full tensor (30018).
- reorder_dimensions in TensorIterator favors output write locality, improving overall performance when operating on discontiguous tensors (28615).
- Float pickling speed was improved (28553).
- GRAIN_SIZE for intra-op parallelization was unified between TH and ATen operations (28770)
tensor.numel
devirtualized, improving performance (27294)