MLIR 23.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
35#include "mlir/IR/Value.h"
36#include "mlir/Support/LLVM.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
47#include <optional>
48
49using namespace mlir;
50using namespace mlir::linalg;
51
52#define DEBUG_TYPE "linalg-vectorization"
53
54/// Try to vectorize `convOp` as a convolution.
55static FailureOr<Operation *>
56vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
57 ArrayRef<int64_t> inputVecSizes = {},
58 ArrayRef<bool> inputVecScalableFlags = {},
59 bool flatten1DDepthwiseConv = false);
60
61/// Vectorize tensor::InsertSliceOp with:
62/// * vector::TransferReadOp + vector::TransferWriteOp
63/// The vector sizes are either:
64/// * user-provided in `inputVectorSizes`, or
65/// * inferred from the static dims in the input and output tensors.
66/// Bails out if:
67/// * vector sizes are not user-provided, and
68/// * at least one dim is dynamic (in both the input and output tensors).
69///
70/// Before:
71/// !t_in_type = tensor<1x2x3xf32>
72/// !t_out_type = tensor<9x8x7x1x2x3xf32>
73/// !v_type = vector<1x2x3xf32>
74/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
75/// into !t_out_type
76/// After:
77/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
78/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
79static LogicalResult
80vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
81 ArrayRef<int64_t> inputVectorSizes,
82 SmallVectorImpl<Value> &newResults);
83
84/// Returns the effective Pad value for the input op, provided it's a scalar.
85///
86/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
87/// this Op performs padding, retrieve the padding value provided that it's
88/// a scalar and static/fixed for all the padded values. Returns an empty value
89/// otherwise.
91
92/// Return the unique instance of OpType in `block` if it is indeed unique.
93/// Return null if none or more than 1 instances exist.
94template <typename OpType>
95static OpType getSingleOpOfType(Block &block) {
96 OpType res;
97 block.walk([&](OpType op) {
98 if (res) {
99 res = nullptr;
100 return WalkResult::interrupt();
101 }
102 res = op;
103 return WalkResult::advance();
104 });
105 return res;
106}
107
108/// Helper function to extract the input slices after filter is unrolled along
109/// kw.
112 int64_t nSize, int64_t wSize, int64_t cSize,
113 int64_t kwSize, int strideW, int dilationW,
114 int64_t wSizeStep, bool isSingleChanneled) {
116 if (isSingleChanneled) {
117 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
118 // convolution.
119 SmallVector<int64_t> sizes = {wSizeStep};
120 SmallVector<int64_t> strides = {1};
121 for (int64_t kw = 0; kw < kwSize; ++kw) {
122 for (int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
124 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
125 strides));
126 }
127 }
128 } else {
129 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130 // for channeled convolution.
131 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132 SmallVector<int64_t> strides = {1, 1, 1};
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
137 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138 sizes, strides));
139 }
140 }
141 }
142 return result;
143}
144
145/// Helper function to extract the filter slices after filter is unrolled along
146/// kw.
148 Location loc, Value filter,
149 int64_t kwSize) {
151 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152 // non-chanelled convolution] @ [kw].
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
155 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156 }
157 return result;
158}
159
160/// Helper function to extract the result slices after filter is unrolled along
161/// kw.
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep, bool isSingleChanneled) {
167 if (isSingleChanneled) {
168 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169 SmallVector<int64_t> sizes = {wSizeStep};
170 SmallVector<int64_t> strides = {1};
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
173 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
174 strides));
175 }
176 } else {
177 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
178 // convolution.
179 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
180 SmallVector<int64_t> strides = {1, 1, 1};
181 for (int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
183 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
184 strides));
185 }
186 }
187 return result;
188}
189
190/// Helper function to insert the computed result slices.
192 Value res, int64_t wSize, int64_t wSizeStep,
193 SmallVectorImpl<Value> &resVals,
194 bool isSingleChanneled) {
195
196 if (isSingleChanneled) {
197 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
198 // This does not depend on kw.
199 SmallVector<int64_t> strides = {1};
200 for (int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
202 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
203 strides);
204 }
205 } else {
206 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
207 // convolution. This does not depend on kw.
208 SmallVector<int64_t> strides = {1, 1, 1};
209 for (int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
212 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
213 }
214 }
215 return res;
216}
217
218/// Contains the vectorization state and related methods used across the
219/// vectorization process of a given operation.
221 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
222
223 /// Initializes the vectorization state, including the computation of the
224 /// canonical vector shape for vectorization.
225 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
226 ArrayRef<int64_t> inputVectorSizes,
227 ArrayRef<bool> inputScalableVecDims,
228 bool assumeDynamicDimsMatchVecSizes = false);
229
230 /// Returns the canonical vector shape used to vectorize the iteration space.
231 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
232
233 /// Returns the vector dimensions that are scalable in the canonical vector
234 /// shape.
235 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
236
237 /// Returns a vector type of the provided `elementType` with the canonical
238 /// vector shape and the corresponding fixed/scalable dimensions bit. If
239 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
240 /// accordingly.
242 Type elementType,
243 std::optional<AffineMap> dimPermutation = std::nullopt) const {
245 SmallVector<bool> scalableDims;
246 if (dimPermutation.has_value()) {
248 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
249 scalableDims =
250 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
251 } else {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
254 }
255
256 return VectorType::get(vectorShape, elementType, scalableDims);
257 }
258
259 /// Masks an operation with the canonical vector mask if the operation needs
260 /// masking. Returns the masked operation or the original operation if masking
261 /// is not needed. If provided, the canonical mask for this operation is
262 /// permuted using `maybeIndexingMap`.
263 Operation *
264 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
266
267private:
268 /// Initializes the iteration space static sizes using the Linalg op
269 /// information. This may become more complicated in the future.
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
272 }
273
274 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
275 /// all the static and dynamic dimensions of the iteration space to be
276 /// vectorized and store them in `iterSpaceValueSizes`.
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
278 LinalgOp linalgOp);
279
280 /// Create or retrieve an existing mask value to mask `opToMask` in the
281 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
282 /// permuted using that permutation map. If a new mask is created, it will be
283 /// cached for future users.
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
285 LinalgOp linalgOp,
286 std::optional<AffineMap> maybeMaskingMap);
287
288 /// Check whether this permutation map can be used for masking. At the
289 /// moment we only make sure that there are no broadcast dimensions, but this
290 /// might change if indexing maps evolve.
291 bool isValidMaskingMap(AffineMap maskingMap) {
292 return maskingMap.getBroadcastDims().empty();
293 }
294
295 /// Turn the input indexing map into a valid masking map.
296 ///
297 /// The input indexing map may contain "zero" results, e.g.:
298 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
299 /// Applying such maps to canonical vector shapes like this one:
300 /// (1, 16, 16, 4)
301 /// would yield an invalid vector shape like this:
302 /// (16, 16, 1, 0)
303 /// Instead, drop the broadcasting dims that make no sense for masking perm.
304 /// maps:
305 /// (d0, d1, d2, d3) -> (d2, d1, d0)
306 /// This way, the corresponding vector/mask type will be:
307 /// vector<16x16x1xty>
308 /// rather than this invalid Vector type:
309 /// vector<16x16x1x0xty>
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
311 return indexingMap.dropZeroResults();
312 }
313
314 // Holds the compile-time static sizes of the iteration space to vectorize.
315 // Dynamic dimensions are represented using ShapedType::kDynamic.
316 SmallVector<int64_t> iterSpaceStaticSizes;
317
318 /// Holds the value sizes of the iteration space to vectorize. Static
319 /// dimensions are represented by 'arith.constant' and dynamic
320 /// dimensions by 'tensor/memref.dim'.
321 SmallVector<Value> iterSpaceValueSizes;
322
323 /// Holds the canonical vector shape used to vectorize the iteration space.
324 SmallVector<int64_t> canonicalVecShape;
325
326 /// Holds the vector dimensions that are scalable in the canonical vector
327 /// shape.
328 SmallVector<bool> scalableVecDims;
329
330 /// Holds the active masks for permutations of the canonical vector iteration
331 /// space.
332 DenseMap<AffineMap, Value> activeMaskCache;
333
334 /// Global vectorization guard for the incoming rewriter. It's initialized
335 /// when the vectorization state is initialized.
336 OpBuilder::InsertionGuard rewriterGuard;
337
338 /// Do all dynamic dims match the corresponding vector sizes?
339 ///
340 /// When a dynamic tensor/memref dimension matches the corresponding vector
341 /// dimension, masking can be safely skipped, despite the presence of dynamic
342 /// shapes. Use this flag with care and only for cases where you are
343 /// confident the assumption holds.
344 bool assumeDynamicDimsMatchVecSizes = false;
345};
346
347LogicalResult
348VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
349 LinalgOp linalgOp) {
350 // TODO: Support 0-d vectors.
351 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 // Create constant index op for static dimensions.
354 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
356 continue;
357 }
358
359 // Find an operand defined on this dimension of the iteration space to
360 // extract the runtime dimension size.
361 Value operand;
362 unsigned operandDimPos;
363 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
364 operandDimPos)))
365 return failure();
366
367 Value dynamicDim =
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 operandDimPos)
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
372 operandDimPos);
373 iterSpaceValueSizes.push_back(dynamicDim);
374 }
375
376 return success();
377}
378
379/// Initializes the vectorization state, including the computation of the
380/// canonical vector shape for vectorization.
381// TODO: Move this to the constructor when we can remove the failure cases.
383 LinalgOp linalgOp,
384 ArrayRef<int64_t> inputVectorSizes,
385 ArrayRef<bool> inputScalableVecDims,
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
388 // Initialize the insertion point.
389 rewriter.setInsertionPoint(linalgOp);
390
391 if (!inputVectorSizes.empty()) {
392 // Get the canonical vector shape from the input vector sizes provided. This
393 // path should be taken to vectorize code with dynamic shapes and when using
394 // vector sizes greater than the iteration space sizes.
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
398 } else {
399 // Compute the canonical vector shape from the operation shape. If there are
400 // dynamic shapes, the operation won't be vectorized. We assume all the
401 // vector dimensions are fixed.
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(), false);
404 }
405
406 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
408
409 if (ShapedType::isDynamicShape(canonicalVecShape))
410 return failure();
411
412 // Initialize iteration space static sizes.
413 initIterSpaceStaticSizes(linalgOp);
414
415 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
416 // all the static and dynamic dimensions of the iteration space, needed to
417 // compute a mask during vectorization.
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
419 return failure();
420
421 return success();
422}
423
424/// Create or retrieve an existing mask value to mask `opToMask` in the
425/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
426/// using that permutation map. If a new mask is created, it will be cached for
427/// future users.
428Value VectorizationState::getOrCreateMaskFor(
429 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
430 std::optional<AffineMap> maybeMaskingMap) {
431
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
434
435 // No mask is needed if the operation is not maskable.
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
437 if (!maskableOp)
438 return Value();
439
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
442
443 // If no masking map was provided, use an identity map with the loop dims.
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
446 AffineMap maskingMap =
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.getContext());
450
451 LDBG() << "Masking map: " << maskingMap;
452
453 // Return the active mask for the masking map of this operation if it was
454 // already created.
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() << "Reusing mask: " << mask;
459 return mask;
460 }
461
462 // Compute permuted projection of the iteration space to be masked and the
463 // corresponding mask shape. If the resulting iteration space dimensions are
464 // static and identical to the mask shape, masking is not needed for this
465 // operation.
466 // TODO: Improve this check. Only projected permutation indexing maps are
467 // supported.
468 SmallVector<int64_t> permutedStaticSizes =
469 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
470 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
472
473 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
474
475 if (permutedStaticSizes == maskShape) {
476 LDBG() << "Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
478 return Value();
479 }
480
481 if (assumeDynamicDimsMatchVecSizes) {
482 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
483 // vector sizes match, we still need to check the _static_ dim sizes. Only
484 // then we can be 100% sure that masking is not required.
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486 [](auto it) {
487 return std::get<0>(it) == ShapedType::kDynamic
488 ? true
489 : std::get<0>(it) == std::get<1>(it);
490 })) {
491 LDBG()
492 << "Dynamic + static dimensions match vector sizes, masking is not "
493 "required.";
494 activeMaskCache[maskingMap] = Value();
495 return Value();
496 }
497 }
498
499 // Permute the iteration space value sizes to compute the mask upper bounds.
500 SmallVector<Value> upperBounds =
501 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
504
505 // Create the mask based on the dimension values.
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() << "Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
510 return mask;
511}
512
513Operation *
515 LinalgOp linalgOp,
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() << "Trying to mask: " << *opToMask;
518
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
522
523 // Create or retrieve mask for this operation.
524 Value mask =
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526
527 if (!mask) {
528 LDBG() << "No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
532 [&](auto xferOp) {
533 // For vector.transfer_read and vector.transfer_write, there is
534 // also the `in-bounds` attribute that has to be set explicitly
535 // to true. Otherwise, "out-of-bounds" access will be assumed
536 // and masks will be generated while lowering these.
537 LDBG() << "Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
539 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
540 ShapedType xferType = xferOp.getShapedType();
541 AffineMap permMap = xferOp.getPermutationMap();
542 // Only set the in-bounds values to true for dynamic dims.
543 // Different mechanisms will set these accordingly for the
544 // static dims.
545 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
547 // Skip broadcast dimensions.
548 if (!dimExpr)
549 continue;
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] = true;
553 }
554 rewriter.modifyOpInPlace(xferOp, [&]() {
555 xferOp.setInBoundsAttr(
556 rewriter.getBoolArrayAttr(inBoundsMap));
557 });
558 })
559 .Default([](Operation *op) {
560 // No-op if the operation is not an xfer read or write.
561 });
562 }
563 return opToMask;
564 }
565
566 // Wrap the operation with a new `vector.mask` and update D-U chain.
567 assert(opToMask && "Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
569 mlir::vector::maskOperation(rewriter, opToMask, mask));
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
571
572 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
573 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
574 maskOpTerminator);
575
576 LDBG() << "Masked operation: " << *maskOp;
577 return maskOp;
578}
579
580/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
581/// projectedPermutation, compress the unused dimensions to serve as a
582/// permutation_map for a vector transfer operation.
583/// For example, given a linalg op such as:
584///
585/// ```
586/// %0 = linalg.generic {
587/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
588/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
589/// }
590/// ins(%0 : tensor<2x3x4xf32>)
591/// outs(%1 : tensor<5x6xf32>)
592/// ```
593///
594/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
595/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
596/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
598 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
599 "expected projected permutation");
600 auto res = compressUnusedDims(map);
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
604 return res;
605}
606
607/// Helper enum to represent conv1d input traversal order.
608enum class Conv1DOpOrder {
609 W, // Corresponds to non-channeled 1D convolution operation.
610 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
611 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
612};
613
614/// Helper data structure to represent the result of vectorization for a single
615/// operation. In certain specific cases, like terminators, we do not want to
616/// propagate.
618 /// Op failed to vectorize.
620 /// Op vectorized and custom function took care of replacement logic
622 /// Op vectorized into a new Op whose results will replace original Op's
623 /// results.
625 // TODO: support values if Op vectorized to Many-Ops whose results we need to
626 // aggregate for replacement.
627};
628/// VectorizationHookResult contains the vectorized op returned from a
629/// CustomVectorizationHook. This is an internal implementation detail of
630/// linalg vectorization, not to be confused with VectorizationResult.
632 /// Return status from vectorizing the current op.
634 /// New vectorized operation to replace the current op.
635 /// Replacement behavior is specified by `status`.
637};
638
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
642
643 if (!combinerOp)
644 return std::nullopt;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](auto op) { return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) { return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](auto op) { return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) { return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) { return CombiningKind::XOR; })
661 .Default(std::nullopt);
662}
663
664/// Check whether `outputOperand` is a reduction with a single combiner
665/// operation. Return the combiner operation of the reduction. Return
666/// nullptr otherwise. Multiple reduction operations would impose an
667/// ordering between reduction dimensions and is currently unsupported in
668/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
669/// max(min(X))
670// TODO: use in LinalgOp verification, there is a circular dependency atm.
671static Operation *matchLinalgReduction(OpOperand *outputOperand) {
672 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
673 unsigned outputPos =
674 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
675 // Only single combiner operations are supported for now.
676 SmallVector<Operation *, 4> combinerOps;
677 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
679 return nullptr;
680
681 // Return the combiner operation.
682 return combinerOps[0];
683}
684
685/// Broadcast `value` to a vector of `shape` if possible. Return value
686/// otherwise.
687static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
688 auto dstVecType = dyn_cast<VectorType>(dstType);
689 // If no shape to broadcast to, just return `value`.
690 if (dstVecType.getRank() == 0)
691 return value;
692 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
694 return value;
695 Location loc = b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
697}
698
699/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
700/// assumes that `reductionOp` has two operands and one of them is the reduction
701/// initial value.buildMultiDimReduce
702// Note: this is a true builder that notifies the OpBuilder listener.
703// TODO: Consider moving as a static helper on the ReduceOp.
705 Value valueToReduce, Value acc,
706 ArrayRef<bool> dimsToMask) {
707 auto maybeKind = getCombinerOpKind(reduceOp);
708 assert(maybeKind && "Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
711}
712
713static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
716}
717
718/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
719/// reduction iterator.
720static bool hasReductionIterator(LinalgOp &op) {
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
723 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
724}
725
726/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
727/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
728/// currently being vectorized. If `dest` has null rank, build an memref.store.
729/// Return the produced value or null if no value is produced.
730// Note: this is a true builder that notifies the OpBuilder listener.
731// TODO: Consider moving as a static helper on the ReduceOp.
732static Value buildVectorWrite(RewriterBase &rewriter, Value value,
733 OpOperand *outputOperand,
734 VectorizationState &state) {
735 Location loc = value.getLoc();
736 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
738
739 // Compute the vector type of the value to store. This type should be an
740 // identity or projection of the canonical vector type without any permutation
741 // applied, given that any permutation in a transfer write happens as part of
742 // the write itself.
744 opOperandMap.getContext(), opOperandMap.getNumInputs(),
745 [&](AffineDimExpr dimExpr) -> bool {
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
747 });
748 auto vectorType = state.getCanonicalVecType(
749 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
750
751 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
752 arith::ConstantIndexOp::create(rewriter, loc, 0));
753
754 Operation *write;
755 if (vectorType.getRank() > 0) {
756 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
757 value = broadcastIfNeeded(rewriter, value, vectorType);
758 assert(value.getType() == vectorType && "Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->get(), indices, writeMap);
761 } else {
762 // 0-d case is still special: do not invert the reindexing writeMap.
763 if (!isa<VectorType>(value.getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.getType() == vectorType && "Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
767 outputOperand->get(), indices);
768 }
769
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
771
772 // If masked, set in-bounds to true. Masking guarantees that the access will
773 // be in-bounds.
774 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
776 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
777 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
778 }
779
780 LDBG() << "vectorized op: " << *write;
781 if (!write->getResults().empty())
782 return write->getResult(0);
783 return Value();
784}
785
786// Custom vectorization precondition function type. This is intented to be used
787// with CustomVectorizationHook. Returns success if the corresponding custom
788// hook can vectorize the op.
790 std::function<LogicalResult(Operation *, bool)>;
791
792// Custom vectorization function type. Produce a vector form of Operation*
793// assuming all its vectorized operands are already in the IRMapping.
794// Return nullptr if the Operation cannot be vectorized.
796 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
797
798/// Helper function to vectorize the terminator of a `linalgOp`. New result
799/// vector values are appended to `newResults`. Return
800/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
801/// that it should not try to map produced operations and instead return the
802/// results using the `newResults` vector making them available to the
803/// vectorization algorithm for RAUW. This function is meant to be used as a
804/// CustomVectorizationHook.
807 const IRMapping &bvm, VectorizationState &state,
808 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
810 if (!yieldOp)
812 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
813 // TODO: Scan for an opportunity for reuse.
814 // TODO: use a map.
815 Value vectorValue = bvm.lookup(output.value());
816 Value newResult =
817 buildVectorWrite(rewriter, vectorValue,
818 linalgOp.getDpsInitOperand(output.index()), state);
819 if (newResult)
820 newResults.push_back(newResult);
821 }
822
824}
825
826/// Helper function to vectorize the index operations of a `linalgOp`. Return
827/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
828/// should map the produced operations. This function is meant to be used as a
829/// CustomVectorizationHook.
831 VectorizationState &state,
832 Operation *op,
833 LinalgOp linalgOp) {
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 if (!indexOp)
837 auto loc = indexOp.getLoc();
838 // Compute the static loop sizes of the index op.
839 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
840 auto dim = indexOp.getDim();
841 // Compute a one-dimensional index vector for the index op dimension.
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
846 // Return the one-dimensional index vector if it lives in the trailing
847 // dimension of the iteration space since the vectorization algorithm in this
848 // case can handle the broadcast.
849 if (dim == targetShape.size() - 1)
851 // Otherwise permute the targetShape to move the index dimension last,
852 // broadcast the one-dimensional index vector to the permuted shape, and
853 // finally transpose the broadcasted index vector to undo the permutation.
854 auto permPattern =
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
857 auto permMap =
858 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
859
860 auto broadCastOp = vector::BroadcastOp::create(
861 rewriter, loc,
862 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
863 SmallVector<int64_t> transposition =
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
866 auto transposeOp =
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
869}
870
871/// Helper function to check if the tensor.extract can be vectorized by the
872/// custom hook vectorizeTensorExtract.
873static LogicalResult
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
876 if (!extractOp)
877 return failure();
878
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
880 return failure();
881
882 // Check the index type, but only for non 0-d tensors (for which we do need
883 // access indices).
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
886 return failure();
887 }
888
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
891 return failure();
892 }
893
894 return success();
895}
896
897/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
898/// generated from `tensor.extract`. The offset is calculated as follows
899/// (example using scalar values):
900///
901/// offset = extractOp.indices[0]
902/// for (i = 1; i < numIndices; i++)
903/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
904///
905/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
906/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
910 const IRMapping &bvm) {
911 // The vector of indices for GatherOp should be shaped as the output vector.
912 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
913 auto loc = extractOp.getLoc();
914
915 Value offset = broadcastIfNeeded(
916 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
917
918 const size_t numIndices = extractOp.getIndices().size();
919 for (size_t i = 1; i < numIndices; i++) {
920 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
921
922 auto dimSize = broadcastIfNeeded(
923 rewriter,
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 indexVecType);
926
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928
929 auto extractOpIndex = broadcastIfNeeded(
930 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
931
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
933 }
934
935 return offset;
936}
937
939
940/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
941/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
942/// represents a contiguous load operation.
943///
944/// Note that when calling this hook, it is assumed that the output vector is
945/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
946/// labelled as a gather load before entering this method.
947///
948/// Following on from the above, it is assumed that:
949/// * for statically shaped loops, when no masks are used, only one dim is !=
950/// 1 (that's what the shape of the output vector is based on).
951/// * for dynamically shaped loops, there might be more non-unit dims
952/// as the output vector type is user-specified.
953///
954/// TODO: Statically shaped loops + vector masking
955static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
956 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
957 assert(
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
963
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
967 break;
968
969 return idx;
970}
971
972/// Checks whether `val` can be used for calculating a loop invariant index.
973static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
974 VectorType resType) {
975
976 assert(((llvm::count_if(resType.getShape(),
977 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
979
980 // Blocks outside _this_ linalg.generic are effectively loop invariant.
981 // However, analysing block arguments for _this_ linalg.generic Op is a bit
982 // tricky. Just bail out in the latter case.
983 // TODO: We could try analysing the corresponding affine map here.
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
987
988 Operation *defOp = val.getDefiningOp();
989 assert(defOp && "This is neither a block argument nor an operation result");
990
991 // IndexOp is loop invariant as long as its result remains constant across
992 // iterations. Note that for dynamic shapes, the corresponding dim will also
993 // be conservatively treated as != 1.
994 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 }
997
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
999
1000 // Values define outside `linalgOp` are loop invariant.
1001 if (!ancestor)
1002 return true;
1003
1004 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1005 if (isa<arith::ConstantOp>(ancestor))
1006 return true;
1007
1008 bool result = true;
1009 for (auto op : ancestor->getOperands())
1010 result &= isLoopInvariantIdx(linalgOp, op, resType);
1011
1012 return result;
1013}
1014
1015/// Check whether `val` could be used for calculating the trailing index for a
1016/// contiguous load operation.
1017///
1018/// There are currently 3 types of values that are allowed here:
1019/// 1. loop-invariant values,
1020/// 2. values that increment by 1 with every loop iteration,
1021/// 3. results of basic arithmetic operations (linear and continuous)
1022/// involving 1., 2. and 3.
1023/// This method returns True if indeed only such values are used in calculating
1024/// `val.`
1025///
1026/// Additionally, the trailing index for a contiguous load operation should
1027/// increment by 1 with every loop iteration, i.e. be based on:
1028/// * `linalg.index <dim>` ,
1029/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1030/// `linalg.index <dim>` increments by 1 with every loop iteration).
1031/// `foundIndexOp` is updated to `true` when such Op is found.
1032static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1033 bool &foundIndexOp, VectorType resType) {
1034
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1038
1039 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1040 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1041 // tricky. Just bail out in the latter case.
1042 // TODO: We could try analysing the corresponding affine map here.
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1046
1047 Operation *defOp = val.getDefiningOp();
1048 assert(defOp && "This is neither a block argument nor an operation result");
1049
1050 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1052
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1054 return true;
1055 }
1056
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1058
1059 if (!ancestor)
1060 return false;
1061
1062 // Conservatively reject Ops that could lead to indices with stride other
1063 // than 1.
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1065 return false;
1066
1067 bool result = false;
1068 for (auto op : ancestor->getOperands())
1069 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1070
1071 return result;
1072}
1073
1074/// Infer the memory access pattern for the input ExtractOp
1075///
1076/// Based on the ExtratOp result shape and the access indices, decides whether
1077/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1078/// or a gather load. When analysing the ExtractOp indices (to identify
1079/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1080/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1081/// Op).
1082///
1083/// Note that it is always safe to use gather load operations for contiguous
1084/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1085/// that `extractOp` is a gather load.
1087getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1088 LinalgOp &linalgOp, VectorType resType) {
1089
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091
1092 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1093 if (inputShape.getShape().empty())
1095
1096 // 0a. Is the result a 0-D vector? If yes, there are no iteration dimensions
1097 // so the tensor.extract is a single scalar load regardless of the index.
1098 if (resType.getRank() == 0)
1100
1101 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1102 // otherwise.
1103 bool isOutput1DVector =
1104 (llvm::count_if(resType.getShape(),
1105 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1106 // 1. Assume that it's a gather load when reading non-1D vector.
1107 if (!isOutput1DVector)
1109
1110 bool leadingIdxsLoopInvariant = true;
1111
1112 // 2. Analyze the leading indices of `extractOp`.
1113 // Look at the way each index is calculated and decide whether it is suitable
1114 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1115 // gather load.
1116 auto indices = extractOp.getIndices();
1117 auto leadIndices = indices.drop_back(1);
1118
1119 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1120 if (inputShape.getShape()[i] == 1)
1121 continue;
1122
1123 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1124 }
1125
1126 if (!leadingIdxsLoopInvariant) {
1127 LDBG() << "Found gather load: " << extractOp;
1129 }
1130
1131 // 3. Analyze the trailing index for `extractOp`.
1132 // At this point we know that the leading indices are loop invariant. This
1133 // means that is potentially a scalar or a contiguous load. We can decide
1134 // based on the trailing idx.
1135 auto extractOpTrailingIdx = indices.back();
1136
1137 // 3a. Scalar broadcast load
1138 // If the trailing index is loop invariant then this is a scalar load.
1139 if (leadingIdxsLoopInvariant &&
1140 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1141 LDBG() << "Found scalar broadcast load: " << extractOp;
1142
1144 }
1145
1146 // 3b. Contiguous loads
1147 // The trailing `extractOp` index should increment with every loop iteration.
1148 // This effectively means that it must be based on the trailing loop index.
1149 // This is what the following bool captures.
1150 bool foundIndexOp = false;
1151 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1152 foundIndexOp, resType);
1153 // TODO: Support generating contiguous loads for column vectors - that will
1154 // require adding a permutation map to tranfer_read Ops.
1155 bool isRowVector = resType.getShape().back() != 1;
1156 isContiguousLoad &= (foundIndexOp && isRowVector);
1157
1158 if (isContiguousLoad) {
1159 LDBG() << "Found contigous load: " << extractOp;
1161 }
1162
1163 // 4. Fallback case - gather load.
1164 LDBG() << "Found gather load: " << extractOp;
1166}
1167
1168/// Helper function to vectorize the tensor.extract operations. Returns
1169/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1170/// should map the produced operations. This function is meant to be used as a
1171/// CustomVectorizationHook.
1172static VectorizationHookResult
1173vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1174 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1175 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1176 if (!extractOp)
1178 auto loc = extractOp.getLoc();
1179
1180 // Compute the static loop sizes of the extract op.
1181 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1182 auto maskConstantOp = arith::ConstantOp::create(
1183 rewriter, loc,
1184 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1185 /*value=*/true));
1186 auto passThruConstantOp = arith::ConstantOp::create(
1187 rewriter, loc, rewriter.getZeroAttr(resultType));
1188
1189 // Base indices are currently set to 0. We will need to re-visit if more
1190 // generic scenarios are to be supported.
1191 SmallVector<Value> baseIndices(
1192 extractOp.getIndices().size(),
1193 arith::ConstantIndexOp::create(rewriter, loc, 0));
1194
1195 VectorMemoryAccessKind memAccessKind =
1196 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1197
1198 // 1. Handle gather access
1199 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1200 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1201
1202 // Generate the gather load
1203 Operation *gatherOp = vector::GatherOp::create(
1204 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1205 maskConstantOp, passThruConstantOp);
1206 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1207
1208 LDBG() << "Vectorised as gather load: " << extractOp;
1210 }
1211
1212 // 2. Handle:
1213 // a. scalar loads + broadcast,
1214 // b. contiguous loads.
1215 // Both cases use vector.transfer_read.
1216
1217 // Collect indices for `vector.transfer_read`. At this point, the indices will
1218 // either be scalars or would have been broadcast to vectors matching the
1219 // result type. For indices that are vectors, there are two options:
1220 // * for non-trailing indices, all elements are identical (contiguous
1221 // loads are identified by looking for non-trailing indices that are
1222 // invariant with respect to the corresponding linalg.generic), or
1223 // * for trailing indices, the index vector will contain values with stride
1224 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1225 // needed.
1226 // This means that
1227 // * for scalar indices - just re-use it,
1228 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1229 // (0th) element and use that.
1230 SmallVector<Value> transferReadIdxs;
1231 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1232 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1233 if (idx.getType().isIndex()) {
1234 transferReadIdxs.push_back(idx);
1235 continue;
1236 }
1237
1238 auto indexAs1dVector = vector::ShapeCastOp::create(
1239 rewriter, loc,
1240 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1241 resultType.getScalableDims().back()),
1242 idx);
1243 transferReadIdxs.push_back(
1244 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1245 }
1246
1247 // `tensor.extract_element` is always in-bounds, hence the following holds.
1248 auto dstRank = resultType.getRank();
1249 auto srcRank = extractOp.getTensor().getType().getRank();
1250 SmallVector<bool> inBounds(dstRank, true);
1251
1252 // 2a. Handle scalar broadcast access.
1253 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1254 MLIRContext *ctx = rewriter.getContext();
1255 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1256 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1257
1258 auto transferReadOp = vector::TransferReadOp::create(
1259 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1260 /*padding=*/std::nullopt, permutationMap, inBounds);
1261
1262 Operation *readOrMaskedReadOp = transferReadOp;
1263 if (dstRank > 0) {
1264 // Mask this broadcasting xfer_read here rather than relying on the
1265 // generic path (the generic path assumes identity masking map, which
1266 // wouldn't be valid here).
1267 SmallVector<int64_t> readMaskShape = {1};
1268 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1269 auto allTrue = vector::ConstantMaskOp::create(
1270 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1271 readOrMaskedReadOp =
1272 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1273 }
1274
1275 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1277 readOrMaskedReadOp};
1278 }
1279
1280 // 2b. Handle contiguous access.
1281 auto permutationMap = AffineMap::getMinorIdentityMap(
1282 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1283
1284 int32_t rankDiff = dstRank - srcRank;
1285 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1286 // dims so that the ranks match. This is done by extending the map with 0s.
1287 // For example, for dstRank = 3, srcRank = 2, the following map created
1288 // above:
1289 // (d0, d1) --> (d0, d1)
1290 // is extended as:
1291 // (d0, d1) --> (0, d0, d1)
1292 while (rankDiff > 0) {
1293 permutationMap = permutationMap.insertResult(
1294 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1295 rankDiff--;
1296 }
1297
1298 auto transferReadOp = vector::TransferReadOp::create(
1299 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1300 /*padding=*/std::nullopt, permutationMap, inBounds);
1301
1302 LDBG() << "Vectorised as contiguous load: " << extractOp;
1304 transferReadOp};
1305}
1306
1307/// Emit reduction operations if the shapes of the value to reduce is different
1308/// that the result shape.
1309// Note: this is a true builder that notifies the OpBuilder listener.
1310// TODO: Consider moving as a static helper on the ReduceOp.
1311static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1312 Value reduceValue, Value initialValue,
1313 const IRMapping &bvm) {
1314 Value reduceVec = bvm.lookup(reduceValue);
1315 Value outputVec = bvm.lookup(initialValue);
1316 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1317 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1318 // Reduce only if needed as the value may already have been reduce for
1319 // contraction vectorization.
1320 if (!reduceType ||
1321 (outputType && reduceType.getShape() == outputType.getShape()))
1322 return nullptr;
1323 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1324 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1325}
1326
1327/// Generic vectorization for a single operation `op`, given already vectorized
1328/// operands carried by `bvm`. Vectorization occurs as follows:
1329/// 1. Try to apply any of the `customVectorizationHooks` and return its
1330/// result on success.
1331/// 2. Clone any constant in the current scope without vectorization: each
1332/// consumer of the constant will later determine the shape to which the
1333/// constant needs to be broadcast to.
1334/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1335/// of the `customVectorizationHooks` to cover such cases.
1336/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1337/// operand of maximal rank. Other operands have smaller rank and are
1338/// broadcast accordingly. It is assumed this broadcast is always legal,
1339/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1340///
1341/// This function assumes all operands of `op` have been vectorized and are in
1342/// the `bvm` mapping. As a consequence, this function is meant to be called on
1343/// a topologically-sorted list of ops.
1344/// This function does not update `bvm` but returns a VectorizationHookStatus
1345/// that instructs the caller what `bvm` update needs to occur.
1346static VectorizationHookResult
1347vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1348 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1349 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1350 LDBG() << "vectorize op " << *op;
1351
1352 // 1. Try to apply any CustomVectorizationHook.
1353 if (!customVectorizationHooks.empty()) {
1354 for (auto &customFunc : customVectorizationHooks) {
1355 VectorizationHookResult result = customFunc(op, bvm);
1357 continue;
1358 return result;
1359 }
1360 }
1361
1362 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1363 // Clone so that the constant is not confined to the linalgOp block .
1364 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1366 rewriter.clone(*op)};
1367
1368 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1371
1372 // 4 . Check if the operation is a reduction.
1373 SmallVector<std::pair<Value, Value>> reductionOperands;
1374 for (Value operand : op->getOperands()) {
1375 auto blockArg = dyn_cast<BlockArgument>(operand);
1376 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1377 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1378 continue;
1379 SmallVector<Operation *> reductionOps;
1380 Value reduceValue = matchReduction(
1381 linalgOp.getRegionOutputArgs(),
1382 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1383 if (!reduceValue)
1384 continue;
1385 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1386 }
1387 if (!reductionOperands.empty()) {
1388 assert(reductionOperands.size() == 1);
1389 Operation *reduceOp =
1390 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1391 reductionOperands[0].second, bvm);
1392 if (reduceOp)
1394 }
1395
1396 // 5. Generic vectorization path for ElementwiseMappable ops.
1397 // a. Get the first max ranked shape.
1398 VectorType firstMaxRankedType;
1399 for (Value operand : op->getOperands()) {
1400 auto vecOperand = bvm.lookup(operand);
1401 assert(vecOperand && "Vector operand couldn't be found");
1402
1403 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1404 if (vecType && (!firstMaxRankedType ||
1405 firstMaxRankedType.getRank() < vecType.getRank()))
1406 firstMaxRankedType = vecType;
1407 }
1408 // b. Broadcast each op if needed.
1409 SmallVector<Value> vecOperands;
1410 for (Value scalarOperand : op->getOperands()) {
1411 Value vecOperand = bvm.lookup(scalarOperand);
1412 assert(vecOperand && "Vector operand couldn't be found");
1413
1414 if (firstMaxRankedType) {
1415 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1416 getElementTypeOrSelf(vecOperand.getType()),
1417 firstMaxRankedType.getScalableDims());
1418 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1419 } else {
1420 vecOperands.push_back(vecOperand);
1421 }
1422 }
1423 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1424 SmallVector<Type> resultTypes;
1425 for (Type resultType : op->getResultTypes()) {
1426 resultTypes.push_back(
1427 firstMaxRankedType
1428 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1429 firstMaxRankedType.getScalableDims())
1430 : resultType);
1431 }
1432 // d. Build and return the new op.
1435 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1436 resultTypes, op->getAttrs())};
1437}
1438
1439/// Generic vectorization function that rewrites the body of a `linalgOp` into
1440/// vector form. Generic vectorization proceeds as follows:
1441/// 1. Verify the `linalgOp` has one non-empty region.
1442/// 2. Values defined above the region are mapped to themselves and will be
1443/// broadcasted on a per-need basis by their consumers.
1444/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1445/// load).
1446/// TODO: Reuse opportunities for RAR dependencies.
1447/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1448/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1449/// iteration indices.
1450/// 5. Iteratively call vectorizeOneOp on the region operations.
1451///
1452/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1453/// performed to the maximal common vector size implied by the `linalgOp`
1454/// iteration space. This eager broadcasting is introduced in the
1455/// permutation_map of the vector.transfer_read operations. The eager
1456/// broadcasting makes it trivial to determine where broadcast, transposes and
1457/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1458/// the absence of good canonicalizations, the amount of work increases.
1459/// This is not deemed a problem as we expect canonicalizations and foldings to
1460/// aggressively clean up the useless work.
1461static LogicalResult
1462vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1463 LinalgOp linalgOp,
1464 SmallVectorImpl<Value> &newResults) {
1465 LDBG() << "Vectorizing operation as linalg generic/n";
1466 Block *block = linalgOp.getBlock();
1467
1468 // 2. Values defined above the region can only be broadcast for now. Make them
1469 // map to themselves.
1470 IRMapping bvm;
1471 SetVector<Value> valuesSet;
1472 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1473 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1474
1475 if (linalgOp.getNumDpsInits() == 0)
1476 return failure();
1477
1478 // 3. Turn all BBArgs into vector.transfer_read / load.
1479 Location loc = linalgOp.getLoc();
1480 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1481 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1482 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1483 if (linalgOp.isScalar(opOperand)) {
1484 bvm.map(bbarg, opOperand->get());
1485 continue;
1486 }
1487
1488 // 3.a. Convert the indexing map for this input/output to a transfer read
1489 // permutation map and masking map.
1490 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1491
1492 AffineMap readMap;
1493 VectorType readType;
1494 Type elemType = getElementTypeOrSelf(opOperand->get());
1495 if (linalgOp.isDpsInput(opOperand)) {
1496 // 3.a.i. For input reads we use the canonical vector shape.
1497 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1498 readType = state.getCanonicalVecType(elemType);
1499 } else {
1500 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1501 // reductions), the vector shape is computed by mapping the canonical
1502 // vector shape to the output domain and back to the canonical domain.
1503 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1504 readType =
1505 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1506 }
1507
1508 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1509
1510 Operation *read = vector::TransferReadOp::create(
1511 rewriter, loc, readType, opOperand->get(), indices,
1512 /*padding=*/std::nullopt, readMap);
1513 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1514 Value readValue = read->getResult(0);
1515
1516 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1517 // will be in-bounds.
1518 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1519 SmallVector<bool> inBounds(readType.getRank(), true);
1520 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1521 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1522 }
1523
1524 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1525 // TODO: remove this.
1526 if (readType.getRank() == 0)
1527 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1529
1530 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1531 << "): " << readValue;
1532 bvm.map(bbarg, readValue);
1533 bvm.map(opOperand->get(), readValue);
1534 }
1535
1537 // 4a. Register CustomVectorizationHook for yieldOp.
1538 CustomVectorizationHook vectorizeYield =
1539 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1540 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1541 };
1542 hooks.push_back(vectorizeYield);
1543
1544 // 4b. Register CustomVectorizationHook for indexOp.
1545 CustomVectorizationHook vectorizeIndex =
1546 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1547 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1548 };
1549 hooks.push_back(vectorizeIndex);
1550
1551 // 4c. Register CustomVectorizationHook for extractOp.
1552 CustomVectorizationHook vectorizeExtract =
1553 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1554 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1555 };
1556 hooks.push_back(vectorizeExtract);
1557
1558 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1559 for (Operation &op : block->getOperations()) {
1561 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1563 LDBG() << "failed to vectorize: " << op;
1564 return failure();
1565 }
1566 if (result.status == VectorizationHookStatus::NewOp) {
1567 Operation *maybeMaskedOp =
1568 state.maskOperation(rewriter, result.newOp, linalgOp);
1569 LDBG() << "New vector op: " << *maybeMaskedOp;
1570 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1571 }
1572 }
1573
1574 return success();
1575}
1576
1577/// Determines whether a mask for xfer_write is trivially "all true"
1578///
1579/// Given all the inputs required to generate a mask (mask sizes and shapes),
1580/// and an xfer_write operation (write indices and the destination tensor
1581/// shape), determines whether the corresponding mask would be trivially
1582/// foldable (i.e., trivially "all true").
1583///
1584/// Use this method to avoid generating spurious masks and relaying on
1585/// vectorization post-processing to remove them.
1586///
1587/// Pre-conditions for a mask to be trivially foldable:
1588/// * All involved shapes (mask + destination tensor) are static.
1589/// * All write indices are constant.
1590/// * All mask sizes are constant (including `arith.constant`).
1591///
1592/// If the pre-conditions are met, the method checks for each destination
1593/// dimension `d`:
1594/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1595/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1596///
1597/// rankDiff = rank(dest) - rank(mask).
1598///
1599/// This method takes a conservative view: it may return false even if the mask
1600/// is technically foldable.
1601///
1602/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1603/// of the dest tensor):
1604/// %c0 = arith.constant 0 : index
1605/// %mask = vector.create_mask 5, 1
1606/// vector.mask %mask {
1607/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1608/// {in_bounds = [true, true]}
1609/// : vector<5x1xi32>, tensor<5x1xi32>
1610/// }
1611///
1612/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1613/// mask is required to avoid out-of-bounds write):
1614/// %c0 = arith.constant 0 : index
1615/// %mask = vector.create_mask 5, 1
1616/// vector.mask %mask {
1617/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1618/// {in_bounds = [true, true]}
1619/// : vector<8x1xi32>, tensor<5x1xi32>
1620/// }
1621///
1622/// TODO: Re-use in createReadOrMaskedRead
1624 SmallVector<Value> &writeIdxs,
1625 ArrayRef<int64_t> destShape,
1626 ArrayRef<int64_t> maskShape) {
1627 // Masking is unavoidable in the case of dynamic tensors.
1628 if (ShapedType::isDynamicShape(destShape))
1629 return false;
1630
1631 // Collect all constant mask sizes.
1632 SmallVector<int64_t, 4> cstMaskSizes;
1633 for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1634 if (auto intSize = getConstantIntValue(dimSize)) {
1635 cstMaskSizes.push_back(*intSize);
1636 }
1637 }
1638
1639 // If any of the mask sizes is non-constant, bail out.
1640 if (cstMaskSizes.size() != maskShape.size())
1641 return false;
1642
1643 // Collect all constant write indices.
1644 SmallVector<int64_t, 4> cstWriteIdxs;
1645 for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1646 APSInt intVal;
1647 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1648 cstWriteIdxs.push_back(intVal.getSExtValue());
1649 }
1650 }
1651
1652 // If any of the write indices is non-constant, bail out.
1653 if (cstWriteIdxs.size() != destShape.size())
1654 return false;
1655
1656 // Go over all destination dims and check (1) and (2). Take into account that:
1657 // * The number of mask sizes will match the rank of the vector to store.
1658 // This could be lower than the rank of the destination tensor.
1659 // * Mask sizes could be larger than the corresponding mask shape (hence
1660 // `clamp`).
1661 // TODO: The 2nd item should be rejected by the verifier.
1662 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1663 for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1664 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1665 /*(2)*/ destShape[rankDiff + i] <
1666 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1667 cstWriteIdxs[i]))
1668 return false;
1669 }
1670
1671 return true;
1672}
1673
1674/// Creates an optionally masked TransferWriteOp
1675///
1676/// Generates the following operation:
1677/// %res = vector.transfer_write %vecToStore into %dest
1678///
1679/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1680///
1681/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1682/// %res = vector.mask %mask {
1683/// vector.transfer_write %vecToStore into %dest
1684/// }
1685///
1686/// The mask shape is identical to `vecToStore` (with the element type ==
1687/// i1), and the mask values are based on the shape of the `dest` tensor.
1688///
1689/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1690/// is used instead of masking:
1691///
1692/// %write = vector.transfer_write %vecToStore into %dest
1693/// in_bounds_flags = (...)
1694/// %res = vector.transfer_write %input into %dest
1695/// {in_bounds = in_bounds_flags}
1696///
1697/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1698/// are set to 0.
1699static Operation *
1701 Value dest, SmallVector<Value> writeIndices = {},
1702 bool useInBoundsInsteadOfMasking = false) {
1703
1704 ShapedType destType = cast<ShapedType>(dest.getType());
1705 int64_t destRank = destType.getRank();
1706 auto destShape = destType.getShape();
1707
1708 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1709 int64_t vecToStoreRank = vecToStoreType.getRank();
1710 auto vecToStoreShape = vecToStoreType.getShape();
1711
1712 // Compute the in_bounds attribute
1713 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1714 if (useInBoundsInsteadOfMasking) {
1715 // Update the inBounds attribute.
1716 // FIXME: This computation is too weak - it ignores the write indices.
1717 for (unsigned i = 0; i < vecToStoreRank; i++)
1718 inBoundsVal[i] =
1719 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1720 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1721 }
1722
1723 // If missing, initialize the write indices to 0.
1724 bool useDefaultWriteIdxs = writeIndices.empty();
1725 assert((useDefaultWriteIdxs ||
1726 writeIndices.size() == static_cast<size_t>(destRank)) &&
1727 "Invalid number of write indices!");
1728 if (writeIndices.empty()) {
1729 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1730 writeIndices.assign(destRank, zero);
1731 }
1732
1733 // Generate the xfer_write Op
1734 Operation *write = vector::TransferWriteOp::create(builder, loc,
1735 /*vector=*/vecToStore,
1736 /*source=*/dest,
1737 /*indices=*/writeIndices,
1738 /*inBounds=*/inBoundsVal);
1739
1740 // If masking is disabled, exit.
1741 if (useInBoundsInsteadOfMasking)
1742 return write;
1743
1744 // Check if masking is needed. If not, exit.
1745 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1746 return write;
1747
1748 // Compute the mask and mask the write Op.
1749 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1750 vecToStoreType.getScalableDims());
1751
1752 SmallVector<OpFoldResult> destSizes =
1753 isa<MemRefType>(dest.getType())
1754 ? memref::getMixedSizes(builder, loc, dest)
1755 : tensor::getMixedSizes(builder, loc, dest);
1756
1757 // Compute sizes for write-mask
1758 SmallVector<OpFoldResult> maskSizes;
1759 if (useDefaultWriteIdxs) {
1760 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1761 destSizes.end());
1762 } else {
1763 size_t diff = destShape.size() - vecToStoreRank;
1764 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1765 auto value =
1766 getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
1767 auto neg =
1768 builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1769 maskSizes.push_back(OpFoldResult(neg));
1770 }
1771 }
1772
1773 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1774 vecToStoreShape))
1775 return write;
1776
1777 Value maskForWrite =
1778 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1779 return mlir::vector::maskOperation(builder, write, maskForWrite);
1780}
1781
1782/// Given the re-associations, "collapses" the input Vector type
1783///
1784/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1785/// differences:
1786/// * We can safely assume that there are no dynamic sizes.
1787/// * Scalable flags are updated alongside regular dims.
1788///
1789/// When collapsing scalable flags, conservatively avoids cases with two
1790/// scalable dims. We could re-visit this in the future.
1791///
1792/// EXAMPLE:
1793/// type = vector<4x16x[8]x16xf32>
1794/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1795/// (d0, d1, d2, d3) -> (d2, d3)]
1796/// Result:
1797/// vector<64x[128]xf32>
1798static VectorType getCollapsedVecType(VectorType type,
1799 ArrayRef<AffineMap> reassociation) {
1800 assert(type.getNumScalableDims() < 2 &&
1801 "Collapsing more than 1 scalable dim is not supported ATM");
1802
1803 // Use the fact that reassociation is valid to simplify the logic: only use
1804 // each map's rank.
1805 assert(isReassociationValid(reassociation) && "invalid reassociation");
1806
1807 auto shape = type.getShape();
1808 auto scalableFlags = type.getScalableDims();
1809 SmallVector<int64_t> newShape;
1810 SmallVector<bool> newScalableFlags;
1811
1812 unsigned currentDim = 0;
1813 for (AffineMap m : reassociation) {
1814 unsigned dim = m.getNumResults();
1815 int64_t size = 1;
1816 bool flag = false;
1817 for (unsigned d = 0; d < dim; ++d) {
1818 size *= shape[currentDim + d];
1819 flag |= scalableFlags[currentDim + d];
1820 }
1821 newShape.push_back(size);
1822 newScalableFlags.push_back(flag);
1823 currentDim += dim;
1824 }
1825
1826 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1827}
1828
1829/// Vectorize `linalg.pack` as:
1830/// * xfer_read -> shape_cast -> transpose -> xfer_write
1831///
1832/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1833/// sizes for the xfer_write operation). This is sufficient to infer the other
1834/// vector sizes required here.
1835///
1836/// If the vector sizes are not provided:
1837/// * the vector sizes are determined from the destination tensor static shape.
1838/// * the inBounds attribute is used instead of masking.
1839///
1840/// EXAMPLE (no vector sizes):
1841/// ```
1842/// %pack = tensor.pack %src
1843/// inner_dims_pos = [2, 1]
1844/// inner_tiles = [16, 2]
1845/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1846/// ``
1847/// is vectorizes as:
1848/// ```
1849/// %read = vector.transfer_read %src
1850/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1851/// %sc = vector.shape_cast %read
1852/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1853/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1854/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1855/// %write = vector.transfer_write %tr into %dest
1856/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1857/// ```
1858static LogicalResult
1859vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1860 ArrayRef<int64_t> inputVectorSizes,
1861 SmallVectorImpl<Value> &newResults) {
1862 if (!inputVectorSizes.empty()) {
1863 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1864 "Invalid number of input vector sizes!");
1865 }
1866
1867 // TODO: Introduce a parent class that will handle the insertion point update.
1868 OpBuilder::InsertionGuard g(rewriter);
1869 rewriter.setInsertionPoint(packOp);
1870
1871 Location loc = packOp.getLoc();
1872 std::optional<Value> padValue = packOp.getPaddingValue()
1873 ? std::optional(packOp.getPaddingValue())
1874 : std::nullopt;
1875
1876 SmallVector<int64_t> destShape =
1877 SmallVector<int64_t>(packOp.getDestType().getShape());
1878
1879 // This is just a convenience alias to clearly communicate that the input
1880 // vector sizes determine the _write_ sizes.
1881 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1882
1883 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1884 // In addition, use the inBounds attribute instead of masking.
1885 bool useInBoundsInsteadOfMasking = false;
1886 if (writeVectorSizes.empty()) {
1887 if (ShapedType::isDynamicShape(destShape))
1888 return rewriter.notifyMatchFailure(packOp,
1889 "unable to infer vector sizes");
1890
1891 writeVectorSizes = destShape;
1892 useInBoundsInsteadOfMasking = true;
1893 }
1894
1895 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1896 // _before_ the transposition (i.e. before dimension permutation). This is
1897 // done by inverting the permutation/transposition that's part of the Pack
1898 // operation. This type is required to:
1899 // 1) compute the read vector type for masked-read below, and
1900 // 2) generate shape-cast Op below that expands the read vector type.
1901 PackingMetadata packMetadata;
1902 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1903 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1904 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1905 auto preTransposeWriteVecType =
1906 VectorType::get(preTransposeWriteVecSizses,
1907 packOp.getResult().getType().getElementType());
1908
1909 // Compute vector type for the _read_ opeartion. This is simply
1910 // pre-transpose-write-vector-type with the dimensions collapsed
1911 // as per the Pack operation.
1912 VectorType readVecType = getCollapsedVecType(
1913 preTransposeWriteVecType,
1915 rewriter.getContext(), packMetadata.reassociations)));
1916
1917 // Create masked TransferReadOp.
1918 auto maskedRead = vector::createReadOrMaskedRead(
1919 rewriter, loc, packOp.getSource(), readVecType, padValue,
1920 useInBoundsInsteadOfMasking);
1921
1922 // Create ShapeCastOp.
1923 auto shapeCastOp = vector::ShapeCastOp::create(
1924 rewriter, loc, preTransposeWriteVecType, maskedRead);
1925
1926 // Create TransposeOp.
1927 auto destPermutation = invertPermutationVector(destInvPermutation);
1928 auto transposeOp = vector::TransposeOp::create(
1929 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1930
1931 // Create TransferWriteOp.
1932 Operation *write = createWriteOrMaskedWrite(
1933 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1934 newResults.push_back(write->getResult(0));
1935 return success();
1936}
1937
1938/// Vectorize `linalg.unpack` as:
1939/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1940///
1941/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1942/// sizes for the xfer_read operation). This is sufficient to infer the other
1943/// vector sizes required here.
1944///
1945/// If the vector sizes are not provided:
1946/// * the vector sizes are determined from the input tensor static shape.
1947/// * the inBounds attribute is used instead of masking.
1948///
1949/// EXAMPLE (no vector sizes):
1950/// ```
1951/// %unpack = linalg.unpack %src
1952/// inner_dims_pos = [0, 1]
1953/// inner_tiles = [8, 8]
1954/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1955/// ```
1956/// is vectorized as:
1957/// ```
1958/// %read = vector.transfer_read %src
1959/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1960/// %tr = vector.transpose %read, [0, 2, 1, 3]
1961/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1962/// %sc = vector.shape_cast %tr
1963/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1964/// %vector = vector.transfer_write %sc into %dest
1965/// : vector<8x8xf32>, tensor<8x8xf32>
1966/// ```
1967static LogicalResult
1968vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1969 ArrayRef<int64_t> inputVectorSizes,
1970 ArrayRef<bool> inputScalableVecDims,
1971 SmallVectorImpl<Value> &newResults) {
1972 if (!inputVectorSizes.empty()) {
1973 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1974 "Invalid number of input vector sizes!");
1975 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1976 "Incompatible number of vector sizes and vector scalable flags!");
1977 }
1978
1979 // TODO: Introduce a parent class that will handle the insertion point update.
1980 OpBuilder::InsertionGuard g(rewriter);
1981 rewriter.setInsertionPoint(unpackOp);
1982
1983 ShapedType unpackTensorType = unpackOp.getSourceType();
1984
1985 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1986 bool useInBoundsInsteadOfMasking = false;
1987
1988 Location loc = unpackOp->getLoc();
1989
1990 // Obtain vector sizes for the read operation.
1991 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1992 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1993
1994 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1995 if (inputVectorSizes.empty()) {
1996 if (ShapedType::isDynamicShape(sourceShape))
1997 return rewriter.notifyMatchFailure(unpackOp,
1998 "Unable to infer vector sizes!");
1999
2000 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
2001 useInBoundsInsteadOfMasking = true;
2002 }
2003
2004 // -- Generate the read operation --
2005 VectorType readVecType =
2006 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
2007 readScalableVectorFlags);
2008 Value readResult = vector::createReadOrMaskedRead(
2009 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2010 useInBoundsInsteadOfMasking);
2011
2012 // -- Generate the transpose operation --
2013 PackingMetadata packMetadata;
2014 SmallVector<int64_t> lastDimToInsertPosPerm =
2015 getUnPackInverseSrcPerm(unpackOp, packMetadata);
2016 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2017 rewriter, loc, readResult, lastDimToInsertPosPerm);
2018
2019 // -- Generate the shape_cast operation --
2020 VectorType collapsedVecType = getCollapsedVecType(
2021 transposeOp.getType(),
2023 rewriter.getContext(), packMetadata.reassociations)));
2024 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2025 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2026
2027 // -- Generate the write operation --
2028 Operation *write = createWriteOrMaskedWrite(
2029 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2030 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
2031
2032 newResults.push_back(write->getResult(0));
2033 return success();
2034}
2035
2036/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
2037/// and (3) all-zero lowPad to
2038/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2039static LogicalResult
2040vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2041 ArrayRef<int64_t> inputVectorSizes,
2042 SmallVectorImpl<Value> &newResults) {
2043 auto padValue = padOp.getConstantPaddingValue();
2044 Location loc = padOp.getLoc();
2045
2046 // TODO: Introduce a parent class that will handle the insertion point update.
2047 OpBuilder::InsertionGuard g(rewriter);
2048 rewriter.setInsertionPoint(padOp);
2049
2050 ReifiedRankedShapedTypeDims reifiedReturnShapes;
2051 LogicalResult status =
2052 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2053 .reifyResultShapes(rewriter, reifiedReturnShapes);
2054 (void)status; // prevent unused variable warning on non-assert builds
2055 assert(succeeded(status) && "failed to reify result shapes");
2056 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2057 auto maskedRead = vector::createReadOrMaskedRead(
2058 rewriter, loc, padOp.getSource(), readType, padValue,
2059 /*useInBoundsInsteadOfMasking=*/false);
2060
2061 // Create Xfer write Op
2062 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2063 padOp.getResultType().getElementType());
2064 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2065 newResults.push_back(write->getResult(0));
2066 return success();
2067}
2068
2069// TODO: probably need some extra checks for reduction followed by consumer
2070// ops that may not commute (e.g. linear reduction + non-linear instructions).
2071static LogicalResult reductionPreconditions(LinalgOp op) {
2072 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2073 LDBG() << "reduction precondition failed: no reduction iterator";
2074 return failure();
2075 }
2076 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2077 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2078 if (indexingMap.isPermutation())
2079 continue;
2080
2081 Operation *reduceOp = matchLinalgReduction(&opOperand);
2082 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2083 LDBG() << "reduction precondition failed: reduction detection failed";
2084 return failure();
2085 }
2086 }
2087 return success();
2088}
2089
2090static LogicalResult
2091vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2092 bool flatten1DDepthwiseConv) {
2093 if (flatten1DDepthwiseConv) {
2094 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2095 "supported";
2096 return failure();
2097 }
2098
2100 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2101 return failure();
2102 }
2103
2104 // Support dynamic shapes in 1D depthwise convolution, but only in the
2105 // _channel_ dimension.
2106 Value lhs = conv.getDpsInputOperand(0)->get();
2107 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2108 auto shapeWithoutCh = lhsShape.drop_back(1);
2109 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2110 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2111 "channel dim can be dynamic";
2112 return failure();
2113 }
2114
2115 return success();
2116}
2117
2118static LogicalResult
2119vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2120 bool flatten1DDepthwiseConv) {
2122 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2123
2124 if (hasReductionIterator(op))
2125 return reductionPreconditions(op);
2126
2127 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2128 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2129 if (!isElementwise(op) &&
2130 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2131 op.getOperation()))
2132 return failure();
2133
2134 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2135 return success();
2136}
2137
2138//// This hook considers two cases:
2139/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2140/// infered. This is only possible when all shapes are static.
2141/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2142/// carry out basic sanity-checking.
2143static LogicalResult
2144vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2145 ArrayRef<int64_t> inputVectorSizes) {
2146 // TODO: Support Memref UnPackOp. Temporarily return failure.
2147 if (!unpackOp.hasPureTensorSemantics())
2148 return failure();
2149
2150 // If there are no input vector sizes and all shapes are static, there is
2151 // nothing left to check.
2152 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2153 unpackOp.getSourceType().hasStaticShape())
2154 return success();
2155
2156 // The number of input vector sizes must be equal to:
2157 // * read-vector-rank
2158 if (!inputVectorSizes.empty() &&
2159 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2160 LDBG() << "Incorrect number of input vector sizes";
2161 return failure();
2162 }
2163
2164 // Check the vector sizes for the read operation.
2166 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2167 LDBG() << "Invalid vector sizes for the read operation";
2168 return failure();
2169 }
2170
2171 return success();
2172}
2173
2174static LogicalResult
2175vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2176 ArrayRef<int64_t> inputVectorSizes) {
2177
2178 TypedValue<RankedTensorType> source = sliceOp.getSource();
2179 auto sourceType = source.getType();
2180 if (!VectorType::isValidElementType(sourceType.getElementType()))
2181 return failure();
2182
2183 // Get the pad value.
2184 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2185 // scalar padding value. Note that:
2186 // * for in-bounds accesses,
2187 // the value is actually irrelevant. There are 2 cases in which xfer.read
2188 // accesses are known to be in-bounds:
2189 // 1. The source shape is static (output vector sizes would be based on
2190 // the source shape and hence all memory accesses would be in-bounds),
2191 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2192 // this case it is safe to assume that all memory accesses are in-bounds.
2193 //
2194 // When the value is not known and not needed, use 0. Otherwise, bail out.
2195 Value padValue = getStaticPadVal(sliceOp);
2196 bool isOutOfBoundsRead =
2197 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2198
2199 if (!padValue && isOutOfBoundsRead) {
2200 LDBG() << "Failed to get a pad value for out-of-bounds read access";
2201 return failure();
2202 }
2203 return success();
2204}
2205
2206/// Vectorize a named linalg contraction op into:
2207/// vector::TransferReadOp - Reads vectors from the operands
2208/// vector::ContractionOp - Performs contraction
2209/// vector::TransferWriteOp - Write the result vector back to the
2210/// destination
2211/// The operands shapes are preserved and loaded directly into vectors.
2212/// Any further permutations or numerical casting remain within contraction op.
2213static LogicalResult
2214vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2215 LinalgOp linalgOp,
2216 SmallVectorImpl<Value> &newResults) {
2217 Location loc = linalgOp.getLoc();
2218 MLIRContext *ctx = linalgOp.getContext();
2219
2220 // For simplicity, contraction vectorization is limited to linalg named ops.
2221 // Generic op is ignored as not every arbitrary contraction body can be
2222 // expressed by a vector.contract.
2223 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2224 return failure();
2225
2226 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2227 Operation *reduceOp = matchLinalgReduction(outOperand);
2228 auto maybeKind = getCombinerOpKind(reduceOp);
2229 if (!maybeKind) {
2230 LDBG() << "Failed to determine contraction combining kind.";
2231 return failure();
2232 }
2233
2234 // Check that all dimensions are present in the input operands.
2235 // Arbitrary broadcasts are not supported by the vector contraction.
2236 // Broadcasts are expected to be decomposed before vectorization.
2237 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2238 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2239 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2240 LDBG() << "Contractions with broadcasts are not supported.";
2241 return failure();
2242 }
2243
2244 // Load operands.
2245 SmallVector<Value> vecOperands;
2246 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2247 // The operand vector shape is computed by mapping the canonical vector
2248 // shape to the operand's domain. Further permutations are left as a part of
2249 // the contraction.
2250 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2251 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2252 indexingMap.getNumResults(), rewriter.getContext());
2253 Type elemType = getElementTypeOrSelf(opOperand.get());
2254 VectorType readType =
2255 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2256
2258 rewriter, loc, opOperand.get(), readType,
2259 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2260 /*useInBoundsInsteadOfMasking=*/false);
2261 vecOperands.push_back(read);
2262 }
2263
2264 // Remap iterators from linalg to vector.
2265 SmallVector<Attribute> iterAttrs;
2266 auto iterators = linalgOp.getIteratorTypesArray();
2267 for (utils::IteratorType iter : iterators) {
2268 auto vecIter = iter == utils::IteratorType::parallel
2269 ? vector::IteratorType::parallel
2270 : vector::IteratorType::reduction;
2271 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2272 }
2273
2274 // Create contraction.
2275 Operation *contractOp = vector::ContractionOp::create(
2276 rewriter, loc, /*lhs=*/vecOperands[0],
2277 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2278 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2279 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2280
2281 // Store result.
2282 Operation *write = createWriteOrMaskedWrite(
2283 rewriter, loc, contractOp->getResult(0), outOperand->get());
2284
2285 // Finalize.
2286 if (!write->getResults().empty())
2287 newResults.push_back(write->getResult(0));
2288
2289 return success();
2290}
2291
2292namespace {
2293enum class ConvOperationKind { Conv, Pool };
2294} // namespace
2295
2296static bool isCastOfBlockArgument(Operation *op) {
2297 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2298 isa<BlockArgument>(op->getOperand(0));
2299}
2300
2301// Returns the ConvOperationKind of the op using reduceOp of the generic
2302// payload. If it is neither a convolution nor a pooling, it returns
2303// std::nullopt.
2304//
2305// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2306// + yield) and rhs is not used) then it is the body of a pooling
2307// If conv, check for single `mul` predecessor. The `mul` operands must be
2308// block arguments or extension of block arguments.
2309// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2310// must be block arguments or extension of block arguments.
2311static std::optional<ConvOperationKind>
2312getConvOperationKind(Operation *reduceOp) {
2313 int numBlockArguments =
2314 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2315
2316 switch (numBlockArguments) {
2317 case 1: {
2318 // Will be convolution if feeder is a MulOp.
2319 // A strength reduced version of MulOp for i1 type is AndOp which is also
2320 // supported. Otherwise, it can be pooling. This strength reduction logic
2321 // is in `buildBinaryFn` helper in the Linalg dialect.
2322 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2323 llvm::IsaPred<BlockArgument>);
2324 assert(feedValIt != reduceOp->operand_end() &&
2325 "Expected a non-block argument operand");
2326 Operation *feedOp = (*feedValIt).getDefiningOp();
2327 if (isCastOfBlockArgument(feedOp)) {
2328 return ConvOperationKind::Pool;
2329 }
2330
2331 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2332 (isa<arith::AndIOp>(feedOp) &&
2333 feedOp->getResultTypes()[0].isInteger(1))) &&
2334 llvm::all_of(feedOp->getOperands(), [](Value v) {
2335 if (isa<BlockArgument>(v))
2336 return true;
2337 if (Operation *op = v.getDefiningOp())
2338 return isCastOfBlockArgument(op);
2339 return false;
2340 }))) {
2341 return std::nullopt;
2342 }
2343
2344 return ConvOperationKind::Conv;
2345 }
2346 case 2:
2347 // Must be pooling
2348 return ConvOperationKind::Pool;
2349 default:
2350 return std::nullopt;
2351 }
2352}
2353
2354static bool isSupportedPoolKind(vector::CombiningKind kind) {
2355 switch (kind) {
2356 case vector::CombiningKind::ADD:
2357 case vector::CombiningKind::MAXNUMF:
2358 case vector::CombiningKind::MAXIMUMF:
2359 case vector::CombiningKind::MAXSI:
2360 case vector::CombiningKind::MAXUI:
2361 case vector::CombiningKind::MINNUMF:
2362 case vector::CombiningKind::MINIMUMF:
2363 case vector::CombiningKind::MINSI:
2364 case vector::CombiningKind::MINUI:
2365 return true;
2366 default:
2367 return false;
2368 }
2369}
2370
2371static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2372 auto getOperandType = [&](auto operand) {
2373 return dyn_cast<ShapedType>((operand->get()).getType());
2374 };
2375 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2376 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2377 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2378 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2379 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2380 // Note that this also ensures 2D and 3D convolutions are rejected.
2381 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2382 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2383 return failure();
2384
2385 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2386 if (!reduceOp)
2387 return failure();
2388
2389 auto maybeOper = getConvOperationKind(reduceOp);
2390 if (!maybeOper.has_value())
2391 return failure();
2392
2393 auto maybeKind = getCombinerOpKind(reduceOp);
2394 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2395 // can get strength reduced to `OR` which is also supported. This strength
2396 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2397 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2398 *maybeKind != vector::CombiningKind::OR) &&
2399 (*maybeOper != ConvOperationKind::Pool ||
2400 !isSupportedPoolKind(*maybeKind)))) {
2401 return failure();
2402 }
2403
2404 auto rhsRank = rhsShapedType.getRank();
2405 if (*maybeOper == ConvOperationKind::Pool) {
2406 if (rhsRank != 1)
2407 return failure();
2408 } else {
2409 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2410 return failure();
2411 }
2412
2413 return success();
2414}
2415
2416static LogicalResult vectorizeLinalgOpPrecondition(
2417 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2418 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2419 // tensor with dimension of 0 cannot be vectorized.
2420 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2421 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2422 }))
2423 return failure();
2424 // Check API contract for input vector sizes.
2425 if (!inputVectorSizes.empty() &&
2426 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2427 inputVectorSizes)))
2428 return failure();
2429
2430 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2431 linalgOp, flatten1DDepthwiseConv))) {
2432 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2433 return failure();
2434 }
2435
2436 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2437
2438 // Register CustomVectorizationPrecondition for extractOp.
2439 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2440
2441 // All types in the body should be a supported element type for VectorType.
2442 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2443 // Check if any custom hook can vectorize the inner op.
2444 if (llvm::any_of(
2445 customPreconditions,
2446 [&](const CustomVectorizationPrecondition &customPrecondition) {
2447 return succeeded(
2448 customPrecondition(&innerOp, vectorizeNDExtract));
2449 })) {
2450 continue;
2451 }
2452 if (!llvm::all_of(innerOp.getOperandTypes(),
2453 VectorType::isValidElementType)) {
2454 return failure();
2455 }
2456 if (!llvm::all_of(innerOp.getResultTypes(),
2457 VectorType::isValidElementType)) {
2458 return failure();
2459 }
2460 }
2461 if (isElementwise(linalgOp))
2462 return success();
2463
2464 // Check for both named as well as generic convolution ops.
2465 if (isaConvolutionOpInterface(linalgOp))
2466 return vectorizeConvOpPrecondition(linalgOp);
2467
2468 // TODO: the common vector shape is equal to the static loop sizes only when
2469 // all indexing maps are projected permutations. For convs and stencils the
2470 // logic will need to evolve.
2471 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2472 LDBG() << "precondition failed: not projected permutations";
2473 return failure();
2474 }
2475 if (failed(reductionPreconditions(linalgOp))) {
2476 LDBG() << "precondition failed: reduction preconditions";
2477 return failure();
2478 }
2479 return success();
2480}
2481
2482static LogicalResult
2483vectorizePackOpPrecondition(linalg::PackOp packOp,
2484 ArrayRef<int64_t> inputVectorSizes) {
2485 // TODO: Support Memref PackOp. Temporarily return failure.
2486 if (!packOp.hasPureTensorSemantics())
2487 return failure();
2488
2489 auto padValue = packOp.getPaddingValue();
2490 Attribute cstAttr;
2491 // TODO: Relax this condiiton
2492 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2493 LDBG() << "pad value is not constant: " << packOp;
2494 return failure();
2495 }
2496
2497 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2498 bool satisfyEmptyCond = true;
2499 if (inputVectorSizes.empty()) {
2500 if (!packOp.getDestType().hasStaticShape() ||
2501 !packOp.getSourceType().hasStaticShape())
2502 satisfyEmptyCond = false;
2503 }
2504
2505 if (!satisfyEmptyCond &&
2507 resultTensorShape.take_front(packOp.getSourceRank()),
2508 inputVectorSizes)))
2509 return failure();
2510
2511 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2512 return !getConstantIntValue(v).has_value();
2513 })) {
2514 LDBG() << "inner_tiles must be constant: " << packOp;
2515 return failure();
2516 }
2517
2518 return success();
2519}
2520
2521static LogicalResult
2522vectorizePadOpPrecondition(tensor::PadOp padOp,
2523 ArrayRef<int64_t> inputVectorSizes) {
2524 auto padValue = padOp.getConstantPaddingValue();
2525 if (!padValue) {
2526 LDBG() << "pad value is not constant: " << padOp;
2527 return failure();
2528 }
2529
2530 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2531 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2532 inputVectorSizes)))
2533 return failure();
2534
2535 // Padding with non-zero low pad values is not supported, unless the
2536 // corresponding result dim is 1 as this would require shifting the results to
2537 // the right for the low padded dims by the required amount of low padding.
2538 // However, we do support low padding if the dims being low padded have result
2539 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2540 // input size of that dimension will be dynamically zero (as the sum of the
2541 // low pad and input dim size has to be one) and hence we will create a zero
2542 // mask as the lowering logic just makes the mask one for the input dim size -
2543 // which is zero here. Hence we will load the pad value which is what we want
2544 // in this case. If the low pad is dynamically zero then the lowering is
2545 // correct as well as no shifts are necessary.
2546 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2547 [&](const auto &en) {
2548 OpFoldResult padValue = en.value();
2549 unsigned pos = en.index();
2550 std::optional<int64_t> pad = getConstantIntValue(padValue);
2551 return (!pad.has_value() || pad.value() != 0) &&
2552 resultTensorShape[pos] != 1;
2553 })) {
2554 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2555 return failure();
2556 }
2557
2558 return success();
2559}
2560
2561/// Preconditions for scalable vectors.
2562///
2563/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2564/// models the fact that in practice we would only make selected dimensions
2565/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2566/// unconditionally - we are yet to identify meaningful conditions.
2567static LogicalResult
2568vectorizeScalableVectorPrecondition(Operation *op,
2569 ArrayRef<int64_t> inputVectorSizes,
2570 ArrayRef<bool> inputScalableVecDims) {
2571 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2572 "Number of input vector sizes and scalable dims doesn't match");
2573
2574 size_t numOfScalableDims =
2575 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2576
2577 if (numOfScalableDims == 0)
2578 return success();
2579
2580 auto linalgOp = dyn_cast<LinalgOp>(op);
2581
2582 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2583 // exception of UnpackOp for which there is a dedicated hook.
2584 if (!linalgOp) {
2585 return success(isa<linalg::UnPackOp>(op));
2586 }
2587
2588 // Cond 2: There's been no need for more than 2 scalable dims so far
2589 if (numOfScalableDims > 2)
2590 return failure();
2591
2592 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2593 // it matches one of the supported cases:
2594 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2595 // (*).
2596 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2597 // parallel dims.
2598 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2599 // dim.
2600 // The 2nd restriction above means that only Matmul-like Ops are supported
2601 // when 2 dims are scalable, e.g. :
2602 // * iterators = [parallel, parallel, reduction]
2603 // * scalable flags = [true, true, false]
2604 //
2605 // (*) Non-unit dims get folded away in practice.
2606 // TODO: Relax these conditions as good motivating examples are identified.
2607
2608 // Find the first scalable flag.
2609 bool seenNonUnitParallel = false;
2610 auto iterators = linalgOp.getIteratorTypesArray();
2611 SmallVector<bool> scalableFlags(inputScalableVecDims);
2612 int64_t idx = scalableFlags.size() - 1;
2613 while (!scalableFlags[idx]) {
2614 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2615 seenNonUnitParallel |=
2616 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2617
2618 iterators.pop_back();
2619 scalableFlags.pop_back();
2620 --idx;
2621 }
2622
2623 // Analyze the iterator corresponding to the first scalable dim.
2624 switch (iterators.back()) {
2625 case utils::IteratorType::reduction: {
2626 // Check 3. above is met.
2627 if (iterators.size() != inputVectorSizes.size()) {
2628 LDBG() << "Non-trailing reduction dim requested for scalable "
2629 "vectorization";
2630 return failure();
2631 }
2632 if (isa<linalg::MatmulOp>(op)) {
2633 LDBG()
2634 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2635 "is not supported";
2636 return failure();
2637 }
2638 break;
2639 }
2640 case utils::IteratorType::parallel: {
2641 // Check 1. and 2. above are met.
2642 if (seenNonUnitParallel) {
2643 LDBG() << "Inner parallel dim not requested for scalable "
2644 "vectorization";
2645 return failure();
2646 }
2647 break;
2648 }
2649 }
2650
2651 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2652 // supported for which expect the folowing config:
2653 // * iterators = [parallel, parallel, reduction]
2654 // * scalable flags = [true, true, false]
2655 if (numOfScalableDims == 2) {
2656 // Disallow below case which breaks 3. above:
2657 // * iterators = [..., parallel, reduction]
2658 // * scalable flags = [..., true, true]
2659 if (iterators.back() == utils::IteratorType::reduction) {
2660 LDBG() << "Higher dim than the trailing reduction dim requested for "
2661 "scalable "
2662 "vectorizatio";
2663 return failure();
2664 }
2665 scalableFlags.pop_back();
2666 iterators.pop_back();
2667
2668 if (!scalableFlags.back() ||
2669 (iterators.back() != utils::IteratorType::parallel))
2670 return failure();
2671 }
2672
2673 // Cond 4: Only the following ops are supported in the
2674 // presence of scalable vectors
2675 return success(
2676 isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2677 isa<linalg::BatchMatmulOp>(op) ||
2679 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2680 isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
2681}
2682
2684 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2685 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2686 bool flatten1DDepthwiseConv) {
2687
2688 if (!hasVectorizationImpl(op))
2689 return failure();
2690
2691 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2692 inputScalableVecDims)))
2693 return failure();
2694
2696 .Case([&](linalg::LinalgOp linalgOp) {
2697 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2698 vectorizeNDExtract,
2699 flatten1DDepthwiseConv);
2700 })
2701 .Case([&](tensor::PadOp padOp) {
2702 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2703 })
2704 .Case([&](linalg::PackOp packOp) {
2705 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2706 })
2707 .Case([&](linalg::UnPackOp unpackOp) {
2708 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2709 })
2710 .Case([&](tensor::InsertSliceOp sliceOp) {
2711 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2712 })
2713 .Default(failure());
2714}
2715
2716/// Converts affine.apply Ops to arithmetic operations.
2717static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2718 OpBuilder::InsertionGuard g(rewriter);
2719 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2720
2721 for (auto op : make_early_inc_range(toReplace)) {
2722 rewriter.setInsertionPoint(op);
2723 auto expanded = affine::expandAffineExpr(
2724 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2725 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2726 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2727 rewriter.replaceOp(op, expanded);
2728 }
2729}
2730
2731bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2732 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2733 tensor::InsertSliceOp>(op);
2734}
2735
2736FailureOr<VectorizationResult> mlir::linalg::vectorize(
2737 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2738 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2739 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2740 bool createNamedContraction) {
2741 LDBG() << "Attempting to vectorize: " << *op;
2742 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2743 LDBG() << "Input scalable vector dims: "
2744 << llvm::interleaved(inputScalableVecDims);
2745
2746 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2747 vectorizeNDExtract,
2748 flatten1DDepthwiseConv))) {
2749 LDBG() << "Vectorization pre-conditions failed";
2750 return failure();
2751 }
2752
2753 // Initialize vectorization state.
2754 VectorizationState state(rewriter);
2755 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2756 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2757 inputScalableVecDims,
2758 assumeDynamicDimsMatchVecSizes))) {
2759 LDBG() << "Vectorization state couldn't be initialized";
2760 return failure();
2761 }
2762 }
2763
2764 SmallVector<Value> results;
2765 auto vectorizeResult =
2767 .Case([&](linalg::LinalgOp linalgOp) {
2768 // Check for both named as well as generic convolution ops.
2769 if (isaConvolutionOpInterface(linalgOp)) {
2770 FailureOr<Operation *> convOr = vectorizeConvolution(
2771 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2772 flatten1DDepthwiseConv);
2773 if (succeeded(convOr)) {
2774 llvm::append_range(results, (*convOr)->getResults());
2775 return success();
2776 }
2777
2778 LDBG() << "Unsupported convolution can't be vectorized.";
2779 return failure();
2780 }
2781
2782 if (createNamedContraction &&
2783 isa<ContractionOpInterface>(linalgOp.getOperation()))
2784 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2785 results);
2786
2787 LDBG()
2788 << "Vectorize generic by broadcasting to the canonical vector "
2789 "shape";
2790
2791 // Pre-process before proceeding.
2792 convertAffineApply(rewriter, linalgOp);
2793
2794 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2795 // to 'OpBuilder' when it is passed over to some methods like
2796 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2797 // erase an op within these methods, the actual rewriter won't be
2798 // notified and we will end up with read-after-free issues!
2799 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2800 })
2801 .Case([&](tensor::PadOp padOp) {
2802 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2803 results);
2804 })
2805 .Case([&](linalg::PackOp packOp) {
2806 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2807 results);
2808 })
2809 .Case([&](linalg::UnPackOp unpackOp) {
2810 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2811 inputVectorSizes,
2812 inputScalableVecDims, results);
2813 })
2814 .Case([&](tensor::InsertSliceOp sliceOp) {
2815 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2816 results);
2817 })
2818 .Default(failure());
2819
2820 if (failed(vectorizeResult)) {
2821 LDBG() << "Vectorization failed";
2822 return failure();
2823 }
2824
2825 return VectorizationResult{results};
2826}
2827
2828LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2829 memref::CopyOp copyOp) {
2830 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2831 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2832 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2833 return failure();
2834
2835 auto srcElementType = getElementTypeOrSelf(srcType);
2836 auto dstElementType = getElementTypeOrSelf(dstType);
2837 if (!VectorType::isValidElementType(srcElementType) ||
2838 !VectorType::isValidElementType(dstElementType))
2839 return failure();
2840
2841 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2842 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2843
2844 Location loc = copyOp->getLoc();
2845 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2846 SmallVector<Value> indices(srcType.getRank(), zero);
2847
2848 Value readValue = vector::TransferReadOp::create(
2849 rewriter, loc, readType, copyOp.getSource(), indices,
2850 /*padding=*/std::nullopt,
2851 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2852 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2853 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2854 ArrayRef<int64_t>());
2855 readValue =
2856 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2857 }
2858 Operation *writeValue = vector::TransferWriteOp::create(
2859 rewriter, loc, readValue, copyOp.getTarget(), indices,
2860 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2861 rewriter.replaceOp(copyOp, writeValue->getResults());
2862 return success();
2863}
2864
2865//----------------------------------------------------------------------------//
2866// Misc. vectorization patterns.
2867//----------------------------------------------------------------------------//
2868/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2869/// given operation type OpTy.
2870template <typename OpTy>
2871struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2872 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2873
2874 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2875 PatternRewriter &rewriter) const final {
2876 bool changed = false;
2877 // Insert users in vector, because some users may be replaced/removed.
2878 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2879 if (auto op = dyn_cast<OpTy>(user))
2880 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2881 return success(changed);
2882 }
2883
2884protected:
2885 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2886 tensor::PadOp padOp, OpTy op) const = 0;
2887};
2888
2889/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2890/// ```
2891/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2892/// %r = vector.transfer_read %0[%c0, %c0], %cst
2893/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2894/// ```
2895/// is rewritten to:
2896/// ```
2897/// %r = vector.transfer_read %src[%c0, %c0], %padding
2898/// {in_bounds = [true, true]}
2899/// : tensor<?x?xf32>, vector<17x5xf32>
2900/// ```
2901/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2902/// sure that the original padding value %cst was never used.
2903///
2904/// This rewrite is possible if:
2905/// - `xferOp` has no out-of-bounds dims or mask.
2906/// - Low padding is static 0.
2907/// - Single, scalar padding value.
2908struct PadOpVectorizationWithTransferReadPattern
2909 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2910 using VectorizePadOpUserPattern<
2911 vector::TransferReadOp>::VectorizePadOpUserPattern;
2912
2913 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2914 vector::TransferReadOp xferOp) const override {
2915 // Low padding must be static 0.
2916 if (!padOp.hasZeroLowPad())
2917 return failure();
2918 // Pad value must be a constant.
2919 auto padValue = padOp.getConstantPaddingValue();
2920 if (!padValue)
2921 return failure();
2922 // Padding value of existing `xferOp` is unused.
2923 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2924 return failure();
2925
2926 rewriter.modifyOpInPlace(xferOp, [&]() {
2927 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2928 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2929 rewriter.getBoolArrayAttr(inBounds));
2930 xferOp.getBaseMutable().assign(padOp.getSource());
2931 xferOp.getPaddingMutable().assign(padValue);
2932 });
2933
2934 return success();
2935 }
2936};
2937
2938/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2939/// This pattern rewrites TransferWriteOps that write to a padded tensor
2940/// value, where the same amount of padding is immediately removed again after
2941/// the write. In such cases, the TransferWriteOp can write to the non-padded
2942/// tensor value and apply out-of-bounds masking. E.g.:
2943/// ```
2944/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2945/// : tensor<...> to tensor<?x?xf32>
2946/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2947/// %2 = vector.transfer_write %vec, %1[...]
2948/// : vector<17x5xf32>, tensor<17x5xf32>
2949/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2950/// : tensor<17x5xf32> to tensor<?x?xf32>
2951/// ```
2952/// is rewritten to:
2953/// ```
2954/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2955/// : tensor<...> to tensor<?x?xf32>
2956/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2957/// tensor<?x?xf32>
2958/// ```
2959/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2960/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2961/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2962/// from %r's old dimensions.
2963///
2964/// This rewrite is possible if:
2965/// - Low padding is static 0.
2966/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2967/// ExtractSliceOp trims the same amount of padding that was added
2968/// beforehand.
2969/// - Single, scalar padding value.
2970struct PadOpVectorizationWithTransferWritePattern
2971 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2972 using VectorizePadOpUserPattern<
2973 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2974
2975 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2976 vector::TransferWriteOp xferOp) const override {
2977 // TODO: support 0-d corner case.
2978 if (xferOp.getTransferRank() == 0)
2979 return failure();
2980
2981 // Low padding must be static 0.
2982 if (!padOp.hasZeroLowPad())
2983 return failure();
2984 // Pad value must be a constant.
2985 auto padValue = padOp.getConstantPaddingValue();
2986 if (!padValue)
2987 return failure();
2988 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2989 if (!xferOp->hasOneUse())
2990 return failure();
2991 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2992 if (!trimPadding)
2993 return failure();
2994 // Only static zero offsets supported when trimming padding.
2995 if (!trimPadding.hasZeroOffset())
2996 return failure();
2997 // trimPadding must remove the amount of padding that was added earlier.
2998 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2999 return failure();
3000
3001 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
3002 rewriter.setInsertionPoint(xferOp);
3003
3004 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
3005 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3006 xferOp, padOp.getSource().getType(), xferOp.getVector(),
3007 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
3008 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
3009 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
3010
3011 return success();
3012 }
3013
3014 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
3015 /// i.e., same dimensions.
3016 ///
3017 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
3018 /// dimensions, this function tries to infer the (static) tensor size by
3019 /// looking at the defining op and utilizing op-specific knowledge.
3020 ///
3021 /// This is a conservative analysis. In case equal tensor sizes cannot be
3022 /// proven statically, this analysis returns `false` even though the tensor
3023 /// sizes may turn out to be equal at runtime.
3024 bool hasSameTensorSize(Value beforePadding,
3025 tensor::ExtractSliceOp afterTrimming) const {
3026 // If the input to tensor::PadOp is a CastOp, try with both CastOp
3027 // result and CastOp operand.
3028 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
3029 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3030 return true;
3031
3032 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
3033 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3034 // Only RankedTensorType supported.
3035 if (!t1 || !t2)
3036 return false;
3037 // Rank of both values must be the same.
3038 if (t1.getRank() != t2.getRank())
3039 return false;
3040
3041 // All static dimensions must be the same. Mixed cases (e.g., dimension
3042 // static in `t1` but dynamic in `t2`) are not supported.
3043 for (unsigned i = 0; i < t1.getRank(); ++i) {
3044 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3045 return false;
3046 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3047 return false;
3048 }
3049
3050 // Nothing more to check if all dimensions are static.
3051 if (t1.getNumDynamicDims() == 0)
3052 return true;
3053
3054 // All dynamic sizes must be the same. The only supported case at the
3055 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3056 // thereof).
3057
3058 // Apart from CastOp, only ExtractSliceOp is supported.
3059 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3060 if (!beforeSlice)
3061 return false;
3062
3063 assert(static_cast<size_t>(t1.getRank()) ==
3064 beforeSlice.getMixedSizes().size());
3065 assert(static_cast<size_t>(t2.getRank()) ==
3066 afterTrimming.getMixedSizes().size());
3067
3068 for (unsigned i = 0; i < t1.getRank(); ++i) {
3069 // Skip static dimensions.
3070 if (!t1.isDynamicDim(i))
3071 continue;
3072 auto size1 = beforeSlice.getMixedSizes()[i];
3073 auto size2 = afterTrimming.getMixedSizes()[i];
3074
3075 // Case 1: Same value or same constant int.
3076 if (isEqualConstantIntOrValue(size1, size2))
3077 continue;
3078
3079 // Other cases: Take a deeper look at defining ops of values.
3080 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3081 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3082 if (!v1 || !v2)
3083 return false;
3084
3085 // Case 2: Both values are identical AffineMinOps. (Should not happen if
3086 // CSE is run.)
3087 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3088 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3089 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3090 minOp1.getOperands() == minOp2.getOperands())
3091 continue;
3092
3093 // Add additional cases as needed.
3094 }
3095
3096 // All tests passed.
3097 return true;
3098 }
3099};
3100
3101/// Returns the effective Pad value for the input op, provided it's a scalar.
3102///
3103/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3104/// this Op performs padding, retrieve the padding value provided that it's
3105/// a scalar and static/fixed for all the padded values. Returns an empty value
3106/// otherwise.
3107///
3108/// TODO: This is used twice (when checking vectorization pre-conditions and
3109/// when vectorizing). Cache results instead of re-running.
3110static Value getStaticPadVal(Operation *op) {
3111 if (!op)
3112 return {};
3113
3114 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3115 // being broadcast, provided that it's a scalar.
3116 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3117 auto source = bcast.getSource();
3118 if (llvm::dyn_cast<VectorType>(source.getType()))
3119 return {};
3120
3121 return source;
3122 }
3123
3124 // 2. linalg.fill - use the scalar input value that used to fill the output
3125 // tensor.
3126 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3127 return fill.getInputs()[0];
3128 }
3129
3130 // 3. tensor.generateOp - can't guarantee the value is fixed without
3131 // analysing, bail out.
3132 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3133 return {};
3134 }
3135
3136 // 4. vector.transfer_write - inspect the input vector that's written from. If
3137 // if contains a single value that has been broadcast (e.g. via
3138 // vector.broadcast), extract it, fail otherwise.
3139 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3140 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3141
3142 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3143 // than the input tensor, then, provided it's constant, we'll extract the
3144 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3145 // TODO: Clarify the semantics when the input tensor is larger than the
3146 // destination.
3147 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3148 return getStaticPadVal(slice.getDest().getDefiningOp());
3149
3150 return {};
3151}
3152
3153static LogicalResult
3154vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3155 ArrayRef<int64_t> inputVectorSizes,
3156 SmallVectorImpl<Value> &newResults) {
3157 // TODO: Introduce a parent class that will handle the insertion point update.
3158 OpBuilder::InsertionGuard g(rewriter);
3159 rewriter.setInsertionPoint(sliceOp);
3160
3161 TypedValue<RankedTensorType> source = sliceOp.getSource();
3162 auto sourceType = source.getType();
3163 auto resultType = sliceOp.getResultType();
3164
3165 Value padValue = getStaticPadVal(sliceOp);
3166
3167 if (!padValue) {
3168 auto elemType = sourceType.getElementType();
3169 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3170 rewriter.getZeroAttr(elemType));
3171 }
3172
3173 // 2. Get the vector shape
3174 SmallVector<int64_t> vecShape;
3175 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3176 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3177 if (!inputVectorSizes.empty()) {
3178 vecShape.push_back(inputVectorSizes[i]);
3179 } else if (!sourceType.isDynamicDim(i)) {
3180 vecShape.push_back(sourceType.getDimSize(i));
3181 } else if (!resultType.isDynamicDim(i)) {
3182 // Source shape is not statically known, but result shape is.
3183 // Vectorize with size of result shape. This may be larger than the
3184 // source size.
3185 // FIXME: Using rankDiff implies that the source tensor is inserted at
3186 // the end of the destination tensor. However, that's not required.
3187 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3188 } else {
3189 // Neither source nor result dim of padOp is static. Cannot vectorize
3190 // the copy.
3191 return failure();
3192 }
3193 }
3194 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3195
3196 // 3. Generate TransferReadOp + TransferWriteOp
3197 auto loc = sliceOp.getLoc();
3198
3199 // Create read
3200 SmallVector<Value> readIndices(
3201 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3203 rewriter, loc, source, vecType, padValue,
3204 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3205
3206 // Create write
3207 auto writeIndices =
3208 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3209 Operation *write =
3210 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3211 writeIndices, inputVectorSizes.empty());
3212
3213 // 4. Finalize
3214 newResults.push_back(write->getResult(0));
3215
3216 return success();
3217}
3218
3219/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3220/// ```
3221/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3222/// %r = tensor.insert_slice %0
3223/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3224/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3225/// ```
3226/// is rewritten to:
3227/// ```
3228/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3229/// : tensor<?x?xf32>, vector<17x5xf32>
3230/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3231/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3232/// ```
3233///
3234/// This rewrite is possible if:
3235/// - Low padding is static 0.
3236/// - `padOp` result shape is static.
3237/// - The entire padded tensor is inserted.
3238/// (Implies that sizes of `insertOp` are all static.)
3239/// - Only unit strides in `insertOp`.
3240/// - Single, scalar padding value.
3241/// - `padOp` result not used as destination.
3242struct PadOpVectorizationWithInsertSlicePattern
3243 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3244 using VectorizePadOpUserPattern<
3245 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3246
3247 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3248 tensor::InsertSliceOp insertOp) const override {
3249 // Low padding must be static 0.
3250 if (!padOp.hasZeroLowPad())
3251 return failure();
3252 // Only unit stride supported.
3253 if (!insertOp.hasUnitStride())
3254 return failure();
3255 // Pad value must be a constant.
3256 auto padValue = padOp.getConstantPaddingValue();
3257 if (!padValue)
3258 return failure();
3259 // Dynamic shapes not supported.
3260 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3261 return failure();
3262 // Pad result not used as destination.
3263 if (insertOp.getDest() == padOp.getResult())
3264 return failure();
3265
3266 auto vecType = VectorType::get(padOp.getType().getShape(),
3267 padOp.getType().getElementType());
3268 unsigned vecRank = vecType.getRank();
3269 unsigned tensorRank = insertOp.getType().getRank();
3270
3271 // Check if sizes match: Insert the entire tensor into most minor dims.
3272 // (No permutations allowed.)
3273 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3274 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3275 if (!llvm::all_of(
3276 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3277 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3278 }))
3279 return failure();
3280
3281 // Insert the TransferReadOp and TransferWriteOp at the position of the
3282 // InsertSliceOp.
3283 rewriter.setInsertionPoint(insertOp);
3284
3285 // Generate TransferReadOp: Read entire source tensor and add high
3286 // padding.
3287 SmallVector<Value> readIndices(
3288 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3289 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3290 vecType, padOp.getSource(),
3291 readIndices, padValue);
3292
3293 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3294 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3295 // source must fit into the destination at the specified offsets.
3296 auto writeIndices = getValueOrCreateConstantIndexOp(
3297 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3298 SmallVector<bool> inBounds(vecRank, true);
3299 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3300 insertOp, read, insertOp.getDest(), writeIndices,
3301 ArrayRef<bool>{inBounds});
3302
3303 return success();
3304 }
3305};
3306
3308 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3309 patterns.add<PadOpVectorizationWithTransferReadPattern,
3310 PadOpVectorizationWithTransferWritePattern,
3311 PadOpVectorizationWithInsertSlicePattern>(
3312 patterns.getContext(), baseBenefit.getBenefit() + 1);
3313}
3314
3315//----------------------------------------------------------------------------//
3316// Forwarding patterns
3317//----------------------------------------------------------------------------//
3318
3319/// Check whether there is any interleaved use of any `values` between
3320/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3321/// is in a different block.
3322static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3323 ValueRange values) {
3324 if (firstOp->getBlock() != secondOp->getBlock() ||
3325 !firstOp->isBeforeInBlock(secondOp)) {
3326 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3327 << ", second op: " << *secondOp;
3328 return true;
3329 }
3330 for (auto v : values) {
3331 for (auto &u : v.getUses()) {
3332 Operation *owner = u.getOwner();
3333 if (owner == firstOp || owner == secondOp)
3334 continue;
3335 // TODO: this is too conservative, use dominance info in the future.
3336 if (owner->getBlock() == firstOp->getBlock() &&
3337 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3338 continue;
3339 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3340 << ", second op: " << *secondOp;
3341 return true;
3342 }
3343 }
3344 return false;
3345}
3346
3347/// Return the unique subview use of `v` if it is indeed unique, null
3348/// otherwise.
3349static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3350 memref::SubViewOp subViewOp;
3351 for (auto &u : v.getUses()) {
3352 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3353 if (subViewOp)
3354 return memref::SubViewOp();
3355 subViewOp = newSubViewOp;
3356 }
3357 }
3358 return subViewOp;
3359}
3360
3361/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3362/// when available.
3364 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3365
3366 // TODO: support mask.
3367 if (xferOp.getMask())
3368 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3369
3370 // Transfer into `view`.
3371 Value viewOrAlloc = xferOp.getBase();
3372 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3373 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3374 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3375
3376 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3377 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3378 if (!subViewOp)
3379 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3380 Value subView = subViewOp.getResult();
3381
3382 // Find the copy into `subView` without interleaved uses.
3383 memref::CopyOp copyOp;
3384 for (auto &u : subView.getUses()) {
3385 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3386 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3387 if (newCopyOp.getTarget() != subView)
3388 continue;
3389 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3390 continue;
3391 copyOp = newCopyOp;
3392 break;
3393 }
3394 }
3395 if (!copyOp)
3396 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3397
3398 // Find the fill into `viewOrAlloc` without interleaved uses before the
3399 // copy.
3400 FillOp maybeFillOp;
3401 for (auto &u : viewOrAlloc.getUses()) {
3402 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3403 assert(isa<MemRefType>(newFillOp.output().getType()));
3404 if (newFillOp.output() != viewOrAlloc)
3405 continue;
3406 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3407 continue;
3408 maybeFillOp = newFillOp;
3409 break;
3410 }
3411 }
3412 // Ensure padding matches.
3413 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3414 return rewriter.notifyMatchFailure(xferOp,
3415 "padding value does not match fill");
3416
3417 // `in` is the subview that memref.copy reads. Replace it.
3418 Value in = copyOp.getSource();
3419
3420 // memref.copy + linalg.fill can be used to create a padded local buffer.
3421 // The `masked` attribute is only valid on this padded buffer.
3422 // When forwarding to vector.transfer_read, the attribute must be reset
3423 // conservatively.
3424 auto vectorType = xferOp.getVectorType();
3425 Value res = vector::TransferReadOp::create(
3426 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3427 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3428 rewriter.getBoolArrayAttr(
3429 SmallVector<bool>(vectorType.getRank(), false)));
3430
3431 if (maybeFillOp)
3432 rewriter.eraseOp(maybeFillOp);
3433 rewriter.eraseOp(copyOp);
3434 rewriter.replaceOp(xferOp, res);
3435
3436 return success();
3437}
3438
3439/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3440/// when available.
3442 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3443 // TODO: support mask.
3444 if (xferOp.getMask())
3445 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3446
3447 // Transfer into `viewOrAlloc`.
3448 Value viewOrAlloc = xferOp.getBase();
3449 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3450 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3451 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3452
3453 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3454 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3455 if (!subViewOp)
3456 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3457 Value subView = subViewOp.getResult();
3458
3459 // Find the copy from `subView` without interleaved uses.
3460 memref::CopyOp copyOp;
3461 for (auto &u : subViewOp.getResult().getUses()) {
3462 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3463 if (newCopyOp.getSource() != subView)
3464 continue;
3465 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3466 continue;
3467 copyOp = newCopyOp;
3468 break;
3469 }
3470 }
3471 if (!copyOp)
3472 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3473
3474 // `out` is the subview copied into that we replace.
3475 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3476 Value out = copyOp.getTarget();
3477
3478 // Forward vector.transfer into copy.
3479 // memref.copy + linalg.fill can be used to create a padded local buffer.
3480 // The `masked` attribute is only valid on this padded buffer.
3481 // When forwarding to vector.transfer_write, the attribute must be reset
3482 // conservatively.
3483 auto vector = xferOp.getVector();
3484 vector::TransferWriteOp::create(
3485 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3486 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3487 rewriter.getBoolArrayAttr(SmallVector<bool>(
3488 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3489
3490 rewriter.eraseOp(copyOp);
3491 rewriter.eraseOp(xferOp);
3492
3493 return success();
3494}
3495
3496//===----------------------------------------------------------------------===//
3497// Convolution vectorization patterns
3498//===----------------------------------------------------------------------===//
3499
3500template <int N>
3501static void bindShapeDims(ShapedType shapedType) {}
3502
3503template <int N, typename IntTy, typename... IntTy2>
3504static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3505 val = shapedType.getShape()[N];
3506 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3507}
3508
3509/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3510template <typename... IntTy>
3511static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3512 bindShapeDims<0>(shapedType, vals...);
3513}
3514
3515/// Match 1D convolution or pooling operations and return their dilations and
3516/// strides. Returns std::nullopt for unrecognized ops.
3517static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3518#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3519 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3520 return convParams;
3521
3522 // 1D Convolution ops.
3523 MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
3524 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
3525 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
3526 // Depthwise 1D Convolution ops.
3527 // Note: Only NWC layout without channel multiplier is supported.
3528 // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
3529 // are not supported.
3530 MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
3531 // 1D Pooling ops (NWC layout).
3532 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
3533 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
3534 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
3535 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
3536 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
3537 // 1D Pooling ops (NCW layout).
3538 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
3539 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
3540
3541#undef MATCH_1D_CONV_POOL_OP
3542
3543 return std::nullopt;
3544}
3545
3546namespace {
3547/// Generate a vector implementation for either:
3548/// ```
3549/// Op def: ( w, kw )
3550/// Iters: ({Par(), Red()})
3551/// Layout: {{w + kw}, {kw}, {w}}
3552/// ```
3553/// kw is unrolled.
3554///
3555/// or
3556///
3557/// ```
3558/// Op def: ( n, w, c, kw, f )
3559/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3560/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3561/// ```
3562/// kw is unrolled, w is unrolled iff dilationW > 1.
3563///
3564/// or
3565///
3566/// ```
3567/// Op def: ( n, c, w, f, kw )
3568/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3569/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3570/// ```
3571/// kw is unrolled, w is unrolled iff dilationW > 1.
3572///
3573/// or
3574///
3575/// ```
3576/// Op def: ( n, w, c, kw )
3577/// Iters: ({Par(), Par(), Par(), Red()})
3578/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3579/// ```
3580/// kw is unrolled, w is unrolled iff dilationW > 1.
3581struct Conv1DGenerator
3582 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3583 /// Factory method to create a Conv1DGenerator. Returns failure if the
3584 /// operation doesn't have valid strides/dilations.
3585 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3586 LinalgOp linalgOp) {
3587 // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
3588 // works for both named ops and generic ops that match their semantics.
3589 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3590 if (!convParams)
3591 return failure();
3592
3593 int strideW = static_cast<int>(convParams->strides.front());
3594 int dilationW = static_cast<int>(convParams->dilations.front());
3595 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3596 }
3597
3598private:
3599 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3600 int dilationW)
3601 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3602 strideW(strideW), dilationW(dilationW) {
3603
3604 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3605 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3606 resShaped = linalgOp.getDpsInitOperand(0)->get();
3607 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3608 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3609 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3610
3611 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3612 redOp = reduceOp->getName().getIdentifier();
3613
3614 setConvOperationKind(reduceOp);
3615
3616 auto maybeKind = getCombinerOpKind(reduceOp);
3617 reductionKind = maybeKind.value();
3618 }
3619
3620public:
3621 /// Generate a vector implementation for:
3622 /// ```
3623 /// Op def: ( w, kw )
3624 /// Iters: ({Par(), Red()})
3625 /// Layout: {{w + kw}, {kw}, {w}}
3626 /// ```
3627 /// kw is always unrolled.
3628 ///
3629 /// or
3630 ///
3631 /// ```
3632 /// Op def: ( n, w, c, kw, f )
3633 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3634 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3635 /// ```
3636 /// kw is always unrolled.
3637 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3638 /// > 1.
3639 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3640 int64_t nSize, wSize, cSize, kwSize, fSize;
3641 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3642 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3643 switch (conv1DOpOrder) {
3644 case Conv1DOpOrder::W:
3645 // Initialize unused dimensions
3646 nSize = fSize = cSize = 0;
3647 // out{W}
3648 bindShapeDims(resShapedType, wSize);
3649 // kernel{kw}
3650 bindShapeDims(rhsShapedType, kwSize);
3651 lhsShape = {// iw = ow + kw - 1
3652 // (i.e. 16 convolved with 3 -> 14)
3653 (wSize + kwSize - 1)};
3654 rhsShape = {kwSize};
3655 resShape = {wSize};
3656 break;
3657 case Conv1DOpOrder::Nwc:
3658 // out{n, w, f}
3659 bindShapeDims(resShapedType, nSize, wSize, fSize);
3660 switch (oper) {
3661 case ConvOperationKind::Conv:
3662 // kernel{kw, c, f}
3663 bindShapeDims(rhsShapedType, kwSize, cSize);
3664 break;
3665 case ConvOperationKind::Pool:
3666 // kernel{kw}
3667 bindShapeDims(rhsShapedType, kwSize);
3668 cSize = fSize;
3669 break;
3670 }
3671 lhsShape = {nSize,
3672 // iw = ow * sw + kw * dw - 1
3673 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3674 // Perform the proper inclusive -> exclusive -> inclusive.
3675 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3676 1,
3677 cSize};
3678 switch (oper) {
3679 case ConvOperationKind::Conv:
3680 rhsShape = {kwSize, cSize, fSize};
3681 break;
3682 case ConvOperationKind::Pool:
3683 rhsShape = {kwSize};
3684 break;
3685 }
3686 resShape = {nSize, wSize, fSize};
3687 break;
3688 case Conv1DOpOrder::Ncw:
3689 // out{n, f, w}
3690 bindShapeDims(resShapedType, nSize, fSize, wSize);
3691 switch (oper) {
3692 case ConvOperationKind::Conv:
3693 // kernel{f, c, kw}
3694 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3695 break;
3696 case ConvOperationKind::Pool:
3697 // kernel{kw}
3698 bindShapeDims(rhsShapedType, kwSize);
3699 cSize = fSize;
3700 break;
3701 }
3702 lhsShape = {nSize, cSize,
3703 // iw = ow * sw + kw * dw - 1
3704 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3705 // Perform the proper inclusive -> exclusive -> inclusive.
3706 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3707 1};
3708 switch (oper) {
3709 case ConvOperationKind::Conv:
3710 rhsShape = {fSize, cSize, kwSize};
3711 break;
3712 case ConvOperationKind::Pool:
3713 rhsShape = {kwSize};
3714 break;
3715 }
3716 resShape = {nSize, fSize, wSize};
3717 break;
3718 }
3719
3720 vector::TransferWriteOp write;
3721 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3722
3723 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3724 // When strideW == 1, we can batch the contiguous loads and avoid
3725 // unrolling
3726 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3727
3728 Type lhsEltType = lhsShapedType.getElementType();
3729 Type rhsEltType = rhsShapedType.getElementType();
3730 Type resEltType = resShapedType.getElementType();
3731 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3732 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3733 auto resType = VectorType::get(resShape, resEltType);
3734 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3735 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3736 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3737 SmallVector<Value> resPadding(resShape.size(), zero);
3738
3739 // Read the whole lhs, rhs and res in one shot (with zero padding).
3740 Value lhs = vector::TransferReadOp::create(
3741 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3742 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3743 // This is needed only for Conv.
3744 Value rhs = nullptr;
3745 if (oper == ConvOperationKind::Conv)
3746 rhs = vector::TransferReadOp::create(
3747 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3748 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3749 Value res = vector::TransferReadOp::create(
3750 rewriter, loc, resType, resShaped, resPadding,
3751 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3752
3753 // The base vectorization case for channeled convolution is input:
3754 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3755 // vectorization case, we do pre transpose on input, weight, and output.
3756 switch (conv1DOpOrder) {
3757 case Conv1DOpOrder::W:
3758 case Conv1DOpOrder::Nwc:
3759 // Base case, so no transposes necessary.
3760 break;
3761 case Conv1DOpOrder::Ncw: {
3762 // To match base vectorization case, we pre-transpose current case.
3763 // ncw -> nwc
3764 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3765 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3766 // fcw -> wcf
3767 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3768
3769 // This is needed only for Conv.
3770 if (oper == ConvOperationKind::Conv)
3771 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3772 // nfw -> nwf
3773 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3774 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3775 break;
3776 }
3777 }
3778
3779 //===------------------------------------------------------------------===//
3780 // Begin vector-only rewrite part
3781 //===------------------------------------------------------------------===//
3782 // Unroll along kw and read slices of lhs and rhs.
3783 SmallVector<Value> lhsVals, rhsVals, resVals;
3784 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3785 kwSize, strideW, dilationW, wSizeStep,
3786 isSingleChanneled);
3787 // Do not do for pooling.
3788 if (oper == ConvOperationKind::Conv)
3789 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3790 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3791 wSizeStep, isSingleChanneled);
3792
3793 auto linearIndex = [&](int64_t kw, int64_t w) {
3794 return kw * (wSize / wSizeStep) + w;
3795 };
3796
3797 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3798 // or perform outerproduct for non-channeled convolution or perform simple
3799 // arith operation for pooling
3800 for (int64_t kw = 0; kw < kwSize; ++kw) {
3801 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3802 switch (oper) {
3803 case ConvOperationKind::Conv:
3804 if (isSingleChanneled) {
3805 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3806 lhsVals[linearIndex(kw, w)],
3807 rhsVals[kw], resVals[w]);
3808 } else {
3809 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3810 lhsVals[linearIndex(kw, w)],
3811 rhsVals[kw], resVals[w]);
3812 }
3813 break;
3814 case ConvOperationKind::Pool:
3815 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3816 resVals[w]);
3817 break;
3818 }
3819 }
3820 }
3821
3822 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3823 isSingleChanneled);
3824 //===------------------------------------------------------------------===//
3825 // End vector-only rewrite part
3826 //===------------------------------------------------------------------===//
3827
3828 // The base vectorization case for channeled convolution is output:
3829 // {n,w,f} To reuse the result from base pattern vectorization case, we
3830 // post transpose the base case result.
3831 switch (conv1DOpOrder) {
3832 case Conv1DOpOrder::W:
3833 case Conv1DOpOrder::Nwc:
3834 // Base case, so no transposes necessary.
3835 break;
3836 case Conv1DOpOrder::Ncw: {
3837 // nwf -> nfw
3838 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3839 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3840 break;
3841 }
3842 }
3843
3844 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3845 resPadding)
3846 .getOperation();
3847 }
3848
3849 // Take a value and widen to have the same element type as `ty`.
3850 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3851 const Type srcElementType = getElementTypeOrSelf(val.getType());
3852 const Type dstElementType = getElementTypeOrSelf(ty);
3853 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3854 if (srcElementType == dstElementType)
3855 return val;
3856
3857 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3858 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3859 // Handle both shaped as well as scalar types.
3860 Type dstType;
3861 if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
3862 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3863 else
3864 dstType = dstElementType;
3865
3866 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3867 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3868 }
3869
3870 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3871 srcWidth < dstWidth)
3872 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3873
3874 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3875 srcWidth < dstWidth)
3876 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3877
3878 assert(false && "unhandled promotion case");
3879 return nullptr;
3880 }
3881
3882 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3883 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3884 Value lhs, Value rhs, Value res) {
3885 vector::IteratorType par = vector::IteratorType::parallel;
3886 vector::IteratorType red = vector::IteratorType::reduction;
3887 AffineExpr n, w, f, c;
3888 bindDims(ctx, n, w, f, c);
3889 lhs = promote(rewriter, loc, lhs, res.getType());
3890 rhs = promote(rewriter, loc, rhs, res.getType());
3891 auto contrationOp = vector::ContractionOp::create(
3892 rewriter, loc, lhs, rhs, res,
3893 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3894 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3895 contrationOp.setKind(reductionKind);
3896 return contrationOp;
3897 }
3898
3899 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3900 // convolution.
3901 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3902 Value lhs, Value rhs, Value res) {
3903 lhs = promote(rewriter, loc, lhs, res.getType());
3904 rhs = promote(rewriter, loc, rhs, res.getType());
3905 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3906 rhs, res, vector::CombiningKind::ADD);
3907 }
3908
3909 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3910 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3911 Value res) {
3912 if (isPoolExt)
3913 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3914 return rewriter
3915 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3916 ->getResult(0);
3917 }
3918
3919 /// Generate a vector implementation for:
3920 /// ```
3921 /// Op def: ( n, w, c, kw)
3922 /// Iters: ({Par(), Par(), Par(), Red()})
3923 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3924 /// ```
3925 /// kw is always unrolled.
3926 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3927 /// > 1.
3928 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3929 bool channelDimScalableFlag,
3930 bool flatten) {
3931 bool scalableChDim = false;
3932 bool useMasking = false;
3933 int64_t nSize, wSize, cSize, kwSize;
3934 // kernel{kw, c}
3935 bindShapeDims(rhsShapedType, kwSize, cSize);
3936 if (ShapedType::isDynamic(cSize)) {
3937 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3938 cSize = channelDimVecSize;
3939 // Scalable vectors are only used when both conditions are met:
3940 // 1. channel dim is dynamic
3941 // 2. channelDimScalableFlag is set
3942 scalableChDim = channelDimScalableFlag;
3943 useMasking = true;
3944 }
3945
3946 assert(!(useMasking && flatten) &&
3947 "Unsupported flattened conv with dynamic shapes");
3948
3949 // out{n, w, c}
3950 bindShapeDims(resShapedType, nSize, wSize);
3951
3952 vector::TransferWriteOp write;
3953 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3954
3955 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3956 // When strideW == 1, we can batch the contiguous loads and avoid
3957 // unrolling
3958 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3959
3960 Type lhsEltType = lhsShapedType.getElementType();
3961 Type rhsEltType = rhsShapedType.getElementType();
3962 Type resEltType = resShapedType.getElementType();
3963 VectorType lhsType = VectorType::get(
3964 {nSize,
3965 // iw = ow * sw + kw * dw - 1
3966 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3967 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3968 cSize},
3969 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3970 VectorType rhsType =
3971 VectorType::get({kwSize, cSize}, rhsEltType,
3972 /*scalableDims=*/{false, scalableChDim});
3973 VectorType resType =
3974 VectorType::get({nSize, wSize, cSize}, resEltType,
3975 /*scalableDims=*/{false, false, scalableChDim});
3976
3977 // Masks the input xfer Op along the channel dim, iff the corresponding
3978 // scalable flag is set.
3979 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3980 ArrayRef<bool> scalableDims,
3981 Operation *opToMask) {
3982 if (!useMasking)
3983 return opToMask;
3984 auto maskType =
3985 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3986
3987 SmallVector<bool> inBounds(maskShape.size(), true);
3988 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3989 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3990 rewriter.getBoolArrayAttr(inBounds));
3991
3992 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3993 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3994
3995 Value maskOp =
3996 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3997
3998 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3999 };
4000
4001 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
4002 // 0].
4003 Value lhs = vector::TransferReadOp::create(
4004 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
4005 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
4006 auto *maybeMaskedLhs = maybeMaskXferOp(
4007 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
4008
4009 // Read rhs slice of size {kw, c} @ [0, 0].
4010 Value rhs = vector::TransferReadOp::create(
4011 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
4012 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
4013 auto *maybeMaskedRhs = maybeMaskXferOp(
4014 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
4015
4016 // Read res slice of size {n, w, c} @ [0, 0, 0].
4017 Value res = vector::TransferReadOp::create(
4018 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
4019 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
4020 auto *maybeMaskedRes = maybeMaskXferOp(
4021 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
4022
4023 //===------------------------------------------------------------------===//
4024 // Begin vector-only rewrite part
4025 //===------------------------------------------------------------------===//
4026 // Unroll along kw and read slices of lhs and rhs.
4027 SmallVector<Value> lhsVals, rhsVals, resVals;
4028 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4029 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4030
4031 // Extract lhs slice of size {n, wSizeStep, c}
4032 // @ [0, sw * w + dw * kw, 0].
4033 for (int64_t kw = 0; kw < kwSize; ++kw) {
4034 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4035 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4036 rewriter, loc, maybeMaskedLhs->getResult(0),
4037 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4038 inOutSliceSizes, inOutStrides));
4039 }
4040 }
4041 // Extract rhs slice of size {c} @ [kw].
4042 for (int64_t kw = 0; kw < kwSize; ++kw) {
4043 rhsVals.push_back(
4044 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4045 /*offsets=*/ArrayRef<int64_t>{kw}));
4046 }
4047 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
4048 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4049 resVals.push_back(vector::ExtractStridedSliceOp::create(
4050 rewriter, loc, maybeMaskedRes->getResult(0),
4051 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4052 inOutStrides));
4053 }
4054
4055 auto linearIndex = [&](int64_t kw, int64_t w) {
4056 return kw * (wSize / wSizeStep) + w;
4057 };
4058
4059 // Note - the scalable flags are ignored as flattening combined with
4060 // scalable vectorization is not supported.
4061 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4062 auto lhsTypeAfterFlattening =
4063 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4064 auto resTypeAfterFlattening =
4065 VectorType::get(inOutFlattenSliceSizes, resEltType);
4066
4067 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
4068 for (int64_t kw = 0; kw < kwSize; ++kw) {
4069 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4070 Value lhsVal = lhsVals[linearIndex(kw, w)];
4071 Value resVal = resVals[w];
4072 if (flatten) {
4073 // Flatten the input and output vectors (collapse the channel
4074 // dimension)
4075 lhsVal =
4076 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4077 lhsVals[linearIndex(kw, w)]);
4078 resVal = vector::ShapeCastOp::create(
4079 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4080 }
4081 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4082 rhsVals[kw], resVal, flatten);
4083 if (flatten) {
4084 // Un-flatten the output vector (restore the channel dimension)
4085 resVals[w] = vector::ShapeCastOp::create(
4086 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4087 resVals[w]);
4088 }
4089 }
4090 }
4091
4092 // Its possible we failed to create the Fma.
4093 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4094 // Manually revert (in reverse order) to avoid leaving a bad IR state.
4095 for (auto &collection :
4096 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4097 for (Value v : collection)
4098 rewriter.eraseOp(v.getDefiningOp());
4099 return rewriter.notifyMatchFailure(op, "failed to create FMA");
4100 }
4101
4102 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4103 // This does not depend on kw.
4104 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4105 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4106 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4107 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4108 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4109 }
4110 //===------------------------------------------------------------------===//
4111 // End vector-only rewrite part
4112 //===------------------------------------------------------------------===//
4113
4114 // Write back res slice of size {n, w, c} @ [0, 0, 0].
4115 Operation *resOut = vector::TransferWriteOp::create(
4116 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4117 ValueRange{zero, zero, zero});
4118 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4119 resOut);
4120 }
4121
4122 /// Lower:
4123 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4124 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4125 /// to MulAcc.
4126 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4127 Value lhs, Value rhs, Value res,
4128 bool flatten) {
4129 auto rhsTy = cast<ShapedType>(rhs.getType());
4130 auto resTy = cast<ShapedType>(res.getType());
4131
4132 // TODO(suderman): Change this to use a vector.ima intrinsic.
4133 lhs = promote(rewriter, loc, lhs, resTy);
4134
4135 if (flatten) {
4136 // NOTE: This following logic won't work for scalable vectors. For this
4137 // reason, "flattening" is not supported when shapes are dynamic (this
4138 // should be captured by one of the pre-conditions).
4139
4140 // There are two options for handling the filter:
4141 // * shape_cast(broadcast(filter))
4142 // * broadcast(shuffle(filter))
4143 // Opt for the option without shape_cast to simplify the codegen.
4144 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4145 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4146
4147 SmallVector<int64_t, 16> indices;
4148 for (int i = 0; i < resSize / rhsSize; ++i) {
4149 for (int j = 0; j < rhsSize; ++j)
4150 indices.push_back(j);
4151 }
4152
4153 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4154 }
4155 // Broadcast the filter to match the output vector
4156 rhs = vector::BroadcastOp::create(rewriter, loc,
4157 resTy.clone(rhsTy.getElementType()), rhs);
4158
4159 rhs = promote(rewriter, loc, rhs, resTy);
4160
4161 if (!lhs || !rhs)
4162 return nullptr;
4163
4164 if (isa<FloatType>(resTy.getElementType()))
4165 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4166
4167 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4168 return arith::AddIOp::create(rewriter, loc, mul, res);
4169 }
4170
4171 /// Entry point for non-channeled convolution:
4172 /// {{w + kw}, {kw}, {w}}
4173 FailureOr<Operation *> generateNonChanneledConv() {
4174 AffineExpr w, kw;
4175 bindDims(ctx, w, kw);
4176 if (!iters({Par(), Red()}))
4177 return rewriter.notifyMatchFailure(op,
4178 "failed to match conv::W 1-par 1-red");
4179
4180 // No transposition needed.
4181 if (layout({/*lhsIndex*/ {w + kw},
4182 /*rhsIndex*/ {kw},
4183 /*resIndex*/ {w}}))
4184 return conv(Conv1DOpOrder::W);
4185
4186 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4187 }
4188
4189 /// Entry point that transposes into the common form:
4190 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4191 FailureOr<Operation *> generateNwcConv() {
4192 AffineExpr n, w, f, kw, c;
4193 bindDims(ctx, n, w, f, kw, c);
4194 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4195 return rewriter.notifyMatchFailure(
4196 op, "failed to match conv::Nwc 3-par 2-red");
4197
4198 // No transposition needed.
4199 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4200 /*rhsIndex*/ {kw, c, f},
4201 /*resIndex*/ {n, w, f}}))
4202 return conv(Conv1DOpOrder::Nwc);
4203
4204 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4205 }
4206
4207 /// Entry point that transposes into the common form:
4208 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4209 FailureOr<Operation *> generateNcwConv() {
4210 AffineExpr n, w, f, kw, c;
4211 bindDims(ctx, n, f, w, c, kw);
4212 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4213 return rewriter.notifyMatchFailure(
4214 op, "failed to match conv::Ncw 3-par 2-red");
4215
4216 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4217 /*rhsIndex*/ {f, c, kw},
4218 /*resIndex*/ {n, f, w}}))
4219 return conv(Conv1DOpOrder::Ncw);
4220
4221 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4222 }
4223
4224 /// Entry point that transposes into the common form:
4225 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4226 FailureOr<Operation *> generateNwcPooling() {
4227 AffineExpr n, w, c, kw;
4228 bindDims(ctx, n, w, c, kw);
4229 if (!iters({Par(), Par(), Par(), Red()}))
4230 return rewriter.notifyMatchFailure(op,
4231 "failed to match pooling 3-par 1-red");
4232
4233 // No transposition needed.
4234 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4235 /*rhsIndex*/ {kw},
4236 /*resIndex*/ {n, w, c}}))
4237 return conv(Conv1DOpOrder::Nwc);
4238
4239 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4240 }
4241
4242 /// Entry point that transposes into the common form:
4243 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4244 FailureOr<Operation *> generateNcwPooling() {
4245 AffineExpr n, w, c, kw;
4246 bindDims(ctx, n, c, w, kw);
4247 if (!iters({Par(), Par(), Par(), Red()}))
4248 return rewriter.notifyMatchFailure(op,
4249 "failed to match pooling 3-par 1-red");
4250
4251 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4252 /*rhsIndex*/ {kw},
4253 /*resIndex*/ {n, c, w}}))
4254 return conv(Conv1DOpOrder::Ncw);
4255
4256 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4257 }
4258
4259 /// Entry point that transposes into the common form:
4260 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4261 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4262 bool vecChDimScalableFlag = false,
4263 bool flatten = false) {
4264 AffineExpr n, w, c, kw;
4265 bindDims(ctx, n, w, c, kw);
4266 if (!iters({Par(), Par(), Par(), Red()}))
4267 return rewriter.notifyMatchFailure(
4268 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4269
4270 // No transposition needed.
4271 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4272 /*rhsIndex*/ {kw, c},
4273 /*resIndex*/ {n, w, c}}))
4274 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4275
4276 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4277 }
4278
4279private:
4280 ConvOperationKind oper = ConvOperationKind::Conv;
4281 StringAttr redOp;
4282 StringAttr poolExtOp;
4283 bool isPoolExt = false;
4284 int strideW, dilationW;
4285 Value lhsShaped, rhsShaped, resShaped;
4286 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4287 vector::CombiningKind reductionKind;
4288
4289 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4290 void setConvOperationKind(Operation *reduceOp) {
4291 int numBlockArguments =
4292 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4293 if (numBlockArguments == 1) {
4294 // Will be convolution if feeder is a MulOp.
4295 // A strength reduced version of MulOp for i1 type is AndOp which is also
4296 // supported. Otherwise, it can be pooling. This strength reduction logic
4297 // is in `buildBinaryFn` helper in the Linalg dialect.
4298 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4299 llvm::IsaPred<BlockArgument>);
4300 Operation *feedOp = (*feedValIt).getDefiningOp();
4301 if (isCastOfBlockArgument(feedOp)) {
4302 oper = ConvOperationKind::Pool;
4303 isPoolExt = true;
4304 poolExtOp = feedOp->getName().getIdentifier();
4305 return;
4306 }
4307 oper = ConvOperationKind::Conv;
4308 return;
4309 }
4310 // numBlockArugments == 2 and this is a pooling op.
4311 oper = ConvOperationKind::Pool;
4312 isPoolExt = false;
4313 }
4314};
4315} // namespace
4316
4317/// Helper function to vectorize a LinalgOp with convolution semantics.
4318// TODO: extend the generic vectorization to support windows and drop this.
4319static FailureOr<Operation *> vectorizeConvolution(
4320 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4321 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4322 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4323 if (failed(conv1dGen))
4324 return failure();
4325 auto res = conv1dGen->generateNonChanneledConv();
4326 if (succeeded(res))
4327 return res;
4328 res = conv1dGen->generateNwcConv();
4329 if (succeeded(res))
4330 return res;
4331 res = conv1dGen->generateNcwConv();
4332 if (succeeded(res))
4333 return res;
4334 res = conv1dGen->generateNwcPooling();
4335 if (succeeded(res))
4336 return res;
4337 res = conv1dGen->generateNcwPooling();
4338 if (succeeded(res))
4339 return res;
4340
4341 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4342 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4343 // masked/scalable) is the channel dim (i.e. the trailing dim).
4344 uint64_t vecChDimSize = ShapedType::kDynamic;
4345 bool vecChDimScalableFlag = false;
4346 if (!inputVecSizes.empty()) {
4347 // Only use the input vector size corresponding to the channel dim. Other
4348 // vector dims will be inferred from the Ops.
4351 "Not a 1D depthwise conv!");
4352 size_t chDimIdx = 0;
4354 chDimIdx = 2;
4356 chDimIdx = 1;
4357
4358 vecChDimSize = inputVecSizes[chDimIdx];
4359 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4360 }
4361 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4362 flatten1DDepthwiseConv);
4363}
4364
4365struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4367
4368 LogicalResult matchAndRewrite(LinalgOp op,
4369 PatternRewriter &rewriter) const override {
4370 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4371 if (failed(resultOrFail))
4372 return failure();
4373 Operation *newOp = *resultOrFail;
4374 if (newOp->getNumResults() == 0) {
4375 rewriter.eraseOp(op.getOperation());
4376 return success();
4377 }
4378 assert(newOp->getNumResults() == 1 && "expected single result");
4379 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4380 return success();
4381 }
4382};
4383
4385 RewritePatternSet &patterns, PatternBenefit benefit) {
4386 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4387}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
unsigned getNumOperands()
Definition Operation.h:375
operand_iterator operand_end()
Definition Operation.h:404
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:457
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:197
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
Definition Utils.h:126
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override