MLIR  20.0.0git
VectorUtils.h
Go to the documentation of this file.
1 //===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
10 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
11 
18 #include "mlir/Support/LLVM.h"
19 
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 
25 // Forward declarations.
26 class AffineMap;
27 class Block;
28 class Location;
29 class OpBuilder;
30 class Operation;
31 class ShapedType;
32 class Value;
33 class VectorType;
34 class VectorTransferOpInterface;
35 
36 namespace affine {
37 class AffineApplyOp;
38 class AffineForOp;
39 } // namespace affine
40 
41 namespace vector {
42 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
43 /// the type of `source`.
44 Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
45 
46 /// Returns two dims that are greater than one if the transposition is applied
47 /// on a 2D slice. Otherwise, returns a failure.
48 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
49 
50 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
51 ///
52 /// Only the N = vectorType.getRank() trailing dims of `memrefType` are
53 /// checked (the other dims are not relevant). Note that for `vectorType` to be
54 /// a contiguous slice of `memrefType`, the trailing dims of the latter have
55 /// to be contiguous - this is checked by looking at the corresponding strides.
56 ///
57 /// There might be some restriction on the leading dim of `VectorType`:
58 ///
59 /// Case 1. If all the trailing dims of `vectorType` match the trailing dims
60 /// of `memrefType` then the leading dim of `vectorType` can be
61 /// arbitrary.
62 ///
63 /// Ex. 1.1 contiguous slice, perfect match
64 /// vector<4x3x2xi32> from memref<5x4x3x2xi32>
65 /// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
66 /// vector<2x3x2xi32> from memref<5x4x3x2xi32>
67 ///
68 /// Case 2. If an "internal" dim of `vectorType` does not match the
69 /// corresponding trailing dim in `memrefType` then the remaining
70 /// leading dims of `vectorType` have to be 1 (the first non-matching
71 /// dim can be arbitrary).
72 ///
73 /// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
74 /// vector<2x2x2xi32> from memref<5x4x3x2xi32>
75 /// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
76 /// vector<1x2x2xi32> from memref<5x4x3x2xi32>
77 /// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
78 /// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
79 /// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
80 /// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
81 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
82 
83 /// Returns an iterator for all positions in the leading dimensions of `vType`
84 /// up to the `targetRank`. If any leading dimension before the `targetRank` is
85 /// scalable (so cannot be unrolled), it will return an iterator for positions
86 /// up to the first scalable dimension.
87 ///
88 /// If no leading dimensions can be unrolled an empty optional will be returned.
89 ///
90 /// Examples:
91 ///
92 /// For vType = vector<2x3x4> and targetRank = 1
93 ///
94 /// The resulting iterator will yield:
95 /// [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]
96 ///
97 /// For vType = vector<3x[4]x5> and targetRank = 0
98 ///
99 /// The scalable dimension blocks unrolling so the iterator yields only:
100 /// [0], [1], [2]
101 ///
102 std::optional<StaticTileOffsetRange>
103 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
104 
105 /// Returns a functor (int64_t -> Value) which returns a constant vscale
106 /// multiple.
107 ///
108 /// Example:
109 /// ```c++
110 /// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
111 /// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
112 /// ```
114  Value vscale = nullptr;
115  return [loc, vscale, &rewriter](int64_t multiplier) mutable {
116  if (!vscale)
117  vscale = rewriter.create<vector::VectorScaleOp>(loc);
118  return rewriter.create<arith::MulIOp>(
119  loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
120  };
121 }
122 
123 /// A wrapper for getMixedSizes for vector.transfer_read and
124 /// vector.transfer_write Ops (for source and destination, respectively).
125 ///
126 /// Tensor and MemRef types implement their own, very similar version of
127 /// getMixedSizes. This method will call the appropriate version (depending on
128 /// `hasTensorSemantics`). It will also automatically extract the operand for
129 /// which to call it on (source for "read" and destination for "write" ops).
131  Operation *xfer,
132  RewriterBase &rewriter);
133 
134 /// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
135 /// masked (i.e. inside `vector.mask` Op region). In particular:
136 /// 1. Matches `SourceOp` operation, Op.
137 /// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
138 /// insertion point to avoid inserting new ops into the `vector.mask` Op
139 /// region (which only allows one Op).
140 /// 2.2 If Op is not masked, this step is skipped.
141 /// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
142 /// found in step 2.1.
143 ///
144 /// This wrapper frees patterns from re-implementing the logic to update the
145 /// insertion point when a maskable Op is masked. Such patterns are still
146 /// responsible for providing an updated ("rewritten") version of:
147 /// a. the source Op when mask _is not_ present,
148 /// b. the source Op and the masking Op when mask _is_ present.
149 /// To use this pattern, implement `matchAndRewriteMaskableOp`. Note that
150 /// the return value will depend on the case above.
151 template <class SourceOp>
154 
155 private:
156  LogicalResult matchAndRewrite(SourceOp sourceOp,
157  PatternRewriter &rewriter) const final {
158  auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
159  if (!maskableOp)
160  return failure();
161 
162  Operation *rootOp = sourceOp;
163 
164  // If this Op is masked, update the insertion point to avoid inserting into
165  // the vector.mask Op region.
166  OpBuilder::InsertionGuard guard(rewriter);
167  MaskingOpInterface maskOp;
168  if (maskableOp.isMasked()) {
169  maskOp = maskableOp.getMaskingOp();
170  rewriter.setInsertionPoint(maskOp);
171  rootOp = maskOp;
172  }
173 
174  FailureOr<Value> newOp =
175  matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
176  if (failed(newOp))
177  return failure();
178 
179  // Rewriting succeeded but there are no values to replace.
180  if (rootOp->getNumResults() == 0) {
181  rewriter.eraseOp(rootOp);
182  } else {
183  assert(*newOp != Value() &&
184  "Cannot replace an op's use with an empty value.");
185  rewriter.replaceOp(rootOp, *newOp);
186  }
187  return success();
188  }
189 
190 public:
191  // Matches `sourceOp` that can potentially be masked with `maskingOp`. If the
192  // latter is present, returns a replacement for `maskingOp`. Otherwise,
193  // returns a replacement for `sourceOp`.
194  virtual FailureOr<Value>
195  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
196  PatternRewriter &rewriter) const = 0;
197 };
198 
199 /// Returns true if the input Vector type can be linearized.
200 ///
201 /// Linearization is meant in the sense of flattening vectors, e.g.:
202 /// * vector<NxMxKxi32> -> vector<N*M*Kxi32>
203 /// In this sense, Vectors that are either:
204 /// * already linearized, or
205 /// * contain more than 1 scalable dimensions,
206 /// are not linearizable.
207 bool isLinearizableVector(VectorType type);
208 
209 /// Create a TransferReadOp from `source` with static shape `readShape`. If the
210 /// vector type for the read is not the same as the type of `source`, then a
211 /// mask is created on the read, if use of mask is specified or the bounds on a
212 /// dimension are different.
213 ///
214 /// `useInBoundsInsteadOfMasking` if false, the inBoundsVal values are set
215 /// properly, based on
216 /// the rank dimensions of the source and destination tensors. And that is
217 /// what determines if masking is done.
218 ///
219 /// Note that the internal `vector::TransferReadOp` always read at indices zero
220 /// for each dimension of the passed in tensor.
221 Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
222  ArrayRef<int64_t> readShape, Value padValue,
223  bool useInBoundsInsteadOfMasking);
224 
225 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
226 /// given `shape`, i.e., it meets:
227 /// 1. The numbers of elements in both array are equal.
228 /// 2. `inputVectorSizes` does not have dynamic dimensions.
229 /// 3. All the values in `inputVectorSizes` are greater than or equal to
230 /// static sizes in `shape`.
231 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
232  ArrayRef<int64_t> inputVectorSizes);
233 } // namespace vector
234 
235 /// Constructs a permutation map of invariant memref indices to vector
236 /// dimension.
237 ///
238 /// If no index is found to be invariant, 0 is added to the permutation_map and
239 /// corresponds to a vector broadcast along that dimension.
240 ///
241 /// The implementation uses the knowledge of the mapping of loops to
242 /// vector dimension. `loopToVectorDim` carries this information as a map with:
243 /// - keys representing "vectorized enclosing loops";
244 /// - values representing the corresponding vector dimension.
245 /// Note that loopToVectorDim is a whole function map from which only enclosing
246 /// loop information is extracted.
247 ///
248 /// Prerequisites: `indices` belong to a vectorizable load or store operation
249 /// (i.e. at most one invariant index along each AffineForOp of
250 /// `loopToVectorDim`). `insertPoint` is the insertion point for the vectorized
251 /// load or store operation.
252 ///
253 /// Example 1:
254 /// The following MLIR snippet:
255 ///
256 /// ```mlir
257 /// affine.for %i3 = 0 to %0 {
258 /// affine.for %i4 = 0 to %1 {
259 /// affine.for %i5 = 0 to %2 {
260 /// %a5 = load %arg0[%i4, %i5, %i3] : memref<?x?x?xf32>
261 /// }}}
262 /// ```
263 ///
264 /// may vectorize with {permutation_map: (d0, d1, d2) -> (d2, d1)} into:
265 ///
266 /// ```mlir
267 /// affine.for %i3 = 0 to %0 step 32 {
268 /// affine.for %i4 = 0 to %1 {
269 /// affine.for %i5 = 0 to %2 step 256 {
270 /// %4 = vector.transfer_read %arg0, %i4, %i5, %i3
271 /// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
272 /// (memref<?x?x?xf32>, index, index) -> vector<32x256xf32>
273 /// }}}
274 /// ```
275 ///
276 /// Meaning that vector.transfer_read will be responsible for reading the slice:
277 /// `%arg0[%i4, %i5:%15+256, %i3:%i3+32]` into vector<32x256xf32>.
278 ///
279 /// Example 2:
280 /// The following MLIR snippet:
281 ///
282 /// ```mlir
283 /// %cst0 = arith.constant 0 : index
284 /// affine.for %i0 = 0 to %0 {
285 /// %a0 = load %arg0[%cst0, %cst0] : memref<?x?xf32>
286 /// }
287 /// ```
288 ///
289 /// may vectorize with {permutation_map: (d0) -> (0)} into:
290 ///
291 /// ```mlir
292 /// affine.for %i0 = 0 to %0 step 128 {
293 /// %3 = vector.transfer_read %arg0, %c0_0, %c0_0
294 /// {permutation_map: (d0, d1) -> (0)} :
295 /// (memref<?x?xf32>, index, index) -> vector<128xf32>
296 /// }
297 /// ````
298 ///
299 /// Meaning that vector.transfer_read will be responsible of reading the slice
300 /// `%arg0[%c0, %c0]` into vector<128xf32> which needs a 1-D vector broadcast.
301 ///
302 AffineMap
303 makePermutationMap(Block *insertPoint, ArrayRef<Value> indices,
304  const DenseMap<Operation *, unsigned> &loopToVectorDim);
305 AffineMap
306 makePermutationMap(Operation *insertPoint, ArrayRef<Value> indices,
307  const DenseMap<Operation *, unsigned> &loopToVectorDim);
308 
309 namespace matcher {
310 
311 /// Matches vector.transfer_read, vector.transfer_write and ops that return a
312 /// vector type that is a multiple of the sub-vector type. This allows passing
313 /// over other smaller vector types in the function and avoids interfering with
314 /// operations on those.
315 /// This is a first approximation, it can easily be extended in the future.
316 /// TODO: this could all be much simpler if we added a bit that a vector type to
317 /// mark that a vector is a strict super-vector but it still does not warrant
318 /// adding even 1 extra bit in the IR for now.
319 bool operatesOnSuperVectorsOf(Operation &op, VectorType subVectorType);
320 
321 } // namespace matcher
322 } // namespace mlir
323 
324 #endif // MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
static AffineMap makePermutationMap(ArrayRef< Value > indices, const DenseMap< Operation *, unsigned > &enclosingLoopToVectorDim)
Constructs a permutation map from memref indices to vector dimension.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
Block represents an ordered list of Operations.
Definition: Block.h:31
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType)
Return true if vectorType is a contiguous slice of memrefType.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)
Returns two dims that are greater than one if the transposition is applied on a 2D slice.
Definition: VectorUtils.cpp:84
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: VectorUtils.cpp:41
bool isLinearizableVector(VectorType type)
Returns true if the input Vector type can be linearized.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef< int64_t > readShape, Value padValue, bool useInBoundsInsteadOfMasking)
Create a TransferReadOp from source with static shape readShape.
auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc)
Returns a functor (int64_t -> Value) which returns a constant vscale multiple.
Definition: VectorUtils.h:113
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
Definition: VectorUtils.h:152
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0