MLIR 23.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
35#include "mlir/IR/Value.h"
36#include "mlir/Support/LLVM.h"
38#include "llvm/ADT/STLExtras.h"
39#include "llvm/ADT/Sequence.h"
40#include "llvm/ADT/SmallVector.h"
41#include "llvm/ADT/SmallVectorExtras.h"
42#include "llvm/ADT/TypeSwitch.h"
43#include "llvm/Support/DebugLog.h"
44#include "llvm/Support/InterleavedRange.h"
45#include "llvm/Support/MathExtras.h"
46#include "llvm/Support/raw_ostream.h"
47#include <optional>
48
49using namespace mlir;
50using namespace mlir::linalg;
51
52#define DEBUG_TYPE "linalg-vectorization"
53
54/// Try to vectorize `convOp` as a convolution.
55static FailureOr<Operation *>
56vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
57 ArrayRef<int64_t> inputVecSizes = {},
58 ArrayRef<bool> inputVecScalableFlags = {},
59 bool flatten1DDepthwiseConv = false);
60
61/// Vectorize tensor::InsertSliceOp with:
62/// * vector::TransferReadOp + vector::TransferWriteOp
63/// The vector sizes are either:
64/// * user-provided in `inputVectorSizes`, or
65/// * inferred from the static dims in the input and output tensors.
66/// Bails out if:
67/// * vector sizes are not user-provided, and
68/// * at least one dim is dynamic (in both the input and output tensors).
69///
70/// Before:
71/// !t_in_type = tensor<1x2x3xf32>
72/// !t_out_type = tensor<9x8x7x1x2x3xf32>
73/// !v_type = vector<1x2x3xf32>
74/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
75/// into !t_out_type
76/// After:
77/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
78/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
79static LogicalResult
80vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
81 ArrayRef<int64_t> inputVectorSizes,
82 SmallVectorImpl<Value> &newResults);
83
84/// Returns the effective Pad value for the input op, provided it's a scalar.
85///
86/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
87/// this Op performs padding, retrieve the padding value provided that it's
88/// a scalar and static/fixed for all the padded values. Returns an empty value
89/// otherwise.
91
92/// Return the unique instance of OpType in `block` if it is indeed unique.
93/// Return null if none or more than 1 instances exist.
94template <typename OpType>
95static OpType getSingleOpOfType(Block &block) {
96 OpType res;
97 block.walk([&](OpType op) {
98 if (res) {
99 res = nullptr;
100 return WalkResult::interrupt();
101 }
102 res = op;
103 return WalkResult::advance();
104 });
105 return res;
106}
107
108/// Helper function to extract the input slices after filter is unrolled along
109/// kw.
112 int64_t nSize, int64_t wSize, int64_t cSize,
113 int64_t kwSize, int strideW, int dilationW,
114 int64_t wSizeStep, bool isSingleChanneled) {
116 if (isSingleChanneled) {
117 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
118 // convolution.
119 SmallVector<int64_t> sizes = {wSizeStep};
120 SmallVector<int64_t> strides = {1};
121 for (int64_t kw = 0; kw < kwSize; ++kw) {
122 for (int64_t w = 0; w < wSize; w += wSizeStep) {
123 result.push_back(vector::ExtractStridedSliceOp::create(
124 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
125 strides));
126 }
127 }
128 } else {
129 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
130 // for channeled convolution.
131 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
132 SmallVector<int64_t> strides = {1, 1, 1};
133 for (int64_t kw = 0; kw < kwSize; ++kw) {
134 for (int64_t w = 0; w < wSize; w += wSizeStep) {
135 result.push_back(vector::ExtractStridedSliceOp::create(
136 rewriter, loc, input,
137 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
138 sizes, strides));
139 }
140 }
141 }
142 return result;
143}
144
145/// Helper function to extract the filter slices after filter is unrolled along
146/// kw.
148 Location loc, Value filter,
149 int64_t kwSize) {
151 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
152 // non-chanelled convolution] @ [kw].
153 for (int64_t kw = 0; kw < kwSize; ++kw) {
154 result.push_back(vector::ExtractOp::create(
155 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
156 }
157 return result;
158}
159
160/// Helper function to extract the result slices after filter is unrolled along
161/// kw.
164 int64_t nSize, int64_t wSize, int64_t fSize,
165 int64_t wSizeStep, bool isSingleChanneled) {
167 if (isSingleChanneled) {
168 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
169 SmallVector<int64_t> sizes = {wSizeStep};
170 SmallVector<int64_t> strides = {1};
171 for (int64_t w = 0; w < wSize; w += wSizeStep) {
172 result.push_back(vector::ExtractStridedSliceOp::create(
173 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
174 strides));
175 }
176 } else {
177 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
178 // convolution.
179 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
180 SmallVector<int64_t> strides = {1, 1, 1};
181 for (int64_t w = 0; w < wSize; w += wSizeStep) {
182 result.push_back(vector::ExtractStridedSliceOp::create(
183 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
184 strides));
185 }
186 }
187 return result;
188}
189
190/// Helper function to insert the computed result slices.
192 Value res, int64_t wSize, int64_t wSizeStep,
193 SmallVectorImpl<Value> &resVals,
194 bool isSingleChanneled) {
195
196 if (isSingleChanneled) {
197 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
198 // This does not depend on kw.
199 SmallVector<int64_t> strides = {1};
200 for (int64_t w = 0; w < wSize; w += wSizeStep) {
201 res = vector::InsertStridedSliceOp::create(
202 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
203 strides);
204 }
205 } else {
206 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
207 // convolution. This does not depend on kw.
208 SmallVector<int64_t> strides = {1, 1, 1};
209 for (int64_t w = 0; w < wSize; w += wSizeStep) {
210 res = vector::InsertStridedSliceOp::create(
211 rewriter, loc, resVals[w], res,
212 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
213 }
214 }
215 return res;
216}
217
218/// Contains the vectorization state and related methods used across the
219/// vectorization process of a given operation.
221 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
222
223 /// Initializes the vectorization state, including the computation of the
224 /// canonical vector shape for vectorization.
225 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
226 ArrayRef<int64_t> inputVectorSizes,
227 ArrayRef<bool> inputScalableVecDims,
228 bool assumeDynamicDimsMatchVecSizes = false);
229
230 /// Returns the canonical vector shape used to vectorize the iteration space.
231 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
232
233 /// Returns the vector dimensions that are scalable in the canonical vector
234 /// shape.
235 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
236
237 /// Returns a vector type of the provided `elementType` with the canonical
238 /// vector shape and the corresponding fixed/scalable dimensions bit. If
239 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
240 /// accordingly.
242 Type elementType,
243 std::optional<AffineMap> dimPermutation = std::nullopt) const {
245 SmallVector<bool> scalableDims;
246 if (dimPermutation.has_value()) {
248 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
249 scalableDims =
250 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
251 } else {
252 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
253 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
254 }
255
256 return VectorType::get(vectorShape, elementType, scalableDims);
257 }
258
259 /// Masks an operation with the canonical vector mask if the operation needs
260 /// masking. Returns the masked operation or the original operation if masking
261 /// is not needed. If provided, the canonical mask for this operation is
262 /// permuted using `maybeIndexingMap`.
263 Operation *
264 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
265 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
266
267private:
268 /// Initializes the iteration space static sizes using the Linalg op
269 /// information. This may become more complicated in the future.
270 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
271 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
272 }
273
274 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
275 /// all the static and dynamic dimensions of the iteration space to be
276 /// vectorized and store them in `iterSpaceValueSizes`.
277 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
278 LinalgOp linalgOp);
279
280 /// Create or retrieve an existing mask value to mask `opToMask` in the
281 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
282 /// permuted using that permutation map. If a new mask is created, it will be
283 /// cached for future users.
284 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
285 LinalgOp linalgOp,
286 std::optional<AffineMap> maybeMaskingMap);
287
288 /// Check whether this permutation map can be used for masking. At the
289 /// moment we only make sure that there are no broadcast dimensions, but this
290 /// might change if indexing maps evolve.
291 bool isValidMaskingMap(AffineMap maskingMap) {
292 return maskingMap.getBroadcastDims().empty();
293 }
294
295 /// Turn the input indexing map into a valid masking map.
296 ///
297 /// The input indexing map may contain "zero" results, e.g.:
298 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
299 /// Applying such maps to canonical vector shapes like this one:
300 /// (1, 16, 16, 4)
301 /// would yield an invalid vector shape like this:
302 /// (16, 16, 1, 0)
303 /// Instead, drop the broadcasting dims that make no sense for masking perm.
304 /// maps:
305 /// (d0, d1, d2, d3) -> (d2, d1, d0)
306 /// This way, the corresponding vector/mask type will be:
307 /// vector<16x16x1xty>
308 /// rather than this invalid Vector type:
309 /// vector<16x16x1x0xty>
310 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
311 return indexingMap.dropZeroResults();
312 }
313
314 // Holds the compile-time static sizes of the iteration space to vectorize.
315 // Dynamic dimensions are represented using ShapedType::kDynamic.
316 SmallVector<int64_t> iterSpaceStaticSizes;
317
318 /// Holds the value sizes of the iteration space to vectorize. Static
319 /// dimensions are represented by 'arith.constant' and dynamic
320 /// dimensions by 'tensor/memref.dim'.
321 SmallVector<Value> iterSpaceValueSizes;
322
323 /// Holds the canonical vector shape used to vectorize the iteration space.
324 SmallVector<int64_t> canonicalVecShape;
325
326 /// Holds the vector dimensions that are scalable in the canonical vector
327 /// shape.
328 SmallVector<bool> scalableVecDims;
329
330 /// Holds the active masks for permutations of the canonical vector iteration
331 /// space.
332 DenseMap<AffineMap, Value> activeMaskCache;
333
334 /// Global vectorization guard for the incoming rewriter. It's initialized
335 /// when the vectorization state is initialized.
336 OpBuilder::InsertionGuard rewriterGuard;
337
338 /// Do all dynamic dims match the corresponding vector sizes?
339 ///
340 /// When a dynamic tensor/memref dimension matches the corresponding vector
341 /// dimension, masking can be safely skipped, despite the presence of dynamic
342 /// shapes. Use this flag with care and only for cases where you are
343 /// confident the assumption holds.
344 bool assumeDynamicDimsMatchVecSizes = false;
345};
346
347LogicalResult
348VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
349 LinalgOp linalgOp) {
350 // TODO: Support 0-d vectors.
351 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
352 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
353 // Create constant index op for static dimensions.
354 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
355 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
356 continue;
357 }
358
359 // Find an operand defined on this dimension of the iteration space to
360 // extract the runtime dimension size.
361 Value operand;
362 unsigned operandDimPos;
363 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
364 operandDimPos)))
365 return failure();
366
367 Value dynamicDim =
368 linalgOp.hasPureTensorSemantics()
369 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 operandDimPos)
371 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
372 operandDimPos);
373 iterSpaceValueSizes.push_back(dynamicDim);
374 }
375
376 return success();
377}
378
379/// Initializes the vectorization state, including the computation of the
380/// canonical vector shape for vectorization.
381// TODO: Move this to the constructor when we can remove the failure cases.
383 LinalgOp linalgOp,
384 ArrayRef<int64_t> inputVectorSizes,
385 ArrayRef<bool> inputScalableVecDims,
386 bool assumeDimsMatchVec) {
387 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
388 // Initialize the insertion point.
389 rewriter.setInsertionPoint(linalgOp);
390
391 if (!inputVectorSizes.empty()) {
392 // Get the canonical vector shape from the input vector sizes provided. This
393 // path should be taken to vectorize code with dynamic shapes and when using
394 // vector sizes greater than the iteration space sizes.
395 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
396 scalableVecDims.append(inputScalableVecDims.begin(),
397 inputScalableVecDims.end());
398 } else {
399 // Compute the canonical vector shape from the operation shape. If there are
400 // dynamic shapes, the operation won't be vectorized. We assume all the
401 // vector dimensions are fixed.
402 canonicalVecShape = linalgOp.getStaticLoopRanges();
403 scalableVecDims.append(linalgOp.getNumLoops(), false);
404 }
405
406 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
407 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
408
409 if (ShapedType::isDynamicShape(canonicalVecShape))
410 return failure();
411
412 // Initialize iteration space static sizes.
413 initIterSpaceStaticSizes(linalgOp);
414
415 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
416 // all the static and dynamic dimensions of the iteration space, needed to
417 // compute a mask during vectorization.
418 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
419 return failure();
420
421 return success();
422}
423
424/// Create or retrieve an existing mask value to mask `opToMask` in the
425/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
426/// using that permutation map. If a new mask is created, it will be cached for
427/// future users.
428Value VectorizationState::getOrCreateMaskFor(
429 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
430 std::optional<AffineMap> maybeMaskingMap) {
431
432 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
433 "Ill-formed masking map.");
434
435 // No mask is needed if the operation is not maskable.
436 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
437 if (!maskableOp)
438 return Value();
439
440 assert(!maskableOp.isMasked() &&
441 "Masking an operation that is already masked");
442
443 // If no masking map was provided, use an identity map with the loop dims.
444 assert((!maybeMaskingMap || *maybeMaskingMap) &&
445 "Unexpected null mask permutation map");
446 AffineMap maskingMap =
447 maybeMaskingMap ? *maybeMaskingMap
449 linalgOp.getNumLoops(), rewriter.getContext());
450
451 LDBG() << "Masking map: " << maskingMap;
452
453 // Return the active mask for the masking map of this operation if it was
454 // already created.
455 auto activeMaskIt = activeMaskCache.find(maskingMap);
456 if (activeMaskIt != activeMaskCache.end()) {
457 Value mask = activeMaskIt->second;
458 LDBG() << "Reusing mask: " << mask;
459 return mask;
460 }
461
462 // Compute permuted projection of the iteration space to be masked and the
463 // corresponding mask shape. If the resulting iteration space dimensions are
464 // static and identical to the mask shape, masking is not needed for this
465 // operation.
466 // TODO: Improve this check. Only projected permutation indexing maps are
467 // supported.
468 SmallVector<int64_t> permutedStaticSizes =
469 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
470 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
471 auto maskShape = maskType.getShape();
472
473 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
474
475 if (permutedStaticSizes == maskShape) {
476 LDBG() << "Masking is not needed for masking map: " << maskingMap;
477 activeMaskCache[maskingMap] = Value();
478 return Value();
479 }
480
481 if (assumeDynamicDimsMatchVecSizes) {
482 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
483 // vector sizes match, we still need to check the _static_ dim sizes. Only
484 // then we can be 100% sure that masking is not required.
485 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
486 [](auto it) {
487 return std::get<0>(it) == ShapedType::kDynamic
488 ? true
489 : std::get<0>(it) == std::get<1>(it);
490 })) {
491 LDBG()
492 << "Dynamic + static dimensions match vector sizes, masking is not "
493 "required.";
494 activeMaskCache[maskingMap] = Value();
495 return Value();
496 }
497 }
498
499 // Permute the iteration space value sizes to compute the mask upper bounds.
500 SmallVector<Value> upperBounds =
501 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
502 assert(!maskShape.empty() && !upperBounds.empty() &&
503 "Masked 0-d vectors are not supported yet");
504
505 // Create the mask based on the dimension values.
506 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
507 maskType, upperBounds);
508 LDBG() << "Creating new mask: " << mask;
509 activeMaskCache[maskingMap] = mask;
510 return mask;
511}
512
513Operation *
515 LinalgOp linalgOp,
516 std::optional<AffineMap> maybeIndexingMap) {
517 LDBG() << "Trying to mask: " << *opToMask;
518
519 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
520 if (maybeIndexingMap)
521 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
522
523 // Create or retrieve mask for this operation.
524 Value mask =
525 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
526
527 if (!mask) {
528 LDBG() << "No mask required";
529 if (assumeDynamicDimsMatchVecSizes) {
531 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
532 [&](auto xferOp) {
533 // For vector.transfer_read and vector.transfer_write, there is
534 // also the `in-bounds` attribute that has to be set explicitly
535 // to true. Otherwise, "out-of-bounds" access will be assumed
536 // and masks will be generated while lowering these.
537 LDBG() << "Assuming dynamic dimensions match vector sizes and "
538 "setting their in-bounds to true!";
539 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
540 ShapedType xferType = xferOp.getShapedType();
541 AffineMap permMap = xferOp.getPermutationMap();
542 // Only set the in-bounds values to true for dynamic dims.
543 // Different mechanisms will set these accordingly for the
544 // static dims.
545 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
546 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
547 // Skip broadcast dimensions.
548 if (!dimExpr)
549 continue;
550 unsigned pos = dimExpr.getPosition();
551 if (xferType.isDynamicDim(pos))
552 inBoundsMap[i] = true;
553 }
554 rewriter.modifyOpInPlace(xferOp, [&]() {
555 xferOp.setInBoundsAttr(
556 rewriter.getBoolArrayAttr(inBoundsMap));
557 });
558 })
559 .Default([](Operation *op) {
560 // No-op if the operation is not an xfer read or write.
561 });
562 }
563 return opToMask;
564 }
565
566 // Wrap the operation with a new `vector.mask` and update D-U chain.
567 assert(opToMask && "Expected a valid operation to mask");
568 auto maskOp = cast<vector::MaskOp>(
569 mlir::vector::maskOperation(rewriter, opToMask, mask));
570 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
571
572 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
573 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
574 maskOpTerminator);
575
576 LDBG() << "Masked operation: " << *maskOp;
577 return maskOp;
578}
579
580/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
581/// projectedPermutation, compress the unused dimensions to serve as a
582/// permutation_map for a vector transfer operation.
583/// For example, given a linalg op such as:
584///
585/// ```
586/// %0 = linalg.generic {
587/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
588/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
589/// }
590/// ins(%0 : tensor<2x3x4xf32>)
591/// outs(%1 : tensor<5x6xf32>)
592/// ```
593///
594/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
595/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
596/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
598 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
599 "expected projected permutation");
600 auto res = compressUnusedDims(map);
601 assert(res.getNumDims() ==
602 (res.getNumResults() - res.getNumOfZeroResults()) &&
603 "expected reindexed map with same number of dims and results");
604 return res;
605}
606
607/// Helper enum to represent conv1d input traversal order.
608enum class Conv1DOpOrder {
609 W, // Corresponds to non-channeled 1D convolution operation.
610 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
611 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
612};
613
614/// Helper data structure to represent the result of vectorization for a single
615/// operation. In certain specific cases, like terminators, we do not want to
616/// propagate.
618 /// Op failed to vectorize.
620 /// Op vectorized and custom function took care of replacement logic
622 /// Op vectorized into a new Op whose results will replace original Op's
623 /// results.
625 // TODO: support values if Op vectorized to Many-Ops whose results we need to
626 // aggregate for replacement.
627};
628/// VectorizationHookResult contains the vectorized op returned from a
629/// CustomVectorizationHook. This is an internal implementation detail of
630/// linalg vectorization, not to be confused with VectorizationResult.
632 /// Return status from vectorizing the current op.
634 /// New vectorized operation to replace the current op.
635 /// Replacement behavior is specified by `status`.
637};
638
639std::optional<vector::CombiningKind>
641 using ::mlir::vector::CombiningKind;
642
643 if (!combinerOp)
644 return std::nullopt;
646 .Case<arith::AddIOp, arith::AddFOp>(
647 [&](auto op) { return CombiningKind::ADD; })
648 .Case([&](arith::AndIOp op) { return CombiningKind::AND; })
649 .Case([&](arith::MaxSIOp op) { return CombiningKind::MAXSI; })
650 .Case([&](arith::MaxUIOp op) { return CombiningKind::MAXUI; })
651 .Case([&](arith::MaximumFOp op) { return CombiningKind::MAXIMUMF; })
652 .Case([&](arith::MaxNumFOp op) { return CombiningKind::MAXNUMF; })
653 .Case([&](arith::MinSIOp op) { return CombiningKind::MINSI; })
654 .Case([&](arith::MinUIOp op) { return CombiningKind::MINUI; })
655 .Case([&](arith::MinimumFOp op) { return CombiningKind::MINIMUMF; })
656 .Case([&](arith::MinNumFOp op) { return CombiningKind::MINNUMF; })
657 .Case<arith::MulIOp, arith::MulFOp>(
658 [&](auto op) { return CombiningKind::MUL; })
659 .Case([&](arith::OrIOp op) { return CombiningKind::OR; })
660 .Case([&](arith::XOrIOp op) { return CombiningKind::XOR; })
661 .Default(std::nullopt);
662}
663
664/// Check whether `outputOperand` is a reduction with a single combiner
665/// operation. Return the combiner operation of the reduction. Return
666/// nullptr otherwise. Multiple reduction operations would impose an
667/// ordering between reduction dimensions and is currently unsupported in
668/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
669/// max(min(X))
670// TODO: use in LinalgOp verification, there is a circular dependency atm.
671static Operation *matchLinalgReduction(OpOperand *outputOperand) {
672 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
673 unsigned outputPos =
674 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
675 // Only single combiner operations are supported for now.
676 SmallVector<Operation *, 4> combinerOps;
677 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
678 combinerOps.size() != 1)
679 return nullptr;
680
681 // Return the combiner operation.
682 return combinerOps[0];
683}
684
685/// Broadcast `value` to a vector of `shape` if possible. Return value
686/// otherwise.
687static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
688 auto dstVecType = dyn_cast<VectorType>(dstType);
689 // If no shape to broadcast to, just return `value`.
690 if (dstVecType.getRank() == 0)
691 return value;
692 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
694 return value;
695 Location loc = b.getInsertionPoint()->getLoc();
696 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
697}
698
699/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
700/// assumes that `reductionOp` has two operands and one of them is the reduction
701/// initial value.buildMultiDimReduce
702// Note: this is a true builder that notifies the OpBuilder listener.
703// TODO: Consider moving as a static helper on the ReduceOp.
705 Value valueToReduce, Value acc,
706 ArrayRef<bool> dimsToMask) {
707 auto maybeKind = getCombinerOpKind(reduceOp);
708 assert(maybeKind && "Failed precondition: could not get reduction kind");
709 return vector::MultiDimReductionOp::create(
710 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
711}
712
713static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
714 return llvm::map_to_vector(linalgOp.getIteratorTypesArray(),
716}
717
718/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
719/// reduction iterator.
720static bool hasReductionIterator(LinalgOp &op) {
721 return isa<linalg::ReduceOp>(op) ||
722 (isa<linalg::GenericOp>(op) &&
723 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
724}
725
726/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
727/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
728/// currently being vectorized. If `dest` has null rank, build an memref.store.
729/// Return the produced value or null if no value is produced.
730// Note: this is a true builder that notifies the OpBuilder listener.
731// TODO: Consider moving as a static helper on the ReduceOp.
732static Value buildVectorWrite(RewriterBase &rewriter, Value value,
733 OpOperand *outputOperand,
734 VectorizationState &state) {
735 Location loc = value.getLoc();
736 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
737 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
738
739 // Compute the vector type of the value to store. This type should be an
740 // identity or projection of the canonical vector type without any permutation
741 // applied, given that any permutation in a transfer write happens as part of
742 // the write itself.
744 opOperandMap.getContext(), opOperandMap.getNumInputs(),
745 [&](AffineDimExpr dimExpr) -> bool {
746 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
747 });
748 auto vectorType = state.getCanonicalVecType(
749 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
750
751 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
752 arith::ConstantIndexOp::create(rewriter, loc, 0));
753
754 Operation *write;
755 if (vectorType.getRank() > 0) {
756 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
757 value = broadcastIfNeeded(rewriter, value, vectorType);
758 assert(value.getType() == vectorType && "Incorrect type");
759 write = vector::TransferWriteOp::create(
760 rewriter, loc, value, outputOperand->get(), indices, writeMap);
761 } else {
762 // 0-d case is still special: do not invert the reindexing writeMap.
763 if (!isa<VectorType>(value.getType()))
764 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
765 assert(value.getType() == vectorType && "Incorrect type");
766 write = vector::TransferWriteOp::create(rewriter, loc, value,
767 outputOperand->get(), indices);
768 }
769
770 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
771
772 // If masked, set in-bounds to true. Masking guarantees that the access will
773 // be in-bounds.
774 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
775 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
776 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
777 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
778 }
779
780 LDBG() << "vectorized op: " << *write;
781 if (!write->getResults().empty())
782 return write->getResult(0);
783 return Value();
784}
785
786// Custom vectorization precondition function type. This is intented to be used
787// with CustomVectorizationHook. Returns success if the corresponding custom
788// hook can vectorize the op.
790 std::function<LogicalResult(Operation *, bool)>;
791
792// Custom vectorization function type. Produce a vector form of Operation*
793// assuming all its vectorized operands are already in the IRMapping.
794// Return nullptr if the Operation cannot be vectorized.
796 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
797
798/// Helper function to vectorize the terminator of a `linalgOp`. New result
799/// vector values are appended to `newResults`. Return
800/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
801/// that it should not try to map produced operations and instead return the
802/// results using the `newResults` vector making them available to the
803/// vectorization algorithm for RAUW. This function is meant to be used as a
804/// CustomVectorizationHook.
807 const IRMapping &bvm, VectorizationState &state,
808 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
809 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
810 if (!yieldOp)
812 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
813 // TODO: Scan for an opportunity for reuse.
814 // TODO: use a map.
815 Value vectorValue = bvm.lookup(output.value());
816 Value newResult =
817 buildVectorWrite(rewriter, vectorValue,
818 linalgOp.getDpsInitOperand(output.index()), state);
819 if (newResult)
820 newResults.push_back(newResult);
821 }
822
824}
825
826/// Helper function to vectorize the index operations of a `linalgOp`. Return
827/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
828/// should map the produced operations. This function is meant to be used as a
829/// CustomVectorizationHook.
831 VectorizationState &state,
832 Operation *op,
833 LinalgOp linalgOp) {
834 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
835 if (!indexOp)
837 auto loc = indexOp.getLoc();
838 // Compute the static loop sizes of the index op.
839 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
840 auto dim = indexOp.getDim();
841 // Compute a one-dimensional index vector for the index op dimension.
842 auto indexVectorType =
843 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
844 state.getScalableVecDims()[dim]);
845 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
846 // Return the one-dimensional index vector if it lives in the trailing
847 // dimension of the iteration space since the vectorization algorithm in this
848 // case can handle the broadcast.
849 if (dim == targetShape.size() - 1)
851 // Otherwise permute the targetShape to move the index dimension last,
852 // broadcast the one-dimensional index vector to the permuted shape, and
853 // finally transpose the broadcasted index vector to undo the permutation.
854 auto permPattern =
855 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
856 std::swap(permPattern[dim], permPattern.back());
857 auto permMap =
858 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
859
860 auto broadCastOp = vector::BroadcastOp::create(
861 rewriter, loc,
862 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
863 SmallVector<int64_t> transposition =
864 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
865 std::swap(transposition.back(), transposition[dim]);
866 auto transposeOp =
867 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
869}
870
871/// Helper function to check if the tensor.extract can be vectorized by the
872/// custom hook vectorizeTensorExtract.
873static LogicalResult
875 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
876 if (!extractOp)
877 return failure();
878
879 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
880 return failure();
881
882 // Check the index type, but only for non 0-d tensors (for which we do need
883 // access indices).
884 if (not extractOp.getIndices().empty()) {
885 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
886 return failure();
887 }
888
889 if (!llvm::all_of(extractOp->getResultTypes(),
890 VectorType::isValidElementType)) {
891 return failure();
892 }
893
894 return success();
895}
896
897/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
898/// generated from `tensor.extract`. The offset is calculated as follows
899/// (example using scalar values):
900///
901/// offset = extractOp.indices[0]
902/// for (i = 1; i < numIndices; i++)
903/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
904///
905/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
906/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
908 VectorizationState &state,
909 tensor::ExtractOp extractOp,
910 const IRMapping &bvm) {
911 // The vector of indices for GatherOp should be shaped as the output vector.
912 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
913 auto loc = extractOp.getLoc();
914
915 Value offset = broadcastIfNeeded(
916 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
917
918 const size_t numIndices = extractOp.getIndices().size();
919 for (size_t i = 1; i < numIndices; i++) {
920 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
921
922 auto dimSize = broadcastIfNeeded(
923 rewriter,
924 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
925 indexVecType);
926
927 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
928
929 auto extractOpIndex = broadcastIfNeeded(
930 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
931
932 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
933 }
934
935 return offset;
936}
937
939
940/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
941/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
942/// represents a contiguous load operation.
943///
944/// Note that when calling this hook, it is assumed that the output vector is
945/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
946/// labelled as a gather load before entering this method.
947///
948/// Following on from the above, it is assumed that:
949/// * for statically shaped loops, when no masks are used, only one dim is !=
950/// 1 (that's what the shape of the output vector is based on).
951/// * for dynamically shaped loops, there might be more non-unit dims
952/// as the output vector type is user-specified.
953///
954/// TODO: Statically shaped loops + vector masking
955static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
956 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
957 assert(
958 (linalgOp.hasDynamicShape() ||
959 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
960 "For statically shaped Linalg Ops, only one "
961 "non-unit loop dim is expected");
962 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
963
964 size_t idx = loopRanges.size() - 1;
965 for (; idx != 0; idx--)
966 if (loopRanges[idx] != 1)
967 break;
968
969 return idx;
970}
971
972/// Checks whether `val` can be used for calculating a loop invariant index.
973static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
974 VectorType resType) {
975
976 assert(((llvm::count_if(resType.getShape(),
977 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
978 "n-D vectors are not yet supported");
979
980 // Blocks outside _this_ linalg.generic are effectively loop invariant.
981 // However, analysing block arguments for _this_ linalg.generic Op is a bit
982 // tricky. Just bail out in the latter case.
983 // TODO: We could try analysing the corresponding affine map here.
984 auto *block = linalgOp.getBlock();
985 if (isa<BlockArgument>(val))
986 return !llvm::is_contained(block->getArguments(), val);
987
988 Operation *defOp = val.getDefiningOp();
989 assert(defOp && "This is neither a block argument nor an operation result");
990
991 // IndexOp is loop invariant as long as its result remains constant across
992 // iterations. Note that for dynamic shapes, the corresponding dim will also
993 // be conservatively treated as != 1.
994 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
995 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
996 }
997
998 auto *ancestor = block->findAncestorOpInBlock(*defOp);
999
1000 // Values define outside `linalgOp` are loop invariant.
1001 if (!ancestor)
1002 return true;
1003
1004 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1005 if (isa<arith::ConstantOp>(ancestor))
1006 return true;
1007
1008 bool result = true;
1009 for (auto op : ancestor->getOperands())
1010 result &= isLoopInvariantIdx(linalgOp, op, resType);
1011
1012 return result;
1013}
1014
1015/// Check whether `val` could be used for calculating the trailing index for a
1016/// contiguous load operation.
1017///
1018/// There are currently 3 types of values that are allowed here:
1019/// 1. loop-invariant values,
1020/// 2. values that increment by 1 with every loop iteration,
1021/// 3. results of basic arithmetic operations (linear and continuous)
1022/// involving 1., 2. and 3.
1023/// This method returns True if indeed only such values are used in calculating
1024/// `val.`
1025///
1026/// Additionally, the trailing index for a contiguous load operation should
1027/// increment by 1 with every loop iteration, i.e. be based on:
1028/// * `linalg.index <dim>` ,
1029/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1030/// `linalg.index <dim>` increments by 1 with every loop iteration).
1031/// `foundIndexOp` is updated to `true` when such Op is found.
1032static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1033 bool &foundIndexOp, VectorType resType) {
1034
1035 assert(((llvm::count_if(resType.getShape(),
1036 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1037 "n-D vectors are not yet supported");
1038
1039 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1040 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1041 // tricky. Just bail out in the latter case.
1042 // TODO: We could try analysing the corresponding affine map here.
1043 auto *block = linalgOp.getBlock();
1044 if (isa<BlockArgument>(val))
1045 return !llvm::is_contained(block->getArguments(), val);
1046
1047 Operation *defOp = val.getDefiningOp();
1048 assert(defOp && "This is neither a block argument nor an operation result");
1049
1050 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1051 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1052
1053 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1054 return true;
1055 }
1056
1057 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1058
1059 if (!ancestor)
1060 return false;
1061
1062 // Conservatively reject Ops that could lead to indices with stride other
1063 // than 1.
1064 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1065 return false;
1066
1067 bool result = false;
1068 for (auto op : ancestor->getOperands())
1069 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1070
1071 return result;
1072}
1073
1074/// Infer the memory access pattern for the input ExtractOp
1075///
1076/// Based on the ExtratOp result shape and the access indices, decides whether
1077/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1078/// or a gather load. When analysing the ExtractOp indices (to identify
1079/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1080/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1081/// Op).
1082///
1083/// Note that it is always safe to use gather load operations for contiguous
1084/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1085/// that `extractOp` is a gather load.
1087getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1088 LinalgOp &linalgOp, VectorType resType) {
1089
1090 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1091
1092 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1093 if (inputShape.getShape().empty())
1095
1096 // 0a. Is the result a 0-D vector? If yes, there are no iteration dimensions
1097 // so the tensor.extract is a single scalar load regardless of the index.
1098 if (resType.getRank() == 0)
1100
1101 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1102 // otherwise.
1103 bool isOutput1DVector =
1104 (llvm::count_if(resType.getShape(),
1105 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1106 // 1. Assume that it's a gather load when reading non-1D vector.
1107 if (!isOutput1DVector)
1109
1110 bool leadingIdxsLoopInvariant = true;
1111
1112 // 2. Analyze the leading indices of `extractOp`.
1113 // Look at the way each index is calculated and decide whether it is suitable
1114 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1115 // gather load.
1116 auto indices = extractOp.getIndices();
1117 auto leadIndices = indices.drop_back(1);
1118
1119 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1120 if (inputShape.getShape()[i] == 1)
1121 continue;
1122
1123 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1124 }
1125
1126 if (!leadingIdxsLoopInvariant) {
1127 LDBG() << "Found gather load: " << extractOp;
1129 }
1130
1131 // 3. Analyze the trailing index for `extractOp`.
1132 // At this point we know that the leading indices are loop invariant. This
1133 // means that is potentially a scalar or a contiguous load. We can decide
1134 // based on the trailing idx.
1135 auto extractOpTrailingIdx = indices.back();
1136
1137 // 3a. Scalar broadcast load
1138 // If the trailing index is loop invariant then this is a scalar load.
1139 if (leadingIdxsLoopInvariant &&
1140 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1141 LDBG() << "Found scalar broadcast load: " << extractOp;
1142
1144 }
1145
1146 // 3b. Contiguous loads
1147 // The trailing `extractOp` index should increment with every loop iteration.
1148 // This effectively means that it must be based on the trailing loop index.
1149 // This is what the following bool captures.
1150 bool foundIndexOp = false;
1151 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1152 foundIndexOp, resType);
1153 // TODO: Support generating contiguous loads for column vectors - that will
1154 // require adding a permutation map to tranfer_read Ops.
1155 bool isRowVector = resType.getShape().back() != 1;
1156 isContiguousLoad &= (foundIndexOp && isRowVector);
1157
1158 if (isContiguousLoad) {
1159 LDBG() << "Found contigous load: " << extractOp;
1161 }
1162
1163 // 4. Fallback case - gather load.
1164 LDBG() << "Found gather load: " << extractOp;
1166}
1167
1168/// Helper function to vectorize the tensor.extract operations. Returns
1169/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1170/// should map the produced operations. This function is meant to be used as a
1171/// CustomVectorizationHook.
1172static VectorizationHookResult
1173vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1174 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1175 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1176 if (!extractOp)
1178 auto loc = extractOp.getLoc();
1179
1180 // Compute the static loop sizes of the extract op.
1181 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1182 auto maskConstantOp = arith::ConstantOp::create(
1183 rewriter, loc,
1184 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1185 /*value=*/true));
1186 auto passThruConstantOp = arith::ConstantOp::create(
1187 rewriter, loc, rewriter.getZeroAttr(resultType));
1188
1189 // Base indices are currently set to 0. We will need to re-visit if more
1190 // generic scenarios are to be supported.
1191 SmallVector<Value> baseIndices(
1192 extractOp.getIndices().size(),
1193 arith::ConstantIndexOp::create(rewriter, loc, 0));
1194
1195 VectorMemoryAccessKind memAccessKind =
1196 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1197
1198 // 1. Handle gather access
1199 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1200 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1201
1202 // Generate the gather load
1203 Operation *gatherOp = vector::GatherOp::create(
1204 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1205 maskConstantOp, passThruConstantOp);
1206 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1207
1208 LDBG() << "Vectorised as gather load: " << extractOp;
1210 }
1211
1212 // 2. Handle:
1213 // a. scalar loads + broadcast,
1214 // b. contiguous loads.
1215 // Both cases use vector.transfer_read.
1216
1217 // Collect indices for `vector.transfer_read`. At this point, the indices will
1218 // either be scalars or would have been broadcast to vectors matching the
1219 // result type. For indices that are vectors, there are two options:
1220 // * for non-trailing indices, all elements are identical (contiguous
1221 // loads are identified by looking for non-trailing indices that are
1222 // invariant with respect to the corresponding linalg.generic), or
1223 // * for trailing indices, the index vector will contain values with stride
1224 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1225 // needed.
1226 // This means that
1227 // * for scalar indices - just re-use it,
1228 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1229 // (0th) element and use that.
1230 SmallVector<Value> transferReadIdxs;
1231 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1232 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1233 if (idx.getType().isIndex()) {
1234 transferReadIdxs.push_back(idx);
1235 continue;
1236 }
1237
1238 auto indexAs1dVector = vector::ShapeCastOp::create(
1239 rewriter, loc,
1240 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1241 resultType.getScalableDims().back()),
1242 idx);
1243 transferReadIdxs.push_back(
1244 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1245 }
1246
1247 // `tensor.extract_element` is always in-bounds, hence the following holds.
1248 auto dstRank = resultType.getRank();
1249 auto srcRank = extractOp.getTensor().getType().getRank();
1250 SmallVector<bool> inBounds(dstRank, true);
1251
1252 // 2a. Handle scalar broadcast access.
1253 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1254 MLIRContext *ctx = rewriter.getContext();
1255 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1256 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1257
1258 auto transferReadOp = vector::TransferReadOp::create(
1259 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1260 /*padding=*/std::nullopt, permutationMap, inBounds);
1261
1262 Operation *readOrMaskedReadOp = transferReadOp;
1263 if (dstRank > 0) {
1264 // Mask this broadcasting xfer_read here rather than relying on the
1265 // generic path (the generic path assumes identity masking map, which
1266 // wouldn't be valid here).
1267 SmallVector<int64_t> readMaskShape = {1};
1268 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1269 auto allTrue = vector::ConstantMaskOp::create(
1270 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1271 readOrMaskedReadOp =
1272 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1273 }
1274
1275 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1277 readOrMaskedReadOp};
1278 }
1279
1280 // 2b. Handle contiguous access.
1281 auto permutationMap = AffineMap::getMinorIdentityMap(
1282 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1283
1284 int32_t rankDiff = dstRank - srcRank;
1285 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1286 // dims so that the ranks match. This is done by extending the map with 0s.
1287 // For example, for dstRank = 3, srcRank = 2, the following map created
1288 // above:
1289 // (d0, d1) --> (d0, d1)
1290 // is extended as:
1291 // (d0, d1) --> (0, d0, d1)
1292 while (rankDiff > 0) {
1293 permutationMap = permutationMap.insertResult(
1294 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1295 rankDiff--;
1296 }
1297
1298 auto transferReadOp = vector::TransferReadOp::create(
1299 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1300 /*padding=*/std::nullopt, permutationMap, inBounds);
1301
1302 LDBG() << "Vectorised as contiguous load: " << extractOp;
1304 transferReadOp};
1305}
1306
1307/// Emit reduction operations if the shapes of the value to reduce is different
1308/// that the result shape.
1309// Note: this is a true builder that notifies the OpBuilder listener.
1310// TODO: Consider moving as a static helper on the ReduceOp.
1311static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1312 Value reduceValue, Value initialValue,
1313 const IRMapping &bvm) {
1314 Value reduceVec = bvm.lookup(reduceValue);
1315 Value outputVec = bvm.lookup(initialValue);
1316 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1317 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1318 // Reduce only if needed as the value may already have been reduce for
1319 // contraction vectorization.
1320 if (!reduceType ||
1321 (outputType && reduceType.getShape() == outputType.getShape()))
1322 return nullptr;
1323 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1324 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1325}
1326
1327/// Generic vectorization for a single operation `op`, given already vectorized
1328/// operands carried by `bvm`. Vectorization occurs as follows:
1329/// 1. Try to apply any of the `customVectorizationHooks` and return its
1330/// result on success.
1331/// 2. Clone any constant in the current scope without vectorization: each
1332/// consumer of the constant will later determine the shape to which the
1333/// constant needs to be broadcast to.
1334/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1335/// of the `customVectorizationHooks` to cover such cases.
1336/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1337/// operand of maximal rank. Other operands have smaller rank and are
1338/// broadcast accordingly. It is assumed this broadcast is always legal,
1339/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1340///
1341/// This function assumes all operands of `op` have been vectorized and are in
1342/// the `bvm` mapping. As a consequence, this function is meant to be called on
1343/// a topologically-sorted list of ops.
1344/// This function does not update `bvm` but returns a VectorizationHookStatus
1345/// that instructs the caller what `bvm` update needs to occur.
1346static VectorizationHookResult
1347vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1348 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1349 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1350 LDBG() << "vectorize op " << *op;
1351
1352 // 1. Try to apply any CustomVectorizationHook.
1353 if (!customVectorizationHooks.empty()) {
1354 for (auto &customFunc : customVectorizationHooks) {
1355 VectorizationHookResult result = customFunc(op, bvm);
1357 continue;
1358 return result;
1359 }
1360 }
1361
1362 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1363 // Clone so that the constant is not confined to the linalgOp block .
1364 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1366 rewriter.clone(*op)};
1367
1368 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1371
1372 // 4 . Check if the operation is a reduction.
1373 SmallVector<std::pair<Value, Value>> reductionOperands;
1374 for (Value operand : op->getOperands()) {
1375 auto blockArg = dyn_cast<BlockArgument>(operand);
1376 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1377 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1378 continue;
1379 SmallVector<Operation *> reductionOps;
1380 Value reduceValue = matchReduction(
1381 linalgOp.getRegionOutputArgs(),
1382 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1383 if (!reduceValue)
1384 continue;
1385 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1386 }
1387 if (!reductionOperands.empty()) {
1388 assert(reductionOperands.size() == 1);
1389 Operation *reduceOp =
1390 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1391 reductionOperands[0].second, bvm);
1392 if (reduceOp)
1394 }
1395
1396 // 5. Generic vectorization path for ElementwiseMappable ops.
1397 // a. Get the first max ranked shape.
1398 VectorType firstMaxRankedType;
1399 for (Value operand : op->getOperands()) {
1400 auto vecOperand = bvm.lookup(operand);
1401 assert(vecOperand && "Vector operand couldn't be found");
1402
1403 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1404 if (vecType && (!firstMaxRankedType ||
1405 firstMaxRankedType.getRank() < vecType.getRank()))
1406 firstMaxRankedType = vecType;
1407 }
1408 // b. Broadcast each op if needed.
1409 SmallVector<Value> vecOperands;
1410 for (Value scalarOperand : op->getOperands()) {
1411 Value vecOperand = bvm.lookup(scalarOperand);
1412 assert(vecOperand && "Vector operand couldn't be found");
1413
1414 if (firstMaxRankedType) {
1415 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1416 getElementTypeOrSelf(vecOperand.getType()),
1417 firstMaxRankedType.getScalableDims());
1418 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1419 } else {
1420 vecOperands.push_back(vecOperand);
1421 }
1422 }
1423 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1424 SmallVector<Type> resultTypes;
1425 for (Type resultType : op->getResultTypes()) {
1426 resultTypes.push_back(
1427 firstMaxRankedType
1428 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1429 firstMaxRankedType.getScalableDims())
1430 : resultType);
1431 }
1432 // d. Build and return the new op.
1435 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1436 resultTypes, op->getAttrs())};
1437}
1438
1439/// Generic vectorization function that rewrites the body of a `linalgOp` into
1440/// vector form. Generic vectorization proceeds as follows:
1441/// 1. Verify the `linalgOp` has one non-empty region.
1442/// 2. Values defined above the region are mapped to themselves and will be
1443/// broadcasted on a per-need basis by their consumers.
1444/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1445/// load).
1446/// TODO: Reuse opportunities for RAR dependencies.
1447/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1448/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1449/// iteration indices.
1450/// 5. Iteratively call vectorizeOneOp on the region operations.
1451///
1452/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1453/// performed to the maximal common vector size implied by the `linalgOp`
1454/// iteration space. This eager broadcasting is introduced in the
1455/// permutation_map of the vector.transfer_read operations. The eager
1456/// broadcasting makes it trivial to determine where broadcast, transposes and
1457/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1458/// the absence of good canonicalizations, the amount of work increases.
1459/// This is not deemed a problem as we expect canonicalizations and foldings to
1460/// aggressively clean up the useless work.
1461static LogicalResult
1462vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1463 LinalgOp linalgOp,
1464 SmallVectorImpl<Value> &newResults) {
1465 LDBG() << "Vectorizing operation as linalg generic/n";
1466 Block *block = linalgOp.getBlock();
1467
1468 // 2. Values defined above the region can only be broadcast for now. Make them
1469 // map to themselves.
1470 IRMapping bvm;
1471 SetVector<Value> valuesSet;
1472 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1473 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1474
1475 if (linalgOp.getNumDpsInits() == 0)
1476 return failure();
1477
1478 // 3. Turn all BBArgs into vector.transfer_read / load.
1479 Location loc = linalgOp.getLoc();
1480 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1481 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1482 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1483 if (linalgOp.isScalar(opOperand)) {
1484 bvm.map(bbarg, opOperand->get());
1485 continue;
1486 }
1487
1488 // 3.a. Convert the indexing map for this input/output to a transfer read
1489 // permutation map and masking map.
1490 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1491
1492 AffineMap readMap;
1493 VectorType readType;
1494 Type elemType = getElementTypeOrSelf(opOperand->get());
1495 if (linalgOp.isDpsInput(opOperand)) {
1496 // 3.a.i. For input reads we use the canonical vector shape.
1497 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1498 readType = state.getCanonicalVecType(elemType);
1499 } else {
1500 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1501 // reductions), the vector shape is computed by mapping the canonical
1502 // vector shape to the output domain and back to the canonical domain.
1503 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1504 readType =
1505 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1506 }
1507
1508 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1509
1510 Operation *read = vector::TransferReadOp::create(
1511 rewriter, loc, readType, opOperand->get(), indices,
1512 /*padding=*/std::nullopt, readMap);
1513 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1514 Value readValue = read->getResult(0);
1515
1516 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1517 // will be in-bounds.
1518 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1519 SmallVector<bool> inBounds(readType.getRank(), true);
1520 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1521 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1522 }
1523
1524 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1525 // TODO: remove this.
1526 if (readType.getRank() == 0)
1527 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1529
1530 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1531 << "): " << readValue;
1532 bvm.map(bbarg, readValue);
1533 bvm.map(opOperand->get(), readValue);
1534 }
1535
1537 // 4a. Register CustomVectorizationHook for yieldOp.
1538 CustomVectorizationHook vectorizeYield =
1539 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1540 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1541 };
1542 hooks.push_back(vectorizeYield);
1543
1544 // 4b. Register CustomVectorizationHook for indexOp.
1545 CustomVectorizationHook vectorizeIndex =
1546 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1547 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1548 };
1549 hooks.push_back(vectorizeIndex);
1550
1551 // 4c. Register CustomVectorizationHook for extractOp.
1552 CustomVectorizationHook vectorizeExtract =
1553 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1554 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1555 };
1556 hooks.push_back(vectorizeExtract);
1557
1558 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1559 for (Operation &op : block->getOperations()) {
1561 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1563 LDBG() << "failed to vectorize: " << op;
1564 return failure();
1565 }
1566 if (result.status == VectorizationHookStatus::NewOp) {
1567 Operation *maybeMaskedOp =
1568 state.maskOperation(rewriter, result.newOp, linalgOp);
1569 LDBG() << "New vector op: " << *maybeMaskedOp;
1570 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1571 }
1572 }
1573
1574 return success();
1575}
1576
1577/// Given the re-associations, "collapses" the input Vector type
1578///
1579/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1580/// differences:
1581/// * We can safely assume that there are no dynamic sizes.
1582/// * Scalable flags are updated alongside regular dims.
1583///
1584/// When collapsing scalable flags, conservatively avoids cases with two
1585/// scalable dims. We could re-visit this in the future.
1586///
1587/// EXAMPLE:
1588/// type = vector<4x16x[8]x16xf32>
1589/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1590/// (d0, d1, d2, d3) -> (d2, d3)]
1591/// Result:
1592/// vector<64x[128]xf32>
1593static VectorType getCollapsedVecType(VectorType type,
1594 ArrayRef<AffineMap> reassociation) {
1595 assert(type.getNumScalableDims() < 2 &&
1596 "Collapsing more than 1 scalable dim is not supported ATM");
1597
1598 // Use the fact that reassociation is valid to simplify the logic: only use
1599 // each map's rank.
1600 assert(isReassociationValid(reassociation) && "invalid reassociation");
1601
1602 auto shape = type.getShape();
1603 auto scalableFlags = type.getScalableDims();
1604 SmallVector<int64_t> newShape;
1605 SmallVector<bool> newScalableFlags;
1606
1607 unsigned currentDim = 0;
1608 for (AffineMap m : reassociation) {
1609 unsigned dim = m.getNumResults();
1610 int64_t size = 1;
1611 bool flag = false;
1612 for (unsigned d = 0; d < dim; ++d) {
1613 size *= shape[currentDim + d];
1614 flag |= scalableFlags[currentDim + d];
1615 }
1616 newShape.push_back(size);
1617 newScalableFlags.push_back(flag);
1618 currentDim += dim;
1619 }
1620
1621 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1622}
1623
1624/// Vectorize `linalg.pack` as:
1625/// * xfer_read -> shape_cast -> transpose -> xfer_write
1626///
1627/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1628/// sizes for the xfer_write operation). This is sufficient to infer the other
1629/// vector sizes required here.
1630///
1631/// If the vector sizes are not provided:
1632/// * the vector sizes are determined from the destination tensor static shape.
1633/// * the inBounds attribute is used instead of masking.
1634///
1635/// EXAMPLE (no vector sizes):
1636/// ```
1637/// %pack = tensor.pack %src
1638/// inner_dims_pos = [2, 1]
1639/// inner_tiles = [16, 2]
1640/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1641/// ``
1642/// is vectorizes as:
1643/// ```
1644/// %read = vector.transfer_read %src
1645/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1646/// %sc = vector.shape_cast %read
1647/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1648/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1649/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1650/// %write = vector.transfer_write %tr into %dest
1651/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1652/// ```
1653static LogicalResult
1654vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1655 ArrayRef<int64_t> inputVectorSizes,
1656 SmallVectorImpl<Value> &newResults) {
1657 if (!inputVectorSizes.empty()) {
1658 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1659 "Invalid number of input vector sizes!");
1660 }
1661
1662 // TODO: Introduce a parent class that will handle the insertion point update.
1663 OpBuilder::InsertionGuard g(rewriter);
1664 rewriter.setInsertionPoint(packOp);
1665
1666 Location loc = packOp.getLoc();
1667 std::optional<Value> padValue = packOp.getPaddingValue()
1668 ? std::optional(packOp.getPaddingValue())
1669 : std::nullopt;
1670
1671 SmallVector<int64_t> destShape =
1672 SmallVector<int64_t>(packOp.getDestType().getShape());
1673
1674 // This is just a convenience alias to clearly communicate that the input
1675 // vector sizes determine the _write_ sizes.
1676 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1677
1678 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1679 // In addition, use the inBounds attribute instead of masking.
1680 bool useInBoundsInsteadOfMasking = false;
1681 if (writeVectorSizes.empty()) {
1682 if (ShapedType::isDynamicShape(destShape))
1683 return rewriter.notifyMatchFailure(packOp,
1684 "unable to infer vector sizes");
1685
1686 writeVectorSizes = destShape;
1687 useInBoundsInsteadOfMasking = true;
1688 }
1689
1690 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1691 // _before_ the transposition (i.e. before dimension permutation). This is
1692 // done by inverting the permutation/transposition that's part of the Pack
1693 // operation. This type is required to:
1694 // 1) compute the read vector type for masked-read below, and
1695 // 2) generate shape-cast Op below that expands the read vector type.
1696 PackingMetadata packMetadata;
1697 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1698 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1699 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1700 auto preTransposeWriteVecType =
1701 VectorType::get(preTransposeWriteVecSizses,
1702 packOp.getResult().getType().getElementType());
1703
1704 // Compute vector type for the _read_ opeartion. This is simply
1705 // pre-transpose-write-vector-type with the dimensions collapsed
1706 // as per the Pack operation.
1707 VectorType readVecType = getCollapsedVecType(
1708 preTransposeWriteVecType,
1710 rewriter.getContext(), packMetadata.reassociations)));
1711
1712 // Create masked TransferReadOp.
1713 auto maskedRead = vector::createReadOrMaskedRead(
1714 rewriter, loc, packOp.getSource(), readVecType, padValue,
1715 useInBoundsInsteadOfMasking);
1716
1717 // Create ShapeCastOp.
1718 auto shapeCastOp = vector::ShapeCastOp::create(
1719 rewriter, loc, preTransposeWriteVecType, maskedRead);
1720
1721 // Create TransposeOp.
1722 auto destPermutation = invertPermutationVector(destInvPermutation);
1723 auto transposeOp = vector::TransposeOp::create(
1724 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1725
1726 // Create TransferWriteOp.
1727 Operation *write = vector::createWriteOrMaskedWrite(
1728 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1729 newResults.push_back(write->getResult(0));
1730 return success();
1731}
1732
1733/// Vectorize `linalg.unpack` as:
1734/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1735///
1736/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1737/// sizes for the xfer_read operation). This is sufficient to infer the other
1738/// vector sizes required here.
1739///
1740/// If the vector sizes are not provided:
1741/// * the vector sizes are determined from the input tensor static shape.
1742/// * the inBounds attribute is used instead of masking.
1743///
1744/// EXAMPLE (no vector sizes):
1745/// ```
1746/// %unpack = linalg.unpack %src
1747/// inner_dims_pos = [0, 1]
1748/// inner_tiles = [8, 8]
1749/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1750/// ```
1751/// is vectorized as:
1752/// ```
1753/// %read = vector.transfer_read %src
1754/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1755/// %tr = vector.transpose %read, [0, 2, 1, 3]
1756/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1757/// %sc = vector.shape_cast %tr
1758/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1759/// %vector = vector.transfer_write %sc into %dest
1760/// : vector<8x8xf32>, tensor<8x8xf32>
1761/// ```
1762static LogicalResult
1763vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1764 ArrayRef<int64_t> inputVectorSizes,
1765 ArrayRef<bool> inputScalableVecDims,
1766 SmallVectorImpl<Value> &newResults) {
1767 if (!inputVectorSizes.empty()) {
1768 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1769 "Invalid number of input vector sizes!");
1770 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1771 "Incompatible number of vector sizes and vector scalable flags!");
1772 }
1773
1774 // TODO: Introduce a parent class that will handle the insertion point update.
1775 OpBuilder::InsertionGuard g(rewriter);
1776 rewriter.setInsertionPoint(unpackOp);
1777
1778 ShapedType unpackTensorType = unpackOp.getSourceType();
1779
1780 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1781 bool useInBoundsInsteadOfMasking = false;
1782
1783 Location loc = unpackOp->getLoc();
1784
1785 // Obtain vector sizes for the read operation.
1786 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1787 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1788
1789 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1790 if (inputVectorSizes.empty()) {
1791 if (ShapedType::isDynamicShape(sourceShape))
1792 return rewriter.notifyMatchFailure(unpackOp,
1793 "Unable to infer vector sizes!");
1794
1795 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1796 useInBoundsInsteadOfMasking = true;
1797 }
1798
1799 // -- Generate the read operation --
1800 VectorType readVecType =
1801 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1802 readScalableVectorFlags);
1803 Value readResult = vector::createReadOrMaskedRead(
1804 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1805 useInBoundsInsteadOfMasking);
1806
1807 // -- Generate the transpose operation --
1808 PackingMetadata packMetadata;
1809 SmallVector<int64_t> lastDimToInsertPosPerm =
1810 getUnPackInverseSrcPerm(unpackOp, packMetadata);
1811 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1812 rewriter, loc, readResult, lastDimToInsertPosPerm);
1813
1814 // -- Generate the shape_cast operation --
1815 VectorType collapsedVecType = getCollapsedVecType(
1816 transposeOp.getType(),
1818 rewriter.getContext(), packMetadata.reassociations)));
1819 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1820 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
1821
1822 // -- Generate the write operation --
1823 Operation *write = vector::createWriteOrMaskedWrite(
1824 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
1825 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1826
1827 newResults.push_back(write->getResult(0));
1828 return success();
1829}
1830
1831/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
1832/// and (3) all-zero lowPad to
1833/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
1834static LogicalResult
1835vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1836 ArrayRef<int64_t> inputVectorSizes,
1837 SmallVectorImpl<Value> &newResults) {
1838 auto padValue = padOp.getConstantPaddingValue();
1839 Location loc = padOp.getLoc();
1840
1841 // TODO: Introduce a parent class that will handle the insertion point update.
1842 OpBuilder::InsertionGuard g(rewriter);
1843 rewriter.setInsertionPoint(padOp);
1844
1845 ReifiedRankedShapedTypeDims reifiedReturnShapes;
1846 LogicalResult status =
1847 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
1848 .reifyResultShapes(rewriter, reifiedReturnShapes);
1849 (void)status; // prevent unused variable warning on non-assert builds
1850 assert(succeeded(status) && "failed to reify result shapes");
1851 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
1852 auto maskedRead = vector::createReadOrMaskedRead(
1853 rewriter, loc, padOp.getSource(), readType, padValue,
1854 /*useInBoundsInsteadOfMasking=*/false);
1855
1856 // Create Xfer write Op
1857 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
1858 padOp.getResultType().getElementType());
1859 Operation *write =
1860 vector::createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
1861 newResults.push_back(write->getResult(0));
1862 return success();
1863}
1864
1865// TODO: probably need some extra checks for reduction followed by consumer
1866// ops that may not commute (e.g. linear reduction + non-linear instructions).
1867static LogicalResult reductionPreconditions(LinalgOp op) {
1868 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
1869 LDBG() << "reduction precondition failed: no reduction iterator";
1870 return failure();
1871 }
1872 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
1873 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
1874 if (indexingMap.isPermutation())
1875 continue;
1876
1877 Operation *reduceOp = matchLinalgReduction(&opOperand);
1878 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
1879 LDBG() << "reduction precondition failed: reduction detection failed";
1880 return failure();
1881 }
1882 }
1883 return success();
1884}
1885
1886static LogicalResult
1887vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1888 bool flatten1DDepthwiseConv) {
1889 if (flatten1DDepthwiseConv) {
1890 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
1891 "supported";
1892 return failure();
1893 }
1894
1896 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
1897 return failure();
1898 }
1899
1900 // Support dynamic shapes in 1D depthwise convolution, but only in the
1901 // _channel_ dimension.
1902 Value lhs = conv.getDpsInputOperand(0)->get();
1903 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1904 auto shapeWithoutCh = lhsShape.drop_back(1);
1905 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1906 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
1907 "channel dim can be dynamic";
1908 return failure();
1909 }
1910
1911 return success();
1912}
1913
1914static LogicalResult
1915vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
1916 bool flatten1DDepthwiseConv) {
1918 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
1919
1920 if (hasReductionIterator(op))
1921 return reductionPreconditions(op);
1922
1923 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1924 // linalg.copy ops and ops that implement ContractionOpInterface for now.
1925 if (!isElementwise(op) &&
1926 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
1927 op.getOperation()))
1928 return failure();
1929
1930 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
1931 return success();
1932}
1933
1934//// This hook considers two cases:
1935/// (1) If the input-vector-sizes are empty, then the vector sizes will be
1936/// infered. This is only possible when all shapes are static.
1937/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
1938/// carry out basic sanity-checking.
1939static LogicalResult
1940vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
1941 ArrayRef<int64_t> inputVectorSizes) {
1942 // TODO: Support Memref UnPackOp. Temporarily return failure.
1943 if (!unpackOp.hasPureTensorSemantics())
1944 return failure();
1945
1946 // If there are no input vector sizes and all shapes are static, there is
1947 // nothing left to check.
1948 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
1949 unpackOp.getSourceType().hasStaticShape())
1950 return success();
1951
1952 // The number of input vector sizes must be equal to:
1953 // * read-vector-rank
1954 if (!inputVectorSizes.empty() &&
1955 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
1956 LDBG() << "Incorrect number of input vector sizes";
1957 return failure();
1958 }
1959
1960 // Check the vector sizes for the read operation.
1962 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
1963 LDBG() << "Invalid vector sizes for the read operation";
1964 return failure();
1965 }
1966
1967 return success();
1968}
1969
1970static LogicalResult
1971vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
1972 ArrayRef<int64_t> inputVectorSizes) {
1973
1974 TypedValue<RankedTensorType> source = sliceOp.getSource();
1975 auto sourceType = source.getType();
1976 if (!VectorType::isValidElementType(sourceType.getElementType()))
1977 return failure();
1978
1979 // Get the pad value.
1980 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1981 // scalar padding value. Note that:
1982 // * for in-bounds accesses,
1983 // the value is actually irrelevant. There are 2 cases in which xfer.read
1984 // accesses are known to be in-bounds:
1985 // 1. The source shape is static (output vector sizes would be based on
1986 // the source shape and hence all memory accesses would be in-bounds),
1987 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
1988 // this case it is safe to assume that all memory accesses are in-bounds.
1989 //
1990 // When the value is not known and not needed, use 0. Otherwise, bail out.
1991 Value padValue = getStaticPadVal(sliceOp);
1992 bool isOutOfBoundsRead =
1993 !sourceType.hasStaticShape() && inputVectorSizes.empty();
1994
1995 if (!padValue && isOutOfBoundsRead) {
1996 LDBG() << "Failed to get a pad value for out-of-bounds read access";
1997 return failure();
1998 }
1999 return success();
2000}
2001
2002/// Vectorize a named linalg contraction op into:
2003/// vector::TransferReadOp - Reads vectors from the operands
2004/// vector::ContractionOp - Performs contraction
2005/// vector::TransferWriteOp - Write the result vector back to the
2006/// destination
2007/// The operands shapes are preserved and loaded directly into vectors.
2008/// Any further permutations or numerical casting remain within contraction op.
2009static LogicalResult
2010vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2011 LinalgOp linalgOp,
2012 SmallVectorImpl<Value> &newResults) {
2013 Location loc = linalgOp.getLoc();
2014 MLIRContext *ctx = linalgOp.getContext();
2015
2016 // For simplicity, contraction vectorization is limited to linalg named ops.
2017 // Generic op is ignored as not every arbitrary contraction body can be
2018 // expressed by a vector.contract.
2019 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2020 return failure();
2021
2022 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2023 Operation *reduceOp = matchLinalgReduction(outOperand);
2024 auto maybeKind = getCombinerOpKind(reduceOp);
2025 if (!maybeKind) {
2026 LDBG() << "Failed to determine contraction combining kind.";
2027 return failure();
2028 }
2029
2030 // Check that all dimensions are present in the input operands.
2031 // Arbitrary broadcasts are not supported by the vector contraction.
2032 // Broadcasts are expected to be decomposed before vectorization.
2033 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2034 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2035 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2036 LDBG() << "Contractions with broadcasts are not supported.";
2037 return failure();
2038 }
2039
2040 // Load operands.
2041 SmallVector<Value> vecOperands;
2042 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2043 // The operand vector shape is computed by mapping the canonical vector
2044 // shape to the operand's domain. Further permutations are left as a part of
2045 // the contraction.
2046 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2047 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2048 indexingMap.getNumResults(), rewriter.getContext());
2049 Type elemType = getElementTypeOrSelf(opOperand.get());
2050 VectorType readType =
2051 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2052
2054 rewriter, loc, opOperand.get(), readType,
2055 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2056 /*useInBoundsInsteadOfMasking=*/false);
2057 vecOperands.push_back(read);
2058 }
2059
2060 // Remap iterators from linalg to vector.
2061 SmallVector<Attribute> iterAttrs;
2062 auto iterators = linalgOp.getIteratorTypesArray();
2063 for (utils::IteratorType iter : iterators) {
2064 auto vecIter = iter == utils::IteratorType::parallel
2065 ? vector::IteratorType::parallel
2066 : vector::IteratorType::reduction;
2067 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2068 }
2069
2070 // Create contraction.
2071 Operation *contractOp = vector::ContractionOp::create(
2072 rewriter, loc, /*lhs=*/vecOperands[0],
2073 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2074 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2075 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2076
2077 // Store result.
2078 Operation *write = vector::createWriteOrMaskedWrite(
2079 rewriter, loc, contractOp->getResult(0), outOperand->get());
2080
2081 // Finalize.
2082 if (!write->getResults().empty())
2083 newResults.push_back(write->getResult(0));
2084
2085 return success();
2086}
2087
2088namespace {
2089enum class ConvOperationKind { Conv, Pool };
2090} // namespace
2091
2092static bool isCastOfBlockArgument(Operation *op) {
2093 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2094 isa<BlockArgument>(op->getOperand(0));
2095}
2096
2097// Returns the ConvOperationKind of the op using reduceOp of the generic
2098// payload. If it is neither a convolution nor a pooling, it returns
2099// std::nullopt.
2100//
2101// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2102// + yield) and rhs is not used) then it is the body of a pooling
2103// If conv, check for single `mul` predecessor. The `mul` operands must be
2104// block arguments or extension of block arguments.
2105// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2106// must be block arguments or extension of block arguments.
2107static std::optional<ConvOperationKind>
2108getConvOperationKind(Operation *reduceOp) {
2109 int numBlockArguments =
2110 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2111
2112 switch (numBlockArguments) {
2113 case 1: {
2114 // Will be convolution if feeder is a MulOp.
2115 // A strength reduced version of MulOp for i1 type is AndOp which is also
2116 // supported. Otherwise, it can be pooling. This strength reduction logic
2117 // is in `buildBinaryFn` helper in the Linalg dialect.
2118 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2119 llvm::IsaPred<BlockArgument>);
2120 assert(feedValIt != reduceOp->operand_end() &&
2121 "Expected a non-block argument operand");
2122 Operation *feedOp = (*feedValIt).getDefiningOp();
2123 if (isCastOfBlockArgument(feedOp)) {
2124 return ConvOperationKind::Pool;
2125 }
2126
2127 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2128 (isa<arith::AndIOp>(feedOp) &&
2129 feedOp->getResultTypes()[0].isInteger(1))) &&
2130 llvm::all_of(feedOp->getOperands(), [](Value v) {
2131 if (isa<BlockArgument>(v))
2132 return true;
2133 if (Operation *op = v.getDefiningOp())
2134 return isCastOfBlockArgument(op);
2135 return false;
2136 }))) {
2137 return std::nullopt;
2138 }
2139
2140 return ConvOperationKind::Conv;
2141 }
2142 case 2:
2143 // Must be pooling
2144 return ConvOperationKind::Pool;
2145 default:
2146 return std::nullopt;
2147 }
2148}
2149
2150static bool isSupportedPoolKind(vector::CombiningKind kind) {
2151 switch (kind) {
2152 case vector::CombiningKind::ADD:
2153 case vector::CombiningKind::MAXNUMF:
2154 case vector::CombiningKind::MAXIMUMF:
2155 case vector::CombiningKind::MAXSI:
2156 case vector::CombiningKind::MAXUI:
2157 case vector::CombiningKind::MINNUMF:
2158 case vector::CombiningKind::MINIMUMF:
2159 case vector::CombiningKind::MINSI:
2160 case vector::CombiningKind::MINUI:
2161 return true;
2162 default:
2163 return false;
2164 }
2165}
2166
2167static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2168 auto getOperandType = [&](auto operand) {
2169 return dyn_cast<ShapedType>((operand->get()).getType());
2170 };
2171 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2172 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2173 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2174 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2175 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2176 // Note that this also ensures 2D and 3D convolutions are rejected.
2177 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2178 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2179 return failure();
2180
2181 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2182 if (!reduceOp)
2183 return failure();
2184
2185 auto maybeOper = getConvOperationKind(reduceOp);
2186 if (!maybeOper.has_value())
2187 return failure();
2188
2189 auto maybeKind = getCombinerOpKind(reduceOp);
2190 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2191 // can get strength reduced to `OR` which is also supported. This strength
2192 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2193 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2194 *maybeKind != vector::CombiningKind::OR) &&
2195 (*maybeOper != ConvOperationKind::Pool ||
2196 !isSupportedPoolKind(*maybeKind)))) {
2197 return failure();
2198 }
2199
2200 auto rhsRank = rhsShapedType.getRank();
2201 if (*maybeOper == ConvOperationKind::Pool) {
2202 if (rhsRank != 1)
2203 return failure();
2204 } else {
2205 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2206 return failure();
2207 }
2208
2209 return success();
2210}
2211
2212static LogicalResult vectorizeLinalgOpPrecondition(
2213 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2214 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2215 // tensor with dimension of 0 cannot be vectorized.
2216 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2217 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2218 }))
2219 return failure();
2220 // Check API contract for input vector sizes.
2221 if (!inputVectorSizes.empty() &&
2222 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2223 inputVectorSizes)))
2224 return failure();
2225
2226 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2227 linalgOp, flatten1DDepthwiseConv))) {
2228 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2229 return failure();
2230 }
2231
2232 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2233
2234 // Register CustomVectorizationPrecondition for extractOp.
2235 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2236
2237 // All types in the body should be a supported element type for VectorType.
2238 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2239 // Check if any custom hook can vectorize the inner op.
2240 if (llvm::any_of(
2241 customPreconditions,
2242 [&](const CustomVectorizationPrecondition &customPrecondition) {
2243 return succeeded(
2244 customPrecondition(&innerOp, vectorizeNDExtract));
2245 })) {
2246 continue;
2247 }
2248 if (!llvm::all_of(innerOp.getOperandTypes(),
2249 VectorType::isValidElementType)) {
2250 return failure();
2251 }
2252 if (!llvm::all_of(innerOp.getResultTypes(),
2253 VectorType::isValidElementType)) {
2254 return failure();
2255 }
2256 }
2257 if (isElementwise(linalgOp))
2258 return success();
2259
2260 // Check for both named as well as generic convolution ops.
2261 if (isaConvolutionOpInterface(linalgOp))
2262 return vectorizeConvOpPrecondition(linalgOp);
2263
2264 // TODO: the common vector shape is equal to the static loop sizes only when
2265 // all indexing maps are projected permutations. For convs and stencils the
2266 // logic will need to evolve.
2267 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2268 LDBG() << "precondition failed: not projected permutations";
2269 return failure();
2270 }
2271 if (failed(reductionPreconditions(linalgOp))) {
2272 LDBG() << "precondition failed: reduction preconditions";
2273 return failure();
2274 }
2275 return success();
2276}
2277
2278static LogicalResult
2279vectorizePackOpPrecondition(linalg::PackOp packOp,
2280 ArrayRef<int64_t> inputVectorSizes) {
2281 // TODO: Support Memref PackOp. Temporarily return failure.
2282 if (!packOp.hasPureTensorSemantics())
2283 return failure();
2284
2285 auto padValue = packOp.getPaddingValue();
2286 Attribute cstAttr;
2287 // TODO: Relax this condiiton
2288 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2289 LDBG() << "pad value is not constant: " << packOp;
2290 return failure();
2291 }
2292
2293 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2294 bool satisfyEmptyCond = true;
2295 if (inputVectorSizes.empty()) {
2296 if (!packOp.getDestType().hasStaticShape() ||
2297 !packOp.getSourceType().hasStaticShape())
2298 satisfyEmptyCond = false;
2299 }
2300
2301 if (!satisfyEmptyCond &&
2303 resultTensorShape.take_front(packOp.getSourceRank()),
2304 inputVectorSizes)))
2305 return failure();
2306
2307 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2308 return !getConstantIntValue(v).has_value();
2309 })) {
2310 LDBG() << "inner_tiles must be constant: " << packOp;
2311 return failure();
2312 }
2313
2314 return success();
2315}
2316
2317static LogicalResult
2318vectorizePadOpPrecondition(tensor::PadOp padOp,
2319 ArrayRef<int64_t> inputVectorSizes) {
2320 auto padValue = padOp.getConstantPaddingValue();
2321 if (!padValue) {
2322 LDBG() << "pad value is not constant: " << padOp;
2323 return failure();
2324 }
2325
2326 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2327 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2328 inputVectorSizes)))
2329 return failure();
2330
2331 // Padding with non-zero low pad values is not supported, unless the
2332 // corresponding result dim is 1 as this would require shifting the results to
2333 // the right for the low padded dims by the required amount of low padding.
2334 // However, we do support low padding if the dims being low padded have result
2335 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2336 // input size of that dimension will be dynamically zero (as the sum of the
2337 // low pad and input dim size has to be one) and hence we will create a zero
2338 // mask as the lowering logic just makes the mask one for the input dim size -
2339 // which is zero here. Hence we will load the pad value which is what we want
2340 // in this case. If the low pad is dynamically zero then the lowering is
2341 // correct as well as no shifts are necessary.
2342 if (llvm::any_of(llvm::enumerate(padOp.getMixedLowPad()),
2343 [&](const auto &en) {
2344 OpFoldResult padValue = en.value();
2345 unsigned pos = en.index();
2346 std::optional<int64_t> pad = getConstantIntValue(padValue);
2347 return (!pad.has_value() || pad.value() != 0) &&
2348 resultTensorShape[pos] != 1;
2349 })) {
2350 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2351 return failure();
2352 }
2353
2354 return success();
2355}
2356
2357/// Preconditions for scalable vectors.
2358///
2359/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2360/// models the fact that in practice we would only make selected dimensions
2361/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2362/// unconditionally - we are yet to identify meaningful conditions.
2363static LogicalResult
2364vectorizeScalableVectorPrecondition(Operation *op,
2365 ArrayRef<int64_t> inputVectorSizes,
2366 ArrayRef<bool> inputScalableVecDims) {
2367 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2368 "Number of input vector sizes and scalable dims doesn't match");
2369
2370 size_t numOfScalableDims =
2371 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2372
2373 if (numOfScalableDims == 0)
2374 return success();
2375
2376 auto linalgOp = dyn_cast<LinalgOp>(op);
2377
2378 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2379 // exception of UnpackOp for which there is a dedicated hook.
2380 if (!linalgOp) {
2381 return success(isa<linalg::UnPackOp>(op));
2382 }
2383
2384 // Cond 2: There's been no need for more than 2 scalable dims so far
2385 if (numOfScalableDims > 2)
2386 return failure();
2387
2388 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2389 // it matches one of the supported cases:
2390 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2391 // (*).
2392 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2393 // parallel dims.
2394 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2395 // dim.
2396 // The 2nd restriction above means that only Matmul-like Ops are supported
2397 // when 2 dims are scalable, e.g. :
2398 // * iterators = [parallel, parallel, reduction]
2399 // * scalable flags = [true, true, false]
2400 //
2401 // (*) Non-unit dims get folded away in practice.
2402 // TODO: Relax these conditions as good motivating examples are identified.
2403
2404 // Find the first scalable flag.
2405 bool seenNonUnitParallel = false;
2406 auto iterators = linalgOp.getIteratorTypesArray();
2407 SmallVector<bool> scalableFlags(inputScalableVecDims);
2408 int64_t idx = scalableFlags.size() - 1;
2409 while (!scalableFlags[idx]) {
2410 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2411 seenNonUnitParallel |=
2412 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2413
2414 iterators.pop_back();
2415 scalableFlags.pop_back();
2416 --idx;
2417 }
2418
2419 // Analyze the iterator corresponding to the first scalable dim.
2420 switch (iterators.back()) {
2421 case utils::IteratorType::reduction: {
2422 // Check 3. above is met.
2423 if (iterators.size() != inputVectorSizes.size()) {
2424 LDBG() << "Non-trailing reduction dim requested for scalable "
2425 "vectorization";
2426 return failure();
2427 }
2428 if (isa<linalg::MatmulOp>(op)) {
2429 LDBG()
2430 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2431 "is not supported";
2432 return failure();
2433 }
2434 break;
2435 }
2436 case utils::IteratorType::parallel: {
2437 // Check 1. and 2. above are met.
2438 if (seenNonUnitParallel) {
2439 LDBG() << "Inner parallel dim not requested for scalable "
2440 "vectorization";
2441 return failure();
2442 }
2443 break;
2444 }
2445 }
2446
2447 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2448 // supported for which expect the folowing config:
2449 // * iterators = [parallel, parallel, reduction]
2450 // * scalable flags = [true, true, false]
2451 if (numOfScalableDims == 2) {
2452 // Disallow below case which breaks 3. above:
2453 // * iterators = [..., parallel, reduction]
2454 // * scalable flags = [..., true, true]
2455 if (iterators.back() == utils::IteratorType::reduction) {
2456 LDBG() << "Higher dim than the trailing reduction dim requested for "
2457 "scalable "
2458 "vectorizatio";
2459 return failure();
2460 }
2461 scalableFlags.pop_back();
2462 iterators.pop_back();
2463
2464 if (!scalableFlags.back() ||
2465 (iterators.back() != utils::IteratorType::parallel))
2466 return failure();
2467 }
2468
2469 // Cond 4: Only the following ops are supported in the
2470 // presence of scalable vectors
2471 return success(
2472 isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2473 isa<linalg::BatchMatmulOp>(op) ||
2475 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2476 isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp));
2477}
2478
2480 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2481 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2482 bool flatten1DDepthwiseConv) {
2483
2484 if (!hasVectorizationImpl(op))
2485 return failure();
2486
2487 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2488 inputScalableVecDims)))
2489 return failure();
2490
2492 .Case([&](linalg::LinalgOp linalgOp) {
2493 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2494 vectorizeNDExtract,
2495 flatten1DDepthwiseConv);
2496 })
2497 .Case([&](tensor::PadOp padOp) {
2498 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2499 })
2500 .Case([&](linalg::PackOp packOp) {
2501 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2502 })
2503 .Case([&](linalg::UnPackOp unpackOp) {
2504 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2505 })
2506 .Case([&](tensor::InsertSliceOp sliceOp) {
2507 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2508 })
2509 .Default(failure());
2510}
2511
2512/// Converts affine.apply Ops to arithmetic operations.
2513static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2514 OpBuilder::InsertionGuard g(rewriter);
2515 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2516
2517 for (auto op : make_early_inc_range(toReplace)) {
2518 rewriter.setInsertionPoint(op);
2519 auto expanded = affine::expandAffineExpr(
2520 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2521 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2522 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2523 rewriter.replaceOp(op, expanded);
2524 }
2525}
2526
2527bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2528 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2529 tensor::InsertSliceOp>(op);
2530}
2531
2532FailureOr<VectorizationResult> mlir::linalg::vectorize(
2533 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2534 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2535 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2536 bool createNamedContraction) {
2537 LDBG() << "Attempting to vectorize: " << *op;
2538 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2539 LDBG() << "Input scalable vector dims: "
2540 << llvm::interleaved(inputScalableVecDims);
2541
2542 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2543 vectorizeNDExtract,
2544 flatten1DDepthwiseConv))) {
2545 LDBG() << "Vectorization pre-conditions failed";
2546 return failure();
2547 }
2548
2549 // Initialize vectorization state.
2550 VectorizationState state(rewriter);
2551 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2552 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2553 inputScalableVecDims,
2554 assumeDynamicDimsMatchVecSizes))) {
2555 LDBG() << "Vectorization state couldn't be initialized";
2556 return failure();
2557 }
2558 }
2559
2560 SmallVector<Value> results;
2561 auto vectorizeResult =
2563 .Case([&](linalg::LinalgOp linalgOp) {
2564 // Check for both named as well as generic convolution ops.
2565 if (isaConvolutionOpInterface(linalgOp)) {
2566 FailureOr<Operation *> convOr = vectorizeConvolution(
2567 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2568 flatten1DDepthwiseConv);
2569 if (succeeded(convOr)) {
2570 llvm::append_range(results, (*convOr)->getResults());
2571 return success();
2572 }
2573
2574 LDBG() << "Unsupported convolution can't be vectorized.";
2575 return failure();
2576 }
2577
2578 if (createNamedContraction &&
2579 isa<ContractionOpInterface>(linalgOp.getOperation()))
2580 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2581 results);
2582
2583 LDBG()
2584 << "Vectorize generic by broadcasting to the canonical vector "
2585 "shape";
2586
2587 // Pre-process before proceeding.
2588 convertAffineApply(rewriter, linalgOp);
2589
2590 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2591 // to 'OpBuilder' when it is passed over to some methods like
2592 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2593 // erase an op within these methods, the actual rewriter won't be
2594 // notified and we will end up with read-after-free issues!
2595 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2596 })
2597 .Case([&](tensor::PadOp padOp) {
2598 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2599 results);
2600 })
2601 .Case([&](linalg::PackOp packOp) {
2602 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2603 results);
2604 })
2605 .Case([&](linalg::UnPackOp unpackOp) {
2606 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2607 inputVectorSizes,
2608 inputScalableVecDims, results);
2609 })
2610 .Case([&](tensor::InsertSliceOp sliceOp) {
2611 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2612 results);
2613 })
2614 .Default(failure());
2615
2616 if (failed(vectorizeResult)) {
2617 LDBG() << "Vectorization failed";
2618 return failure();
2619 }
2620
2621 return VectorizationResult{results};
2622}
2623
2624LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2625 memref::CopyOp copyOp) {
2626 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2627 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2628 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2629 return failure();
2630
2631 auto srcElementType = getElementTypeOrSelf(srcType);
2632 auto dstElementType = getElementTypeOrSelf(dstType);
2633 if (!VectorType::isValidElementType(srcElementType) ||
2634 !VectorType::isValidElementType(dstElementType))
2635 return failure();
2636
2637 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2638 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2639
2640 Location loc = copyOp->getLoc();
2641 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2642 SmallVector<Value> indices(srcType.getRank(), zero);
2643
2644 Value readValue = vector::TransferReadOp::create(
2645 rewriter, loc, readType, copyOp.getSource(), indices,
2646 /*padding=*/std::nullopt,
2647 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2648 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2649 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2650 ArrayRef<int64_t>());
2651 readValue =
2652 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2653 }
2654 Operation *writeValue = vector::TransferWriteOp::create(
2655 rewriter, loc, readValue, copyOp.getTarget(), indices,
2656 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2657 rewriter.replaceOp(copyOp, writeValue->getResults());
2658 return success();
2659}
2660
2661//----------------------------------------------------------------------------//
2662// Misc. vectorization patterns.
2663//----------------------------------------------------------------------------//
2664/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2665/// given operation type OpTy.
2666template <typename OpTy>
2667struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2668 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2669
2670 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2671 PatternRewriter &rewriter) const final {
2672 bool changed = false;
2673 // Insert users in vector, because some users may be replaced/removed.
2674 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2675 if (auto op = dyn_cast<OpTy>(user))
2676 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2677 return success(changed);
2678 }
2679
2680protected:
2681 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2682 tensor::PadOp padOp, OpTy op) const = 0;
2683};
2684
2685/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2686/// ```
2687/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2688/// %r = vector.transfer_read %0[%c0, %c0], %cst
2689/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2690/// ```
2691/// is rewritten to:
2692/// ```
2693/// %r = vector.transfer_read %src[%c0, %c0], %padding
2694/// {in_bounds = [true, true]}
2695/// : tensor<?x?xf32>, vector<17x5xf32>
2696/// ```
2697/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2698/// sure that the original padding value %cst was never used.
2699///
2700/// This rewrite is possible if:
2701/// - `xferOp` has no out-of-bounds dims or mask.
2702/// - Low padding is static 0.
2703/// - Single, scalar padding value.
2704struct PadOpVectorizationWithTransferReadPattern
2705 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2706 using VectorizePadOpUserPattern<
2707 vector::TransferReadOp>::VectorizePadOpUserPattern;
2708
2709 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2710 vector::TransferReadOp xferOp) const override {
2711 // Low padding must be static 0.
2712 if (!padOp.hasZeroLowPad())
2713 return failure();
2714 // Pad value must be a constant.
2715 auto padValue = padOp.getConstantPaddingValue();
2716 if (!padValue)
2717 return failure();
2718 // Padding value of existing `xferOp` is unused.
2719 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2720 return failure();
2721
2722 rewriter.modifyOpInPlace(xferOp, [&]() {
2723 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2724 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2725 rewriter.getBoolArrayAttr(inBounds));
2726 xferOp.getBaseMutable().assign(padOp.getSource());
2727 xferOp.getPaddingMutable().assign(padValue);
2728 });
2729
2730 return success();
2731 }
2732};
2733
2734/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2735/// This pattern rewrites TransferWriteOps that write to a padded tensor
2736/// value, where the same amount of padding is immediately removed again after
2737/// the write. In such cases, the TransferWriteOp can write to the non-padded
2738/// tensor value and apply out-of-bounds masking. E.g.:
2739/// ```
2740/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2741/// : tensor<...> to tensor<?x?xf32>
2742/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2743/// %2 = vector.transfer_write %vec, %1[...]
2744/// : vector<17x5xf32>, tensor<17x5xf32>
2745/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2746/// : tensor<17x5xf32> to tensor<?x?xf32>
2747/// ```
2748/// is rewritten to:
2749/// ```
2750/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2751/// : tensor<...> to tensor<?x?xf32>
2752/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2753/// tensor<?x?xf32>
2754/// ```
2755/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2756/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2757/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2758/// from %r's old dimensions.
2759///
2760/// This rewrite is possible if:
2761/// - Low padding is static 0.
2762/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2763/// ExtractSliceOp trims the same amount of padding that was added
2764/// beforehand.
2765/// - Single, scalar padding value.
2766struct PadOpVectorizationWithTransferWritePattern
2767 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2768 using VectorizePadOpUserPattern<
2769 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2770
2771 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2772 vector::TransferWriteOp xferOp) const override {
2773 // TODO: support 0-d corner case.
2774 if (xferOp.getTransferRank() == 0)
2775 return failure();
2776
2777 // Low padding must be static 0.
2778 if (!padOp.hasZeroLowPad())
2779 return failure();
2780 // Pad value must be a constant.
2781 auto padValue = padOp.getConstantPaddingValue();
2782 if (!padValue)
2783 return failure();
2784 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2785 if (!xferOp->hasOneUse())
2786 return failure();
2787 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2788 if (!trimPadding)
2789 return failure();
2790 // Only static zero offsets supported when trimming padding.
2791 if (!trimPadding.hasZeroOffset())
2792 return failure();
2793 // trimPadding must remove the amount of padding that was added earlier.
2794 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2795 return failure();
2796
2797 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2798 rewriter.setInsertionPoint(xferOp);
2799
2800 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2801 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2802 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2803 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2804 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2805 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2806
2807 return success();
2808 }
2809
2810 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2811 /// i.e., same dimensions.
2812 ///
2813 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2814 /// dimensions, this function tries to infer the (static) tensor size by
2815 /// looking at the defining op and utilizing op-specific knowledge.
2816 ///
2817 /// This is a conservative analysis. In case equal tensor sizes cannot be
2818 /// proven statically, this analysis returns `false` even though the tensor
2819 /// sizes may turn out to be equal at runtime.
2820 bool hasSameTensorSize(Value beforePadding,
2821 tensor::ExtractSliceOp afterTrimming) const {
2822 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2823 // result and CastOp operand.
2824 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2825 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2826 return true;
2827
2828 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
2829 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
2830 // Only RankedTensorType supported.
2831 if (!t1 || !t2)
2832 return false;
2833 // Rank of both values must be the same.
2834 if (t1.getRank() != t2.getRank())
2835 return false;
2836
2837 // All static dimensions must be the same. Mixed cases (e.g., dimension
2838 // static in `t1` but dynamic in `t2`) are not supported.
2839 for (unsigned i = 0; i < t1.getRank(); ++i) {
2840 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
2841 return false;
2842 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
2843 return false;
2844 }
2845
2846 // Nothing more to check if all dimensions are static.
2847 if (t1.getNumDynamicDims() == 0)
2848 return true;
2849
2850 // All dynamic sizes must be the same. The only supported case at the
2851 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
2852 // thereof).
2853
2854 // Apart from CastOp, only ExtractSliceOp is supported.
2855 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
2856 if (!beforeSlice)
2857 return false;
2858
2859 assert(static_cast<size_t>(t1.getRank()) ==
2860 beforeSlice.getMixedSizes().size());
2861 assert(static_cast<size_t>(t2.getRank()) ==
2862 afterTrimming.getMixedSizes().size());
2863
2864 for (unsigned i = 0; i < t1.getRank(); ++i) {
2865 // Skip static dimensions.
2866 if (!t1.isDynamicDim(i))
2867 continue;
2868 auto size1 = beforeSlice.getMixedSizes()[i];
2869 auto size2 = afterTrimming.getMixedSizes()[i];
2870
2871 // Case 1: Same value or same constant int.
2872 if (isEqualConstantIntOrValue(size1, size2))
2873 continue;
2874
2875 // Other cases: Take a deeper look at defining ops of values.
2876 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
2877 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
2878 if (!v1 || !v2)
2879 return false;
2880
2881 // Case 2: Both values are identical AffineMinOps. (Should not happen if
2882 // CSE is run.)
2883 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
2884 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
2885 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
2886 minOp1.getOperands() == minOp2.getOperands())
2887 continue;
2888
2889 // Add additional cases as needed.
2890 }
2891
2892 // All tests passed.
2893 return true;
2894 }
2895};
2896
2897/// Returns the effective Pad value for the input op, provided it's a scalar.
2898///
2899/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2900/// this Op performs padding, retrieve the padding value provided that it's
2901/// a scalar and static/fixed for all the padded values. Returns an empty value
2902/// otherwise.
2903///
2904/// TODO: This is used twice (when checking vectorization pre-conditions and
2905/// when vectorizing). Cache results instead of re-running.
2906static Value getStaticPadVal(Operation *op) {
2907 if (!op)
2908 return {};
2909
2910 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2911 // being broadcast, provided that it's a scalar.
2912 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2913 auto source = bcast.getSource();
2914 if (llvm::dyn_cast<VectorType>(source.getType()))
2915 return {};
2916
2917 return source;
2918 }
2919
2920 // 2. linalg.fill - use the scalar input value that used to fill the output
2921 // tensor.
2922 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2923 return fill.getInputs()[0];
2924 }
2925
2926 // 3. tensor.generateOp - can't guarantee the value is fixed without
2927 // analysing, bail out.
2928 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2929 return {};
2930 }
2931
2932 // 4. vector.transfer_write - inspect the input vector that's written from. If
2933 // if contains a single value that has been broadcast (e.g. via
2934 // vector.broadcast), extract it, fail otherwise.
2935 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2936 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
2937
2938 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2939 // than the input tensor, then, provided it's constant, we'll extract the
2940 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2941 // TODO: Clarify the semantics when the input tensor is larger than the
2942 // destination.
2943 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2944 return getStaticPadVal(slice.getDest().getDefiningOp());
2945
2946 return {};
2947}
2948
2949static LogicalResult
2950vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2951 ArrayRef<int64_t> inputVectorSizes,
2952 SmallVectorImpl<Value> &newResults) {
2953 // TODO: Introduce a parent class that will handle the insertion point update.
2954 OpBuilder::InsertionGuard g(rewriter);
2955 rewriter.setInsertionPoint(sliceOp);
2956
2957 TypedValue<RankedTensorType> source = sliceOp.getSource();
2958 auto sourceType = source.getType();
2959 auto resultType = sliceOp.getResultType();
2960
2961 Value padValue = getStaticPadVal(sliceOp);
2962
2963 if (!padValue) {
2964 auto elemType = sourceType.getElementType();
2965 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
2966 rewriter.getZeroAttr(elemType));
2967 }
2968
2969 // 2. Get the vector shape
2970 SmallVector<int64_t> vecShape;
2971 size_t rankDiff = resultType.getRank() - sourceType.getRank();
2972 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
2973 if (!inputVectorSizes.empty()) {
2974 vecShape.push_back(inputVectorSizes[i]);
2975 } else if (!sourceType.isDynamicDim(i)) {
2976 vecShape.push_back(sourceType.getDimSize(i));
2977 } else if (!resultType.isDynamicDim(i)) {
2978 // Source shape is not statically known, but result shape is.
2979 // Vectorize with size of result shape. This may be larger than the
2980 // source size.
2981 // FIXME: Using rankDiff implies that the source tensor is inserted at
2982 // the end of the destination tensor. However, that's not required.
2983 vecShape.push_back(resultType.getDimSize(rankDiff + i));
2984 } else {
2985 // Neither source nor result dim of padOp is static. Cannot vectorize
2986 // the copy.
2987 return failure();
2988 }
2989 }
2990 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
2991
2992 // 3. Generate TransferReadOp + TransferWriteOp
2993 auto loc = sliceOp.getLoc();
2994
2995 // Create read
2996 SmallVector<Value> readIndices(
2997 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
2999 rewriter, loc, source, vecType, padValue,
3000 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3001
3002 // Create write
3003 auto writeIndices =
3004 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3005 Operation *write =
3006 vector::createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3007 writeIndices, inputVectorSizes.empty());
3008
3009 // 4. Finalize
3010 newResults.push_back(write->getResult(0));
3011
3012 return success();
3013}
3014
3015/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3016/// ```
3017/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3018/// %r = tensor.insert_slice %0
3019/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3020/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3021/// ```
3022/// is rewritten to:
3023/// ```
3024/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3025/// : tensor<?x?xf32>, vector<17x5xf32>
3026/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3027/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3028/// ```
3029///
3030/// This rewrite is possible if:
3031/// - Low padding is static 0.
3032/// - `padOp` result shape is static.
3033/// - The entire padded tensor is inserted.
3034/// (Implies that sizes of `insertOp` are all static.)
3035/// - Only unit strides in `insertOp`.
3036/// - Single, scalar padding value.
3037/// - `padOp` result not used as destination.
3038struct PadOpVectorizationWithInsertSlicePattern
3039 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3040 using VectorizePadOpUserPattern<
3041 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3042
3043 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3044 tensor::InsertSliceOp insertOp) const override {
3045 // Low padding must be static 0.
3046 if (!padOp.hasZeroLowPad())
3047 return failure();
3048 // Only unit stride supported.
3049 if (!insertOp.hasUnitStride())
3050 return failure();
3051 // Pad value must be a constant.
3052 auto padValue = padOp.getConstantPaddingValue();
3053 if (!padValue)
3054 return failure();
3055 // Dynamic shapes not supported.
3056 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3057 return failure();
3058 // Pad result not used as destination.
3059 if (insertOp.getDest() == padOp.getResult())
3060 return failure();
3061
3062 auto vecType = VectorType::get(padOp.getType().getShape(),
3063 padOp.getType().getElementType());
3064 unsigned vecRank = vecType.getRank();
3065 unsigned tensorRank = insertOp.getType().getRank();
3066
3067 // Check if sizes match: Insert the entire tensor into most minor dims.
3068 // (No permutations allowed.)
3069 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3070 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3071 if (!llvm::all_of(
3072 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3073 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3074 }))
3075 return failure();
3076
3077 // Insert the TransferReadOp and TransferWriteOp at the position of the
3078 // InsertSliceOp.
3079 rewriter.setInsertionPoint(insertOp);
3080
3081 // Generate TransferReadOp: Read entire source tensor and add high
3082 // padding.
3083 SmallVector<Value> readIndices(
3084 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3085 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3086 vecType, padOp.getSource(),
3087 readIndices, padValue);
3088
3089 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3090 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3091 // source must fit into the destination at the specified offsets.
3092 auto writeIndices = getValueOrCreateConstantIndexOp(
3093 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3094 SmallVector<bool> inBounds(vecRank, true);
3095 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3096 insertOp, read, insertOp.getDest(), writeIndices,
3097 ArrayRef<bool>{inBounds});
3098
3099 return success();
3100 }
3101};
3102
3104 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3105 patterns.add<PadOpVectorizationWithTransferReadPattern,
3106 PadOpVectorizationWithTransferWritePattern,
3107 PadOpVectorizationWithInsertSlicePattern>(
3108 patterns.getContext(), baseBenefit.getBenefit() + 1);
3109}
3110
3111//----------------------------------------------------------------------------//
3112// Forwarding patterns
3113//----------------------------------------------------------------------------//
3114
3115/// Check whether there is any interleaved use of any `values` between
3116/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3117/// is in a different block.
3118static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3119 ValueRange values) {
3120 if (firstOp->getBlock() != secondOp->getBlock() ||
3121 !firstOp->isBeforeInBlock(secondOp)) {
3122 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3123 << ", second op: " << *secondOp;
3124 return true;
3125 }
3126 for (auto v : values) {
3127 for (auto &u : v.getUses()) {
3128 Operation *owner = u.getOwner();
3129 if (owner == firstOp || owner == secondOp)
3130 continue;
3131 // TODO: this is too conservative, use dominance info in the future.
3132 if (owner->getBlock() == firstOp->getBlock() &&
3133 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3134 continue;
3135 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3136 << ", second op: " << *secondOp;
3137 return true;
3138 }
3139 }
3140 return false;
3141}
3142
3143/// Return the unique subview use of `v` if it is indeed unique, null
3144/// otherwise.
3145static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3146 memref::SubViewOp subViewOp;
3147 for (auto &u : v.getUses()) {
3148 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3149 if (subViewOp)
3150 return memref::SubViewOp();
3151 subViewOp = newSubViewOp;
3152 }
3153 }
3154 return subViewOp;
3155}
3156
3157/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3158/// when available.
3160 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3161
3162 // TODO: support mask.
3163 if (xferOp.getMask())
3164 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3165
3166 // Transfer into `view`.
3167 Value viewOrAlloc = xferOp.getBase();
3168 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3169 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3170 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3171
3172 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3173 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3174 if (!subViewOp)
3175 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3176 Value subView = subViewOp.getResult();
3177
3178 // Find the copy into `subView` without interleaved uses.
3179 memref::CopyOp copyOp;
3180 for (auto &u : subView.getUses()) {
3181 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3182 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3183 if (newCopyOp.getTarget() != subView)
3184 continue;
3185 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3186 continue;
3187 copyOp = newCopyOp;
3188 break;
3189 }
3190 }
3191 if (!copyOp)
3192 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3193
3194 // Find the fill into `viewOrAlloc` without interleaved uses before the
3195 // copy.
3196 FillOp maybeFillOp;
3197 for (auto &u : viewOrAlloc.getUses()) {
3198 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3199 assert(isa<MemRefType>(newFillOp.output().getType()));
3200 if (newFillOp.output() != viewOrAlloc)
3201 continue;
3202 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3203 continue;
3204 maybeFillOp = newFillOp;
3205 break;
3206 }
3207 }
3208 // Ensure padding matches.
3209 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3210 return rewriter.notifyMatchFailure(xferOp,
3211 "padding value does not match fill");
3212
3213 // `in` is the subview that memref.copy reads. Replace it.
3214 Value in = copyOp.getSource();
3215
3216 // memref.copy + linalg.fill can be used to create a padded local buffer.
3217 // The `masked` attribute is only valid on this padded buffer.
3218 // When forwarding to vector.transfer_read, the attribute must be reset
3219 // conservatively.
3220 auto vectorType = xferOp.getVectorType();
3221 Value res = vector::TransferReadOp::create(
3222 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3223 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3224 rewriter.getBoolArrayAttr(
3225 SmallVector<bool>(vectorType.getRank(), false)));
3226
3227 if (maybeFillOp)
3228 rewriter.eraseOp(maybeFillOp);
3229 rewriter.eraseOp(copyOp);
3230 rewriter.replaceOp(xferOp, res);
3231
3232 return success();
3233}
3234
3235/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3236/// when available.
3238 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3239 // TODO: support mask.
3240 if (xferOp.getMask())
3241 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3242
3243 // Transfer into `viewOrAlloc`.
3244 Value viewOrAlloc = xferOp.getBase();
3245 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3246 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3247 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3248
3249 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3250 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3251 if (!subViewOp)
3252 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3253 Value subView = subViewOp.getResult();
3254
3255 // Find the copy from `subView` without interleaved uses.
3256 memref::CopyOp copyOp;
3257 for (auto &u : subViewOp.getResult().getUses()) {
3258 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3259 if (newCopyOp.getSource() != subView)
3260 continue;
3261 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3262 continue;
3263 copyOp = newCopyOp;
3264 break;
3265 }
3266 }
3267 if (!copyOp)
3268 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3269
3270 // `out` is the subview copied into that we replace.
3271 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3272 Value out = copyOp.getTarget();
3273
3274 // Forward vector.transfer into copy.
3275 // memref.copy + linalg.fill can be used to create a padded local buffer.
3276 // The `masked` attribute is only valid on this padded buffer.
3277 // When forwarding to vector.transfer_write, the attribute must be reset
3278 // conservatively.
3279 auto vector = xferOp.getVector();
3280 vector::TransferWriteOp::create(
3281 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3282 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3283 rewriter.getBoolArrayAttr(SmallVector<bool>(
3284 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3285
3286 rewriter.eraseOp(copyOp);
3287 rewriter.eraseOp(xferOp);
3288
3289 return success();
3290}
3291
3292//===----------------------------------------------------------------------===//
3293// Convolution vectorization patterns
3294//===----------------------------------------------------------------------===//
3295
3296template <int N>
3297static void bindShapeDims(ShapedType shapedType) {}
3298
3299template <int N, typename IntTy, typename... IntTy2>
3300static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3301 val = shapedType.getShape()[N];
3302 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3303}
3304
3305/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3306template <typename... IntTy>
3307static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3308 bindShapeDims<0>(shapedType, vals...);
3309}
3310
3311/// Match 1D convolution or pooling operations and return their dilations and
3312/// strides. Returns std::nullopt for unrecognized ops.
3313static std::optional<DilationsAndStrides> match1DConvPoolOp(LinalgOp op) {
3314#define MATCH_1D_CONV_POOL_OP(ConvOpTy) \
3315 if (auto convParams = matchConvolutionOpOfType<ConvOpTy>(op)) \
3316 return convParams;
3317
3318 // 1D Convolution ops.
3319 MATCH_1D_CONV_POOL_OP(linalg::Conv1DOp);
3320 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNwcWcfOp);
3321 MATCH_1D_CONV_POOL_OP(linalg::Conv1DNcwFcwOp);
3322 // Depthwise 1D Convolution ops.
3323 // Note: Only NWC layout without channel multiplier is supported.
3324 // DepthwiseConv1DNcwCwOp (NCW) and DepthwiseConv1DNwcWcmOp (with multiplier)
3325 // are not supported.
3326 MATCH_1D_CONV_POOL_OP(linalg::DepthwiseConv1DNwcWcOp);
3327 // 1D Pooling ops (NWC layout).
3328 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcSumOp);
3329 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxOp);
3330 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMaxUnsignedOp);
3331 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinOp);
3332 MATCH_1D_CONV_POOL_OP(linalg::PoolingNwcMinUnsignedOp);
3333 // 1D Pooling ops (NCW layout).
3334 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwSumOp);
3335 MATCH_1D_CONV_POOL_OP(linalg::PoolingNcwMaxOp);
3336
3337#undef MATCH_1D_CONV_POOL_OP
3338
3339 return std::nullopt;
3340}
3341
3342namespace {
3343/// Generate a vector implementation for either:
3344/// ```
3345/// Op def: ( w, kw )
3346/// Iters: ({Par(), Red()})
3347/// Layout: {{w + kw}, {kw}, {w}}
3348/// ```
3349/// kw is unrolled.
3350///
3351/// or
3352///
3353/// ```
3354/// Op def: ( n, w, c, kw, f )
3355/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3356/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3357/// ```
3358/// kw is unrolled, w is unrolled iff dilationW > 1.
3359///
3360/// or
3361///
3362/// ```
3363/// Op def: ( n, c, w, f, kw )
3364/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3365/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3366/// ```
3367/// kw is unrolled, w is unrolled iff dilationW > 1.
3368///
3369/// or
3370///
3371/// ```
3372/// Op def: ( n, w, c, kw )
3373/// Iters: ({Par(), Par(), Par(), Red()})
3374/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3375/// ```
3376/// kw is unrolled, w is unrolled iff dilationW > 1.
3377struct Conv1DGenerator
3378 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3379 /// Factory method to create a Conv1DGenerator. Returns failure if the
3380 /// operation doesn't have valid strides/dilations.
3381 static FailureOr<Conv1DGenerator> create(RewriterBase &rewriter,
3382 LinalgOp linalgOp) {
3383 // Try to match a 1D conv/pool op using matchConvolutionOpOfType. This
3384 // works for both named ops and generic ops that match their semantics.
3385 std::optional<DilationsAndStrides> convParams = match1DConvPoolOp(linalgOp);
3386 if (!convParams)
3387 return failure();
3388
3389 int strideW = static_cast<int>(convParams->strides.front());
3390 int dilationW = static_cast<int>(convParams->dilations.front());
3391 return Conv1DGenerator(rewriter, linalgOp, strideW, dilationW);
3392 }
3393
3394private:
3395 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp, int strideW,
3396 int dilationW)
3397 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp),
3398 strideW(strideW), dilationW(dilationW) {
3399
3400 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3401 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3402 resShaped = linalgOp.getDpsInitOperand(0)->get();
3403 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3404 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3405 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3406
3407 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3408 redOp = reduceOp->getName().getIdentifier();
3409
3410 setConvOperationKind(reduceOp);
3411
3412 auto maybeKind = getCombinerOpKind(reduceOp);
3413 reductionKind = maybeKind.value();
3414 }
3415
3416public:
3417 /// Generate a vector implementation for:
3418 /// ```
3419 /// Op def: ( w, kw )
3420 /// Iters: ({Par(), Red()})
3421 /// Layout: {{w + kw}, {kw}, {w}}
3422 /// ```
3423 /// kw is always unrolled.
3424 ///
3425 /// or
3426 ///
3427 /// ```
3428 /// Op def: ( n, w, c, kw, f )
3429 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3430 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3431 /// ```
3432 /// kw is always unrolled.
3433 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3434 /// > 1.
3435 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3436 int64_t nSize, wSize, cSize, kwSize, fSize;
3437 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3438 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3439 switch (conv1DOpOrder) {
3440 case Conv1DOpOrder::W:
3441 // Initialize unused dimensions
3442 nSize = fSize = cSize = 0;
3443 // out{W}
3444 bindShapeDims(resShapedType, wSize);
3445 // kernel{kw}
3446 bindShapeDims(rhsShapedType, kwSize);
3447 lhsShape = {// iw = ow + kw - 1
3448 // (i.e. 16 convolved with 3 -> 14)
3449 (wSize + kwSize - 1)};
3450 rhsShape = {kwSize};
3451 resShape = {wSize};
3452 break;
3453 case Conv1DOpOrder::Nwc:
3454 // out{n, w, f}
3455 bindShapeDims(resShapedType, nSize, wSize, fSize);
3456 switch (oper) {
3457 case ConvOperationKind::Conv:
3458 // kernel{kw, c, f}
3459 bindShapeDims(rhsShapedType, kwSize, cSize);
3460 break;
3461 case ConvOperationKind::Pool:
3462 // kernel{kw}
3463 bindShapeDims(rhsShapedType, kwSize);
3464 cSize = fSize;
3465 break;
3466 }
3467 lhsShape = {nSize,
3468 // iw = ow * sw + kw * dw - 1
3469 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3470 // Perform the proper inclusive -> exclusive -> inclusive.
3471 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3472 1,
3473 cSize};
3474 switch (oper) {
3475 case ConvOperationKind::Conv:
3476 rhsShape = {kwSize, cSize, fSize};
3477 break;
3478 case ConvOperationKind::Pool:
3479 rhsShape = {kwSize};
3480 break;
3481 }
3482 resShape = {nSize, wSize, fSize};
3483 break;
3484 case Conv1DOpOrder::Ncw:
3485 // out{n, f, w}
3486 bindShapeDims(resShapedType, nSize, fSize, wSize);
3487 switch (oper) {
3488 case ConvOperationKind::Conv:
3489 // kernel{f, c, kw}
3490 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3491 break;
3492 case ConvOperationKind::Pool:
3493 // kernel{kw}
3494 bindShapeDims(rhsShapedType, kwSize);
3495 cSize = fSize;
3496 break;
3497 }
3498 lhsShape = {nSize, cSize,
3499 // iw = ow * sw + kw * dw - 1
3500 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3501 // Perform the proper inclusive -> exclusive -> inclusive.
3502 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3503 1};
3504 switch (oper) {
3505 case ConvOperationKind::Conv:
3506 rhsShape = {fSize, cSize, kwSize};
3507 break;
3508 case ConvOperationKind::Pool:
3509 rhsShape = {kwSize};
3510 break;
3511 }
3512 resShape = {nSize, fSize, wSize};
3513 break;
3514 }
3515
3516 vector::TransferWriteOp write;
3517 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3518
3519 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3520 // When strideW == 1, we can batch the contiguous loads and avoid
3521 // unrolling
3522 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3523
3524 Type lhsEltType = lhsShapedType.getElementType();
3525 Type rhsEltType = rhsShapedType.getElementType();
3526 Type resEltType = resShapedType.getElementType();
3527 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3528 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3529 auto resType = VectorType::get(resShape, resEltType);
3530 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3531 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3532 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3533 SmallVector<Value> resPadding(resShape.size(), zero);
3534
3535 // Read the whole lhs, rhs and res in one shot (with zero padding).
3536 Value lhs = vector::TransferReadOp::create(
3537 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3538 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3539 // This is needed only for Conv.
3540 Value rhs = nullptr;
3541 if (oper == ConvOperationKind::Conv)
3542 rhs = vector::TransferReadOp::create(
3543 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3544 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3545 Value res = vector::TransferReadOp::create(
3546 rewriter, loc, resType, resShaped, resPadding,
3547 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3548
3549 // The base vectorization case for channeled convolution is input:
3550 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3551 // vectorization case, we do pre transpose on input, weight, and output.
3552 switch (conv1DOpOrder) {
3553 case Conv1DOpOrder::W:
3554 case Conv1DOpOrder::Nwc:
3555 // Base case, so no transposes necessary.
3556 break;
3557 case Conv1DOpOrder::Ncw: {
3558 // To match base vectorization case, we pre-transpose current case.
3559 // ncw -> nwc
3560 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3561 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3562 // fcw -> wcf
3563 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3564
3565 // This is needed only for Conv.
3566 if (oper == ConvOperationKind::Conv)
3567 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3568 // nfw -> nwf
3569 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3570 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3571 break;
3572 }
3573 }
3574
3575 //===------------------------------------------------------------------===//
3576 // Begin vector-only rewrite part
3577 //===------------------------------------------------------------------===//
3578 // Unroll along kw and read slices of lhs and rhs.
3579 SmallVector<Value> lhsVals, rhsVals, resVals;
3580 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3581 kwSize, strideW, dilationW, wSizeStep,
3582 isSingleChanneled);
3583 // Do not do for pooling.
3584 if (oper == ConvOperationKind::Conv)
3585 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3586 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3587 wSizeStep, isSingleChanneled);
3588
3589 auto linearIndex = [&](int64_t kw, int64_t w) {
3590 return kw * (wSize / wSizeStep) + w;
3591 };
3592
3593 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3594 // or perform outerproduct for non-channeled convolution or perform simple
3595 // arith operation for pooling
3596 for (int64_t kw = 0; kw < kwSize; ++kw) {
3597 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3598 switch (oper) {
3599 case ConvOperationKind::Conv:
3600 if (isSingleChanneled) {
3601 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3602 lhsVals[linearIndex(kw, w)],
3603 rhsVals[kw], resVals[w]);
3604 } else {
3605 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3606 lhsVals[linearIndex(kw, w)],
3607 rhsVals[kw], resVals[w]);
3608 }
3609 break;
3610 case ConvOperationKind::Pool:
3611 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3612 resVals[w]);
3613 break;
3614 }
3615 }
3616 }
3617
3618 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3619 isSingleChanneled);
3620 //===------------------------------------------------------------------===//
3621 // End vector-only rewrite part
3622 //===------------------------------------------------------------------===//
3623
3624 // The base vectorization case for channeled convolution is output:
3625 // {n,w,f} To reuse the result from base pattern vectorization case, we
3626 // post transpose the base case result.
3627 switch (conv1DOpOrder) {
3628 case Conv1DOpOrder::W:
3629 case Conv1DOpOrder::Nwc:
3630 // Base case, so no transposes necessary.
3631 break;
3632 case Conv1DOpOrder::Ncw: {
3633 // nwf -> nfw
3634 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3635 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3636 break;
3637 }
3638 }
3639
3640 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3641 resPadding)
3642 .getOperation();
3643 }
3644
3645 // Take a value and widen to have the same element type as `ty`.
3646 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3647 const Type srcElementType = getElementTypeOrSelf(val.getType());
3648 const Type dstElementType = getElementTypeOrSelf(ty);
3649 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3650 if (srcElementType == dstElementType)
3651 return val;
3652
3653 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3654 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3655 // Handle both shaped as well as scalar types.
3656 Type dstType;
3657 if (auto shapedType = dyn_cast<ShapedType>(val.getType()))
3658 dstType = shapedType.cloneWith(std::nullopt, dstElementType);
3659 else
3660 dstType = dstElementType;
3661
3662 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3663 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3664 }
3665
3666 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3667 srcWidth < dstWidth)
3668 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3669
3670 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3671 srcWidth < dstWidth)
3672 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3673
3674 assert(false && "unhandled promotion case");
3675 return nullptr;
3676 }
3677
3678 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3679 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3680 Value lhs, Value rhs, Value res) {
3681 vector::IteratorType par = vector::IteratorType::parallel;
3682 vector::IteratorType red = vector::IteratorType::reduction;
3683 AffineExpr n, w, f, c;
3684 bindDims(ctx, n, w, f, c);
3685 lhs = promote(rewriter, loc, lhs, res.getType());
3686 rhs = promote(rewriter, loc, rhs, res.getType());
3687 auto contrationOp = vector::ContractionOp::create(
3688 rewriter, loc, lhs, rhs, res,
3689 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3690 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3691 contrationOp.setKind(reductionKind);
3692 return contrationOp;
3693 }
3694
3695 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3696 // convolution.
3697 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3698 Value lhs, Value rhs, Value res) {
3699 lhs = promote(rewriter, loc, lhs, res.getType());
3700 rhs = promote(rewriter, loc, rhs, res.getType());
3701 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3702 rhs, res, vector::CombiningKind::ADD);
3703 }
3704
3705 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3706 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3707 Value res) {
3708 if (isPoolExt)
3709 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3710 return rewriter
3711 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3712 ->getResult(0);
3713 }
3714
3715 /// Generate a vector implementation for:
3716 /// ```
3717 /// Op def: ( n, w, c, kw)
3718 /// Iters: ({Par(), Par(), Par(), Red()})
3719 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3720 /// ```
3721 /// kw is always unrolled.
3722 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3723 /// > 1.
3724 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3725 bool channelDimScalableFlag,
3726 bool flatten) {
3727 bool scalableChDim = false;
3728 bool useMasking = false;
3729 int64_t nSize, wSize, cSize, kwSize;
3730 // kernel{kw, c}
3731 bindShapeDims(rhsShapedType, kwSize, cSize);
3732 if (ShapedType::isDynamic(cSize)) {
3733 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3734 cSize = channelDimVecSize;
3735 // Scalable vectors are only used when both conditions are met:
3736 // 1. channel dim is dynamic
3737 // 2. channelDimScalableFlag is set
3738 scalableChDim = channelDimScalableFlag;
3739 useMasking = true;
3740 }
3741
3742 assert(!(useMasking && flatten) &&
3743 "Unsupported flattened conv with dynamic shapes");
3744
3745 // out{n, w, c}
3746 bindShapeDims(resShapedType, nSize, wSize);
3747
3748 vector::TransferWriteOp write;
3749 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3750
3751 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3752 // When strideW == 1, we can batch the contiguous loads and avoid
3753 // unrolling
3754 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3755
3756 Type lhsEltType = lhsShapedType.getElementType();
3757 Type rhsEltType = rhsShapedType.getElementType();
3758 Type resEltType = resShapedType.getElementType();
3759 VectorType lhsType = VectorType::get(
3760 {nSize,
3761 // iw = ow * sw + kw * dw - 1
3762 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3763 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3764 cSize},
3765 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3766 VectorType rhsType =
3767 VectorType::get({kwSize, cSize}, rhsEltType,
3768 /*scalableDims=*/{false, scalableChDim});
3769 VectorType resType =
3770 VectorType::get({nSize, wSize, cSize}, resEltType,
3771 /*scalableDims=*/{false, false, scalableChDim});
3772
3773 // Masks the input xfer Op along the channel dim, iff the corresponding
3774 // scalable flag is set.
3775 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3776 ArrayRef<bool> scalableDims,
3777 Operation *opToMask) {
3778 if (!useMasking)
3779 return opToMask;
3780 auto maskType =
3781 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3782
3783 SmallVector<bool> inBounds(maskShape.size(), true);
3784 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3785 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3786 rewriter.getBoolArrayAttr(inBounds));
3787
3788 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3789 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3790
3791 Value maskOp =
3792 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3793
3794 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3795 };
3796
3797 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3798 // 0].
3799 Value lhs = vector::TransferReadOp::create(
3800 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3801 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3802 auto *maybeMaskedLhs = maybeMaskXferOp(
3803 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3804
3805 // Read rhs slice of size {kw, c} @ [0, 0].
3806 Value rhs = vector::TransferReadOp::create(
3807 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3808 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3809 auto *maybeMaskedRhs = maybeMaskXferOp(
3810 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3811
3812 // Read res slice of size {n, w, c} @ [0, 0, 0].
3813 Value res = vector::TransferReadOp::create(
3814 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3815 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3816 auto *maybeMaskedRes = maybeMaskXferOp(
3817 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3818
3819 //===------------------------------------------------------------------===//
3820 // Begin vector-only rewrite part
3821 //===------------------------------------------------------------------===//
3822 // Unroll along kw and read slices of lhs and rhs.
3823 SmallVector<Value> lhsVals, rhsVals, resVals;
3824 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3825 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3826
3827 // Extract lhs slice of size {n, wSizeStep, c}
3828 // @ [0, sw * w + dw * kw, 0].
3829 for (int64_t kw = 0; kw < kwSize; ++kw) {
3830 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3831 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3832 rewriter, loc, maybeMaskedLhs->getResult(0),
3833 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3834 inOutSliceSizes, inOutStrides));
3835 }
3836 }
3837 // Extract rhs slice of size {c} @ [kw].
3838 for (int64_t kw = 0; kw < kwSize; ++kw) {
3839 rhsVals.push_back(
3840 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3841 /*offsets=*/ArrayRef<int64_t>{kw}));
3842 }
3843 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3844 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3845 resVals.push_back(vector::ExtractStridedSliceOp::create(
3846 rewriter, loc, maybeMaskedRes->getResult(0),
3847 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3848 inOutStrides));
3849 }
3850
3851 auto linearIndex = [&](int64_t kw, int64_t w) {
3852 return kw * (wSize / wSizeStep) + w;
3853 };
3854
3855 // Note - the scalable flags are ignored as flattening combined with
3856 // scalable vectorization is not supported.
3857 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3858 auto lhsTypeAfterFlattening =
3859 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3860 auto resTypeAfterFlattening =
3861 VectorType::get(inOutFlattenSliceSizes, resEltType);
3862
3863 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3864 for (int64_t kw = 0; kw < kwSize; ++kw) {
3865 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3866 Value lhsVal = lhsVals[linearIndex(kw, w)];
3867 Value resVal = resVals[w];
3868 if (flatten) {
3869 // Flatten the input and output vectors (collapse the channel
3870 // dimension)
3871 lhsVal =
3872 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3873 lhsVals[linearIndex(kw, w)]);
3874 resVal = vector::ShapeCastOp::create(
3875 rewriter, loc, resTypeAfterFlattening, resVals[w]);
3876 }
3877 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
3878 rhsVals[kw], resVal, flatten);
3879 if (flatten) {
3880 // Un-flatten the output vector (restore the channel dimension)
3881 resVals[w] = vector::ShapeCastOp::create(
3882 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
3883 resVals[w]);
3884 }
3885 }
3886 }
3887
3888 // Its possible we failed to create the Fma.
3889 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
3890 // Manually revert (in reverse order) to avoid leaving a bad IR state.
3891 for (auto &collection :
3892 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
3893 for (Value v : collection)
3894 rewriter.eraseOp(v.getDefiningOp());
3895 return rewriter.notifyMatchFailure(op, "failed to create FMA");
3896 }
3897
3898 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
3899 // This does not depend on kw.
3900 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3901 maybeMaskedRes = vector::InsertStridedSliceOp::create(
3902 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
3903 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
3904 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
3905 }
3906 //===------------------------------------------------------------------===//
3907 // End vector-only rewrite part
3908 //===------------------------------------------------------------------===//
3909
3910 // Write back res slice of size {n, w, c} @ [0, 0, 0].
3911 Operation *resOut = vector::TransferWriteOp::create(
3912 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
3913 ValueRange{zero, zero, zero});
3914 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
3915 resOut);
3916 }
3917
3918 /// Lower:
3919 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
3920 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
3921 /// to MulAcc.
3922 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
3923 Value lhs, Value rhs, Value res,
3924 bool flatten) {
3925 auto rhsTy = cast<ShapedType>(rhs.getType());
3926 auto resTy = cast<ShapedType>(res.getType());
3927
3928 // TODO(suderman): Change this to use a vector.ima intrinsic.
3929 lhs = promote(rewriter, loc, lhs, resTy);
3930
3931 if (flatten) {
3932 // NOTE: This following logic won't work for scalable vectors. For this
3933 // reason, "flattening" is not supported when shapes are dynamic (this
3934 // should be captured by one of the pre-conditions).
3935
3936 // There are two options for handling the filter:
3937 // * shape_cast(broadcast(filter))
3938 // * broadcast(shuffle(filter))
3939 // Opt for the option without shape_cast to simplify the codegen.
3940 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
3941 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
3942
3943 SmallVector<int64_t, 16> indices;
3944 for (int i = 0; i < resSize / rhsSize; ++i) {
3945 for (int j = 0; j < rhsSize; ++j)
3946 indices.push_back(j);
3947 }
3948
3949 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
3950 }
3951 // Broadcast the filter to match the output vector
3952 rhs = vector::BroadcastOp::create(rewriter, loc,
3953 resTy.clone(rhsTy.getElementType()), rhs);
3954
3955 rhs = promote(rewriter, loc, rhs, resTy);
3956
3957 if (!lhs || !rhs)
3958 return nullptr;
3959
3960 if (isa<FloatType>(resTy.getElementType()))
3961 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
3962
3963 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
3964 return arith::AddIOp::create(rewriter, loc, mul, res);
3965 }
3966
3967 /// Entry point for non-channeled convolution:
3968 /// {{w + kw}, {kw}, {w}}
3969 FailureOr<Operation *> generateNonChanneledConv() {
3970 AffineExpr w, kw;
3971 bindDims(ctx, w, kw);
3972 if (!iters({Par(), Red()}))
3973 return rewriter.notifyMatchFailure(op,
3974 "failed to match conv::W 1-par 1-red");
3975
3976 // No transposition needed.
3977 if (layout({/*lhsIndex*/ {w + kw},
3978 /*rhsIndex*/ {kw},
3979 /*resIndex*/ {w}}))
3980 return conv(Conv1DOpOrder::W);
3981
3982 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
3983 }
3984
3985 /// Entry point that transposes into the common form:
3986 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3987 FailureOr<Operation *> generateNwcConv() {
3988 AffineExpr n, w, f, kw, c;
3989 bindDims(ctx, n, w, f, kw, c);
3990 if (!iters({Par(), Par(), Par(), Red(), Red()}))
3991 return rewriter.notifyMatchFailure(
3992 op, "failed to match conv::Nwc 3-par 2-red");
3993
3994 // No transposition needed.
3995 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3996 /*rhsIndex*/ {kw, c, f},
3997 /*resIndex*/ {n, w, f}}))
3998 return conv(Conv1DOpOrder::Nwc);
3999
4000 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4001 }
4002
4003 /// Entry point that transposes into the common form:
4004 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4005 FailureOr<Operation *> generateNcwConv() {
4006 AffineExpr n, w, f, kw, c;
4007 bindDims(ctx, n, f, w, c, kw);
4008 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4009 return rewriter.notifyMatchFailure(
4010 op, "failed to match conv::Ncw 3-par 2-red");
4011
4012 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4013 /*rhsIndex*/ {f, c, kw},
4014 /*resIndex*/ {n, f, w}}))
4015 return conv(Conv1DOpOrder::Ncw);
4016
4017 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4018 }
4019
4020 /// Entry point that transposes into the common form:
4021 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4022 FailureOr<Operation *> generateNwcPooling() {
4023 AffineExpr n, w, c, kw;
4024 bindDims(ctx, n, w, c, kw);
4025 if (!iters({Par(), Par(), Par(), Red()}))
4026 return rewriter.notifyMatchFailure(op,
4027 "failed to match pooling 3-par 1-red");
4028
4029 // No transposition needed.
4030 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4031 /*rhsIndex*/ {kw},
4032 /*resIndex*/ {n, w, c}}))
4033 return conv(Conv1DOpOrder::Nwc);
4034
4035 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4036 }
4037
4038 /// Entry point that transposes into the common form:
4039 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4040 FailureOr<Operation *> generateNcwPooling() {
4041 AffineExpr n, w, c, kw;
4042 bindDims(ctx, n, c, w, kw);
4043 if (!iters({Par(), Par(), Par(), Red()}))
4044 return rewriter.notifyMatchFailure(op,
4045 "failed to match pooling 3-par 1-red");
4046
4047 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4048 /*rhsIndex*/ {kw},
4049 /*resIndex*/ {n, c, w}}))
4050 return conv(Conv1DOpOrder::Ncw);
4051
4052 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4053 }
4054
4055 /// Entry point that transposes into the common form:
4056 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4057 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4058 bool vecChDimScalableFlag = false,
4059 bool flatten = false) {
4060 AffineExpr n, w, c, kw;
4061 bindDims(ctx, n, w, c, kw);
4062 if (!iters({Par(), Par(), Par(), Red()}))
4063 return rewriter.notifyMatchFailure(
4064 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4065
4066 // No transposition needed.
4067 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4068 /*rhsIndex*/ {kw, c},
4069 /*resIndex*/ {n, w, c}}))
4070 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4071
4072 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4073 }
4074
4075private:
4076 ConvOperationKind oper = ConvOperationKind::Conv;
4077 StringAttr redOp;
4078 StringAttr poolExtOp;
4079 bool isPoolExt = false;
4080 int strideW, dilationW;
4081 Value lhsShaped, rhsShaped, resShaped;
4082 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4083 vector::CombiningKind reductionKind;
4084
4085 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4086 void setConvOperationKind(Operation *reduceOp) {
4087 int numBlockArguments =
4088 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4089 if (numBlockArguments == 1) {
4090 // Will be convolution if feeder is a MulOp.
4091 // A strength reduced version of MulOp for i1 type is AndOp which is also
4092 // supported. Otherwise, it can be pooling. This strength reduction logic
4093 // is in `buildBinaryFn` helper in the Linalg dialect.
4094 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4095 llvm::IsaPred<BlockArgument>);
4096 Operation *feedOp = (*feedValIt).getDefiningOp();
4097 if (isCastOfBlockArgument(feedOp)) {
4098 oper = ConvOperationKind::Pool;
4099 isPoolExt = true;
4100 poolExtOp = feedOp->getName().getIdentifier();
4101 return;
4102 }
4103 oper = ConvOperationKind::Conv;
4104 return;
4105 }
4106 // numBlockArugments == 2 and this is a pooling op.
4107 oper = ConvOperationKind::Pool;
4108 isPoolExt = false;
4109 }
4110};
4111} // namespace
4112
4113/// Helper function to vectorize a LinalgOp with convolution semantics.
4114// TODO: extend the generic vectorization to support windows and drop this.
4115static FailureOr<Operation *> vectorizeConvolution(
4116 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4117 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4118 FailureOr<Conv1DGenerator> conv1dGen = Conv1DGenerator::create(rewriter, op);
4119 if (failed(conv1dGen))
4120 return failure();
4121 auto res = conv1dGen->generateNonChanneledConv();
4122 if (succeeded(res))
4123 return res;
4124 res = conv1dGen->generateNwcConv();
4125 if (succeeded(res))
4126 return res;
4127 res = conv1dGen->generateNcwConv();
4128 if (succeeded(res))
4129 return res;
4130 res = conv1dGen->generateNwcPooling();
4131 if (succeeded(res))
4132 return res;
4133 res = conv1dGen->generateNcwPooling();
4134 if (succeeded(res))
4135 return res;
4136
4137 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4138 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4139 // masked/scalable) is the channel dim (i.e. the trailing dim).
4140 uint64_t vecChDimSize = ShapedType::kDynamic;
4141 bool vecChDimScalableFlag = false;
4142 if (!inputVecSizes.empty()) {
4143 // Only use the input vector size corresponding to the channel dim. Other
4144 // vector dims will be inferred from the Ops.
4147 "Not a 1D depthwise conv!");
4148 size_t chDimIdx = 0;
4150 chDimIdx = 2;
4152 chDimIdx = 1;
4153
4154 vecChDimSize = inputVecSizes[chDimIdx];
4155 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4156 }
4157 return conv1dGen->generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4158 flatten1DDepthwiseConv);
4159}
4160
4161struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4163
4164 LogicalResult matchAndRewrite(LinalgOp op,
4165 PatternRewriter &rewriter) const override {
4166 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4167 if (failed(resultOrFail))
4168 return failure();
4169 Operation *newOp = *resultOrFail;
4170 if (newOp->getNumResults() == 0) {
4171 rewriter.eraseOp(op.getOperation());
4172 return success();
4173 }
4174 assert(newOp->getNumResults() == 1 && "expected single result");
4175 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4176 return success();
4177 }
4178};
4179
4181 RewritePatternSet &patterns, PatternBenefit benefit) {
4182 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4183}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
#define MATCH_1D_CONV_POOL_OP(ConvOpTy)
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
operand_iterator operand_end()
Definition Operation.h:401
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:197
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType(LinalgOp op)
Returns true if the linalg op is a convolution op of type ConvOpTy.
Definition Utils.h:126
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Create a TransferWriteOp of vecToStore into dest.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override