MLIR  22.0.0git
VectorUtils.h
Go to the documentation of this file.
1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11 
19 #include "mlir/Support/LLVM.h"
20 
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 namespace mlir {
25 
26 // Forward declarations.
27 class AffineMap;
28 class Block;
29 class Location;
30 class OpBuilder;
31 class Operation;
32 class ShapedType;
33 class Value;
34 class VectorType;
35 class VectorTransferOpInterface;
36 
37 namespace affine {
38 class AffineApplyOp;
39 class AffineForOp;
40 } // namespace affine
41 
42 namespace vector {
43 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
44 /// the type of `source`.
45 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
46 
47 /// Returns two dims that are greater than one if the transposition is applied
48 /// on a 2D slice. Otherwise, returns a failure.
49 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
50 
51 /// Return true if `vectorType` is a contiguous slice of `memrefType`,
52 /// in the sense that it can be read/written from/to a contiguous area
53 /// of the memref.
54 ///
55 /// The leading unit dimensions of the vector type are ignored as they
56 /// are not relevant to the result. Let N be the number of the vector
57 /// dimensions after ignoring a leading sequence of unit ones.
58 ///
59 /// For `vectorType` to be a contiguous slice of `memrefType`
60 /// a) the N trailing dimensions of `memrefType` must be contiguous, and
61 /// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
62 /// must match.
63 ///
64 /// Examples:
65 ///
66 /// Ex.1 contiguous slice, perfect match
67 /// vector<4x3x2xi32> from memref<5x4x3x2xi32>
68 /// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
69 /// vector<2x3x2xi32> from memref<5x4x3x2xi32>
70 /// Ex.3 non-contiguous slice, 2 != 3
71 /// vector<2x2x2xi32> from memref<5x4x3x2xi32>
72 /// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
73 /// 2 != 3 (allowed)
74 /// vector<1x2x2xi32> from memref<5x4x3x2xi32>
75 /// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
76 /// 2 != 3 (allowed)
77 /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
78 /// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
79 /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
80 /// Ex.7 contiguous slice, memref needs to be contiguous only in the last
81 /// dimension
82 /// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
83 /// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
84 /// two dimensions, and it isn't
85 /// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
86 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
87 
88 /// Returns an iterator for all positions in the leading dimensions of `vType`
89 /// up to the `targetRank`. If any leading dimension before the `targetRank` is
90 /// scalable (so cannot be unrolled), it will return an iterator for positions
91 /// up to the first scalable dimension.
92 ///
93 /// If no leading dimensions can be unrolled an empty optional will be returned.
94 ///
95 /// Examples:
96 ///
97 /// For vType = vector<2x3x4> and targetRank = 1
98 ///
99 /// The resulting iterator will yield:
100 /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
101 ///
102 /// For vType = vector<3x[4]x5> and targetRank = 0
103 ///
104 /// The scalable dimension blocks unrolling so the iterator yields only:
105 /// [0], [1], [2]
106 ///
107 std::optional<StaticTileOffsetRange>
108 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
109 
110 /// Returns a functor (int64_t -> Value) which returns a constant vscale
111 /// multiple.
112 ///
113 /// Example:
114 /// ```c++
115 /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
116 /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
117 /// ```
119  Value vscale = nullptr;
120  return [loc, vscale, &rewriter](int64_t multiplier) mutable {
121  if (!vscale)
122  vscale = vector::VectorScaleOp::create(rewriter, loc);
123  return arith::MulIOp::create(
124  rewriter, loc, vscale,
125  arith::ConstantIndexOp::create(rewriter, loc, multiplier));
126  };
127 }
128 
129 /// Returns a range over the dims (size and scalability) of a VectorType.
130 inline auto getDims(VectorType vType) {
131  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
132 }
133 
134 /// A wrapper for getMixedSizes for vector.transfer_read and
135 /// vector.transfer_write Ops (for source and destination, respectively).
136 ///
137 /// Tensor and MemRef types implement their own, very similar version of
138 /// getMixedSizes. This method will call the appropriate version (depending on
139 /// `hasTensorSemantics`). It will also automatically extract the operand for
140 /// which to call it on (source for "read" and destination for "write" ops).
142  Operation *xfer,
143  RewriterBase &rewriter);
144 
145 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
146 /// masked (i.e. inside `vector.mask` Op region). In particular:
147 /// 1. Matches `SourceOp` operation, Op.
148 /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
149 /// insertion point to avoid inserting new ops into the `vector.mask` Op
150 /// region (which only allows one Op).
151 /// 2.2 If Op is not masked, this step is skipped.
152 /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
153 /// found in step 2.1.
154 ///
155 /// This wrapper frees patterns from re-implementing the logic to update the
156 /// insertion point when a maskable Op is masked. Such patterns are still
157 /// responsible for providing an updated ("rewritten") version of:
158 /// a. the source Op when mask _is not_ present,
159 /// b. the source Op and the masking Op when mask _is_ present.
160 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
161 /// the return value will depend on the case above.
162 template <class SourceOp>
165 
166 private:
167  LogicalResult matchAndRewrite(SourceOp sourceOp,
168  PatternRewriter &rewriter) const final {
169  auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
170  if (!maskableOp)
171  return failure();
172 
173  Operation *rootOp = sourceOp;
174 
175  // If this Op is masked, update the insertion point to avoid inserting into
176  // the vector.mask Op region.
177  OpBuilder::InsertionGuard guard(rewriter);
178  MaskingOpInterface maskOp;
179  if (maskableOp.isMasked()) {
180  maskOp = maskableOp.getMaskingOp();
181  rewriter.setInsertionPoint(maskOp);
182  rootOp = maskOp;
183  }
184 
185  FailureOr<Value> newOp =
186  matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
187  if (failed(newOp))
188  return failure();
189 
190  // Rewriting succeeded but there are no values to replace.
191  if (rootOp->getNumResults() == 0) {
192  rewriter.eraseOp(rootOp);
193  } else {
194  assert(*newOp != Value() &&
195  "Cannot replace an op's use with an empty value.");
196  rewriter.replaceOp(rootOp, *newOp);
197  }
198  return success();
199  }
200 
201 public:
202  // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
203  // latter is present, returns a replacement for `maskingOp`. Otherwise,
204  // returns a replacement for `sourceOp`.
205  virtual FailureOr<Value>
206  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
207  PatternRewriter &rewriter) const = 0;
208 };
209 
210 /// Returns true if the input Vector type can be linearized.
211 ///
212 /// Linearization is meant in the sense of flattening vectors, e.g.:
213 /// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
214 /// In this sense, Vectors that are either:
215 /// * already linearized, or
216 /// * contain more than 1 scalable dimensions,
217 /// are not linearizable.
218 bool isLinearizableVector(VectorType type);
219 
220 /// Creates a TransferReadOp from `source`.
221 ///
222 /// The shape of the vector to read is specified via `inputVectorSizes`. If the
223 /// shape of the output vector differs from the shape of the value being read,
224 /// masking is used to avoid out-of-bounds accesses. Set
225 /// `useInBoundsInsteadOfMasking` to `true` to use the "in_bounds" attribute
226 /// instead of explicit masks.
227 ///
228 /// Note: all read offsets are set to 0.
229 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
230  ArrayRef<int64_t> inputVectorSizes, Value padValue,
231  bool useInBoundsInsteadOfMasking = false,
232  ArrayRef<bool> inputScalableVecDims = {});
233 
234 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
235 /// given `shape`, i.e., it meets:
236 /// 1. The numbers of elements in both array are equal.
237 /// 2. `inputVectorSizes` does not have dynamic dimensions.
238 /// 3. All the values in `inputVectorSizes` are greater than or equal to
239 /// static sizes in `shape`.
240 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
241  ArrayRef<int64_t> inputVectorSizes);
242 
243 /// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
244 /// This handles the common pattern of:
245 /// 1. Check if already 1-D. If so, return failure.
246 /// 2. Check for scalable dimensions. If so, return failure.
247 /// 3. Create poison initialized result.
248 /// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
249 /// create sub vectors.
250 /// 5. Insert the sub vectors back into the final vector.
251 /// 6. Replace the original op with the new result.
253  function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
254 
255 LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
256  UnrollVectorOpFn unrollFn);
257 
258 } // namespace vector
259 
260 /// Constructs a permutation map of invariant memref indices to vector
261 /// dimension.
262 ///
263 /// If no index is found to be invariant, 0 is added to the permutation_map and
264 /// corresponds to a vector broadcast along that dimension.
265 ///
266 /// The implementation uses the knowledge of the mapping of loops to
267 /// vector dimension. `loopToVectorDim` carries this information as a map with:
268 /// - keys representing "vectorized enclosing loops";
269 /// - values representing the corresponding vector dimension.
270 /// Note that loopToVectorDim is a whole function map from which only enclosing
271 /// loop information is extracted.
272 ///
273 /// Prerequisites: `indices` belong to a vectorizable load or store operation
274 /// (i.e. at most one invariant index along each AffineForOp of
275 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
276 /// load or store operation.
277 ///
278 /// Example 1:
279 /// The following MLIR snippet:
280 ///
281 /// ```mlir
282 /// affine.for %i3 = 0 to %0 {
283 /// affine.for %i4 = 0 to %1 {
284 /// affine.for %i5 = 0 to %2 {
285 /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
286 /// }}}
287 /// ```
288 ///
289 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
290 ///
291 /// ```mlir
292 /// affine.for %i3 = 0 to %0 step 32 {
293 /// affine.for %i4 = 0 to %1 {
294 /// affine.for %i5 = 0 to %2 step 256 {
295 /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
296 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
297 /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
298 /// }}}
299 /// ```
300 ///
301 /// Meaning that vector.transfer_read will be responsible for reading the slice:
302 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
303 ///
304 /// Example 2:
305 /// The following MLIR snippet:
306 ///
307 /// ```mlir
308 /// %cst0 = arith.constant 0 : index
309 /// affine.for %i0 = 0 to %0 {
310 /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
311 /// }
312 /// ```
313 ///
314 /// may vectorize with {permutation_map: (d0) -> (0)} into:
315 ///
316 /// ```mlir
317 /// affine.for %i0 = 0 to %0 step 128 {
318 /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
319 /// {permutation_map: (d0, d1) -> (0)} :
320 /// (memref<?x?xf32>, index, index) -> vector<128xf32>
321 /// }
322 /// ````
323 ///
324 /// Meaning that vector.transfer_read will be responsible of reading the slice
325 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
326 ///
327 AffineMap
328 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
329  const DenseMap<Operation *, unsigned> &loopToVectorDim);
330 AffineMap
331 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
332  const DenseMap<Operation *, unsigned> &loopToVectorDim);
333 
334 namespace matcher {
335 
336 /// Matches vector.transfer_read, vector.transfer_write and ops that return a
337 /// vector type that is a multiple of the sub-vector type. This allows passing
338 /// over other smaller vector types in the function and avoids interfering with
339 /// operations on those.
340 /// This is a first approximation, it can easily be extended in the future.
341 /// TODO: this could all be much simpler if we added a bit that a vector type to
342 /// mark that a vector is a strict super-vector but it still does not warrant
343 /// adding even 1 extra bit in the IR for now.
344 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
345 
346 } // namespace matcher
347 } // namespace mlir
348 
349 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
static AffineMap makePermutationMap(ArrayRef< Value > indices, const DenseMap< Operation *, unsigned > &enclosingLoopToVectorDim)
Constructs a permutation map from memref indices to vector dimension.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Block represents an ordered list of Operations.
Definition: Block.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType, in the sense that it can be read/writt...
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:130
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:82
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:39
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:118
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking=false, ArrayRef< bool > inputScalableVecDims={})
Creates a TransferReadOp from source.
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:163
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0