MLIR 23.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
35#include "mlir/IR/Value.h"
36#include "mlir/Support/LLVM.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
47#include <optional>
48
49using namespace mlir;
50using namespace mlir::linalg;
51
52#define DEBUG_TYPE "linalg-vectorization"
53
54/// Try to vectorize `convOp` as a convolution.
55static FailureOr<Operation *>
56vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
57 ArrayRef<int64_t> inputVecSizes = {},
58 ArrayRef<bool> inputVecScalableFlags = {},
59 bool flatten1DDepthwiseConv = false);
60
61/// Vectorize tensor::InsertSliceOp with:
62/// * vector::TransferReadOp + vector::TransferWriteOp
63/// The vector sizes are either:
64/// * user-provided in `inputVectorSizes`, or
65/// * inferred from the static dims in the input and output tensors.
66/// Bails out if:
67/// * vector sizes are not user-provided, and
68/// * at least one dim is dynamic (in both the input and output tensors).
69///
70/// Before:
71/// !t_in_type = tensor<1x2x3xf32>
72/// !t_out_type = tensor<9x8x7x1x2x3xf32>
73/// !v_type = vector<1x2x3xf32>
74/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
75/// into !t_out_type
76/// After:
77/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
78/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
79static LogicalResult
80vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
81 ArrayRef<int64_t> inputVectorSizes,
82 SmallVectorImpl<Value> &newResults);
83
84/// Returns the effective Pad value for the input op, provided it's a scalar.
85///
86/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
87/// this Op performs padding, retrieve the padding value provided that it's
88/// a scalar and static/fixed for all the padded values. Returns an empty value
89/// otherwise.
91
92/// Return the unique instance of OpType in `block` if it is indeed unique.
93/// Return null if none or more than 1 instances exist.
94template <typename OpType>
95static OpType getSingleOpOfType(Block &block) {
96 OpType res;
97 block.walk([&](OpType op) {
98 if (res) {
99 res = nullptr;
100 return WalkResult::interrupt();
101 }
102 res = op;
103 return WalkResult::advance();
104 });
105 return res;
106}
107
108/// Helper function to extract the input slices after filter is unrolled along
109/// kw.
112 int64_t nSize, int64_t wSize, int64_t cSize,
113 int64_t kwSize, int strideW, int dilationW,
114 int64_t wSizeStep, bool isSingleChanneled) {
116 if (isSingleChanneled) {
117 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
118 // convolution.
119 SmallVector<int64_t> sizes = {wSizeStep};
120 SmallVector<int64_t> strides = {1};
121 for (int64_t kw = 0; kw < kwSize; ++kw) {
122 for (int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
124 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
125 strides));
126 }
127 }
128 } else {
129 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130 // for channeled convolution.
131 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132 SmallVector<int64_t> strides = {1, 1, 1};
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
137 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138 sizes, strides));
139 }
140 }
141 }
142 return result;
143}
144
145/// Helper function to extract the filter slices after filter is unrolled along
146/// kw.
148 Location loc, Value filter,
149 int64_t kwSize) {
151 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152 // non-chanelled convolution] @ [kw].
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
155 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156 }
157 return result;
158}
159
160/// Helper function to extract the result slices after filter is unrolled along
161/// kw.
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep, bool isSingleChanneled) {
167 if (isSingleChanneled) {
168 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169 SmallVector<int64_t> sizes = {wSizeStep};
170 SmallVector<int64_t> strides = {1};
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
173 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
174 strides));
175 }
176 } else {
177 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
178 // convolution.
179 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
180 SmallVector<int64_t> strides = {1, 1, 1};
181 for (int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
183 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
184 strides));
185 }
186 }
187 return result;
188}
189
190/// Helper function to insert the computed result slices.
192 Value res, int64_t wSize, int64_t wSizeStep,
193 SmallVectorImpl<Value> &resVals,
194 bool isSingleChanneled) {
195
196 if (isSingleChanneled) {
197 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
198 // This does not depend on kw.
199 SmallVector<int64_t> strides = {1};
200 for (int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
202 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
203 strides);
204 }
205 } else {
206 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
207 // convolution. This does not depend on kw.
208 SmallVector<int64_t> strides = {1, 1, 1};
209 for (int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
212 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
213 }
214 }
215 return res;
216}
217
218/// Contains the vectorization state and related methods used across the
219/// vectorization process of a given operation.
221 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
222
223 /// Initializes the vectorization state, including the computation of the
224 /// canonical vector shape for vectorization.
225 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
226 ArrayRef<int64_t> inputVectorSizes,
227 ArrayRef<bool> inputScalableVecDims,
228 bool assumeDynamicDimsMatchVecSizes = false);
229
230 /// Returns the canonical vector shape used to vectorize the iteration space.
231 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
232
233 /// Returns the vector dimensions that are scalable in the canonical vector
234 /// shape.
235 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
236
237 /// Returns a vector type of the provided `elementType` with the canonical
238 /// vector shape and the corresponding fixed/scalable dimensions bit. If
239 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
240 /// accordingly.
242 Type elementType,
243 std::optional<AffineMap> dimPermutation = std::nullopt) const {
245 SmallVector<bool> scalableDims;
246 if (dimPermutation.has_value()) {
248 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
249 scalableDims =
250 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
251 } else {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
254 }
255
256 return VectorType::get(vectorShape, elementType, scalableDims);
257 }
258
259 /// Masks an operation with the canonical vector mask if the operation needs
260 /// masking. Returns the masked operation or the original operation if masking
261 /// is not needed. If provided, the canonical mask for this operation is
262 /// permuted using `maybeIndexingMap`.
263 Operation *
264 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
266
267private:
268 /// Initializes the iteration space static sizes using the Linalg op
269 /// information. This may become more complicated in the future.
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
272 }
273
274 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
275 /// all the static and dynamic dimensions of the iteration space to be
276 /// vectorized and store them in `iterSpaceValueSizes`.
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
278 LinalgOp linalgOp);
279
280 /// Create or retrieve an existing mask value to mask `opToMask` in the
281 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
282 /// permuted using that permutation map. If a new mask is created, it will be
283 /// cached for future users.
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
285 LinalgOp linalgOp,
286 std::optional<AffineMap> maybeMaskingMap);
287
288 /// Check whether this permutation map can be used for masking. At the
289 /// moment we only make sure that there are no broadcast dimensions, but this
290 /// might change if indexing maps evolve.
291 bool isValidMaskingMap(AffineMap maskingMap) {
292 return maskingMap.getBroadcastDims().empty();
293 }
294
295 /// Turn the input indexing map into a valid masking map.
296 ///
297 /// The input indexing map may contain "zero" results, e.g.:
298 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
299 /// Applying such maps to canonical vector shapes like this one:
300 /// (1, 16, 16, 4)
301 /// would yield an invalid vector shape like this:
302 /// (16, 16, 1, 0)
303 /// Instead, drop the broadcasting dims that make no sense for masking perm.
304 /// maps:
305 /// (d0, d1, d2, d3) -> (d2, d1, d0)
306 /// This way, the corresponding vector/mask type will be:
307 /// vector<16x16x1xty>
308 /// rather than this invalid Vector type:
309 /// vector<16x16x1x0xty>
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
311 return indexingMap.dropZeroResults();
312 }
313
314 // Holds the compile-time static sizes of the iteration space to vectorize.
315 // Dynamic dimensions are represented using ShapedType::kDynamic.
316 SmallVector<int64_t> iterSpaceStaticSizes;
317
318 /// Holds the value sizes of the iteration space to vectorize. Static
319 /// dimensions are represented by 'arith.constant' and dynamic
320 /// dimensions by 'tensor/memref.dim'.
321 SmallVector<Value> iterSpaceValueSizes;
322
323 /// Holds the canonical vector shape used to vectorize the iteration space.
324 SmallVector<int64_t> canonicalVecShape;
325
326 /// Holds the vector dimensions that are scalable in the canonical vector
327 /// shape.
328 SmallVector<bool> scalableVecDims;
329
330 /// Holds the active masks for permutations of the canonical vector iteration
331 /// space.
332 DenseMap<AffineMap, Value> activeMaskCache;
333
334 /// Global vectorization guard for the incoming rewriter. It's initialized
335 /// when the vectorization state is initialized.
336 OpBuilder::InsertionGuard rewriterGuard;
337
338 /// Do all dynamic dims match the corresponding vector sizes?
339 ///
340 /// When a dynamic tensor/memref dimension matches the corresponding vector
341 /// dimension, masking can be safely skipped, despite the presence of dynamic
342 /// shapes. Use this flag with care and only for cases where you are
343 /// confident the assumption holds.
344 bool assumeDynamicDimsMatchVecSizes = false;
345};
346
347LogicalResult
348VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
349 LinalgOp linalgOp) {
350 // TODO: Support 0-d vectors.
351 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 // Create constant index op for static dimensions.
354 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
356 continue;
357 }
358
359 // Find an operand defined on this dimension of the iteration space to
360 // extract the runtime dimension size.
361 Value operand;
362 unsigned operandDimPos;
363 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
364 operandDimPos)))
365 return failure();
366
367 Value dynamicDim =
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 operandDimPos)
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
372 operandDimPos);
373 iterSpaceValueSizes.push_back(dynamicDim);
374 }
375
376 return success();
377}
378
379/// Initializes the vectorization state, including the computation of the
380/// canonical vector shape for vectorization.
381// TODO: Move this to the constructor when we can remove the failure cases.
383 LinalgOp linalgOp,
384 ArrayRef<int64_t> inputVectorSizes,
385 ArrayRef<bool> inputScalableVecDims,
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
388 // Initialize the insertion point.
389 rewriter.setInsertionPoint(linalgOp);
390
391 if (!inputVectorSizes.empty()) {
392 // Get the canonical vector shape from the input vector sizes provided. This
393 // path should be taken to vectorize code with dynamic shapes and when using
394 // vector sizes greater than the iteration space sizes.
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
398 } else {
399 // Compute the canonical vector shape from the operation shape. If there are
400 // dynamic shapes, the operation won't be vectorized. We assume all the
401 // vector dimensions are fixed.
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(), false);
404 }
405
406 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
408
409 if (ShapedType::isDynamicShape(canonicalVecShape))
410 return failure();
411
412 // Initialize iteration space static sizes.
413 initIterSpaceStaticSizes(linalgOp);
414
415 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
416 // all the static and dynamic dimensions of the iteration space, needed to
417 // compute a mask during vectorization.
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
419 return failure();
420
421 return success();
422}
423
424/// Create or retrieve an existing mask value to mask `opToMask` in the
425/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
426/// using that permutation map. If a new mask is created, it will be cached for
427/// future users.
428Value VectorizationState::getOrCreateMaskFor(
429 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
430 std::optional<AffineMap> maybeMaskingMap) {
431
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
434
435 // No mask is needed if the operation is not maskable.
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
437 if (!maskableOp)
438 return Value();
439
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
442
443 // If no masking map was provided, use an identity map with the loop dims.
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
446 AffineMap maskingMap =
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.getContext());
450
451 LDBG() << "Masking map: " << maskingMap;
452
453 // Return the active mask for the masking map of this operation if it was
454 // already created.
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() << "Reusing mask: " << mask;
459 return mask;
460 }
461
462 // Compute permuted projection of the iteration space to be masked and the
463 // corresponding mask shape. If the resulting iteration space dimensions are
464 // static and identical to the mask shape, masking is not needed for this
465 // operation.
466 // TODO: Improve this check. Only projected permutation indexing maps are
467 // supported.
468 SmallVector<int64_t> permutedStaticSizes =
469 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
470 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
472
473 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
474
475 if (permutedStaticSizes == maskShape) {
476 LDBG() << "Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
478 return Value();
479 }
480
481 if (assumeDynamicDimsMatchVecSizes) {
482 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
483 // vector sizes match, we still need to check the _static_ dim sizes. Only
484 // then we can be 100% sure that masking is not required.
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486 [](auto it) {
487 return std::get<0>(it) == ShapedType::kDynamic
488 ? true
489 : std::get<0>(it) == std::get<1>(it);
490 })) {
491 LDBG()
492 << "Dynamic + static dimensions match vector sizes, masking is not "
493 "required.";
494 activeMaskCache[maskingMap] = Value();
495 return Value();
496 }
497 }
498
499 // Permute the iteration space value sizes to compute the mask upper bounds.
500 SmallVector<Value> upperBounds =
501 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
504
505 // Create the mask based on the dimension values.
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() << "Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
510 return mask;
511}
512
513Operation *
515 LinalgOp linalgOp,
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() << "Trying to mask: " << *opToMask;
518
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
522
523 // Create or retrieve mask for this operation.
524 Value mask =
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526
527 if (!mask) {
528 LDBG() << "No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
532 [&](auto xferOp) {
533 // For vector.transfer_read and vector.transfer_write, there is
534 // also the `in-bounds` attribute that has to be set explicitly
535 // to true. Otherwise, "out-of-bounds" access will be assumed
536 // and masks will be generated while lowering these.
537 LDBG() << "Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
539 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
540 ShapedType xferType = xferOp.getShapedType();
541 AffineMap permMap = xferOp.getPermutationMap();
542 // Only set the in-bounds values to true for dynamic dims.
543 // Different mechanisms will set these accordingly for the
544 // static dims.
545 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
547 // Skip broadcast dimensions.
548 if (!dimExpr)
549 continue;
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] = true;
553 }
554 rewriter.modifyOpInPlace(xferOp, [&]() {
555 xferOp.setInBoundsAttr(
556 rewriter.getBoolArrayAttr(inBoundsMap));
557 });
558 })
559 .Default([](Operation *op) {
560 // No-op if the operation is not an xfer read or write.
561 });
562 }
563 return opToMask;
564 }
565
566 // Wrap the operation with a new `vector.mask` and update D-U chain.
567 assert(opToMask && "Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
569 mlir::vector::maskOperation(rewriter, opToMask, mask));
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
571
572 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
573 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
574 maskOpTerminator);
575
576 LDBG() << "Masked operation: " << *maskOp;
577 return maskOp;
578}
579
580/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
581/// projectedPermutation, compress the unused dimensions to serve as a
582/// permutation_map for a vector transfer operation.
583/// For example, given a linalg op such as:
584///
585/// ```
586/// %0 = linalg.generic {
587/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
588/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
589/// }
590/// ins(%0 : tensor<2x3x4xf32>)
591/// outs(%1 : tensor<5x6xf32>)
592/// ```
593///
594/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
595/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
596/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
598 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
599 "expected projected permutation");
600 auto res = compressUnusedDims(map);
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
604 return res;
605}
606
607/// Helper enum to represent conv1d input traversal order.
608enum class Conv1DOpOrder {
609 W, // Corresponds to non-channeled 1D convolution operation.
610 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
611 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
612};
613
614/// Helper data structure to represent the result of vectorization for a single
615/// operation. In certain specific cases, like terminators, we do not want to
616/// propagate.
618 /// Op failed to vectorize.
620 /// Op vectorized and custom function took care of replacement logic
622 /// Op vectorized into a new Op whose results will replace original Op's
623 /// results.
625 // TODO: support values if Op vectorized to Many-Ops whose results we need to
626 // aggregate for replacement.
627};
628/// VectorizationHookResult contains the vectorized op returned from a
629/// CustomVectorizationHook. This is an internal implementation detail of
630/// linalg vectorization, not to be confused with VectorizationResult.
632 /// Return status from vectorizing the current op.
634 /// New vectorized operation to replace the current op.
635 /// Replacement behavior is specified by `status`.
637};
638
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
642
643 if (!combinerOp)
644 return std::nullopt;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](auto op) { return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) { return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](auto op) { return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) { return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) { return CombiningKind::XOR; })
661 .Default(std::nullopt);
662}
663
664/// Check whether `outputOperand` is a reduction with a single combiner
665/// operation. Return the combiner operation of the reduction. Return
666/// nullptr otherwise. Multiple reduction operations would impose an
667/// ordering between reduction dimensions and is currently unsupported in
668/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
669/// max(min(X))
670// TODO: use in LinalgOp verification, there is a circular dependency atm.
671static Operation *matchLinalgReduction(OpOperand *outputOperand) {
672 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
673 unsigned outputPos =
674 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
675 // Only single combiner operations are supported for now.
676 SmallVector<Operation *, 4> combinerOps;
677 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
679 return nullptr;
680
681 // Return the combiner operation.
682 return combinerOps[0];
683}
684
685/// Broadcast `value` to a vector of `shape` if possible. Return value
686/// otherwise.
687static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
688 auto dstVecType = dyn_cast<VectorType>(dstType);
689 // If no shape to broadcast to, just return `value`.
690 if (dstVecType.getRank() == 0)
691 return value;
692 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
694 return value;
695 Location loc = b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
697}
698
699/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
700/// assumes that `reductionOp` has two operands and one of them is the reduction
701/// initial value.buildMultiDimReduce
702// Note: this is a true builder that notifies the OpBuilder listener.
703// TODO: Consider moving as a static helper on the ReduceOp.
705 Value valueToReduce, Value acc,
706 ArrayRef<bool> dimsToMask) {
707 auto maybeKind = getCombinerOpKind(reduceOp);
708 assert(maybeKind && "Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
711}
712
713static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
716}
717
718/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
719/// reduction iterator.
720static bool hasReductionIterator(LinalgOp &op) {
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
723 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
724}
725
726/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
727/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
728/// currently being vectorized. If `dest` has null rank, build an memref.store.
729/// Return the produced value or null if no value is produced.
730// Note: this is a true builder that notifies the OpBuilder listener.
731// TODO: Consider moving as a static helper on the ReduceOp.
732static Value buildVectorWrite(RewriterBase &rewriter, Value value,
733 OpOperand *outputOperand,
734 VectorizationState &state) {
735 Location loc = value.getLoc();
736 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
738
739 // Compute the vector type of the value to store. This type should be an
740 // identity or projection of the canonical vector type without any permutation
741 // applied, given that any permutation in a transfer write happens as part of
742 // the write itself.
744 opOperandMap.getContext(), opOperandMap.getNumInputs(),
745 [&](AffineDimExpr dimExpr) -> bool {
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
747 });
748 auto vectorType = state.getCanonicalVecType(
749 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
750
751 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
752 arith::ConstantIndexOp::create(rewriter, loc, 0));
753
754 Operation *write;
755 if (vectorType.getRank() > 0) {
756 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
757 value = broadcastIfNeeded(rewriter, value, vectorType);
758 assert(value.getType() == vectorType && "Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->get(), indices, writeMap);
761 } else {
762 // 0-d case is still special: do not invert the reindexing writeMap.
763 if (!isa<VectorType>(value.getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.getType() == vectorType && "Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
767 outputOperand->get(), indices);
768 }
769
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
771
772 // If masked, set in-bounds to true. Masking guarantees that the access will
773 // be in-bounds.
774 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
776 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
777 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
778 }
779
780 LDBG() << "vectorized op: " << *write;
781 if (!write->getResults().empty())
782 return write->getResult(0);
783 return Value();
784}
785
786// Custom vectorization precondition function type. This is intented to be used
787// with CustomVectorizationHook. Returns success if the corresponding custom
788// hook can vectorize the op.
790 std::function<LogicalResult(Operation *, bool)>;
791
792// Custom vectorization function type. Produce a vector form of Operation*
793// assuming all its vectorized operands are already in the IRMapping.
794// Return nullptr if the Operation cannot be vectorized.
796 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
797
798/// Helper function to vectorize the terminator of a `linalgOp`. New result
799/// vector values are appended to `newResults`. Return
800/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
801/// that it should not try to map produced operations and instead return the
802/// results using the `newResults` vector making them available to the
803/// vectorization algorithm for RAUW. This function is meant to be used as a
804/// CustomVectorizationHook.
807 const IRMapping &bvm, VectorizationState &state,
808 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
810 if (!yieldOp)
812 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
813 // TODO: Scan for an opportunity for reuse.
814 // TODO: use a map.
815 Value vectorValue = bvm.lookup(output.value());
816 Value newResult =
817 buildVectorWrite(rewriter, vectorValue,
818 linalgOp.getDpsInitOperand(output.index()), state);
819 if (newResult)
820 newResults.push_back(newResult);
821 }
822
824}
825
826/// Helper function to vectorize the index operations of a `linalgOp`. Return
827/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
828/// should map the produced operations. This function is meant to be used as a
829/// CustomVectorizationHook.
831 VectorizationState &state,
832 Operation *op,
833 LinalgOp linalgOp) {
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 if (!indexOp)
837 auto loc = indexOp.getLoc();
838 // Compute the static loop sizes of the index op.
839 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
840 auto dim = indexOp.getDim();
841 // Compute a one-dimensional index vector for the index op dimension.
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
846 // Return the one-dimensional index vector if it lives in the trailing
847 // dimension of the iteration space since the vectorization algorithm in this
848 // case can handle the broadcast.
849 if (dim == targetShape.size() - 1)
851 // Otherwise permute the targetShape to move the index dimension last,
852 // broadcast the one-dimensional index vector to the permuted shape, and
853 // finally transpose the broadcasted index vector to undo the permutation.
854 auto permPattern =
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
857 auto permMap =
858 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
859
860 auto broadCastOp = vector::BroadcastOp::create(
861 rewriter, loc,
862 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
863 SmallVector<int64_t> transposition =
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
866 auto transposeOp =
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
869}
870
871/// Helper function to check if the tensor.extract can be vectorized by the
872/// custom hook vectorizeTensorExtract.
873static LogicalResult
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
876 if (!extractOp)
877 return failure();
878
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
880 return failure();
881
882 // Check the index type, but only for non 0-d tensors (for which we do need
883 // access indices).
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
886 return failure();
887 }
888
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
891 return failure();
892 }
893
894 return success();
895}
896
897/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
898/// generated from `tensor.extract`. The offset is calculated as follows
899/// (example using scalar values):
900///
901/// offset = extractOp.indices[0]
902/// for (i = 1; i < numIndices; i++)
903/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
904///
905/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
906/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
910 const IRMapping &bvm) {
911 // The vector of indices for GatherOp should be shaped as the output vector.
912 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
913 auto loc = extractOp.getLoc();
914
915 Value offset = broadcastIfNeeded(
916 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
917
918 const size_t numIndices = extractOp.getIndices().size();
919 for (size_t i = 1; i < numIndices; i++) {
920 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
921
922 auto dimSize = broadcastIfNeeded(
923 rewriter,
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 indexVecType);
926
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928
929 auto extractOpIndex = broadcastIfNeeded(
930 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
931
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
933 }
934
935 return offset;
936}
937
939
940/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
941/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
942/// represents a contiguous load operation.
943///
944/// Note that when calling this hook, it is assumed that the output vector is
945/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
946/// labelled as a gather load before entering this method.
947///
948/// Following on from the above, it is assumed that:
949/// * for statically shaped loops, when no masks are used, only one dim is !=
950/// 1 (that's what the shape of the output vector is based on).
951/// * for dynamically shaped loops, there might be more non-unit dims
952/// as the output vector type is user-specified.
953///
954/// TODO: Statically shaped loops + vector masking
955static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
956 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
957 assert(
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
963
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
967 break;
968
969 return idx;
970}
971
972/// Checks whether `val` can be used for calculating a loop invariant index.
973static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
974 VectorType resType) {
975
976 assert(((llvm::count_if(resType.getShape(),
977 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
979
980 // Blocks outside _this_ linalg.generic are effectively loop invariant.
981 // However, analysing block arguments for _this_ linalg.generic Op is a bit
982 // tricky. Just bail out in the latter case.
983 // TODO: We could try analysing the corresponding affine map here.
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
987
988 Operation *defOp = val.getDefiningOp();
989 assert(defOp && "This is neither a block argument nor an operation result");
990
991 // IndexOp is loop invariant as long as its result remains constant across
992 // iterations. Note that for dynamic shapes, the corresponding dim will also
993 // be conservatively treated as != 1.
994 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 }
997
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
999
1000 // Values define outside `linalgOp` are loop invariant.
1001 if (!ancestor)
1002 return true;
1003
1004 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1005 if (isa<arith::ConstantOp>(ancestor))
1006 return true;
1007
1008 bool result = true;
1009 for (auto op : ancestor->getOperands())
1010 result &= isLoopInvariantIdx(linalgOp, op, resType);
1011
1012 return result;
1013}
1014
1015/// Check whether `val` could be used for calculating the trailing index for a
1016/// contiguous load operation.
1017///
1018/// There are currently 3 types of values that are allowed here:
1019/// 1. loop-invariant values,
1020/// 2. values that increment by 1 with every loop iteration,
1021/// 3. results of basic arithmetic operations (linear and continuous)
1022/// involving 1., 2. and 3.
1023/// This method returns True if indeed only such values are used in calculating
1024/// `val.`
1025///
1026/// Additionally, the trailing index for a contiguous load operation should
1027/// increment by 1 with every loop iteration, i.e. be based on:
1028/// * `linalg.index <dim>` ,
1029/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1030/// `linalg.index <dim>` increments by 1 with every loop iteration).
1031/// `foundIndexOp` is updated to `true` when such Op is found.
1032static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1033 bool &foundIndexOp, VectorType resType) {
1034
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1038
1039 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1040 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1041 // tricky. Just bail out in the latter case.
1042 // TODO: We could try analysing the corresponding affine map here.
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1046
1047 Operation *defOp = val.getDefiningOp();
1048 assert(defOp && "This is neither a block argument nor an operation result");
1049
1050 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1052
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1054 return true;
1055 }
1056
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1058
1059 if (!ancestor)
1060 return false;
1061
1062 // Conservatively reject Ops that could lead to indices with stride other
1063 // than 1.
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1065 return false;
1066
1067 bool result = false;
1068 for (auto op : ancestor->getOperands())
1069 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1070
1071 return result;
1072}
1073
1074/// Infer the memory access pattern for the input ExtractOp
1075///
1076/// Based on the ExtratOp result shape and the access indices, decides whether
1077/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1078/// or a gather load. When analysing the ExtractOp indices (to identify
1079/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1080/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1081/// Op).
1082///
1083/// Note that it is always safe to use gather load operations for contiguous
1084/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1085/// that `extractOp` is a gather load.
1087getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1088 LinalgOp &linalgOp, VectorType resType) {
1089
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091
1092 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1093 if (inputShape.getShape().empty())
1095
1096 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1097 // otherwise.
1098 bool isOutput1DVector =
1099 (llvm::count_if(resType.getShape(),
1100 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1101 // 1. Assume that it's a gather load when reading non-1D vector.
1102 if (!isOutput1DVector)
1104
1105 bool leadingIdxsLoopInvariant = true;
1106
1107 // 2. Analyze the leading indices of `extractOp`.
1108 // Look at the way each index is calculated and decide whether it is suitable
1109 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1110 // gather load.
1111 auto indices = extractOp.getIndices();
1112 auto leadIndices = indices.drop_back(1);
1113
1114 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1115 if (inputShape.getShape()[i] == 1)
1116 continue;
1117
1118 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1119 }
1120
1121 if (!leadingIdxsLoopInvariant) {
1122 LDBG() << "Found gather load: " << extractOp;
1124 }
1125
1126 // 3. Analyze the trailing index for `extractOp`.
1127 // At this point we know that the leading indices are loop invariant. This
1128 // means that is potentially a scalar or a contiguous load. We can decide
1129 // based on the trailing idx.
1130 auto extractOpTrailingIdx = indices.back();
1131
1132 // 3a. Scalar broadcast load
1133 // If the trailing index is loop invariant then this is a scalar load.
1134 if (leadingIdxsLoopInvariant &&
1135 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1136 LDBG() << "Found scalar broadcast load: " << extractOp;
1137
1139 }
1140
1141 // 3b. Contiguous loads
1142 // The trailing `extractOp` index should increment with every loop iteration.
1143 // This effectively means that it must be based on the trailing loop index.
1144 // This is what the following bool captures.
1145 bool foundIndexOp = false;
1146 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1147 foundIndexOp, resType);
1148 // TODO: Support generating contiguous loads for column vectors - that will
1149 // require adding a permutation map to tranfer_read Ops.
1150 bool isRowVector = resType.getShape().back() != 1;
1151 isContiguousLoad &= (foundIndexOp && isRowVector);
1152
1153 if (isContiguousLoad) {
1154 LDBG() << "Found contigous load: " << extractOp;
1156 }
1157
1158 // 4. Fallback case - gather load.
1159 LDBG() << "Found gather load: " << extractOp;
1161}
1162
1163/// Helper function to vectorize the tensor.extract operations. Returns
1164/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1165/// should map the produced operations. This function is meant to be used as a
1166/// CustomVectorizationHook.
1167static VectorizationHookResult
1168vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1169 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1170 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1171 if (!extractOp)
1173 auto loc = extractOp.getLoc();
1174
1175 // Compute the static loop sizes of the extract op.
1176 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1177 auto maskConstantOp = arith::ConstantOp::create(
1178 rewriter, loc,
1179 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1180 /*value=*/true));
1181 auto passThruConstantOp = arith::ConstantOp::create(
1182 rewriter, loc, rewriter.getZeroAttr(resultType));
1183
1184 // Base indices are currently set to 0. We will need to re-visit if more
1185 // generic scenarios are to be supported.
1186 SmallVector<Value> baseIndices(
1187 extractOp.getIndices().size(),
1188 arith::ConstantIndexOp::create(rewriter, loc, 0));
1189
1190 VectorMemoryAccessKind memAccessKind =
1191 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1192
1193 // 1. Handle gather access
1194 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1195 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1196
1197 // Generate the gather load
1198 Operation *gatherOp = vector::GatherOp::create(
1199 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1200 maskConstantOp, passThruConstantOp);
1201 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1202
1203 LDBG() << "Vectorised as gather load: " << extractOp;
1205 }
1206
1207 // 2. Handle:
1208 // a. scalar loads + broadcast,
1209 // b. contiguous loads.
1210 // Both cases use vector.transfer_read.
1211
1212 // Collect indices for `vector.transfer_read`. At this point, the indices will
1213 // either be scalars or would have been broadcast to vectors matching the
1214 // result type. For indices that are vectors, there are two options:
1215 // * for non-trailing indices, all elements are identical (contiguous
1216 // loads are identified by looking for non-trailing indices that are
1217 // invariant with respect to the corresponding linalg.generic), or
1218 // * for trailing indices, the index vector will contain values with stride
1219 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1220 // needed.
1221 // This means that
1222 // * for scalar indices - just re-use it,
1223 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1224 // (0th) element and use that.
1225 SmallVector<Value> transferReadIdxs;
1226 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1227 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1228 if (idx.getType().isIndex()) {
1229 transferReadIdxs.push_back(idx);
1230 continue;
1231 }
1232
1233 auto indexAs1dVector = vector::ShapeCastOp::create(
1234 rewriter, loc,
1235 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1236 resultType.getScalableDims().back()),
1237 idx);
1238 transferReadIdxs.push_back(
1239 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1240 }
1241
1242 // `tensor.extract_element` is always in-bounds, hence the following holds.
1243 auto dstRank = resultType.getRank();
1244 auto srcRank = extractOp.getTensor().getType().getRank();
1245 SmallVector<bool> inBounds(dstRank, true);
1246
1247 // 2a. Handle scalar broadcast access.
1248 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1249 MLIRContext *ctx = rewriter.getContext();
1250 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1251 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1252
1253 auto transferReadOp = vector::TransferReadOp::create(
1254 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1255 /*padding=*/std::nullopt, permutationMap, inBounds);
1256
1257 // Mask this broadcasting xfer_read here rather than relying on the generic
1258 // path (the generic path assumes identity masking map, which wouldn't be
1259 // valid here).
1260 SmallVector<int64_t> readMaskShape = {1};
1261 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1262 auto allTrue = vector::ConstantMaskOp::create(
1263 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1264 auto *maskedReadOp =
1265 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1266
1267 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1269 maskedReadOp};
1270 }
1271
1272 // 2b. Handle contiguous access.
1273 auto permutationMap = AffineMap::getMinorIdentityMap(
1274 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1275
1276 int32_t rankDiff = dstRank - srcRank;
1277 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1278 // dims so that the ranks match. This is done by extending the map with 0s.
1279 // For example, for dstRank = 3, srcRank = 2, the following map created
1280 // above:
1281 // (d0, d1) --> (d0, d1)
1282 // is extended as:
1283 // (d0, d1) --> (0, d0, d1)
1284 while (rankDiff > 0) {
1285 permutationMap = permutationMap.insertResult(
1286 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1287 rankDiff--;
1288 }
1289
1290 auto transferReadOp = vector::TransferReadOp::create(
1291 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1292 /*padding=*/std::nullopt, permutationMap, inBounds);
1293
1294 LDBG() << "Vectorised as contiguous load: " << extractOp;
1296 transferReadOp};
1297}
1298
1299/// Emit reduction operations if the shapes of the value to reduce is different
1300/// that the result shape.
1301// Note: this is a true builder that notifies the OpBuilder listener.
1302// TODO: Consider moving as a static helper on the ReduceOp.
1303static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1304 Value reduceValue, Value initialValue,
1305 const IRMapping &bvm) {
1306 Value reduceVec = bvm.lookup(reduceValue);
1307 Value outputVec = bvm.lookup(initialValue);
1308 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1309 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1310 // Reduce only if needed as the value may already have been reduce for
1311 // contraction vectorization.
1312 if (!reduceType ||
1313 (outputType && reduceType.getShape() == outputType.getShape()))
1314 return nullptr;
1315 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1316 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1317}
1318
1319/// Generic vectorization for a single operation `op`, given already vectorized
1320/// operands carried by `bvm`. Vectorization occurs as follows:
1321/// 1. Try to apply any of the `customVectorizationHooks` and return its
1322/// result on success.
1323/// 2. Clone any constant in the current scope without vectorization: each
1324/// consumer of the constant will later determine the shape to which the
1325/// constant needs to be broadcast to.
1326/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1327/// of the `customVectorizationHooks` to cover such cases.
1328/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1329/// operand of maximal rank. Other operands have smaller rank and are
1330/// broadcast accordingly. It is assumed this broadcast is always legal,
1331/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1332///
1333/// This function assumes all operands of `op` have been vectorized and are in
1334/// the `bvm` mapping. As a consequence, this function is meant to be called on
1335/// a topologically-sorted list of ops.
1336/// This function does not update `bvm` but returns a VectorizationHookStatus
1337/// that instructs the caller what `bvm` update needs to occur.
1338static VectorizationHookResult
1339vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1340 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1341 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1342 LDBG() << "vectorize op " << *op;
1343
1344 // 1. Try to apply any CustomVectorizationHook.
1345 if (!customVectorizationHooks.empty()) {
1346 for (auto &customFunc : customVectorizationHooks) {
1347 VectorizationHookResult result = customFunc(op, bvm);
1349 continue;
1350 return result;
1351 }
1352 }
1353
1354 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1355 // Clone so that the constant is not confined to the linalgOp block .
1356 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1358 rewriter.clone(*op)};
1359
1360 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1363
1364 // 4 . Check if the operation is a reduction.
1365 SmallVector<std::pair<Value, Value>> reductionOperands;
1366 for (Value operand : op->getOperands()) {
1367 auto blockArg = dyn_cast<BlockArgument>(operand);
1368 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1369 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1370 continue;
1371 SmallVector<Operation *> reductionOps;
1372 Value reduceValue = matchReduction(
1373 linalgOp.getRegionOutputArgs(),
1374 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1375 if (!reduceValue)
1376 continue;
1377 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1378 }
1379 if (!reductionOperands.empty()) {
1380 assert(reductionOperands.size() == 1);
1381 Operation *reduceOp =
1382 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1383 reductionOperands[0].second, bvm);
1384 if (reduceOp)
1386 }
1387
1388 // 5. Generic vectorization path for ElementwiseMappable ops.
1389 // a. Get the first max ranked shape.
1390 VectorType firstMaxRankedType;
1391 for (Value operand : op->getOperands()) {
1392 auto vecOperand = bvm.lookup(operand);
1393 assert(vecOperand && "Vector operand couldn't be found");
1394
1395 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1396 if (vecType && (!firstMaxRankedType ||
1397 firstMaxRankedType.getRank() < vecType.getRank()))
1398 firstMaxRankedType = vecType;
1399 }
1400 // b. Broadcast each op if needed.
1401 SmallVector<Value> vecOperands;
1402 for (Value scalarOperand : op->getOperands()) {
1403 Value vecOperand = bvm.lookup(scalarOperand);
1404 assert(vecOperand && "Vector operand couldn't be found");
1405
1406 if (firstMaxRankedType) {
1407 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1408 getElementTypeOrSelf(vecOperand.getType()),
1409 firstMaxRankedType.getScalableDims());
1410 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1411 } else {
1412 vecOperands.push_back(vecOperand);
1413 }
1414 }
1415 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1416 SmallVector<Type> resultTypes;
1417 for (Type resultType : op->getResultTypes()) {
1418 resultTypes.push_back(
1419 firstMaxRankedType
1420 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1421 firstMaxRankedType.getScalableDims())
1422 : resultType);
1423 }
1424 // d. Build and return the new op.
1427 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1428 resultTypes, op->getAttrs())};
1429}
1430
1431/// Generic vectorization function that rewrites the body of a `linalgOp` into
1432/// vector form. Generic vectorization proceeds as follows:
1433/// 1. Verify the `linalgOp` has one non-empty region.
1434/// 2. Values defined above the region are mapped to themselves and will be
1435/// broadcasted on a per-need basis by their consumers.
1436/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1437/// load).
1438/// TODO: Reuse opportunities for RAR dependencies.
1439/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1440/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1441/// iteration indices.
1442/// 5. Iteratively call vectorizeOneOp on the region operations.
1443///
1444/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1445/// performed to the maximal common vector size implied by the `linalgOp`
1446/// iteration space. This eager broadcasting is introduced in the
1447/// permutation_map of the vector.transfer_read operations. The eager
1448/// broadcasting makes it trivial to determine where broadcast, transposes and
1449/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1450/// the absence of good canonicalizations, the amount of work increases.
1451/// This is not deemed a problem as we expect canonicalizations and foldings to
1452/// aggressively clean up the useless work.
1453static LogicalResult
1454vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1455 LinalgOp linalgOp,
1456 SmallVectorImpl<Value> &newResults) {
1457 LDBG() << "Vectorizing operation as linalg generic/n";
1458 Block *block = linalgOp.getBlock();
1459
1460 // 2. Values defined above the region can only be broadcast for now. Make them
1461 // map to themselves.
1462 IRMapping bvm;
1463 SetVector<Value> valuesSet;
1464 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1465 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1466
1467 if (linalgOp.getNumDpsInits() == 0)
1468 return failure();
1469
1470 // 3. Turn all BBArgs into vector.transfer_read / load.
1471 Location loc = linalgOp.getLoc();
1472 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1473 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1474 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1475 if (linalgOp.isScalar(opOperand)) {
1476 bvm.map(bbarg, opOperand->get());
1477 continue;
1478 }
1479
1480 // 3.a. Convert the indexing map for this input/output to a transfer read
1481 // permutation map and masking map.
1482 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1483
1484 AffineMap readMap;
1485 VectorType readType;
1486 Type elemType = getElementTypeOrSelf(opOperand->get());
1487 if (linalgOp.isDpsInput(opOperand)) {
1488 // 3.a.i. For input reads we use the canonical vector shape.
1489 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1490 readType = state.getCanonicalVecType(elemType);
1491 } else {
1492 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1493 // reductions), the vector shape is computed by mapping the canonical
1494 // vector shape to the output domain and back to the canonical domain.
1495 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1496 readType =
1497 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1498 }
1499
1500 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1501
1502 Operation *read = vector::TransferReadOp::create(
1503 rewriter, loc, readType, opOperand->get(), indices,
1504 /*padding=*/std::nullopt, readMap);
1505 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1506 Value readValue = read->getResult(0);
1507
1508 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1509 // will be in-bounds.
1510 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1511 SmallVector<bool> inBounds(readType.getRank(), true);
1512 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1513 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1514 }
1515
1516 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1517 // TODO: remove this.
1518 if (readType.getRank() == 0)
1519 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1521
1522 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1523 << "): " << readValue;
1524 bvm.map(bbarg, readValue);
1525 bvm.map(opOperand->get(), readValue);
1526 }
1527
1529 // 4a. Register CustomVectorizationHook for yieldOp.
1530 CustomVectorizationHook vectorizeYield =
1531 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1532 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1533 };
1534 hooks.push_back(vectorizeYield);
1535
1536 // 4b. Register CustomVectorizationHook for indexOp.
1537 CustomVectorizationHook vectorizeIndex =
1538 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1539 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1540 };
1541 hooks.push_back(vectorizeIndex);
1542
1543 // 4c. Register CustomVectorizationHook for extractOp.
1544 CustomVectorizationHook vectorizeExtract =
1545 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1546 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1547 };
1548 hooks.push_back(vectorizeExtract);
1549
1550 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1551 for (Operation &op : block->getOperations()) {
1553 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1555 LDBG() << "failed to vectorize: " << op;
1556 return failure();
1557 }
1558 if (result.status == VectorizationHookStatus::NewOp) {
1559 Operation *maybeMaskedOp =
1560 state.maskOperation(rewriter, result.newOp, linalgOp);
1561 LDBG() << "New vector op: " << *maybeMaskedOp;
1562 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1563 }
1564 }
1565
1566 return success();
1567}
1568
1569/// Determines whether a mask for xfer_write is trivially "all true"
1570///
1571/// Given all the inputs required to generate a mask (mask sizes and shapes),
1572/// and an xfer_write operation (write indices and the destination tensor
1573/// shape), determines whether the corresponding mask would be trivially
1574/// foldable (i.e., trivially "all true").
1575///
1576/// Use this method to avoid generating spurious masks and relaying on
1577/// vectorization post-processing to remove them.
1578///
1579/// Pre-conditions for a mask to be trivially foldable:
1580/// * All involved shapes (mask + destination tensor) are static.
1581/// * All write indices are constant.
1582/// * All mask sizes are constant (including `arith.constant`).
1583///
1584/// If the pre-conditions are met, the method checks for each destination
1585/// dimension `d`:
1586/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1587/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1588///
1589/// rankDiff = rank(dest) - rank(mask).
1590///
1591/// This method takes a conservative view: it may return false even if the mask
1592/// is technically foldable.
1593///
1594/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1595/// of the dest tensor):
1596/// %c0 = arith.constant 0 : index
1597/// %mask = vector.create_mask 5, 1
1598/// vector.mask %mask {
1599/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1600/// {in_bounds = [true, true]}
1601/// : vector<5x1xi32>, tensor<5x1xi32>
1602/// }
1603///
1604/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1605/// mask is required to avoid out-of-bounds write):
1606/// %c0 = arith.constant 0 : index
1607/// %mask = vector.create_mask 5, 1
1608/// vector.mask %mask {
1609/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1610/// {in_bounds = [true, true]}
1611/// : vector<8x1xi32>, tensor<5x1xi32>
1612/// }
1613///
1614/// TODO: Re-use in createReadOrMaskedRead
1616 SmallVector<Value> &writeIdxs,
1617 ArrayRef<int64_t> destShape,
1618 ArrayRef<int64_t> maskShape) {
1619 // Masking is unavoidable in the case of dynamic tensors.
1620 if (ShapedType::isDynamicShape(destShape))
1621 return false;
1622
1623 // Collect all constant mask sizes.
1624 SmallVector<int64_t, 4> cstMaskSizes;
1625 for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1626 if (auto intSize = getConstantIntValue(dimSize)) {
1627 cstMaskSizes.push_back(*intSize);
1628 }
1629 }
1630
1631 // If any of the mask sizes is non-constant, bail out.
1632 if (cstMaskSizes.size() != maskShape.size())
1633 return false;
1634
1635 // Collect all constant write indices.
1636 SmallVector<int64_t, 4> cstWriteIdxs;
1637 for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1638 APSInt intVal;
1639 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1640 cstWriteIdxs.push_back(intVal.getSExtValue());
1641 }
1642 }
1643
1644 // If any of the write indices is non-constant, bail out.
1645 if (cstWriteIdxs.size() != destShape.size())
1646 return false;
1647
1648 // Go over all destination dims and check (1) and (2). Take into account that:
1649 // * The number of mask sizes will match the rank of the vector to store.
1650 // This could be lower than the rank of the destination tensor.
1651 // * Mask sizes could be larger than the corresponding mask shape (hence
1652 // `clamp`).
1653 // TODO: The 2nd item should be rejected by the verifier.
1654 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1655 for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1656 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1657 /*(2)*/ destShape[rankDiff + i] <
1658 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1659 cstWriteIdxs[i]))
1660 return false;
1661 }
1662
1663 return true;
1664}
1665
1666/// Creates an optionally masked TransferWriteOp
1667///
1668/// Generates the following operation:
1669/// %res = vector.transfer_write %vecToStore into %dest
1670///
1671/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1672///
1673/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1674/// %res = vector.mask %mask {
1675/// vector.transfer_write %vecToStore into %dest
1676/// }
1677///
1678/// The mask shape is identical to `vecToStore` (with the element type ==
1679/// i1), and the mask values are based on the shape of the `dest` tensor.
1680///
1681/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1682/// is used instead of masking:
1683///
1684/// %write = vector.transfer_write %vecToStore into %dest
1685/// in_bounds_flags = (...)
1686/// %res = vector.transfer_write %input into %dest
1687/// {in_bounds = in_bounds_flags}
1688///
1689/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1690/// are set to 0.
1691static Operation *
1693 Value dest, SmallVector<Value> writeIndices = {},
1694 bool useInBoundsInsteadOfMasking = false) {
1695
1696 ShapedType destType = cast<ShapedType>(dest.getType());
1697 int64_t destRank = destType.getRank();
1698 auto destShape = destType.getShape();
1699
1700 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1701 int64_t vecToStoreRank = vecToStoreType.getRank();
1702 auto vecToStoreShape = vecToStoreType.getShape();
1703
1704 // Compute the in_bounds attribute
1705 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1706 if (useInBoundsInsteadOfMasking) {
1707 // Update the inBounds attribute.
1708 // FIXME: This computation is too weak - it ignores the write indices.
1709 for (unsigned i = 0; i < vecToStoreRank; i++)
1710 inBoundsVal[i] =
1711 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1712 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1713 }
1714
1715 // If missing, initialize the write indices to 0.
1716 bool useDefaultWriteIdxs = writeIndices.empty();
1717 assert((useDefaultWriteIdxs ||
1718 writeIndices.size() == static_cast<size_t>(destRank)) &&
1719 "Invalid number of write indices!");
1720 if (writeIndices.empty()) {
1721 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1722 writeIndices.assign(destRank, zero);
1723 }
1724
1725 // Generate the xfer_write Op
1726 Operation *write = vector::TransferWriteOp::create(builder, loc,
1727 /*vector=*/vecToStore,
1728 /*source=*/dest,
1729 /*indices=*/writeIndices,
1730 /*inBounds=*/inBoundsVal);
1731
1732 // If masking is disabled, exit.
1733 if (useInBoundsInsteadOfMasking)
1734 return write;
1735
1736 // Check if masking is needed. If not, exit.
1737 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1738 return write;
1739
1740 // Compute the mask and mask the write Op.
1741 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1742 vecToStoreType.getScalableDims());
1743
1744 SmallVector<OpFoldResult> destSizes =
1745 isa<MemRefType>(dest.getType())
1746 ? memref::getMixedSizes(builder, loc, dest)
1747 : tensor::getMixedSizes(builder, loc, dest);
1748
1749 // Compute sizes for write-mask
1750 SmallVector<OpFoldResult> maskSizes;
1751 if (useDefaultWriteIdxs) {
1752 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1753 destSizes.end());
1754 } else {
1755 size_t diff = destShape.size() - vecToStoreRank;
1756 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1757 auto value =
1758 getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
1759 auto neg =
1760 builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1761 maskSizes.push_back(OpFoldResult(neg));
1762 }
1763 }
1764
1765 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1766 vecToStoreShape))
1767 return write;
1768
1769 Value maskForWrite =
1770 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1771 return mlir::vector::maskOperation(builder, write, maskForWrite);
1772}
1773
1774/// Given the re-associations, "collapses" the input Vector type
1775///
1776/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1777/// differences:
1778/// * We can safely assume that there are no dynamic sizes.
1779/// * Scalable flags are updated alongside regular dims.
1780///
1781/// When collapsing scalable flags, conservatively avoids cases with two
1782/// scalable dims. We could re-visit this in the future.
1783///
1784/// EXAMPLE:
1785/// type = vector<4x16x[8]x16xf32>
1786/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1787/// (d0, d1, d2, d3) -> (d2, d3)]
1788/// Result:
1789/// vector<64x[128]xf32>
1790static VectorType getCollapsedVecType(VectorType type,
1791 ArrayRef<AffineMap> reassociation) {
1792 assert(type.getNumScalableDims() < 2 &&
1793 "Collapsing more than 1 scalable dim is not supported ATM");
1794
1795 // Use the fact that reassociation is valid to simplify the logic: only use
1796 // each map's rank.
1797 assert(isReassociationValid(reassociation) && "invalid reassociation");
1798
1799 auto shape = type.getShape();
1800 auto scalableFlags = type.getScalableDims();
1801 SmallVector<int64_t> newShape;
1802 SmallVector<bool> newScalableFlags;
1803
1804 unsigned currentDim = 0;
1805 for (AffineMap m : reassociation) {
1806 unsigned dim = m.getNumResults();
1807 int64_t size = 1;
1808 bool flag = false;
1809 for (unsigned d = 0; d < dim; ++d) {
1810 size *= shape[currentDim + d];
1811 flag |= scalableFlags[currentDim + d];
1812 }
1813 newShape.push_back(size);
1814 newScalableFlags.push_back(flag);
1815 currentDim += dim;
1816 }
1817
1818 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1819}
1820
1821/// Vectorize `linalg.pack` as:
1822/// * xfer_read -> shape_cast -> transpose -> xfer_write
1823///
1824/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1825/// sizes for the xfer_write operation). This is sufficient to infer the other
1826/// vector sizes required here.
1827///
1828/// If the vector sizes are not provided:
1829/// * the vector sizes are determined from the destination tensor static shape.
1830/// * the inBounds attribute is used instead of masking.
1831///
1832/// EXAMPLE (no vector sizes):
1833/// ```
1834/// %pack = tensor.pack %src
1835/// inner_dims_pos = [2, 1]
1836/// inner_tiles = [16, 2]
1837/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1838/// ``
1839/// is vectorizes as:
1840/// ```
1841/// %read = vector.transfer_read %src
1842/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1843/// %sc = vector.shape_cast %read
1844/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1845/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1846/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1847/// %write = vector.transfer_write %tr into %dest
1848/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1849/// ```
1850static LogicalResult
1851vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1852 ArrayRef<int64_t> inputVectorSizes,
1853 SmallVectorImpl<Value> &newResults) {
1854 if (!inputVectorSizes.empty()) {
1855 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1856 "Invalid number of input vector sizes!");
1857 }
1858
1859 // TODO: Introduce a parent class that will handle the insertion point update.
1860 OpBuilder::InsertionGuard g(rewriter);
1861 rewriter.setInsertionPoint(packOp);
1862
1863 Location loc = packOp.getLoc();
1864 std::optional<Value> padValue = packOp.getPaddingValue()
1865 ? std::optional(packOp.getPaddingValue())
1866 : std::nullopt;
1867
1868 SmallVector<int64_t> destShape =
1869 SmallVector<int64_t>(packOp.getDestType().getShape());
1870
1871 // This is just a convenience alias to clearly communicate that the input
1872 // vector sizes determine the _write_ sizes.
1873 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1874
1875 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1876 // In addition, use the inBounds attribute instead of masking.
1877 bool useInBoundsInsteadOfMasking = false;
1878 if (writeVectorSizes.empty()) {
1879 if (ShapedType::isDynamicShape(destShape))
1880 return rewriter.notifyMatchFailure(packOp,
1881 "unable to infer vector sizes");
1882
1883 writeVectorSizes = destShape;
1884 useInBoundsInsteadOfMasking = true;
1885 }
1886
1887 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1888 // _before_ the transposition (i.e. before dimension permutation). This is
1889 // done by inverting the permutation/transposition that's part of the Pack
1890 // operation. This type is required to:
1891 // 1) compute the read vector type for masked-read below, and
1892 // 2) generate shape-cast Op below that expands the read vector type.
1893 PackingMetadata packMetadata;
1894 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1895 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1896 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1897 auto preTransposeWriteVecType =
1898 VectorType::get(preTransposeWriteVecSizses,
1899 packOp.getResult().getType().getElementType());
1900
1901 // Compute vector type for the _read_ opeartion. This is simply
1902 // pre-transpose-write-vector-type with the dimensions collapsed
1903 // as per the Pack operation.
1904 VectorType readVecType = getCollapsedVecType(
1905 preTransposeWriteVecType,
1907 rewriter.getContext(), packMetadata.reassociations)));
1908
1909 // Create masked TransferReadOp.
1910 auto maskedRead = vector::createReadOrMaskedRead(
1911 rewriter, loc, packOp.getSource(), readVecType, padValue,
1912 useInBoundsInsteadOfMasking);
1913
1914 // Create ShapeCastOp.
1915 auto shapeCastOp = vector::ShapeCastOp::create(
1916 rewriter, loc, preTransposeWriteVecType, maskedRead);
1917
1918 // Create TransposeOp.
1919 auto destPermutation = invertPermutationVector(destInvPermutation);
1920 auto transposeOp = vector::TransposeOp::create(
1921 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1922
1923 // Create TransferWriteOp.
1924 Operation *write = createWriteOrMaskedWrite(
1925 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1926 newResults.push_back(write->getResult(0));
1927 return success();
1928}
1929
1930/// Vectorize `linalg.unpack` as:
1931/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1932///
1933/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1934/// sizes for the xfer_read operation). This is sufficient to infer the other
1935/// vector sizes required here.
1936///
1937/// If the vector sizes are not provided:
1938/// * the vector sizes are determined from the input tensor static shape.
1939/// * the inBounds attribute is used instead of masking.
1940///
1941/// EXAMPLE (no vector sizes):
1942/// ```
1943/// %unpack = linalg.unpack %src
1944/// inner_dims_pos = [0, 1]
1945/// inner_tiles = [8, 8]
1946/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1947/// ```
1948/// is vectorized as:
1949/// ```
1950/// %read = vector.transfer_read %src
1951/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1952/// %tr = vector.transpose %read, [0, 2, 1, 3]
1953/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1954/// %sc = vector.shape_cast %tr
1955/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1956/// %vector = vector.transfer_write %sc into %dest
1957/// : vector<8x8xf32>, tensor<8x8xf32>
1958/// ```
1959static LogicalResult
1960vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1961 ArrayRef<int64_t> inputVectorSizes,
1962 ArrayRef<bool> inputScalableVecDims,
1963 SmallVectorImpl<Value> &newResults) {
1964 if (!inputVectorSizes.empty()) {
1965 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1966 "Invalid number of input vector sizes!");
1967 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1968 "Incompatible number of vector sizes and vector scalable flags!");
1969 }
1970
1971 // TODO: Introduce a parent class that will handle the insertion point update.
1972 OpBuilder::InsertionGuard g(rewriter);
1973 rewriter.setInsertionPoint(unpackOp);
1974
1975 ShapedType unpackTensorType = unpackOp.getSourceType();
1976
1977 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1978 bool useInBoundsInsteadOfMasking = false;
1979
1980 Location loc = unpackOp->getLoc();
1981
1982 // Obtain vector sizes for the read operation.
1983 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1984 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1985
1986 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1987 if (inputVectorSizes.empty()) {
1988 if (ShapedType::isDynamicShape(sourceShape))
1989 return rewriter.notifyMatchFailure(unpackOp,
1990 "Unable to infer vector sizes!");
1991
1992 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1993 useInBoundsInsteadOfMasking = true;
1994 }
1995
1996 // -- Generate the read operation --
1997 VectorType readVecType =
1998 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1999 readScalableVectorFlags);
2000 Value readResult = vector::createReadOrMaskedRead(
2001 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2002 useInBoundsInsteadOfMasking);
2003
2004 // -- Generate the transpose operation --
2005 PackingMetadata packMetadata;
2006 SmallVector<int64_t> lastDimToInsertPosPerm =
2007 getUnPackInverseSrcPerm(unpackOp, packMetadata);
2008 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2009 rewriter, loc, readResult, lastDimToInsertPosPerm);
2010
2011 // -- Generate the shape_cast operation --
2012 VectorType collapsedVecType = getCollapsedVecType(
2013 transposeOp.getType(),
2015 rewriter.getContext(), packMetadata.reassociations)));
2016 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2017 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2018
2019 // -- Generate the write operation --
2020 Operation *write = createWriteOrMaskedWrite(
2021 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2022 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
2023
2024 newResults.push_back(write->getResult(0));
2025 return success();
2026}
2027
2028/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
2029/// and (3) all-zero lowPad to
2030/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2031static LogicalResult
2032vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2033 ArrayRef<int64_t> inputVectorSizes,
2034 SmallVectorImpl<Value> &newResults) {
2035 auto padValue = padOp.getConstantPaddingValue();
2036 Location loc = padOp.getLoc();
2037
2038 // TODO: Introduce a parent class that will handle the insertion point update.
2039 OpBuilder::InsertionGuard g(rewriter);
2040 rewriter.setInsertionPoint(padOp);
2041
2042 ReifiedRankedShapedTypeDims reifiedReturnShapes;
2043 LogicalResult status =
2044 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2045 .reifyResultShapes(rewriter, reifiedReturnShapes);
2046 (void)status; // prevent unused variable warning on non-assert builds
2047 assert(succeeded(status) && "failed to reify result shapes");
2048 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2049 auto maskedRead = vector::createReadOrMaskedRead(
2050 rewriter, loc, padOp.getSource(), readType, padValue,
2051 /*useInBoundsInsteadOfMasking=*/false);
2052
2053 // Create Xfer write Op
2054 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2055 padOp.getResultType().getElementType());
2056 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2057 newResults.push_back(write->getResult(0));
2058 return success();
2059}
2060
2061// TODO: probably need some extra checks for reduction followed by consumer
2062// ops that may not commute (e.g. linear reduction + non-linear instructions).
2063static LogicalResult reductionPreconditions(LinalgOp op) {
2064 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2065 LDBG() << "reduction precondition failed: no reduction iterator";
2066 return failure();
2067 }
2068 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2069 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2070 if (indexingMap.isPermutation())
2071 continue;
2072
2073 Operation *reduceOp = matchLinalgReduction(&opOperand);
2074 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2075 LDBG() << "reduction precondition failed: reduction detection failed";
2076 return failure();
2077 }
2078 }
2079 return success();
2080}
2081
2082static LogicalResult
2083vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2084 bool flatten1DDepthwiseConv) {
2085 if (flatten1DDepthwiseConv) {
2086 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2087 "supported";
2088 return failure();
2089 }
2090
2092 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2093 return failure();
2094 }
2095
2096 // Support dynamic shapes in 1D depthwise convolution, but only in the
2097 // _channel_ dimension.
2098 Value lhs = conv.getDpsInputOperand(0)->get();
2099 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2100 auto shapeWithoutCh = lhsShape.drop_back(1);
2101 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2102 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2103 "channel dim can be dynamic";
2104 return failure();
2105 }
2106
2107 return success();
2108}
2109
2110static LogicalResult
2111vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2112 bool flatten1DDepthwiseConv) {
2113 if (isa<ConvolutionOpInterface>(op.getOperation()))
2114 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2115
2116 if (hasReductionIterator(op))
2117 return reductionPreconditions(op);
2118
2119 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2120 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2121 if (!isElementwise(op) &&
2122 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2123 op.getOperation()))
2124 return failure();
2125
2126 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2127 return success();
2128}
2129
2130//// This hook considers two cases:
2131/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2132/// infered. This is only possible when all shapes are static.
2133/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2134/// carry out basic sanity-checking.
2135static LogicalResult
2136vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2137 ArrayRef<int64_t> inputVectorSizes) {
2138 // TODO: Support Memref UnPackOp. Temporarily return failure.
2139 if (!unpackOp.hasPureTensorSemantics())
2140 return failure();
2141
2142 // If there are no input vector sizes and all shapes are static, there is
2143 // nothing left to check.
2144 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2145 unpackOp.getSourceType().hasStaticShape())
2146 return success();
2147
2148 // The number of input vector sizes must be equal to:
2149 // * read-vector-rank
2150 if (!inputVectorSizes.empty() &&
2151 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2152 LDBG() << "Incorrect number of input vector sizes";
2153 return failure();
2154 }
2155
2156 // Check the vector sizes for the read operation.
2158 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2159 LDBG() << "Invalid vector sizes for the read operation";
2160 return failure();
2161 }
2162
2163 return success();
2164}
2165
2166static LogicalResult
2167vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2168 ArrayRef<int64_t> inputVectorSizes) {
2169
2170 TypedValue<RankedTensorType> source = sliceOp.getSource();
2171 auto sourceType = source.getType();
2172 if (!VectorType::isValidElementType(sourceType.getElementType()))
2173 return failure();
2174
2175 // Get the pad value.
2176 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2177 // scalar padding value. Note that:
2178 // * for in-bounds accesses,
2179 // the value is actually irrelevant. There are 2 cases in which xfer.read
2180 // accesses are known to be in-bounds:
2181 // 1. The source shape is static (output vector sizes would be based on
2182 // the source shape and hence all memory accesses would be in-bounds),
2183 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2184 // this case it is safe to assume that all memory accesses are in-bounds.
2185 //
2186 // When the value is not known and not needed, use 0. Otherwise, bail out.
2187 Value padValue = getStaticPadVal(sliceOp);
2188 bool isOutOfBoundsRead =
2189 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2190
2191 if (!padValue && isOutOfBoundsRead) {
2192 LDBG() << "Failed to get a pad value for out-of-bounds read access";
2193 return failure();
2194 }
2195 return success();
2196}
2197
2198/// Vectorize a named linalg contraction op into:
2199/// vector::TransferReadOp - Reads vectors from the operands
2200/// vector::ContractionOp - Performs contraction
2201/// vector::TransferWriteOp - Write the result vector back to the
2202/// destination
2203/// The operands shapes are preserved and loaded directly into vectors.
2204/// Any further permutations or numerical casting remain within contraction op.
2205static LogicalResult
2206vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2207 LinalgOp linalgOp,
2208 SmallVectorImpl<Value> &newResults) {
2209 Location loc = linalgOp.getLoc();
2210 MLIRContext *ctx = linalgOp.getContext();
2211
2212 // For simplicity, contraction vectorization is limited to linalg named ops.
2213 // Generic op is ignored as not every arbitrary contraction body can be
2214 // expressed by a vector.contract.
2215 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2216 return failure();
2217
2218 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2219 Operation *reduceOp = matchLinalgReduction(outOperand);
2220 auto maybeKind = getCombinerOpKind(reduceOp);
2221 if (!maybeKind) {
2222 LDBG() << "Failed to determine contraction combining kind.";
2223 return failure();
2224 }
2225
2226 // Check that all dimensions are present in the input operands.
2227 // Arbitrary broadcasts are not supported by the vector contraction.
2228 // Broadcasts are expected to be decomposed before vectorization.
2229 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2230 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2231 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2232 LDBG() << "Contractions with broadcasts are not supported.";
2233 return failure();
2234 }
2235
2236 // Load operands.
2237 SmallVector<Value> vecOperands;
2238 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2239 // The operand vector shape is computed by mapping the canonical vector
2240 // shape to the operand's domain. Further permutations are left as a part of
2241 // the contraction.
2242 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2243 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2244 indexingMap.getNumResults(), rewriter.getContext());
2245 Type elemType = getElementTypeOrSelf(opOperand.get());
2246 VectorType readType =
2247 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2248
2250 rewriter, loc, opOperand.get(), readType,
2251 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2252 /*useInBoundsInsteadOfMasking=*/false);
2253 vecOperands.push_back(read);
2254 }
2255
2256 // Remap iterators from linalg to vector.
2257 SmallVector<Attribute> iterAttrs;
2258 auto iterators = linalgOp.getIteratorTypesArray();
2259 for (utils::IteratorType iter : iterators) {
2260 auto vecIter = iter == utils::IteratorType::parallel
2261 ? vector::IteratorType::parallel
2262 : vector::IteratorType::reduction;
2263 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2264 }
2265
2266 // Create contraction.
2267 Operation *contractOp = vector::ContractionOp::create(
2268 rewriter, loc, /*lhs=*/vecOperands[0],
2269 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2270 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2271 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2272
2273 // Store result.
2274 Operation *write = createWriteOrMaskedWrite(
2275 rewriter, loc, contractOp->getResult(0), outOperand->get());
2276
2277 // Finalize.
2278 if (!write->getResults().empty())
2279 newResults.push_back(write->getResult(0));
2280
2281 return success();
2282}
2283
2284namespace {
2285enum class ConvOperationKind { Conv, Pool };
2286} // namespace
2287
2288static bool isCastOfBlockArgument(Operation *op) {
2289 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2290 isa<BlockArgument>(op->getOperand(0));
2291}
2292
2293// Returns the ConvOperationKind of the op using reduceOp of the generic
2294// payload. If it is neither a convolution nor a pooling, it returns
2295// std::nullopt.
2296//
2297// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2298// + yield) and rhs is not used) then it is the body of a pooling
2299// If conv, check for single `mul` predecessor. The `mul` operands must be
2300// block arguments or extension of block arguments.
2301// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2302// must be block arguments or extension of block arguments.
2303static std::optional<ConvOperationKind>
2304getConvOperationKind(Operation *reduceOp) {
2305 int numBlockArguments =
2306 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2307
2308 switch (numBlockArguments) {
2309 case 1: {
2310 // Will be convolution if feeder is a MulOp.
2311 // A strength reduced version of MulOp for i1 type is AndOp which is also
2312 // supported. Otherwise, it can be pooling. This strength reduction logic
2313 // is in `buildBinaryFn` helper in the Linalg dialect.
2314 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2315 llvm::IsaPred<BlockArgument>);
2316 assert(feedValIt != reduceOp->operand_end() &&
2317 "Expected a non-block argument operand");
2318 Operation *feedOp = (*feedValIt).getDefiningOp();
2319 if (isCastOfBlockArgument(feedOp)) {
2320 return ConvOperationKind::Pool;
2321 }
2322
2323 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2324 (isa<arith::AndIOp>(feedOp) &&
2325 feedOp->getResultTypes()[0].isInteger(1))) &&
2326 llvm::all_of(feedOp->getOperands(), [](Value v) {
2327 if (isa<BlockArgument>(v))
2328 return true;
2329 if (Operation *op = v.getDefiningOp())
2330 return isCastOfBlockArgument(op);
2331 return false;
2332 }))) {
2333 return std::nullopt;
2334 }
2335
2336 return ConvOperationKind::Conv;
2337 }
2338 case 2:
2339 // Must be pooling
2340 return ConvOperationKind::Pool;
2341 default:
2342 return std::nullopt;
2343 }
2344}
2345
2346static bool isSupportedPoolKind(vector::CombiningKind kind) {
2347 switch (kind) {
2348 case vector::CombiningKind::ADD:
2349 case vector::CombiningKind::MAXNUMF:
2350 case vector::CombiningKind::MAXIMUMF:
2351 case vector::CombiningKind::MAXSI:
2352 case vector::CombiningKind::MAXUI:
2353 case vector::CombiningKind::MINNUMF:
2354 case vector::CombiningKind::MINIMUMF:
2355 case vector::CombiningKind::MINSI:
2356 case vector::CombiningKind::MINUI:
2357 return true;
2358 default:
2359 return false;
2360 }
2361}
2362
2363static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2364 auto getOperandType = [&](auto operand) {
2365 return dyn_cast<ShapedType>((operand->get()).getType());
2366 };
2367 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2368 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2369 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2370 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2371 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2372 // Note that this also ensures 2D and 3D convolutions are rejected.
2373 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2374 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2375 return failure();
2376
2377 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2378 if (!reduceOp)
2379 return failure();
2380
2381 auto maybeOper = getConvOperationKind(reduceOp);
2382 if (!maybeOper.has_value())
2383 return failure();
2384
2385 auto maybeKind = getCombinerOpKind(reduceOp);
2386 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2387 // can get strength reduced to `OR` which is also supported. This strength
2388 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2389 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2390 *maybeKind != vector::CombiningKind::OR) &&
2391 (*maybeOper != ConvOperationKind::Pool ||
2392 !isSupportedPoolKind(*maybeKind)))) {
2393 return failure();
2394 }
2395
2396 auto rhsRank = rhsShapedType.getRank();
2397 if (*maybeOper == ConvOperationKind::Pool) {
2398 if (rhsRank != 1)
2399 return failure();
2400 } else {
2401 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2402 return failure();
2403 }
2404
2405 return success();
2406}
2407
2408static LogicalResult vectorizeLinalgOpPrecondition(
2409 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2410 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2411 // tensor with dimension of 0 cannot be vectorized.
2412 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2413 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2414 }))
2415 return failure();
2416 // Check API contract for input vector sizes.
2417 if (!inputVectorSizes.empty() &&
2418 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2419 inputVectorSizes)))
2420 return failure();
2421
2422 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2423 linalgOp, flatten1DDepthwiseConv))) {
2424 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2425 return failure();
2426 }
2427
2428 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2429
2430 // Register CustomVectorizationPrecondition for extractOp.
2431 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2432
2433 // All types in the body should be a supported element type for VectorType.
2434 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2435 // Check if any custom hook can vectorize the inner op.
2436 if (llvm::any_of(
2437 customPreconditions,
2438 [&](const CustomVectorizationPrecondition &customPrecondition) {
2439 return succeeded(
2440 customPrecondition(&innerOp, vectorizeNDExtract));
2441 })) {
2442 continue;
2443 }
2444 if (!llvm::all_of(innerOp.getOperandTypes(),
2445 VectorType::isValidElementType)) {
2446 return failure();
2447 }
2448 if (!llvm::all_of(innerOp.getResultTypes(),
2449 VectorType::isValidElementType)) {
2450 return failure();
2451 }
2452 }
2453 if (isElementwise(linalgOp))
2454 return success();
2455
2456 // Check for both named as well as generic convolution ops.
2457 if (isaConvolutionOpInterface(linalgOp))
2458 return vectorizeConvOpPrecondition(linalgOp);
2459
2460 // TODO: the common vector shape is equal to the static loop sizes only when
2461 // all indexing maps are projected permutations. For convs and stencils the
2462 // logic will need to evolve.
2463 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2464 LDBG() << "precondition failed: not projected permutations";
2465 return failure();
2466 }
2467 if (failed(reductionPreconditions(linalgOp))) {
2468 LDBG() << "precondition failed: reduction preconditions";
2469 return failure();
2470 }
2471 return success();
2472}
2473
2474static LogicalResult
2475vectorizePackOpPrecondition(linalg::PackOp packOp,
2476 ArrayRef<int64_t> inputVectorSizes) {
2477 // TODO: Support Memref PackOp. Temporarily return failure.
2478 if (!packOp.hasPureTensorSemantics())
2479 return failure();
2480
2481 auto padValue = packOp.getPaddingValue();
2482 Attribute cstAttr;
2483 // TODO: Relax this condiiton
2484 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2485 LDBG() << "pad value is not constant: " << packOp;
2486 return failure();
2487 }
2488
2489 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2490 bool satisfyEmptyCond = true;
2491 if (inputVectorSizes.empty()) {
2492 if (!packOp.getDestType().hasStaticShape() ||
2493 !packOp.getSourceType().hasStaticShape())
2494 satisfyEmptyCond = false;
2495 }
2496
2497 if (!satisfyEmptyCond &&
2499 resultTensorShape.take_front(packOp.getSourceRank()),
2500 inputVectorSizes)))
2501 return failure();
2502
2503 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2504 return !getConstantIntValue(v).has_value();
2505 })) {
2506 LDBG() << "inner_tiles must be constant: " << packOp;
2507 return failure();
2508 }
2509
2510 return success();
2511}
2512
2513static LogicalResult
2514vectorizePadOpPrecondition(tensor::PadOp padOp,
2515 ArrayRef<int64_t> inputVectorSizes) {
2516 auto padValue = padOp.getConstantPaddingValue();
2517 if (!padValue) {
2518 LDBG() << "pad value is not constant: " << padOp;
2519 return failure();
2520 }
2521
2522 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2523 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2524 inputVectorSizes)))
2525 return failure();
2526
2527 // Padding with non-zero low pad values is not supported, unless the
2528 // corresponding result dim is 1 as this would require shifting the results to
2529 // the right for the low padded dims by the required amount of low padding.
2530 // However, we do support low padding if the dims being low padded have result
2531 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2532 // input size of that dimension will be dynamically zero (as the sum of the
2533 // low pad and input dim size has to be one) and hence we will create a zero
2534 // mask as the lowering logic just makes the mask one for the input dim size -
2535 // which is zero here. Hence we will load the pad value which is what we want
2536 // in this case. If the low pad is dynamically zero then the lowering is
2537 // correct as well as no shifts are necessary.
2538 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2539 [&](const auto &en) {
2540 OpFoldResult padValue = en.value();
2541 unsigned pos = en.index();
2542 std::optional<int64_t> pad = getConstantIntValue(padValue);
2543 return (!pad.has_value() || pad.value() != 0) &&
2544 resultTensorShape[pos] != 1;
2545 })) {
2546 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2547 return failure();
2548 }
2549
2550 return success();
2551}
2552
2553/// Preconditions for scalable vectors.
2554///
2555/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2556/// models the fact that in practice we would only make selected dimensions
2557/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2558/// unconditionally - we are yet to identify meaningful conditions.
2559static LogicalResult
2560vectorizeScalableVectorPrecondition(Operation *op,
2561 ArrayRef<int64_t> inputVectorSizes,
2562 ArrayRef<bool> inputScalableVecDims) {
2563 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2564 "Number of input vector sizes and scalable dims doesn't match");
2565
2566 size_t numOfScalableDims =
2567 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2568
2569 if (numOfScalableDims == 0)
2570 return success();
2571
2572 auto linalgOp = dyn_cast<LinalgOp>(op);
2573
2574 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2575 // exception of UnpackOp for which there is a dedicated hook.
2576 if (!linalgOp) {
2577 return success(isa<linalg::UnPackOp>(op));
2578 }
2579
2580 // Cond 2: There's been no need for more than 2 scalable dims so far
2581 if (numOfScalableDims > 2)
2582 return failure();
2583
2584 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2585 // it matches one of the supported cases:
2586 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2587 // (*).
2588 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2589 // parallel dims.
2590 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2591 // dim.
2592 // The 2nd restriction above means that only Matmul-like Ops are supported
2593 // when 2 dims are scalable, e.g. :
2594 // * iterators = [parallel, parallel, reduction]
2595 // * scalable flags = [true, true, false]
2596 //
2597 // (*) Non-unit dims get folded away in practice.
2598 // TODO: Relax these conditions as good motivating examples are identified.
2599
2600 // Find the first scalable flag.
2601 bool seenNonUnitParallel = false;
2602 auto iterators = linalgOp.getIteratorTypesArray();
2603 SmallVector<bool> scalableFlags(inputScalableVecDims);
2604 int64_t idx = scalableFlags.size() - 1;
2605 while (!scalableFlags[idx]) {
2606 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2607 seenNonUnitParallel |=
2608 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2609
2610 iterators.pop_back();
2611 scalableFlags.pop_back();
2612 --idx;
2613 }
2614
2615 // Analyze the iterator corresponding to the first scalable dim.
2616 switch (iterators.back()) {
2617 case utils::IteratorType::reduction: {
2618 // Check 3. above is met.
2619 if (iterators.size() != inputVectorSizes.size()) {
2620 LDBG() << "Non-trailing reduction dim requested for scalable "
2621 "vectorization";
2622 return failure();
2623 }
2624 if (isa<linalg::MatmulOp>(op)) {
2625 LDBG()
2626 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2627 "is not supported";
2628 return failure();
2629 }
2630 break;
2631 }
2632 case utils::IteratorType::parallel: {
2633 // Check 1. and 2. above are met.
2634 if (seenNonUnitParallel) {
2635 LDBG() << "Inner parallel dim not requested for scalable "
2636 "vectorization";
2637 return failure();
2638 }
2639 break;
2640 }
2641 }
2642
2643 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2644 // supported for which expect the folowing config:
2645 // * iterators = [parallel, parallel, reduction]
2646 // * scalable flags = [true, true, false]
2647 if (numOfScalableDims == 2) {
2648 // Disallow below case which breaks 3. above:
2649 // * iterators = [..., parallel, reduction]
2650 // * scalable flags = [..., true, true]
2651 if (iterators.back() == utils::IteratorType::reduction) {
2652 LDBG() << "Higher dim than the trailing reduction dim requested for "
2653 "scalable "
2654 "vectorizatio";
2655 return failure();
2656 }
2657 scalableFlags.pop_back();
2658 iterators.pop_back();
2659
2660 if (!scalableFlags.back() ||
2661 (iterators.back() != utils::IteratorType::parallel))
2662 return failure();
2663 }
2664
2665 // Cond 4: Only the following ops are supported in the
2666 // presence of scalable vectors
2667 return success(
2668 isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2669 isa<linalg::BatchMatmulOp>(op) ||
2671 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2672 isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
2673}
2674
2676 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2677 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2678 bool flatten1DDepthwiseConv) {
2679
2680 if (!hasVectorizationImpl(op))
2681 return failure();
2682
2683 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2684 inputScalableVecDims)))
2685 return failure();
2686
2688 .Case([&](linalg::LinalgOp linalgOp) {
2689 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2690 vectorizeNDExtract,
2691 flatten1DDepthwiseConv);
2692 })
2693 .Case([&](tensor::PadOp padOp) {
2694 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2695 })
2696 .Case([&](linalg::PackOp packOp) {
2697 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2698 })
2699 .Case([&](linalg::UnPackOp unpackOp) {
2700 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2701 })
2702 .Case([&](tensor::InsertSliceOp sliceOp) {
2703 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2704 })
2705 .Default(failure());
2706}
2707
2708/// Converts affine.apply Ops to arithmetic operations.
2709static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2710 OpBuilder::InsertionGuard g(rewriter);
2711 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2712
2713 for (auto op : make_early_inc_range(toReplace)) {
2714 rewriter.setInsertionPoint(op);
2715 auto expanded = affine::expandAffineExpr(
2716 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2717 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2718 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2719 rewriter.replaceOp(op, expanded);
2720 }
2721}
2722
2723bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2724 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2725 tensor::InsertSliceOp>(op);
2726}
2727
2728FailureOr<VectorizationResult> mlir::linalg::vectorize(
2729 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2730 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2731 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2732 bool createNamedContraction) {
2733 LDBG() << "Attempting to vectorize: " << *op;
2734 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2735 LDBG() << "Input scalable vector dims: "
2736 << llvm::interleaved(inputScalableVecDims);
2737
2738 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2739 vectorizeNDExtract,
2740 flatten1DDepthwiseConv))) {
2741 LDBG() << "Vectorization pre-conditions failed";
2742 return failure();
2743 }
2744
2745 // Initialize vectorization state.
2746 VectorizationState state(rewriter);
2747 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2748 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2749 inputScalableVecDims,
2750 assumeDynamicDimsMatchVecSizes))) {
2751 LDBG() << "Vectorization state couldn't be initialized";
2752 return failure();
2753 }
2754 }
2755
2756 SmallVector<Value> results;
2757 auto vectorizeResult =
2759 .Case([&](linalg::LinalgOp linalgOp) {
2760 // Check for both named as well as generic convolution ops.
2761 if (isaConvolutionOpInterface(linalgOp)) {
2762 FailureOr<Operation *> convOr = vectorizeConvolution(
2763 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2764 flatten1DDepthwiseConv);
2765 if (succeeded(convOr)) {
2766 llvm::append_range(results, (*convOr)->getResults());
2767 return success();
2768 }
2769
2770 LDBG() << "Unsupported convolution can't be vectorized.";
2771 return failure();
2772 }
2773
2774 if (createNamedContraction &&
2775 isa<ContractionOpInterface>(linalgOp.getOperation()))
2776 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2777 results);
2778
2779 LDBG()
2780 << "Vectorize generic by broadcasting to the canonical vector "
2781 "shape";
2782
2783 // Pre-process before proceeding.
2784 convertAffineApply(rewriter, linalgOp);
2785
2786 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2787 // to 'OpBuilder' when it is passed over to some methods like
2788 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2789 // erase an op within these methods, the actual rewriter won't be
2790 // notified and we will end up with read-after-free issues!
2791 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2792 })
2793 .Case([&](tensor::PadOp padOp) {
2794 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2795 results);
2796 })
2797 .Case([&](linalg::PackOp packOp) {
2798 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2799 results);
2800 })
2801 .Case([&](linalg::UnPackOp unpackOp) {
2802 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2803 inputVectorSizes,
2804 inputScalableVecDims, results);
2805 })
2806 .Case([&](tensor::InsertSliceOp sliceOp) {
2807 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2808 results);
2809 })
2810 .Default(failure());
2811
2812 if (failed(vectorizeResult)) {
2813 LDBG() << "Vectorization failed";
2814 return failure();
2815 }
2816
2817 return VectorizationResult{results};
2818}
2819
2820LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2821 memref::CopyOp copyOp) {
2822 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2823 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2824 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2825 return failure();
2826
2827 auto srcElementType = getElementTypeOrSelf(srcType);
2828 auto dstElementType = getElementTypeOrSelf(dstType);
2829 if (!VectorType::isValidElementType(srcElementType) ||
2830 !VectorType::isValidElementType(dstElementType))
2831 return failure();
2832
2833 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2834 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2835
2836 Location loc = copyOp->getLoc();
2837 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2838 SmallVector<Value> indices(srcType.getRank(), zero);
2839
2840 Value readValue = vector::TransferReadOp::create(
2841 rewriter, loc, readType, copyOp.getSource(), indices,
2842 /*padding=*/std::nullopt,
2843 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2844 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2845 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2846 ArrayRef<int64_t>());
2847 readValue =
2848 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2849 }
2850 Operation *writeValue = vector::TransferWriteOp::create(
2851 rewriter, loc, readValue, copyOp.getTarget(), indices,
2852 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2853 rewriter.replaceOp(copyOp, writeValue->getResults());
2854 return success();
2855}
2856
2857//----------------------------------------------------------------------------//
2858// Misc. vectorization patterns.
2859//----------------------------------------------------------------------------//
2860/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2861/// given operation type OpTy.
2862template <typename OpTy>
2863struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2864 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2865
2866 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2867 PatternRewriter &rewriter) const final {
2868 bool changed = false;
2869 // Insert users in vector, because some users may be replaced/removed.
2870 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2871 if (auto op = dyn_cast<OpTy>(user))
2872 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2873 return success(changed);
2874 }
2875
2876protected:
2877 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2878 tensor::PadOp padOp, OpTy op) const = 0;
2879};
2880
2881/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2882/// ```
2883/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2884/// %r = vector.transfer_read %0[%c0, %c0], %cst
2885/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2886/// ```
2887/// is rewritten to:
2888/// ```
2889/// %r = vector.transfer_read %src[%c0, %c0], %padding
2890/// {in_bounds = [true, true]}
2891/// : tensor<?x?xf32>, vector<17x5xf32>
2892/// ```
2893/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2894/// sure that the original padding value %cst was never used.
2895///
2896/// This rewrite is possible if:
2897/// - `xferOp` has no out-of-bounds dims or mask.
2898/// - Low padding is static 0.
2899/// - Single, scalar padding value.
2900struct PadOpVectorizationWithTransferReadPattern
2901 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2902 using VectorizePadOpUserPattern<
2903 vector::TransferReadOp>::VectorizePadOpUserPattern;
2904
2905 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2906 vector::TransferReadOp xferOp) const override {
2907 // Low padding must be static 0.
2908 if (!padOp.hasZeroLowPad())
2909 return failure();
2910 // Pad value must be a constant.
2911 auto padValue = padOp.getConstantPaddingValue();
2912 if (!padValue)
2913 return failure();
2914 // Padding value of existing `xferOp` is unused.
2915 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2916 return failure();
2917
2918 rewriter.modifyOpInPlace(xferOp, [&]() {
2919 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2920 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2921 rewriter.getBoolArrayAttr(inBounds));
2922 xferOp.getBaseMutable().assign(padOp.getSource());
2923 xferOp.getPaddingMutable().assign(padValue);
2924 });
2925
2926 return success();
2927 }
2928};
2929
2930/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2931/// This pattern rewrites TransferWriteOps that write to a padded tensor
2932/// value, where the same amount of padding is immediately removed again after
2933/// the write. In such cases, the TransferWriteOp can write to the non-padded
2934/// tensor value and apply out-of-bounds masking. E.g.:
2935/// ```
2936/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2937/// : tensor<...> to tensor<?x?xf32>
2938/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2939/// %2 = vector.transfer_write %vec, %1[...]
2940/// : vector<17x5xf32>, tensor<17x5xf32>
2941/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2942/// : tensor<17x5xf32> to tensor<?x?xf32>
2943/// ```
2944/// is rewritten to:
2945/// ```
2946/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2947/// : tensor<...> to tensor<?x?xf32>
2948/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2949/// tensor<?x?xf32>
2950/// ```
2951/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2952/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2953/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2954/// from %r's old dimensions.
2955///
2956/// This rewrite is possible if:
2957/// - Low padding is static 0.
2958/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2959/// ExtractSliceOp trims the same amount of padding that was added
2960/// beforehand.
2961/// - Single, scalar padding value.
2962struct PadOpVectorizationWithTransferWritePattern
2963 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2964 using VectorizePadOpUserPattern<
2965 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2966
2967 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2968 vector::TransferWriteOp xferOp) const override {
2969 // TODO: support 0-d corner case.
2970 if (xferOp.getTransferRank() == 0)
2971 return failure();
2972
2973 // Low padding must be static 0.
2974 if (!padOp.hasZeroLowPad())
2975 return failure();
2976 // Pad value must be a constant.
2977 auto padValue = padOp.getConstantPaddingValue();
2978 if (!padValue)
2979 return failure();
2980 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2981 if (!xferOp->hasOneUse())
2982 return failure();
2983 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2984 if (!trimPadding)
2985 return failure();
2986 // Only static zero offsets supported when trimming padding.
2987 if (!trimPadding.hasZeroOffset())
2988 return failure();
2989 // trimPadding must remove the amount of padding that was added earlier.
2990 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2991 return failure();
2992
2993 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2994 rewriter.setInsertionPoint(xferOp);
2995
2996 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2997 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2998 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2999 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
3000 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
3001 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
3002
3003 return success();
3004 }
3005
3006 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
3007 /// i.e., same dimensions.
3008 ///
3009 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
3010 /// dimensions, this function tries to infer the (static) tensor size by
3011 /// looking at the defining op and utilizing op-specific knowledge.
3012 ///
3013 /// This is a conservative analysis. In case equal tensor sizes cannot be
3014 /// proven statically, this analysis returns `false` even though the tensor
3015 /// sizes may turn out to be equal at runtime.
3016 bool hasSameTensorSize(Value beforePadding,
3017 tensor::ExtractSliceOp afterTrimming) const {
3018 // If the input to tensor::PadOp is a CastOp, try with both CastOp
3019 // result and CastOp operand.
3020 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
3021 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3022 return true;
3023
3024 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
3025 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3026 // Only RankedTensorType supported.
3027 if (!t1 || !t2)
3028 return false;
3029 // Rank of both values must be the same.
3030 if (t1.getRank() != t2.getRank())
3031 return false;
3032
3033 // All static dimensions must be the same. Mixed cases (e.g., dimension
3034 // static in `t1` but dynamic in `t2`) are not supported.
3035 for (unsigned i = 0; i < t1.getRank(); ++i) {
3036 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3037 return false;
3038 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3039 return false;
3040 }
3041
3042 // Nothing more to check if all dimensions are static.
3043 if (t1.getNumDynamicDims() == 0)
3044 return true;
3045
3046 // All dynamic sizes must be the same. The only supported case at the
3047 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3048 // thereof).
3049
3050 // Apart from CastOp, only ExtractSliceOp is supported.
3051 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3052 if (!beforeSlice)
3053 return false;
3054
3055 assert(static_cast<size_t>(t1.getRank()) ==
3056 beforeSlice.getMixedSizes().size());
3057 assert(static_cast<size_t>(t2.getRank()) ==
3058 afterTrimming.getMixedSizes().size());
3059
3060 for (unsigned i = 0; i < t1.getRank(); ++i) {
3061 // Skip static dimensions.
3062 if (!t1.isDynamicDim(i))
3063 continue;
3064 auto size1 = beforeSlice.getMixedSizes()[i];
3065 auto size2 = afterTrimming.getMixedSizes()[i];
3066
3067 // Case 1: Same value or same constant int.
3068 if (isEqualConstantIntOrValue(size1, size2))
3069 continue;
3070
3071 // Other cases: Take a deeper look at defining ops of values.
3072 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3073 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3074 if (!v1 || !v2)
3075 return false;
3076
3077 // Case 2: Both values are identical AffineMinOps. (Should not happen if
3078 // CSE is run.)
3079 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3080 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3081 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3082 minOp1.getOperands() == minOp2.getOperands())
3083 continue;
3084
3085 // Add additional cases as needed.
3086 }
3087
3088 // All tests passed.
3089 return true;
3090 }
3091};
3092
3093/// Returns the effective Pad value for the input op, provided it's a scalar.
3094///
3095/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3096/// this Op performs padding, retrieve the padding value provided that it's
3097/// a scalar and static/fixed for all the padded values. Returns an empty value
3098/// otherwise.
3099///
3100/// TODO: This is used twice (when checking vectorization pre-conditions and
3101/// when vectorizing). Cache results instead of re-running.
3102static Value getStaticPadVal(Operation *op) {
3103 if (!op)
3104 return {};
3105
3106 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3107 // being broadcast, provided that it's a scalar.
3108 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3109 auto source = bcast.getSource();
3110 if (llvm::dyn_cast<VectorType>(source.getType()))
3111 return {};
3112
3113 return source;
3114 }
3115
3116 // 2. linalg.fill - use the scalar input value that used to fill the output
3117 // tensor.
3118 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3119 return fill.getInputs()[0];
3120 }
3121
3122 // 3. tensor.generateOp - can't guarantee the value is fixed without
3123 // analysing, bail out.
3124 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3125 return {};
3126 }
3127
3128 // 4. vector.transfer_write - inspect the input vector that's written from. If
3129 // if contains a single value that has been broadcast (e.g. via
3130 // vector.broadcast), extract it, fail otherwise.
3131 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3132 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3133
3134 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3135 // than the input tensor, then, provided it's constant, we'll extract the
3136 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3137 // TODO: Clarify the semantics when the input tensor is larger than the
3138 // destination.
3139 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3140 return getStaticPadVal(slice.getDest().getDefiningOp());
3141
3142 return {};
3143}
3144
3145static LogicalResult
3146vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3147 ArrayRef<int64_t> inputVectorSizes,
3148 SmallVectorImpl<Value> &newResults) {
3149 // TODO: Introduce a parent class that will handle the insertion point update.
3150 OpBuilder::InsertionGuard g(rewriter);
3151 rewriter.setInsertionPoint(sliceOp);
3152
3153 TypedValue<RankedTensorType> source = sliceOp.getSource();
3154 auto sourceType = source.getType();
3155 auto resultType = sliceOp.getResultType();
3156
3157 Value padValue = getStaticPadVal(sliceOp);
3158
3159 if (!padValue) {
3160 auto elemType = sourceType.getElementType();
3161 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3162 rewriter.getZeroAttr(elemType));
3163 }
3164
3165 // 2. Get the vector shape
3166 SmallVector<int64_t> vecShape;
3167 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3168 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3169 if (!inputVectorSizes.empty()) {
3170 vecShape.push_back(inputVectorSizes[i]);
3171 } else if (!sourceType.isDynamicDim(i)) {
3172 vecShape.push_back(sourceType.getDimSize(i));
3173 } else if (!resultType.isDynamicDim(i)) {
3174 // Source shape is not statically known, but result shape is.
3175 // Vectorize with size of result shape. This may be larger than the
3176 // source size.
3177 // FIXME: Using rankDiff implies that the source tensor is inserted at
3178 // the end of the destination tensor. However, that's not required.
3179 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3180 } else {
3181 // Neither source nor result dim of padOp is static. Cannot vectorize
3182 // the copy.
3183 return failure();
3184 }
3185 }
3186 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3187
3188 // 3. Generate TransferReadOp + TransferWriteOp
3189 auto loc = sliceOp.getLoc();
3190
3191 // Create read
3192 SmallVector<Value> readIndices(
3193 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3195 rewriter, loc, source, vecType, padValue,
3196 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3197
3198 // Create write
3199 auto writeIndices =
3200 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3201 Operation *write =
3202 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3203 writeIndices, inputVectorSizes.empty());
3204
3205 // 4. Finalize
3206 newResults.push_back(write->getResult(0));
3207
3208 return success();
3209}
3210
3211/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3212/// ```
3213/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3214/// %r = tensor.insert_slice %0
3215/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3216/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3217/// ```
3218/// is rewritten to:
3219/// ```
3220/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3221/// : tensor<?x?xf32>, vector<17x5xf32>
3222/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3223/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3224/// ```
3225///
3226/// This rewrite is possible if:
3227/// - Low padding is static 0.
3228/// - `padOp` result shape is static.
3229/// - The entire padded tensor is inserted.
3230/// (Implies that sizes of `insertOp` are all static.)
3231/// - Only unit strides in `insertOp`.
3232/// - Single, scalar padding value.
3233/// - `padOp` result not used as destination.
3234struct PadOpVectorizationWithInsertSlicePattern
3235 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3236 using VectorizePadOpUserPattern<
3237 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3238
3239 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3240 tensor::InsertSliceOp insertOp) const override {
3241 // Low padding must be static 0.
3242 if (!padOp.hasZeroLowPad())
3243 return failure();
3244 // Only unit stride supported.
3245 if (!insertOp.hasUnitStride())
3246 return failure();
3247 // Pad value must be a constant.
3248 auto padValue = padOp.getConstantPaddingValue();
3249 if (!padValue)
3250 return failure();
3251 // Dynamic shapes not supported.
3252 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3253 return failure();
3254 // Pad result not used as destination.
3255 if (insertOp.getDest() == padOp.getResult())
3256 return failure();
3257
3258 auto vecType = VectorType::get(padOp.getType().getShape(),
3259 padOp.getType().getElementType());
3260 unsigned vecRank = vecType.getRank();
3261 unsigned tensorRank = insertOp.getType().getRank();
3262
3263 // Check if sizes match: Insert the entire tensor into most minor dims.
3264 // (No permutations allowed.)
3265 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3266 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3267 if (!llvm::all_of(
3268 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3269 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3270 }))
3271 return failure();
3272
3273 // Insert the TransferReadOp and TransferWriteOp at the position of the
3274 // InsertSliceOp.
3275 rewriter.setInsertionPoint(insertOp);
3276
3277 // Generate TransferReadOp: Read entire source tensor and add high
3278 // padding.
3279 SmallVector<Value> readIndices(
3280 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3281 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3282 vecType, padOp.getSource(),
3283 readIndices, padValue);
3284
3285 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3286 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3287 // source must fit into the destination at the specified offsets.
3288 auto writeIndices = getValueOrCreateConstantIndexOp(
3289 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3290 SmallVector<bool> inBounds(vecRank, true);
3291 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3292 insertOp, read, insertOp.getDest(), writeIndices,
3293 ArrayRef<bool>{inBounds});
3294
3295 return success();
3296 }
3297};
3298
3300 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3301 patterns.add<PadOpVectorizationWithTransferReadPattern,
3302 PadOpVectorizationWithTransferWritePattern,
3303 PadOpVectorizationWithInsertSlicePattern>(
3304 patterns.getContext(), baseBenefit.getBenefit() + 1);
3305}
3306
3307//----------------------------------------------------------------------------//
3308// Forwarding patterns
3309//----------------------------------------------------------------------------//
3310
3311/// Check whether there is any interleaved use of any `values` between
3312/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3313/// is in a different block.
3314static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3315 ValueRange values) {
3316 if (firstOp->getBlock() != secondOp->getBlock() ||
3317 !firstOp->isBeforeInBlock(secondOp)) {
3318 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3319 << ", second op: " << *secondOp;
3320 return true;
3321 }
3322 for (auto v : values) {
3323 for (auto &u : v.getUses()) {
3324 Operation *owner = u.getOwner();
3325 if (owner == firstOp || owner == secondOp)
3326 continue;
3327 // TODO: this is too conservative, use dominance info in the future.
3328 if (owner->getBlock() == firstOp->getBlock() &&
3329 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3330 continue;
3331 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3332 << ", second op: " << *secondOp;
3333 return true;
3334 }
3335 }
3336 return false;
3337}
3338
3339/// Return the unique subview use of `v` if it is indeed unique, null
3340/// otherwise.
3341static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3342 memref::SubViewOp subViewOp;
3343 for (auto &u : v.getUses()) {
3344 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3345 if (subViewOp)
3346 return memref::SubViewOp();
3347 subViewOp = newSubViewOp;
3348 }
3349 }
3350 return subViewOp;
3351}
3352
3353/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3354/// when available.
3356 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3357
3358 // TODO: support mask.
3359 if (xferOp.getMask())
3360 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3361
3362 // Transfer into `view`.
3363 Value viewOrAlloc = xferOp.getBase();
3364 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3365 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3366 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3367
3368 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3369 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3370 if (!subViewOp)
3371 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3372 Value subView = subViewOp.getResult();
3373
3374 // Find the copy into `subView` without interleaved uses.
3375 memref::CopyOp copyOp;
3376 for (auto &u : subView.getUses()) {
3377 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3378 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3379 if (newCopyOp.getTarget() != subView)
3380 continue;
3381 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3382 continue;
3383 copyOp = newCopyOp;
3384 break;
3385 }
3386 }
3387 if (!copyOp)
3388 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3389
3390 // Find the fill into `viewOrAlloc` without interleaved uses before the
3391 // copy.
3392 FillOp maybeFillOp;
3393 for (auto &u : viewOrAlloc.getUses()) {
3394 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3395 assert(isa<MemRefType>(newFillOp.output().getType()));
3396 if (newFillOp.output() != viewOrAlloc)
3397 continue;
3398 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3399 continue;
3400 maybeFillOp = newFillOp;
3401 break;
3402 }
3403 }
3404 // Ensure padding matches.
3405 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3406 return rewriter.notifyMatchFailure(xferOp,
3407 "padding value does not match fill");
3408
3409 // `in` is the subview that memref.copy reads. Replace it.
3410 Value in = copyOp.getSource();
3411
3412 // memref.copy + linalg.fill can be used to create a padded local buffer.
3413 // The `masked` attribute is only valid on this padded buffer.
3414 // When forwarding to vector.transfer_read, the attribute must be reset
3415 // conservatively.
3416 auto vectorType = xferOp.getVectorType();
3417 Value res = vector::TransferReadOp::create(
3418 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3419 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3420 rewriter.getBoolArrayAttr(
3421 SmallVector<bool>(vectorType.getRank(), false)));
3422
3423 if (maybeFillOp)
3424 rewriter.eraseOp(maybeFillOp);
3425 rewriter.eraseOp(copyOp);
3426 rewriter.replaceOp(xferOp, res);
3427
3428 return success();
3429}
3430
3431/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3432/// when available.
3434 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3435 // TODO: support mask.
3436 if (xferOp.getMask())
3437 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3438
3439 // Transfer into `viewOrAlloc`.
3440 Value viewOrAlloc = xferOp.getBase();
3441 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3442 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3443 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3444
3445 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3446 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3447 if (!subViewOp)
3448 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3449 Value subView = subViewOp.getResult();
3450
3451 // Find the copy from `subView` without interleaved uses.
3452 memref::CopyOp copyOp;
3453 for (auto &u : subViewOp.getResult().getUses()) {
3454 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3455 if (newCopyOp.getSource() != subView)
3456 continue;
3457 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3458 continue;
3459 copyOp = newCopyOp;
3460 break;
3461 }
3462 }
3463 if (!copyOp)
3464 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3465
3466 // `out` is the subview copied into that we replace.
3467 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3468 Value out = copyOp.getTarget();
3469
3470 // Forward vector.transfer into copy.
3471 // memref.copy + linalg.fill can be used to create a padded local buffer.
3472 // The `masked` attribute is only valid on this padded buffer.
3473 // When forwarding to vector.transfer_write, the attribute must be reset
3474 // conservatively.
3475 auto vector = xferOp.getVector();
3476 vector::TransferWriteOp::create(
3477 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3478 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3479 rewriter.getBoolArrayAttr(SmallVector<bool>(
3480 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3481
3482 rewriter.eraseOp(copyOp);
3483 rewriter.eraseOp(xferOp);
3484
3485 return success();
3486}
3487
3488//===----------------------------------------------------------------------===//
3489// Convolution vectorization patterns
3490//===----------------------------------------------------------------------===//
3491
3492template <int N>
3493static void bindShapeDims(ShapedType shapedType) {}
3494
3495template <int N, typename IntTy, typename... IntTy2>
3496static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3497 val = shapedType.getShape()[N];
3498 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3499}
3500
3501/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3502template <typename... IntTy>
3503static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3504 bindShapeDims<0>(shapedType, vals...);
3505}
3506
3507/// Match 1D convolution or pooling operations and return their dilations and
3508/// strides. Returns std::nullopt for unrecognized ops.
3509static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3510#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3511 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3512 return convParams;
3513
3514 // 1D Convolution ops.
3515 MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
3516 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
3517 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
3518 // Depthwise 1D Convolution ops.
3519 // Note: Only NWC layout without channel multiplier is supported.
3520 // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
3521 // are not supported.
3522 MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
3523 // 1D Pooling ops (NWC layout).
3524 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
3525 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
3526 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
3527 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
3528 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
3529 // 1D Pooling ops (NCW layout).
3530 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
3531 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
3532
3533#undef MATCH_1D_CONV_POOL_OP
3534
3535 return std::nullopt;
3536}
3537
3538namespace {
3539/// Generate a vector implementation for either:
3540/// ```
3541/// Op def: ( w, kw )
3542/// Iters: ({Par(), Red()})
3543/// Layout: {{w + kw}, {kw}, {w}}
3544/// ```
3545/// kw is unrolled.
3546///
3547/// or
3548///
3549/// ```
3550/// Op def: ( n, w, c, kw, f )
3551/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3552/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3553/// ```
3554/// kw is unrolled, w is unrolled iff dilationW > 1.
3555///
3556/// or
3557///
3558/// ```
3559/// Op def: ( n, c, w, f, kw )
3560/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3561/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3562/// ```
3563/// kw is unrolled, w is unrolled iff dilationW > 1.
3564///
3565/// or
3566///
3567/// ```
3568/// Op def: ( n, w, c, kw )
3569/// Iters: ({Par(), Par(), Par(), Red()})
3570/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3571/// ```
3572/// kw is unrolled, w is unrolled iff dilationW > 1.
3573struct Conv1DGenerator
3574 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3575 /// Factory method to create a Conv1DGenerator. Returns failure if the
3576 /// operation doesn't have valid strides/dilations.
3577 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3578 LinalgOp linalgOp) {
3579 // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
3580 // works for both named ops and generic ops that match their semantics.
3581 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3582 if (!convParams)
3583 return failure();
3584
3585 int strideW = static_cast<int>(convParams->strides.front());
3586 int dilationW = static_cast<int>(convParams->dilations.front());
3587 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3588 }
3589
3590private:
3591 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3592 int dilationW)
3593 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3594 strideW(strideW), dilationW(dilationW) {
3595
3596 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3597 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3598 resShaped = linalgOp.getDpsInitOperand(0)->get();
3599 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3600 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3601 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3602
3603 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3604 redOp = reduceOp->getName().getIdentifier();
3605
3606 setConvOperationKind(reduceOp);
3607
3608 auto maybeKind = getCombinerOpKind(reduceOp);
3609 reductionKind = maybeKind.value();
3610 }
3611
3612public:
3613 /// Generate a vector implementation for:
3614 /// ```
3615 /// Op def: ( w, kw )
3616 /// Iters: ({Par(), Red()})
3617 /// Layout: {{w + kw}, {kw}, {w}}
3618 /// ```
3619 /// kw is always unrolled.
3620 ///
3621 /// or
3622 ///
3623 /// ```
3624 /// Op def: ( n, w, c, kw, f )
3625 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3626 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3627 /// ```
3628 /// kw is always unrolled.
3629 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3630 /// > 1.
3631 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3632 int64_t nSize, wSize, cSize, kwSize, fSize;
3633 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3634 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3635 switch (conv1DOpOrder) {
3636 case Conv1DOpOrder::W:
3637 // Initialize unused dimensions
3638 nSize = fSize = cSize = 0;
3639 // out{W}
3640 bindShapeDims(resShapedType, wSize);
3641 // kernel{kw}
3642 bindShapeDims(rhsShapedType, kwSize);
3643 lhsShape = {// iw = ow + kw - 1
3644 // (i.e. 16 convolved with 3 -> 14)
3645 (wSize + kwSize - 1)};
3646 rhsShape = {kwSize};
3647 resShape = {wSize};
3648 break;
3649 case Conv1DOpOrder::Nwc:
3650 // out{n, w, f}
3651 bindShapeDims(resShapedType, nSize, wSize, fSize);
3652 switch (oper) {
3653 case ConvOperationKind::Conv:
3654 // kernel{kw, c, f}
3655 bindShapeDims(rhsShapedType, kwSize, cSize);
3656 break;
3657 case ConvOperationKind::Pool:
3658 // kernel{kw}
3659 bindShapeDims(rhsShapedType, kwSize);
3660 cSize = fSize;
3661 break;
3662 }
3663 lhsShape = {nSize,
3664 // iw = ow * sw + kw * dw - 1
3665 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3666 // Perform the proper inclusive -> exclusive -> inclusive.
3667 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3668 1,
3669 cSize};
3670 switch (oper) {
3671 case ConvOperationKind::Conv:
3672 rhsShape = {kwSize, cSize, fSize};
3673 break;
3674 case ConvOperationKind::Pool:
3675 rhsShape = {kwSize};
3676 break;
3677 }
3678 resShape = {nSize, wSize, fSize};
3679 break;
3680 case Conv1DOpOrder::Ncw:
3681 // out{n, f, w}
3682 bindShapeDims(resShapedType, nSize, fSize, wSize);
3683 switch (oper) {
3684 case ConvOperationKind::Conv:
3685 // kernel{f, c, kw}
3686 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3687 break;
3688 case ConvOperationKind::Pool:
3689 // kernel{kw}
3690 bindShapeDims(rhsShapedType, kwSize);
3691 cSize = fSize;
3692 break;
3693 }
3694 lhsShape = {nSize, cSize,
3695 // iw = ow * sw + kw * dw - 1
3696 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3697 // Perform the proper inclusive -> exclusive -> inclusive.
3698 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3699 1};
3700 switch (oper) {
3701 case ConvOperationKind::Conv:
3702 rhsShape = {fSize, cSize, kwSize};
3703 break;
3704 case ConvOperationKind::Pool:
3705 rhsShape = {kwSize};
3706 break;
3707 }
3708 resShape = {nSize, fSize, wSize};
3709 break;
3710 }
3711
3712 vector::TransferWriteOp write;
3713 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3714
3715 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3716 // When strideW == 1, we can batch the contiguous loads and avoid
3717 // unrolling
3718 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3719
3720 Type lhsEltType = lhsShapedType.getElementType();
3721 Type rhsEltType = rhsShapedType.getElementType();
3722 Type resEltType = resShapedType.getElementType();
3723 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3724 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3725 auto resType = VectorType::get(resShape, resEltType);
3726 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3727 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3728 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3729 SmallVector<Value> resPadding(resShape.size(), zero);
3730
3731 // Read the whole lhs, rhs and res in one shot (with zero padding).
3732 Value lhs = vector::TransferReadOp::create(
3733 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3734 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3735 // This is needed only for Conv.
3736 Value rhs = nullptr;
3737 if (oper == ConvOperationKind::Conv)
3738 rhs = vector::TransferReadOp::create(
3739 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3740 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3741 Value res = vector::TransferReadOp::create(
3742 rewriter, loc, resType, resShaped, resPadding,
3743 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3744
3745 // The base vectorization case for channeled convolution is input:
3746 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3747 // vectorization case, we do pre transpose on input, weight, and output.
3748 switch (conv1DOpOrder) {
3749 case Conv1DOpOrder::W:
3750 case Conv1DOpOrder::Nwc:
3751 // Base case, so no transposes necessary.
3752 break;
3753 case Conv1DOpOrder::Ncw: {
3754 // To match base vectorization case, we pre-transpose current case.
3755 // ncw -> nwc
3756 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3757 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3758 // fcw -> wcf
3759 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3760
3761 // This is needed only for Conv.
3762 if (oper == ConvOperationKind::Conv)
3763 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3764 // nfw -> nwf
3765 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3766 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3767 break;
3768 }
3769 }
3770
3771 //===------------------------------------------------------------------===//
3772 // Begin vector-only rewrite part
3773 //===------------------------------------------------------------------===//
3774 // Unroll along kw and read slices of lhs and rhs.
3775 SmallVector<Value> lhsVals, rhsVals, resVals;
3776 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3777 kwSize, strideW, dilationW, wSizeStep,
3778 isSingleChanneled);
3779 // Do not do for pooling.
3780 if (oper == ConvOperationKind::Conv)
3781 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3782 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3783 wSizeStep, isSingleChanneled);
3784
3785 auto linearIndex = [&](int64_t kw, int64_t w) {
3786 return kw * (wSize / wSizeStep) + w;
3787 };
3788
3789 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3790 // or perform outerproduct for non-channeled convolution or perform simple
3791 // arith operation for pooling
3792 for (int64_t kw = 0; kw < kwSize; ++kw) {
3793 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3794 switch (oper) {
3795 case ConvOperationKind::Conv:
3796 if (isSingleChanneled) {
3797 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3798 lhsVals[linearIndex(kw, w)],
3799 rhsVals[kw], resVals[w]);
3800 } else {
3801 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3802 lhsVals[linearIndex(kw, w)],
3803 rhsVals[kw], resVals[w]);
3804 }
3805 break;
3806 case ConvOperationKind::Pool:
3807 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3808 resVals[w]);
3809 break;
3810 }
3811 }
3812 }
3813
3814 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3815 isSingleChanneled);
3816 //===------------------------------------------------------------------===//
3817 // End vector-only rewrite part
3818 //===------------------------------------------------------------------===//
3819
3820 // The base vectorization case for channeled convolution is output:
3821 // {n,w,f} To reuse the result from base pattern vectorization case, we
3822 // post transpose the base case result.
3823 switch (conv1DOpOrder) {
3824 case Conv1DOpOrder::W:
3825 case Conv1DOpOrder::Nwc:
3826 // Base case, so no transposes necessary.
3827 break;
3828 case Conv1DOpOrder::Ncw: {
3829 // nwf -> nfw
3830 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3831 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3832 break;
3833 }
3834 }
3835
3836 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3837 resPadding)
3838 .getOperation();
3839 }
3840
3841 // Take a value and widen to have the same element type as `ty`.
3842 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3843 const Type srcElementType = getElementTypeOrSelf(val.getType());
3844 const Type dstElementType = getElementTypeOrSelf(ty);
3845 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3846 if (srcElementType == dstElementType)
3847 return val;
3848
3849 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3850 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3851 // Handle both shaped as well as scalar types.
3852 Type dstType;
3853 if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
3854 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3855 else
3856 dstType = dstElementType;
3857
3858 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3859 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3860 }
3861
3862 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3863 srcWidth < dstWidth)
3864 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3865
3866 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3867 srcWidth < dstWidth)
3868 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3869
3870 assert(false && "unhandled promotion case");
3871 return nullptr;
3872 }
3873
3874 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3875 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3876 Value lhs, Value rhs, Value res) {
3877 vector::IteratorType par = vector::IteratorType::parallel;
3878 vector::IteratorType red = vector::IteratorType::reduction;
3879 AffineExpr n, w, f, c;
3880 bindDims(ctx, n, w, f, c);
3881 lhs = promote(rewriter, loc, lhs, res.getType());
3882 rhs = promote(rewriter, loc, rhs, res.getType());
3883 auto contrationOp = vector::ContractionOp::create(
3884 rewriter, loc, lhs, rhs, res,
3885 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3886 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3887 contrationOp.setKind(reductionKind);
3888 return contrationOp;
3889 }
3890
3891 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3892 // convolution.
3893 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3894 Value lhs, Value rhs, Value res) {
3895 lhs = promote(rewriter, loc, lhs, res.getType());
3896 rhs = promote(rewriter, loc, rhs, res.getType());
3897 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3898 rhs, res, vector::CombiningKind::ADD);
3899 }
3900
3901 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3902 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3903 Value res) {
3904 if (isPoolExt)
3905 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3906 return rewriter
3907 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3908 ->getResult(0);
3909 }
3910
3911 /// Generate a vector implementation for:
3912 /// ```
3913 /// Op def: ( n, w, c, kw)
3914 /// Iters: ({Par(), Par(), Par(), Red()})
3915 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3916 /// ```
3917 /// kw is always unrolled.
3918 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3919 /// > 1.
3920 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3921 bool channelDimScalableFlag,
3922 bool flatten) {
3923 bool scalableChDim = false;
3924 bool useMasking = false;
3925 int64_t nSize, wSize, cSize, kwSize;
3926 // kernel{kw, c}
3927 bindShapeDims(rhsShapedType, kwSize, cSize);
3928 if (ShapedType::isDynamic(cSize)) {
3929 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3930 cSize = channelDimVecSize;
3931 // Scalable vectors are only used when both conditions are met:
3932 // 1. channel dim is dynamic
3933 // 2. channelDimScalableFlag is set
3934 scalableChDim = channelDimScalableFlag;
3935 useMasking = true;
3936 }
3937
3938 assert(!(useMasking && flatten) &&
3939 "Unsupported flattened conv with dynamic shapes");
3940
3941 // out{n, w, c}
3942 bindShapeDims(resShapedType, nSize, wSize);
3943
3944 vector::TransferWriteOp write;
3945 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3946
3947 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3948 // When strideW == 1, we can batch the contiguous loads and avoid
3949 // unrolling
3950 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3951
3952 Type lhsEltType = lhsShapedType.getElementType();
3953 Type rhsEltType = rhsShapedType.getElementType();
3954 Type resEltType = resShapedType.getElementType();
3955 VectorType lhsType = VectorType::get(
3956 {nSize,
3957 // iw = ow * sw + kw * dw - 1
3958 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3959 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3960 cSize},
3961 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3962 VectorType rhsType =
3963 VectorType::get({kwSize, cSize}, rhsEltType,
3964 /*scalableDims=*/{false, scalableChDim});
3965 VectorType resType =
3966 VectorType::get({nSize, wSize, cSize}, resEltType,
3967 /*scalableDims=*/{false, false, scalableChDim});
3968
3969 // Masks the input xfer Op along the channel dim, iff the corresponding
3970 // scalable flag is set.
3971 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3972 ArrayRef<bool> scalableDims,
3973 Operation *opToMask) {
3974 if (!useMasking)
3975 return opToMask;
3976 auto maskType =
3977 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3978
3979 SmallVector<bool> inBounds(maskShape.size(), true);
3980 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3981 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3982 rewriter.getBoolArrayAttr(inBounds));
3983
3984 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3985 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3986
3987 Value maskOp =
3988 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3989
3990 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3991 };
3992
3993 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3994 // 0].
3995 Value lhs = vector::TransferReadOp::create(
3996 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3997 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3998 auto *maybeMaskedLhs = maybeMaskXferOp(
3999 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
4000
4001 // Read rhs slice of size {kw, c} @ [0, 0].
4002 Value rhs = vector::TransferReadOp::create(
4003 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
4004 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
4005 auto *maybeMaskedRhs = maybeMaskXferOp(
4006 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
4007
4008 // Read res slice of size {n, w, c} @ [0, 0, 0].
4009 Value res = vector::TransferReadOp::create(
4010 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
4011 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
4012 auto *maybeMaskedRes = maybeMaskXferOp(
4013 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
4014
4015 //===------------------------------------------------------------------===//
4016 // Begin vector-only rewrite part
4017 //===------------------------------------------------------------------===//
4018 // Unroll along kw and read slices of lhs and rhs.
4019 SmallVector<Value> lhsVals, rhsVals, resVals;
4020 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4021 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4022
4023 // Extract lhs slice of size {n, wSizeStep, c}
4024 // @ [0, sw * w + dw * kw, 0].
4025 for (int64_t kw = 0; kw < kwSize; ++kw) {
4026 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4027 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4028 rewriter, loc, maybeMaskedLhs->getResult(0),
4029 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4030 inOutSliceSizes, inOutStrides));
4031 }
4032 }
4033 // Extract rhs slice of size {c} @ [kw].
4034 for (int64_t kw = 0; kw < kwSize; ++kw) {
4035 rhsVals.push_back(
4036 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4037 /*offsets=*/ArrayRef<int64_t>{kw}));
4038 }
4039 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
4040 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4041 resVals.push_back(vector::ExtractStridedSliceOp::create(
4042 rewriter, loc, maybeMaskedRes->getResult(0),
4043 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4044 inOutStrides));
4045 }
4046
4047 auto linearIndex = [&](int64_t kw, int64_t w) {
4048 return kw * (wSize / wSizeStep) + w;
4049 };
4050
4051 // Note - the scalable flags are ignored as flattening combined with
4052 // scalable vectorization is not supported.
4053 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4054 auto lhsTypeAfterFlattening =
4055 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4056 auto resTypeAfterFlattening =
4057 VectorType::get(inOutFlattenSliceSizes, resEltType);
4058
4059 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
4060 for (int64_t kw = 0; kw < kwSize; ++kw) {
4061 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4062 Value lhsVal = lhsVals[linearIndex(kw, w)];
4063 Value resVal = resVals[w];
4064 if (flatten) {
4065 // Flatten the input and output vectors (collapse the channel
4066 // dimension)
4067 lhsVal =
4068 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4069 lhsVals[linearIndex(kw, w)]);
4070 resVal = vector::ShapeCastOp::create(
4071 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4072 }
4073 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4074 rhsVals[kw], resVal, flatten);
4075 if (flatten) {
4076 // Un-flatten the output vector (restore the channel dimension)
4077 resVals[w] = vector::ShapeCastOp::create(
4078 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4079 resVals[w]);
4080 }
4081 }
4082 }
4083
4084 // Its possible we failed to create the Fma.
4085 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4086 // Manually revert (in reverse order) to avoid leaving a bad IR state.
4087 for (auto &collection :
4088 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4089 for (Value v : collection)
4090 rewriter.eraseOp(v.getDefiningOp());
4091 return rewriter.notifyMatchFailure(op, "failed to create FMA");
4092 }
4093
4094 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4095 // This does not depend on kw.
4096 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4097 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4098 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4099 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4100 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4101 }
4102 //===------------------------------------------------------------------===//
4103 // End vector-only rewrite part
4104 //===------------------------------------------------------------------===//
4105
4106 // Write back res slice of size {n, w, c} @ [0, 0, 0].
4107 Operation *resOut = vector::TransferWriteOp::create(
4108 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4109 ValueRange{zero, zero, zero});
4110 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4111 resOut);
4112 }
4113
4114 /// Lower:
4115 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4116 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4117 /// to MulAcc.
4118 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4119 Value lhs, Value rhs, Value res,
4120 bool flatten) {
4121 auto rhsTy = cast<ShapedType>(rhs.getType());
4122 auto resTy = cast<ShapedType>(res.getType());
4123
4124 // TODO(suderman): Change this to use a vector.ima intrinsic.
4125 lhs = promote(rewriter, loc, lhs, resTy);
4126
4127 if (flatten) {
4128 // NOTE: This following logic won't work for scalable vectors. For this
4129 // reason, "flattening" is not supported when shapes are dynamic (this
4130 // should be captured by one of the pre-conditions).
4131
4132 // There are two options for handling the filter:
4133 // * shape_cast(broadcast(filter))
4134 // * broadcast(shuffle(filter))
4135 // Opt for the option without shape_cast to simplify the codegen.
4136 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4137 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4138
4139 SmallVector<int64_t, 16> indices;
4140 for (int i = 0; i < resSize / rhsSize; ++i) {
4141 for (int j = 0; j < rhsSize; ++j)
4142 indices.push_back(j);
4143 }
4144
4145 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4146 }
4147 // Broadcast the filter to match the output vector
4148 rhs = vector::BroadcastOp::create(rewriter, loc,
4149 resTy.clone(rhsTy.getElementType()), rhs);
4150
4151 rhs = promote(rewriter, loc, rhs, resTy);
4152
4153 if (!lhs || !rhs)
4154 return nullptr;
4155
4156 if (isa<FloatType>(resTy.getElementType()))
4157 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4158
4159 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4160 return arith::AddIOp::create(rewriter, loc, mul, res);
4161 }
4162
4163 /// Entry point for non-channeled convolution:
4164 /// {{w + kw}, {kw}, {w}}
4165 FailureOr<Operation *> generateNonChanneledConv() {
4166 AffineExpr w, kw;
4167 bindDims(ctx, w, kw);
4168 if (!iters({Par(), Red()}))
4169 return rewriter.notifyMatchFailure(op,
4170 "failed to match conv::W 1-par 1-red");
4171
4172 // No transposition needed.
4173 if (layout({/*lhsIndex*/ {w + kw},
4174 /*rhsIndex*/ {kw},
4175 /*resIndex*/ {w}}))
4176 return conv(Conv1DOpOrder::W);
4177
4178 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4179 }
4180
4181 /// Entry point that transposes into the common form:
4182 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4183 FailureOr<Operation *> generateNwcConv() {
4184 AffineExpr n, w, f, kw, c;
4185 bindDims(ctx, n, w, f, kw, c);
4186 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4187 return rewriter.notifyMatchFailure(
4188 op, "failed to match conv::Nwc 3-par 2-red");
4189
4190 // No transposition needed.
4191 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4192 /*rhsIndex*/ {kw, c, f},
4193 /*resIndex*/ {n, w, f}}))
4194 return conv(Conv1DOpOrder::Nwc);
4195
4196 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4197 }
4198
4199 /// Entry point that transposes into the common form:
4200 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4201 FailureOr<Operation *> generateNcwConv() {
4202 AffineExpr n, w, f, kw, c;
4203 bindDims(ctx, n, f, w, c, kw);
4204 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4205 return rewriter.notifyMatchFailure(
4206 op, "failed to match conv::Ncw 3-par 2-red");
4207
4208 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4209 /*rhsIndex*/ {f, c, kw},
4210 /*resIndex*/ {n, f, w}}))
4211 return conv(Conv1DOpOrder::Ncw);
4212
4213 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4214 }
4215
4216 /// Entry point that transposes into the common form:
4217 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4218 FailureOr<Operation *> generateNwcPooling() {
4219 AffineExpr n, w, c, kw;
4220 bindDims(ctx, n, w, c, kw);
4221 if (!iters({Par(), Par(), Par(), Red()}))
4222 return rewriter.notifyMatchFailure(op,
4223 "failed to match pooling 3-par 1-red");
4224
4225 // No transposition needed.
4226 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4227 /*rhsIndex*/ {kw},
4228 /*resIndex*/ {n, w, c}}))
4229 return conv(Conv1DOpOrder::Nwc);
4230
4231 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4232 }
4233
4234 /// Entry point that transposes into the common form:
4235 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4236 FailureOr<Operation *> generateNcwPooling() {
4237 AffineExpr n, w, c, kw;
4238 bindDims(ctx, n, c, w, kw);
4239 if (!iters({Par(), Par(), Par(), Red()}))
4240 return rewriter.notifyMatchFailure(op,
4241 "failed to match pooling 3-par 1-red");
4242
4243 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4244 /*rhsIndex*/ {kw},
4245 /*resIndex*/ {n, c, w}}))
4246 return conv(Conv1DOpOrder::Ncw);
4247
4248 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4249 }
4250
4251 /// Entry point that transposes into the common form:
4252 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4253 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4254 bool vecChDimScalableFlag = false,
4255 bool flatten = false) {
4256 AffineExpr n, w, c, kw;
4257 bindDims(ctx, n, w, c, kw);
4258 if (!iters({Par(), Par(), Par(), Red()}))
4259 return rewriter.notifyMatchFailure(
4260 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4261
4262 // No transposition needed.
4263 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4264 /*rhsIndex*/ {kw, c},
4265 /*resIndex*/ {n, w, c}}))
4266 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4267
4268 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4269 }
4270
4271private:
4272 ConvOperationKind oper = ConvOperationKind::Conv;
4273 StringAttr redOp;
4274 StringAttr poolExtOp;
4275 bool isPoolExt = false;
4276 int strideW, dilationW;
4277 Value lhsShaped, rhsShaped, resShaped;
4278 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4279 vector::CombiningKind reductionKind;
4280
4281 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4282 void setConvOperationKind(Operation *reduceOp) {
4283 int numBlockArguments =
4284 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4285 if (numBlockArguments == 1) {
4286 // Will be convolution if feeder is a MulOp.
4287 // A strength reduced version of MulOp for i1 type is AndOp which is also
4288 // supported. Otherwise, it can be pooling. This strength reduction logic
4289 // is in `buildBinaryFn` helper in the Linalg dialect.
4290 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4291 llvm::IsaPred<BlockArgument>);
4292 Operation *feedOp = (*feedValIt).getDefiningOp();
4293 if (isCastOfBlockArgument(feedOp)) {
4294 oper = ConvOperationKind::Pool;
4295 isPoolExt = true;
4296 poolExtOp = feedOp->getName().getIdentifier();
4297 return;
4298 }
4299 oper = ConvOperationKind::Conv;
4300 return;
4301 }
4302 // numBlockArugments == 2 and this is a pooling op.
4303 oper = ConvOperationKind::Pool;
4304 isPoolExt = false;
4305 }
4306};
4307} // namespace
4308
4309/// Helper function to vectorize a LinalgOp with convolution semantics.
4310// TODO: extend the generic vectorization to support windows and drop this.
4311static FailureOr<Operation *> vectorizeConvolution(
4312 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4313 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4314 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4315 if (failed(conv1dGen))
4316 return failure();
4317 auto res = conv1dGen->generateNonChanneledConv();
4318 if (succeeded(res))
4319 return res;
4320 res = conv1dGen->generateNwcConv();
4321 if (succeeded(res))
4322 return res;
4323 res = conv1dGen->generateNcwConv();
4324 if (succeeded(res))
4325 return res;
4326 res = conv1dGen->generateNwcPooling();
4327 if (succeeded(res))
4328 return res;
4329 res = conv1dGen->generateNcwPooling();
4330 if (succeeded(res))
4331 return res;
4332
4333 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4334 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4335 // masked/scalable) is the channel dim (i.e. the trailing dim).
4336 uint64_t vecChDimSize = ShapedType::kDynamic;
4337 bool vecChDimScalableFlag = false;
4338 if (!inputVecSizes.empty()) {
4339 // Only use the input vector size corresponding to the channel dim. Other
4340 // vector dims will be inferred from the Ops.
4343 "Not a 1D depthwise conv!");
4344 size_t chDimIdx = 0;
4346 chDimIdx = 2;
4348 chDimIdx = 1;
4349
4350 vecChDimSize = inputVecSizes[chDimIdx];
4351 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4352 }
4353 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4354 flatten1DDepthwiseConv);
4355}
4356
4357struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4359
4360 LogicalResult matchAndRewrite(LinalgOp op,
4361 PatternRewriter &rewriter) const override {
4362 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4363 if (failed(resultOrFail))
4364 return failure();
4365 Operation *newOp = *resultOrFail;
4366 if (newOp->getNumResults() == 0) {
4367 rewriter.eraseOp(op.getOperation());
4368 return success();
4369 }
4370 assert(newOp->getNumResults() == 1 && "expected single result");
4371 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4372 return success();
4373 }
4374};
4375
4377 RewritePatternSet &patterns, PatternBenefit benefit) {
4378 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4379}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_iterator operand_end()
Definition Operation.h:375
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:197
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
Definition Utils.h:126
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:732
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override