Compile with mixed precision (Beta)
Mixed precision combines the use of different numerical formats (such as FP32 and BF16) to reduce memory footprint and speed up large neural network workloads.
SambaFlow provides several convenience methods for enabling mixed precision in the model.
Graph Automatic Mixed Precision (GraphAMP)
The GraphAMP feature is available starting with SambaFlow 1.18. It is a convenient way to choose how aggressively the compiler will downcast operations in the model.
You use the --graphamp-preset compiler argument to switch between different presets. The available options, shown in the following table, offer a tradeoff between accuracy and performance:
| GraphAMP options | Description | 
|---|---|
| 
 | GraphAMP is disabled and no optimization is applied to the model. | 
| 
 | Downcasts the inputs of GEMM-like operators to BF16. | 
| 
 | Downcasts all inputs and operators of the model to BF16. | 
| Default | Forces all inputs and operators of the model to BF16 until it encounters a floating point conversion. | 
| NOTE: The  | |
Example
Below is a simple example that we will use to explain what happens when you compile the model with different --graphamp-preset values.
import torch
import torch.nn.functional as F
class Model(nn.Module):
    def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        matmul_1 = torch.matmul(a, b)
        act_1 = F.relu(matmul_1)
        matmul_2 = torch.matmul(act_1, c)
        out = F.relu(matmul_2)
        return out
mymodel = Model()GraphAMP mp0 mode
With mp0, GraphAMP is disabled. The model that is being lowered by the compiler is the same as the model provided by the user through the model code.
$ python mymodel.py compile --graphamp-preset mp0GraphAMP mp1 mode
With mp1, GraphAMP downcasts the inputs of GEMM-like operators, in this example, 2 Matmul operators. In the resulting mixed-precision model the Matmul operators are also in mixed precision.
$ python mymodel.py compile --graphamp-preset mp1| With only GEMM operators in mixed precision, mp1offers a conservative tradeoff between accuracy and performance. This mode is the recommened preset. | 
GraphAMP mp4 mode
With mp4, GraphAMP aggressively downcasts all the inputs and operators to BF16. The resulting model uses BF16 throughput.
$ python mymodel.py compile --graphamp-preset mp4Overrides with disable_graphamp()
SambaFlow provides context managers or decorators that allow you to disable mixed precision in regions of your model. This provides the flexibility to improve performance of your model using GraphAMP while preserving the floating-point precision for a specific operator instance.
In these regions, operators run in an operator-specific datatype chosen by the user to maintain accuracy.
Use disable_graphamp() to wrap only the forward pass(es) of your network, including the loss computation(s). Backward operations run with the same datatype as forward operations.
Example:
import sambaflow.samba as samba
import torch
import torch.nn.functional as F
class Model(nn.Module):
    def forward(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        matmul_1 = torch.matmul(a, b)
        act_1 = F.relu(matmul_1)
        with samba.session.disable_graphamp():
            matmul_2 = torch.matmul(act_1.bfloat16(), c.bfloat16())
            out = F.relu(matmul_2)
        return out
mymodel = Model()$ python mymodel.py compile --graphamp-preset mp1For the above example, matmul_2 and its relu are within the disable_graphamp context manager and the remaining graph is using mp1 mode. The resultant graph is shown below.
- 
With full mp1mode bothmatmul_1andmatmul_2are in mixed precision.
- 
In this example, the resulting graph has matmul_1in mixed precision andmatmul_2is in bf16 precision.
Resulting model:
Operator-specific behavior
- 
Matmul and Linear. These are GEMM-like operators that support internal mixed precision compute and can take BF16 inputs and output FP32 tensors directly. 
- 
See here for a reference to supported PyTorch operators. 
Additional precision arguments
You can customize compiler behavior with the following compiler arguments.
--tiling-accum
The accumulation operation associated with tiling can be sensitive to the precision setting. By default accumulation happens in FP32. Use --tiling-accum to change accumulation to BF16 with stochastic rounding to enable better performance with some accuracy degradation.
$ compile --tiling-accum ["fp32" | "bf16sr"]--weight-grad-reduce
Weight gradient reduce means reduction across weight gradient values computed in parallel. By default weight gradient reduction happens in FP32. Use the --weight-grad-reduce argument to change reduction to BF16 with stochastic rounding to enable better performance with some accuracy degradation.
$ compile --weight-grad-reduce ["fp32" | "bf16sr"]| The weight-gradient-reduceargument is applied only to linear weight gradient updates. | 
--fp32-params
By default, if the forward pass for a particular operator has BF16 inputs, the backward pass for that operator produces BF16 gradients. Gradient values with small magnitudes may not be representable, causing underflow and the update for the corresponding parameters to be lost.
--fp32-params argument addresses this issue by having the optimizer output two copies of the weight in both BF16 and FP32 precision. The BF16 copy is sent to the next operator, while the FP32 copy is used to update the trainable parameter.
| When using full precision weight update mode ( --fp32-params), it is expected the model parameters are initialize in BF16. | 
--enable-mixed-precision-ops
By default, an operator computes and returns in the same precision as its input. However, for some operations, you might want to have certain operators in mixed precision without relying on the GraphAMP.
enable-mixed-precision-ops allows you to specify operators in mixed precision on the whole graph. Two types of operations are supported, GEMM-like operations (matmul and linear) and softmax.
$ compile --enable-mixed-precision-ops gemm softmax| For GEMM-like operations, enable-mixed-precision-ops enables bf16 input and fp32 output. For softmaxoperations,enable-mixed-precision-opsenables fp32 input and bf16 output. |