MLIR  20.0.0git
Namespaces | Macros | Functions
SuperVectorize.cpp File Reference
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <optional>
#include "mlir/Dialect/Affine/Passes.h.inc"

Go to the source code of this file.

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::affine
 

Macros

#define GEN_PASS_DEF_AFFINEVECTORIZE
 
#define DEBUG_TYPE   "early-vect"
 Implements a high-level vectorization strategy on a Function. More...
 

Functions

static FilterFunctionType isVectorizableLoopPtrFactory (const DenseSet< Operation * > &parallelLoops, int fastestVaryingMemRefDimension)
 Forward declaration. More...
 
static std::optional< NestedPatternmakePattern (const DenseSet< Operation * > &parallelLoops, int vectorRank, ArrayRef< int64_t > fastestVaryingPattern)
 Creates a vectorization pattern from the command line arguments. More...
 
static NestedPatternvectorTransferPattern ()
 
static void vectorizeLoopIfProfitable (Operation *loop, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy)
 
static LogicalResult analyzeProfitability (ArrayRef< NestedMatch > matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy)
 Implements a simple strawman strategy for vectorization. More...
 
static void eraseLoopNest (AffineForOp forOp)
 Erases a loop nest, including all its nested operations. More...
 
static void computeMemoryOpIndices (Operation *op, AffineMap map, ValueRange mapOperands, VectorizationState &state, SmallVectorImpl< Value > &results)
 
static VectorType getVectorType (Type scalarTy, const VectorizationStrategy *strategy)
 Returns the vector type resulting from applying the provided vectorization strategy on the scalar type. More...
 
static arith::ConstantOp vectorizeConstant (arith::ConstantOp constOp, VectorizationState &state)
 Tries to transform a scalar constant into a vector constant. More...
 
static OperationvectorizeAffineApplyOp (AffineApplyOp applyOp, VectorizationState &state)
 We have no need to vectorize affine.apply. More...
 
static arith::ConstantOp createInitialVector (arith::AtomicRMWKind reductionKind, Value oldOperand, VectorizationState &state)
 Creates a constant vector filled with the neutral elements of the given reduction. More...
 
static Value createMask (AffineForOp vecForOp, VectorizationState &state)
 Creates a mask used to filter out garbage elements in the last iteration of unaligned loops. More...
 
static bool isUniformDefinition (Value value, const VectorizationStrategy *strategy)
 Returns true if the provided value is vector uniform given the vectorization strategy. More...
 
static OperationvectorizeUniform (Value uniformVal, VectorizationState &state)
 Generates a broadcast op for the provided uniform value using the vectorization strategy in 'state'. More...
 
static Value vectorizeOperand (Value operand, VectorizationState &state)
 Tries to vectorize a given operand by applying the following logic: More...
 
static OperationvectorizeAffineLoad (AffineLoadOp loadOp, VectorizationState &state)
 Vectorizes an affine load with the vectorization strategy in 'state' by generating a 'vector.transfer_read' op with the proper permutation map inferred from the indices of the load. More...
 
static OperationvectorizeAffineStore (AffineStoreOp storeOp, VectorizationState &state)
 Vectorizes an affine store with the vectorization strategy in 'state' by generating a 'vector.transfer_write' op with the proper permutation map inferred from the indices of the store. More...
 
static bool isNeutralElementConst (arith::AtomicRMWKind reductionKind, Value value, VectorizationState &state)
 Returns true if value is a constant equal to the neutral element of the given vectorizable reduction. More...
 
static OperationvectorizeAffineForOp (AffineForOp forOp, VectorizationState &state)
 Vectorizes a loop with the vectorization strategy in 'state'. More...
 
static OperationwidenOp (Operation *op, VectorizationState &state)
 Vectorizes arbitrary operation by plain widening. More...
 
static OperationvectorizeAffineYieldOp (AffineYieldOp yieldOp, VectorizationState &state)
 Vectorizes a yield operation by widening its types. More...
 
static OperationvectorizeOneOperation (Operation *op, VectorizationState &state)
 Encodes Operation-specific behavior for vectorization. More...
 
static void getMatchedAffineLoopsRec (NestedMatch match, unsigned currentLevel, std::vector< SmallVector< AffineForOp, 2 >> &loops)
 Recursive implementation to convert all the nested loops in 'match' to a 2D vector container that preserves the relative nesting level of each loop with respect to the others in 'match'. More...
 
static void getMatchedAffineLoops (NestedMatch match, std::vector< SmallVector< AffineForOp, 2 >> &loops)
 Converts all the nested loops in 'match' to a 2D vector container that preserves the relative nesting level of each loop with respect to the others in 'match'. More...
 
static LogicalResult vectorizeLoopNest (std::vector< SmallVector< AffineForOp, 2 >> &loops, const VectorizationStrategy &strategy)
 Internal implementation to vectorize affine loops from a single loop nest using an n-D vectorization strategy. More...
 
static LogicalResult vectorizeRootMatch (NestedMatch m, const VectorizationStrategy &strategy)
 Extracts the matched loops and vectorizes them following a topological order. More...
 
static void computeIntersectionBuckets (ArrayRef< NestedMatch > matches, std::vector< SmallVector< NestedMatch, 8 >> &intersectionBuckets)
 Traverses all the loop matches and classifies them into intersection buckets. More...
 
static void vectorizeLoops (Operation *parentOp, DenseSet< Operation * > &loops, ArrayRef< int64_t > vectorSizes, ArrayRef< int64_t > fastestVaryingPattern, const ReductionLoopMap &reductionLoops)
 Internal implementation to vectorize affine loops in 'loops' using the n-D vectorization factors in 'vectorSizes'. More...
 
static LogicalResult verifyLoopNesting (const std::vector< SmallVector< AffineForOp, 2 >> &loops)
 Verify that affine loops in 'loops' meet the nesting criteria expected by SuperVectorizer: More...
 

Macro Definition Documentation

◆ DEBUG_TYPE

#define DEBUG_TYPE   "early-vect"

Implements a high-level vectorization strategy on a Function.

The abstraction used is that of super-vectors, which provide a single, compact, representation in the vector types, information that is expected to reduce the impact of the phase ordering problem

Vector granularity:

This pass is designed to perform vectorization at a super-vector granularity. A super-vector is loosely defined as a vector type that is a multiple of a "good" vector size so the HW can efficiently implement a set of high-level primitives. Multiple is understood along any dimension; e.g. both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a vector<8xf32> HW vector. Note that a "good vector size so the HW can efficiently implement a set of high-level primitives" is not necessarily an integer multiple of actual hardware registers. We leave details of this distinction unspecified for now.

Some may prefer the terminology a "tile of HW vectors". In this case, one should note that super-vectors implement an "always full tile" abstraction. They guarantee no partial-tile separation is necessary by relying on a high-level copy-reshape abstraction that we call vector.transfer. This copy-reshape operations is also responsible for performing layout transposition if necessary. In the general case this will require a scoped allocation in some notional local memory.

Whatever the mental model one prefers to use for this abstraction, the key point is that we burn into a single, compact, representation in the vector types, information that is expected to reduce the impact of the phase ordering problem. Indeed, a vector type conveys information that:

  1. the associated loops have dependency semantics that do not prevent vectorization;
  2. the associate loops have been sliced in chunks of static sizes that are compatible with vector sizes (i.e. similar to unroll-and-jam);
  3. the inner loops, in the unroll-and-jam analogy of 2, are captured by the vector type and no vectorization hampering transformations can be applied to them anymore;
  4. the underlying memrefs are accessed in some notional contiguous way that allows loading into vectors with some amount of spatial locality; In other words, super-vectorization provides a level of separation of concern by way of opacity to subsequent passes. This has the effect of encapsulating and propagating vectorization constraints down the list of passes until we are ready to lower further.

For a particular target, a notion of minimal n-d vector size will be specified and vectorization targets a multiple of those. In the following paragraph, let "k ." represent "a multiple of", to be understood as a multiple in the same dimension (e.g. vector<16 x k . 128> summarizes vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).

Some non-exhaustive notable super-vector sizes of interest include:

  • CPU: vector<k . HW_vector_size>, vector<k' . core_count x k . HW_vector_size>, vector<socket_count x k' . core_count x k . HW_vector_size>;
  • GPU: vector<k . warp_size>, vector<k . warp_size x float2>, vector<k . warp_size x float4>, vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).

Loops and operations are emitted that operate on those super-vector shapes. Subsequent lowering passes will materialize to actual HW vector sizes. These passes are expected to be (gradually) more target-specific.

At a high level, a vectorized load in a loop will resemble:

affine.for %i = ? to ? step ? {
%v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
}

It is the responsibility of the implementation of vector.transfer_read to materialize vector registers from the original scalar memrefs. A later (more target-dependent) lowering pass will materialize to actual HW vector sizes. This lowering may be occur at different times:

  1. at the MLIR level into a combination of loops, unrolling, DmaStartOp + DmaWaitOp + vectorized operations for data transformations and shuffle; thus opening opportunities for unrolling and pipelining. This is an instance of library call "whiteboxing"; or
  2. later in the a target-specific lowering pass or hand-written library call; achieving full separation of concerns. This is an instance of library call; or
  3. a mix of both, e.g. based on a model. In the future, these operations will expose a contract to constrain the search on vectorization patterns and sizes.

Occurrence of super-vectorization in the compiler flow:

This is an active area of investigation. We start with 2 remarks to position super-vectorization in the context of existing ongoing work: LLVM VPLAN and LLVM SLP Vectorizer.

LLVM VPLAN:

The astute reader may have noticed that in the limit, super-vectorization can be applied at a similar time and with similar objectives than VPLAN. For instance, in the case of a traditional, polyhedral compilation-flow (for instance, the PPCG project uses ISL to provide dependence analysis, multi-level(scheduling + tiling), lifting footprint to fast memory, communication synthesis, mapping, register optimizations) and before unrolling. When vectorization is applied at this late level in a typical polyhedral flow, and is instantiated with actual hardware vector sizes, super-vectorization is expected to match (or subsume) the type of patterns that LLVM's VPLAN aims at targeting. The main difference here is that MLIR is higher level and our implementation should be significantly simpler. Also note that in this mode, recursive patterns are probably a bit of an overkill although it is reasonable to expect that mixing a bit of outer loop and inner loop vectorization + unrolling will provide interesting choices to MLIR.

LLVM SLP Vectorizer:

Super-vectorization however is not meant to be usable in a similar fashion to the SLP vectorizer. The main difference lies in the information that both vectorizers use: super-vectorization examines contiguity of memory references along fastest varying dimensions and loops with recursive nested patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on the other hand, performs flat pattern matching inside a single unrolled loop body and stitches together pieces of load and store operations into full 1-D vectors. We envision that the SLP vectorizer is a good way to capture innermost loop, control-flow dependent patterns that super-vectorization may not be able to capture easily. In other words, super-vectorization does not aim at replacing the SLP vectorizer and the two solutions are complementary.

Ongoing investigations:

We discuss the following early places where super-vectorization is applicable and touch on the expected benefits and risks . We list the opportunities in the context of the traditional polyhedral compiler flow described in PPCG. There are essentially 6 places in the MLIR pass pipeline we expect to experiment with super-vectorization:

  1. Right after language lowering to MLIR: this is the earliest time where super-vectorization is expected to be applied. At this level, all the language/user/library-level annotations are available and can be fully exploited. Examples include loop-type annotations (such as parallel, reduction, scan, dependence distance vector, vectorizable) as well as memory access annotations (such as non-aliasing writes guaranteed, indirect accesses that are permutations by construction) accesses or that a particular operation is prescribed atomic by the user. At this level, anything that enriches what dependence analysis can do should be aggressively exploited. At this level we are close to having explicit vector types in the language, except we do not impose that burden on the programmer/library: we derive information from scalar code + annotations.
  2. After dependence analysis and before polyhedral scheduling: the information that supports vectorization does not need to be supplied by a higher level of abstraction. Traditional dependence analysis is available in MLIR and will be used to drive vectorization and cost models.

Let's pause here and remark that applying super-vectorization as described in 1. and 2. presents clear opportunities and risks:

  • the opportunity is that vectorization is burned in the type system and is protected from the adverse effect of loop scheduling, tiling, loop interchange and all passes downstream. Provided that subsequent passes are able to operate on vector types; the vector shapes, associated loop iterator properties, alignment, and contiguity of fastest varying dimensions are preserved until we lower the super-vector types. We expect this to significantly rein in on the adverse effects of phase ordering.
  • the risks are that a. all passes after super-vectorization have to work on elemental vector types (not that this is always true, wherever vectorization is applied) and b. that imposing vectorization constraints too early may be overall detrimental to loop fusion, tiling and other transformations because the dependence distances are coarsened when operating on elemental vector types. For this reason, the pattern profitability analysis should include a component that also captures the maximal amount of fusion available under a particular pattern. This is still at the stage of rough ideas but in this context, search is our friend as the Tensor Comprehensions and auto-TVM contributions demonstrated previously. Bottom-line is we do not yet have good answers for the above but aim at making it easy to answer such questions.

Back to our listing, the last places where early super-vectorization makes sense are:

  1. right after polyhedral-style scheduling: PLUTO-style algorithms are known to improve locality, parallelism and be configurable (e.g. max-fuse, smart-fuse etc). They can also have adverse effects on contiguity properties that are required for vectorization but the vector.transfer copy-reshape-pad-transpose abstraction is expected to help recapture these properties.
  2. right after polyhedral-style scheduling+tiling;
  3. right after scheduling+tiling+rescheduling: points 4 and 5 represent probably the most promising places because applying tiling achieves a separation of concerns that allows rescheduling to worry less about locality and more about parallelism and distribution (e.g. min-fuse).

At these levels the risk-reward looks different: on one hand we probably lost a good deal of language/user/library-level annotation; on the other hand we gained parallelism and locality through scheduling and tiling. However we probably want to ensure tiling is compatible with the full-tile-only abstraction used in super-vectorization or suffer the consequences. It is too early to place bets on what will win but we expect super-vectorization to be the right abstraction to allow exploring at all these levels. And again, search is our friend.

Lastly, we mention it again here:

  1. as a MLIR-based alternative to VPLAN.

Lowering, unrolling, pipelining:

TODO: point to the proper places.

Algorithm:

The algorithm proceeds in a few steps:

  1. defining super-vectorization patterns and matching them on the tree of AffineForOp. A super-vectorization pattern is defined as a recursive data structures that matches and captures nested, imperfectly-nested loops that have a. conformable loop annotations attached (e.g. parallel, reduction, vectorizable, ...) as well as b. all contiguous load/store operations along a specified minor dimension (not necessarily the fastest varying) ;
  2. analyzing those patterns for profitability (TODO: and interference);
  3. then, for each pattern in order: a. applying iterative rewriting of the loops and all their nested operations in topological order. Rewriting is implemented by coarsening the loops and converting operations and operands to their vector forms. Processing operations in topological order is relatively simple due to the structured nature of the control-flow representation. This order ensures that all the operands of a given operation have been vectorized before the operation itself in a single traversal, except for operands defined outside of the loop nest. The algorithm can convert the following operations to their vector form:
    • Affine load and store operations are converted to opaque vector transfer read and write operations.
    • Scalar constant operations/operands are converted to vector constant operations (splat).
    • Uniform operands (only induction variables of loops not mapped to a vector dimension, or operands defined outside of the loop nest for now) are broadcasted to a vector. TODO: Support more uniform cases.
    • Affine for operations with 'iter_args' are vectorized by vectorizing their 'iter_args' operands and results. TODO: Support more complex loops with divergent lbs and/or ubs.
    • The remaining operations in the loop nest are vectorized by widening their scalar types to vector types. b. if everything under the root AffineForOp in the current pattern is vectorized properly, we commit that loop to the IR and remove the scalar loop. Otherwise, we discard the vectorized loop and keep the original scalar loop. c. vectorization is applied on the next pattern in the list. Because pattern interference avoidance is not yet implemented and that we do not support further vectorizing an already vector load we need to re-verify that the pattern is still vectorizable. This is expected to make cost models more difficult to write and is subject to improvement in the future.

Choice of loop transformation to support the algorithm:

The choice of loop transformation to apply for coarsening vectorized loops is still subject to exploratory tradeoffs. In particular, say we want to vectorize by a factor 128, we want to transform the following input:

affine.for %i = %M to %N {
%a = affine.load %A[%i] : memref<?xf32>
}

Traditionally, one would vectorize late (after scheduling, tiling, memory promotion etc) say after stripmining (and potentially unrolling in the case of LLVM's SLP vectorizer):

affine.for %i = floor(%M, 128) to ceil(%N, 128) {
affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
%a = affine.load %A[%ii] : memref<?xf32>
}
}

Instead, we seek to vectorize early and freeze vector types before scheduling, so we want to generate a pattern that resembles:

affine.for %i = ? to ? step ? {
%v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
}

i. simply dividing the lower / upper bounds by 128 creates issues when representing expressions such as ii + 1 because now we only have access to original values that have been divided. Additional information is needed to specify accesses at below-128 granularity; ii. another alternative is to coarsen the loop step but this may have consequences on dependence analysis and fusability of loops: fusable loops probably need to have the same step (because we don't want to stripmine/unroll to enable fusion). As a consequence, we choose to represent the coarsening using the loop step for now and reevaluate in the future. Note that we can renormalize loop steps later if/when we have evidence that they are problematic.

For the simple strawman example above, vectorizing for a 1-D vector abstraction of size 128 returns code similar to:

affine.for %i = %M to %N step 128 {
%v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
}

Unsupported cases, extensions, and work in progress (help welcome :-) ):

  1. lowering to concrete vector types for various HW;
  2. reduction support for n-D vectorization and non-unit steps;
  3. non-effecting padding during vector.transfer_read and filter during vector.transfer_write;
  4. misalignment support vector.transfer_read / vector.transfer_write (hopefully without read-modify-writes);
  5. control-flow support;
  6. cost-models, heuristics and search;
  7. Op implementation, extensions and implication on memref views;
  8. many TODOs left around.

Examples:

Consider the following Function:

func @vector_add_2d(%M : index, %N : index) -> f32 {
%A = alloc (%M, %N) : memref<?x?xf32, 0>
%B = alloc (%M, %N) : memref<?x?xf32, 0>
%C = alloc (%M, %N) : memref<?x?xf32, 0>
%f1 = arith.constant 1.0 : f32
%f2 = arith.constant 2.0 : f32
affine.for %i0 = 0 to %M {
affine.for %i1 = 0 to %N {
// non-scoped %f1
affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
}
}
affine.for %i2 = 0 to %M {
affine.for %i3 = 0 to %N {
// non-scoped %f2
affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
}
}
affine.for %i4 = 0 to %M {
affine.for %i5 = 0 to %N {
%a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0>
%b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0>
%s5 = arith.addf %a5, %b5 : f32
// non-scoped %f1
%s6 = arith.addf %s5, %f1 : f32
// non-scoped %f2
%s7 = arith.addf %s5, %f2 : f32
// diamond dependency.
%s8 = arith.addf %s7, %s6 : f32
affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
}
}
%c7 = arith.constant 7 : index
%c42 = arith.constant 42 : index
%res = load %C[%c7, %c42] : memref<?x?xf32, 0>
return %res : f32
}

The -affine-super-vectorize pass with the following arguments:

-affine-super-vectorize="virtual-vector-size=256 test-fastest-varying=0"
LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Emit a suitable vector form for an operation.

produces this standard innermost-loop vectorized code:

func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%cst = arith.constant 1.0 : f32
%cst_0 = arith.constant 2.0 : f32
affine.for %i0 = 0 to %arg0 {
affine.for %i1 = 0 to %arg1 step 256 {
%cst_1 = arith.constant dense<vector<256xf32>, 1.0> :
vector<256xf32>
vector.transfer_write %cst_1, %0[%i0, %i1] :
vector<256xf32>, memref<?x?xf32>
}
}
affine.for %i2 = 0 to %arg0 {
affine.for %i3 = 0 to %arg1 step 256 {
%cst_2 = arith.constant dense<vector<256xf32>, 2.0> :
vector<256xf32>
vector.transfer_write %cst_2, %1[%i2, %i3] :
vector<256xf32>, memref<?x?xf32>
}
}
affine.for %i4 = 0 to %arg0 {
affine.for %i5 = 0 to %arg1 step 256 {
%3 = vector.transfer_read %0[%i4, %i5] :
memref<?x?xf32>, vector<256xf32>
%4 = vector.transfer_read %1[%i4, %i5] :
memref<?x?xf32>, vector<256xf32>
%5 = arith.addf %3, %4 : vector<256xf32>
%cst_3 = arith.constant dense<vector<256xf32>, 1.0> :
vector<256xf32>
%6 = arith.addf %5, %cst_3 : vector<256xf32>
%cst_4 = arith.constant dense<vector<256xf32>, 2.0> :
vector<256xf32>
%7 = arith.addf %5, %cst_4 : vector<256xf32>
%8 = arith.addf %7, %6 : vector<256xf32>
vector.transfer_write %8, %2[%i4, %i5] :
vector<256xf32>, memref<?x?xf32>
}
}
%c7 = arith.constant 7 : index
%c42 = arith.constant 42 : index
%9 = load %2[%c7, %c42] : memref<?x?xf32>
return %9 : f32
}

The -affine-super-vectorize pass with the following arguments:

-affine-super-vectorize="virtual-vector-size=32,256 \
test-fastest-varying=1,0"

produces this more interesting mixed outer-innermost-loop vectorized code:

func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
%cst = arith.constant 1.0 : f32
%cst_0 = arith.constant 2.0 : f32
affine.for %i0 = 0 to %arg0 step 32 {
affine.for %i1 = 0 to %arg1 step 256 {
%cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> :
vector<32x256xf32>
vector.transfer_write %cst_1, %0[%i0, %i1] :
vector<32x256xf32>, memref<?x?xf32>
}
}
affine.for %i2 = 0 to %arg0 step 32 {
affine.for %i3 = 0 to %arg1 step 256 {
%cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> :
vector<32x256xf32>
vector.transfer_write %cst_2, %1[%i2, %i3] :
vector<32x256xf32>, memref<?x?xf32>
}
}
affine.for %i4 = 0 to %arg0 step 32 {
affine.for %i5 = 0 to %arg1 step 256 {
%3 = vector.transfer_read %0[%i4, %i5] :
memref<?x?xf32> vector<32x256xf32>
%4 = vector.transfer_read %1[%i4, %i5] :
memref<?x?xf32>, vector<32x256xf32>
%5 = arith.addf %3, %4 : vector<32x256xf32>
%cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> :
vector<32x256xf32>
%6 = arith.addf %5, %cst_3 : vector<32x256xf32>
%cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> :
vector<32x256xf32>
%7 = arith.addf %5, %cst_4 : vector<32x256xf32>
%8 = arith.addf %7, %6 : vector<32x256xf32>
vector.transfer_write %8, %2[%i4, %i5] :
vector<32x256xf32>, memref<?x?xf32>
}
}
%c7 = arith.constant 7 : index
%c42 = arith.constant 42 : index
%9 = load %2[%c7, %c42] : memref<?x?xf32>
return %9 : f32
}

Of course, much more intricate n-D imperfectly-nested patterns can be vectorized too and specified in a fully declarative fashion.

Reduction:

Vectorizing reduction loops along the reduction dimension is supported if:

  • the reduction kind is supported,
  • the vectorization is 1-D, and
  • the step size of the loop equals to one.

Comparing to the non-vector-dimension case, two additional things are done during vectorization of such loops:

  • The resulting vector returned from the loop is reduced to a scalar using vector.reduce.
  • In some cases a mask is applied to the vector yielded at the end of the loop to prevent garbage values from being written to the accumulator.

Reduction vectorization is switched off by default, it can be enabled by passing a map from loops to reductions to utility functions, or by passing vectorize-reductions=true to the vectorization pass.

Consider the following example:

func @vecred(%in: memref<512xf32>) -> f32 {
%cst = arith.constant 0.000000e+00 : f32
%sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) {
%ld = affine.load %in[%i] : memref<512xf32>
%cos = math.cos %ld : f32
%add = arith.addf %part_sum, %cos : f32
affine.yield %add : f32
}
return %sum : f32
}

The -affine-super-vectorize pass with the following arguments:

-affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 \
vectorize-reductions=true"

produces the following output:

#map = affine_map<(d0) -> (-d0 + 500)>
func @vecred(%arg0: memref<512xf32>) -> f32 {
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32>
%0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0)
-> (vector<128xf32>) {
// %2 is the number of iterations left in the original loop.
%2 = affine.apply #map(%arg1)
%3 = vector.create_mask %2 : vector<128xi1>
%cst_1 = arith.constant 0.000000e+00 : f32
%4 = vector.transfer_read %arg0[%arg1], %cst_1 :
memref<512xf32>, vector<128xf32>
%5 = math.cos %4 : vector<128xf32>
%6 = arith.addf %arg2, %5 : vector<128xf32>
// We filter out the effect of last 12 elements using the mask.
%7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32>
affine.yield %7 : vector<128xf32>
}
%1 = vector.reduction <add>, %0 : vector<128xf32> into f32
return %1 : f32
}

Note that because of loop misalignment we needed to apply a mask to prevent last 12 elements from affecting the final result. The mask is full of ones in every iteration except for the last one, in which it has the form 11...100...0 with 116 ones and 12 zeros.

Definition at line 575 of file SuperVectorize.cpp.

◆ GEN_PASS_DEF_AFFINEVECTORIZE

#define GEN_PASS_DEF_AFFINEVECTORIZE

Definition at line 35 of file SuperVectorize.cpp.

Function Documentation

◆ analyzeProfitability()

static LogicalResult analyzeProfitability ( ArrayRef< NestedMatch matches,
unsigned  depthInPattern,
unsigned  patternDepth,
VectorizationStrategy strategy 
)
static

Implements a simple strawman strategy for vectorization.

Given a matched pattern matches of depth patternDepth, this strategy greedily assigns the fastest varying dimension ** of the vector ** to the innermost loop in the pattern. When coupled with a pattern that looks for the fastest varying dimension in load/store MemRefs, this creates a generic vectorization strategy that works for any loop in a hierarchy (outermost, innermost or intermediate).

TODO: In the future we should additionally increase the power of the profitability analysis along 3 directions:

  1. account for loop extents (both static and parametric + annotations);
  2. account for data layout permutations;
  3. account for impact of vectorization on maximal loop fusion. Then we can quantify the above to build a cost model and search over strategies.

Definition at line 657 of file SuperVectorize.cpp.

Referenced by vectorizeLoops().

◆ computeIntersectionBuckets()

static void computeIntersectionBuckets ( ArrayRef< NestedMatch matches,
std::vector< SmallVector< NestedMatch, 8 >> &  intersectionBuckets 
)
static

Traverses all the loop matches and classifies them into intersection buckets.

Two matches intersect if any of them encloses the other one. A match intersects with a bucket if the match intersects with the root (outermost) loop in that bucket.

Definition at line 1644 of file SuperVectorize.cpp.

Referenced by vectorizeLoops().

◆ computeMemoryOpIndices()

static void computeMemoryOpIndices ( Operation op,
AffineMap  map,
ValueRange  mapOperands,
VectorizationState state,
SmallVectorImpl< Value > &  results 
)
static

Definition at line 901 of file SuperVectorize.cpp.

◆ createInitialVector()

static arith::ConstantOp createInitialVector ( arith::AtomicRMWKind  reductionKind,
Value  oldOperand,
VectorizationState state 
)
static

Creates a constant vector filled with the neutral elements of the given reduction.

The scalar type of vector elements will be taken from oldOperand.

Definition at line 1001 of file SuperVectorize.cpp.

◆ createMask()

static Value createMask ( AffineForOp  vecForOp,
VectorizationState state 
)
static

Creates a mask used to filter out garbage elements in the last iteration of unaligned loops.

If a mask is not required then nullptr is returned. The mask will be a vector of booleans representing meaningful vector elements in the current iteration. It is filled with ones for each iteration except for the last one, where it has the form 11...100...0 with the number of ones equal to the number of meaningful elements (i.e. the number of iterations that would be left in the original loop).

Definition at line 1025 of file SuperVectorize.cpp.

◆ eraseLoopNest()

static void eraseLoopNest ( AffineForOp  forOp)
static

Erases a loop nest, including all its nested operations.

Definition at line 889 of file SuperVectorize.cpp.

◆ getMatchedAffineLoops()

static void getMatchedAffineLoops ( NestedMatch  match,
std::vector< SmallVector< AffineForOp, 2 >> &  loops 
)
static

Converts all the nested loops in 'match' to a 2D vector container that preserves the relative nesting level of each loop with respect to the others in 'match'.

This means that every loop in 'loops[i]' will have a parent loop in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'.

Definition at line 1554 of file SuperVectorize.cpp.

References getMatchedAffineLoopsRec().

◆ getMatchedAffineLoopsRec()

static void getMatchedAffineLoopsRec ( NestedMatch  match,
unsigned  currentLevel,
std::vector< SmallVector< AffineForOp, 2 >> &  loops 
)
static

Recursive implementation to convert all the nested loops in 'match' to a 2D vector container that preserves the relative nesting level of each loop with respect to the others in 'match'.

'currentLevel' is the nesting level that will be assigned to the loop in the current 'match'.

Definition at line 1534 of file SuperVectorize.cpp.

References mlir::affine::NestedMatch::getMatchedChildren(), and mlir::affine::NestedMatch::getMatchedOperation().

Referenced by getMatchedAffineLoops().

◆ getVectorType()

static VectorType getVectorType ( Type  scalarTy,
const VectorizationStrategy strategy 
)
static

Returns the vector type resulting from applying the provided vectorization strategy on the scalar type.

Definition at line 936 of file SuperVectorize.cpp.

References mlir::get(), and mlir::affine::VectorizationStrategy::vectorSizes.

◆ isNeutralElementConst()

static bool isNeutralElementConst ( arith::AtomicRMWKind  reductionKind,
Value  value,
VectorizationState state 
)
static

Returns true if value is a constant equal to the neutral element of the given vectorizable reduction.

Definition at line 1281 of file SuperVectorize.cpp.

◆ isUniformDefinition()

static bool isUniformDefinition ( Value  value,
const VectorizationStrategy strategy 
)
static

Returns true if the provided value is vector uniform given the vectorization strategy.

Definition at line 1098 of file SuperVectorize.cpp.

References mlir::affine::getForInductionVarOwner(), and mlir::affine::VectorizationStrategy::loopToVectorDim.

◆ isVectorizableLoopPtrFactory()

static FilterFunctionType isVectorizableLoopPtrFactory ( const DenseSet< Operation * > &  parallelLoops,
int  fastestVaryingMemRefDimension 
)
static

Forward declaration.

Returns a FilterFunctionType that can be used in NestedPattern to match a loop whose underlying load/store accesses are either invariant or all.

Definition at line 918 of file SuperVectorize.cpp.

References mlir::affine::isVectorizableLoopBody(), and vectorTransferPattern().

Referenced by makePattern().

◆ makePattern()

static std::optional<NestedPattern> makePattern ( const DenseSet< Operation * > &  parallelLoops,
int  vectorRank,
ArrayRef< int64_t >  fastestVaryingPattern 
)
static

Creates a vectorization pattern from the command line arguments.

Up to 3-D patterns are supported. If the command line argument requests a pattern of higher order, returns an empty pattern list which will conservatively result in no vectorization.

Definition at line 589 of file SuperVectorize.cpp.

References mlir::affine::matcher::For(), and isVectorizableLoopPtrFactory().

Referenced by vectorizeLoops().

◆ vectorizeAffineApplyOp()

static Operation* vectorizeAffineApplyOp ( AffineApplyOp  applyOp,
VectorizationState state 
)
static

We have no need to vectorize affine.apply.

However, we still need to generate it and replace the operands with values in valueScalarReplacement.

Definition at line 973 of file SuperVectorize.cpp.

◆ vectorizeAffineForOp()

static Operation* vectorizeAffineForOp ( AffineForOp  forOp,
VectorizationState state 
)
static

Vectorizes a loop with the vectorization strategy in 'state'.

A new loop is created and registered as replacement for the scalar loop. The builder's insertion point is set to the new loop's body so that subsequent vectorized operations are inserted into the new loop. If the loop is a vector dimension, the step of the newly created loop will reflect the vectorization factor used to vectorized that dimension.

Definition at line 1299 of file SuperVectorize.cpp.

◆ vectorizeAffineLoad()

static Operation* vectorizeAffineLoad ( AffineLoadOp  loadOp,
VectorizationState state 
)
static

Vectorizes an affine load with the vectorization strategy in 'state' by generating a 'vector.transfer_read' op with the proper permutation map inferred from the indices of the load.

The new 'vector.transfer_read' is registered as replacement of the scalar load. Returns the newly created 'vector.transfer_read' if vectorization was successful. Returns nullptr, otherwise.

Definition at line 1190 of file SuperVectorize.cpp.

◆ vectorizeAffineStore()

static Operation* vectorizeAffineStore ( AffineStoreOp  storeOp,
VectorizationState state 
)
static

Vectorizes an affine store with the vectorization strategy in 'state' by generating a 'vector.transfer_write' op with the proper permutation map inferred from the indices of the store.

The new 'vector.transfer_store' is registered as replacement of the scalar load. Returns the newly created 'vector.transfer_write' if vectorization was successful. Returns nullptr, otherwise.

Definition at line 1240 of file SuperVectorize.cpp.

◆ vectorizeAffineYieldOp()

static Operation* vectorizeAffineYieldOp ( AffineYieldOp  yieldOp,
VectorizationState state 
)
static

Vectorizes a yield operation by widening its types.

The builder's insertion point is set after the vectorized parent op to continue vectorizing the operations after the parent op. When vectorizing a reduction loop a mask may be used to prevent adding garbage values to the accumulator.

Definition at line 1453 of file SuperVectorize.cpp.

◆ vectorizeConstant()

static arith::ConstantOp vectorizeConstant ( arith::ConstantOp  constOp,
VectorizationState state 
)
static

Tries to transform a scalar constant into a vector constant.

Returns the vector constant if the scalar type is valid vector element type. Returns nullptr, otherwise.

Definition at line 945 of file SuperVectorize.cpp.

◆ vectorizeLoopIfProfitable()

static void vectorizeLoopIfProfitable ( Operation loop,
unsigned  depthInPattern,
unsigned  patternDepth,
VectorizationStrategy strategy 
)
static

◆ vectorizeLoopNest()

static LogicalResult vectorizeLoopNest ( std::vector< SmallVector< AffineForOp, 2 >> &  loops,
const VectorizationStrategy strategy 
)
static

Internal implementation to vectorize affine loops from a single loop nest using an n-D vectorization strategy.

Definition at line 1562 of file SuperVectorize.cpp.

Referenced by mlir::affine::vectorizeAffineLoopNest().

◆ vectorizeLoops()

static void vectorizeLoops ( Operation parentOp,
DenseSet< Operation * > &  loops,
ArrayRef< int64_t >  vectorSizes,
ArrayRef< int64_t >  fastestVaryingPattern,
const ReductionLoopMap reductionLoops 
)
static

Internal implementation to vectorize affine loops in 'loops' using the n-D vectorization factors in 'vectorSizes'.

By default, each vectorization factor is applied inner-to-outer to the loops of each loop nest. 'fastestVaryingPattern' can be optionally used to provide a different loop vectorization order. reductionLoops can be provided to specify loops which can be vectorized along the reduction dimension.

Definition at line 1688 of file SuperVectorize.cpp.

References analyzeProfitability(), computeIntersectionBuckets(), makePattern(), mlir::affine::VectorizationStrategy::reductionLoops, vectorizeLoopIfProfitable(), vectorizeRootMatch(), and mlir::affine::VectorizationStrategy::vectorSizes.

◆ vectorizeOneOperation()

static Operation* vectorizeOneOperation ( Operation op,
VectorizationState state 
)
static

Encodes Operation-specific behavior for vectorization.

In general we assume that all operands of an op must be vectorized but this is not always true. In the future, it would be nice to have a trait that describes how a particular operation vectorizes. For now we implement the case distinction here. Returns a vectorized form of an operation or nullptr if vectorization fails.

Definition at line 1501 of file SuperVectorize.cpp.

◆ vectorizeOperand()

static Value vectorizeOperand ( Value  operand,
VectorizationState state 
)
static

Tries to vectorize a given operand by applying the following logic:

  1. if the defining operation has been already vectorized, operand is already in the proper vector form;
  2. if the operand is a constant, returns the vectorized form of the constant;
  3. if the operand is uniform, returns a vector broadcast of the op;
  4. otherwise, the vectorization of operand is not supported. Newly created vector operations are registered in state as replacement for their scalar counterparts. In particular this logic captures some of the use cases where definitions that are not scoped under the current pattern are needed to vectorize. One such example is top level function constants that need to be splatted.

Returns an operand that has been vectorized to match state's strategy if vectorization is possible with the above logic. Returns nullptr otherwise.

TODO: handle more complex cases.

Definition at line 1145 of file SuperVectorize.cpp.

◆ vectorizeRootMatch()

static LogicalResult vectorizeRootMatch ( NestedMatch  m,
const VectorizationStrategy strategy 
)
static

Extracts the matched loops and vectorizes them following a topological order.

A new vector loop nest will be created if vectorization succeeds. The original loop nest won't be modified in any case.

Definition at line 1633 of file SuperVectorize.cpp.

Referenced by vectorizeLoops().

◆ vectorizeUniform()

static Operation* vectorizeUniform ( Value  uniformVal,
VectorizationState state 
)
static

Generates a broadcast op for the provided uniform value using the vectorization strategy in 'state'.

Definition at line 1114 of file SuperVectorize.cpp.

◆ vectorTransferPattern()

static NestedPattern& vectorTransferPattern ( )
static

Definition at line 611 of file SuperVectorize.cpp.

References mlir::affine::matcher::Op().

Referenced by isVectorizableLoopPtrFactory().

◆ verifyLoopNesting()

static LogicalResult verifyLoopNesting ( const std::vector< SmallVector< AffineForOp, 2 >> &  loops)
static

Verify that affine loops in 'loops' meet the nesting criteria expected by SuperVectorizer:

  • There must be at least one loop.
  • There must be a single root loop (nesting level 0).
  • Each loop at a given nesting level must be nested in a loop from a previous nesting level.

Definition at line 1800 of file SuperVectorize.cpp.

Referenced by mlir::affine::vectorizeAffineLoopNest().

◆ widenOp()

static Operation* widenOp ( Operation op,
VectorizationState state 
)
static

Vectorizes arbitrary operation by plain widening.

We apply generic type widening of all its results and retrieve the vector counterparts for all its operands.

Definition at line 1421 of file SuperVectorize.cpp.