MLIR  20.0.0git
VectorUtils.h
Go to the documentation of this file.
1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11 
18 #include "mlir/Support/LLVM.h"
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 
25 // Forward declarations.
26 class AffineMap;
27 class Block;
28 class Location;
29 class OpBuilder;
30 class Operation;
31 class ShapedType;
32 class Value;
33 class VectorType;
34 class VectorTransferOpInterface;
35 
36 namespace affine {
37 class AffineApplyOp;
38 class AffineForOp;
39 } // namespace affine
40 
41 namespace vector {
42 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
43 /// the type of `source`.
44 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
45 
46 /// Returns two dims that are greater than one if the transposition is applied
47 /// on a 2D slice. Otherwise, returns a failure.
48 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
49 
50 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
51 ///
52 /// Only the N = vectorType.getRank() trailing dims of `memrefType` are
53 /// checked (the other dims are not relevant). Note that for `vectorType` to be
54 /// a contiguous slice of `memrefType`, the trailing dims of the latter have
55 /// to be contiguous - this is checked by looking at the corresponding strides.
56 ///
57 /// There might be some restriction on the leading dim of `VectorType`:
58 ///
59 /// Case 1. If all the trailing dims of `vectorType` match the trailing dims
60 /// of `memrefType` then the leading dim of `vectorType` can be
61 /// arbitrary.
62 ///
63 /// Ex. 1.1 contiguous slice, perfect match
64 /// vector<4x3x2xi32> from memref<5x4x3x2xi32>
65 /// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
66 /// vector<2x3x2xi32> from memref<5x4x3x2xi32>
67 ///
68 /// Case 2. If an "internal" dim of `vectorType` does not match the
69 /// corresponding trailing dim in `memrefType` then the remaining
70 /// leading dims of `vectorType` have to be 1 (the first non-matching
71 /// dim can be arbitrary).
72 ///
73 /// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
74 /// vector<2x2x2xi32> from memref<5x4x3x2xi32>
75 /// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
76 /// vector<1x2x2xi32> from memref<5x4x3x2xi32>
77 /// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
78 /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
79 /// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
80 /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
81 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
82 
83 /// Returns an iterator for all positions in the leading dimensions of `vType`
84 /// up to the `targetRank`. If any leading dimension before the `targetRank` is
85 /// scalable (so cannot be unrolled), it will return an iterator for positions
86 /// up to the first scalable dimension.
87 ///
88 /// If no leading dimensions can be unrolled an empty optional will be returned.
89 ///
90 /// Examples:
91 ///
92 /// For vType = vector<2x3x4> and targetRank = 1
93 ///
94 /// The resulting iterator will yield:
95 /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
96 ///
97 /// For vType = vector<3x[4]x5> and targetRank = 0
98 ///
99 /// The scalable dimension blocks unrolling so the iterator yields only:
100 /// [0], [1], [2]
101 ///
102 std::optional<StaticTileOffsetRange>
103 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
104 
105 /// Returns a functor (int64_t -> Value) which returns a constant vscale
106 /// multiple.
107 ///
108 /// Example:
109 /// ```c++
110 /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
111 /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
112 /// ```
114  Value vscale = nullptr;
115  return [loc, vscale, &rewriter](int64_t multiplier) mutable {
116  if (!vscale)
117  vscale = rewriter.create<vector::VectorScaleOp>(loc);
118  return rewriter.create<arith::MulIOp>(
119  loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
120  };
121 }
122 
123 /// Returns a range over the dims (size and scalability) of a VectorType.
124 inline auto getDims(VectorType vType) {
125  return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
126 }
127 
128 /// A wrapper for getMixedSizes for vector.transfer_read and
129 /// vector.transfer_write Ops (for source and destination, respectively).
130 ///
131 /// Tensor and MemRef types implement their own, very similar version of
132 /// getMixedSizes. This method will call the appropriate version (depending on
133 /// `hasTensorSemantics`). It will also automatically extract the operand for
134 /// which to call it on (source for "read" and destination for "write" ops).
136  Operation *xfer,
137  RewriterBase &rewriter);
138 
139 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
140 /// masked (i.e. inside `vector.mask` Op region). In particular:
141 /// 1. Matches `SourceOp` operation, Op.
142 /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
143 /// insertion point to avoid inserting new ops into the `vector.mask` Op
144 /// region (which only allows one Op).
145 /// 2.2 If Op is not masked, this step is skipped.
146 /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
147 /// found in step 2.1.
148 ///
149 /// This wrapper frees patterns from re-implementing the logic to update the
150 /// insertion point when a maskable Op is masked. Such patterns are still
151 /// responsible for providing an updated ("rewritten") version of:
152 /// a. the source Op when mask _is not_ present,
153 /// b. the source Op and the masking Op when mask _is_ present.
154 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
155 /// the return value will depend on the case above.
156 template <class SourceOp>
159 
160 private:
161  LogicalResult matchAndRewrite(SourceOp sourceOp,
162  PatternRewriter &rewriter) const final {
163  auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
164  if (!maskableOp)
165  return failure();
166 
167  Operation *rootOp = sourceOp;
168 
169  // If this Op is masked, update the insertion point to avoid inserting into
170  // the vector.mask Op region.
171  OpBuilder::InsertionGuard guard(rewriter);
172  MaskingOpInterface maskOp;
173  if (maskableOp.isMasked()) {
174  maskOp = maskableOp.getMaskingOp();
175  rewriter.setInsertionPoint(maskOp);
176  rootOp = maskOp;
177  }
178 
179  FailureOr<Value> newOp =
180  matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
181  if (failed(newOp))
182  return failure();
183 
184  // Rewriting succeeded but there are no values to replace.
185  if (rootOp->getNumResults() == 0) {
186  rewriter.eraseOp(rootOp);
187  } else {
188  assert(*newOp != Value() &&
189  "Cannot replace an op's use with an empty value.");
190  rewriter.replaceOp(rootOp, *newOp);
191  }
192  return success();
193  }
194 
195 public:
196  // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
197  // latter is present, returns a replacement for `maskingOp`. Otherwise,
198  // returns a replacement for `sourceOp`.
199  virtual FailureOr<Value>
200  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
201  PatternRewriter &rewriter) const = 0;
202 };
203 
204 /// Returns true if the input Vector type can be linearized.
205 ///
206 /// Linearization is meant in the sense of flattening vectors, e.g.:
207 /// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
208 /// In this sense, Vectors that are either:
209 /// * already linearized, or
210 /// * contain more than 1 scalable dimensions,
211 /// are not linearizable.
212 bool isLinearizableVector(VectorType type);
213 
214 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
215 /// vector type for the read is not the same as the type of `source`, then a
216 /// mask is created on the read, if use of mask is specified or the bounds on a
217 /// dimension are different.
218 ///
219 /// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set
220 /// properly, based on
221 /// the rank dimensions of the source and destination tensors. And that is
222 /// what determines if masking is done.
223 ///
224 /// Note that the internal `vector::TransferReadOp` always read at indices zero
225 /// for each dimension of the passed in tensor.
226 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
227  ArrayRef<int64_t> readShape, Value padValue,
228  bool useInBoundsInsteadOfMasking);
229 
230 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
231 /// given `shape`, i.e., it meets:
232 /// 1. The numbers of elements in both array are equal.
233 /// 2. `inputVectorSizes` does not have dynamic dimensions.
234 /// 3. All the values in `inputVectorSizes` are greater than or equal to
235 /// static sizes in `shape`.
236 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
237  ArrayRef<int64_t> inputVectorSizes);
238 } // namespace vector
239 
240 /// Constructs a permutation map of invariant memref indices to vector
241 /// dimension.
242 ///
243 /// If no index is found to be invariant, 0 is added to the permutation_map and
244 /// corresponds to a vector broadcast along that dimension.
245 ///
246 /// The implementation uses the knowledge of the mapping of loops to
247 /// vector dimension. `loopToVectorDim` carries this information as a map with:
248 /// - keys representing "vectorized enclosing loops";
249 /// - values representing the corresponding vector dimension.
250 /// Note that loopToVectorDim is a whole function map from which only enclosing
251 /// loop information is extracted.
252 ///
253 /// Prerequisites: `indices` belong to a vectorizable load or store operation
254 /// (i.e. at most one invariant index along each AffineForOp of
255 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
256 /// load or store operation.
257 ///
258 /// Example 1:
259 /// The following MLIR snippet:
260 ///
261 /// ```mlir
262 /// affine.for %i3 = 0 to %0 {
263 /// affine.for %i4 = 0 to %1 {
264 /// affine.for %i5 = 0 to %2 {
265 /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
266 /// }}}
267 /// ```
268 ///
269 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
270 ///
271 /// ```mlir
272 /// affine.for %i3 = 0 to %0 step 32 {
273 /// affine.for %i4 = 0 to %1 {
274 /// affine.for %i5 = 0 to %2 step 256 {
275 /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
276 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
277 /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
278 /// }}}
279 /// ```
280 ///
281 /// Meaning that vector.transfer_read will be responsible for reading the slice:
282 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
283 ///
284 /// Example 2:
285 /// The following MLIR snippet:
286 ///
287 /// ```mlir
288 /// %cst0 = arith.constant 0 : index
289 /// affine.for %i0 = 0 to %0 {
290 /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
291 /// }
292 /// ```
293 ///
294 /// may vectorize with {permutation_map: (d0) -> (0)} into:
295 ///
296 /// ```mlir
297 /// affine.for %i0 = 0 to %0 step 128 {
298 /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
299 /// {permutation_map: (d0, d1) -> (0)} :
300 /// (memref<?x?xf32>, index, index) -> vector<128xf32>
301 /// }
302 /// ````
303 ///
304 /// Meaning that vector.transfer_read will be responsible of reading the slice
305 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
306 ///
307 AffineMap
308 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
309  const DenseMap<Operation *, unsigned> &loopToVectorDim);
310 AffineMap
311 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
312  const DenseMap<Operation *, unsigned> &loopToVectorDim);
313 
314 namespace matcher {
315 
316 /// Matches vector.transfer_read, vector.transfer_write and ops that return a
317 /// vector type that is a multiple of the sub-vector type. This allows passing
318 /// over other smaller vector types in the function and avoids interfering with
319 /// operations on those.
320 /// This is a first approximation, it can easily be extended in the future.
321 /// TODO: this could all be much simpler if we added a bit that a vector type to
322 /// mark that a vector is a strict super-vector but it still does not warrant
323 /// adding even 1 extra bit in the IR for now.
324 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
325 
326 } // namespace matcher
327 } // namespace mlir
328 
329 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
static AffineMap makePermutationMap(ArrayRef< Value > indices, const DenseMap< Operation *, unsigned > &enclosingLoopToVectorDim)
Constructs a permutation map from memref indices to vector dimension.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Block represents an ordered list of Operations.
Definition: Block.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:124
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:84
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:113
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:157
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0