MLIR 22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
34#include "mlir/IR/Value.h"
35#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/Sequence.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/InterleavedRange.h"
43#include "llvm/Support/MathExtras.h"
44#include "llvm/Support/raw_ostream.h"
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::linalg;
49
50#define DEBUG_TYPE "linalg-vectorization"
51
52/// Try to vectorize `convOp` as a convolution.
53static FailureOr<Operation *>
54vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55 ArrayRef<int64_t> inputVecSizes = {},
56 ArrayRef<bool> inputVecScalableFlags = {},
57 bool flatten1DDepthwiseConv = false);
58
59/// Vectorize tensor::InsertSliceOp with:
60/// * vector::TransferReadOp + vector::TransferWriteOp
61/// The vector sizes are either:
62/// * user-provided in `inputVectorSizes`, or
63/// * inferred from the static dims in the input and output tensors.
64/// Bails out if:
65/// * vector sizes are not user-provided, and
66/// * at least one dim is dynamic (in both the input and output tensors).
67///
68/// Before:
69/// !t_in_type = tensor<1x2x3xf32>
70/// !t_out_type = tensor<9x8x7x1x2x3xf32>
71/// !v_type = vector<1x2x3xf32>
72/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73/// into !t_out_type
74/// After:
75/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77static LogicalResult
78vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79 ArrayRef<int64_t> inputVectorSizes,
80 SmallVectorImpl<Value> &newResults);
81
82/// Returns the effective Pad value for the input op, provided it's a scalar.
83///
84/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85/// this Op performs padding, retrieve the padding value provided that it's
86/// a scalar and static/fixed for all the padded values. Returns an empty value
87/// otherwise.
89
90/// Return the unique instance of OpType in `block` if it is indeed unique.
91/// Return null if none or more than 1 instances exist.
92template <typename OpType>
93static OpType getSingleOpOfType(Block &block) {
94 OpType res;
95 block.walk([&](OpType op) {
96 if (res) {
97 res = nullptr;
98 return WalkResult::interrupt();
99 }
100 res = op;
101 return WalkResult::advance();
102 });
103 return res;
104}
105
106/// Helper function to extract the input slices after filter is unrolled along
107/// kw.
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize, int strideW, int dilationW,
112 int64_t wSizeStep, bool isSingleChanneled) {
114 if (isSingleChanneled) {
115 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116 // convolution.
117 SmallVector<int64_t> sizes = {wSizeStep};
118 SmallVector<int64_t> strides = {1};
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
122 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123 strides));
124 }
125 }
126 } else {
127 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128 // for channeled convolution.
129 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130 SmallVector<int64_t> strides = {1, 1, 1};
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
135 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136 sizes, strides));
137 }
138 }
139 }
140 return result;
141}
142
143/// Helper function to extract the filter slices after filter is unrolled along
144/// kw.
146 Location loc, Value filter,
147 int64_t kwSize) {
149 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150 // non-chanelled convolution] @ [kw].
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
153 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154 }
155 return result;
156}
157
158/// Helper function to extract the result slices after filter is unrolled along
159/// kw.
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep, bool isSingleChanneled) {
165 if (isSingleChanneled) {
166 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167 SmallVector<int64_t> sizes = {wSizeStep};
168 SmallVector<int64_t> strides = {1};
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
171 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172 strides));
173 }
174 } else {
175 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176 // convolution.
177 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178 SmallVector<int64_t> strides = {1, 1, 1};
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
181 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182 strides));
183 }
184 }
185 return result;
186}
187
188/// Helper function to insert the computed result slices.
190 Value res, int64_t wSize, int64_t wSizeStep,
191 SmallVectorImpl<Value> &resVals,
192 bool isSingleChanneled) {
193
194 if (isSingleChanneled) {
195 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196 // This does not depend on kw.
197 SmallVector<int64_t> strides = {1};
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
200 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201 strides);
202 }
203 } else {
204 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205 // convolution. This does not depend on kw.
206 SmallVector<int64_t> strides = {1, 1, 1};
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
210 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211 }
212 }
213 return res;
214}
215
216/// Contains the vectorization state and related methods used across the
217/// vectorization process of a given operation.
219 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220
221 /// Initializes the vectorization state, including the computation of the
222 /// canonical vector shape for vectorization.
223 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224 ArrayRef<int64_t> inputVectorSizes,
225 ArrayRef<bool> inputScalableVecDims,
226 bool assumeDynamicDimsMatchVecSizes = false);
227
228 /// Returns the canonical vector shape used to vectorize the iteration space.
229 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230
231 /// Returns the vector dimensions that are scalable in the canonical vector
232 /// shape.
233 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234
235 /// Returns a vector type of the provided `elementType` with the canonical
236 /// vector shape and the corresponding fixed/scalable dimensions bit. If
237 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238 /// accordingly.
240 Type elementType,
241 std::optional<AffineMap> dimPermutation = std::nullopt) const {
243 SmallVector<bool> scalableDims;
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247 scalableDims =
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249 } else {
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252 }
253
254 return VectorType::get(vectorShape, elementType, scalableDims);
255 }
256
257 /// Masks an operation with the canonical vector mask if the operation needs
258 /// masking. Returns the masked operation or the original operation if masking
259 /// is not needed. If provided, the canonical mask for this operation is
260 /// permuted using `maybeIndexingMap`.
261 Operation *
262 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264
265private:
266 /// Initializes the iteration space static sizes using the Linalg op
267 /// information. This may become more complicated in the future.
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270 }
271
272 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273 /// all the static and dynamic dimensions of the iteration space to be
274 /// vectorized and store them in `iterSpaceValueSizes`.
275 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276 LinalgOp linalgOp);
277
278 /// Create or retrieve an existing mask value to mask `opToMask` in the
279 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280 /// permuted using that permutation map. If a new mask is created, it will be
281 /// cached for future users.
282 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283 LinalgOp linalgOp,
284 std::optional<AffineMap> maybeMaskingMap);
285
286 /// Check whether this permutation map can be used for masking. At the
287 /// moment we only make sure that there are no broadcast dimensions, but this
288 /// might change if indexing maps evolve.
289 bool isValidMaskingMap(AffineMap maskingMap) {
290 return maskingMap.getBroadcastDims().empty();
291 }
292
293 /// Turn the input indexing map into a valid masking map.
294 ///
295 /// The input indexing map may contain "zero" results, e.g.:
296 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297 /// Applying such maps to canonical vector shapes like this one:
298 /// (1, 16, 16, 4)
299 /// would yield an invalid vector shape like this:
300 /// (16, 16, 1, 0)
301 /// Instead, drop the broadcasting dims that make no sense for masking perm.
302 /// maps:
303 /// (d0, d1, d2, d3) -> (d2, d1, d0)
304 /// This way, the corresponding vector/mask type will be:
305 /// vector<16x16x1xty>
306 /// rather than this invalid Vector type:
307 /// vector<16x16x1x0xty>
308 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309 return indexingMap.dropZeroResults();
310 }
311
312 // Holds the compile-time static sizes of the iteration space to vectorize.
313 // Dynamic dimensions are represented using ShapedType::kDynamic.
314 SmallVector<int64_t> iterSpaceStaticSizes;
315
316 /// Holds the value sizes of the iteration space to vectorize. Static
317 /// dimensions are represented by 'arith.constant' and dynamic
318 /// dimensions by 'tensor/memref.dim'.
319 SmallVector<Value> iterSpaceValueSizes;
320
321 /// Holds the canonical vector shape used to vectorize the iteration space.
322 SmallVector<int64_t> canonicalVecShape;
323
324 /// Holds the vector dimensions that are scalable in the canonical vector
325 /// shape.
326 SmallVector<bool> scalableVecDims;
327
328 /// Holds the active masks for permutations of the canonical vector iteration
329 /// space.
330 DenseMap<AffineMap, Value> activeMaskCache;
331
332 /// Global vectorization guard for the incoming rewriter. It's initialized
333 /// when the vectorization state is initialized.
334 OpBuilder::InsertionGuard rewriterGuard;
335
336 /// Do all dynamic dims match the corresponding vector sizes?
337 ///
338 /// When a dynamic tensor/memref dimension matches the corresponding vector
339 /// dimension, masking can be safely skipped, despite the presence of dynamic
340 /// shapes. Use this flag with care and only for cases where you are
341 /// confident the assumption holds.
342 bool assumeDynamicDimsMatchVecSizes = false;
343};
344
345LogicalResult
346VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347 LinalgOp linalgOp) {
348 // TODO: Support 0-d vectors.
349 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351 // Create constant index op for static dimensions.
352 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354 continue;
355 }
356
357 // Find an operand defined on this dimension of the iteration space to
358 // extract the runtime dimension size.
359 Value operand;
360 unsigned operandDimPos;
361 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362 operandDimPos)))
363 return failure();
364
365 Value dynamicDim =
366 linalgOp.hasPureTensorSemantics()
367 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368 operandDimPos)
369 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 operandDimPos);
371 iterSpaceValueSizes.push_back(dynamicDim);
372 }
373
374 return success();
375}
376
377/// Initializes the vectorization state, including the computation of the
378/// canonical vector shape for vectorization.
379// TODO: Move this to the constructor when we can remove the failure cases.
381 LinalgOp linalgOp,
382 ArrayRef<int64_t> inputVectorSizes,
383 ArrayRef<bool> inputScalableVecDims,
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386 // Initialize the insertion point.
387 rewriter.setInsertionPoint(linalgOp);
388
389 if (!inputVectorSizes.empty()) {
390 // Get the canonical vector shape from the input vector sizes provided. This
391 // path should be taken to vectorize code with dynamic shapes and when using
392 // vector sizes greater than the iteration space sizes.
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
396 } else {
397 // Compute the canonical vector shape from the operation shape. If there are
398 // dynamic shapes, the operation won't be vectorized. We assume all the
399 // vector dimensions are fixed.
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(), false);
402 }
403
404 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406
407 if (ShapedType::isDynamicShape(canonicalVecShape))
408 return failure();
409
410 // Initialize iteration space static sizes.
411 initIterSpaceStaticSizes(linalgOp);
412
413 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414 // all the static and dynamic dimensions of the iteration space, needed to
415 // compute a mask during vectorization.
416 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417 return failure();
418
419 return success();
420}
421
422/// Create or retrieve an existing mask value to mask `opToMask` in the
423/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424/// using that permutation map. If a new mask is created, it will be cached for
425/// future users.
426Value VectorizationState::getOrCreateMaskFor(
427 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428 std::optional<AffineMap> maybeMaskingMap) {
429
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
432
433 // No mask is needed if the operation is not maskable.
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435 if (!maskableOp)
436 return Value();
437
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
440
441 // If no masking map was provided, use an identity map with the loop dims.
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
444 AffineMap maskingMap =
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.getContext());
448
449 LDBG() << "Masking map: " << maskingMap;
450
451 // Return the active mask for the masking map of this operation if it was
452 // already created.
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() << "Reusing mask: " << mask;
457 return mask;
458 }
459
460 // Compute permuted projection of the iteration space to be masked and the
461 // corresponding mask shape. If the resulting iteration space dimensions are
462 // static and identical to the mask shape, masking is not needed for this
463 // operation.
464 // TODO: Improve this check. Only projected permutation indexing maps are
465 // supported.
466 SmallVector<int64_t> permutedStaticSizes =
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
470
471 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472
473 if (permutedStaticSizes == maskShape) {
474 LDBG() << "Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] = Value();
476 return Value();
477 }
478
479 if (assumeDynamicDimsMatchVecSizes) {
480 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481 // vector sizes match, we still need to check the _static_ dim sizes. Only
482 // then we can be 100% sure that masking is not required.
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484 [](auto it) {
485 return std::get<0>(it) == ShapedType::kDynamic
486 ? true
487 : std::get<0>(it) == std::get<1>(it);
488 })) {
489 LDBG()
490 << "Dynamic + static dimensions match vector sizes, masking is not "
491 "required.";
492 activeMaskCache[maskingMap] = Value();
493 return Value();
494 }
495 }
496
497 // Permute the iteration space value sizes to compute the mask upper bounds.
498 SmallVector<Value> upperBounds =
499 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
502
503 // Create the mask based on the dimension values.
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() << "Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
508 return mask;
509}
510
511Operation *
513 LinalgOp linalgOp,
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() << "Trying to mask: " << *opToMask;
516
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520
521 // Create or retrieve mask for this operation.
522 Value mask =
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524
525 if (!mask) {
526 LDBG() << "No mask required";
527 if (assumeDynamicDimsMatchVecSizes) {
529 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
530 [&](auto xferOp) {
531 // For vector.transfer_read and vector.transfer_write, there is
532 // also the `in-bounds` attribute that has to be set explicitly
533 // to true. Otherwise, "out-of-bounds" access will be assumed
534 // and masks will be generated while lowering these.
535 LDBG() << "Assuming dynamic dimensions match vector sizes and "
536 "setting their in-bounds to true!";
537 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
538 ShapedType xferType = xferOp.getShapedType();
539 AffineMap permMap = xferOp.getPermutationMap();
540 // Only set the in-bounds values to true for dynamic dims.
541 // Different mechanisms will set these accordingly for the
542 // static dims.
543 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
545 // Skip broadcast dimensions.
546 if (!dimExpr)
547 continue;
548 unsigned pos = dimExpr.getPosition();
549 if (xferType.isDynamicDim(pos))
550 inBoundsMap[i] = true;
551 }
552 rewriter.modifyOpInPlace(xferOp, [&]() {
553 xferOp.setInBoundsAttr(
554 rewriter.getBoolArrayAttr(inBoundsMap));
555 });
556 })
557 .Default([](Operation *op) {
558 // No-op if the operation is not an xfer read or write.
559 });
560 }
561 return opToMask;
562 }
563
564 // Wrap the operation with a new `vector.mask` and update D-U chain.
565 assert(opToMask && "Expected a valid operation to mask");
566 auto maskOp = cast<vector::MaskOp>(
567 mlir::vector::maskOperation(rewriter, opToMask, mask));
568 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
569
570 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
571 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
572 maskOpTerminator);
573
574 LDBG() << "Masked operation: " << *maskOp;
575 return maskOp;
576}
577
578/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
579/// projectedPermutation, compress the unused dimensions to serve as a
580/// permutation_map for a vector transfer operation.
581/// For example, given a linalg op such as:
582///
583/// ```
584/// %0 = linalg.generic {
585/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
586/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
587/// }
588/// ins(%0 : tensor<2x3x4xf32>)
589/// outs(%1 : tensor<5x6xf32>)
590/// ```
591///
592/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
593/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
594/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
596 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
597 "expected projected permutation");
598 auto res = compressUnusedDims(map);
599 assert(res.getNumDims() ==
600 (res.getNumResults() - res.getNumOfZeroResults()) &&
601 "expected reindexed map with same number of dims and results");
602 return res;
603}
604
605/// Helper enum to represent conv1d input traversal order.
606enum class Conv1DOpOrder {
607 W, // Corresponds to non-channeled 1D convolution operation.
608 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
609 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
610};
611
612/// Helper data structure to represent the result of vectorization for a single
613/// operation. In certain specific cases, like terminators, we do not want to
614/// propagate.
616 /// Op failed to vectorize.
618 /// Op vectorized and custom function took care of replacement logic
620 /// Op vectorized into a new Op whose results will replace original Op's
621 /// results.
623 // TODO: support values if Op vectorized to Many-Ops whose results we need to
624 // aggregate for replacement.
625};
626/// VectorizationHookResult contains the vectorized op returned from a
627/// CustomVectorizationHook. This is an internal implementation detail of
628/// linalg vectorization, not to be confused with VectorizationResult.
630 /// Return status from vectorizing the current op.
632 /// New vectorized operation to replace the current op.
633 /// Replacement behavior is specified by `status`.
635};
636
637std::optional<vector::CombiningKind>
639 using ::mlir::vector::CombiningKind;
640
641 if (!combinerOp)
642 return std::nullopt;
644 .Case<arith::AddIOp, arith::AddFOp>(
645 [&](auto op) { return CombiningKind::ADD; })
646 .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
647 .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
648 .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
649 .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
650 .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
651 .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
652 .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
653 .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
654 .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
655 .Case<arith::MulIOp, arith::MulFOp>(
656 [&](auto op) { return CombiningKind::MUL; })
657 .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
658 .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
659 .Default(std::nullopt);
660}
661
662/// Check whether `outputOperand` is a reduction with a single combiner
663/// operation. Return the combiner operation of the reduction. Return
664/// nullptr otherwise. Multiple reduction operations would impose an
665/// ordering between reduction dimensions and is currently unsupported in
666/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
667/// max(min(X))
668// TODO: use in LinalgOp verification, there is a circular dependency atm.
669static Operation *matchLinalgReduction(OpOperand *outputOperand) {
670 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
671 unsigned outputPos =
672 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
673 // Only single combiner operations are supported for now.
674 SmallVector<Operation *, 4> combinerOps;
675 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676 combinerOps.size() != 1)
677 return nullptr;
678
679 // Return the combiner operation.
680 return combinerOps[0];
681}
682
683/// Broadcast `value` to a vector of `shape` if possible. Return value
684/// otherwise.
685static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
686 auto dstVecType = dyn_cast<VectorType>(dstType);
687 // If no shape to broadcast to, just return `value`.
688 if (dstVecType.getRank() == 0)
689 return value;
690 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
692 return value;
693 Location loc = b.getInsertionPoint()->getLoc();
694 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
695}
696
697/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
698/// assumes that `reductionOp` has two operands and one of them is the reduction
699/// initial value.buildMultiDimReduce
700// Note: this is a true builder that notifies the OpBuilder listener.
701// TODO: Consider moving as a static helper on the ReduceOp.
703 Value valueToReduce, Value acc,
704 ArrayRef<bool> dimsToMask) {
705 auto maybeKind = getCombinerOpKind(reduceOp);
706 assert(maybeKind && "Failed precondition: could not get reduction kind");
707 return vector::MultiDimReductionOp::create(
708 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
709}
710
711static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
712 return llvm::to_vector(
713 llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
714}
715
716/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
717/// reduction iterator.
718static bool hasReductionIterator(LinalgOp &op) {
719 return isa<linalg::ReduceOp>(op) ||
720 (isa<linalg::GenericOp>(op) &&
721 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
722}
723
724/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
725/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
726/// currently being vectorized. If `dest` has null rank, build an memref.store.
727/// Return the produced value or null if no value is produced.
728// Note: this is a true builder that notifies the OpBuilder listener.
729// TODO: Consider moving as a static helper on the ReduceOp.
730static Value buildVectorWrite(RewriterBase &rewriter, Value value,
731 OpOperand *outputOperand,
732 VectorizationState &state) {
733 Location loc = value.getLoc();
734 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
735 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
736
737 // Compute the vector type of the value to store. This type should be an
738 // identity or projection of the canonical vector type without any permutation
739 // applied, given that any permutation in a transfer write happens as part of
740 // the write itself.
742 opOperandMap.getContext(), opOperandMap.getNumInputs(),
743 [&](AffineDimExpr dimExpr) -> bool {
744 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
745 });
746 auto vectorType = state.getCanonicalVecType(
747 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
748
749 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
750 arith::ConstantIndexOp::create(rewriter, loc, 0));
751
752 Operation *write;
753 if (vectorType.getRank() > 0) {
754 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
755 value = broadcastIfNeeded(rewriter, value, vectorType);
756 assert(value.getType() == vectorType && "Incorrect type");
757 write = vector::TransferWriteOp::create(
758 rewriter, loc, value, outputOperand->get(), indices, writeMap);
759 } else {
760 // 0-d case is still special: do not invert the reindexing writeMap.
761 if (!isa<VectorType>(value.getType()))
762 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763 assert(value.getType() == vectorType && "Incorrect type");
764 write = vector::TransferWriteOp::create(rewriter, loc, value,
765 outputOperand->get(), indices);
766 }
767
768 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
769
770 // If masked, set in-bounds to true. Masking guarantees that the access will
771 // be in-bounds.
772 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
774 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
775 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
776 }
777
778 LDBG() << "vectorized op: " << *write;
779 if (!write->getResults().empty())
780 return write->getResult(0);
781 return Value();
782}
783
784// Custom vectorization precondition function type. This is intented to be used
785// with CustomVectorizationHook. Returns success if the corresponding custom
786// hook can vectorize the op.
788 std::function<LogicalResult(Operation *, bool)>;
789
790// Custom vectorization function type. Produce a vector form of Operation*
791// assuming all its vectorized operands are already in the IRMapping.
792// Return nullptr if the Operation cannot be vectorized.
794 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
795
796/// Helper function to vectorize the terminator of a `linalgOp`. New result
797/// vector values are appended to `newResults`. Return
798/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
799/// that it should not try to map produced operations and instead return the
800/// results using the `newResults` vector making them available to the
801/// vectorization algorithm for RAUW. This function is meant to be used as a
802/// CustomVectorizationHook.
805 const IRMapping &bvm, VectorizationState &state,
806 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
807 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
808 if (!yieldOp)
810 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
811 // TODO: Scan for an opportunity for reuse.
812 // TODO: use a map.
813 Value vectorValue = bvm.lookup(output.value());
814 Value newResult =
815 buildVectorWrite(rewriter, vectorValue,
816 linalgOp.getDpsInitOperand(output.index()), state);
817 if (newResult)
818 newResults.push_back(newResult);
819 }
820
822}
823
824/// Helper function to vectorize the index operations of a `linalgOp`. Return
825/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
826/// should map the produced operations. This function is meant to be used as a
827/// CustomVectorizationHook.
829 VectorizationState &state,
830 Operation *op,
831 LinalgOp linalgOp) {
832 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
833 if (!indexOp)
835 auto loc = indexOp.getLoc();
836 // Compute the static loop sizes of the index op.
837 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
838 auto dim = indexOp.getDim();
839 // Compute a one-dimensional index vector for the index op dimension.
840 auto indexVectorType =
841 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
842 state.getScalableVecDims()[dim]);
843 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
844 // Return the one-dimensional index vector if it lives in the trailing
845 // dimension of the iteration space since the vectorization algorithm in this
846 // case can handle the broadcast.
847 if (dim == targetShape.size() - 1)
849 // Otherwise permute the targetShape to move the index dimension last,
850 // broadcast the one-dimensional index vector to the permuted shape, and
851 // finally transpose the broadcasted index vector to undo the permutation.
852 auto permPattern =
853 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854 std::swap(permPattern[dim], permPattern.back());
855 auto permMap =
856 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
857
858 auto broadCastOp = vector::BroadcastOp::create(
859 rewriter, loc,
860 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
861 SmallVector<int64_t> transposition =
862 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863 std::swap(transposition.back(), transposition[dim]);
864 auto transposeOp =
865 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
867}
868
869/// Helper function to check if the tensor.extract can be vectorized by the
870/// custom hook vectorizeTensorExtract.
871static LogicalResult
873 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
874 if (!extractOp)
875 return failure();
876
877 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
878 return failure();
879
880 // Check the index type, but only for non 0-d tensors (for which we do need
881 // access indices).
882 if (not extractOp.getIndices().empty()) {
883 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
884 return failure();
885 }
886
887 if (!llvm::all_of(extractOp->getResultTypes(),
888 VectorType::isValidElementType)) {
889 return failure();
890 }
891
892 return success();
893}
894
895/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
896/// generated from `tensor.extract`. The offset is calculated as follows
897/// (example using scalar values):
898///
899/// offset = extractOp.indices[0]
900/// for (i = 1; i < numIndices; i++)
901/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
902///
903/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
904/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
906 VectorizationState &state,
907 tensor::ExtractOp extractOp,
908 const IRMapping &bvm) {
909 // The vector of indices for GatherOp should be shaped as the output vector.
910 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
911 auto loc = extractOp.getLoc();
912
913 Value offset = broadcastIfNeeded(
914 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
915
916 const size_t numIndices = extractOp.getIndices().size();
917 for (size_t i = 1; i < numIndices; i++) {
918 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
919
920 auto dimSize = broadcastIfNeeded(
921 rewriter,
922 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
923 indexVecType);
924
925 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
926
927 auto extractOpIndex = broadcastIfNeeded(
928 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
929
930 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
931 }
932
933 return offset;
934}
935
937
938/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
939/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
940/// represents a contiguous load operation.
941///
942/// Note that when calling this hook, it is assumed that the output vector is
943/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
944/// labelled as a gather load before entering this method.
945///
946/// Following on from the above, it is assumed that:
947/// * for statically shaped loops, when no masks are used, only one dim is !=
948/// 1 (that's what the shape of the output vector is based on).
949/// * for dynamically shaped loops, there might be more non-unit dims
950/// as the output vector type is user-specified.
951///
952/// TODO: Statically shaped loops + vector masking
953static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
954 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
955 assert(
956 (linalgOp.hasDynamicShape() ||
957 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
958 "For statically shaped Linalg Ops, only one "
959 "non-unit loop dim is expected");
960 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
961
962 size_t idx = loopRanges.size() - 1;
963 for (; idx != 0; idx--)
964 if (loopRanges[idx] != 1)
965 break;
966
967 return idx;
968}
969
970/// Checks whether `val` can be used for calculating a loop invariant index.
971static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
972 VectorType resType) {
973
974 assert(((llvm::count_if(resType.getShape(),
975 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
977
978 // Blocks outside _this_ linalg.generic are effectively loop invariant.
979 // However, analysing block arguments for _this_ linalg.generic Op is a bit
980 // tricky. Just bail out in the latter case.
981 // TODO: We could try analysing the corresponding affine map here.
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
985
986 Operation *defOp = val.getDefiningOp();
987 assert(defOp && "This is neither a block argument nor an operation result");
988
989 // IndexOp is loop invariant as long as its result remains constant across
990 // iterations. Note that for dynamic shapes, the corresponding dim will also
991 // be conservatively treated as != 1.
992 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
994 }
995
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
997
998 // Values define outside `linalgOp` are loop invariant.
999 if (!ancestor)
1000 return true;
1001
1002 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1003 if (isa<arith::ConstantOp>(ancestor))
1004 return true;
1005
1006 bool result = true;
1007 for (auto op : ancestor->getOperands())
1008 result &= isLoopInvariantIdx(linalgOp, op, resType);
1009
1010 return result;
1011}
1012
1013/// Check whether `val` could be used for calculating the trailing index for a
1014/// contiguous load operation.
1015///
1016/// There are currently 3 types of values that are allowed here:
1017/// 1. loop-invariant values,
1018/// 2. values that increment by 1 with every loop iteration,
1019/// 3. results of basic arithmetic operations (linear and continuous)
1020/// involving 1., 2. and 3.
1021/// This method returns True if indeed only such values are used in calculating
1022/// `val.`
1023///
1024/// Additionally, the trailing index for a contiguous load operation should
1025/// increment by 1 with every loop iteration, i.e. be based on:
1026/// * `linalg.index <dim>` ,
1027/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1028/// `linalg.index <dim>` increments by 1 with every loop iteration).
1029/// `foundIndexOp` is updated to `true` when such Op is found.
1030static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1031 bool &foundIndexOp, VectorType resType) {
1032
1033 assert(((llvm::count_if(resType.getShape(),
1034 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035 "n-D vectors are not yet supported");
1036
1037 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1038 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1039 // tricky. Just bail out in the latter case.
1040 // TODO: We could try analysing the corresponding affine map here.
1041 auto *block = linalgOp.getBlock();
1042 if (isa<BlockArgument>(val))
1043 return !llvm::is_contained(block->getArguments(), val);
1044
1045 Operation *defOp = val.getDefiningOp();
1046 assert(defOp && "This is neither a block argument nor an operation result");
1047
1048 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1049 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1050
1051 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1052 return true;
1053 }
1054
1055 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1056
1057 if (!ancestor)
1058 return false;
1059
1060 // Conservatively reject Ops that could lead to indices with stride other
1061 // than 1.
1062 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1063 return false;
1064
1065 bool result = false;
1066 for (auto op : ancestor->getOperands())
1067 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1068
1069 return result;
1070}
1071
1072/// Infer the memory access pattern for the input ExtractOp
1073///
1074/// Based on the ExtratOp result shape and the access indices, decides whether
1075/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1076/// or a gather load. When analysing the ExtractOp indices (to identify
1077/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1078/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1079/// Op).
1080///
1081/// Note that it is always safe to use gather load operations for contiguous
1082/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1083/// that `extractOp` is a gather load.
1085getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1086 LinalgOp &linalgOp, VectorType resType) {
1087
1088 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1089
1090 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1091 if (inputShape.getShape().empty())
1093
1094 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1095 // otherwise.
1096 bool isOutput1DVector =
1097 (llvm::count_if(resType.getShape(),
1098 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1099 // 1. Assume that it's a gather load when reading non-1D vector.
1100 if (!isOutput1DVector)
1102
1103 bool leadingIdxsLoopInvariant = true;
1104
1105 // 2. Analyze the leading indices of `extractOp`.
1106 // Look at the way each index is calculated and decide whether it is suitable
1107 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1108 // gather load.
1109 auto indices = extractOp.getIndices();
1110 auto leadIndices = indices.drop_back(1);
1111
1112 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1113 if (inputShape.getShape()[i] == 1)
1114 continue;
1115
1116 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1117 }
1118
1119 if (!leadingIdxsLoopInvariant) {
1120 LDBG() << "Found gather load: " << extractOp;
1122 }
1123
1124 // 3. Analyze the trailing index for `extractOp`.
1125 // At this point we know that the leading indices are loop invariant. This
1126 // means that is potentially a scalar or a contiguous load. We can decide
1127 // based on the trailing idx.
1128 auto extractOpTrailingIdx = indices.back();
1129
1130 // 3a. Scalar broadcast load
1131 // If the trailing index is loop invariant then this is a scalar load.
1132 if (leadingIdxsLoopInvariant &&
1133 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1134 LDBG() << "Found scalar broadcast load: " << extractOp;
1135
1137 }
1138
1139 // 3b. Contiguous loads
1140 // The trailing `extractOp` index should increment with every loop iteration.
1141 // This effectively means that it must be based on the trailing loop index.
1142 // This is what the following bool captures.
1143 bool foundIndexOp = false;
1144 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1145 foundIndexOp, resType);
1146 // TODO: Support generating contiguous loads for column vectors - that will
1147 // require adding a permutation map to tranfer_read Ops.
1148 bool isRowVector = resType.getShape().back() != 1;
1149 isContiguousLoad &= (foundIndexOp && isRowVector);
1150
1151 if (isContiguousLoad) {
1152 LDBG() << "Found contigous load: " << extractOp;
1154 }
1155
1156 // 4. Fallback case - gather load.
1157 LDBG() << "Found gather load: " << extractOp;
1159}
1160
1161/// Helper function to vectorize the tensor.extract operations. Returns
1162/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1163/// should map the produced operations. This function is meant to be used as a
1164/// CustomVectorizationHook.
1165static VectorizationHookResult
1166vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1167 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1168 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1169 if (!extractOp)
1171 auto loc = extractOp.getLoc();
1172
1173 // Compute the static loop sizes of the extract op.
1174 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175 auto maskConstantOp = arith::ConstantOp::create(
1176 rewriter, loc,
1177 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1178 /*value=*/true));
1179 auto passThruConstantOp = arith::ConstantOp::create(
1180 rewriter, loc, rewriter.getZeroAttr(resultType));
1181
1182 // Base indices are currently set to 0. We will need to re-visit if more
1183 // generic scenarios are to be supported.
1184 SmallVector<Value> baseIndices(
1185 extractOp.getIndices().size(),
1186 arith::ConstantIndexOp::create(rewriter, loc, 0));
1187
1188 VectorMemoryAccessKind memAccessKind =
1189 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1190
1191 // 1. Handle gather access
1192 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1193 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1194
1195 // Generate the gather load
1196 Operation *gatherOp = vector::GatherOp::create(
1197 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198 maskConstantOp, passThruConstantOp);
1199 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1200
1201 LDBG() << "Vectorised as gather load: " << extractOp;
1203 }
1204
1205 // 2. Handle:
1206 // a. scalar loads + broadcast,
1207 // b. contiguous loads.
1208 // Both cases use vector.transfer_read.
1209
1210 // Collect indices for `vector.transfer_read`. At this point, the indices will
1211 // either be scalars or would have been broadcast to vectors matching the
1212 // result type. For indices that are vectors, there are two options:
1213 // * for non-trailing indices, all elements are identical (contiguous
1214 // loads are identified by looking for non-trailing indices that are
1215 // invariant with respect to the corresponding linalg.generic), or
1216 // * for trailing indices, the index vector will contain values with stride
1217 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1218 // needed.
1219 // This means that
1220 // * for scalar indices - just re-use it,
1221 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1222 // (0th) element and use that.
1223 SmallVector<Value> transferReadIdxs;
1224 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1226 if (idx.getType().isIndex()) {
1227 transferReadIdxs.push_back(idx);
1228 continue;
1229 }
1230
1231 auto indexAs1dVector = vector::ShapeCastOp::create(
1232 rewriter, loc,
1233 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1234 resultType.getScalableDims().back()),
1235 idx);
1236 transferReadIdxs.push_back(
1237 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1238 }
1239
1240 // `tensor.extract_element` is always in-bounds, hence the following holds.
1241 auto dstRank = resultType.getRank();
1242 auto srcRank = extractOp.getTensor().getType().getRank();
1243 SmallVector<bool> inBounds(dstRank, true);
1244
1245 // 2a. Handle scalar broadcast access.
1246 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1247 MLIRContext *ctx = rewriter.getContext();
1248 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1249 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1250
1251 auto transferReadOp = vector::TransferReadOp::create(
1252 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253 /*padding=*/std::nullopt, permutationMap, inBounds);
1254
1255 // Mask this broadcasting xfer_read here rather than relying on the generic
1256 // path (the generic path assumes identity masking map, which wouldn't be
1257 // valid here).
1258 SmallVector<int64_t> readMaskShape = {1};
1259 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1260 auto allTrue = vector::ConstantMaskOp::create(
1261 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1262 auto *maskedReadOp =
1263 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1264
1265 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1267 maskedReadOp};
1268 }
1269
1270 // 2b. Handle contiguous access.
1271 auto permutationMap = AffineMap::getMinorIdentityMap(
1272 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1273
1274 int32_t rankDiff = dstRank - srcRank;
1275 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1276 // dims so that the ranks match. This is done by extending the map with 0s.
1277 // For example, for dstRank = 3, srcRank = 2, the following map created
1278 // above:
1279 // (d0, d1) --> (d0, d1)
1280 // is extended as:
1281 // (d0, d1) --> (0, d0, d1)
1282 while (rankDiff > 0) {
1283 permutationMap = permutationMap.insertResult(
1284 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1285 rankDiff--;
1286 }
1287
1288 auto transferReadOp = vector::TransferReadOp::create(
1289 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290 /*padding=*/std::nullopt, permutationMap, inBounds);
1291
1292 LDBG() << "Vectorised as contiguous load: " << extractOp;
1294 transferReadOp};
1295}
1296
1297/// Emit reduction operations if the shapes of the value to reduce is different
1298/// that the result shape.
1299// Note: this is a true builder that notifies the OpBuilder listener.
1300// TODO: Consider moving as a static helper on the ReduceOp.
1301static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1302 Value reduceValue, Value initialValue,
1303 const IRMapping &bvm) {
1304 Value reduceVec = bvm.lookup(reduceValue);
1305 Value outputVec = bvm.lookup(initialValue);
1306 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1307 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1308 // Reduce only if needed as the value may already have been reduce for
1309 // contraction vectorization.
1310 if (!reduceType ||
1311 (outputType && reduceType.getShape() == outputType.getShape()))
1312 return nullptr;
1313 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1314 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1315}
1316
1317/// Generic vectorization for a single operation `op`, given already vectorized
1318/// operands carried by `bvm`. Vectorization occurs as follows:
1319/// 1. Try to apply any of the `customVectorizationHooks` and return its
1320/// result on success.
1321/// 2. Clone any constant in the current scope without vectorization: each
1322/// consumer of the constant will later determine the shape to which the
1323/// constant needs to be broadcast to.
1324/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1325/// of the `customVectorizationHooks` to cover such cases.
1326/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1327/// operand of maximal rank. Other operands have smaller rank and are
1328/// broadcast accordingly. It is assumed this broadcast is always legal,
1329/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1330///
1331/// This function assumes all operands of `op` have been vectorized and are in
1332/// the `bvm` mapping. As a consequence, this function is meant to be called on
1333/// a topologically-sorted list of ops.
1334/// This function does not update `bvm` but returns a VectorizationHookStatus
1335/// that instructs the caller what `bvm` update needs to occur.
1336static VectorizationHookResult
1337vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1338 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1339 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1340 LDBG() << "vectorize op " << *op;
1341
1342 // 1. Try to apply any CustomVectorizationHook.
1343 if (!customVectorizationHooks.empty()) {
1344 for (auto &customFunc : customVectorizationHooks) {
1345 VectorizationHookResult result = customFunc(op, bvm);
1347 continue;
1348 return result;
1349 }
1350 }
1351
1352 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1353 // Clone so that the constant is not confined to the linalgOp block .
1354 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356 rewriter.clone(*op)};
1357
1358 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1361
1362 // 4 . Check if the operation is a reduction.
1363 SmallVector<std::pair<Value, Value>> reductionOperands;
1364 for (Value operand : op->getOperands()) {
1365 auto blockArg = dyn_cast<BlockArgument>(operand);
1366 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1368 continue;
1369 SmallVector<Operation *> reductionOps;
1370 Value reduceValue = matchReduction(
1371 linalgOp.getRegionOutputArgs(),
1372 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1373 if (!reduceValue)
1374 continue;
1375 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1376 }
1377 if (!reductionOperands.empty()) {
1378 assert(reductionOperands.size() == 1);
1379 Operation *reduceOp =
1380 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381 reductionOperands[0].second, bvm);
1382 if (reduceOp)
1384 }
1385
1386 // 5. Generic vectorization path for ElementwiseMappable ops.
1387 // a. Get the first max ranked shape.
1388 VectorType firstMaxRankedType;
1389 for (Value operand : op->getOperands()) {
1390 auto vecOperand = bvm.lookup(operand);
1391 assert(vecOperand && "Vector operand couldn't be found");
1392
1393 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394 if (vecType && (!firstMaxRankedType ||
1395 firstMaxRankedType.getRank() < vecType.getRank()))
1396 firstMaxRankedType = vecType;
1397 }
1398 // b. Broadcast each op if needed.
1399 SmallVector<Value> vecOperands;
1400 for (Value scalarOperand : op->getOperands()) {
1401 Value vecOperand = bvm.lookup(scalarOperand);
1402 assert(vecOperand && "Vector operand couldn't be found");
1403
1404 if (firstMaxRankedType) {
1405 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1406 getElementTypeOrSelf(vecOperand.getType()),
1407 firstMaxRankedType.getScalableDims());
1408 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1409 } else {
1410 vecOperands.push_back(vecOperand);
1411 }
1412 }
1413 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1414 SmallVector<Type> resultTypes;
1415 for (Type resultType : op->getResultTypes()) {
1416 resultTypes.push_back(
1417 firstMaxRankedType
1418 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1419 firstMaxRankedType.getScalableDims())
1420 : resultType);
1421 }
1422 // d. Build and return the new op.
1425 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1426 resultTypes, op->getAttrs())};
1427}
1428
1429/// Generic vectorization function that rewrites the body of a `linalgOp` into
1430/// vector form. Generic vectorization proceeds as follows:
1431/// 1. Verify the `linalgOp` has one non-empty region.
1432/// 2. Values defined above the region are mapped to themselves and will be
1433/// broadcasted on a per-need basis by their consumers.
1434/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1435/// load).
1436/// TODO: Reuse opportunities for RAR dependencies.
1437/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1438/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1439/// iteration indices.
1440/// 5. Iteratively call vectorizeOneOp on the region operations.
1441///
1442/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1443/// performed to the maximal common vector size implied by the `linalgOp`
1444/// iteration space. This eager broadcasting is introduced in the
1445/// permutation_map of the vector.transfer_read operations. The eager
1446/// broadcasting makes it trivial to determine where broadcast, transposes and
1447/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1448/// the absence of good canonicalizations, the amount of work increases.
1449/// This is not deemed a problem as we expect canonicalizations and foldings to
1450/// aggressively clean up the useless work.
1451static LogicalResult
1452vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1453 LinalgOp linalgOp,
1454 SmallVectorImpl<Value> &newResults) {
1455 LDBG() << "Vectorizing operation as linalg generic/n";
1456 Block *block = linalgOp.getBlock();
1457
1458 // 2. Values defined above the region can only be broadcast for now. Make them
1459 // map to themselves.
1460 IRMapping bvm;
1461 SetVector<Value> valuesSet;
1462 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1463 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1464
1465 if (linalgOp.getNumDpsInits() == 0)
1466 return failure();
1467
1468 // 3. Turn all BBArgs into vector.transfer_read / load.
1469 Location loc = linalgOp.getLoc();
1470 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1471 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473 if (linalgOp.isScalar(opOperand)) {
1474 bvm.map(bbarg, opOperand->get());
1475 continue;
1476 }
1477
1478 // 3.a. Convert the indexing map for this input/output to a transfer read
1479 // permutation map and masking map.
1480 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1481
1482 AffineMap readMap;
1483 VectorType readType;
1484 Type elemType = getElementTypeOrSelf(opOperand->get());
1485 if (linalgOp.isDpsInput(opOperand)) {
1486 // 3.a.i. For input reads we use the canonical vector shape.
1487 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1488 readType = state.getCanonicalVecType(elemType);
1489 } else {
1490 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1491 // reductions), the vector shape is computed by mapping the canonical
1492 // vector shape to the output domain and back to the canonical domain.
1493 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1494 readType =
1495 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1496 }
1497
1498 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1499
1500 Operation *read = vector::TransferReadOp::create(
1501 rewriter, loc, readType, opOperand->get(), indices,
1502 /*padding=*/std::nullopt, readMap);
1503 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1504 Value readValue = read->getResult(0);
1505
1506 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1507 // will be in-bounds.
1508 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1509 SmallVector<bool> inBounds(readType.getRank(), true);
1510 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1511 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1512 }
1513
1514 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1515 // TODO: remove this.
1516 if (readType.getRank() == 0)
1517 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1519
1520 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1521 << "): " << readValue;
1522 bvm.map(bbarg, readValue);
1523 bvm.map(opOperand->get(), readValue);
1524 }
1525
1527 // 4a. Register CustomVectorizationHook for yieldOp.
1528 CustomVectorizationHook vectorizeYield =
1529 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1530 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1531 };
1532 hooks.push_back(vectorizeYield);
1533
1534 // 4b. Register CustomVectorizationHook for indexOp.
1535 CustomVectorizationHook vectorizeIndex =
1536 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1537 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1538 };
1539 hooks.push_back(vectorizeIndex);
1540
1541 // 4c. Register CustomVectorizationHook for extractOp.
1542 CustomVectorizationHook vectorizeExtract =
1543 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1544 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1545 };
1546 hooks.push_back(vectorizeExtract);
1547
1548 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1549 for (Operation &op : block->getOperations()) {
1551 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1553 LDBG() << "failed to vectorize: " << op;
1554 return failure();
1555 }
1556 if (result.status == VectorizationHookStatus::NewOp) {
1557 Operation *maybeMaskedOp =
1558 state.maskOperation(rewriter, result.newOp, linalgOp);
1559 LDBG() << "New vector op: " << *maybeMaskedOp;
1560 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1561 }
1562 }
1563
1564 return success();
1565}
1566
1567/// Determines whether a mask for xfer_write is trivially "all true"
1568///
1569/// Given all the inputs required to generate a mask (mask sizes and shapes),
1570/// and an xfer_write operation (write indices and the destination tensor
1571/// shape), determines whether the corresponding mask would be trivially
1572/// foldable (i.e., trivially "all true").
1573///
1574/// Use this method to avoid generating spurious masks and relaying on
1575/// vectorization post-processing to remove them.
1576///
1577/// Pre-conditions for a mask to be trivially foldable:
1578/// * All involved shapes (mask + destination tensor) are static.
1579/// * All write indices are constant.
1580/// * All mask sizes are constant (including `arith.constant`).
1581///
1582/// If the pre-conditions are met, the method checks for each destination
1583/// dimension `d`:
1584/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1585/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1586///
1587/// rankDiff = rank(dest) - rank(mask).
1588///
1589/// This method takes a conservative view: it may return false even if the mask
1590/// is technically foldable.
1591///
1592/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1593/// of the dest tensor):
1594/// %c0 = arith.constant 0 : index
1595/// %mask = vector.create_mask 5, 1
1596/// vector.mask %mask {
1597/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1598/// {in_bounds = [true, true]}
1599/// : vector<5x1xi32>, tensor<5x1xi32>
1600/// }
1601///
1602/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1603/// mask is required to avoid out-of-bounds write):
1604/// %c0 = arith.constant 0 : index
1605/// %mask = vector.create_mask 5, 1
1606/// vector.mask %mask {
1607/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1608/// {in_bounds = [true, true]}
1609/// : vector<8x1xi32>, tensor<5x1xi32>
1610/// }
1611///
1612/// TODO: Re-use in createReadOrMaskedRead
1614 SmallVector<Value> &writeIdxs,
1615 ArrayRef<int64_t> destShape,
1616 ArrayRef<int64_t> maskShape) {
1617 // Masking is unavoidable in the case of dynamic tensors.
1618 if (ShapedType::isDynamicShape(destShape))
1619 return false;
1620
1621 // Collect all constant mask sizes.
1622 SmallVector<int64_t, 4> cstMaskSizes;
1623 for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1624 if (auto intSize = getConstantIntValue(dimSize)) {
1625 cstMaskSizes.push_back(*intSize);
1626 }
1627 }
1628
1629 // If any of the mask sizes is non-constant, bail out.
1630 if (cstMaskSizes.size() != maskShape.size())
1631 return false;
1632
1633 // Collect all constant write indices.
1634 SmallVector<int64_t, 4> cstWriteIdxs;
1635 for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1636 APSInt intVal;
1637 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1638 cstWriteIdxs.push_back(intVal.getSExtValue());
1639 }
1640 }
1641
1642 // If any of the write indices is non-constant, bail out.
1643 if (cstWriteIdxs.size() != destShape.size())
1644 return false;
1645
1646 // Go over all destination dims and check (1) and (2). Take into account that:
1647 // * The number of mask sizes will match the rank of the vector to store.
1648 // This could be lower than the rank of the destination tensor.
1649 // * Mask sizes could be larger than the corresponding mask shape (hence
1650 // `clamp`).
1651 // TODO: The 2nd item should be rejected by the verifier.
1652 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1653 for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1654 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1655 /*(2)*/ destShape[rankDiff + i] <
1656 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1657 cstWriteIdxs[i]))
1658 return false;
1659 }
1660
1661 return true;
1662}
1663
1664/// Creates an optionally masked TransferWriteOp
1665///
1666/// Generates the following operation:
1667/// %res = vector.transfer_write %vecToStore into %dest
1668///
1669/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1670///
1671/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1672/// %res = vector.mask %mask {
1673/// vector.transfer_write %vecToStore into %dest
1674/// }
1675///
1676/// The mask shape is identical to `vecToStore` (with the element type ==
1677/// i1), and the mask values are based on the shape of the `dest` tensor.
1678///
1679/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1680/// is used instead of masking:
1681///
1682/// %write = vector.transfer_write %vecToStore into %dest
1683/// in_bounds_flags = (...)
1684/// %res = vector.transfer_write %input into %dest
1685/// {in_bounds = in_bounds_flags}
1686///
1687/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1688/// are set to 0.
1689static Operation *
1691 Value dest, SmallVector<Value> writeIndices = {},
1692 bool useInBoundsInsteadOfMasking = false) {
1693
1694 ShapedType destType = cast<ShapedType>(dest.getType());
1695 int64_t destRank = destType.getRank();
1696 auto destShape = destType.getShape();
1697
1698 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1699 int64_t vecToStoreRank = vecToStoreType.getRank();
1700 auto vecToStoreShape = vecToStoreType.getShape();
1701
1702 // Compute the in_bounds attribute
1703 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1704 if (useInBoundsInsteadOfMasking) {
1705 // Update the inBounds attribute.
1706 // FIXME: This computation is too weak - it ignores the write indices.
1707 for (unsigned i = 0; i < vecToStoreRank; i++)
1708 inBoundsVal[i] =
1709 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1710 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1711 }
1712
1713 // If missing, initialize the write indices to 0.
1714 assert((writeIndices.empty() ||
1715 writeIndices.size() == static_cast<size_t>(destRank)) &&
1716 "Invalid number of write indices!");
1717 if (writeIndices.empty()) {
1718 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1719 writeIndices.assign(destRank, zero);
1720 }
1721
1722 // Generate the xfer_write Op
1723 Operation *write = vector::TransferWriteOp::create(builder, loc,
1724 /*vector=*/vecToStore,
1725 /*source=*/dest,
1726 /*indices=*/writeIndices,
1727 /*inBounds=*/inBoundsVal);
1728
1729 // If masking is disabled, exit.
1730 if (useInBoundsInsteadOfMasking)
1731 return write;
1732
1733 // Check if masking is needed. If not, exit.
1734 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1735 return write;
1736
1737 // Compute the mask and mask the write Op.
1738 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1739 vecToStoreType.getScalableDims());
1740
1741 SmallVector<OpFoldResult> destSizes =
1742 isa<MemRefType>(dest.getType())
1743 ? memref::getMixedSizes(builder, loc, dest)
1744 : tensor::getMixedSizes(builder, loc, dest);
1745 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1746 destSizes.end());
1747
1748 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1749 vecToStoreShape))
1750 return write;
1751
1752 Value maskForWrite =
1753 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1754 return mlir::vector::maskOperation(builder, write, maskForWrite);
1755}
1756
1757/// Given the re-associations, "collapses" the input Vector type
1758///
1759/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1760/// differences:
1761/// * We can safely assume that there are no dynamic sizes.
1762/// * Scalable flags are updated alongside regular dims.
1763///
1764/// When collapsing scalable flags, conservatively avoids cases with two
1765/// scalable dims. We could re-visit this in the future.
1766///
1767/// EXAMPLE:
1768/// type = vector<4x16x[8]x16xf32>
1769/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1770/// (d0, d1, d2, d3) -> (d2, d3)]
1771/// Result:
1772/// vector<64x[128]xf32>
1773static VectorType getCollapsedVecType(VectorType type,
1774 ArrayRef<AffineMap> reassociation) {
1775 assert(type.getNumScalableDims() < 2 &&
1776 "Collapsing more than 1 scalable dim is not supported ATM");
1777
1778 // Use the fact that reassociation is valid to simplify the logic: only use
1779 // each map's rank.
1780 assert(isReassociationValid(reassociation) && "invalid reassociation");
1781
1782 auto shape = type.getShape();
1783 auto scalableFlags = type.getScalableDims();
1784 SmallVector<int64_t> newShape;
1785 SmallVector<bool> newScalableFlags;
1786
1787 unsigned currentDim = 0;
1788 for (AffineMap m : reassociation) {
1789 unsigned dim = m.getNumResults();
1790 int64_t size = 1;
1791 bool flag = false;
1792 for (unsigned d = 0; d < dim; ++d) {
1793 size *= shape[currentDim + d];
1794 flag |= scalableFlags[currentDim + d];
1795 }
1796 newShape.push_back(size);
1797 newScalableFlags.push_back(flag);
1798 currentDim += dim;
1799 }
1800
1801 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1802}
1803
1804/// Vectorize `linalg.pack` as:
1805/// * xfer_read -> shape_cast -> transpose -> xfer_write
1806///
1807/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1808/// sizes for the xfer_write operation). This is sufficient to infer the other
1809/// vector sizes required here.
1810///
1811/// If the vector sizes are not provided:
1812/// * the vector sizes are determined from the destination tensor static shape.
1813/// * the inBounds attribute is used instead of masking.
1814///
1815/// EXAMPLE (no vector sizes):
1816/// ```
1817/// %pack = tensor.pack %src
1818/// inner_dims_pos = [2, 1]
1819/// inner_tiles = [16, 2]
1820/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1821/// ``
1822/// is vectorizes as:
1823/// ```
1824/// %read = vector.transfer_read %src
1825/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1826/// %sc = vector.shape_cast %read
1827/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1828/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1829/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1830/// %write = vector.transfer_write %tr into %dest
1831/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1832/// ```
1833static LogicalResult
1834vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1835 ArrayRef<int64_t> inputVectorSizes,
1836 SmallVectorImpl<Value> &newResults) {
1837 if (!inputVectorSizes.empty()) {
1838 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1839 "Invalid number of input vector sizes!");
1840 }
1841
1842 // TODO: Introduce a parent class that will handle the insertion point update.
1843 OpBuilder::InsertionGuard g(rewriter);
1844 rewriter.setInsertionPoint(packOp);
1845
1846 Location loc = packOp.getLoc();
1847 std::optional<Value> padValue = packOp.getPaddingValue()
1848 ? std::optional(packOp.getPaddingValue())
1849 : std::nullopt;
1850
1851 SmallVector<int64_t> destShape =
1852 SmallVector<int64_t>(packOp.getDestType().getShape());
1853
1854 // This is just a convenience alias to clearly communicate that the input
1855 // vector sizes determine the _write_ sizes.
1856 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1857
1858 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1859 // In addition, use the inBounds attribute instead of masking.
1860 bool useInBoundsInsteadOfMasking = false;
1861 if (writeVectorSizes.empty()) {
1862 if (ShapedType::isDynamicShape(destShape))
1863 return rewriter.notifyMatchFailure(packOp,
1864 "unable to infer vector sizes");
1865
1866 writeVectorSizes = destShape;
1867 useInBoundsInsteadOfMasking = true;
1868 }
1869
1870 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1871 // _before_ the transposition (i.e. before dimension permutation). This is
1872 // done by inverting the permutation/transposition that's part of the Pack
1873 // operation. This type is required to:
1874 // 1) compute the read vector type for masked-read below, and
1875 // 2) generate shape-cast Op below that expands the read vector type.
1876 PackingMetadata packMetadata;
1877 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1878 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1879 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1880 auto preTransposeWriteVecType = VectorType::get(
1881 preTransposeWriteVecSizses, packOp.getType().getElementType());
1882
1883 // Compute vector type for the _read_ opeartion. This is simply
1884 // pre-transpose-write-vector-type with the dimensions collapsed
1885 // as per the Pack operation.
1886 VectorType readVecType = getCollapsedVecType(
1887 preTransposeWriteVecType,
1889 rewriter.getContext(), packMetadata.reassociations)));
1890
1891 // Create masked TransferReadOp.
1892 auto maskedRead = vector::createReadOrMaskedRead(
1893 rewriter, loc, packOp.getSource(), readVecType, padValue,
1894 useInBoundsInsteadOfMasking);
1895
1896 // Create ShapeCastOp.
1897 auto shapeCastOp = vector::ShapeCastOp::create(
1898 rewriter, loc, preTransposeWriteVecType, maskedRead);
1899
1900 // Create TransposeOp.
1901 auto destPermutation = invertPermutationVector(destInvPermutation);
1902 auto transposeOp = vector::TransposeOp::create(
1903 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1904
1905 // Create TransferWriteOp.
1906 Operation *write = createWriteOrMaskedWrite(
1907 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1908 newResults.push_back(write->getResult(0));
1909 return success();
1910}
1911
1912/// Vectorize `linalg.unpack` as:
1913/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1914///
1915/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1916/// sizes for the xfer_read operation). This is sufficient to infer the other
1917/// vector sizes required here.
1918///
1919/// If the vector sizes are not provided:
1920/// * the vector sizes are determined from the input tensor static shape.
1921/// * the inBounds attribute is used instead of masking.
1922///
1923/// EXAMPLE (no vector sizes):
1924/// ```
1925/// %unpack = linalg.unpack %src
1926/// inner_dims_pos = [0, 1]
1927/// inner_tiles = [8, 8]
1928/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1929/// ```
1930/// is vectorized as:
1931/// ```
1932/// %read = vector.transfer_read %src
1933/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1934/// %tr = vector.transpose %read, [0, 2, 1, 3]
1935/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1936/// %sc = vector.shape_cast %tr
1937/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1938/// %vector = vector.transfer_write %sc into %dest
1939/// : vector<8x8xf32>, tensor<8x8xf32>
1940/// ```
1941static LogicalResult
1942vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1943 ArrayRef<int64_t> inputVectorSizes,
1944 ArrayRef<bool> inputScalableVecDims,
1945 SmallVectorImpl<Value> &newResults) {
1946 if (!inputVectorSizes.empty()) {
1947 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1948 "Invalid number of input vector sizes!");
1949 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1950 "Incompatible number of vector sizes and vector scalable flags!");
1951 }
1952
1953 // TODO: Introduce a parent class that will handle the insertion point update.
1954 OpBuilder::InsertionGuard g(rewriter);
1955 rewriter.setInsertionPoint(unpackOp);
1956
1957 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1958
1959 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1960 bool useInBoundsInsteadOfMasking = false;
1961
1962 Location loc = unpackOp->getLoc();
1963
1964 // Obtain vector sizes for the read operation.
1965 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1966 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1967
1968 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1969 if (inputVectorSizes.empty()) {
1970 if (ShapedType::isDynamicShape(sourceShape))
1971 return rewriter.notifyMatchFailure(unpackOp,
1972 "Unable to infer vector sizes!");
1973
1974 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1975 useInBoundsInsteadOfMasking = true;
1976 }
1977
1978 // -- Generate the read operation --
1979 VectorType readVecType =
1980 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981 readScalableVectorFlags);
1982 Value readResult = vector::createReadOrMaskedRead(
1983 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984 useInBoundsInsteadOfMasking);
1985
1986 // -- Generate the transpose operation --
1987 PackingMetadata packMetadata;
1988 SmallVector<int64_t> lastDimToInsertPosPerm =
1989 getUnPackInverseSrcPerm(unpackOp, packMetadata);
1990 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1991 rewriter, loc, readResult, lastDimToInsertPosPerm);
1992
1993 // -- Generate the shape_cast operation --
1994 VectorType collapsedVecType = getCollapsedVecType(
1995 transposeOp.getType(),
1997 rewriter.getContext(), packMetadata.reassociations)));
1998 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1999 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2000
2001 // -- Generate the write operation --
2002 Operation *write = createWriteOrMaskedWrite(
2003 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2004 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
2005
2006 newResults.push_back(write->getResult(0));
2007 return success();
2008}
2009
2010/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
2011/// and (3) all-zero lowPad to
2012/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2013static LogicalResult
2014vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2015 ArrayRef<int64_t> inputVectorSizes,
2016 SmallVectorImpl<Value> &newResults) {
2017 auto padValue = padOp.getConstantPaddingValue();
2018 Location loc = padOp.getLoc();
2019
2020 // TODO: Introduce a parent class that will handle the insertion point update.
2021 OpBuilder::InsertionGuard g(rewriter);
2022 rewriter.setInsertionPoint(padOp);
2023
2024 ReifiedRankedShapedTypeDims reifiedReturnShapes;
2025 LogicalResult status =
2026 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2027 .reifyResultShapes(rewriter, reifiedReturnShapes);
2028 (void)status; // prevent unused variable warning on non-assert builds
2029 assert(succeeded(status) && "failed to reify result shapes");
2030 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2031 auto maskedRead = vector::createReadOrMaskedRead(
2032 rewriter, loc, padOp.getSource(), readType, padValue,
2033 /*useInBoundsInsteadOfMasking=*/false);
2034
2035 // Create Xfer write Op
2036 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2037 padOp.getResultType().getElementType());
2038 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2039 newResults.push_back(write->getResult(0));
2040 return success();
2041}
2042
2043// TODO: probably need some extra checks for reduction followed by consumer
2044// ops that may not commute (e.g. linear reduction + non-linear instructions).
2045static LogicalResult reductionPreconditions(LinalgOp op) {
2046 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2047 LDBG() << "reduction precondition failed: no reduction iterator";
2048 return failure();
2049 }
2050 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2051 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2052 if (indexingMap.isPermutation())
2053 continue;
2054
2055 Operation *reduceOp = matchLinalgReduction(&opOperand);
2056 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2057 LDBG() << "reduction precondition failed: reduction detection failed";
2058 return failure();
2059 }
2060 }
2061 return success();
2062}
2063
2064static LogicalResult
2065vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2066 bool flatten1DDepthwiseConv) {
2067 if (flatten1DDepthwiseConv) {
2068 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2069 "supported";
2070 return failure();
2071 }
2072
2073 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2074 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2075 return failure();
2076 }
2077
2078 // Support dynamic shapes in 1D depthwise convolution, but only in the
2079 // _channel_ dimension.
2080 Value lhs = conv.getDpsInputOperand(0)->get();
2081 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2082 auto shapeWithoutCh = lhsShape.drop_back(1);
2083 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2084 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2085 "channel dim can be dynamic";
2086 return failure();
2087 }
2088
2089 return success();
2090}
2091
2092static LogicalResult
2093vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2094 bool flatten1DDepthwiseConv) {
2095 if (isa<ConvolutionOpInterface>(op.getOperation()))
2096 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2097
2098 if (hasReductionIterator(op))
2099 return reductionPreconditions(op);
2100
2101 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2102 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2103 if (!isElementwise(op) &&
2104 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2105 op.getOperation()))
2106 return failure();
2107
2108 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2109 return success();
2110}
2111
2112//// This hook considers two cases:
2113/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2114/// infered. This is only possible when all shapes are static.
2115/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2116/// carry out basic sanity-checking.
2117static LogicalResult
2118vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2119 ArrayRef<int64_t> inputVectorSizes) {
2120 // If there are no input vector sizes and all shapes are static, there is
2121 // nothing left to check.
2122 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2123 unpackOp.getSourceType().hasStaticShape())
2124 return success();
2125
2126 // The number of input vector sizes must be equal to:
2127 // * read-vector-rank
2128 if (!inputVectorSizes.empty() &&
2129 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2130 LDBG() << "Incorrect number of input vector sizes";
2131 return failure();
2132 }
2133
2134 // Check the vector sizes for the read operation.
2136 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2137 LDBG() << "Invalid vector sizes for the read operation";
2138 return failure();
2139 }
2140
2141 return success();
2142}
2143
2144static LogicalResult
2145vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2146 ArrayRef<int64_t> inputVectorSizes) {
2147
2148 TypedValue<RankedTensorType> source = sliceOp.getSource();
2149 auto sourceType = source.getType();
2150 if (!VectorType::isValidElementType(sourceType.getElementType()))
2151 return failure();
2152
2153 // Get the pad value.
2154 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2155 // scalar padding value. Note that:
2156 // * for in-bounds accesses,
2157 // the value is actually irrelevant. There are 2 cases in which xfer.read
2158 // accesses are known to be in-bounds:
2159 // 1. The source shape is static (output vector sizes would be based on
2160 // the source shape and hence all memory accesses would be in-bounds),
2161 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2162 // this case it is safe to assume that all memory accesses are in-bounds.
2163 //
2164 // When the value is not known and not needed, use 0. Otherwise, bail out.
2165 Value padValue = getStaticPadVal(sliceOp);
2166 bool isOutOfBoundsRead =
2167 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2168
2169 if (!padValue && isOutOfBoundsRead) {
2170 LDBG() << "Failed to get a pad value for out-of-bounds read access";
2171 return failure();
2172 }
2173 return success();
2174}
2175
2176/// Vectorize a named linalg contraction op into:
2177/// vector::TransferReadOp - Reads vectors from the operands
2178/// vector::ContractionOp - Performs contraction
2179/// vector::TransferWriteOp - Write the result vector back to the
2180/// destination
2181/// The operands shapes are preserved and loaded directly into vectors.
2182/// Any further permutations or numerical casting remain within contraction op.
2183static LogicalResult
2184vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2185 LinalgOp linalgOp,
2186 SmallVectorImpl<Value> &newResults) {
2187 Location loc = linalgOp.getLoc();
2188 MLIRContext *ctx = linalgOp.getContext();
2189
2190 // For simplicity, contraction vectorization is limited to linalg named ops.
2191 // Generic op is ignored as not every arbitrary contraction body can be
2192 // expressed by a vector.contract.
2193 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2194 return failure();
2195
2196 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2197 Operation *reduceOp = matchLinalgReduction(outOperand);
2198 auto maybeKind = getCombinerOpKind(reduceOp);
2199 if (!maybeKind) {
2200 LDBG() << "Failed to determine contraction combining kind.";
2201 return failure();
2202 }
2203
2204 // Check that all dimensions are present in the input operands.
2205 // Arbitrary broadcasts are not supported by the vector contraction.
2206 // Broadcasts are expected to be decomposed before vectorization.
2207 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2208 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2209 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2210 LDBG() << "Contractions with broadcasts are not supported.";
2211 return failure();
2212 }
2213
2214 // Load operands.
2215 SmallVector<Value> vecOperands;
2216 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2217 // The operand vector shape is computed by mapping the canonical vector
2218 // shape to the operand's domain. Further permutations are left as a part of
2219 // the contraction.
2220 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2221 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2222 indexingMap.getNumResults(), rewriter.getContext());
2223 Type elemType = getElementTypeOrSelf(opOperand.get());
2224 VectorType readType =
2225 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2226
2228 rewriter, loc, opOperand.get(), readType,
2229 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2230 /*useInBoundsInsteadOfMasking=*/false);
2231 vecOperands.push_back(read);
2232 }
2233
2234 // Remap iterators from linalg to vector.
2235 SmallVector<Attribute> iterAttrs;
2236 auto iterators = linalgOp.getIteratorTypesArray();
2237 for (utils::IteratorType iter : iterators) {
2238 auto vecIter = iter == utils::IteratorType::parallel
2239 ? vector::IteratorType::parallel
2240 : vector::IteratorType::reduction;
2241 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2242 }
2243
2244 // Create contraction.
2245 Operation *contractOp = vector::ContractionOp::create(
2246 rewriter, loc, /*lhs=*/vecOperands[0],
2247 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2248 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2249 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2250
2251 // Store result.
2252 Operation *write = createWriteOrMaskedWrite(
2253 rewriter, loc, contractOp->getResult(0), outOperand->get());
2254
2255 // Finalize.
2256 if (!write->getResults().empty())
2257 newResults.push_back(write->getResult(0));
2258
2259 return success();
2260}
2261
2262namespace {
2263enum class ConvOperationKind { Conv, Pool };
2264} // namespace
2265
2266static bool isCastOfBlockArgument(Operation *op) {
2267 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2268 isa<BlockArgument>(op->getOperand(0));
2269}
2270
2271// Returns the ConvOperationKind of the op using reduceOp of the generic
2272// payload. If it is neither a convolution nor a pooling, it returns
2273// std::nullopt.
2274//
2275// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2276// + yield) and rhs is not used) then it is the body of a pooling
2277// If conv, check for single `mul` predecessor. The `mul` operands must be
2278// block arguments or extension of block arguments.
2279// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2280// must be block arguments or extension of block arguments.
2281static std::optional<ConvOperationKind>
2282getConvOperationKind(Operation *reduceOp) {
2283 int numBlockArguments =
2284 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2285
2286 switch (numBlockArguments) {
2287 case 1: {
2288 // Will be convolution if feeder is a MulOp.
2289 // A strength reduced version of MulOp for i1 type is AndOp which is also
2290 // supported. Otherwise, it can be pooling. This strength reduction logic
2291 // is in `buildBinaryFn` helper in the Linalg dialect.
2292 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2293 llvm::IsaPred<BlockArgument>);
2294 assert(feedValIt != reduceOp->operand_end() &&
2295 "Expected a non-block argument operand");
2296 Operation *feedOp = (*feedValIt).getDefiningOp();
2297 if (isCastOfBlockArgument(feedOp)) {
2298 return ConvOperationKind::Pool;
2299 }
2300
2301 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2302 (isa<arith::AndIOp>(feedOp) &&
2303 feedOp->getResultTypes()[0].isInteger(1))) &&
2304 llvm::all_of(feedOp->getOperands(), [](Value v) {
2305 if (isa<BlockArgument>(v))
2306 return true;
2307 if (Operation *op = v.getDefiningOp())
2308 return isCastOfBlockArgument(op);
2309 return false;
2310 }))) {
2311 return std::nullopt;
2312 }
2313
2314 return ConvOperationKind::Conv;
2315 }
2316 case 2:
2317 // Must be pooling
2318 return ConvOperationKind::Pool;
2319 default:
2320 return std::nullopt;
2321 }
2322}
2323
2324static bool isSupportedPoolKind(vector::CombiningKind kind) {
2325 switch (kind) {
2326 case vector::CombiningKind::ADD:
2327 case vector::CombiningKind::MAXNUMF:
2328 case vector::CombiningKind::MAXIMUMF:
2329 case vector::CombiningKind::MAXSI:
2330 case vector::CombiningKind::MAXUI:
2331 case vector::CombiningKind::MINNUMF:
2332 case vector::CombiningKind::MINIMUMF:
2333 case vector::CombiningKind::MINSI:
2334 case vector::CombiningKind::MINUI:
2335 return true;
2336 default:
2337 return false;
2338 }
2339}
2340
2341static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2342 auto getOperandType = [&](auto operand) {
2343 return dyn_cast<ShapedType>((operand->get()).getType());
2344 };
2345 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2346 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2347 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2348 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2349 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2350 // Note that this also ensures 2D and 3D convolutions are rejected.
2351 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2352 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2353 return failure();
2354
2355 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2356 if (!reduceOp)
2357 return failure();
2358
2359 auto maybeOper = getConvOperationKind(reduceOp);
2360 if (!maybeOper.has_value())
2361 return failure();
2362
2363 auto maybeKind = getCombinerOpKind(reduceOp);
2364 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2365 // can get strength reduced to `OR` which is also supported. This strength
2366 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2367 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2368 *maybeKind != vector::CombiningKind::OR) &&
2369 (*maybeOper != ConvOperationKind::Pool ||
2370 !isSupportedPoolKind(*maybeKind)))) {
2371 return failure();
2372 }
2373
2374 auto rhsRank = rhsShapedType.getRank();
2375 if (*maybeOper == ConvOperationKind::Pool) {
2376 if (rhsRank != 1)
2377 return failure();
2378 } else {
2379 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2380 return failure();
2381 }
2382
2383 return success();
2384}
2385
2386static LogicalResult vectorizeLinalgOpPrecondition(
2387 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2388 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2389 // tensor with dimension of 0 cannot be vectorized.
2390 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2391 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2392 }))
2393 return failure();
2394 // Check API contract for input vector sizes.
2395 if (!inputVectorSizes.empty() &&
2396 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2397 inputVectorSizes)))
2398 return failure();
2399
2400 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2401 linalgOp, flatten1DDepthwiseConv))) {
2402 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2403 return failure();
2404 }
2405
2406 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2407
2408 // Register CustomVectorizationPrecondition for extractOp.
2409 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2410
2411 // All types in the body should be a supported element type for VectorType.
2412 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2413 // Check if any custom hook can vectorize the inner op.
2414 if (llvm::any_of(
2415 customPreconditions,
2416 [&](const CustomVectorizationPrecondition &customPrecondition) {
2417 return succeeded(
2418 customPrecondition(&innerOp, vectorizeNDExtract));
2419 })) {
2420 continue;
2421 }
2422 if (!llvm::all_of(innerOp.getOperandTypes(),
2423 VectorType::isValidElementType)) {
2424 return failure();
2425 }
2426 if (!llvm::all_of(innerOp.getResultTypes(),
2427 VectorType::isValidElementType)) {
2428 return failure();
2429 }
2430 }
2431 if (isElementwise(linalgOp))
2432 return success();
2433
2434 // TODO: isaConvolutionOpInterface that can also infer from generic
2435 // features. But we will still need stride/dilation attributes that will be
2436 // annoying to reverse-engineer...
2437 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2438 return vectorizeConvOpPrecondition(linalgOp);
2439
2440 // TODO: the common vector shape is equal to the static loop sizes only when
2441 // all indexing maps are projected permutations. For convs and stencils the
2442 // logic will need to evolve.
2443 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2444 LDBG() << "precondition failed: not projected permutations";
2445 return failure();
2446 }
2447 if (failed(reductionPreconditions(linalgOp))) {
2448 LDBG() << "precondition failed: reduction preconditions";
2449 return failure();
2450 }
2451 return success();
2452}
2453
2454static LogicalResult
2455vectorizePackOpPrecondition(linalg::PackOp packOp,
2456 ArrayRef<int64_t> inputVectorSizes) {
2457 auto padValue = packOp.getPaddingValue();
2458 Attribute cstAttr;
2459 // TODO: Relax this condiiton
2460 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2461 LDBG() << "pad value is not constant: " << packOp;
2462 return failure();
2463 }
2464
2465 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2466 bool satisfyEmptyCond = true;
2467 if (inputVectorSizes.empty()) {
2468 if (!packOp.getDestType().hasStaticShape() ||
2469 !packOp.getSourceType().hasStaticShape())
2470 satisfyEmptyCond = false;
2471 }
2472
2473 if (!satisfyEmptyCond &&
2475 resultTensorShape.take_front(packOp.getSourceRank()),
2476 inputVectorSizes)))
2477 return failure();
2478
2479 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2480 return !getConstantIntValue(v).has_value();
2481 })) {
2482 LDBG() << "inner_tiles must be constant: " << packOp;
2483 return failure();
2484 }
2485
2486 return success();
2487}
2488
2489static LogicalResult
2490vectorizePadOpPrecondition(tensor::PadOp padOp,
2491 ArrayRef<int64_t> inputVectorSizes) {
2492 auto padValue = padOp.getConstantPaddingValue();
2493 if (!padValue) {
2494 LDBG() << "pad value is not constant: " << padOp;
2495 return failure();
2496 }
2497
2498 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2499 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2500 inputVectorSizes)))
2501 return failure();
2502
2503 // Padding with non-zero low pad values is not supported, unless the
2504 // corresponding result dim is 1 as this would require shifting the results to
2505 // the right for the low padded dims by the required amount of low padding.
2506 // However, we do support low padding if the dims being low padded have result
2507 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2508 // input size of that dimension will be dynamically zero (as the sum of the
2509 // low pad and input dim size has to be one) and hence we will create a zero
2510 // mask as the lowering logic just makes the mask one for the input dim size -
2511 // which is zero here. Hence we will load the pad value which is what we want
2512 // in this case. If the low pad is dynamically zero then the lowering is
2513 // correct as well as no shifts are necessary.
2514 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2515 Value padValue = en.value();
2516 unsigned pos = en.index();
2517 std::optional<int64_t> pad = getConstantIntValue(padValue);
2518 return (!pad.has_value() || pad.value() != 0) &&
2519 resultTensorShape[pos] != 1;
2520 })) {
2521 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2522 return failure();
2523 }
2524
2525 return success();
2526}
2527
2528/// Preconditions for scalable vectors.
2529///
2530/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2531/// models the fact that in practice we would only make selected dimensions
2532/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2533/// unconditionally - we are yet to identify meaningful conditions.
2534static LogicalResult
2535vectorizeScalableVectorPrecondition(Operation *op,
2536 ArrayRef<int64_t> inputVectorSizes,
2537 ArrayRef<bool> inputScalableVecDims) {
2538 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2539 "Number of input vector sizes and scalable dims doesn't match");
2540
2541 size_t numOfScalableDims =
2542 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2543
2544 if (numOfScalableDims == 0)
2545 return success();
2546
2547 auto linalgOp = dyn_cast<LinalgOp>(op);
2548
2549 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2550 // exception of UnpackOp for which there is a dedicated hook.
2551 if (!linalgOp) {
2552 return success(isa<linalg::UnPackOp>(op));
2553 }
2554
2555 // Cond 2: There's been no need for more than 2 scalable dims so far
2556 if (numOfScalableDims > 2)
2557 return failure();
2558
2559 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2560 // it matches one of the supported cases:
2561 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2562 // (*).
2563 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2564 // parallel dims.
2565 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2566 // dim.
2567 // The 2nd restriction above means that only Matmul-like Ops are supported
2568 // when 2 dims are scalable, e.g. :
2569 // * iterators = [parallel, parallel, reduction]
2570 // * scalable flags = [true, true, false]
2571 //
2572 // (*) Non-unit dims get folded away in practice.
2573 // TODO: Relax these conditions as good motivating examples are identified.
2574
2575 // Find the first scalable flag.
2576 bool seenNonUnitParallel = false;
2577 auto iterators = linalgOp.getIteratorTypesArray();
2578 SmallVector<bool> scalableFlags(inputScalableVecDims);
2579 int64_t idx = scalableFlags.size() - 1;
2580 while (!scalableFlags[idx]) {
2581 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2582 seenNonUnitParallel |=
2583 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2584
2585 iterators.pop_back();
2586 scalableFlags.pop_back();
2587 --idx;
2588 }
2589
2590 // Analyze the iterator corresponding to the first scalable dim.
2591 switch (iterators.back()) {
2592 case utils::IteratorType::reduction: {
2593 // Check 3. above is met.
2594 if (iterators.size() != inputVectorSizes.size()) {
2595 LDBG() << "Non-trailing reduction dim requested for scalable "
2596 "vectorization";
2597 return failure();
2598 }
2599 if (isa<linalg::MatmulOp>(op)) {
2600 LDBG()
2601 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2602 "is not supported";
2603 return failure();
2604 }
2605 break;
2606 }
2607 case utils::IteratorType::parallel: {
2608 // Check 1. and 2. above are met.
2609 if (seenNonUnitParallel) {
2610 LDBG() << "Inner parallel dim not requested for scalable "
2611 "vectorization";
2612 return failure();
2613 }
2614 break;
2615 }
2616 }
2617
2618 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2619 // supported for which expect the folowing config:
2620 // * iterators = [parallel, parallel, reduction]
2621 // * scalable flags = [true, true, false]
2622 if (numOfScalableDims == 2) {
2623 // Disallow below case which breaks 3. above:
2624 // * iterators = [..., parallel, reduction]
2625 // * scalable flags = [..., true, true]
2626 if (iterators.back() == utils::IteratorType::reduction) {
2627 LDBG() << "Higher dim than the trailing reduction dim requested for "
2628 "scalable "
2629 "vectorizatio";
2630 return failure();
2631 }
2632 scalableFlags.pop_back();
2633 iterators.pop_back();
2634
2635 if (!scalableFlags.back() ||
2636 (iterators.back() != utils::IteratorType::parallel))
2637 return failure();
2638 }
2639
2640 // Cond 4: Only the following ops are supported in the
2641 // presence of scalable vectors
2642 return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2643 isa<linalg::BatchMatmulOp>(op) ||
2644 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2645 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2646 isa<linalg::BatchMmt4DOp>(op) ||
2647 hasReductionIterator(linalgOp));
2648}
2649
2651 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2652 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2653 bool flatten1DDepthwiseConv) {
2654
2655 if (!hasVectorizationImpl(op))
2656 return failure();
2657
2658 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2659 inputScalableVecDims)))
2660 return failure();
2661
2663 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2664 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2665 vectorizeNDExtract,
2666 flatten1DDepthwiseConv);
2667 })
2668 .Case<tensor::PadOp>([&](auto padOp) {
2669 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2670 })
2671 .Case<linalg::PackOp>([&](auto packOp) {
2672 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2673 })
2674 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2675 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2676 })
2677 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2678 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2679 })
2680 .Default(failure());
2681}
2682
2683/// Converts affine.apply Ops to arithmetic operations.
2684static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2685 OpBuilder::InsertionGuard g(rewriter);
2686 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2687
2688 for (auto op : make_early_inc_range(toReplace)) {
2689 rewriter.setInsertionPoint(op);
2690 auto expanded = affine::expandAffineExpr(
2691 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2692 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2693 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2694 rewriter.replaceOp(op, expanded);
2695 }
2696}
2697
2698bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2699 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2700 tensor::InsertSliceOp>(op);
2701}
2702
2703FailureOr<VectorizationResult> mlir::linalg::vectorize(
2704 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2705 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2706 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2707 bool createNamedContraction) {
2708 LDBG() << "Attempting to vectorize: " << *op;
2709 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2710 LDBG() << "Input scalable vector dims: "
2711 << llvm::interleaved(inputScalableVecDims);
2712
2713 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2714 vectorizeNDExtract,
2715 flatten1DDepthwiseConv))) {
2716 LDBG() << "Vectorization pre-conditions failed";
2717 return failure();
2718 }
2719
2720 // Initialize vectorization state.
2721 VectorizationState state(rewriter);
2722 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2723 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2724 inputScalableVecDims,
2725 assumeDynamicDimsMatchVecSizes))) {
2726 LDBG() << "Vectorization state couldn't be initialized";
2727 return failure();
2728 }
2729 }
2730
2731 SmallVector<Value> results;
2732 auto vectorizeResult =
2734 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2735 // TODO: isaConvolutionOpInterface that can also infer from
2736 // generic features. Will require stride/dilation attributes
2737 // inference.
2738 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2739 FailureOr<Operation *> convOr = vectorizeConvolution(
2740 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2741 flatten1DDepthwiseConv);
2742 if (succeeded(convOr)) {
2743 llvm::append_range(results, (*convOr)->getResults());
2744 return success();
2745 }
2746
2747 LDBG() << "Unsupported convolution can't be vectorized.";
2748 return failure();
2749 }
2750
2751 if (createNamedContraction &&
2752 isa<ContractionOpInterface>(linalgOp.getOperation()))
2753 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2754 results);
2755
2756 LDBG()
2757 << "Vectorize generic by broadcasting to the canonical vector "
2758 "shape";
2759
2760 // Pre-process before proceeding.
2761 convertAffineApply(rewriter, linalgOp);
2762
2763 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2764 // to 'OpBuilder' when it is passed over to some methods like
2765 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2766 // erase an op within these methods, the actual rewriter won't be
2767 // notified and we will end up with read-after-free issues!
2768 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2769 })
2770 .Case<tensor::PadOp>([&](auto padOp) {
2771 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2772 results);
2773 })
2774 .Case<linalg::PackOp>([&](auto packOp) {
2775 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2776 results);
2777 })
2778 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2779 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2780 inputVectorSizes,
2781 inputScalableVecDims, results);
2782 })
2783 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2784 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2785 results);
2786 })
2787 .Default(failure());
2788
2789 if (failed(vectorizeResult)) {
2790 LDBG() << "Vectorization failed";
2791 return failure();
2792 }
2793
2794 return VectorizationResult{results};
2795}
2796
2797LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2798 memref::CopyOp copyOp) {
2799 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2800 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2801 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2802 return failure();
2803
2804 auto srcElementType = getElementTypeOrSelf(srcType);
2805 auto dstElementType = getElementTypeOrSelf(dstType);
2806 if (!VectorType::isValidElementType(srcElementType) ||
2807 !VectorType::isValidElementType(dstElementType))
2808 return failure();
2809
2810 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2811 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2812
2813 Location loc = copyOp->getLoc();
2814 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2815 SmallVector<Value> indices(srcType.getRank(), zero);
2816
2817 Value readValue = vector::TransferReadOp::create(
2818 rewriter, loc, readType, copyOp.getSource(), indices,
2819 /*padding=*/std::nullopt,
2820 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2821 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2822 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2823 ArrayRef<int64_t>());
2824 readValue =
2825 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2826 }
2827 Operation *writeValue = vector::TransferWriteOp::create(
2828 rewriter, loc, readValue, copyOp.getTarget(), indices,
2829 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2830 rewriter.replaceOp(copyOp, writeValue->getResults());
2831 return success();
2832}
2833
2834//----------------------------------------------------------------------------//
2835// Misc. vectorization patterns.
2836//----------------------------------------------------------------------------//
2837/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2838/// given operation type OpTy.
2839template <typename OpTy>
2840struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2841 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2842
2843 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2844 PatternRewriter &rewriter) const final {
2845 bool changed = false;
2846 // Insert users in vector, because some users may be replaced/removed.
2847 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2848 if (auto op = dyn_cast<OpTy>(user))
2849 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2850 return success(changed);
2851 }
2852
2853protected:
2854 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2855 tensor::PadOp padOp, OpTy op) const = 0;
2856};
2857
2858/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2859/// ```
2860/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2861/// %r = vector.transfer_read %0[%c0, %c0], %cst
2862/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2863/// ```
2864/// is rewritten to:
2865/// ```
2866/// %r = vector.transfer_read %src[%c0, %c0], %padding
2867/// {in_bounds = [true, true]}
2868/// : tensor<?x?xf32>, vector<17x5xf32>
2869/// ```
2870/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2871/// sure that the original padding value %cst was never used.
2872///
2873/// This rewrite is possible if:
2874/// - `xferOp` has no out-of-bounds dims or mask.
2875/// - Low padding is static 0.
2876/// - Single, scalar padding value.
2877struct PadOpVectorizationWithTransferReadPattern
2878 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2879 using VectorizePadOpUserPattern<
2880 vector::TransferReadOp>::VectorizePadOpUserPattern;
2881
2882 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2883 vector::TransferReadOp xferOp) const override {
2884 // Low padding must be static 0.
2885 if (!padOp.hasZeroLowPad())
2886 return failure();
2887 // Pad value must be a constant.
2888 auto padValue = padOp.getConstantPaddingValue();
2889 if (!padValue)
2890 return failure();
2891 // Padding value of existing `xferOp` is unused.
2892 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2893 return failure();
2894
2895 rewriter.modifyOpInPlace(xferOp, [&]() {
2896 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2897 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2898 rewriter.getBoolArrayAttr(inBounds));
2899 xferOp.getBaseMutable().assign(padOp.getSource());
2900 xferOp.getPaddingMutable().assign(padValue);
2901 });
2902
2903 return success();
2904 }
2905};
2906
2907/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2908/// This pattern rewrites TransferWriteOps that write to a padded tensor
2909/// value, where the same amount of padding is immediately removed again after
2910/// the write. In such cases, the TransferWriteOp can write to the non-padded
2911/// tensor value and apply out-of-bounds masking. E.g.:
2912/// ```
2913/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2914/// : tensor<...> to tensor<?x?xf32>
2915/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2916/// %2 = vector.transfer_write %vec, %1[...]
2917/// : vector<17x5xf32>, tensor<17x5xf32>
2918/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2919/// : tensor<17x5xf32> to tensor<?x?xf32>
2920/// ```
2921/// is rewritten to:
2922/// ```
2923/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2924/// : tensor<...> to tensor<?x?xf32>
2925/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2926/// tensor<?x?xf32>
2927/// ```
2928/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2929/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2930/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2931/// from %r's old dimensions.
2932///
2933/// This rewrite is possible if:
2934/// - Low padding is static 0.
2935/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2936/// ExtractSliceOp trims the same amount of padding that was added
2937/// beforehand.
2938/// - Single, scalar padding value.
2939struct PadOpVectorizationWithTransferWritePattern
2940 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2941 using VectorizePadOpUserPattern<
2942 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2943
2944 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2945 vector::TransferWriteOp xferOp) const override {
2946 // TODO: support 0-d corner case.
2947 if (xferOp.getTransferRank() == 0)
2948 return failure();
2949
2950 // Low padding must be static 0.
2951 if (!padOp.hasZeroLowPad())
2952 return failure();
2953 // Pad value must be a constant.
2954 auto padValue = padOp.getConstantPaddingValue();
2955 if (!padValue)
2956 return failure();
2957 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2958 if (!xferOp->hasOneUse())
2959 return failure();
2960 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2961 if (!trimPadding)
2962 return failure();
2963 // Only static zero offsets supported when trimming padding.
2964 if (!trimPadding.hasZeroOffset())
2965 return failure();
2966 // trimPadding must remove the amount of padding that was added earlier.
2967 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2968 return failure();
2969
2970 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2971 rewriter.setInsertionPoint(xferOp);
2972
2973 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2974 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2975 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2976 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2977 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2978 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2979
2980 return success();
2981 }
2982
2983 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2984 /// i.e., same dimensions.
2985 ///
2986 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2987 /// dimensions, this function tries to infer the (static) tensor size by
2988 /// looking at the defining op and utilizing op-specific knowledge.
2989 ///
2990 /// This is a conservative analysis. In case equal tensor sizes cannot be
2991 /// proven statically, this analysis returns `false` even though the tensor
2992 /// sizes may turn out to be equal at runtime.
2993 bool hasSameTensorSize(Value beforePadding,
2994 tensor::ExtractSliceOp afterTrimming) const {
2995 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2996 // result and CastOp operand.
2997 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2998 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2999 return true;
3000
3001 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
3002 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3003 // Only RankedTensorType supported.
3004 if (!t1 || !t2)
3005 return false;
3006 // Rank of both values must be the same.
3007 if (t1.getRank() != t2.getRank())
3008 return false;
3009
3010 // All static dimensions must be the same. Mixed cases (e.g., dimension
3011 // static in `t1` but dynamic in `t2`) are not supported.
3012 for (unsigned i = 0; i < t1.getRank(); ++i) {
3013 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3014 return false;
3015 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3016 return false;
3017 }
3018
3019 // Nothing more to check if all dimensions are static.
3020 if (t1.getNumDynamicDims() == 0)
3021 return true;
3022
3023 // All dynamic sizes must be the same. The only supported case at the
3024 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3025 // thereof).
3026
3027 // Apart from CastOp, only ExtractSliceOp is supported.
3028 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3029 if (!beforeSlice)
3030 return false;
3031
3032 assert(static_cast<size_t>(t1.getRank()) ==
3033 beforeSlice.getMixedSizes().size());
3034 assert(static_cast<size_t>(t2.getRank()) ==
3035 afterTrimming.getMixedSizes().size());
3036
3037 for (unsigned i = 0; i < t1.getRank(); ++i) {
3038 // Skip static dimensions.
3039 if (!t1.isDynamicDim(i))
3040 continue;
3041 auto size1 = beforeSlice.getMixedSizes()[i];
3042 auto size2 = afterTrimming.getMixedSizes()[i];
3043
3044 // Case 1: Same value or same constant int.
3045 if (isEqualConstantIntOrValue(size1, size2))
3046 continue;
3047
3048 // Other cases: Take a deeper look at defining ops of values.
3049 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3050 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3051 if (!v1 || !v2)
3052 return false;
3053
3054 // Case 2: Both values are identical AffineMinOps. (Should not happen if
3055 // CSE is run.)
3056 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3057 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3058 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3059 minOp1.getOperands() == minOp2.getOperands())
3060 continue;
3061
3062 // Add additional cases as needed.
3063 }
3064
3065 // All tests passed.
3066 return true;
3067 }
3068};
3069
3070/// Returns the effective Pad value for the input op, provided it's a scalar.
3071///
3072/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3073/// this Op performs padding, retrieve the padding value provided that it's
3074/// a scalar and static/fixed for all the padded values. Returns an empty value
3075/// otherwise.
3076///
3077/// TODO: This is used twice (when checking vectorization pre-conditions and
3078/// when vectorizing). Cache results instead of re-running.
3079static Value getStaticPadVal(Operation *op) {
3080 if (!op)
3081 return {};
3082
3083 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3084 // being broadcast, provided that it's a scalar.
3085 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3086 auto source = bcast.getSource();
3087 if (llvm::dyn_cast<VectorType>(source.getType()))
3088 return {};
3089
3090 return source;
3091 }
3092
3093 // 2. linalg.fill - use the scalar input value that used to fill the output
3094 // tensor.
3095 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3096 return fill.getInputs()[0];
3097 }
3098
3099 // 3. tensor.generateOp - can't guarantee the value is fixed without
3100 // analysing, bail out.
3101 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3102 return {};
3103 }
3104
3105 // 4. vector.transfer_write - inspect the input vector that's written from. If
3106 // if contains a single value that has been broadcast (e.g. via
3107 // vector.broadcast), extract it, fail otherwise.
3108 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3109 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3110
3111 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3112 // than the input tensor, then, provided it's constant, we'll extract the
3113 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3114 // TODO: Clarify the semantics when the input tensor is larger than the
3115 // destination.
3116 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3117 return getStaticPadVal(slice.getDest().getDefiningOp());
3118
3119 return {};
3120}
3121
3122static LogicalResult
3123vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3124 ArrayRef<int64_t> inputVectorSizes,
3125 SmallVectorImpl<Value> &newResults) {
3126 // TODO: Introduce a parent class that will handle the insertion point update.
3127 OpBuilder::InsertionGuard g(rewriter);
3128 rewriter.setInsertionPoint(sliceOp);
3129
3130 TypedValue<RankedTensorType> source = sliceOp.getSource();
3131 auto sourceType = source.getType();
3132 auto resultType = sliceOp.getResultType();
3133
3134 Value padValue = getStaticPadVal(sliceOp);
3135
3136 if (!padValue) {
3137 auto elemType = sourceType.getElementType();
3138 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3139 rewriter.getZeroAttr(elemType));
3140 }
3141
3142 // 2. Get the vector shape
3143 SmallVector<int64_t> vecShape;
3144 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3145 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3146 if (!inputVectorSizes.empty()) {
3147 vecShape.push_back(inputVectorSizes[i]);
3148 } else if (!sourceType.isDynamicDim(i)) {
3149 vecShape.push_back(sourceType.getDimSize(i));
3150 } else if (!resultType.isDynamicDim(i)) {
3151 // Source shape is not statically known, but result shape is.
3152 // Vectorize with size of result shape. This may be larger than the
3153 // source size.
3154 // FIXME: Using rankDiff implies that the source tensor is inserted at
3155 // the end of the destination tensor. However, that's not required.
3156 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3157 } else {
3158 // Neither source nor result dim of padOp is static. Cannot vectorize
3159 // the copy.
3160 return failure();
3161 }
3162 }
3163 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3164
3165 // 3. Generate TransferReadOp + TransferWriteOp
3166 auto loc = sliceOp.getLoc();
3167
3168 // Create read
3169 SmallVector<Value> readIndices(
3170 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3172 rewriter, loc, source, vecType, padValue,
3173 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3174
3175 // Create write
3176 auto writeIndices =
3177 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3178 Operation *write =
3179 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3180 writeIndices, inputVectorSizes.empty());
3181
3182 // 4. Finalize
3183 newResults.push_back(write->getResult(0));
3184
3185 return success();
3186}
3187
3188/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3189/// ```
3190/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3191/// %r = tensor.insert_slice %0
3192/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3193/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3194/// ```
3195/// is rewritten to:
3196/// ```
3197/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3198/// : tensor<?x?xf32>, vector<17x5xf32>
3199/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3200/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3201/// ```
3202///
3203/// This rewrite is possible if:
3204/// - Low padding is static 0.
3205/// - `padOp` result shape is static.
3206/// - The entire padded tensor is inserted.
3207/// (Implies that sizes of `insertOp` are all static.)
3208/// - Only unit strides in `insertOp`.
3209/// - Single, scalar padding value.
3210/// - `padOp` result not used as destination.
3211struct PadOpVectorizationWithInsertSlicePattern
3212 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3213 using VectorizePadOpUserPattern<
3214 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3215
3216 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3217 tensor::InsertSliceOp insertOp) const override {
3218 // Low padding must be static 0.
3219 if (!padOp.hasZeroLowPad())
3220 return failure();
3221 // Only unit stride supported.
3222 if (!insertOp.hasUnitStride())
3223 return failure();
3224 // Pad value must be a constant.
3225 auto padValue = padOp.getConstantPaddingValue();
3226 if (!padValue)
3227 return failure();
3228 // Dynamic shapes not supported.
3229 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3230 return failure();
3231 // Pad result not used as destination.
3232 if (insertOp.getDest() == padOp.getResult())
3233 return failure();
3234
3235 auto vecType = VectorType::get(padOp.getType().getShape(),
3236 padOp.getType().getElementType());
3237 unsigned vecRank = vecType.getRank();
3238 unsigned tensorRank = insertOp.getType().getRank();
3239
3240 // Check if sizes match: Insert the entire tensor into most minor dims.
3241 // (No permutations allowed.)
3242 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3243 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3244 if (!llvm::all_of(
3245 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3246 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3247 }))
3248 return failure();
3249
3250 // Insert the TransferReadOp and TransferWriteOp at the position of the
3251 // InsertSliceOp.
3252 rewriter.setInsertionPoint(insertOp);
3253
3254 // Generate TransferReadOp: Read entire source tensor and add high
3255 // padding.
3256 SmallVector<Value> readIndices(
3257 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3258 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3259 vecType, padOp.getSource(),
3260 readIndices, padValue);
3261
3262 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3263 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3264 // source must fit into the destination at the specified offsets.
3265 auto writeIndices = getValueOrCreateConstantIndexOp(
3266 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3267 SmallVector<bool> inBounds(vecRank, true);
3268 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3269 insertOp, read, insertOp.getDest(), writeIndices,
3270 ArrayRef<bool>{inBounds});
3271
3272 return success();
3273 }
3274};
3275
3277 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3278 patterns.add<PadOpVectorizationWithTransferReadPattern,
3279 PadOpVectorizationWithTransferWritePattern,
3280 PadOpVectorizationWithInsertSlicePattern>(
3281 patterns.getContext(), baseBenefit.getBenefit() + 1);
3282}
3283
3284//----------------------------------------------------------------------------//
3285// Forwarding patterns
3286//----------------------------------------------------------------------------//
3287
3288/// Check whether there is any interleaved use of any `values` between
3289/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3290/// is in a different block.
3291static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3292 ValueRange values) {
3293 if (firstOp->getBlock() != secondOp->getBlock() ||
3294 !firstOp->isBeforeInBlock(secondOp)) {
3295 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3296 << ", second op: " << *secondOp;
3297 return true;
3298 }
3299 for (auto v : values) {
3300 for (auto &u : v.getUses()) {
3301 Operation *owner = u.getOwner();
3302 if (owner == firstOp || owner == secondOp)
3303 continue;
3304 // TODO: this is too conservative, use dominance info in the future.
3305 if (owner->getBlock() == firstOp->getBlock() &&
3306 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3307 continue;
3308 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3309 << ", second op: " << *secondOp;
3310 return true;
3311 }
3312 }
3313 return false;
3314}
3315
3316/// Return the unique subview use of `v` if it is indeed unique, null
3317/// otherwise.
3318static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3319 memref::SubViewOp subViewOp;
3320 for (auto &u : v.getUses()) {
3321 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3322 if (subViewOp)
3323 return memref::SubViewOp();
3324 subViewOp = newSubViewOp;
3325 }
3326 }
3327 return subViewOp;
3328}
3329
3330/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3331/// when available.
3333 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3334
3335 // TODO: support mask.
3336 if (xferOp.getMask())
3337 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3338
3339 // Transfer into `view`.
3340 Value viewOrAlloc = xferOp.getBase();
3341 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3342 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3343 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3344
3345 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3346 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3347 if (!subViewOp)
3348 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3349 Value subView = subViewOp.getResult();
3350
3351 // Find the copy into `subView` without interleaved uses.
3352 memref::CopyOp copyOp;
3353 for (auto &u : subView.getUses()) {
3354 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3355 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3356 if (newCopyOp.getTarget() != subView)
3357 continue;
3358 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3359 continue;
3360 copyOp = newCopyOp;
3361 break;
3362 }
3363 }
3364 if (!copyOp)
3365 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3366
3367 // Find the fill into `viewOrAlloc` without interleaved uses before the
3368 // copy.
3369 FillOp maybeFillOp;
3370 for (auto &u : viewOrAlloc.getUses()) {
3371 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3372 assert(isa<MemRefType>(newFillOp.output().getType()));
3373 if (newFillOp.output() != viewOrAlloc)
3374 continue;
3375 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3376 continue;
3377 maybeFillOp = newFillOp;
3378 break;
3379 }
3380 }
3381 // Ensure padding matches.
3382 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3383 return rewriter.notifyMatchFailure(xferOp,
3384 "padding value does not match fill");
3385
3386 // `in` is the subview that memref.copy reads. Replace it.
3387 Value in = copyOp.getSource();
3388
3389 // memref.copy + linalg.fill can be used to create a padded local buffer.
3390 // The `masked` attribute is only valid on this padded buffer.
3391 // When forwarding to vector.transfer_read, the attribute must be reset
3392 // conservatively.
3393 auto vectorType = xferOp.getVectorType();
3394 Value res = vector::TransferReadOp::create(
3395 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3396 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3397 rewriter.getBoolArrayAttr(
3398 SmallVector<bool>(vectorType.getRank(), false)));
3399
3400 if (maybeFillOp)
3401 rewriter.eraseOp(maybeFillOp);
3402 rewriter.eraseOp(copyOp);
3403 rewriter.replaceOp(xferOp, res);
3404
3405 return success();
3406}
3407
3408/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3409/// when available.
3411 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3412 // TODO: support mask.
3413 if (xferOp.getMask())
3414 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3415
3416 // Transfer into `viewOrAlloc`.
3417 Value viewOrAlloc = xferOp.getBase();
3418 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3419 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3420 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3421
3422 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3423 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3424 if (!subViewOp)
3425 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3426 Value subView = subViewOp.getResult();
3427
3428 // Find the copy from `subView` without interleaved uses.
3429 memref::CopyOp copyOp;
3430 for (auto &u : subViewOp.getResult().getUses()) {
3431 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3432 if (newCopyOp.getSource() != subView)
3433 continue;
3434 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3435 continue;
3436 copyOp = newCopyOp;
3437 break;
3438 }
3439 }
3440 if (!copyOp)
3441 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3442
3443 // `out` is the subview copied into that we replace.
3444 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3445 Value out = copyOp.getTarget();
3446
3447 // Forward vector.transfer into copy.
3448 // memref.copy + linalg.fill can be used to create a padded local buffer.
3449 // The `masked` attribute is only valid on this padded buffer.
3450 // When forwarding to vector.transfer_write, the attribute must be reset
3451 // conservatively.
3452 auto vector = xferOp.getVector();
3453 vector::TransferWriteOp::create(
3454 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3455 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3456 rewriter.getBoolArrayAttr(SmallVector<bool>(
3457 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3458
3459 rewriter.eraseOp(copyOp);
3460 rewriter.eraseOp(xferOp);
3461
3462 return success();
3463}
3464
3465//===----------------------------------------------------------------------===//
3466// Convolution vectorization patterns
3467//===----------------------------------------------------------------------===//
3468
3469template <int N>
3470static void bindShapeDims(ShapedType shapedType) {}
3471
3472template <int N, typename IntTy, typename... IntTy2>
3473static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3474 val = shapedType.getShape()[N];
3475 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3476}
3477
3478/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3479template <typename... IntTy>
3480static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3481 bindShapeDims<0>(shapedType, vals...);
3482}
3483
3484namespace {
3485/// Generate a vector implementation for either:
3486/// ```
3487/// Op def: ( w, kw )
3488/// Iters: ({Par(), Red()})
3489/// Layout: {{w + kw}, {kw}, {w}}
3490/// ```
3491/// kw is unrolled.
3492///
3493/// or
3494///
3495/// ```
3496/// Op def: ( n, w, c, kw, f )
3497/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3498/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3499/// ```
3500/// kw is unrolled, w is unrolled iff dilationW > 1.
3501///
3502/// or
3503///
3504/// ```
3505/// Op def: ( n, c, w, f, kw )
3506/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3507/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3508/// ```
3509/// kw is unrolled, w is unrolled iff dilationW > 1.
3510///
3511/// or
3512///
3513/// ```
3514/// Op def: ( n, w, c, kw )
3515/// Iters: ({Par(), Par(), Par(), Red()})
3516/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3517/// ```
3518/// kw is unrolled, w is unrolled iff dilationW > 1.
3519struct Conv1DGenerator
3520 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3521 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3522 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3523
3524 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3525 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3526 resShaped = linalgOp.getDpsInitOperand(0)->get();
3527 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3528 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3529 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3530
3531 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3532 redOp = reduceOp->getName().getIdentifier();
3533
3534 setConvOperationKind(reduceOp);
3535
3536 auto maybeKind = getCombinerOpKind(reduceOp);
3537 reductionKind = maybeKind.value();
3538
3539 // The ConvolutionOpInterface gives us guarantees of existence for
3540 // strides/dilations. However, we do not need to rely on those, we can
3541 // simply use them if present, otherwise use the default and let the generic
3542 // conv. matcher in the ConvGenerator succeed or fail.
3543 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3544 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3545 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3546 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3547 }
3548
3549 /// Generate a vector implementation for:
3550 /// ```
3551 /// Op def: ( w, kw )
3552 /// Iters: ({Par(), Red()})
3553 /// Layout: {{w + kw}, {kw}, {w}}
3554 /// ```
3555 /// kw is always unrolled.
3556 ///
3557 /// or
3558 ///
3559 /// ```
3560 /// Op def: ( n, w, c, kw, f )
3561 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3562 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3563 /// ```
3564 /// kw is always unrolled.
3565 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3566 /// > 1.
3567 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3568 int64_t nSize, wSize, cSize, kwSize, fSize;
3569 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3570 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3571 switch (conv1DOpOrder) {
3572 case Conv1DOpOrder::W:
3573 // Initialize unused dimensions
3574 nSize = fSize = cSize = 0;
3575 // out{W}
3576 bindShapeDims(resShapedType, wSize);
3577 // kernel{kw}
3578 bindShapeDims(rhsShapedType, kwSize);
3579 lhsShape = {// iw = ow + kw - 1
3580 // (i.e. 16 convolved with 3 -> 14)
3581 (wSize + kwSize - 1)};
3582 rhsShape = {kwSize};
3583 resShape = {wSize};
3584 break;
3585 case Conv1DOpOrder::Nwc:
3586 // out{n, w, f}
3587 bindShapeDims(resShapedType, nSize, wSize, fSize);
3588 switch (oper) {
3589 case ConvOperationKind::Conv:
3590 // kernel{kw, c, f}
3591 bindShapeDims(rhsShapedType, kwSize, cSize);
3592 break;
3593 case ConvOperationKind::Pool:
3594 // kernel{kw}
3595 bindShapeDims(rhsShapedType, kwSize);
3596 cSize = fSize;
3597 break;
3598 }
3599 lhsShape = {nSize,
3600 // iw = ow * sw + kw * dw - 1
3601 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3602 // Perform the proper inclusive -> exclusive -> inclusive.
3603 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3604 1,
3605 cSize};
3606 switch (oper) {
3607 case ConvOperationKind::Conv:
3608 rhsShape = {kwSize, cSize, fSize};
3609 break;
3610 case ConvOperationKind::Pool:
3611 rhsShape = {kwSize};
3612 break;
3613 }
3614 resShape = {nSize, wSize, fSize};
3615 break;
3616 case Conv1DOpOrder::Ncw:
3617 // out{n, f, w}
3618 bindShapeDims(resShapedType, nSize, fSize, wSize);
3619 switch (oper) {
3620 case ConvOperationKind::Conv:
3621 // kernel{f, c, kw}
3622 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3623 break;
3624 case ConvOperationKind::Pool:
3625 // kernel{kw}
3626 bindShapeDims(rhsShapedType, kwSize);
3627 cSize = fSize;
3628 break;
3629 }
3630 lhsShape = {nSize, cSize,
3631 // iw = ow * sw + kw * dw - 1
3632 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3633 // Perform the proper inclusive -> exclusive -> inclusive.
3634 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3635 1};
3636 switch (oper) {
3637 case ConvOperationKind::Conv:
3638 rhsShape = {fSize, cSize, kwSize};
3639 break;
3640 case ConvOperationKind::Pool:
3641 rhsShape = {kwSize};
3642 break;
3643 }
3644 resShape = {nSize, fSize, wSize};
3645 break;
3646 }
3647
3648 vector::TransferWriteOp write;
3649 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3650
3651 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3652 // When strideW == 1, we can batch the contiguous loads and avoid
3653 // unrolling
3654 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3655
3656 Type lhsEltType = lhsShapedType.getElementType();
3657 Type rhsEltType = rhsShapedType.getElementType();
3658 Type resEltType = resShapedType.getElementType();
3659 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3660 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3661 auto resType = VectorType::get(resShape, resEltType);
3662 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3663 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3664 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3665 SmallVector<Value> resPadding(resShape.size(), zero);
3666
3667 // Read the whole lhs, rhs and res in one shot (with zero padding).
3668 Value lhs = vector::TransferReadOp::create(
3669 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3670 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3671 // This is needed only for Conv.
3672 Value rhs = nullptr;
3673 if (oper == ConvOperationKind::Conv)
3674 rhs = vector::TransferReadOp::create(
3675 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3676 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3677 Value res = vector::TransferReadOp::create(
3678 rewriter, loc, resType, resShaped, resPadding,
3679 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3680
3681 // The base vectorization case for channeled convolution is input:
3682 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3683 // vectorization case, we do pre transpose on input, weight, and output.
3684 switch (conv1DOpOrder) {
3685 case Conv1DOpOrder::W:
3686 case Conv1DOpOrder::Nwc:
3687 // Base case, so no transposes necessary.
3688 break;
3689 case Conv1DOpOrder::Ncw: {
3690 // To match base vectorization case, we pre-transpose current case.
3691 // ncw -> nwc
3692 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3693 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3694 // fcw -> wcf
3695 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3696
3697 // This is needed only for Conv.
3698 if (oper == ConvOperationKind::Conv)
3699 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3700 // nfw -> nwf
3701 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3702 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3703 break;
3704 }
3705 }
3706
3707 //===------------------------------------------------------------------===//
3708 // Begin vector-only rewrite part
3709 //===------------------------------------------------------------------===//
3710 // Unroll along kw and read slices of lhs and rhs.
3711 SmallVector<Value> lhsVals, rhsVals, resVals;
3712 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3713 kwSize, strideW, dilationW, wSizeStep,
3714 isSingleChanneled);
3715 // Do not do for pooling.
3716 if (oper == ConvOperationKind::Conv)
3717 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3718 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3719 wSizeStep, isSingleChanneled);
3720
3721 auto linearIndex = [&](int64_t kw, int64_t w) {
3722 return kw * (wSize / wSizeStep) + w;
3723 };
3724
3725 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3726 // or perform outerproduct for non-channeled convolution or perform simple
3727 // arith operation for pooling
3728 for (int64_t kw = 0; kw < kwSize; ++kw) {
3729 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3730 switch (oper) {
3731 case ConvOperationKind::Conv:
3732 if (isSingleChanneled) {
3733 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3734 lhsVals[linearIndex(kw, w)],
3735 rhsVals[kw], resVals[w]);
3736 } else {
3737 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3738 lhsVals[linearIndex(kw, w)],
3739 rhsVals[kw], resVals[w]);
3740 }
3741 break;
3742 case ConvOperationKind::Pool:
3743 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3744 resVals[w]);
3745 break;
3746 }
3747 }
3748 }
3749
3750 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3751 isSingleChanneled);
3752 //===------------------------------------------------------------------===//
3753 // End vector-only rewrite part
3754 //===------------------------------------------------------------------===//
3755
3756 // The base vectorization case for channeled convolution is output:
3757 // {n,w,f} To reuse the result from base pattern vectorization case, we
3758 // post transpose the base case result.
3759 switch (conv1DOpOrder) {
3760 case Conv1DOpOrder::W:
3761 case Conv1DOpOrder::Nwc:
3762 // Base case, so no transposes necessary.
3763 break;
3764 case Conv1DOpOrder::Ncw: {
3765 // nwf -> nfw
3766 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3767 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3768 break;
3769 }
3770 }
3771
3772 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3773 resPadding)
3774 .getOperation();
3775 }
3776
3777 // Take a value and widen to have the same element type as `ty`.
3778 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3779 const Type srcElementType = getElementTypeOrSelf(val.getType());
3780 const Type dstElementType = getElementTypeOrSelf(ty);
3781 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3782 if (srcElementType == dstElementType)
3783 return val;
3784
3785 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3786 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3787 const Type dstType =
3788 cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3789
3790 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3791 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3792 }
3793
3794 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3795 srcWidth < dstWidth)
3796 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3797
3798 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3799 srcWidth < dstWidth)
3800 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3801
3802 assert(false && "unhandled promotion case");
3803 return nullptr;
3804 }
3805
3806 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3807 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3808 Value lhs, Value rhs, Value res) {
3809 vector::IteratorType par = vector::IteratorType::parallel;
3810 vector::IteratorType red = vector::IteratorType::reduction;
3811 AffineExpr n, w, f, c;
3812 bindDims(ctx, n, w, f, c);
3813 lhs = promote(rewriter, loc, lhs, res.getType());
3814 rhs = promote(rewriter, loc, rhs, res.getType());
3815 auto contrationOp = vector::ContractionOp::create(
3816 rewriter, loc, lhs, rhs, res,
3817 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3818 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3819 contrationOp.setKind(reductionKind);
3820 return contrationOp;
3821 }
3822
3823 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3824 // convolution.
3825 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3826 Value lhs, Value rhs, Value res) {
3827 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3828 rhs, res, vector::CombiningKind::ADD);
3829 }
3830
3831 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3832 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3833 Value res) {
3834 if (isPoolExt)
3835 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3836 return rewriter
3837 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3838 ->getResult(0);
3839 }
3840
3841 /// Generate a vector implementation for:
3842 /// ```
3843 /// Op def: ( n, w, c, kw)
3844 /// Iters: ({Par(), Par(), Par(), Red()})
3845 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3846 /// ```
3847 /// kw is always unrolled.
3848 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3849 /// > 1.
3850 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3851 bool channelDimScalableFlag,
3852 bool flatten) {
3853 bool scalableChDim = false;
3854 bool useMasking = false;
3855 int64_t nSize, wSize, cSize, kwSize;
3856 // kernel{kw, c}
3857 bindShapeDims(rhsShapedType, kwSize, cSize);
3858 if (ShapedType::isDynamic(cSize)) {
3859 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3860 cSize = channelDimVecSize;
3861 // Scalable vectors are only used when both conditions are met:
3862 // 1. channel dim is dynamic
3863 // 2. channelDimScalableFlag is set
3864 scalableChDim = channelDimScalableFlag;
3865 useMasking = true;
3866 }
3867
3868 assert(!(useMasking && flatten) &&
3869 "Unsupported flattened conv with dynamic shapes");
3870
3871 // out{n, w, c}
3872 bindShapeDims(resShapedType, nSize, wSize);
3873
3874 vector::TransferWriteOp write;
3875 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3876
3877 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3878 // When strideW == 1, we can batch the contiguous loads and avoid
3879 // unrolling
3880 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3881
3882 Type lhsEltType = lhsShapedType.getElementType();
3883 Type rhsEltType = rhsShapedType.getElementType();
3884 Type resEltType = resShapedType.getElementType();
3885 VectorType lhsType = VectorType::get(
3886 {nSize,
3887 // iw = ow * sw + kw * dw - 1
3888 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3889 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3890 cSize},
3891 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3892 VectorType rhsType =
3893 VectorType::get({kwSize, cSize}, rhsEltType,
3894 /*scalableDims=*/{false, scalableChDim});
3895 VectorType resType =
3896 VectorType::get({nSize, wSize, cSize}, resEltType,
3897 /*scalableDims=*/{false, false, scalableChDim});
3898
3899 // Masks the input xfer Op along the channel dim, iff the corresponding
3900 // scalable flag is set.
3901 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3902 ArrayRef<bool> scalableDims,
3903 Operation *opToMask) {
3904 if (!useMasking)
3905 return opToMask;
3906 auto maskType =
3907 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3908
3909 SmallVector<bool> inBounds(maskShape.size(), true);
3910 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3911 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3912 rewriter.getBoolArrayAttr(inBounds));
3913
3914 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3915 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3916
3917 Value maskOp =
3918 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3919
3920 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3921 };
3922
3923 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3924 // 0].
3925 Value lhs = vector::TransferReadOp::create(
3926 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3927 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3928 auto *maybeMaskedLhs = maybeMaskXferOp(
3929 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3930
3931 // Read rhs slice of size {kw, c} @ [0, 0].
3932 Value rhs = vector::TransferReadOp::create(
3933 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3934 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3935 auto *maybeMaskedRhs = maybeMaskXferOp(
3936 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3937
3938 // Read res slice of size {n, w, c} @ [0, 0, 0].
3939 Value res = vector::TransferReadOp::create(
3940 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3941 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3942 auto *maybeMaskedRes = maybeMaskXferOp(
3943 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3944
3945 //===------------------------------------------------------------------===//
3946 // Begin vector-only rewrite part
3947 //===------------------------------------------------------------------===//
3948 // Unroll along kw and read slices of lhs and rhs.
3949 SmallVector<Value> lhsVals, rhsVals, resVals;
3950 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3951 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3952
3953 // Extract lhs slice of size {n, wSizeStep, c}
3954 // @ [0, sw * w + dw * kw, 0].
3955 for (int64_t kw = 0; kw < kwSize; ++kw) {
3956 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3957 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3958 rewriter, loc, maybeMaskedLhs->getResult(0),
3959 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3960 inOutSliceSizes, inOutStrides));
3961 }
3962 }
3963 // Extract rhs slice of size {c} @ [kw].
3964 for (int64_t kw = 0; kw < kwSize; ++kw) {
3965 rhsVals.push_back(
3966 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3967 /*offsets=*/ArrayRef<int64_t>{kw}));
3968 }
3969 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3970 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3971 resVals.push_back(vector::ExtractStridedSliceOp::create(
3972 rewriter, loc, maybeMaskedRes->getResult(0),
3973 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3974 inOutStrides));
3975 }
3976
3977 auto linearIndex = [&](int64_t kw, int64_t w) {
3978 return kw * (wSize / wSizeStep) + w;
3979 };
3980
3981 // Note - the scalable flags are ignored as flattening combined with
3982 // scalable vectorization is not supported.
3983 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3984 auto lhsTypeAfterFlattening =
3985 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3986 auto resTypeAfterFlattening =
3987 VectorType::get(inOutFlattenSliceSizes, resEltType);
3988
3989 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3990 for (int64_t kw = 0; kw < kwSize; ++kw) {
3991 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3992 Value lhsVal = lhsVals[linearIndex(kw, w)];
3993 Value resVal = resVals[w];
3994 if (flatten) {
3995 // Flatten the input and output vectors (collapse the channel
3996 // dimension)
3997 lhsVal =
3998 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3999 lhsVals[linearIndex(kw, w)]);
4000 resVal = vector::ShapeCastOp::create(
4001 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4002 }
4003 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4004 rhsVals[kw], resVal, flatten);
4005 if (flatten) {
4006 // Un-flatten the output vector (restore the channel dimension)
4007 resVals[w] = vector::ShapeCastOp::create(
4008 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4009 resVals[w]);
4010 }
4011 }
4012 }
4013
4014 // Its possible we failed to create the Fma.
4015 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4016 // Manually revert (in reverse order) to avoid leaving a bad IR state.
4017 for (auto &collection :
4018 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4019 for (Value v : collection)
4020 rewriter.eraseOp(v.getDefiningOp());
4021 return rewriter.notifyMatchFailure(op, "failed to create FMA");
4022 }
4023
4024 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4025 // This does not depend on kw.
4026 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4027 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4028 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4029 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4030 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4031 }
4032 //===------------------------------------------------------------------===//
4033 // End vector-only rewrite part
4034 //===------------------------------------------------------------------===//
4035
4036 // Write back res slice of size {n, w, c} @ [0, 0, 0].
4037 Operation *resOut = vector::TransferWriteOp::create(
4038 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4039 ValueRange{zero, zero, zero});
4040 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4041 resOut);
4042 }
4043
4044 /// Lower:
4045 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4046 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4047 /// to MulAcc.
4048 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4049 Value lhs, Value rhs, Value res,
4050 bool flatten) {
4051 auto rhsTy = cast<ShapedType>(rhs.getType());
4052 auto resTy = cast<ShapedType>(res.getType());
4053
4054 // TODO(suderman): Change this to use a vector.ima intrinsic.
4055 lhs = promote(rewriter, loc, lhs, resTy);
4056
4057 if (flatten) {
4058 // NOTE: This following logic won't work for scalable vectors. For this
4059 // reason, "flattening" is not supported when shapes are dynamic (this
4060 // should be captured by one of the pre-conditions).
4061
4062 // There are two options for handling the filter:
4063 // * shape_cast(broadcast(filter))
4064 // * broadcast(shuffle(filter))
4065 // Opt for the option without shape_cast to simplify the codegen.
4066 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4067 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4068
4069 SmallVector<int64_t, 16> indices;
4070 for (int i = 0; i < resSize / rhsSize; ++i) {
4071 for (int j = 0; j < rhsSize; ++j)
4072 indices.push_back(j);
4073 }
4074
4075 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4076 }
4077 // Broadcast the filter to match the output vector
4078 rhs = vector::BroadcastOp::create(rewriter, loc,
4079 resTy.clone(rhsTy.getElementType()), rhs);
4080
4081 rhs = promote(rewriter, loc, rhs, resTy);
4082
4083 if (!lhs || !rhs)
4084 return nullptr;
4085
4086 if (isa<FloatType>(resTy.getElementType()))
4087 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4088
4089 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4090 return arith::AddIOp::create(rewriter, loc, mul, res);
4091 }
4092
4093 /// Entry point for non-channeled convolution:
4094 /// {{w + kw}, {kw}, {w}}
4095 FailureOr<Operation *> generateNonChanneledConv() {
4096 AffineExpr w, kw;
4097 bindDims(ctx, w, kw);
4098 if (!iters({Par(), Red()}))
4099 return rewriter.notifyMatchFailure(op,
4100 "failed to match conv::W 1-par 1-red");
4101
4102 // No transposition needed.
4103 if (layout({/*lhsIndex*/ {w + kw},
4104 /*rhsIndex*/ {kw},
4105 /*resIndex*/ {w}}))
4106 return conv(Conv1DOpOrder::W);
4107
4108 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4109 }
4110
4111 /// Entry point that transposes into the common form:
4112 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4113 FailureOr<Operation *> generateNwcConv() {
4114 AffineExpr n, w, f, kw, c;
4115 bindDims(ctx, n, w, f, kw, c);
4116 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4117 return rewriter.notifyMatchFailure(
4118 op, "failed to match conv::Nwc 3-par 2-red");
4119
4120 // No transposition needed.
4121 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4122 /*rhsIndex*/ {kw, c, f},
4123 /*resIndex*/ {n, w, f}}))
4124 return conv(Conv1DOpOrder::Nwc);
4125
4126 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4127 }
4128
4129 /// Entry point that transposes into the common form:
4130 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4131 FailureOr<Operation *> generateNcwConv() {
4132 AffineExpr n, w, f, kw, c;
4133 bindDims(ctx, n, f, w, c, kw);
4134 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4135 return rewriter.notifyMatchFailure(
4136 op, "failed to match conv::Ncw 3-par 2-red");
4137
4138 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4139 /*rhsIndex*/ {f, c, kw},
4140 /*resIndex*/ {n, f, w}}))
4141 return conv(Conv1DOpOrder::Ncw);
4142
4143 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4144 }
4145
4146 /// Entry point that transposes into the common form:
4147 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4148 FailureOr<Operation *> generateNwcPooling() {
4149 AffineExpr n, w, c, kw;
4150 bindDims(ctx, n, w, c, kw);
4151 if (!iters({Par(), Par(), Par(), Red()}))
4152 return rewriter.notifyMatchFailure(op,
4153 "failed to match pooling 3-par 1-red");
4154
4155 // No transposition needed.
4156 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4157 /*rhsIndex*/ {kw},
4158 /*resIndex*/ {n, w, c}}))
4159 return conv(Conv1DOpOrder::Nwc);
4160
4161 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4162 }
4163
4164 /// Entry point that transposes into the common form:
4165 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4166 FailureOr<Operation *> generateNcwPooling() {
4167 AffineExpr n, w, c, kw;
4168 bindDims(ctx, n, c, w, kw);
4169 if (!iters({Par(), Par(), Par(), Red()}))
4170 return rewriter.notifyMatchFailure(op,
4171 "failed to match pooling 3-par 1-red");
4172
4173 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4174 /*rhsIndex*/ {kw},
4175 /*resIndex*/ {n, c, w}}))
4176 return conv(Conv1DOpOrder::Ncw);
4177
4178 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4179 }
4180
4181 /// Entry point that transposes into the common form:
4182 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4183 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4184 bool vecChDimScalableFlag = false,
4185 bool flatten = false) {
4186 AffineExpr n, w, c, kw;
4187 bindDims(ctx, n, w, c, kw);
4188 if (!iters({Par(), Par(), Par(), Red()}))
4189 return rewriter.notifyMatchFailure(
4190 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4191
4192 // No transposition needed.
4193 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4194 /*rhsIndex*/ {kw, c},
4195 /*resIndex*/ {n, w, c}}))
4196 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4197
4198 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4199 }
4200
4201private:
4202 ConvOperationKind oper = ConvOperationKind::Conv;
4203 StringAttr redOp;
4204 StringAttr poolExtOp;
4205 bool isPoolExt = false;
4206 int strideW, dilationW;
4207 Value lhsShaped, rhsShaped, resShaped;
4208 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4209 vector::CombiningKind reductionKind;
4210
4211 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4212 void setConvOperationKind(Operation *reduceOp) {
4213 int numBlockArguments =
4214 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4215 if (numBlockArguments == 1) {
4216 // Will be convolution if feeder is a MulOp.
4217 // A strength reduced version of MulOp for i1 type is AndOp which is also
4218 // supported. Otherwise, it can be pooling. This strength reduction logic
4219 // is in `buildBinaryFn` helper in the Linalg dialect.
4220 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4221 llvm::IsaPred<BlockArgument>);
4222 Operation *feedOp = (*feedValIt).getDefiningOp();
4223 if (isCastOfBlockArgument(feedOp)) {
4224 oper = ConvOperationKind::Pool;
4225 isPoolExt = true;
4226 poolExtOp = feedOp->getName().getIdentifier();
4227 return;
4228 }
4229 oper = ConvOperationKind::Conv;
4230 return;
4231 }
4232 // numBlockArugments == 2 and this is a pooling op.
4233 oper = ConvOperationKind::Pool;
4234 isPoolExt = false;
4235 }
4236};
4237} // namespace
4238
4239/// Helper function to vectorize a LinalgOp with convolution semantics.
4240// TODO: extend the generic vectorization to support windows and drop this.
4241static FailureOr<Operation *> vectorizeConvolution(
4242 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4243 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4244 Conv1DGenerator conv1dGen(rewriter, op);
4245 auto res = conv1dGen.generateNonChanneledConv();
4246 if (succeeded(res))
4247 return res;
4248 res = conv1dGen.generateNwcConv();
4249 if (succeeded(res))
4250 return res;
4251 res = conv1dGen.generateNcwConv();
4252 if (succeeded(res))
4253 return res;
4254 res = conv1dGen.generateNwcPooling();
4255 if (succeeded(res))
4256 return res;
4257 res = conv1dGen.generateNcwPooling();
4258 if (succeeded(res))
4259 return res;
4260
4261 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4262 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4263 // masked/scalable) is the channel dim (i.e. the trailing dim).
4264 uint64_t vecChDimSize = ShapedType::kDynamic;
4265 bool vecChDimScalableFlag = false;
4266 if (!inputVecSizes.empty()) {
4267 // Only use the input vector size corresponding to the channel dim. Other
4268 // vector dims will be inferred from the Ops.
4269 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4270 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4271 "Not a 1D depthwise conv!");
4272 size_t chDimIdx =
4274 .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4275 .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4276
4277 vecChDimSize = inputVecSizes[chDimIdx];
4278 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4279 }
4280 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4281 flatten1DDepthwiseConv);
4282}
4283
4284struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4286
4287 LogicalResult matchAndRewrite(LinalgOp op,
4288 PatternRewriter &rewriter) const override {
4289 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4290 if (failed(resultOrFail))
4291 return failure();
4292 Operation *newOp = *resultOrFail;
4293 if (newOp->getNumResults() == 0) {
4294 rewriter.eraseOp(op.getOperation());
4295 return success();
4296 }
4297 assert(newOp->getNumResults() == 1 && "expected single result");
4298 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4299 return success();
4300 }
4301};
4302
4304 RewritePatternSet &patterns, PatternBenefit benefit) {
4305 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4306}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:147
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:318
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_iterator operand_end()
Definition Operation.h:375
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:793
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override