MLIR 23.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
35#include "mlir/IR/Value.h"
36#include "mlir/Support/LLVM.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/DebugLog.h"
43#include "llvm/Support/InterleavedRange.h"
44#include "llvm/Support/MathExtras.h"
45#include "llvm/Support/raw_ostream.h"
46#include <optional>
47
48using namespace mlir;
49using namespace mlir::linalg;
50
51#define DEBUG_TYPE "linalg-vectorization"
52
53/// Try to vectorize `convOp` as a convolution.
54static FailureOr<Operation *>
55vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
56 ArrayRef<int64_t> inputVecSizes = {},
57 ArrayRef<bool> inputVecScalableFlags = {},
58 bool flatten1DDepthwiseConv = false);
59
60/// Vectorize tensor::InsertSliceOp with:
61/// * vector::TransferReadOp + vector::TransferWriteOp
62/// The vector sizes are either:
63/// * user-provided in `inputVectorSizes`, or
64/// * inferred from the static dims in the input and output tensors.
65/// Bails out if:
66/// * vector sizes are not user-provided, and
67/// * at least one dim is dynamic (in both the input and output tensors).
68///
69/// Before:
70/// !t_in_type = tensor<1x2x3xf32>
71/// !t_out_type = tensor<9x8x7x1x2x3xf32>
72/// !v_type = vector<1x2x3xf32>
73/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
74/// into !t_out_type
75/// After:
76/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
77/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
78static LogicalResult
79vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
80 ArrayRef<int64_t> inputVectorSizes,
81 SmallVectorImpl<Value> &newResults);
82
83/// Returns the effective Pad value for the input op, provided it's a scalar.
84///
85/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
86/// this Op performs padding, retrieve the padding value provided that it's
87/// a scalar and static/fixed for all the padded values. Returns an empty value
88/// otherwise.
90
91/// Return the unique instance of OpType in `block` if it is indeed unique.
92/// Return null if none or more than 1 instances exist.
93template <typename OpType>
94static OpType getSingleOpOfType(Block &block) {
95 OpType res;
96 block.walk([&](OpType op) {
97 if (res) {
98 res = nullptr;
99 return WalkResult::interrupt();
100 }
101 res = op;
102 return WalkResult::advance();
103 });
104 return res;
105}
106
107/// Helper function to extract the input slices after filter is unrolled along
108/// kw.
111 int64_t nSize, int64_t wSize, int64_t cSize,
112 int64_t kwSize, int strideW, int dilationW,
113 int64_t wSizeStep, bool isSingleChanneled) {
115 if (isSingleChanneled) {
116 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
117 // convolution.
118 SmallVector<int64_t> sizes = {wSizeStep};
119 SmallVector<int64_t> strides = {1};
120 for (int64_t kw = 0; kw < kwSize; ++kw) {
121 for (int64_t w = 0; w < wSize; w += wSizeStep) {
122 result.push_back(vector::ExtractStridedSliceOp::create(
123 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
124 strides));
125 }
126 }
127 } else {
128 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
129 // for channeled convolution.
130 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
131 SmallVector<int64_t> strides = {1, 1, 1};
132 for (int64_t kw = 0; kw < kwSize; ++kw) {
133 for (int64_t w = 0; w < wSize; w += wSizeStep) {
134 result.push_back(vector::ExtractStridedSliceOp::create(
135 rewriter, loc, input,
136 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
137 sizes, strides));
138 }
139 }
140 }
141 return result;
142}
143
144/// Helper function to extract the filter slices after filter is unrolled along
145/// kw.
147 Location loc, Value filter,
148 int64_t kwSize) {
150 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
151 // non-chanelled convolution] @ [kw].
152 for (int64_t kw = 0; kw < kwSize; ++kw) {
153 result.push_back(vector::ExtractOp::create(
154 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
155 }
156 return result;
157}
158
159/// Helper function to extract the result slices after filter is unrolled along
160/// kw.
163 int64_t nSize, int64_t wSize, int64_t fSize,
164 int64_t wSizeStep, bool isSingleChanneled) {
166 if (isSingleChanneled) {
167 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
168 SmallVector<int64_t> sizes = {wSizeStep};
169 SmallVector<int64_t> strides = {1};
170 for (int64_t w = 0; w < wSize; w += wSizeStep) {
171 result.push_back(vector::ExtractStridedSliceOp::create(
172 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
173 strides));
174 }
175 } else {
176 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
177 // convolution.
178 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
179 SmallVector<int64_t> strides = {1, 1, 1};
180 for (int64_t w = 0; w < wSize; w += wSizeStep) {
181 result.push_back(vector::ExtractStridedSliceOp::create(
182 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
183 strides));
184 }
185 }
186 return result;
187}
188
189/// Helper function to insert the computed result slices.
191 Value res, int64_t wSize, int64_t wSizeStep,
192 SmallVectorImpl<Value> &resVals,
193 bool isSingleChanneled) {
194
195 if (isSingleChanneled) {
196 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
197 // This does not depend on kw.
198 SmallVector<int64_t> strides = {1};
199 for (int64_t w = 0; w < wSize; w += wSizeStep) {
200 res = vector::InsertStridedSliceOp::create(
201 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
202 strides);
203 }
204 } else {
205 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
206 // convolution. This does not depend on kw.
207 SmallVector<int64_t> strides = {1, 1, 1};
208 for (int64_t w = 0; w < wSize; w += wSizeStep) {
209 res = vector::InsertStridedSliceOp::create(
210 rewriter, loc, resVals[w], res,
211 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
212 }
213 }
214 return res;
215}
216
217/// Contains the vectorization state and related methods used across the
218/// vectorization process of a given operation.
220 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
221
222 /// Initializes the vectorization state, including the computation of the
223 /// canonical vector shape for vectorization.
224 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
225 ArrayRef<int64_t> inputVectorSizes,
226 ArrayRef<bool> inputScalableVecDims,
227 bool assumeDynamicDimsMatchVecSizes = false);
228
229 /// Returns the canonical vector shape used to vectorize the iteration space.
230 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
231
232 /// Returns the vector dimensions that are scalable in the canonical vector
233 /// shape.
234 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
235
236 /// Returns a vector type of the provided `elementType` with the canonical
237 /// vector shape and the corresponding fixed/scalable dimensions bit. If
238 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
239 /// accordingly.
241 Type elementType,
242 std::optional<AffineMap> dimPermutation = std::nullopt) const {
244 SmallVector<bool> scalableDims;
245 if (dimPermutation.has_value()) {
247 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
248 scalableDims =
249 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
250 } else {
251 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
252 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
253 }
254
255 return VectorType::get(vectorShape, elementType, scalableDims);
256 }
257
258 /// Masks an operation with the canonical vector mask if the operation needs
259 /// masking. Returns the masked operation or the original operation if masking
260 /// is not needed. If provided, the canonical mask for this operation is
261 /// permuted using `maybeIndexingMap`.
262 Operation *
263 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
264 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
265
266private:
267 /// Initializes the iteration space static sizes using the Linalg op
268 /// information. This may become more complicated in the future.
269 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
270 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
271 }
272
273 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
274 /// all the static and dynamic dimensions of the iteration space to be
275 /// vectorized and store them in `iterSpaceValueSizes`.
276 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
277 LinalgOp linalgOp);
278
279 /// Create or retrieve an existing mask value to mask `opToMask` in the
280 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
281 /// permuted using that permutation map. If a new mask is created, it will be
282 /// cached for future users.
283 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
284 LinalgOp linalgOp,
285 std::optional<AffineMap> maybeMaskingMap);
286
287 /// Check whether this permutation map can be used for masking. At the
288 /// moment we only make sure that there are no broadcast dimensions, but this
289 /// might change if indexing maps evolve.
290 bool isValidMaskingMap(AffineMap maskingMap) {
291 return maskingMap.getBroadcastDims().empty();
292 }
293
294 /// Turn the input indexing map into a valid masking map.
295 ///
296 /// The input indexing map may contain "zero" results, e.g.:
297 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
298 /// Applying such maps to canonical vector shapes like this one:
299 /// (1, 16, 16, 4)
300 /// would yield an invalid vector shape like this:
301 /// (16, 16, 1, 0)
302 /// Instead, drop the broadcasting dims that make no sense for masking perm.
303 /// maps:
304 /// (d0, d1, d2, d3) -> (d2, d1, d0)
305 /// This way, the corresponding vector/mask type will be:
306 /// vector<16x16x1xty>
307 /// rather than this invalid Vector type:
308 /// vector<16x16x1x0xty>
309 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
310 return indexingMap.dropZeroResults();
311 }
312
313 // Holds the compile-time static sizes of the iteration space to vectorize.
314 // Dynamic dimensions are represented using ShapedType::kDynamic.
315 SmallVector<int64_t> iterSpaceStaticSizes;
316
317 /// Holds the value sizes of the iteration space to vectorize. Static
318 /// dimensions are represented by 'arith.constant' and dynamic
319 /// dimensions by 'tensor/memref.dim'.
320 SmallVector<Value> iterSpaceValueSizes;
321
322 /// Holds the canonical vector shape used to vectorize the iteration space.
323 SmallVector<int64_t> canonicalVecShape;
324
325 /// Holds the vector dimensions that are scalable in the canonical vector
326 /// shape.
327 SmallVector<bool> scalableVecDims;
328
329 /// Holds the active masks for permutations of the canonical vector iteration
330 /// space.
331 DenseMap<AffineMap, Value> activeMaskCache;
332
333 /// Global vectorization guard for the incoming rewriter. It's initialized
334 /// when the vectorization state is initialized.
335 OpBuilder::InsertionGuard rewriterGuard;
336
337 /// Do all dynamic dims match the corresponding vector sizes?
338 ///
339 /// When a dynamic tensor/memref dimension matches the corresponding vector
340 /// dimension, masking can be safely skipped, despite the presence of dynamic
341 /// shapes. Use this flag with care and only for cases where you are
342 /// confident the assumption holds.
343 bool assumeDynamicDimsMatchVecSizes = false;
344};
345
346LogicalResult
347VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
348 LinalgOp linalgOp) {
349 // TODO: Support 0-d vectors.
350 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
351 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
352 // Create constant index op for static dimensions.
353 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
354 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
355 continue;
356 }
357
358 // Find an operand defined on this dimension of the iteration space to
359 // extract the runtime dimension size.
360 Value operand;
361 unsigned operandDimPos;
362 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
363 operandDimPos)))
364 return failure();
365
366 Value dynamicDim =
367 linalgOp.hasPureTensorSemantics()
368 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
369 operandDimPos)
370 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
371 operandDimPos);
372 iterSpaceValueSizes.push_back(dynamicDim);
373 }
374
375 return success();
376}
377
378/// Initializes the vectorization state, including the computation of the
379/// canonical vector shape for vectorization.
380// TODO: Move this to the constructor when we can remove the failure cases.
382 LinalgOp linalgOp,
383 ArrayRef<int64_t> inputVectorSizes,
384 ArrayRef<bool> inputScalableVecDims,
385 bool assumeDimsMatchVec) {
386 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
387 // Initialize the insertion point.
388 rewriter.setInsertionPoint(linalgOp);
389
390 if (!inputVectorSizes.empty()) {
391 // Get the canonical vector shape from the input vector sizes provided. This
392 // path should be taken to vectorize code with dynamic shapes and when using
393 // vector sizes greater than the iteration space sizes.
394 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
395 scalableVecDims.append(inputScalableVecDims.begin(),
396 inputScalableVecDims.end());
397 } else {
398 // Compute the canonical vector shape from the operation shape. If there are
399 // dynamic shapes, the operation won't be vectorized. We assume all the
400 // vector dimensions are fixed.
401 canonicalVecShape = linalgOp.getStaticLoopRanges();
402 scalableVecDims.append(linalgOp.getNumLoops(), false);
403 }
404
405 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
406 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
407
408 if (ShapedType::isDynamicShape(canonicalVecShape))
409 return failure();
410
411 // Initialize iteration space static sizes.
412 initIterSpaceStaticSizes(linalgOp);
413
414 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
415 // all the static and dynamic dimensions of the iteration space, needed to
416 // compute a mask during vectorization.
417 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
418 return failure();
419
420 return success();
421}
422
423/// Create or retrieve an existing mask value to mask `opToMask` in the
424/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
425/// using that permutation map. If a new mask is created, it will be cached for
426/// future users.
427Value VectorizationState::getOrCreateMaskFor(
428 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
429 std::optional<AffineMap> maybeMaskingMap) {
430
431 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
432 "Ill-formed masking map.");
433
434 // No mask is needed if the operation is not maskable.
435 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
436 if (!maskableOp)
437 return Value();
438
439 assert(!maskableOp.isMasked() &&
440 "Masking an operation that is already masked");
441
442 // If no masking map was provided, use an identity map with the loop dims.
443 assert((!maybeMaskingMap || *maybeMaskingMap) &&
444 "Unexpected null mask permutation map");
445 AffineMap maskingMap =
446 maybeMaskingMap ? *maybeMaskingMap
448 linalgOp.getNumLoops(), rewriter.getContext());
449
450 LDBG() << "Masking map: " << maskingMap;
451
452 // Return the active mask for the masking map of this operation if it was
453 // already created.
454 auto activeMaskIt = activeMaskCache.find(maskingMap);
455 if (activeMaskIt != activeMaskCache.end()) {
456 Value mask = activeMaskIt->second;
457 LDBG() << "Reusing mask: " << mask;
458 return mask;
459 }
460
461 // Compute permuted projection of the iteration space to be masked and the
462 // corresponding mask shape. If the resulting iteration space dimensions are
463 // static and identical to the mask shape, masking is not needed for this
464 // operation.
465 // TODO: Improve this check. Only projected permutation indexing maps are
466 // supported.
467 SmallVector<int64_t> permutedStaticSizes =
468 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
469 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
470 auto maskShape = maskType.getShape();
471
472 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
473
474 if (permutedStaticSizes == maskShape) {
475 LDBG() << "Masking is not needed for masking map: " << maskingMap;
476 activeMaskCache[maskingMap] = Value();
477 return Value();
478 }
479
480 if (assumeDynamicDimsMatchVecSizes) {
481 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
482 // vector sizes match, we still need to check the _static_ dim sizes. Only
483 // then we can be 100% sure that masking is not required.
484 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
485 [](auto it) {
486 return std::get<0>(it) == ShapedType::kDynamic
487 ? true
488 : std::get<0>(it) == std::get<1>(it);
489 })) {
490 LDBG()
491 << "Dynamic + static dimensions match vector sizes, masking is not "
492 "required.";
493 activeMaskCache[maskingMap] = Value();
494 return Value();
495 }
496 }
497
498 // Permute the iteration space value sizes to compute the mask upper bounds.
499 SmallVector<Value> upperBounds =
500 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
501 assert(!maskShape.empty() && !upperBounds.empty() &&
502 "Masked 0-d vectors are not supported yet");
503
504 // Create the mask based on the dimension values.
505 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
506 maskType, upperBounds);
507 LDBG() << "Creating new mask: " << mask;
508 activeMaskCache[maskingMap] = mask;
509 return mask;
510}
511
512Operation *
514 LinalgOp linalgOp,
515 std::optional<AffineMap> maybeIndexingMap) {
516 LDBG() << "Trying to mask: " << *opToMask;
517
518 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
519 if (maybeIndexingMap)
520 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
521
522 // Create or retrieve mask for this operation.
523 Value mask =
524 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
525
526 if (!mask) {
527 LDBG() << "No mask required";
528 if (assumeDynamicDimsMatchVecSizes) {
530 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
531 [&](auto xferOp) {
532 // For vector.transfer_read and vector.transfer_write, there is
533 // also the `in-bounds` attribute that has to be set explicitly
534 // to true. Otherwise, "out-of-bounds" access will be assumed
535 // and masks will be generated while lowering these.
536 LDBG() << "Assuming dynamic dimensions match vector sizes and "
537 "setting their in-bounds to true!";
538 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
539 ShapedType xferType = xferOp.getShapedType();
540 AffineMap permMap = xferOp.getPermutationMap();
541 // Only set the in-bounds values to true for dynamic dims.
542 // Different mechanisms will set these accordingly for the
543 // static dims.
544 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
545 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
546 // Skip broadcast dimensions.
547 if (!dimExpr)
548 continue;
549 unsigned pos = dimExpr.getPosition();
550 if (xferType.isDynamicDim(pos))
551 inBoundsMap[i] = true;
552 }
553 rewriter.modifyOpInPlace(xferOp, [&]() {
554 xferOp.setInBoundsAttr(
555 rewriter.getBoolArrayAttr(inBoundsMap));
556 });
557 })
558 .Default([](Operation *op) {
559 // No-op if the operation is not an xfer read or write.
560 });
561 }
562 return opToMask;
563 }
564
565 // Wrap the operation with a new `vector.mask` and update D-U chain.
566 assert(opToMask && "Expected a valid operation to mask");
567 auto maskOp = cast<vector::MaskOp>(
568 mlir::vector::maskOperation(rewriter, opToMask, mask));
569 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
570
571 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
572 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
573 maskOpTerminator);
574
575 LDBG() << "Masked operation: " << *maskOp;
576 return maskOp;
577}
578
579/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
580/// projectedPermutation, compress the unused dimensions to serve as a
581/// permutation_map for a vector transfer operation.
582/// For example, given a linalg op such as:
583///
584/// ```
585/// %0 = linalg.generic {
586/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
587/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
588/// }
589/// ins(%0 : tensor<2x3x4xf32>)
590/// outs(%1 : tensor<5x6xf32>)
591/// ```
592///
593/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
594/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
595/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
597 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
598 "expected projected permutation");
599 auto res = compressUnusedDims(map);
600 assert(res.getNumDims() ==
601 (res.getNumResults() - res.getNumOfZeroResults()) &&
602 "expected reindexed map with same number of dims and results");
603 return res;
604}
605
606/// Helper enum to represent conv1d input traversal order.
607enum class Conv1DOpOrder {
608 W, // Corresponds to non-channeled 1D convolution operation.
609 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
610 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
611};
612
613/// Helper data structure to represent the result of vectorization for a single
614/// operation. In certain specific cases, like terminators, we do not want to
615/// propagate.
617 /// Op failed to vectorize.
619 /// Op vectorized and custom function took care of replacement logic
621 /// Op vectorized into a new Op whose results will replace original Op's
622 /// results.
624 // TODO: support values if Op vectorized to Many-Ops whose results we need to
625 // aggregate for replacement.
626};
627/// VectorizationHookResult contains the vectorized op returned from a
628/// CustomVectorizationHook. This is an internal implementation detail of
629/// linalg vectorization, not to be confused with VectorizationResult.
631 /// Return status from vectorizing the current op.
633 /// New vectorized operation to replace the current op.
634 /// Replacement behavior is specified by `status`.
636};
637
638std::optional<vector::CombiningKind>
640 using ::mlir::vector::CombiningKind;
641
642 if (!combinerOp)
643 return std::nullopt;
645 .Case<arith::AddIOp, arith::AddFOp>(
646 [&](auto op) { return CombiningKind::ADD; })
647 .Case([&](arith::AndIOp op) { return CombiningKind::AND; })
648 .Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; })
649 .Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; })
650 .Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; })
651 .Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; })
652 .Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; })
653 .Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; })
654 .Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; })
655 .Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; })
656 .Case<arith::MulIOp, arith::MulFOp>(
657 [&](auto op) { return CombiningKind::MUL; })
658 .Case([&](arith::OrIOp op) { return CombiningKind::OR; })
659 .Case([&](arith::XOrIOp op) { return CombiningKind::XOR; })
660 .Default(std::nullopt);
661}
662
663/// Check whether `outputOperand` is a reduction with a single combiner
664/// operation. Return the combiner operation of the reduction. Return
665/// nullptr otherwise. Multiple reduction operations would impose an
666/// ordering between reduction dimensions and is currently unsupported in
667/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
668/// max(min(X))
669// TODO: use in LinalgOp verification, there is a circular dependency atm.
670static Operation *matchLinalgReduction(OpOperand *outputOperand) {
671 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
672 unsigned outputPos =
673 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
674 // Only single combiner operations are supported for now.
675 SmallVector<Operation *, 4> combinerOps;
676 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
677 combinerOps.size() != 1)
678 return nullptr;
679
680 // Return the combiner operation.
681 return combinerOps[0];
682}
683
684/// Broadcast `value` to a vector of `shape` if possible. Return value
685/// otherwise.
686static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
687 auto dstVecType = dyn_cast<VectorType>(dstType);
688 // If no shape to broadcast to, just return `value`.
689 if (dstVecType.getRank() == 0)
690 return value;
691 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
693 return value;
694 Location loc = b.getInsertionPoint()->getLoc();
695 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
696}
697
698/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
699/// assumes that `reductionOp` has two operands and one of them is the reduction
700/// initial value.buildMultiDimReduce
701// Note: this is a true builder that notifies the OpBuilder listener.
702// TODO: Consider moving as a static helper on the ReduceOp.
704 Value valueToReduce, Value acc,
705 ArrayRef<bool> dimsToMask) {
706 auto maybeKind = getCombinerOpKind(reduceOp);
707 assert(maybeKind && "Failed precondition: could not get reduction kind");
708 return vector::MultiDimReductionOp::create(
709 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
710}
711
712static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
713 return llvm::to_vector(
714 llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
715}
716
717/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
718/// reduction iterator.
719static bool hasReductionIterator(LinalgOp &op) {
720 return isa<linalg::ReduceOp>(op) ||
721 (isa<linalg::GenericOp>(op) &&
722 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
723}
724
725/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
726/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
727/// currently being vectorized. If `dest` has null rank, build an memref.store.
728/// Return the produced value or null if no value is produced.
729// Note: this is a true builder that notifies the OpBuilder listener.
730// TODO: Consider moving as a static helper on the ReduceOp.
731static Value buildVectorWrite(RewriterBase &rewriter, Value value,
732 OpOperand *outputOperand,
733 VectorizationState &state) {
734 Location loc = value.getLoc();
735 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
736 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
737
738 // Compute the vector type of the value to store. This type should be an
739 // identity or projection of the canonical vector type without any permutation
740 // applied, given that any permutation in a transfer write happens as part of
741 // the write itself.
743 opOperandMap.getContext(), opOperandMap.getNumInputs(),
744 [&](AffineDimExpr dimExpr) -> bool {
745 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
746 });
747 auto vectorType = state.getCanonicalVecType(
748 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
749
750 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
751 arith::ConstantIndexOp::create(rewriter, loc, 0));
752
753 Operation *write;
754 if (vectorType.getRank() > 0) {
755 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
756 value = broadcastIfNeeded(rewriter, value, vectorType);
757 assert(value.getType() == vectorType && "Incorrect type");
758 write = vector::TransferWriteOp::create(
759 rewriter, loc, value, outputOperand->get(), indices, writeMap);
760 } else {
761 // 0-d case is still special: do not invert the reindexing writeMap.
762 if (!isa<VectorType>(value.getType()))
763 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
764 assert(value.getType() == vectorType && "Incorrect type");
765 write = vector::TransferWriteOp::create(rewriter, loc, value,
766 outputOperand->get(), indices);
767 }
768
769 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
770
771 // If masked, set in-bounds to true. Masking guarantees that the access will
772 // be in-bounds.
773 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
774 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
775 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
776 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
777 }
778
779 LDBG() << "vectorized op: " << *write;
780 if (!write->getResults().empty())
781 return write->getResult(0);
782 return Value();
783}
784
785// Custom vectorization precondition function type. This is intented to be used
786// with CustomVectorizationHook. Returns success if the corresponding custom
787// hook can vectorize the op.
789 std::function<LogicalResult(Operation *, bool)>;
790
791// Custom vectorization function type. Produce a vector form of Operation*
792// assuming all its vectorized operands are already in the IRMapping.
793// Return nullptr if the Operation cannot be vectorized.
795 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
796
797/// Helper function to vectorize the terminator of a `linalgOp`. New result
798/// vector values are appended to `newResults`. Return
799/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
800/// that it should not try to map produced operations and instead return the
801/// results using the `newResults` vector making them available to the
802/// vectorization algorithm for RAUW. This function is meant to be used as a
803/// CustomVectorizationHook.
806 const IRMapping &bvm, VectorizationState &state,
807 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
808 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
809 if (!yieldOp)
811 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
812 // TODO: Scan for an opportunity for reuse.
813 // TODO: use a map.
814 Value vectorValue = bvm.lookup(output.value());
815 Value newResult =
816 buildVectorWrite(rewriter, vectorValue,
817 linalgOp.getDpsInitOperand(output.index()), state);
818 if (newResult)
819 newResults.push_back(newResult);
820 }
821
823}
824
825/// Helper function to vectorize the index operations of a `linalgOp`. Return
826/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
827/// should map the produced operations. This function is meant to be used as a
828/// CustomVectorizationHook.
830 VectorizationState &state,
831 Operation *op,
832 LinalgOp linalgOp) {
833 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
834 if (!indexOp)
836 auto loc = indexOp.getLoc();
837 // Compute the static loop sizes of the index op.
838 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
839 auto dim = indexOp.getDim();
840 // Compute a one-dimensional index vector for the index op dimension.
841 auto indexVectorType =
842 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
843 state.getScalableVecDims()[dim]);
844 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
845 // Return the one-dimensional index vector if it lives in the trailing
846 // dimension of the iteration space since the vectorization algorithm in this
847 // case can handle the broadcast.
848 if (dim == targetShape.size() - 1)
850 // Otherwise permute the targetShape to move the index dimension last,
851 // broadcast the one-dimensional index vector to the permuted shape, and
852 // finally transpose the broadcasted index vector to undo the permutation.
853 auto permPattern =
854 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
855 std::swap(permPattern[dim], permPattern.back());
856 auto permMap =
857 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
858
859 auto broadCastOp = vector::BroadcastOp::create(
860 rewriter, loc,
861 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
862 SmallVector<int64_t> transposition =
863 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
864 std::swap(transposition.back(), transposition[dim]);
865 auto transposeOp =
866 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
868}
869
870/// Helper function to check if the tensor.extract can be vectorized by the
871/// custom hook vectorizeTensorExtract.
872static LogicalResult
874 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
875 if (!extractOp)
876 return failure();
877
878 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
879 return failure();
880
881 // Check the index type, but only for non 0-d tensors (for which we do need
882 // access indices).
883 if (not extractOp.getIndices().empty()) {
884 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
885 return failure();
886 }
887
888 if (!llvm::all_of(extractOp->getResultTypes(),
889 VectorType::isValidElementType)) {
890 return failure();
891 }
892
893 return success();
894}
895
896/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
897/// generated from `tensor.extract`. The offset is calculated as follows
898/// (example using scalar values):
899///
900/// offset = extractOp.indices[0]
901/// for (i = 1; i < numIndices; i++)
902/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
903///
904/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
905/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
907 VectorizationState &state,
908 tensor::ExtractOp extractOp,
909 const IRMapping &bvm) {
910 // The vector of indices for GatherOp should be shaped as the output vector.
911 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
912 auto loc = extractOp.getLoc();
913
914 Value offset = broadcastIfNeeded(
915 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
916
917 const size_t numIndices = extractOp.getIndices().size();
918 for (size_t i = 1; i < numIndices; i++) {
919 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
920
921 auto dimSize = broadcastIfNeeded(
922 rewriter,
923 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
924 indexVecType);
925
926 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
927
928 auto extractOpIndex = broadcastIfNeeded(
929 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
930
931 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
932 }
933
934 return offset;
935}
936
938
939/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
940/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
941/// represents a contiguous load operation.
942///
943/// Note that when calling this hook, it is assumed that the output vector is
944/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
945/// labelled as a gather load before entering this method.
946///
947/// Following on from the above, it is assumed that:
948/// * for statically shaped loops, when no masks are used, only one dim is !=
949/// 1 (that's what the shape of the output vector is based on).
950/// * for dynamically shaped loops, there might be more non-unit dims
951/// as the output vector type is user-specified.
952///
953/// TODO: Statically shaped loops + vector masking
954static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
955 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
956 assert(
957 (linalgOp.hasDynamicShape() ||
958 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
959 "For statically shaped Linalg Ops, only one "
960 "non-unit loop dim is expected");
961 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
962
963 size_t idx = loopRanges.size() - 1;
964 for (; idx != 0; idx--)
965 if (loopRanges[idx] != 1)
966 break;
967
968 return idx;
969}
970
971/// Checks whether `val` can be used for calculating a loop invariant index.
972static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
973 VectorType resType) {
974
975 assert(((llvm::count_if(resType.getShape(),
976 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
977 "n-D vectors are not yet supported");
978
979 // Blocks outside _this_ linalg.generic are effectively loop invariant.
980 // However, analysing block arguments for _this_ linalg.generic Op is a bit
981 // tricky. Just bail out in the latter case.
982 // TODO: We could try analysing the corresponding affine map here.
983 auto *block = linalgOp.getBlock();
984 if (isa<BlockArgument>(val))
985 return !llvm::is_contained(block->getArguments(), val);
986
987 Operation *defOp = val.getDefiningOp();
988 assert(defOp && "This is neither a block argument nor an operation result");
989
990 // IndexOp is loop invariant as long as its result remains constant across
991 // iterations. Note that for dynamic shapes, the corresponding dim will also
992 // be conservatively treated as != 1.
993 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
994 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
995 }
996
997 auto *ancestor = block->findAncestorOpInBlock(*defOp);
998
999 // Values define outside `linalgOp` are loop invariant.
1000 if (!ancestor)
1001 return true;
1002
1003 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1004 if (isa<arith::ConstantOp>(ancestor))
1005 return true;
1006
1007 bool result = true;
1008 for (auto op : ancestor->getOperands())
1009 result &= isLoopInvariantIdx(linalgOp, op, resType);
1010
1011 return result;
1012}
1013
1014/// Check whether `val` could be used for calculating the trailing index for a
1015/// contiguous load operation.
1016///
1017/// There are currently 3 types of values that are allowed here:
1018/// 1. loop-invariant values,
1019/// 2. values that increment by 1 with every loop iteration,
1020/// 3. results of basic arithmetic operations (linear and continuous)
1021/// involving 1., 2. and 3.
1022/// This method returns True if indeed only such values are used in calculating
1023/// `val.`
1024///
1025/// Additionally, the trailing index for a contiguous load operation should
1026/// increment by 1 with every loop iteration, i.e. be based on:
1027/// * `linalg.index <dim>` ,
1028/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1029/// `linalg.index <dim>` increments by 1 with every loop iteration).
1030/// `foundIndexOp` is updated to `true` when such Op is found.
1031static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1032 bool &foundIndexOp, VectorType resType) {
1033
1034 assert(((llvm::count_if(resType.getShape(),
1035 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1036 "n-D vectors are not yet supported");
1037
1038 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1039 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1040 // tricky. Just bail out in the latter case.
1041 // TODO: We could try analysing the corresponding affine map here.
1042 auto *block = linalgOp.getBlock();
1043 if (isa<BlockArgument>(val))
1044 return !llvm::is_contained(block->getArguments(), val);
1045
1046 Operation *defOp = val.getDefiningOp();
1047 assert(defOp && "This is neither a block argument nor an operation result");
1048
1049 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1050 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1051
1052 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1053 return true;
1054 }
1055
1056 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1057
1058 if (!ancestor)
1059 return false;
1060
1061 // Conservatively reject Ops that could lead to indices with stride other
1062 // than 1.
1063 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1064 return false;
1065
1066 bool result = false;
1067 for (auto op : ancestor->getOperands())
1068 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1069
1070 return result;
1071}
1072
1073/// Infer the memory access pattern for the input ExtractOp
1074///
1075/// Based on the ExtratOp result shape and the access indices, decides whether
1076/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1077/// or a gather load. When analysing the ExtractOp indices (to identify
1078/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1079/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1080/// Op).
1081///
1082/// Note that it is always safe to use gather load operations for contiguous
1083/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1084/// that `extractOp` is a gather load.
1086getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1087 LinalgOp &linalgOp, VectorType resType) {
1088
1089 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1090
1091 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1092 if (inputShape.getShape().empty())
1094
1095 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1096 // otherwise.
1097 bool isOutput1DVector =
1098 (llvm::count_if(resType.getShape(),
1099 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1100 // 1. Assume that it's a gather load when reading non-1D vector.
1101 if (!isOutput1DVector)
1103
1104 bool leadingIdxsLoopInvariant = true;
1105
1106 // 2. Analyze the leading indices of `extractOp`.
1107 // Look at the way each index is calculated and decide whether it is suitable
1108 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1109 // gather load.
1110 auto indices = extractOp.getIndices();
1111 auto leadIndices = indices.drop_back(1);
1112
1113 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1114 if (inputShape.getShape()[i] == 1)
1115 continue;
1116
1117 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1118 }
1119
1120 if (!leadingIdxsLoopInvariant) {
1121 LDBG() << "Found gather load: " << extractOp;
1123 }
1124
1125 // 3. Analyze the trailing index for `extractOp`.
1126 // At this point we know that the leading indices are loop invariant. This
1127 // means that is potentially a scalar or a contiguous load. We can decide
1128 // based on the trailing idx.
1129 auto extractOpTrailingIdx = indices.back();
1130
1131 // 3a. Scalar broadcast load
1132 // If the trailing index is loop invariant then this is a scalar load.
1133 if (leadingIdxsLoopInvariant &&
1134 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1135 LDBG() << "Found scalar broadcast load: " << extractOp;
1136
1138 }
1139
1140 // 3b. Contiguous loads
1141 // The trailing `extractOp` index should increment with every loop iteration.
1142 // This effectively means that it must be based on the trailing loop index.
1143 // This is what the following bool captures.
1144 bool foundIndexOp = false;
1145 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1146 foundIndexOp, resType);
1147 // TODO: Support generating contiguous loads for column vectors - that will
1148 // require adding a permutation map to tranfer_read Ops.
1149 bool isRowVector = resType.getShape().back() != 1;
1150 isContiguousLoad &= (foundIndexOp && isRowVector);
1151
1152 if (isContiguousLoad) {
1153 LDBG() << "Found contigous load: " << extractOp;
1155 }
1156
1157 // 4. Fallback case - gather load.
1158 LDBG() << "Found gather load: " << extractOp;
1160}
1161
1162/// Helper function to vectorize the tensor.extract operations. Returns
1163/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1164/// should map the produced operations. This function is meant to be used as a
1165/// CustomVectorizationHook.
1166static VectorizationHookResult
1167vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1168 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1169 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1170 if (!extractOp)
1172 auto loc = extractOp.getLoc();
1173
1174 // Compute the static loop sizes of the extract op.
1175 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1176 auto maskConstantOp = arith::ConstantOp::create(
1177 rewriter, loc,
1178 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1179 /*value=*/true));
1180 auto passThruConstantOp = arith::ConstantOp::create(
1181 rewriter, loc, rewriter.getZeroAttr(resultType));
1182
1183 // Base indices are currently set to 0. We will need to re-visit if more
1184 // generic scenarios are to be supported.
1185 SmallVector<Value> baseIndices(
1186 extractOp.getIndices().size(),
1187 arith::ConstantIndexOp::create(rewriter, loc, 0));
1188
1189 VectorMemoryAccessKind memAccessKind =
1190 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1191
1192 // 1. Handle gather access
1193 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1194 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1195
1196 // Generate the gather load
1197 Operation *gatherOp = vector::GatherOp::create(
1198 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1199 maskConstantOp, passThruConstantOp);
1200 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1201
1202 LDBG() << "Vectorised as gather load: " << extractOp;
1204 }
1205
1206 // 2. Handle:
1207 // a. scalar loads + broadcast,
1208 // b. contiguous loads.
1209 // Both cases use vector.transfer_read.
1210
1211 // Collect indices for `vector.transfer_read`. At this point, the indices will
1212 // either be scalars or would have been broadcast to vectors matching the
1213 // result type. For indices that are vectors, there are two options:
1214 // * for non-trailing indices, all elements are identical (contiguous
1215 // loads are identified by looking for non-trailing indices that are
1216 // invariant with respect to the corresponding linalg.generic), or
1217 // * for trailing indices, the index vector will contain values with stride
1218 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1219 // needed.
1220 // This means that
1221 // * for scalar indices - just re-use it,
1222 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1223 // (0th) element and use that.
1224 SmallVector<Value> transferReadIdxs;
1225 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1226 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1227 if (idx.getType().isIndex()) {
1228 transferReadIdxs.push_back(idx);
1229 continue;
1230 }
1231
1232 auto indexAs1dVector = vector::ShapeCastOp::create(
1233 rewriter, loc,
1234 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1235 resultType.getScalableDims().back()),
1236 idx);
1237 transferReadIdxs.push_back(
1238 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1239 }
1240
1241 // `tensor.extract_element` is always in-bounds, hence the following holds.
1242 auto dstRank = resultType.getRank();
1243 auto srcRank = extractOp.getTensor().getType().getRank();
1244 SmallVector<bool> inBounds(dstRank, true);
1245
1246 // 2a. Handle scalar broadcast access.
1247 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1248 MLIRContext *ctx = rewriter.getContext();
1249 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1250 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1251
1252 auto transferReadOp = vector::TransferReadOp::create(
1253 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1254 /*padding=*/std::nullopt, permutationMap, inBounds);
1255
1256 // Mask this broadcasting xfer_read here rather than relying on the generic
1257 // path (the generic path assumes identity masking map, which wouldn't be
1258 // valid here).
1259 SmallVector<int64_t> readMaskShape = {1};
1260 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1261 auto allTrue = vector::ConstantMaskOp::create(
1262 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1263 auto *maskedReadOp =
1264 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1265
1266 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1268 maskedReadOp};
1269 }
1270
1271 // 2b. Handle contiguous access.
1272 auto permutationMap = AffineMap::getMinorIdentityMap(
1273 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1274
1275 int32_t rankDiff = dstRank - srcRank;
1276 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1277 // dims so that the ranks match. This is done by extending the map with 0s.
1278 // For example, for dstRank = 3, srcRank = 2, the following map created
1279 // above:
1280 // (d0, d1) --> (d0, d1)
1281 // is extended as:
1282 // (d0, d1) --> (0, d0, d1)
1283 while (rankDiff > 0) {
1284 permutationMap = permutationMap.insertResult(
1285 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1286 rankDiff--;
1287 }
1288
1289 auto transferReadOp = vector::TransferReadOp::create(
1290 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1291 /*padding=*/std::nullopt, permutationMap, inBounds);
1292
1293 LDBG() << "Vectorised as contiguous load: " << extractOp;
1295 transferReadOp};
1296}
1297
1298/// Emit reduction operations if the shapes of the value to reduce is different
1299/// that the result shape.
1300// Note: this is a true builder that notifies the OpBuilder listener.
1301// TODO: Consider moving as a static helper on the ReduceOp.
1302static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1303 Value reduceValue, Value initialValue,
1304 const IRMapping &bvm) {
1305 Value reduceVec = bvm.lookup(reduceValue);
1306 Value outputVec = bvm.lookup(initialValue);
1307 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1308 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1309 // Reduce only if needed as the value may already have been reduce for
1310 // contraction vectorization.
1311 if (!reduceType ||
1312 (outputType && reduceType.getShape() == outputType.getShape()))
1313 return nullptr;
1314 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1315 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1316}
1317
1318/// Generic vectorization for a single operation `op`, given already vectorized
1319/// operands carried by `bvm`. Vectorization occurs as follows:
1320/// 1. Try to apply any of the `customVectorizationHooks` and return its
1321/// result on success.
1322/// 2. Clone any constant in the current scope without vectorization: each
1323/// consumer of the constant will later determine the shape to which the
1324/// constant needs to be broadcast to.
1325/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1326/// of the `customVectorizationHooks` to cover such cases.
1327/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1328/// operand of maximal rank. Other operands have smaller rank and are
1329/// broadcast accordingly. It is assumed this broadcast is always legal,
1330/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1331///
1332/// This function assumes all operands of `op` have been vectorized and are in
1333/// the `bvm` mapping. As a consequence, this function is meant to be called on
1334/// a topologically-sorted list of ops.
1335/// This function does not update `bvm` but returns a VectorizationHookStatus
1336/// that instructs the caller what `bvm` update needs to occur.
1337static VectorizationHookResult
1338vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1339 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1340 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1341 LDBG() << "vectorize op " << *op;
1342
1343 // 1. Try to apply any CustomVectorizationHook.
1344 if (!customVectorizationHooks.empty()) {
1345 for (auto &customFunc : customVectorizationHooks) {
1346 VectorizationHookResult result = customFunc(op, bvm);
1348 continue;
1349 return result;
1350 }
1351 }
1352
1353 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1354 // Clone so that the constant is not confined to the linalgOp block .
1355 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1357 rewriter.clone(*op)};
1358
1359 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1362
1363 // 4 . Check if the operation is a reduction.
1364 SmallVector<std::pair<Value, Value>> reductionOperands;
1365 for (Value operand : op->getOperands()) {
1366 auto blockArg = dyn_cast<BlockArgument>(operand);
1367 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1368 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1369 continue;
1370 SmallVector<Operation *> reductionOps;
1371 Value reduceValue = matchReduction(
1372 linalgOp.getRegionOutputArgs(),
1373 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1374 if (!reduceValue)
1375 continue;
1376 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1377 }
1378 if (!reductionOperands.empty()) {
1379 assert(reductionOperands.size() == 1);
1380 Operation *reduceOp =
1381 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1382 reductionOperands[0].second, bvm);
1383 if (reduceOp)
1385 }
1386
1387 // 5. Generic vectorization path for ElementwiseMappable ops.
1388 // a. Get the first max ranked shape.
1389 VectorType firstMaxRankedType;
1390 for (Value operand : op->getOperands()) {
1391 auto vecOperand = bvm.lookup(operand);
1392 assert(vecOperand && "Vector operand couldn't be found");
1393
1394 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1395 if (vecType && (!firstMaxRankedType ||
1396 firstMaxRankedType.getRank() < vecType.getRank()))
1397 firstMaxRankedType = vecType;
1398 }
1399 // b. Broadcast each op if needed.
1400 SmallVector<Value> vecOperands;
1401 for (Value scalarOperand : op->getOperands()) {
1402 Value vecOperand = bvm.lookup(scalarOperand);
1403 assert(vecOperand && "Vector operand couldn't be found");
1404
1405 if (firstMaxRankedType) {
1406 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1407 getElementTypeOrSelf(vecOperand.getType()),
1408 firstMaxRankedType.getScalableDims());
1409 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1410 } else {
1411 vecOperands.push_back(vecOperand);
1412 }
1413 }
1414 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1415 SmallVector<Type> resultTypes;
1416 for (Type resultType : op->getResultTypes()) {
1417 resultTypes.push_back(
1418 firstMaxRankedType
1419 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1420 firstMaxRankedType.getScalableDims())
1421 : resultType);
1422 }
1423 // d. Build and return the new op.
1426 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1427 resultTypes, op->getAttrs())};
1428}
1429
1430/// Generic vectorization function that rewrites the body of a `linalgOp` into
1431/// vector form. Generic vectorization proceeds as follows:
1432/// 1. Verify the `linalgOp` has one non-empty region.
1433/// 2. Values defined above the region are mapped to themselves and will be
1434/// broadcasted on a per-need basis by their consumers.
1435/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1436/// load).
1437/// TODO: Reuse opportunities for RAR dependencies.
1438/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1439/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1440/// iteration indices.
1441/// 5. Iteratively call vectorizeOneOp on the region operations.
1442///
1443/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1444/// performed to the maximal common vector size implied by the `linalgOp`
1445/// iteration space. This eager broadcasting is introduced in the
1446/// permutation_map of the vector.transfer_read operations. The eager
1447/// broadcasting makes it trivial to determine where broadcast, transposes and
1448/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1449/// the absence of good canonicalizations, the amount of work increases.
1450/// This is not deemed a problem as we expect canonicalizations and foldings to
1451/// aggressively clean up the useless work.
1452static LogicalResult
1453vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1454 LinalgOp linalgOp,
1455 SmallVectorImpl<Value> &newResults) {
1456 LDBG() << "Vectorizing operation as linalg generic/n";
1457 Block *block = linalgOp.getBlock();
1458
1459 // 2. Values defined above the region can only be broadcast for now. Make them
1460 // map to themselves.
1461 IRMapping bvm;
1462 SetVector<Value> valuesSet;
1463 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1464 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1465
1466 if (linalgOp.getNumDpsInits() == 0)
1467 return failure();
1468
1469 // 3. Turn all BBArgs into vector.transfer_read / load.
1470 Location loc = linalgOp.getLoc();
1471 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1472 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1473 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1474 if (linalgOp.isScalar(opOperand)) {
1475 bvm.map(bbarg, opOperand->get());
1476 continue;
1477 }
1478
1479 // 3.a. Convert the indexing map for this input/output to a transfer read
1480 // permutation map and masking map.
1481 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1482
1483 AffineMap readMap;
1484 VectorType readType;
1485 Type elemType = getElementTypeOrSelf(opOperand->get());
1486 if (linalgOp.isDpsInput(opOperand)) {
1487 // 3.a.i. For input reads we use the canonical vector shape.
1488 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1489 readType = state.getCanonicalVecType(elemType);
1490 } else {
1491 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1492 // reductions), the vector shape is computed by mapping the canonical
1493 // vector shape to the output domain and back to the canonical domain.
1494 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1495 readType =
1496 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1497 }
1498
1499 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1500
1501 Operation *read = vector::TransferReadOp::create(
1502 rewriter, loc, readType, opOperand->get(), indices,
1503 /*padding=*/std::nullopt, readMap);
1504 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1505 Value readValue = read->getResult(0);
1506
1507 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1508 // will be in-bounds.
1509 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1510 SmallVector<bool> inBounds(readType.getRank(), true);
1511 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1512 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1513 }
1514
1515 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1516 // TODO: remove this.
1517 if (readType.getRank() == 0)
1518 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1520
1521 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1522 << "): " << readValue;
1523 bvm.map(bbarg, readValue);
1524 bvm.map(opOperand->get(), readValue);
1525 }
1526
1528 // 4a. Register CustomVectorizationHook for yieldOp.
1529 CustomVectorizationHook vectorizeYield =
1530 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1531 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1532 };
1533 hooks.push_back(vectorizeYield);
1534
1535 // 4b. Register CustomVectorizationHook for indexOp.
1536 CustomVectorizationHook vectorizeIndex =
1537 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1538 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1539 };
1540 hooks.push_back(vectorizeIndex);
1541
1542 // 4c. Register CustomVectorizationHook for extractOp.
1543 CustomVectorizationHook vectorizeExtract =
1544 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1545 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1546 };
1547 hooks.push_back(vectorizeExtract);
1548
1549 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1550 for (Operation &op : block->getOperations()) {
1552 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1554 LDBG() << "failed to vectorize: " << op;
1555 return failure();
1556 }
1557 if (result.status == VectorizationHookStatus::NewOp) {
1558 Operation *maybeMaskedOp =
1559 state.maskOperation(rewriter, result.newOp, linalgOp);
1560 LDBG() << "New vector op: " << *maybeMaskedOp;
1561 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1562 }
1563 }
1564
1565 return success();
1566}
1567
1568/// Determines whether a mask for xfer_write is trivially "all true"
1569///
1570/// Given all the inputs required to generate a mask (mask sizes and shapes),
1571/// and an xfer_write operation (write indices and the destination tensor
1572/// shape), determines whether the corresponding mask would be trivially
1573/// foldable (i.e., trivially "all true").
1574///
1575/// Use this method to avoid generating spurious masks and relaying on
1576/// vectorization post-processing to remove them.
1577///
1578/// Pre-conditions for a mask to be trivially foldable:
1579/// * All involved shapes (mask + destination tensor) are static.
1580/// * All write indices are constant.
1581/// * All mask sizes are constant (including `arith.constant`).
1582///
1583/// If the pre-conditions are met, the method checks for each destination
1584/// dimension `d`:
1585/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1586/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1587///
1588/// rankDiff = rank(dest) - rank(mask).
1589///
1590/// This method takes a conservative view: it may return false even if the mask
1591/// is technically foldable.
1592///
1593/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1594/// of the dest tensor):
1595/// %c0 = arith.constant 0 : index
1596/// %mask = vector.create_mask 5, 1
1597/// vector.mask %mask {
1598/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1599/// {in_bounds = [true, true]}
1600/// : vector<5x1xi32>, tensor<5x1xi32>
1601/// }
1602///
1603/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1604/// mask is required to avoid out-of-bounds write):
1605/// %c0 = arith.constant 0 : index
1606/// %mask = vector.create_mask 5, 1
1607/// vector.mask %mask {
1608/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1609/// {in_bounds = [true, true]}
1610/// : vector<8x1xi32>, tensor<5x1xi32>
1611/// }
1612///
1613/// TODO: Re-use in createReadOrMaskedRead
1615 SmallVector<Value> &writeIdxs,
1616 ArrayRef<int64_t> destShape,
1617 ArrayRef<int64_t> maskShape) {
1618 // Masking is unavoidable in the case of dynamic tensors.
1619 if (ShapedType::isDynamicShape(destShape))
1620 return false;
1621
1622 // Collect all constant mask sizes.
1623 SmallVector<int64_t, 4> cstMaskSizes;
1624 for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1625 if (auto intSize = getConstantIntValue(dimSize)) {
1626 cstMaskSizes.push_back(*intSize);
1627 }
1628 }
1629
1630 // If any of the mask sizes is non-constant, bail out.
1631 if (cstMaskSizes.size() != maskShape.size())
1632 return false;
1633
1634 // Collect all constant write indices.
1635 SmallVector<int64_t, 4> cstWriteIdxs;
1636 for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1637 APSInt intVal;
1638 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1639 cstWriteIdxs.push_back(intVal.getSExtValue());
1640 }
1641 }
1642
1643 // If any of the write indices is non-constant, bail out.
1644 if (cstWriteIdxs.size() != destShape.size())
1645 return false;
1646
1647 // Go over all destination dims and check (1) and (2). Take into account that:
1648 // * The number of mask sizes will match the rank of the vector to store.
1649 // This could be lower than the rank of the destination tensor.
1650 // * Mask sizes could be larger than the corresponding mask shape (hence
1651 // `clamp`).
1652 // TODO: The 2nd item should be rejected by the verifier.
1653 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1654 for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1655 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1656 /*(2)*/ destShape[rankDiff + i] <
1657 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1658 cstWriteIdxs[i]))
1659 return false;
1660 }
1661
1662 return true;
1663}
1664
1665/// Creates an optionally masked TransferWriteOp
1666///
1667/// Generates the following operation:
1668/// %res = vector.transfer_write %vecToStore into %dest
1669///
1670/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1671///
1672/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1673/// %res = vector.mask %mask {
1674/// vector.transfer_write %vecToStore into %dest
1675/// }
1676///
1677/// The mask shape is identical to `vecToStore` (with the element type ==
1678/// i1), and the mask values are based on the shape of the `dest` tensor.
1679///
1680/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1681/// is used instead of masking:
1682///
1683/// %write = vector.transfer_write %vecToStore into %dest
1684/// in_bounds_flags = (...)
1685/// %res = vector.transfer_write %input into %dest
1686/// {in_bounds = in_bounds_flags}
1687///
1688/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1689/// are set to 0.
1690static Operation *
1692 Value dest, SmallVector<Value> writeIndices = {},
1693 bool useInBoundsInsteadOfMasking = false) {
1694
1695 ShapedType destType = cast<ShapedType>(dest.getType());
1696 int64_t destRank = destType.getRank();
1697 auto destShape = destType.getShape();
1698
1699 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1700 int64_t vecToStoreRank = vecToStoreType.getRank();
1701 auto vecToStoreShape = vecToStoreType.getShape();
1702
1703 // Compute the in_bounds attribute
1704 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1705 if (useInBoundsInsteadOfMasking) {
1706 // Update the inBounds attribute.
1707 // FIXME: This computation is too weak - it ignores the write indices.
1708 for (unsigned i = 0; i < vecToStoreRank; i++)
1709 inBoundsVal[i] =
1710 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1711 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1712 }
1713
1714 // If missing, initialize the write indices to 0.
1715 bool useDefaultWriteIdxs = writeIndices.empty();
1716 assert((useDefaultWriteIdxs ||
1717 writeIndices.size() == static_cast<size_t>(destRank)) &&
1718 "Invalid number of write indices!");
1719 if (writeIndices.empty()) {
1720 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1721 writeIndices.assign(destRank, zero);
1722 }
1723
1724 // Generate the xfer_write Op
1725 Operation *write = vector::TransferWriteOp::create(builder, loc,
1726 /*vector=*/vecToStore,
1727 /*source=*/dest,
1728 /*indices=*/writeIndices,
1729 /*inBounds=*/inBoundsVal);
1730
1731 // If masking is disabled, exit.
1732 if (useInBoundsInsteadOfMasking)
1733 return write;
1734
1735 // Check if masking is needed. If not, exit.
1736 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1737 return write;
1738
1739 // Compute the mask and mask the write Op.
1740 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1741 vecToStoreType.getScalableDims());
1742
1743 SmallVector<OpFoldResult> destSizes =
1744 isa<MemRefType>(dest.getType())
1745 ? memref::getMixedSizes(builder, loc, dest)
1746 : tensor::getMixedSizes(builder, loc, dest);
1747
1748 // Compute sizes for write-mask
1749 SmallVector<OpFoldResult> maskSizes;
1750 if (useDefaultWriteIdxs) {
1751 maskSizes = SmallVector<OpFoldResult>(destSizes.end() - vecToStoreRank,
1752 destSizes.end());
1753 } else {
1754 size_t diff = destShape.size() - vecToStoreRank;
1755 for (int64_t idx = 0; idx < vecToStoreRank; idx++) {
1756 auto value =
1757 getValueOrCreateConstantIndexOp(builder, loc, destSizes[diff + idx]);
1758 auto neg =
1759 builder.createOrFold<arith::SubIOp>(loc, value, writeIndices[idx]);
1760 maskSizes.push_back(OpFoldResult(neg));
1761 }
1762 }
1763
1764 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1765 vecToStoreShape))
1766 return write;
1767
1768 Value maskForWrite =
1769 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1770 return mlir::vector::maskOperation(builder, write, maskForWrite);
1771}
1772
1773/// Given the re-associations, "collapses" the input Vector type
1774///
1775/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1776/// differences:
1777/// * We can safely assume that there are no dynamic sizes.
1778/// * Scalable flags are updated alongside regular dims.
1779///
1780/// When collapsing scalable flags, conservatively avoids cases with two
1781/// scalable dims. We could re-visit this in the future.
1782///
1783/// EXAMPLE:
1784/// type = vector<4x16x[8]x16xf32>
1785/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1786/// (d0, d1, d2, d3) -> (d2, d3)]
1787/// Result:
1788/// vector<64x[128]xf32>
1789static VectorType getCollapsedVecType(VectorType type,
1790 ArrayRef<AffineMap> reassociation) {
1791 assert(type.getNumScalableDims() < 2 &&
1792 "Collapsing more than 1 scalable dim is not supported ATM");
1793
1794 // Use the fact that reassociation is valid to simplify the logic: only use
1795 // each map's rank.
1796 assert(isReassociationValid(reassociation) && "invalid reassociation");
1797
1798 auto shape = type.getShape();
1799 auto scalableFlags = type.getScalableDims();
1800 SmallVector<int64_t> newShape;
1801 SmallVector<bool> newScalableFlags;
1802
1803 unsigned currentDim = 0;
1804 for (AffineMap m : reassociation) {
1805 unsigned dim = m.getNumResults();
1806 int64_t size = 1;
1807 bool flag = false;
1808 for (unsigned d = 0; d < dim; ++d) {
1809 size *= shape[currentDim + d];
1810 flag |= scalableFlags[currentDim + d];
1811 }
1812 newShape.push_back(size);
1813 newScalableFlags.push_back(flag);
1814 currentDim += dim;
1815 }
1816
1817 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1818}
1819
1820/// Vectorize `linalg.pack` as:
1821/// * xfer_read -> shape_cast -> transpose -> xfer_write
1822///
1823/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1824/// sizes for the xfer_write operation). This is sufficient to infer the other
1825/// vector sizes required here.
1826///
1827/// If the vector sizes are not provided:
1828/// * the vector sizes are determined from the destination tensor static shape.
1829/// * the inBounds attribute is used instead of masking.
1830///
1831/// EXAMPLE (no vector sizes):
1832/// ```
1833/// %pack = tensor.pack %src
1834/// inner_dims_pos = [2, 1]
1835/// inner_tiles = [16, 2]
1836/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1837/// ``
1838/// is vectorizes as:
1839/// ```
1840/// %read = vector.transfer_read %src
1841/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1842/// %sc = vector.shape_cast %read
1843/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1844/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1845/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1846/// %write = vector.transfer_write %tr into %dest
1847/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1848/// ```
1849static LogicalResult
1850vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1851 ArrayRef<int64_t> inputVectorSizes,
1852 SmallVectorImpl<Value> &newResults) {
1853 if (!inputVectorSizes.empty()) {
1854 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1855 "Invalid number of input vector sizes!");
1856 }
1857
1858 // TODO: Introduce a parent class that will handle the insertion point update.
1859 OpBuilder::InsertionGuard g(rewriter);
1860 rewriter.setInsertionPoint(packOp);
1861
1862 Location loc = packOp.getLoc();
1863 std::optional<Value> padValue = packOp.getPaddingValue()
1864 ? std::optional(packOp.getPaddingValue())
1865 : std::nullopt;
1866
1867 SmallVector<int64_t> destShape =
1868 SmallVector<int64_t>(packOp.getDestType().getShape());
1869
1870 // This is just a convenience alias to clearly communicate that the input
1871 // vector sizes determine the _write_ sizes.
1872 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1873
1874 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1875 // In addition, use the inBounds attribute instead of masking.
1876 bool useInBoundsInsteadOfMasking = false;
1877 if (writeVectorSizes.empty()) {
1878 if (ShapedType::isDynamicShape(destShape))
1879 return rewriter.notifyMatchFailure(packOp,
1880 "unable to infer vector sizes");
1881
1882 writeVectorSizes = destShape;
1883 useInBoundsInsteadOfMasking = true;
1884 }
1885
1886 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1887 // _before_ the transposition (i.e. before dimension permutation). This is
1888 // done by inverting the permutation/transposition that's part of the Pack
1889 // operation. This type is required to:
1890 // 1) compute the read vector type for masked-read below, and
1891 // 2) generate shape-cast Op below that expands the read vector type.
1892 PackingMetadata packMetadata;
1893 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1894 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1895 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1896 auto preTransposeWriteVecType =
1897 VectorType::get(preTransposeWriteVecSizses,
1898 packOp.getResult().getType().getElementType());
1899
1900 // Compute vector type for the _read_ opeartion. This is simply
1901 // pre-transpose-write-vector-type with the dimensions collapsed
1902 // as per the Pack operation.
1903 VectorType readVecType = getCollapsedVecType(
1904 preTransposeWriteVecType,
1906 rewriter.getContext(), packMetadata.reassociations)));
1907
1908 // Create masked TransferReadOp.
1909 auto maskedRead = vector::createReadOrMaskedRead(
1910 rewriter, loc, packOp.getSource(), readVecType, padValue,
1911 useInBoundsInsteadOfMasking);
1912
1913 // Create ShapeCastOp.
1914 auto shapeCastOp = vector::ShapeCastOp::create(
1915 rewriter, loc, preTransposeWriteVecType, maskedRead);
1916
1917 // Create TransposeOp.
1918 auto destPermutation = invertPermutationVector(destInvPermutation);
1919 auto transposeOp = vector::TransposeOp::create(
1920 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1921
1922 // Create TransferWriteOp.
1923 Operation *write = createWriteOrMaskedWrite(
1924 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1925 newResults.push_back(write->getResult(0));
1926 return success();
1927}
1928
1929/// Vectorize `linalg.unpack` as:
1930/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1931///
1932/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1933/// sizes for the xfer_read operation). This is sufficient to infer the other
1934/// vector sizes required here.
1935///
1936/// If the vector sizes are not provided:
1937/// * the vector sizes are determined from the input tensor static shape.
1938/// * the inBounds attribute is used instead of masking.
1939///
1940/// EXAMPLE (no vector sizes):
1941/// ```
1942/// %unpack = linalg.unpack %src
1943/// inner_dims_pos = [0, 1]
1944/// inner_tiles = [8, 8]
1945/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1946/// ```
1947/// is vectorized as:
1948/// ```
1949/// %read = vector.transfer_read %src
1950/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1951/// %tr = vector.transpose %read, [0, 2, 1, 3]
1952/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1953/// %sc = vector.shape_cast %tr
1954/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1955/// %vector = vector.transfer_write %sc into %dest
1956/// : vector<8x8xf32>, tensor<8x8xf32>
1957/// ```
1958static LogicalResult
1959vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1960 ArrayRef<int64_t> inputVectorSizes,
1961 ArrayRef<bool> inputScalableVecDims,
1962 SmallVectorImpl<Value> &newResults) {
1963 if (!inputVectorSizes.empty()) {
1964 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1965 "Invalid number of input vector sizes!");
1966 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1967 "Incompatible number of vector sizes and vector scalable flags!");
1968 }
1969
1970 // TODO: Introduce a parent class that will handle the insertion point update.
1971 OpBuilder::InsertionGuard g(rewriter);
1972 rewriter.setInsertionPoint(unpackOp);
1973
1974 ShapedType unpackTensorType = unpackOp.getSourceType();
1975
1976 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1977 bool useInBoundsInsteadOfMasking = false;
1978
1979 Location loc = unpackOp->getLoc();
1980
1981 // Obtain vector sizes for the read operation.
1982 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1983 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1984
1985 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1986 if (inputVectorSizes.empty()) {
1987 if (ShapedType::isDynamicShape(sourceShape))
1988 return rewriter.notifyMatchFailure(unpackOp,
1989 "Unable to infer vector sizes!");
1990
1991 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1992 useInBoundsInsteadOfMasking = true;
1993 }
1994
1995 // -- Generate the read operation --
1996 VectorType readVecType =
1997 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1998 readScalableVectorFlags);
1999 Value readResult = vector::createReadOrMaskedRead(
2000 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
2001 useInBoundsInsteadOfMasking);
2002
2003 // -- Generate the transpose operation --
2004 PackingMetadata packMetadata;
2005 SmallVector<int64_t> lastDimToInsertPosPerm =
2006 getUnPackInverseSrcPerm(unpackOp, packMetadata);
2007 vector::TransposeOp transposeOp = vector::TransposeOp::create(
2008 rewriter, loc, readResult, lastDimToInsertPosPerm);
2009
2010 // -- Generate the shape_cast operation --
2011 VectorType collapsedVecType = getCollapsedVecType(
2012 transposeOp.getType(),
2014 rewriter.getContext(), packMetadata.reassociations)));
2015 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
2016 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2017
2018 // -- Generate the write operation --
2019 Operation *write = createWriteOrMaskedWrite(
2020 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2021 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
2022
2023 newResults.push_back(write->getResult(0));
2024 return success();
2025}
2026
2027/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
2028/// and (3) all-zero lowPad to
2029/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2030static LogicalResult
2031vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2032 ArrayRef<int64_t> inputVectorSizes,
2033 SmallVectorImpl<Value> &newResults) {
2034 auto padValue = padOp.getConstantPaddingValue();
2035 Location loc = padOp.getLoc();
2036
2037 // TODO: Introduce a parent class that will handle the insertion point update.
2038 OpBuilder::InsertionGuard g(rewriter);
2039 rewriter.setInsertionPoint(padOp);
2040
2041 ReifiedRankedShapedTypeDims reifiedReturnShapes;
2042 LogicalResult status =
2043 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2044 .reifyResultShapes(rewriter, reifiedReturnShapes);
2045 (void)status; // prevent unused variable warning on non-assert builds
2046 assert(succeeded(status) && "failed to reify result shapes");
2047 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2048 auto maskedRead = vector::createReadOrMaskedRead(
2049 rewriter, loc, padOp.getSource(), readType, padValue,
2050 /*useInBoundsInsteadOfMasking=*/false);
2051
2052 // Create Xfer write Op
2053 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2054 padOp.getResultType().getElementType());
2055 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2056 newResults.push_back(write->getResult(0));
2057 return success();
2058}
2059
2060// TODO: probably need some extra checks for reduction followed by consumer
2061// ops that may not commute (e.g. linear reduction + non-linear instructions).
2062static LogicalResult reductionPreconditions(LinalgOp op) {
2063 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2064 LDBG() << "reduction precondition failed: no reduction iterator";
2065 return failure();
2066 }
2067 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2068 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2069 if (indexingMap.isPermutation())
2070 continue;
2071
2072 Operation *reduceOp = matchLinalgReduction(&opOperand);
2073 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2074 LDBG() << "reduction precondition failed: reduction detection failed";
2075 return failure();
2076 }
2077 }
2078 return success();
2079}
2080
2081static LogicalResult
2082vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2083 bool flatten1DDepthwiseConv) {
2084 if (flatten1DDepthwiseConv) {
2085 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2086 "supported";
2087 return failure();
2088 }
2089
2091 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2092 return failure();
2093 }
2094
2095 // Support dynamic shapes in 1D depthwise convolution, but only in the
2096 // _channel_ dimension.
2097 Value lhs = conv.getDpsInputOperand(0)->get();
2098 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2099 auto shapeWithoutCh = lhsShape.drop_back(1);
2100 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2101 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2102 "channel dim can be dynamic";
2103 return failure();
2104 }
2105
2106 return success();
2107}
2108
2109static LogicalResult
2110vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2111 bool flatten1DDepthwiseConv) {
2112 if (isa<ConvolutionOpInterface>(op.getOperation()))
2113 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2114
2115 if (hasReductionIterator(op))
2116 return reductionPreconditions(op);
2117
2118 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2119 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2120 if (!isElementwise(op) &&
2121 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2122 op.getOperation()))
2123 return failure();
2124
2125 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2126 return success();
2127}
2128
2129//// This hook considers two cases:
2130/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2131/// infered. This is only possible when all shapes are static.
2132/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2133/// carry out basic sanity-checking.
2134static LogicalResult
2135vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2136 ArrayRef<int64_t> inputVectorSizes) {
2137 // TODO: Support Memref UnPackOp. Temporarily return failure.
2138 if (!unpackOp.hasPureTensorSemantics())
2139 return failure();
2140
2141 // If there are no input vector sizes and all shapes are static, there is
2142 // nothing left to check.
2143 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2144 unpackOp.getSourceType().hasStaticShape())
2145 return success();
2146
2147 // The number of input vector sizes must be equal to:
2148 // * read-vector-rank
2149 if (!inputVectorSizes.empty() &&
2150 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2151 LDBG() << "Incorrect number of input vector sizes";
2152 return failure();
2153 }
2154
2155 // Check the vector sizes for the read operation.
2157 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2158 LDBG() << "Invalid vector sizes for the read operation";
2159 return failure();
2160 }
2161
2162 return success();
2163}
2164
2165static LogicalResult
2166vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2167 ArrayRef<int64_t> inputVectorSizes) {
2168
2169 TypedValue<RankedTensorType> source = sliceOp.getSource();
2170 auto sourceType = source.getType();
2171 if (!VectorType::isValidElementType(sourceType.getElementType()))
2172 return failure();
2173
2174 // Get the pad value.
2175 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2176 // scalar padding value. Note that:
2177 // * for in-bounds accesses,
2178 // the value is actually irrelevant. There are 2 cases in which xfer.read
2179 // accesses are known to be in-bounds:
2180 // 1. The source shape is static (output vector sizes would be based on
2181 // the source shape and hence all memory accesses would be in-bounds),
2182 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2183 // this case it is safe to assume that all memory accesses are in-bounds.
2184 //
2185 // When the value is not known and not needed, use 0. Otherwise, bail out.
2186 Value padValue = getStaticPadVal(sliceOp);
2187 bool isOutOfBoundsRead =
2188 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2189
2190 if (!padValue && isOutOfBoundsRead) {
2191 LDBG() << "Failed to get a pad value for out-of-bounds read access";
2192 return failure();
2193 }
2194 return success();
2195}
2196
2197/// Vectorize a named linalg contraction op into:
2198/// vector::TransferReadOp - Reads vectors from the operands
2199/// vector::ContractionOp - Performs contraction
2200/// vector::TransferWriteOp - Write the result vector back to the
2201/// destination
2202/// The operands shapes are preserved and loaded directly into vectors.
2203/// Any further permutations or numerical casting remain within contraction op.
2204static LogicalResult
2205vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2206 LinalgOp linalgOp,
2207 SmallVectorImpl<Value> &newResults) {
2208 Location loc = linalgOp.getLoc();
2209 MLIRContext *ctx = linalgOp.getContext();
2210
2211 // For simplicity, contraction vectorization is limited to linalg named ops.
2212 // Generic op is ignored as not every arbitrary contraction body can be
2213 // expressed by a vector.contract.
2214 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2215 return failure();
2216
2217 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2218 Operation *reduceOp = matchLinalgReduction(outOperand);
2219 auto maybeKind = getCombinerOpKind(reduceOp);
2220 if (!maybeKind) {
2221 LDBG() << "Failed to determine contraction combining kind.";
2222 return failure();
2223 }
2224
2225 // Check that all dimensions are present in the input operands.
2226 // Arbitrary broadcasts are not supported by the vector contraction.
2227 // Broadcasts are expected to be decomposed before vectorization.
2228 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2229 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2230 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2231 LDBG() << "Contractions with broadcasts are not supported.";
2232 return failure();
2233 }
2234
2235 // Load operands.
2236 SmallVector<Value> vecOperands;
2237 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2238 // The operand vector shape is computed by mapping the canonical vector
2239 // shape to the operand's domain. Further permutations are left as a part of
2240 // the contraction.
2241 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2242 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2243 indexingMap.getNumResults(), rewriter.getContext());
2244 Type elemType = getElementTypeOrSelf(opOperand.get());
2245 VectorType readType =
2246 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2247
2249 rewriter, loc, opOperand.get(), readType,
2250 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2251 /*useInBoundsInsteadOfMasking=*/false);
2252 vecOperands.push_back(read);
2253 }
2254
2255 // Remap iterators from linalg to vector.
2256 SmallVector<Attribute> iterAttrs;
2257 auto iterators = linalgOp.getIteratorTypesArray();
2258 for (utils::IteratorType iter : iterators) {
2259 auto vecIter = iter == utils::IteratorType::parallel
2260 ? vector::IteratorType::parallel
2261 : vector::IteratorType::reduction;
2262 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2263 }
2264
2265 // Create contraction.
2266 Operation *contractOp = vector::ContractionOp::create(
2267 rewriter, loc, /*lhs=*/vecOperands[0],
2268 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2269 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2270 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2271
2272 // Store result.
2273 Operation *write = createWriteOrMaskedWrite(
2274 rewriter, loc, contractOp->getResult(0), outOperand->get());
2275
2276 // Finalize.
2277 if (!write->getResults().empty())
2278 newResults.push_back(write->getResult(0));
2279
2280 return success();
2281}
2282
2283namespace {
2284enum class ConvOperationKind { Conv, Pool };
2285} // namespace
2286
2287static bool isCastOfBlockArgument(Operation *op) {
2288 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2289 isa<BlockArgument>(op->getOperand(0));
2290}
2291
2292// Returns the ConvOperationKind of the op using reduceOp of the generic
2293// payload. If it is neither a convolution nor a pooling, it returns
2294// std::nullopt.
2295//
2296// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2297// + yield) and rhs is not used) then it is the body of a pooling
2298// If conv, check for single `mul` predecessor. The `mul` operands must be
2299// block arguments or extension of block arguments.
2300// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2301// must be block arguments or extension of block arguments.
2302static std::optional<ConvOperationKind>
2303getConvOperationKind(Operation *reduceOp) {
2304 int numBlockArguments =
2305 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2306
2307 switch (numBlockArguments) {
2308 case 1: {
2309 // Will be convolution if feeder is a MulOp.
2310 // A strength reduced version of MulOp for i1 type is AndOp which is also
2311 // supported. Otherwise, it can be pooling. This strength reduction logic
2312 // is in `buildBinaryFn` helper in the Linalg dialect.
2313 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2314 llvm::IsaPred<BlockArgument>);
2315 assert(feedValIt != reduceOp->operand_end() &&
2316 "Expected a non-block argument operand");
2317 Operation *feedOp = (*feedValIt).getDefiningOp();
2318 if (isCastOfBlockArgument(feedOp)) {
2319 return ConvOperationKind::Pool;
2320 }
2321
2322 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2323 (isa<arith::AndIOp>(feedOp) &&
2324 feedOp->getResultTypes()[0].isInteger(1))) &&
2325 llvm::all_of(feedOp->getOperands(), [](Value v) {
2326 if (isa<BlockArgument>(v))
2327 return true;
2328 if (Operation *op = v.getDefiningOp())
2329 return isCastOfBlockArgument(op);
2330 return false;
2331 }))) {
2332 return std::nullopt;
2333 }
2334
2335 return ConvOperationKind::Conv;
2336 }
2337 case 2:
2338 // Must be pooling
2339 return ConvOperationKind::Pool;
2340 default:
2341 return std::nullopt;
2342 }
2343}
2344
2345static bool isSupportedPoolKind(vector::CombiningKind kind) {
2346 switch (kind) {
2347 case vector::CombiningKind::ADD:
2348 case vector::CombiningKind::MAXNUMF:
2349 case vector::CombiningKind::MAXIMUMF:
2350 case vector::CombiningKind::MAXSI:
2351 case vector::CombiningKind::MAXUI:
2352 case vector::CombiningKind::MINNUMF:
2353 case vector::CombiningKind::MINIMUMF:
2354 case vector::CombiningKind::MINSI:
2355 case vector::CombiningKind::MINUI:
2356 return true;
2357 default:
2358 return false;
2359 }
2360}
2361
2362static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2363 auto getOperandType = [&](auto operand) {
2364 return dyn_cast<ShapedType>((operand->get()).getType());
2365 };
2366 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2367 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2368 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2369 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2370 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2371 // Note that this also ensures 2D and 3D convolutions are rejected.
2372 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2373 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2374 return failure();
2375
2376 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2377 if (!reduceOp)
2378 return failure();
2379
2380 auto maybeOper = getConvOperationKind(reduceOp);
2381 if (!maybeOper.has_value())
2382 return failure();
2383
2384 auto maybeKind = getCombinerOpKind(reduceOp);
2385 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2386 // can get strength reduced to `OR` which is also supported. This strength
2387 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2388 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2389 *maybeKind != vector::CombiningKind::OR) &&
2390 (*maybeOper != ConvOperationKind::Pool ||
2391 !isSupportedPoolKind(*maybeKind)))) {
2392 return failure();
2393 }
2394
2395 auto rhsRank = rhsShapedType.getRank();
2396 if (*maybeOper == ConvOperationKind::Pool) {
2397 if (rhsRank != 1)
2398 return failure();
2399 } else {
2400 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2401 return failure();
2402 }
2403
2404 return success();
2405}
2406
2407static LogicalResult vectorizeLinalgOpPrecondition(
2408 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2409 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2410 // tensor with dimension of 0 cannot be vectorized.
2411 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2412 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2413 }))
2414 return failure();
2415 // Check API contract for input vector sizes.
2416 if (!inputVectorSizes.empty() &&
2417 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2418 inputVectorSizes)))
2419 return failure();
2420
2421 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2422 linalgOp, flatten1DDepthwiseConv))) {
2423 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2424 return failure();
2425 }
2426
2427 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2428
2429 // Register CustomVectorizationPrecondition for extractOp.
2430 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2431
2432 // All types in the body should be a supported element type for VectorType.
2433 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2434 // Check if any custom hook can vectorize the inner op.
2435 if (llvm::any_of(
2436 customPreconditions,
2437 [&](const CustomVectorizationPrecondition &customPrecondition) {
2438 return succeeded(
2439 customPrecondition(&innerOp, vectorizeNDExtract));
2440 })) {
2441 continue;
2442 }
2443 if (!llvm::all_of(innerOp.getOperandTypes(),
2444 VectorType::isValidElementType)) {
2445 return failure();
2446 }
2447 if (!llvm::all_of(innerOp.getResultTypes(),
2448 VectorType::isValidElementType)) {
2449 return failure();
2450 }
2451 }
2452 if (isElementwise(linalgOp))
2453 return success();
2454
2455 // Check for both named as well as generic convolution ops.
2456 if (isaConvolutionOpInterface(linalgOp))
2457 return vectorizeConvOpPrecondition(linalgOp);
2458
2459 // TODO: the common vector shape is equal to the static loop sizes only when
2460 // all indexing maps are projected permutations. For convs and stencils the
2461 // logic will need to evolve.
2462 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2463 LDBG() << "precondition failed: not projected permutations";
2464 return failure();
2465 }
2466 if (failed(reductionPreconditions(linalgOp))) {
2467 LDBG() << "precondition failed: reduction preconditions";
2468 return failure();
2469 }
2470 return success();
2471}
2472
2473static LogicalResult
2474vectorizePackOpPrecondition(linalg::PackOp packOp,
2475 ArrayRef<int64_t> inputVectorSizes) {
2476 // TODO: Support Memref PackOp. Temporarily return failure.
2477 if (!packOp.hasPureTensorSemantics())
2478 return failure();
2479
2480 auto padValue = packOp.getPaddingValue();
2481 Attribute cstAttr;
2482 // TODO: Relax this condiiton
2483 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2484 LDBG() << "pad value is not constant: " << packOp;
2485 return failure();
2486 }
2487
2488 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2489 bool satisfyEmptyCond = true;
2490 if (inputVectorSizes.empty()) {
2491 if (!packOp.getDestType().hasStaticShape() ||
2492 !packOp.getSourceType().hasStaticShape())
2493 satisfyEmptyCond = false;
2494 }
2495
2496 if (!satisfyEmptyCond &&
2498 resultTensorShape.take_front(packOp.getSourceRank()),
2499 inputVectorSizes)))
2500 return failure();
2501
2502 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2503 return !getConstantIntValue(v).has_value();
2504 })) {
2505 LDBG() << "inner_tiles must be constant: " << packOp;
2506 return failure();
2507 }
2508
2509 return success();
2510}
2511
2512static LogicalResult
2513vectorizePadOpPrecondition(tensor::PadOp padOp,
2514 ArrayRef<int64_t> inputVectorSizes) {
2515 auto padValue = padOp.getConstantPaddingValue();
2516 if (!padValue) {
2517 LDBG() << "pad value is not constant: " << padOp;
2518 return failure();
2519 }
2520
2521 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2522 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2523 inputVectorSizes)))
2524 return failure();
2525
2526 // Padding with non-zero low pad values is not supported, unless the
2527 // corresponding result dim is 1 as this would require shifting the results to
2528 // the right for the low padded dims by the required amount of low padding.
2529 // However, we do support low padding if the dims being low padded have result
2530 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2531 // input size of that dimension will be dynamically zero (as the sum of the
2532 // low pad and input dim size has to be one) and hence we will create a zero
2533 // mask as the lowering logic just makes the mask one for the input dim size -
2534 // which is zero here. Hence we will load the pad value which is what we want
2535 // in this case. If the low pad is dynamically zero then the lowering is
2536 // correct as well as no shifts are necessary.
2537 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2538 [&](const auto &en) {
2539 OpFoldResult padValue = en.value();
2540 unsigned pos = en.index();
2541 std::optional<int64_t> pad = getConstantIntValue(padValue);
2542 return (!pad.has_value() || pad.value() != 0) &&
2543 resultTensorShape[pos] != 1;
2544 })) {
2545 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2546 return failure();
2547 }
2548
2549 return success();
2550}
2551
2552/// Preconditions for scalable vectors.
2553///
2554/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2555/// models the fact that in practice we would only make selected dimensions
2556/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2557/// unconditionally - we are yet to identify meaningful conditions.
2558static LogicalResult
2559vectorizeScalableVectorPrecondition(Operation *op,
2560 ArrayRef<int64_t> inputVectorSizes,
2561 ArrayRef<bool> inputScalableVecDims) {
2562 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2563 "Number of input vector sizes and scalable dims doesn't match");
2564
2565 size_t numOfScalableDims =
2566 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2567
2568 if (numOfScalableDims == 0)
2569 return success();
2570
2571 auto linalgOp = dyn_cast<LinalgOp>(op);
2572
2573 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2574 // exception of UnpackOp for which there is a dedicated hook.
2575 if (!linalgOp) {
2576 return success(isa<linalg::UnPackOp>(op));
2577 }
2578
2579 // Cond 2: There's been no need for more than 2 scalable dims so far
2580 if (numOfScalableDims > 2)
2581 return failure();
2582
2583 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2584 // it matches one of the supported cases:
2585 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2586 // (*).
2587 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2588 // parallel dims.
2589 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2590 // dim.
2591 // The 2nd restriction above means that only Matmul-like Ops are supported
2592 // when 2 dims are scalable, e.g. :
2593 // * iterators = [parallel, parallel, reduction]
2594 // * scalable flags = [true, true, false]
2595 //
2596 // (*) Non-unit dims get folded away in practice.
2597 // TODO: Relax these conditions as good motivating examples are identified.
2598
2599 // Find the first scalable flag.
2600 bool seenNonUnitParallel = false;
2601 auto iterators = linalgOp.getIteratorTypesArray();
2602 SmallVector<bool> scalableFlags(inputScalableVecDims);
2603 int64_t idx = scalableFlags.size() - 1;
2604 while (!scalableFlags[idx]) {
2605 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2606 seenNonUnitParallel |=
2607 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2608
2609 iterators.pop_back();
2610 scalableFlags.pop_back();
2611 --idx;
2612 }
2613
2614 // Analyze the iterator corresponding to the first scalable dim.
2615 switch (iterators.back()) {
2616 case utils::IteratorType::reduction: {
2617 // Check 3. above is met.
2618 if (iterators.size() != inputVectorSizes.size()) {
2619 LDBG() << "Non-trailing reduction dim requested for scalable "
2620 "vectorization";
2621 return failure();
2622 }
2623 if (isa<linalg::MatmulOp>(op)) {
2624 LDBG()
2625 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2626 "is not supported";
2627 return failure();
2628 }
2629 break;
2630 }
2631 case utils::IteratorType::parallel: {
2632 // Check 1. and 2. above are met.
2633 if (seenNonUnitParallel) {
2634 LDBG() << "Inner parallel dim not requested for scalable "
2635 "vectorization";
2636 return failure();
2637 }
2638 break;
2639 }
2640 }
2641
2642 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2643 // supported for which expect the folowing config:
2644 // * iterators = [parallel, parallel, reduction]
2645 // * scalable flags = [true, true, false]
2646 if (numOfScalableDims == 2) {
2647 // Disallow below case which breaks 3. above:
2648 // * iterators = [..., parallel, reduction]
2649 // * scalable flags = [..., true, true]
2650 if (iterators.back() == utils::IteratorType::reduction) {
2651 LDBG() << "Higher dim than the trailing reduction dim requested for "
2652 "scalable "
2653 "vectorizatio";
2654 return failure();
2655 }
2656 scalableFlags.pop_back();
2657 iterators.pop_back();
2658
2659 if (!scalableFlags.back() ||
2660 (iterators.back() != utils::IteratorType::parallel))
2661 return failure();
2662 }
2663
2664 // Cond 4: Only the following ops are supported in the
2665 // presence of scalable vectors
2666 return success(
2667 isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2668 isa<linalg::BatchMatmulOp>(op) ||
2670 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2671 isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
2672}
2673
2675 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2676 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2677 bool flatten1DDepthwiseConv) {
2678
2679 if (!hasVectorizationImpl(op))
2680 return failure();
2681
2682 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2683 inputScalableVecDims)))
2684 return failure();
2685
2687 .Case([&](linalg::LinalgOp linalgOp) {
2688 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2689 vectorizeNDExtract,
2690 flatten1DDepthwiseConv);
2691 })
2692 .Case([&](tensor::PadOp padOp) {
2693 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2694 })
2695 .Case([&](linalg::PackOp packOp) {
2696 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2697 })
2698 .Case([&](linalg::UnPackOp unpackOp) {
2699 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2700 })
2701 .Case([&](tensor::InsertSliceOp sliceOp) {
2702 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2703 })
2704 .Default(failure());
2705}
2706
2707/// Converts affine.apply Ops to arithmetic operations.
2708static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2709 OpBuilder::InsertionGuard g(rewriter);
2710 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2711
2712 for (auto op : make_early_inc_range(toReplace)) {
2713 rewriter.setInsertionPoint(op);
2714 auto expanded = affine::expandAffineExpr(
2715 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2716 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2717 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2718 rewriter.replaceOp(op, expanded);
2719 }
2720}
2721
2722bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2723 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2724 tensor::InsertSliceOp>(op);
2725}
2726
2727FailureOr<VectorizationResult> mlir::linalg::vectorize(
2728 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2729 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2730 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2731 bool createNamedContraction) {
2732 LDBG() << "Attempting to vectorize: " << *op;
2733 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2734 LDBG() << "Input scalable vector dims: "
2735 << llvm::interleaved(inputScalableVecDims);
2736
2737 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2738 vectorizeNDExtract,
2739 flatten1DDepthwiseConv))) {
2740 LDBG() << "Vectorization pre-conditions failed";
2741 return failure();
2742 }
2743
2744 // Initialize vectorization state.
2745 VectorizationState state(rewriter);
2746 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2747 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2748 inputScalableVecDims,
2749 assumeDynamicDimsMatchVecSizes))) {
2750 LDBG() << "Vectorization state couldn't be initialized";
2751 return failure();
2752 }
2753 }
2754
2755 SmallVector<Value> results;
2756 auto vectorizeResult =
2758 .Case([&](linalg::LinalgOp linalgOp) {
2759 // Check for both named as well as generic convolution ops.
2760 if (isaConvolutionOpInterface(linalgOp)) {
2761 FailureOr<Operation *> convOr = vectorizeConvolution(
2762 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2763 flatten1DDepthwiseConv);
2764 if (succeeded(convOr)) {
2765 llvm::append_range(results, (*convOr)->getResults());
2766 return success();
2767 }
2768
2769 LDBG() << "Unsupported convolution can't be vectorized.";
2770 return failure();
2771 }
2772
2773 if (createNamedContraction &&
2774 isa<ContractionOpInterface>(linalgOp.getOperation()))
2775 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2776 results);
2777
2778 LDBG()
2779 << "Vectorize generic by broadcasting to the canonical vector "
2780 "shape";
2781
2782 // Pre-process before proceeding.
2783 convertAffineApply(rewriter, linalgOp);
2784
2785 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2786 // to 'OpBuilder' when it is passed over to some methods like
2787 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2788 // erase an op within these methods, the actual rewriter won't be
2789 // notified and we will end up with read-after-free issues!
2790 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2791 })
2792 .Case([&](tensor::PadOp padOp) {
2793 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2794 results);
2795 })
2796 .Case([&](linalg::PackOp packOp) {
2797 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2798 results);
2799 })
2800 .Case([&](linalg::UnPackOp unpackOp) {
2801 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2802 inputVectorSizes,
2803 inputScalableVecDims, results);
2804 })
2805 .Case([&](tensor::InsertSliceOp sliceOp) {
2806 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2807 results);
2808 })
2809 .Default(failure());
2810
2811 if (failed(vectorizeResult)) {
2812 LDBG() << "Vectorization failed";
2813 return failure();
2814 }
2815
2816 return VectorizationResult{results};
2817}
2818
2819LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2820 memref::CopyOp copyOp) {
2821 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2822 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2823 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2824 return failure();
2825
2826 auto srcElementType = getElementTypeOrSelf(srcType);
2827 auto dstElementType = getElementTypeOrSelf(dstType);
2828 if (!VectorType::isValidElementType(srcElementType) ||
2829 !VectorType::isValidElementType(dstElementType))
2830 return failure();
2831
2832 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2833 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2834
2835 Location loc = copyOp->getLoc();
2836 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2837 SmallVector<Value> indices(srcType.getRank(), zero);
2838
2839 Value readValue = vector::TransferReadOp::create(
2840 rewriter, loc, readType, copyOp.getSource(), indices,
2841 /*padding=*/std::nullopt,
2842 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2843 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2844 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2845 ArrayRef<int64_t>());
2846 readValue =
2847 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2848 }
2849 Operation *writeValue = vector::TransferWriteOp::create(
2850 rewriter, loc, readValue, copyOp.getTarget(), indices,
2851 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2852 rewriter.replaceOp(copyOp, writeValue->getResults());
2853 return success();
2854}
2855
2856//----------------------------------------------------------------------------//
2857// Misc. vectorization patterns.
2858//----------------------------------------------------------------------------//
2859/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2860/// given operation type OpTy.
2861template <typename OpTy>
2862struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2863 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2864
2865 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2866 PatternRewriter &rewriter) const final {
2867 bool changed = false;
2868 // Insert users in vector, because some users may be replaced/removed.
2869 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2870 if (auto op = dyn_cast<OpTy>(user))
2871 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2872 return success(changed);
2873 }
2874
2875protected:
2876 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2877 tensor::PadOp padOp, OpTy op) const = 0;
2878};
2879
2880/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2881/// ```
2882/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2883/// %r = vector.transfer_read %0[%c0, %c0], %cst
2884/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2885/// ```
2886/// is rewritten to:
2887/// ```
2888/// %r = vector.transfer_read %src[%c0, %c0], %padding
2889/// {in_bounds = [true, true]}
2890/// : tensor<?x?xf32>, vector<17x5xf32>
2891/// ```
2892/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2893/// sure that the original padding value %cst was never used.
2894///
2895/// This rewrite is possible if:
2896/// - `xferOp` has no out-of-bounds dims or mask.
2897/// - Low padding is static 0.
2898/// - Single, scalar padding value.
2899struct PadOpVectorizationWithTransferReadPattern
2900 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2901 using VectorizePadOpUserPattern<
2902 vector::TransferReadOp>::VectorizePadOpUserPattern;
2903
2904 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2905 vector::TransferReadOp xferOp) const override {
2906 // Low padding must be static 0.
2907 if (!padOp.hasZeroLowPad())
2908 return failure();
2909 // Pad value must be a constant.
2910 auto padValue = padOp.getConstantPaddingValue();
2911 if (!padValue)
2912 return failure();
2913 // Padding value of existing `xferOp` is unused.
2914 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2915 return failure();
2916
2917 rewriter.modifyOpInPlace(xferOp, [&]() {
2918 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2919 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2920 rewriter.getBoolArrayAttr(inBounds));
2921 xferOp.getBaseMutable().assign(padOp.getSource());
2922 xferOp.getPaddingMutable().assign(padValue);
2923 });
2924
2925 return success();
2926 }
2927};
2928
2929/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2930/// This pattern rewrites TransferWriteOps that write to a padded tensor
2931/// value, where the same amount of padding is immediately removed again after
2932/// the write. In such cases, the TransferWriteOp can write to the non-padded
2933/// tensor value and apply out-of-bounds masking. E.g.:
2934/// ```
2935/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2936/// : tensor<...> to tensor<?x?xf32>
2937/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2938/// %2 = vector.transfer_write %vec, %1[...]
2939/// : vector<17x5xf32>, tensor<17x5xf32>
2940/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2941/// : tensor<17x5xf32> to tensor<?x?xf32>
2942/// ```
2943/// is rewritten to:
2944/// ```
2945/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2946/// : tensor<...> to tensor<?x?xf32>
2947/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2948/// tensor<?x?xf32>
2949/// ```
2950/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2951/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2952/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2953/// from %r's old dimensions.
2954///
2955/// This rewrite is possible if:
2956/// - Low padding is static 0.
2957/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2958/// ExtractSliceOp trims the same amount of padding that was added
2959/// beforehand.
2960/// - Single, scalar padding value.
2961struct PadOpVectorizationWithTransferWritePattern
2962 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2963 using VectorizePadOpUserPattern<
2964 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2965
2966 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2967 vector::TransferWriteOp xferOp) const override {
2968 // TODO: support 0-d corner case.
2969 if (xferOp.getTransferRank() == 0)
2970 return failure();
2971
2972 // Low padding must be static 0.
2973 if (!padOp.hasZeroLowPad())
2974 return failure();
2975 // Pad value must be a constant.
2976 auto padValue = padOp.getConstantPaddingValue();
2977 if (!padValue)
2978 return failure();
2979 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2980 if (!xferOp->hasOneUse())
2981 return failure();
2982 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2983 if (!trimPadding)
2984 return failure();
2985 // Only static zero offsets supported when trimming padding.
2986 if (!trimPadding.hasZeroOffset())
2987 return failure();
2988 // trimPadding must remove the amount of padding that was added earlier.
2989 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2990 return failure();
2991
2992 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2993 rewriter.setInsertionPoint(xferOp);
2994
2995 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2996 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2997 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2998 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2999 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
3000 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
3001
3002 return success();
3003 }
3004
3005 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
3006 /// i.e., same dimensions.
3007 ///
3008 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
3009 /// dimensions, this function tries to infer the (static) tensor size by
3010 /// looking at the defining op and utilizing op-specific knowledge.
3011 ///
3012 /// This is a conservative analysis. In case equal tensor sizes cannot be
3013 /// proven statically, this analysis returns `false` even though the tensor
3014 /// sizes may turn out to be equal at runtime.
3015 bool hasSameTensorSize(Value beforePadding,
3016 tensor::ExtractSliceOp afterTrimming) const {
3017 // If the input to tensor::PadOp is a CastOp, try with both CastOp
3018 // result and CastOp operand.
3019 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
3020 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
3021 return true;
3022
3023 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
3024 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3025 // Only RankedTensorType supported.
3026 if (!t1 || !t2)
3027 return false;
3028 // Rank of both values must be the same.
3029 if (t1.getRank() != t2.getRank())
3030 return false;
3031
3032 // All static dimensions must be the same. Mixed cases (e.g., dimension
3033 // static in `t1` but dynamic in `t2`) are not supported.
3034 for (unsigned i = 0; i < t1.getRank(); ++i) {
3035 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3036 return false;
3037 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3038 return false;
3039 }
3040
3041 // Nothing more to check if all dimensions are static.
3042 if (t1.getNumDynamicDims() == 0)
3043 return true;
3044
3045 // All dynamic sizes must be the same. The only supported case at the
3046 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3047 // thereof).
3048
3049 // Apart from CastOp, only ExtractSliceOp is supported.
3050 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3051 if (!beforeSlice)
3052 return false;
3053
3054 assert(static_cast<size_t>(t1.getRank()) ==
3055 beforeSlice.getMixedSizes().size());
3056 assert(static_cast<size_t>(t2.getRank()) ==
3057 afterTrimming.getMixedSizes().size());
3058
3059 for (unsigned i = 0; i < t1.getRank(); ++i) {
3060 // Skip static dimensions.
3061 if (!t1.isDynamicDim(i))
3062 continue;
3063 auto size1 = beforeSlice.getMixedSizes()[i];
3064 auto size2 = afterTrimming.getMixedSizes()[i];
3065
3066 // Case 1: Same value or same constant int.
3067 if (isEqualConstantIntOrValue(size1, size2))
3068 continue;
3069
3070 // Other cases: Take a deeper look at defining ops of values.
3071 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3072 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3073 if (!v1 || !v2)
3074 return false;
3075
3076 // Case 2: Both values are identical AffineMinOps. (Should not happen if
3077 // CSE is run.)
3078 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3079 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3080 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3081 minOp1.getOperands() == minOp2.getOperands())
3082 continue;
3083
3084 // Add additional cases as needed.
3085 }
3086
3087 // All tests passed.
3088 return true;
3089 }
3090};
3091
3092/// Returns the effective Pad value for the input op, provided it's a scalar.
3093///
3094/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3095/// this Op performs padding, retrieve the padding value provided that it's
3096/// a scalar and static/fixed for all the padded values. Returns an empty value
3097/// otherwise.
3098///
3099/// TODO: This is used twice (when checking vectorization pre-conditions and
3100/// when vectorizing). Cache results instead of re-running.
3101static Value getStaticPadVal(Operation *op) {
3102 if (!op)
3103 return {};
3104
3105 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3106 // being broadcast, provided that it's a scalar.
3107 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3108 auto source = bcast.getSource();
3109 if (llvm::dyn_cast<VectorType>(source.getType()))
3110 return {};
3111
3112 return source;
3113 }
3114
3115 // 2. linalg.fill - use the scalar input value that used to fill the output
3116 // tensor.
3117 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3118 return fill.getInputs()[0];
3119 }
3120
3121 // 3. tensor.generateOp - can't guarantee the value is fixed without
3122 // analysing, bail out.
3123 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3124 return {};
3125 }
3126
3127 // 4. vector.transfer_write - inspect the input vector that's written from. If
3128 // if contains a single value that has been broadcast (e.g. via
3129 // vector.broadcast), extract it, fail otherwise.
3130 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3131 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3132
3133 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3134 // than the input tensor, then, provided it's constant, we'll extract the
3135 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3136 // TODO: Clarify the semantics when the input tensor is larger than the
3137 // destination.
3138 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3139 return getStaticPadVal(slice.getDest().getDefiningOp());
3140
3141 return {};
3142}
3143
3144static LogicalResult
3145vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3146 ArrayRef<int64_t> inputVectorSizes,
3147 SmallVectorImpl<Value> &newResults) {
3148 // TODO: Introduce a parent class that will handle the insertion point update.
3149 OpBuilder::InsertionGuard g(rewriter);
3150 rewriter.setInsertionPoint(sliceOp);
3151
3152 TypedValue<RankedTensorType> source = sliceOp.getSource();
3153 auto sourceType = source.getType();
3154 auto resultType = sliceOp.getResultType();
3155
3156 Value padValue = getStaticPadVal(sliceOp);
3157
3158 if (!padValue) {
3159 auto elemType = sourceType.getElementType();
3160 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3161 rewriter.getZeroAttr(elemType));
3162 }
3163
3164 // 2. Get the vector shape
3165 SmallVector<int64_t> vecShape;
3166 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3167 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3168 if (!inputVectorSizes.empty()) {
3169 vecShape.push_back(inputVectorSizes[i]);
3170 } else if (!sourceType.isDynamicDim(i)) {
3171 vecShape.push_back(sourceType.getDimSize(i));
3172 } else if (!resultType.isDynamicDim(i)) {
3173 // Source shape is not statically known, but result shape is.
3174 // Vectorize with size of result shape. This may be larger than the
3175 // source size.
3176 // FIXME: Using rankDiff implies that the source tensor is inserted at
3177 // the end of the destination tensor. However, that's not required.
3178 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3179 } else {
3180 // Neither source nor result dim of padOp is static. Cannot vectorize
3181 // the copy.
3182 return failure();
3183 }
3184 }
3185 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3186
3187 // 3. Generate TransferReadOp + TransferWriteOp
3188 auto loc = sliceOp.getLoc();
3189
3190 // Create read
3191 SmallVector<Value> readIndices(
3192 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3194 rewriter, loc, source, vecType, padValue,
3195 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3196
3197 // Create write
3198 auto writeIndices =
3199 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3200 Operation *write =
3201 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3202 writeIndices, inputVectorSizes.empty());
3203
3204 // 4. Finalize
3205 newResults.push_back(write->getResult(0));
3206
3207 return success();
3208}
3209
3210/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3211/// ```
3212/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3213/// %r = tensor.insert_slice %0
3214/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3215/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3216/// ```
3217/// is rewritten to:
3218/// ```
3219/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3220/// : tensor<?x?xf32>, vector<17x5xf32>
3221/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3222/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3223/// ```
3224///
3225/// This rewrite is possible if:
3226/// - Low padding is static 0.
3227/// - `padOp` result shape is static.
3228/// - The entire padded tensor is inserted.
3229/// (Implies that sizes of `insertOp` are all static.)
3230/// - Only unit strides in `insertOp`.
3231/// - Single, scalar padding value.
3232/// - `padOp` result not used as destination.
3233struct PadOpVectorizationWithInsertSlicePattern
3234 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3235 using VectorizePadOpUserPattern<
3236 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3237
3238 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3239 tensor::InsertSliceOp insertOp) const override {
3240 // Low padding must be static 0.
3241 if (!padOp.hasZeroLowPad())
3242 return failure();
3243 // Only unit stride supported.
3244 if (!insertOp.hasUnitStride())
3245 return failure();
3246 // Pad value must be a constant.
3247 auto padValue = padOp.getConstantPaddingValue();
3248 if (!padValue)
3249 return failure();
3250 // Dynamic shapes not supported.
3251 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3252 return failure();
3253 // Pad result not used as destination.
3254 if (insertOp.getDest() == padOp.getResult())
3255 return failure();
3256
3257 auto vecType = VectorType::get(padOp.getType().getShape(),
3258 padOp.getType().getElementType());
3259 unsigned vecRank = vecType.getRank();
3260 unsigned tensorRank = insertOp.getType().getRank();
3261
3262 // Check if sizes match: Insert the entire tensor into most minor dims.
3263 // (No permutations allowed.)
3264 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3265 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3266 if (!llvm::all_of(
3267 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3268 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3269 }))
3270 return failure();
3271
3272 // Insert the TransferReadOp and TransferWriteOp at the position of the
3273 // InsertSliceOp.
3274 rewriter.setInsertionPoint(insertOp);
3275
3276 // Generate TransferReadOp: Read entire source tensor and add high
3277 // padding.
3278 SmallVector<Value> readIndices(
3279 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3280 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3281 vecType, padOp.getSource(),
3282 readIndices, padValue);
3283
3284 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3285 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3286 // source must fit into the destination at the specified offsets.
3287 auto writeIndices = getValueOrCreateConstantIndexOp(
3288 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3289 SmallVector<bool> inBounds(vecRank, true);
3290 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3291 insertOp, read, insertOp.getDest(), writeIndices,
3292 ArrayRef<bool>{inBounds});
3293
3294 return success();
3295 }
3296};
3297
3299 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3300 patterns.add<PadOpVectorizationWithTransferReadPattern,
3301 PadOpVectorizationWithTransferWritePattern,
3302 PadOpVectorizationWithInsertSlicePattern>(
3303 patterns.getContext(), baseBenefit.getBenefit() + 1);
3304}
3305
3306//----------------------------------------------------------------------------//
3307// Forwarding patterns
3308//----------------------------------------------------------------------------//
3309
3310/// Check whether there is any interleaved use of any `values` between
3311/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3312/// is in a different block.
3313static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3314 ValueRange values) {
3315 if (firstOp->getBlock() != secondOp->getBlock() ||
3316 !firstOp->isBeforeInBlock(secondOp)) {
3317 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3318 << ", second op: " << *secondOp;
3319 return true;
3320 }
3321 for (auto v : values) {
3322 for (auto &u : v.getUses()) {
3323 Operation *owner = u.getOwner();
3324 if (owner == firstOp || owner == secondOp)
3325 continue;
3326 // TODO: this is too conservative, use dominance info in the future.
3327 if (owner->getBlock() == firstOp->getBlock() &&
3328 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3329 continue;
3330 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3331 << ", second op: " << *secondOp;
3332 return true;
3333 }
3334 }
3335 return false;
3336}
3337
3338/// Return the unique subview use of `v` if it is indeed unique, null
3339/// otherwise.
3340static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3341 memref::SubViewOp subViewOp;
3342 for (auto &u : v.getUses()) {
3343 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3344 if (subViewOp)
3345 return memref::SubViewOp();
3346 subViewOp = newSubViewOp;
3347 }
3348 }
3349 return subViewOp;
3350}
3351
3352/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3353/// when available.
3355 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3356
3357 // TODO: support mask.
3358 if (xferOp.getMask())
3359 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3360
3361 // Transfer into `view`.
3362 Value viewOrAlloc = xferOp.getBase();
3363 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3364 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3365 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3366
3367 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3368 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3369 if (!subViewOp)
3370 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3371 Value subView = subViewOp.getResult();
3372
3373 // Find the copy into `subView` without interleaved uses.
3374 memref::CopyOp copyOp;
3375 for (auto &u : subView.getUses()) {
3376 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3377 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3378 if (newCopyOp.getTarget() != subView)
3379 continue;
3380 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3381 continue;
3382 copyOp = newCopyOp;
3383 break;
3384 }
3385 }
3386 if (!copyOp)
3387 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3388
3389 // Find the fill into `viewOrAlloc` without interleaved uses before the
3390 // copy.
3391 FillOp maybeFillOp;
3392 for (auto &u : viewOrAlloc.getUses()) {
3393 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3394 assert(isa<MemRefType>(newFillOp.output().getType()));
3395 if (newFillOp.output() != viewOrAlloc)
3396 continue;
3397 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3398 continue;
3399 maybeFillOp = newFillOp;
3400 break;
3401 }
3402 }
3403 // Ensure padding matches.
3404 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3405 return rewriter.notifyMatchFailure(xferOp,
3406 "padding value does not match fill");
3407
3408 // `in` is the subview that memref.copy reads. Replace it.
3409 Value in = copyOp.getSource();
3410
3411 // memref.copy + linalg.fill can be used to create a padded local buffer.
3412 // The `masked` attribute is only valid on this padded buffer.
3413 // When forwarding to vector.transfer_read, the attribute must be reset
3414 // conservatively.
3415 auto vectorType = xferOp.getVectorType();
3416 Value res = vector::TransferReadOp::create(
3417 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3418 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3419 rewriter.getBoolArrayAttr(
3420 SmallVector<bool>(vectorType.getRank(), false)));
3421
3422 if (maybeFillOp)
3423 rewriter.eraseOp(maybeFillOp);
3424 rewriter.eraseOp(copyOp);
3425 rewriter.replaceOp(xferOp, res);
3426
3427 return success();
3428}
3429
3430/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3431/// when available.
3433 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3434 // TODO: support mask.
3435 if (xferOp.getMask())
3436 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3437
3438 // Transfer into `viewOrAlloc`.
3439 Value viewOrAlloc = xferOp.getBase();
3440 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3441 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3442 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3443
3444 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3445 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3446 if (!subViewOp)
3447 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3448 Value subView = subViewOp.getResult();
3449
3450 // Find the copy from `subView` without interleaved uses.
3451 memref::CopyOp copyOp;
3452 for (auto &u : subViewOp.getResult().getUses()) {
3453 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3454 if (newCopyOp.getSource() != subView)
3455 continue;
3456 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3457 continue;
3458 copyOp = newCopyOp;
3459 break;
3460 }
3461 }
3462 if (!copyOp)
3463 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3464
3465 // `out` is the subview copied into that we replace.
3466 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3467 Value out = copyOp.getTarget();
3468
3469 // Forward vector.transfer into copy.
3470 // memref.copy + linalg.fill can be used to create a padded local buffer.
3471 // The `masked` attribute is only valid on this padded buffer.
3472 // When forwarding to vector.transfer_write, the attribute must be reset
3473 // conservatively.
3474 auto vector = xferOp.getVector();
3475 vector::TransferWriteOp::create(
3476 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3477 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3478 rewriter.getBoolArrayAttr(SmallVector<bool>(
3479 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3480
3481 rewriter.eraseOp(copyOp);
3482 rewriter.eraseOp(xferOp);
3483
3484 return success();
3485}
3486
3487//===----------------------------------------------------------------------===//
3488// Convolution vectorization patterns
3489//===----------------------------------------------------------------------===//
3490
3491template <int N>
3492static void bindShapeDims(ShapedType shapedType) {}
3493
3494template <int N, typename IntTy, typename... IntTy2>
3495static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3496 val = shapedType.getShape()[N];
3497 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3498}
3499
3500/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3501template <typename... IntTy>
3502static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3503 bindShapeDims<0>(shapedType, vals...);
3504}
3505
3506/// Match 1D convolution or pooling operations and return their dilations and
3507/// strides. Returns std::nullopt for unrecognized ops.
3508static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3509#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3510 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3511 return convParams;
3512
3513 // 1D Convolution ops.
3514 MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
3515 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
3516 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
3517 // Depthwise 1D Convolution ops.
3518 // Note: Only NWC layout without channel multiplier is supported.
3519 // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
3520 // are not supported.
3521 MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
3522 // 1D Pooling ops (NWC layout).
3523 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
3524 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
3525 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
3526 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
3527 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
3528 // 1D Pooling ops (NCW layout).
3529 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
3530 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
3531
3532#undef MATCH_1D_CONV_POOL_OP
3533
3534 return std::nullopt;
3535}
3536
3537namespace {
3538/// Generate a vector implementation for either:
3539/// ```
3540/// Op def: ( w, kw )
3541/// Iters: ({Par(), Red()})
3542/// Layout: {{w + kw}, {kw}, {w}}
3543/// ```
3544/// kw is unrolled.
3545///
3546/// or
3547///
3548/// ```
3549/// Op def: ( n, w, c, kw, f )
3550/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3551/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3552/// ```
3553/// kw is unrolled, w is unrolled iff dilationW > 1.
3554///
3555/// or
3556///
3557/// ```
3558/// Op def: ( n, c, w, f, kw )
3559/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3560/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3561/// ```
3562/// kw is unrolled, w is unrolled iff dilationW > 1.
3563///
3564/// or
3565///
3566/// ```
3567/// Op def: ( n, w, c, kw )
3568/// Iters: ({Par(), Par(), Par(), Red()})
3569/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3570/// ```
3571/// kw is unrolled, w is unrolled iff dilationW > 1.
3572struct Conv1DGenerator
3573 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3574 /// Factory method to create a Conv1DGenerator. Returns failure if the
3575 /// operation doesn't have valid strides/dilations.
3576 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3577 LinalgOp linalgOp) {
3578 // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
3579 // works for both named ops and generic ops that match their semantics.
3580 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3581 if (!convParams)
3582 return failure();
3583
3584 int strideW = static_cast<int>(convParams->strides.front());
3585 int dilationW = static_cast<int>(convParams->dilations.front());
3586 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3587 }
3588
3589private:
3590 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3591 int dilationW)
3592 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3593 strideW(strideW), dilationW(dilationW) {
3594
3595 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3596 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3597 resShaped = linalgOp.getDpsInitOperand(0)->get();
3598 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3599 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3600 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3601
3602 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3603 redOp = reduceOp->getName().getIdentifier();
3604
3605 setConvOperationKind(reduceOp);
3606
3607 auto maybeKind = getCombinerOpKind(reduceOp);
3608 reductionKind = maybeKind.value();
3609 }
3610
3611public:
3612 /// Generate a vector implementation for:
3613 /// ```
3614 /// Op def: ( w, kw )
3615 /// Iters: ({Par(), Red()})
3616 /// Layout: {{w + kw}, {kw}, {w}}
3617 /// ```
3618 /// kw is always unrolled.
3619 ///
3620 /// or
3621 ///
3622 /// ```
3623 /// Op def: ( n, w, c, kw, f )
3624 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3625 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3626 /// ```
3627 /// kw is always unrolled.
3628 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3629 /// > 1.
3630 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3631 int64_t nSize, wSize, cSize, kwSize, fSize;
3632 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3633 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3634 switch (conv1DOpOrder) {
3635 case Conv1DOpOrder::W:
3636 // Initialize unused dimensions
3637 nSize = fSize = cSize = 0;
3638 // out{W}
3639 bindShapeDims(resShapedType, wSize);
3640 // kernel{kw}
3641 bindShapeDims(rhsShapedType, kwSize);
3642 lhsShape = {// iw = ow + kw - 1
3643 // (i.e. 16 convolved with 3 -> 14)
3644 (wSize + kwSize - 1)};
3645 rhsShape = {kwSize};
3646 resShape = {wSize};
3647 break;
3648 case Conv1DOpOrder::Nwc:
3649 // out{n, w, f}
3650 bindShapeDims(resShapedType, nSize, wSize, fSize);
3651 switch (oper) {
3652 case ConvOperationKind::Conv:
3653 // kernel{kw, c, f}
3654 bindShapeDims(rhsShapedType, kwSize, cSize);
3655 break;
3656 case ConvOperationKind::Pool:
3657 // kernel{kw}
3658 bindShapeDims(rhsShapedType, kwSize);
3659 cSize = fSize;
3660 break;
3661 }
3662 lhsShape = {nSize,
3663 // iw = ow * sw + kw * dw - 1
3664 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3665 // Perform the proper inclusive -> exclusive -> inclusive.
3666 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3667 1,
3668 cSize};
3669 switch (oper) {
3670 case ConvOperationKind::Conv:
3671 rhsShape = {kwSize, cSize, fSize};
3672 break;
3673 case ConvOperationKind::Pool:
3674 rhsShape = {kwSize};
3675 break;
3676 }
3677 resShape = {nSize, wSize, fSize};
3678 break;
3679 case Conv1DOpOrder::Ncw:
3680 // out{n, f, w}
3681 bindShapeDims(resShapedType, nSize, fSize, wSize);
3682 switch (oper) {
3683 case ConvOperationKind::Conv:
3684 // kernel{f, c, kw}
3685 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3686 break;
3687 case ConvOperationKind::Pool:
3688 // kernel{kw}
3689 bindShapeDims(rhsShapedType, kwSize);
3690 cSize = fSize;
3691 break;
3692 }
3693 lhsShape = {nSize, cSize,
3694 // iw = ow * sw + kw * dw - 1
3695 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3696 // Perform the proper inclusive -> exclusive -> inclusive.
3697 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3698 1};
3699 switch (oper) {
3700 case ConvOperationKind::Conv:
3701 rhsShape = {fSize, cSize, kwSize};
3702 break;
3703 case ConvOperationKind::Pool:
3704 rhsShape = {kwSize};
3705 break;
3706 }
3707 resShape = {nSize, fSize, wSize};
3708 break;
3709 }
3710
3711 vector::TransferWriteOp write;
3712 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3713
3714 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3715 // When strideW == 1, we can batch the contiguous loads and avoid
3716 // unrolling
3717 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3718
3719 Type lhsEltType = lhsShapedType.getElementType();
3720 Type rhsEltType = rhsShapedType.getElementType();
3721 Type resEltType = resShapedType.getElementType();
3722 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3723 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3724 auto resType = VectorType::get(resShape, resEltType);
3725 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3726 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3727 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3728 SmallVector<Value> resPadding(resShape.size(), zero);
3729
3730 // Read the whole lhs, rhs and res in one shot (with zero padding).
3731 Value lhs = vector::TransferReadOp::create(
3732 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3733 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3734 // This is needed only for Conv.
3735 Value rhs = nullptr;
3736 if (oper == ConvOperationKind::Conv)
3737 rhs = vector::TransferReadOp::create(
3738 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3739 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3740 Value res = vector::TransferReadOp::create(
3741 rewriter, loc, resType, resShaped, resPadding,
3742 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3743
3744 // The base vectorization case for channeled convolution is input:
3745 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3746 // vectorization case, we do pre transpose on input, weight, and output.
3747 switch (conv1DOpOrder) {
3748 case Conv1DOpOrder::W:
3749 case Conv1DOpOrder::Nwc:
3750 // Base case, so no transposes necessary.
3751 break;
3752 case Conv1DOpOrder::Ncw: {
3753 // To match base vectorization case, we pre-transpose current case.
3754 // ncw -> nwc
3755 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3756 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3757 // fcw -> wcf
3758 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3759
3760 // This is needed only for Conv.
3761 if (oper == ConvOperationKind::Conv)
3762 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3763 // nfw -> nwf
3764 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3765 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3766 break;
3767 }
3768 }
3769
3770 //===------------------------------------------------------------------===//
3771 // Begin vector-only rewrite part
3772 //===------------------------------------------------------------------===//
3773 // Unroll along kw and read slices of lhs and rhs.
3774 SmallVector<Value> lhsVals, rhsVals, resVals;
3775 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3776 kwSize, strideW, dilationW, wSizeStep,
3777 isSingleChanneled);
3778 // Do not do for pooling.
3779 if (oper == ConvOperationKind::Conv)
3780 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3781 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3782 wSizeStep, isSingleChanneled);
3783
3784 auto linearIndex = [&](int64_t kw, int64_t w) {
3785 return kw * (wSize / wSizeStep) + w;
3786 };
3787
3788 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3789 // or perform outerproduct for non-channeled convolution or perform simple
3790 // arith operation for pooling
3791 for (int64_t kw = 0; kw < kwSize; ++kw) {
3792 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3793 switch (oper) {
3794 case ConvOperationKind::Conv:
3795 if (isSingleChanneled) {
3796 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3797 lhsVals[linearIndex(kw, w)],
3798 rhsVals[kw], resVals[w]);
3799 } else {
3800 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3801 lhsVals[linearIndex(kw, w)],
3802 rhsVals[kw], resVals[w]);
3803 }
3804 break;
3805 case ConvOperationKind::Pool:
3806 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3807 resVals[w]);
3808 break;
3809 }
3810 }
3811 }
3812
3813 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3814 isSingleChanneled);
3815 //===------------------------------------------------------------------===//
3816 // End vector-only rewrite part
3817 //===------------------------------------------------------------------===//
3818
3819 // The base vectorization case for channeled convolution is output:
3820 // {n,w,f} To reuse the result from base pattern vectorization case, we
3821 // post transpose the base case result.
3822 switch (conv1DOpOrder) {
3823 case Conv1DOpOrder::W:
3824 case Conv1DOpOrder::Nwc:
3825 // Base case, so no transposes necessary.
3826 break;
3827 case Conv1DOpOrder::Ncw: {
3828 // nwf -> nfw
3829 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3830 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3831 break;
3832 }
3833 }
3834
3835 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3836 resPadding)
3837 .getOperation();
3838 }
3839
3840 // Take a value and widen to have the same element type as `ty`.
3841 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3842 const Type srcElementType = getElementTypeOrSelf(val.getType());
3843 const Type dstElementType = getElementTypeOrSelf(ty);
3844 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3845 if (srcElementType == dstElementType)
3846 return val;
3847
3848 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3849 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3850 const Type dstType =
3851 cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3852
3853 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3854 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3855 }
3856
3857 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3858 srcWidth < dstWidth)
3859 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3860
3861 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3862 srcWidth < dstWidth)
3863 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3864
3865 assert(false && "unhandled promotion case");
3866 return nullptr;
3867 }
3868
3869 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3870 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3871 Value lhs, Value rhs, Value res) {
3872 vector::IteratorType par = vector::IteratorType::parallel;
3873 vector::IteratorType red = vector::IteratorType::reduction;
3874 AffineExpr n, w, f, c;
3875 bindDims(ctx, n, w, f, c);
3876 lhs = promote(rewriter, loc, lhs, res.getType());
3877 rhs = promote(rewriter, loc, rhs, res.getType());
3878 auto contrationOp = vector::ContractionOp::create(
3879 rewriter, loc, lhs, rhs, res,
3880 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3881 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3882 contrationOp.setKind(reductionKind);
3883 return contrationOp;
3884 }
3885
3886 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3887 // convolution.
3888 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3889 Value lhs, Value rhs, Value res) {
3890 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3891 rhs, res, vector::CombiningKind::ADD);
3892 }
3893
3894 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3895 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3896 Value res) {
3897 if (isPoolExt)
3898 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3899 return rewriter
3900 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3901 ->getResult(0);
3902 }
3903
3904 /// Generate a vector implementation for:
3905 /// ```
3906 /// Op def: ( n, w, c, kw)
3907 /// Iters: ({Par(), Par(), Par(), Red()})
3908 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3909 /// ```
3910 /// kw is always unrolled.
3911 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3912 /// > 1.
3913 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3914 bool channelDimScalableFlag,
3915 bool flatten) {
3916 bool scalableChDim = false;
3917 bool useMasking = false;
3918 int64_t nSize, wSize, cSize, kwSize;
3919 // kernel{kw, c}
3920 bindShapeDims(rhsShapedType, kwSize, cSize);
3921 if (ShapedType::isDynamic(cSize)) {
3922 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3923 cSize = channelDimVecSize;
3924 // Scalable vectors are only used when both conditions are met:
3925 // 1. channel dim is dynamic
3926 // 2. channelDimScalableFlag is set
3927 scalableChDim = channelDimScalableFlag;
3928 useMasking = true;
3929 }
3930
3931 assert(!(useMasking && flatten) &&
3932 "Unsupported flattened conv with dynamic shapes");
3933
3934 // out{n, w, c}
3935 bindShapeDims(resShapedType, nSize, wSize);
3936
3937 vector::TransferWriteOp write;
3938 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3939
3940 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3941 // When strideW == 1, we can batch the contiguous loads and avoid
3942 // unrolling
3943 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3944
3945 Type lhsEltType = lhsShapedType.getElementType();
3946 Type rhsEltType = rhsShapedType.getElementType();
3947 Type resEltType = resShapedType.getElementType();
3948 VectorType lhsType = VectorType::get(
3949 {nSize,
3950 // iw = ow * sw + kw * dw - 1
3951 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3952 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3953 cSize},
3954 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3955 VectorType rhsType =
3956 VectorType::get({kwSize, cSize}, rhsEltType,
3957 /*scalableDims=*/{false, scalableChDim});
3958 VectorType resType =
3959 VectorType::get({nSize, wSize, cSize}, resEltType,
3960 /*scalableDims=*/{false, false, scalableChDim});
3961
3962 // Masks the input xfer Op along the channel dim, iff the corresponding
3963 // scalable flag is set.
3964 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3965 ArrayRef<bool> scalableDims,
3966 Operation *opToMask) {
3967 if (!useMasking)
3968 return opToMask;
3969 auto maskType =
3970 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3971
3972 SmallVector<bool> inBounds(maskShape.size(), true);
3973 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3974 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3975 rewriter.getBoolArrayAttr(inBounds));
3976
3977 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3978 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3979
3980 Value maskOp =
3981 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3982
3983 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3984 };
3985
3986 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3987 // 0].
3988 Value lhs = vector::TransferReadOp::create(
3989 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3990 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3991 auto *maybeMaskedLhs = maybeMaskXferOp(
3992 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3993
3994 // Read rhs slice of size {kw, c} @ [0, 0].
3995 Value rhs = vector::TransferReadOp::create(
3996 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3997 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3998 auto *maybeMaskedRhs = maybeMaskXferOp(
3999 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
4000
4001 // Read res slice of size {n, w, c} @ [0, 0, 0].
4002 Value res = vector::TransferReadOp::create(
4003 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
4004 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
4005 auto *maybeMaskedRes = maybeMaskXferOp(
4006 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
4007
4008 //===------------------------------------------------------------------===//
4009 // Begin vector-only rewrite part
4010 //===------------------------------------------------------------------===//
4011 // Unroll along kw and read slices of lhs and rhs.
4012 SmallVector<Value> lhsVals, rhsVals, resVals;
4013 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
4014 SmallVector<int64_t> inOutStrides = {1, 1, 1};
4015
4016 // Extract lhs slice of size {n, wSizeStep, c}
4017 // @ [0, sw * w + dw * kw, 0].
4018 for (int64_t kw = 0; kw < kwSize; ++kw) {
4019 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4020 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
4021 rewriter, loc, maybeMaskedLhs->getResult(0),
4022 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
4023 inOutSliceSizes, inOutStrides));
4024 }
4025 }
4026 // Extract rhs slice of size {c} @ [kw].
4027 for (int64_t kw = 0; kw < kwSize; ++kw) {
4028 rhsVals.push_back(
4029 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
4030 /*offsets=*/ArrayRef<int64_t>{kw}));
4031 }
4032 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
4033 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4034 resVals.push_back(vector::ExtractStridedSliceOp::create(
4035 rewriter, loc, maybeMaskedRes->getResult(0),
4036 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
4037 inOutStrides));
4038 }
4039
4040 auto linearIndex = [&](int64_t kw, int64_t w) {
4041 return kw * (wSize / wSizeStep) + w;
4042 };
4043
4044 // Note - the scalable flags are ignored as flattening combined with
4045 // scalable vectorization is not supported.
4046 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
4047 auto lhsTypeAfterFlattening =
4048 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
4049 auto resTypeAfterFlattening =
4050 VectorType::get(inOutFlattenSliceSizes, resEltType);
4051
4052 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
4053 for (int64_t kw = 0; kw < kwSize; ++kw) {
4054 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4055 Value lhsVal = lhsVals[linearIndex(kw, w)];
4056 Value resVal = resVals[w];
4057 if (flatten) {
4058 // Flatten the input and output vectors (collapse the channel
4059 // dimension)
4060 lhsVal =
4061 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
4062 lhsVals[linearIndex(kw, w)]);
4063 resVal = vector::ShapeCastOp::create(
4064 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4065 }
4066 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4067 rhsVals[kw], resVal, flatten);
4068 if (flatten) {
4069 // Un-flatten the output vector (restore the channel dimension)
4070 resVals[w] = vector::ShapeCastOp::create(
4071 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4072 resVals[w]);
4073 }
4074 }
4075 }
4076
4077 // Its possible we failed to create the Fma.
4078 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4079 // Manually revert (in reverse order) to avoid leaving a bad IR state.
4080 for (auto &collection :
4081 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4082 for (Value v : collection)
4083 rewriter.eraseOp(v.getDefiningOp());
4084 return rewriter.notifyMatchFailure(op, "failed to create FMA");
4085 }
4086
4087 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4088 // This does not depend on kw.
4089 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4090 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4091 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4092 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4093 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4094 }
4095 //===------------------------------------------------------------------===//
4096 // End vector-only rewrite part
4097 //===------------------------------------------------------------------===//
4098
4099 // Write back res slice of size {n, w, c} @ [0, 0, 0].
4100 Operation *resOut = vector::TransferWriteOp::create(
4101 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4102 ValueRange{zero, zero, zero});
4103 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4104 resOut);
4105 }
4106
4107 /// Lower:
4108 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4109 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4110 /// to MulAcc.
4111 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4112 Value lhs, Value rhs, Value res,
4113 bool flatten) {
4114 auto rhsTy = cast<ShapedType>(rhs.getType());
4115 auto resTy = cast<ShapedType>(res.getType());
4116
4117 // TODO(suderman): Change this to use a vector.ima intrinsic.
4118 lhs = promote(rewriter, loc, lhs, resTy);
4119
4120 if (flatten) {
4121 // NOTE: This following logic won't work for scalable vectors. For this
4122 // reason, "flattening" is not supported when shapes are dynamic (this
4123 // should be captured by one of the pre-conditions).
4124
4125 // There are two options for handling the filter:
4126 // * shape_cast(broadcast(filter))
4127 // * broadcast(shuffle(filter))
4128 // Opt for the option without shape_cast to simplify the codegen.
4129 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4130 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4131
4132 SmallVector<int64_t, 16> indices;
4133 for (int i = 0; i < resSize / rhsSize; ++i) {
4134 for (int j = 0; j < rhsSize; ++j)
4135 indices.push_back(j);
4136 }
4137
4138 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4139 }
4140 // Broadcast the filter to match the output vector
4141 rhs = vector::BroadcastOp::create(rewriter, loc,
4142 resTy.clone(rhsTy.getElementType()), rhs);
4143
4144 rhs = promote(rewriter, loc, rhs, resTy);
4145
4146 if (!lhs || !rhs)
4147 return nullptr;
4148
4149 if (isa<FloatType>(resTy.getElementType()))
4150 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4151
4152 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4153 return arith::AddIOp::create(rewriter, loc, mul, res);
4154 }
4155
4156 /// Entry point for non-channeled convolution:
4157 /// {{w + kw}, {kw}, {w}}
4158 FailureOr<Operation *> generateNonChanneledConv() {
4159 AffineExpr w, kw;
4160 bindDims(ctx, w, kw);
4161 if (!iters({Par(), Red()}))
4162 return rewriter.notifyMatchFailure(op,
4163 "failed to match conv::W 1-par 1-red");
4164
4165 // No transposition needed.
4166 if (layout({/*lhsIndex*/ {w + kw},
4167 /*rhsIndex*/ {kw},
4168 /*resIndex*/ {w}}))
4169 return conv(Conv1DOpOrder::W);
4170
4171 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4172 }
4173
4174 /// Entry point that transposes into the common form:
4175 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4176 FailureOr<Operation *> generateNwcConv() {
4177 AffineExpr n, w, f, kw, c;
4178 bindDims(ctx, n, w, f, kw, c);
4179 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4180 return rewriter.notifyMatchFailure(
4181 op, "failed to match conv::Nwc 3-par 2-red");
4182
4183 // No transposition needed.
4184 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4185 /*rhsIndex*/ {kw, c, f},
4186 /*resIndex*/ {n, w, f}}))
4187 return conv(Conv1DOpOrder::Nwc);
4188
4189 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4190 }
4191
4192 /// Entry point that transposes into the common form:
4193 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4194 FailureOr<Operation *> generateNcwConv() {
4195 AffineExpr n, w, f, kw, c;
4196 bindDims(ctx, n, f, w, c, kw);
4197 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4198 return rewriter.notifyMatchFailure(
4199 op, "failed to match conv::Ncw 3-par 2-red");
4200
4201 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4202 /*rhsIndex*/ {f, c, kw},
4203 /*resIndex*/ {n, f, w}}))
4204 return conv(Conv1DOpOrder::Ncw);
4205
4206 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4207 }
4208
4209 /// Entry point that transposes into the common form:
4210 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4211 FailureOr<Operation *> generateNwcPooling() {
4212 AffineExpr n, w, c, kw;
4213 bindDims(ctx, n, w, c, kw);
4214 if (!iters({Par(), Par(), Par(), Red()}))
4215 return rewriter.notifyMatchFailure(op,
4216 "failed to match pooling 3-par 1-red");
4217
4218 // No transposition needed.
4219 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4220 /*rhsIndex*/ {kw},
4221 /*resIndex*/ {n, w, c}}))
4222 return conv(Conv1DOpOrder::Nwc);
4223
4224 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4225 }
4226
4227 /// Entry point that transposes into the common form:
4228 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4229 FailureOr<Operation *> generateNcwPooling() {
4230 AffineExpr n, w, c, kw;
4231 bindDims(ctx, n, c, w, kw);
4232 if (!iters({Par(), Par(), Par(), Red()}))
4233 return rewriter.notifyMatchFailure(op,
4234 "failed to match pooling 3-par 1-red");
4235
4236 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4237 /*rhsIndex*/ {kw},
4238 /*resIndex*/ {n, c, w}}))
4239 return conv(Conv1DOpOrder::Ncw);
4240
4241 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4242 }
4243
4244 /// Entry point that transposes into the common form:
4245 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4246 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4247 bool vecChDimScalableFlag = false,
4248 bool flatten = false) {
4249 AffineExpr n, w, c, kw;
4250 bindDims(ctx, n, w, c, kw);
4251 if (!iters({Par(), Par(), Par(), Red()}))
4252 return rewriter.notifyMatchFailure(
4253 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4254
4255 // No transposition needed.
4256 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4257 /*rhsIndex*/ {kw, c},
4258 /*resIndex*/ {n, w, c}}))
4259 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4260
4261 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4262 }
4263
4264private:
4265 ConvOperationKind oper = ConvOperationKind::Conv;
4266 StringAttr redOp;
4267 StringAttr poolExtOp;
4268 bool isPoolExt = false;
4269 int strideW, dilationW;
4270 Value lhsShaped, rhsShaped, resShaped;
4271 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4272 vector::CombiningKind reductionKind;
4273
4274 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4275 void setConvOperationKind(Operation *reduceOp) {
4276 int numBlockArguments =
4277 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4278 if (numBlockArguments == 1) {
4279 // Will be convolution if feeder is a MulOp.
4280 // A strength reduced version of MulOp for i1 type is AndOp which is also
4281 // supported. Otherwise, it can be pooling. This strength reduction logic
4282 // is in `buildBinaryFn` helper in the Linalg dialect.
4283 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4284 llvm::IsaPred<BlockArgument>);
4285 Operation *feedOp = (*feedValIt).getDefiningOp();
4286 if (isCastOfBlockArgument(feedOp)) {
4287 oper = ConvOperationKind::Pool;
4288 isPoolExt = true;
4289 poolExtOp = feedOp->getName().getIdentifier();
4290 return;
4291 }
4292 oper = ConvOperationKind::Conv;
4293 return;
4294 }
4295 // numBlockArugments == 2 and this is a pooling op.
4296 oper = ConvOperationKind::Pool;
4297 isPoolExt = false;
4298 }
4299};
4300} // namespace
4301
4302/// Helper function to vectorize a LinalgOp with convolution semantics.
4303// TODO: extend the generic vectorization to support windows and drop this.
4304static FailureOr<Operation *> vectorizeConvolution(
4305 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4306 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4307 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4308 if (failed(conv1dGen))
4309 return failure();
4310 auto res = conv1dGen->generateNonChanneledConv();
4311 if (succeeded(res))
4312 return res;
4313 res = conv1dGen->generateNwcConv();
4314 if (succeeded(res))
4315 return res;
4316 res = conv1dGen->generateNcwConv();
4317 if (succeeded(res))
4318 return res;
4319 res = conv1dGen->generateNwcPooling();
4320 if (succeeded(res))
4321 return res;
4322 res = conv1dGen->generateNcwPooling();
4323 if (succeeded(res))
4324 return res;
4325
4326 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4327 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4328 // masked/scalable) is the channel dim (i.e. the trailing dim).
4329 uint64_t vecChDimSize = ShapedType::kDynamic;
4330 bool vecChDimScalableFlag = false;
4331 if (!inputVecSizes.empty()) {
4332 // Only use the input vector size corresponding to the channel dim. Other
4333 // vector dims will be inferred from the Ops.
4336 "Not a 1D depthwise conv!");
4337 size_t chDimIdx = 0;
4339 chDimIdx = 2;
4341 chDimIdx = 1;
4342
4343 vecChDimSize = inputVecSizes[chDimIdx];
4344 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4345 }
4346 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4347 flatten1DDepthwiseConv);
4348}
4349
4350struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4352
4353 LogicalResult matchAndRewrite(LinalgOp op,
4354 PatternRewriter &rewriter) const override {
4355 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4356 if (failed(resultOrFail))
4357 return failure();
4358 Operation *newOp = *resultOrFail;
4359 if (newOp->getNumResults() == 0) {
4360 rewriter.eraseOp(op.getOperation());
4361 return success();
4362 }
4363 assert(newOp->getNumResults() == 1 && "expected single result");
4364 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4365 return success();
4366 }
4367};
4368
4370 RewritePatternSet &patterns, PatternBenefit benefit) {
4371 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4372}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_iterator operand_end()
Definition Operation.h:375
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
Definition Utils.h:126
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:78
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:732
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override