MLIR 22.0.0git
Vectorization.cpp
Go to the documentation of this file.
1//===- Vectorization.cpp - Implementation of linalg Vectorization ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the linalg dialect Vectorization transformations.
10//
11//===----------------------------------------------------------------------===//
13
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
34#include "mlir/IR/Value.h"
35#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/STLExtras.h"
38#include "llvm/ADT/Sequence.h"
39#include "llvm/ADT/SmallVector.h"
40#include "llvm/ADT/TypeSwitch.h"
41#include "llvm/Support/DebugLog.h"
42#include "llvm/Support/InterleavedRange.h"
43#include "llvm/Support/MathExtras.h"
44#include "llvm/Support/raw_ostream.h"
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::linalg;
49
50#define DEBUG_TYPE "linalg-vectorization"
51
52/// Try to vectorize `convOp` as a convolution.
53static FailureOr<Operation *>
54vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
55 ArrayRef<int64_t> inputVecSizes = {},
56 ArrayRef<bool> inputVecScalableFlags = {},
57 bool flatten1DDepthwiseConv = false);
58
59/// Vectorize tensor::InsertSliceOp with:
60/// * vector::TransferReadOp + vector::TransferWriteOp
61/// The vector sizes are either:
62/// * user-provided in `inputVectorSizes`, or
63/// * inferred from the static dims in the input and output tensors.
64/// Bails out if:
65/// * vector sizes are not user-provided, and
66/// * at least one dim is dynamic (in both the input and output tensors).
67///
68/// Before:
69/// !t_in_type = tensor<1x2x3xf32>
70/// !t_out_type = tensor<9x8x7x1x2x3xf32>
71/// !v_type = vector<1x2x3xf32>
72/// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
73/// into !t_out_type
74/// After:
75/// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
76/// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
77static LogicalResult
78vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
79 ArrayRef<int64_t> inputVectorSizes,
80 SmallVectorImpl<Value> &newResults);
81
82/// Returns the effective Pad value for the input op, provided it's a scalar.
83///
84/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
85/// this Op performs padding, retrieve the padding value provided that it's
86/// a scalar and static/fixed for all the padded values. Returns an empty value
87/// otherwise.
89
90/// Return the unique instance of OpType in `block` if it is indeed unique.
91/// Return null if none or more than 1 instances exist.
92template <typename OpType>
93static OpType getSingleOpOfType(Block &block) {
94 OpType res;
95 block.walk([&](OpType op) {
96 if (res) {
97 res = nullptr;
98 return WalkResult::interrupt();
99 }
100 res = op;
101 return WalkResult::advance();
102 });
103 return res;
104}
105
106/// Helper function to extract the input slices after filter is unrolled along
107/// kw.
110 int64_t nSize, int64_t wSize, int64_t cSize,
111 int64_t kwSize, int strideW, int dilationW,
112 int64_t wSizeStep, bool isSingleChanneled) {
114 if (isSingleChanneled) {
115 // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled
116 // convolution.
117 SmallVector<int64_t> sizes = {wSizeStep};
118 SmallVector<int64_t> strides = {1};
119 for (int64_t kw = 0; kw < kwSize; ++kw) {
120 for (int64_t w = 0; w < wSize; w += wSizeStep) {
121 result.push_back(vector::ExtractStridedSliceOp::create(
122 rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes,
123 strides));
124 }
125 }
126 } else {
127 // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0]
128 // for channeled convolution.
129 SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize};
130 SmallVector<int64_t> strides = {1, 1, 1};
131 for (int64_t kw = 0; kw < kwSize; ++kw) {
132 for (int64_t w = 0; w < wSize; w += wSizeStep) {
133 result.push_back(vector::ExtractStridedSliceOp::create(
134 rewriter, loc, input,
135 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
136 sizes, strides));
137 }
138 }
139 }
140 return result;
141}
142
143/// Helper function to extract the filter slices after filter is unrolled along
144/// kw.
146 Location loc, Value filter,
147 int64_t kwSize) {
149 // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for
150 // non-chanelled convolution] @ [kw].
151 for (int64_t kw = 0; kw < kwSize; ++kw) {
152 result.push_back(vector::ExtractOp::create(
153 rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw}));
154 }
155 return result;
156}
157
158/// Helper function to extract the result slices after filter is unrolled along
159/// kw.
162 int64_t nSize, int64_t wSize, int64_t fSize,
163 int64_t wSizeStep, bool isSingleChanneled) {
165 if (isSingleChanneled) {
166 // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution.
167 SmallVector<int64_t> sizes = {wSizeStep};
168 SmallVector<int64_t> strides = {1};
169 for (int64_t w = 0; w < wSize; w += wSizeStep) {
170 result.push_back(vector::ExtractStridedSliceOp::create(
171 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes,
172 strides));
173 }
174 } else {
175 // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
176 // convolution.
177 SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize};
178 SmallVector<int64_t> strides = {1, 1, 1};
179 for (int64_t w = 0; w < wSize; w += wSizeStep) {
180 result.push_back(vector::ExtractStridedSliceOp::create(
181 rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes,
182 strides));
183 }
184 }
185 return result;
186}
187
188/// Helper function to insert the computed result slices.
190 Value res, int64_t wSize, int64_t wSizeStep,
191 SmallVectorImpl<Value> &resVals,
192 bool isSingleChanneled) {
193
194 if (isSingleChanneled) {
195 // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution.
196 // This does not depend on kw.
197 SmallVector<int64_t> strides = {1};
198 for (int64_t w = 0; w < wSize; w += wSizeStep) {
199 res = vector::InsertStridedSliceOp::create(
200 rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w},
201 strides);
202 }
203 } else {
204 // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled
205 // convolution. This does not depend on kw.
206 SmallVector<int64_t> strides = {1, 1, 1};
207 for (int64_t w = 0; w < wSize; w += wSizeStep) {
208 res = vector::InsertStridedSliceOp::create(
209 rewriter, loc, resVals[w], res,
210 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides);
211 }
212 }
213 return res;
214}
215
216/// Contains the vectorization state and related methods used across the
217/// vectorization process of a given operation.
219 VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
220
221 /// Initializes the vectorization state, including the computation of the
222 /// canonical vector shape for vectorization.
223 LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
224 ArrayRef<int64_t> inputVectorSizes,
225 ArrayRef<bool> inputScalableVecDims,
226 bool assumeDynamicDimsMatchVecSizes = false);
227
228 /// Returns the canonical vector shape used to vectorize the iteration space.
229 ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
230
231 /// Returns the vector dimensions that are scalable in the canonical vector
232 /// shape.
233 ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
234
235 /// Returns a vector type of the provided `elementType` with the canonical
236 /// vector shape and the corresponding fixed/scalable dimensions bit. If
237 /// `dimPermutation` is provided, the canonical vector dimensions are permuted
238 /// accordingly.
240 Type elementType,
241 std::optional<AffineMap> dimPermutation = std::nullopt) const {
243 SmallVector<bool> scalableDims;
244 if (dimPermutation.has_value()) {
246 applyPermutationMap<int64_t>(*dimPermutation, canonicalVecShape);
247 scalableDims =
248 applyPermutationMap<bool>(*dimPermutation, scalableVecDims);
249 } else {
250 vectorShape.append(canonicalVecShape.begin(), canonicalVecShape.end());
251 scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
252 }
253
254 return VectorType::get(vectorShape, elementType, scalableDims);
255 }
256
257 /// Masks an operation with the canonical vector mask if the operation needs
258 /// masking. Returns the masked operation or the original operation if masking
259 /// is not needed. If provided, the canonical mask for this operation is
260 /// permuted using `maybeIndexingMap`.
261 Operation *
262 maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
263 std::optional<AffineMap> maybeIndexingMap = std::nullopt);
264
265private:
266 /// Initializes the iteration space static sizes using the Linalg op
267 /// information. This may become more complicated in the future.
268 void initIterSpaceStaticSizes(LinalgOp linalgOp) {
269 iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
270 }
271
272 /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
273 /// all the static and dynamic dimensions of the iteration space to be
274 /// vectorized and store them in `iterSpaceValueSizes`.
275 LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
276 LinalgOp linalgOp);
277
278 /// Create or retrieve an existing mask value to mask `opToMask` in the
279 /// canonical vector iteration space. If `maybeMaskingMap` the mask is
280 /// permuted using that permutation map. If a new mask is created, it will be
281 /// cached for future users.
282 Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
283 LinalgOp linalgOp,
284 std::optional<AffineMap> maybeMaskingMap);
285
286 /// Check whether this permutation map can be used for masking. At the
287 /// moment we only make sure that there are no broadcast dimensions, but this
288 /// might change if indexing maps evolve.
289 bool isValidMaskingMap(AffineMap maskingMap) {
290 return maskingMap.getBroadcastDims().empty();
291 }
292
293 /// Turn the input indexing map into a valid masking map.
294 ///
295 /// The input indexing map may contain "zero" results, e.g.:
296 /// (d0, d1, d2, d3) -> (d2, d1, d0, 0)
297 /// Applying such maps to canonical vector shapes like this one:
298 /// (1, 16, 16, 4)
299 /// would yield an invalid vector shape like this:
300 /// (16, 16, 1, 0)
301 /// Instead, drop the broadcasting dims that make no sense for masking perm.
302 /// maps:
303 /// (d0, d1, d2, d3) -> (d2, d1, d0)
304 /// This way, the corresponding vector/mask type will be:
305 /// vector<16x16x1xty>
306 /// rather than this invalid Vector type:
307 /// vector<16x16x1x0xty>
308 AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) {
309 return indexingMap.dropZeroResults();
310 }
311
312 // Holds the compile-time static sizes of the iteration space to vectorize.
313 // Dynamic dimensions are represented using ShapedType::kDynamic.
314 SmallVector<int64_t> iterSpaceStaticSizes;
315
316 /// Holds the value sizes of the iteration space to vectorize. Static
317 /// dimensions are represented by 'arith.constant' and dynamic
318 /// dimensions by 'tensor/memref.dim'.
319 SmallVector<Value> iterSpaceValueSizes;
320
321 /// Holds the canonical vector shape used to vectorize the iteration space.
322 SmallVector<int64_t> canonicalVecShape;
323
324 /// Holds the vector dimensions that are scalable in the canonical vector
325 /// shape.
326 SmallVector<bool> scalableVecDims;
327
328 /// Holds the active masks for permutations of the canonical vector iteration
329 /// space.
330 DenseMap<AffineMap, Value> activeMaskCache;
331
332 /// Global vectorization guard for the incoming rewriter. It's initialized
333 /// when the vectorization state is initialized.
334 OpBuilder::InsertionGuard rewriterGuard;
335
336 /// Do all dynamic dims match the corresponding vector sizes?
337 ///
338 /// When a dynamic tensor/memref dimension matches the corresponding vector
339 /// dimension, masking can be safely skipped, despite the presence of dynamic
340 /// shapes. Use this flag with care and only for cases where you are
341 /// confident the assumption holds.
342 bool assumeDynamicDimsMatchVecSizes = false;
343};
344
345LogicalResult
346VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
347 LinalgOp linalgOp) {
348 // TODO: Support 0-d vectors.
349 for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
350 if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) {
351 // Create constant index op for static dimensions.
352 iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create(
353 rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
354 continue;
355 }
356
357 // Find an operand defined on this dimension of the iteration space to
358 // extract the runtime dimension size.
359 Value operand;
360 unsigned operandDimPos;
361 if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
362 operandDimPos)))
363 return failure();
364
365 Value dynamicDim =
366 linalgOp.hasPureTensorSemantics()
367 ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand,
368 operandDimPos)
369 : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand,
370 operandDimPos);
371 iterSpaceValueSizes.push_back(dynamicDim);
372 }
373
374 return success();
375}
376
377/// Initializes the vectorization state, including the computation of the
378/// canonical vector shape for vectorization.
379// TODO: Move this to the constructor when we can remove the failure cases.
381 LinalgOp linalgOp,
382 ArrayRef<int64_t> inputVectorSizes,
383 ArrayRef<bool> inputScalableVecDims,
384 bool assumeDimsMatchVec) {
385 assumeDynamicDimsMatchVecSizes = assumeDimsMatchVec;
386 // Initialize the insertion point.
387 rewriter.setInsertionPoint(linalgOp);
388
389 if (!inputVectorSizes.empty()) {
390 // Get the canonical vector shape from the input vector sizes provided. This
391 // path should be taken to vectorize code with dynamic shapes and when using
392 // vector sizes greater than the iteration space sizes.
393 canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
394 scalableVecDims.append(inputScalableVecDims.begin(),
395 inputScalableVecDims.end());
396 } else {
397 // Compute the canonical vector shape from the operation shape. If there are
398 // dynamic shapes, the operation won't be vectorized. We assume all the
399 // vector dimensions are fixed.
400 canonicalVecShape = linalgOp.getStaticLoopRanges();
401 scalableVecDims.append(linalgOp.getNumLoops(), false);
402 }
403
404 LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape);
405 LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims);
406
407 if (ShapedType::isDynamicShape(canonicalVecShape))
408 return failure();
409
410 // Initialize iteration space static sizes.
411 initIterSpaceStaticSizes(linalgOp);
412
413 // Generate 'arith.constant' and 'tensor/memref.dim' operations for
414 // all the static and dynamic dimensions of the iteration space, needed to
415 // compute a mask during vectorization.
416 if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
417 return failure();
418
419 return success();
420}
421
422/// Create or retrieve an existing mask value to mask `opToMask` in the
423/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
424/// using that permutation map. If a new mask is created, it will be cached for
425/// future users.
426Value VectorizationState::getOrCreateMaskFor(
427 RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
428 std::optional<AffineMap> maybeMaskingMap) {
429
430 assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) &&
431 "Ill-formed masking map.");
432
433 // No mask is needed if the operation is not maskable.
434 auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
435 if (!maskableOp)
436 return Value();
437
438 assert(!maskableOp.isMasked() &&
439 "Masking an operation that is already masked");
440
441 // If no masking map was provided, use an identity map with the loop dims.
442 assert((!maybeMaskingMap || *maybeMaskingMap) &&
443 "Unexpected null mask permutation map");
444 AffineMap maskingMap =
445 maybeMaskingMap ? *maybeMaskingMap
447 linalgOp.getNumLoops(), rewriter.getContext());
448
449 LDBG() << "Masking map: " << maskingMap;
450
451 // Return the active mask for the masking map of this operation if it was
452 // already created.
453 auto activeMaskIt = activeMaskCache.find(maskingMap);
454 if (activeMaskIt != activeMaskCache.end()) {
455 Value mask = activeMaskIt->second;
456 LDBG() << "Reusing mask: " << mask;
457 return mask;
458 }
459
460 // Compute permuted projection of the iteration space to be masked and the
461 // corresponding mask shape. If the resulting iteration space dimensions are
462 // static and identical to the mask shape, masking is not needed for this
463 // operation.
464 // TODO: Improve this check. Only projected permutation indexing maps are
465 // supported.
466 SmallVector<int64_t> permutedStaticSizes =
467 applyPermutationMap<int64_t>(maskingMap, iterSpaceStaticSizes);
468 auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap);
469 auto maskShape = maskType.getShape();
470
471 LDBG() << "Mask shape: " << llvm::interleaved(maskShape);
472
473 if (permutedStaticSizes == maskShape) {
474 LDBG() << "Masking is not needed for masking map: " << maskingMap;
475 activeMaskCache[maskingMap] = Value();
476 return Value();
477 }
478
479 if (assumeDynamicDimsMatchVecSizes) {
480 // While for _dynamic_ dim sizes we can _assume_ that the corresponding
481 // vector sizes match, we still need to check the _static_ dim sizes. Only
482 // then we can be 100% sure that masking is not required.
483 if (llvm::all_of(llvm::zip(permutedStaticSizes, maskType.getShape()),
484 [](auto it) {
485 return std::get<0>(it) == ShapedType::kDynamic
486 ? true
487 : std::get<0>(it) == std::get<1>(it);
488 })) {
489 LDBG()
490 << "Dynamic + static dimensions match vector sizes, masking is not "
491 "required.";
492 activeMaskCache[maskingMap] = Value();
493 return Value();
494 }
495 }
496
497 // Permute the iteration space value sizes to compute the mask upper bounds.
498 SmallVector<Value> upperBounds =
499 applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
500 assert(!maskShape.empty() && !upperBounds.empty() &&
501 "Masked 0-d vectors are not supported yet");
502
503 // Create the mask based on the dimension values.
504 Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(),
505 maskType, upperBounds);
506 LDBG() << "Creating new mask: " << mask;
507 activeMaskCache[maskingMap] = mask;
508 return mask;
509}
510
511Operation *
513 LinalgOp linalgOp,
514 std::optional<AffineMap> maybeIndexingMap) {
515 LDBG() << "Trying to mask: " << *opToMask;
516
517 std::optional<AffineMap> maybeMaskingMap = std::nullopt;
518 if (maybeIndexingMap)
519 maybeMaskingMap = getMaskingMapFromIndexingMap(*maybeIndexingMap);
520
521 // Create or retrieve mask for this operation.
522 Value mask =
523 getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
524
525 if (!mask) {
526 LDBG() << "No mask required";
527 if (assumeDynamicDimsMatchVecSizes) {
529 .Case<vector::TransferReadOp, vector::TransferWriteOp>(
530 [&](auto xferOp) {
531 // For vector.transfer_read and vector.transfer_write, there is
532 // also the `in-bounds` attribute that has to be set explicitly
533 // to true. Otherwise, "out-of-bounds" access will be assumed
534 // and masks will be generated while lowering these.
535 LDBG() << "Assuming dynamic dimensions match vector sizes and "
536 "setting their in-bounds to true!";
537 SmallVector<bool> inBoundsMap = xferOp.getInBoundsValues();
538 ShapedType xferType = xferOp.getShapedType();
539 AffineMap permMap = xferOp.getPermutationMap();
540 // Only set the in-bounds values to true for dynamic dims.
541 // Different mechanisms will set these accordingly for the
542 // static dims.
543 for (unsigned i = 0; i < xferOp.getTransferRank(); i++) {
544 auto dimExpr = dyn_cast<AffineDimExpr>(permMap.getResult(i));
545 // Skip broadcast dimensions.
546 if (!dimExpr)
547 continue;
548 unsigned pos = dimExpr.getPosition();
549 if (xferType.isDynamicDim(pos))
550 inBoundsMap[i] = true;
551 }
552 rewriter.modifyOpInPlace(xferOp, [&]() {
553 xferOp.setInBoundsAttr(
554 rewriter.getBoolArrayAttr(inBoundsMap));
555 });
556 })
557 .Default([](Operation *op) {
558 // No-op if the operation is not an xfer read or write.
559 });
560 }
561 return opToMask;
562 }
563
564 // Wrap the operation with a new `vector.mask` and update D-U chain.
565 assert(opToMask && "Expected a valid operation to mask");
566 auto maskOp = cast<vector::MaskOp>(
567 mlir::vector::maskOperation(rewriter, opToMask, mask));
568 Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
569
570 for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
571 rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
572 maskOpTerminator);
573
574 LDBG() << "Masked operation: " << *maskOp;
575 return maskOp;
576}
577
578/// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
579/// projectedPermutation, compress the unused dimensions to serve as a
580/// permutation_map for a vector transfer operation.
581/// For example, given a linalg op such as:
582///
583/// ```
584/// %0 = linalg.generic {
585/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>,
586/// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)>
587/// }
588/// ins(%0 : tensor<2x3x4xf32>)
589/// outs(%1 : tensor<5x6xf32>)
590/// ```
591///
592/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine
593/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second
594/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`.
596 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) &&
597 "expected projected permutation");
598 auto res = compressUnusedDims(map);
599 assert(res.getNumDims() ==
600 (res.getNumResults() - res.getNumOfZeroResults()) &&
601 "expected reindexed map with same number of dims and results");
602 return res;
603}
604
605/// Helper enum to represent conv1d input traversal order.
606enum class Conv1DOpOrder {
607 W, // Corresponds to non-channeled 1D convolution operation.
608 Ncw, // Corresponds to operation that traverses the input in (n, c, w) order.
609 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
610};
611
612/// Helper data structure to represent the result of vectorization for a single
613/// operation. In certain specific cases, like terminators, we do not want to
614/// propagate.
616 /// Op failed to vectorize.
618 /// Op vectorized and custom function took care of replacement logic
620 /// Op vectorized into a new Op whose results will replace original Op's
621 /// results.
623 // TODO: support values if Op vectorized to Many-Ops whose results we need to
624 // aggregate for replacement.
625};
626/// VectorizationHookResult contains the vectorized op returned from a
627/// CustomVectorizationHook. This is an internal implementation detail of
628/// linalg vectorization, not to be confused with VectorizationResult.
630 /// Return status from vectorizing the current op.
632 /// New vectorized operation to replace the current op.
633 /// Replacement behavior is specified by `status`.
635};
636
637std::optional<vector::CombiningKind>
639 using ::mlir::vector::CombiningKind;
640
641 if (!combinerOp)
642 return std::nullopt;
644 .Case<arith::AddIOp, arith::AddFOp>(
645 [&](auto op) { return CombiningKind::ADD; })
646 .Case<arith::AndIOp>([&](auto op) { return CombiningKind::AND; })
647 .Case<arith::MaxSIOp>([&](auto op) { return CombiningKind::MAXSI; })
648 .Case<arith::MaxUIOp>([&](auto op) { return CombiningKind::MAXUI; })
649 .Case<arith::MaximumFOp>([&](auto op) { return CombiningKind::MAXIMUMF; })
650 .Case<arith::MaxNumFOp>([&](auto op) { return CombiningKind::MAXNUMF; })
651 .Case<arith::MinSIOp>([&](auto op) { return CombiningKind::MINSI; })
652 .Case<arith::MinUIOp>([&](auto op) { return CombiningKind::MINUI; })
653 .Case<arith::MinimumFOp>([&](auto op) { return CombiningKind::MINIMUMF; })
654 .Case<arith::MinNumFOp>([&](auto op) { return CombiningKind::MINNUMF; })
655 .Case<arith::MulIOp, arith::MulFOp>(
656 [&](auto op) { return CombiningKind::MUL; })
657 .Case<arith::OrIOp>([&](auto op) { return CombiningKind::OR; })
658 .Case<arith::XOrIOp>([&](auto op) { return CombiningKind::XOR; })
659 .Default(std::nullopt);
660}
661
662/// Check whether `outputOperand` is a reduction with a single combiner
663/// operation. Return the combiner operation of the reduction. Return
664/// nullptr otherwise. Multiple reduction operations would impose an
665/// ordering between reduction dimensions and is currently unsupported in
666/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
667/// max(min(X))
668// TODO: use in LinalgOp verification, there is a circular dependency atm.
669static Operation *matchLinalgReduction(OpOperand *outputOperand) {
670 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
671 unsigned outputPos =
672 outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs();
673 // Only single combiner operations are supported for now.
674 SmallVector<Operation *, 4> combinerOps;
675 if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
676 combinerOps.size() != 1)
677 return nullptr;
678
679 // Return the combiner operation.
680 return combinerOps[0];
681}
682
683/// Broadcast `value` to a vector of `shape` if possible. Return value
684/// otherwise.
685static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) {
686 auto dstVecType = dyn_cast<VectorType>(dstType);
687 // If no shape to broadcast to, just return `value`.
688 if (dstVecType.getRank() == 0)
689 return value;
690 if (vector::isBroadcastableTo(value.getType(), dstVecType) !=
692 return value;
693 Location loc = b.getInsertionPoint()->getLoc();
694 return b.createOrFold<vector::BroadcastOp>(loc, dstVecType, value);
695}
696
697/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
698/// assumes that `reductionOp` has two operands and one of them is the reduction
699/// initial value.buildMultiDimReduce
700// Note: this is a true builder that notifies the OpBuilder listener.
701// TODO: Consider moving as a static helper on the ReduceOp.
703 Value valueToReduce, Value acc,
704 ArrayRef<bool> dimsToMask) {
705 auto maybeKind = getCombinerOpKind(reduceOp);
706 assert(maybeKind && "Failed precondition: could not get reduction kind");
707 return vector::MultiDimReductionOp::create(
708 b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
709}
710
711static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
712 return llvm::to_vector(
713 llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
714}
715
716/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
717/// reduction iterator.
718static bool hasReductionIterator(LinalgOp &op) {
719 return isa<linalg::ReduceOp>(op) ||
720 (isa<linalg::GenericOp>(op) &&
721 llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
722}
723
724/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
725/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
726/// currently being vectorized. If `dest` has null rank, build an memref.store.
727/// Return the produced value or null if no value is produced.
728// Note: this is a true builder that notifies the OpBuilder listener.
729// TODO: Consider moving as a static helper on the ReduceOp.
730static Value buildVectorWrite(RewriterBase &rewriter, Value value,
731 OpOperand *outputOperand,
732 VectorizationState &state) {
733 Location loc = value.getLoc();
734 auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
735 AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
736
737 // Compute the vector type of the value to store. This type should be an
738 // identity or projection of the canonical vector type without any permutation
739 // applied, given that any permutation in a transfer write happens as part of
740 // the write itself.
742 opOperandMap.getContext(), opOperandMap.getNumInputs(),
743 [&](AffineDimExpr dimExpr) -> bool {
744 return llvm::is_contained(opOperandMap.getResults(), dimExpr);
745 });
746 auto vectorType = state.getCanonicalVecType(
747 getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap);
748
749 SmallVector<Value> indices(linalgOp.getRank(outputOperand),
750 arith::ConstantIndexOp::create(rewriter, loc, 0));
751
752 Operation *write;
753 if (vectorType.getRank() > 0) {
754 AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap));
755 value = broadcastIfNeeded(rewriter, value, vectorType);
756 assert(value.getType() == vectorType && "Incorrect type");
757 write = vector::TransferWriteOp::create(
758 rewriter, loc, value, outputOperand->get(), indices, writeMap);
759 } else {
760 // 0-d case is still special: do not invert the reindexing writeMap.
761 if (!isa<VectorType>(value.getType()))
762 value = vector::BroadcastOp::create(rewriter, loc, vectorType, value);
763 assert(value.getType() == vectorType && "Incorrect type");
764 write = vector::TransferWriteOp::create(rewriter, loc, value,
765 outputOperand->get(), indices);
766 }
767
768 write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
769
770 // If masked, set in-bounds to true. Masking guarantees that the access will
771 // be in-bounds.
772 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
773 auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
774 SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
775 maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
776 }
777
778 LDBG() << "vectorized op: " << *write;
779 if (!write->getResults().empty())
780 return write->getResult(0);
781 return Value();
782}
783
784// Custom vectorization precondition function type. This is intented to be used
785// with CustomVectorizationHook. Returns success if the corresponding custom
786// hook can vectorize the op.
788 std::function<LogicalResult(Operation *, bool)>;
789
790// Custom vectorization function type. Produce a vector form of Operation*
791// assuming all its vectorized operands are already in the IRMapping.
792// Return nullptr if the Operation cannot be vectorized.
794 std::function<VectorizationHookResult(Operation *, const IRMapping &)>;
795
796/// Helper function to vectorize the terminator of a `linalgOp`. New result
797/// vector values are appended to `newResults`. Return
798/// VectorizationHookStatus::NoReplace to signal the vectorization algorithm
799/// that it should not try to map produced operations and instead return the
800/// results using the `newResults` vector making them available to the
801/// vectorization algorithm for RAUW. This function is meant to be used as a
802/// CustomVectorizationHook.
805 const IRMapping &bvm, VectorizationState &state,
806 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
807 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
808 if (!yieldOp)
810 for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
811 // TODO: Scan for an opportunity for reuse.
812 // TODO: use a map.
813 Value vectorValue = bvm.lookup(output.value());
814 Value newResult =
815 buildVectorWrite(rewriter, vectorValue,
816 linalgOp.getDpsInitOperand(output.index()), state);
817 if (newResult)
818 newResults.push_back(newResult);
819 }
820
822}
823
824/// Helper function to vectorize the index operations of a `linalgOp`. Return
825/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
826/// should map the produced operations. This function is meant to be used as a
827/// CustomVectorizationHook.
829 VectorizationState &state,
830 Operation *op,
831 LinalgOp linalgOp) {
832 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
833 if (!indexOp)
835 auto loc = indexOp.getLoc();
836 // Compute the static loop sizes of the index op.
837 ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
838 auto dim = indexOp.getDim();
839 // Compute a one-dimensional index vector for the index op dimension.
840 auto indexVectorType =
841 VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
842 state.getScalableVecDims()[dim]);
843 auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType);
844 // Return the one-dimensional index vector if it lives in the trailing
845 // dimension of the iteration space since the vectorization algorithm in this
846 // case can handle the broadcast.
847 if (dim == targetShape.size() - 1)
849 // Otherwise permute the targetShape to move the index dimension last,
850 // broadcast the one-dimensional index vector to the permuted shape, and
851 // finally transpose the broadcasted index vector to undo the permutation.
852 auto permPattern =
853 llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
854 std::swap(permPattern[dim], permPattern.back());
855 auto permMap =
856 AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
857
858 auto broadCastOp = vector::BroadcastOp::create(
859 rewriter, loc,
860 state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps);
861 SmallVector<int64_t> transposition =
862 llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
863 std::swap(transposition.back(), transposition[dim]);
864 auto transposeOp =
865 vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition);
867}
868
869/// Helper function to check if the tensor.extract can be vectorized by the
870/// custom hook vectorizeTensorExtract.
871static LogicalResult
873 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
874 if (!extractOp)
875 return failure();
876
877 if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
878 return failure();
879
880 // Check the index type, but only for non 0-d tensors (for which we do need
881 // access indices).
882 if (not extractOp.getIndices().empty()) {
883 if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
884 return failure();
885 }
886
887 if (!llvm::all_of(extractOp->getResultTypes(),
888 VectorType::isValidElementType)) {
889 return failure();
890 }
891
892 return success();
893}
894
895/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
896/// generated from `tensor.extract`. The offset is calculated as follows
897/// (example using scalar values):
898///
899/// offset = extractOp.indices[0]
900/// for (i = 1; i < numIndices; i++)
901/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
902///
903/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
904/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3
906 VectorizationState &state,
907 tensor::ExtractOp extractOp,
908 const IRMapping &bvm) {
909 // The vector of indices for GatherOp should be shaped as the output vector.
910 auto indexVecType = state.getCanonicalVecType(rewriter.getIndexType());
911 auto loc = extractOp.getLoc();
912
913 Value offset = broadcastIfNeeded(
914 rewriter, bvm.lookup(extractOp.getIndices()[0]), indexVecType);
915
916 const size_t numIndices = extractOp.getIndices().size();
917 for (size_t i = 1; i < numIndices; i++) {
918 Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i);
919
920 auto dimSize = broadcastIfNeeded(
921 rewriter,
922 tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx),
923 indexVecType);
924
925 offset = arith::MulIOp::create(rewriter, loc, offset, dimSize);
926
927 auto extractOpIndex = broadcastIfNeeded(
928 rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType);
929
930 offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset);
931 }
932
933 return offset;
934}
935
937
938/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
939/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
940/// represents a contiguous load operation.
941///
942/// Note that when calling this hook, it is assumed that the output vector is
943/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
944/// labelled as a gather load before entering this method.
945///
946/// Following on from the above, it is assumed that:
947/// * for statically shaped loops, when no masks are used, only one dim is !=
948/// 1 (that's what the shape of the output vector is based on).
949/// * for dynamically shaped loops, there might be more non-unit dims
950/// as the output vector type is user-specified.
951///
952/// TODO: Statically shaped loops + vector masking
953static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
954 SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
955 assert(
956 (linalgOp.hasDynamicShape() ||
957 llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) &&
958 "For statically shaped Linalg Ops, only one "
959 "non-unit loop dim is expected");
960 assert(!loopRanges.empty() && "Empty loops, nothing to analyse.");
961
962 size_t idx = loopRanges.size() - 1;
963 for (; idx != 0; idx--)
964 if (loopRanges[idx] != 1)
965 break;
966
967 return idx;
968}
969
970/// Checks whether `val` can be used for calculating a loop invariant index.
971static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
972 VectorType resType) {
973
974 assert(((llvm::count_if(resType.getShape(),
975 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
976 "n-D vectors are not yet supported");
977
978 // Blocks outside _this_ linalg.generic are effectively loop invariant.
979 // However, analysing block arguments for _this_ linalg.generic Op is a bit
980 // tricky. Just bail out in the latter case.
981 // TODO: We could try analysing the corresponding affine map here.
982 auto *block = linalgOp.getBlock();
983 if (isa<BlockArgument>(val))
984 return !llvm::is_contained(block->getArguments(), val);
985
986 Operation *defOp = val.getDefiningOp();
987 assert(defOp && "This is neither a block argument nor an operation result");
988
989 // IndexOp is loop invariant as long as its result remains constant across
990 // iterations. Note that for dynamic shapes, the corresponding dim will also
991 // be conservatively treated as != 1.
992 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
993 return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
994 }
995
996 auto *ancestor = block->findAncestorOpInBlock(*defOp);
997
998 // Values define outside `linalgOp` are loop invariant.
999 if (!ancestor)
1000 return true;
1001
1002 // Values defined inside `linalgOp`, which are constant, are loop invariant.
1003 if (isa<arith::ConstantOp>(ancestor))
1004 return true;
1005
1006 bool result = true;
1007 for (auto op : ancestor->getOperands())
1008 result &= isLoopInvariantIdx(linalgOp, op, resType);
1009
1010 return result;
1011}
1012
1013/// Check whether `val` could be used for calculating the trailing index for a
1014/// contiguous load operation.
1015///
1016/// There are currently 3 types of values that are allowed here:
1017/// 1. loop-invariant values,
1018/// 2. values that increment by 1 with every loop iteration,
1019/// 3. results of basic arithmetic operations (linear and continuous)
1020/// involving 1., 2. and 3.
1021/// This method returns True if indeed only such values are used in calculating
1022/// `val.`
1023///
1024/// Additionally, the trailing index for a contiguous load operation should
1025/// increment by 1 with every loop iteration, i.e. be based on:
1026/// * `linalg.index <dim>` ,
1027/// where <dim> is the trailing non-unit dim of the iteration space (this way,
1028/// `linalg.index <dim>` increments by 1 with every loop iteration).
1029/// `foundIndexOp` is updated to `true` when such Op is found.
1030static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
1031 bool &foundIndexOp, VectorType resType) {
1032
1033 assert(((llvm::count_if(resType.getShape(),
1034 [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
1035 "n-D vectors are not yet supported");
1036
1037 // Blocks outside _this_ linalg.generic are effectively loop invariant.
1038 // However, analysing block arguments for _this_ linalg.generic Op is a bit
1039 // tricky. Just bail out in the latter case.
1040 // TODO: We could try analysing the corresponding affine map here.
1041 auto *block = linalgOp.getBlock();
1042 if (isa<BlockArgument>(val))
1043 return !llvm::is_contained(block->getArguments(), val);
1044
1045 Operation *defOp = val.getDefiningOp();
1046 assert(defOp && "This is neither a block argument nor an operation result");
1047
1048 if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
1049 auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
1050
1051 foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
1052 return true;
1053 }
1054
1055 auto *ancestor = block->findAncestorOpInBlock(*defOp);
1056
1057 if (!ancestor)
1058 return false;
1059
1060 // Conservatively reject Ops that could lead to indices with stride other
1061 // than 1.
1062 if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
1063 return false;
1064
1065 bool result = false;
1066 for (auto op : ancestor->getOperands())
1067 result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
1068
1069 return result;
1070}
1071
1072/// Infer the memory access pattern for the input ExtractOp
1073///
1074/// Based on the ExtratOp result shape and the access indices, decides whether
1075/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
1076/// or a gather load. When analysing the ExtractOp indices (to identify
1077/// contiguous laods), this method looks for "loop" invariant indices (e.g.
1078/// block arguments) and indices that change linearly (e.g. via `linalg.index`
1079/// Op).
1080///
1081/// Note that it is always safe to use gather load operations for contiguous
1082/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
1083/// that `extractOp` is a gather load.
1085getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1086 LinalgOp &linalgOp, VectorType resType) {
1087
1088 auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
1089
1090 // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
1091 if (inputShape.getShape().empty())
1093
1094 // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
1095 // otherwise.
1096 bool isOutput1DVector =
1097 (llvm::count_if(resType.getShape(),
1098 [](int64_t dimSize) { return dimSize > 1; }) == 1);
1099 // 1. Assume that it's a gather load when reading non-1D vector.
1100 if (!isOutput1DVector)
1102
1103 bool leadingIdxsLoopInvariant = true;
1104
1105 // 2. Analyze the leading indices of `extractOp`.
1106 // Look at the way each index is calculated and decide whether it is suitable
1107 // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
1108 // gather load.
1109 auto indices = extractOp.getIndices();
1110 auto leadIndices = indices.drop_back(1);
1111
1112 for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
1113 if (inputShape.getShape()[i] == 1)
1114 continue;
1115
1116 leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
1117 }
1118
1119 if (!leadingIdxsLoopInvariant) {
1120 LDBG() << "Found gather load: " << extractOp;
1122 }
1123
1124 // 3. Analyze the trailing index for `extractOp`.
1125 // At this point we know that the leading indices are loop invariant. This
1126 // means that is potentially a scalar or a contiguous load. We can decide
1127 // based on the trailing idx.
1128 auto extractOpTrailingIdx = indices.back();
1129
1130 // 3a. Scalar broadcast load
1131 // If the trailing index is loop invariant then this is a scalar load.
1132 if (leadingIdxsLoopInvariant &&
1133 isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
1134 LDBG() << "Found scalar broadcast load: " << extractOp;
1135
1137 }
1138
1139 // 3b. Contiguous loads
1140 // The trailing `extractOp` index should increment with every loop iteration.
1141 // This effectively means that it must be based on the trailing loop index.
1142 // This is what the following bool captures.
1143 bool foundIndexOp = false;
1144 bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
1145 foundIndexOp, resType);
1146 // TODO: Support generating contiguous loads for column vectors - that will
1147 // require adding a permutation map to tranfer_read Ops.
1148 bool isRowVector = resType.getShape().back() != 1;
1149 isContiguousLoad &= (foundIndexOp && isRowVector);
1150
1151 if (isContiguousLoad) {
1152 LDBG() << "Found contigous load: " << extractOp;
1154 }
1155
1156 // 4. Fallback case - gather load.
1157 LDBG() << "Found gather load: " << extractOp;
1159}
1160
1161/// Helper function to vectorize the tensor.extract operations. Returns
1162/// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it
1163/// should map the produced operations. This function is meant to be used as a
1164/// CustomVectorizationHook.
1165static VectorizationHookResult
1166vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
1167 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
1168 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
1169 if (!extractOp)
1171 auto loc = extractOp.getLoc();
1172
1173 // Compute the static loop sizes of the extract op.
1174 auto resultType = state.getCanonicalVecType(extractOp.getResult().getType());
1175 auto maskConstantOp = arith::ConstantOp::create(
1176 rewriter, loc,
1177 DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()),
1178 /*value=*/true));
1179 auto passThruConstantOp = arith::ConstantOp::create(
1180 rewriter, loc, rewriter.getZeroAttr(resultType));
1181
1182 // Base indices are currently set to 0. We will need to re-visit if more
1183 // generic scenarios are to be supported.
1184 SmallVector<Value> baseIndices(
1185 extractOp.getIndices().size(),
1186 arith::ConstantIndexOp::create(rewriter, loc, 0));
1187
1188 VectorMemoryAccessKind memAccessKind =
1189 getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
1190
1191 // 1. Handle gather access
1192 if (memAccessKind == VectorMemoryAccessKind::Gather) {
1193 Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm);
1194
1195 // Generate the gather load
1196 Operation *gatherOp = vector::GatherOp::create(
1197 rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset,
1198 maskConstantOp, passThruConstantOp);
1199 gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
1200
1201 LDBG() << "Vectorised as gather load: " << extractOp;
1203 }
1204
1205 // 2. Handle:
1206 // a. scalar loads + broadcast,
1207 // b. contiguous loads.
1208 // Both cases use vector.transfer_read.
1209
1210 // Collect indices for `vector.transfer_read`. At this point, the indices will
1211 // either be scalars or would have been broadcast to vectors matching the
1212 // result type. For indices that are vectors, there are two options:
1213 // * for non-trailing indices, all elements are identical (contiguous
1214 // loads are identified by looking for non-trailing indices that are
1215 // invariant with respect to the corresponding linalg.generic), or
1216 // * for trailing indices, the index vector will contain values with stride
1217 // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
1218 // needed.
1219 // This means that
1220 // * for scalar indices - just re-use it,
1221 // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
1222 // (0th) element and use that.
1223 SmallVector<Value> transferReadIdxs;
1224 for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1225 Value idx = bvm.lookup(extractOp.getIndices()[i]);
1226 if (idx.getType().isIndex()) {
1227 transferReadIdxs.push_back(idx);
1228 continue;
1229 }
1230
1231 auto indexAs1dVector = vector::ShapeCastOp::create(
1232 rewriter, loc,
1233 VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1234 resultType.getScalableDims().back()),
1235 idx);
1236 transferReadIdxs.push_back(
1237 vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0));
1238 }
1239
1240 // `tensor.extract_element` is always in-bounds, hence the following holds.
1241 auto dstRank = resultType.getRank();
1242 auto srcRank = extractOp.getTensor().getType().getRank();
1243 SmallVector<bool> inBounds(dstRank, true);
1244
1245 // 2a. Handle scalar broadcast access.
1246 if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
1247 MLIRContext *ctx = rewriter.getContext();
1248 SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
1249 auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
1250
1251 auto transferReadOp = vector::TransferReadOp::create(
1252 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1253 /*padding=*/std::nullopt, permutationMap, inBounds);
1254
1255 // Mask this broadcasting xfer_read here rather than relying on the generic
1256 // path (the generic path assumes identity masking map, which wouldn't be
1257 // valid here).
1258 SmallVector<int64_t> readMaskShape = {1};
1259 auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type());
1260 auto allTrue = vector::ConstantMaskOp::create(
1261 rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue);
1262 auto *maskedReadOp =
1263 mlir::vector::maskOperation(rewriter, transferReadOp, allTrue);
1264
1265 LDBG() << "Vectorised as scalar broadcast load: " << extractOp;
1267 maskedReadOp};
1268 }
1269
1270 // 2b. Handle contiguous access.
1271 auto permutationMap = AffineMap::getMinorIdentityMap(
1272 srcRank, std::min(dstRank, srcRank), rewriter.getContext());
1273
1274 int32_t rankDiff = dstRank - srcRank;
1275 // When dstRank > srcRank, broadcast the source tensor to the unitary leading
1276 // dims so that the ranks match. This is done by extending the map with 0s.
1277 // For example, for dstRank = 3, srcRank = 2, the following map created
1278 // above:
1279 // (d0, d1) --> (d0, d1)
1280 // is extended as:
1281 // (d0, d1) --> (0, d0, d1)
1282 while (rankDiff > 0) {
1283 permutationMap = permutationMap.insertResult(
1284 mlir::getAffineConstantExpr(0, rewriter.getContext()), 0);
1285 rankDiff--;
1286 }
1287
1288 auto transferReadOp = vector::TransferReadOp::create(
1289 rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs,
1290 /*padding=*/std::nullopt, permutationMap, inBounds);
1291
1292 LDBG() << "Vectorised as contiguous load: " << extractOp;
1294 transferReadOp};
1295}
1296
1297/// Emit reduction operations if the shapes of the value to reduce is different
1298/// that the result shape.
1299// Note: this is a true builder that notifies the OpBuilder listener.
1300// TODO: Consider moving as a static helper on the ReduceOp.
1301static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
1302 Value reduceValue, Value initialValue,
1303 const IRMapping &bvm) {
1304 Value reduceVec = bvm.lookup(reduceValue);
1305 Value outputVec = bvm.lookup(initialValue);
1306 auto reduceType = dyn_cast<VectorType>(reduceVec.getType());
1307 auto outputType = dyn_cast<VectorType>(outputVec.getType());
1308 // Reduce only if needed as the value may already have been reduce for
1309 // contraction vectorization.
1310 if (!reduceType ||
1311 (outputType && reduceType.getShape() == outputType.getShape()))
1312 return nullptr;
1313 SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
1314 return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
1315}
1316
1317/// Generic vectorization for a single operation `op`, given already vectorized
1318/// operands carried by `bvm`. Vectorization occurs as follows:
1319/// 1. Try to apply any of the `customVectorizationHooks` and return its
1320/// result on success.
1321/// 2. Clone any constant in the current scope without vectorization: each
1322/// consumer of the constant will later determine the shape to which the
1323/// constant needs to be broadcast to.
1324/// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose
1325/// of the `customVectorizationHooks` to cover such cases.
1326/// 4. Clone `op` in vector form to a vector of shape prescribed by the first
1327/// operand of maximal rank. Other operands have smaller rank and are
1328/// broadcast accordingly. It is assumed this broadcast is always legal,
1329/// otherwise, it means one of the `customVectorizationHooks` is incorrect.
1330///
1331/// This function assumes all operands of `op` have been vectorized and are in
1332/// the `bvm` mapping. As a consequence, this function is meant to be called on
1333/// a topologically-sorted list of ops.
1334/// This function does not update `bvm` but returns a VectorizationHookStatus
1335/// that instructs the caller what `bvm` update needs to occur.
1336static VectorizationHookResult
1337vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
1338 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
1339 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
1340 LDBG() << "vectorize op " << *op;
1341
1342 // 1. Try to apply any CustomVectorizationHook.
1343 if (!customVectorizationHooks.empty()) {
1344 for (auto &customFunc : customVectorizationHooks) {
1345 VectorizationHookResult result = customFunc(op, bvm);
1347 continue;
1348 return result;
1349 }
1350 }
1351
1352 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
1353 // Clone so that the constant is not confined to the linalgOp block .
1354 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1356 rewriter.clone(*op)};
1357
1358 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
1361
1362 // 4 . Check if the operation is a reduction.
1363 SmallVector<std::pair<Value, Value>> reductionOperands;
1364 for (Value operand : op->getOperands()) {
1365 auto blockArg = dyn_cast<BlockArgument>(operand);
1366 if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() ||
1367 blockArg.getArgNumber() < linalgOp.getNumDpsInputs())
1368 continue;
1369 SmallVector<Operation *> reductionOps;
1370 Value reduceValue = matchReduction(
1371 linalgOp.getRegionOutputArgs(),
1372 blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps);
1373 if (!reduceValue)
1374 continue;
1375 reductionOperands.push_back(std::make_pair(reduceValue, operand));
1376 }
1377 if (!reductionOperands.empty()) {
1378 assert(reductionOperands.size() == 1);
1379 Operation *reduceOp =
1380 reduceIfNeeded(rewriter, linalgOp, op, reductionOperands[0].first,
1381 reductionOperands[0].second, bvm);
1382 if (reduceOp)
1384 }
1385
1386 // 5. Generic vectorization path for ElementwiseMappable ops.
1387 // a. Get the first max ranked shape.
1388 VectorType firstMaxRankedType;
1389 for (Value operand : op->getOperands()) {
1390 auto vecOperand = bvm.lookup(operand);
1391 assert(vecOperand && "Vector operand couldn't be found");
1392
1393 auto vecType = dyn_cast<VectorType>(vecOperand.getType());
1394 if (vecType && (!firstMaxRankedType ||
1395 firstMaxRankedType.getRank() < vecType.getRank()))
1396 firstMaxRankedType = vecType;
1397 }
1398 // b. Broadcast each op if needed.
1399 SmallVector<Value> vecOperands;
1400 for (Value scalarOperand : op->getOperands()) {
1401 Value vecOperand = bvm.lookup(scalarOperand);
1402 assert(vecOperand && "Vector operand couldn't be found");
1403
1404 if (firstMaxRankedType) {
1405 auto vecType = VectorType::get(firstMaxRankedType.getShape(),
1406 getElementTypeOrSelf(vecOperand.getType()),
1407 firstMaxRankedType.getScalableDims());
1408 vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
1409 } else {
1410 vecOperands.push_back(vecOperand);
1411 }
1412 }
1413 // c. for elementwise, the result is the vector with the firstMaxRankedShape
1414 SmallVector<Type> resultTypes;
1415 for (Type resultType : op->getResultTypes()) {
1416 resultTypes.push_back(
1417 firstMaxRankedType
1418 ? VectorType::get(firstMaxRankedType.getShape(), resultType,
1419 firstMaxRankedType.getScalableDims())
1420 : resultType);
1421 }
1422 // d. Build and return the new op.
1425 rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
1426 resultTypes, op->getAttrs())};
1427}
1428
1429/// Generic vectorization function that rewrites the body of a `linalgOp` into
1430/// vector form. Generic vectorization proceeds as follows:
1431/// 1. Verify the `linalgOp` has one non-empty region.
1432/// 2. Values defined above the region are mapped to themselves and will be
1433/// broadcasted on a per-need basis by their consumers.
1434/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d
1435/// load).
1436/// TODO: Reuse opportunities for RAR dependencies.
1437/// 4a. Register CustomVectorizationHook for YieldOp to capture the results.
1438/// 4rewriter. Register CustomVectorizationHook for IndexOp to access the
1439/// iteration indices.
1440/// 5. Iteratively call vectorizeOneOp on the region operations.
1441///
1442/// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is
1443/// performed to the maximal common vector size implied by the `linalgOp`
1444/// iteration space. This eager broadcasting is introduced in the
1445/// permutation_map of the vector.transfer_read operations. The eager
1446/// broadcasting makes it trivial to determine where broadcast, transposes and
1447/// reductions should occur, without any bookkeeping. The tradeoff is that, in
1448/// the absence of good canonicalizations, the amount of work increases.
1449/// This is not deemed a problem as we expect canonicalizations and foldings to
1450/// aggressively clean up the useless work.
1451static LogicalResult
1452vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
1453 LinalgOp linalgOp,
1454 SmallVectorImpl<Value> &newResults) {
1455 LDBG() << "Vectorizing operation as linalg generic/n";
1456 Block *block = linalgOp.getBlock();
1457
1458 // 2. Values defined above the region can only be broadcast for now. Make them
1459 // map to themselves.
1460 IRMapping bvm;
1461 SetVector<Value> valuesSet;
1462 mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet);
1463 bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef());
1464
1465 if (linalgOp.getNumDpsInits() == 0)
1466 return failure();
1467
1468 // 3. Turn all BBArgs into vector.transfer_read / load.
1469 Location loc = linalgOp.getLoc();
1470 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1471 for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1472 BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand);
1473 if (linalgOp.isScalar(opOperand)) {
1474 bvm.map(bbarg, opOperand->get());
1475 continue;
1476 }
1477
1478 // 3.a. Convert the indexing map for this input/output to a transfer read
1479 // permutation map and masking map.
1480 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
1481
1482 AffineMap readMap;
1483 VectorType readType;
1484 Type elemType = getElementTypeOrSelf(opOperand->get());
1485 if (linalgOp.isDpsInput(opOperand)) {
1486 // 3.a.i. For input reads we use the canonical vector shape.
1487 readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
1488 readType = state.getCanonicalVecType(elemType);
1489 } else {
1490 // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
1491 // reductions), the vector shape is computed by mapping the canonical
1492 // vector shape to the output domain and back to the canonical domain.
1493 readMap = inversePermutation(reindexIndexingMap(indexingMap));
1494 readType =
1495 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
1496 }
1497
1498 SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
1499
1500 Operation *read = vector::TransferReadOp::create(
1501 rewriter, loc, readType, opOperand->get(), indices,
1502 /*padding=*/std::nullopt, readMap);
1503 read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
1504 Value readValue = read->getResult(0);
1505
1506 // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
1507 // will be in-bounds.
1508 if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
1509 SmallVector<bool> inBounds(readType.getRank(), true);
1510 cast<vector::TransferReadOp>(maskOp.getMaskableOp())
1511 .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
1512 }
1513
1514 // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
1515 // TODO: remove this.
1516 if (readType.getRank() == 0)
1517 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
1519
1520 LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber()
1521 << "): " << readValue;
1522 bvm.map(bbarg, readValue);
1523 bvm.map(opOperand->get(), readValue);
1524 }
1525
1527 // 4a. Register CustomVectorizationHook for yieldOp.
1528 CustomVectorizationHook vectorizeYield =
1529 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1530 return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
1531 };
1532 hooks.push_back(vectorizeYield);
1533
1534 // 4b. Register CustomVectorizationHook for indexOp.
1535 CustomVectorizationHook vectorizeIndex =
1536 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1537 return vectorizeLinalgIndex(rewriter, state, op, linalgOp);
1538 };
1539 hooks.push_back(vectorizeIndex);
1540
1541 // 4c. Register CustomVectorizationHook for extractOp.
1542 CustomVectorizationHook vectorizeExtract =
1543 [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
1544 return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm);
1545 };
1546 hooks.push_back(vectorizeExtract);
1547
1548 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
1549 for (Operation &op : block->getOperations()) {
1551 vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks);
1553 LDBG() << "failed to vectorize: " << op;
1554 return failure();
1555 }
1556 if (result.status == VectorizationHookStatus::NewOp) {
1557 Operation *maybeMaskedOp =
1558 state.maskOperation(rewriter, result.newOp, linalgOp);
1559 LDBG() << "New vector op: " << *maybeMaskedOp;
1560 bvm.map(op.getResults(), maybeMaskedOp->getResults());
1561 }
1562 }
1563
1564 return success();
1565}
1566
1567/// Determines whether a mask for xfer_write is trivially "all true"
1568///
1569/// Given all the inputs required to generate a mask (mask sizes and shapes),
1570/// and an xfer_write operation (write indices and the destination tensor
1571/// shape), determines whether the corresponding mask would be trivially
1572/// foldable (i.e., trivially "all true").
1573///
1574/// Use this method to avoid generating spurious masks and relaying on
1575/// vectorization post-processing to remove them.
1576///
1577/// Pre-conditions for a mask to be trivially foldable:
1578/// * All involved shapes (mask + destination tensor) are static.
1579/// * All write indices are constant.
1580/// * All mask sizes are constant (including `arith.constant`).
1581///
1582/// If the pre-conditions are met, the method checks for each destination
1583/// dimension `d`:
1584/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1585/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1586///
1587/// rankDiff = rank(dest) - rank(mask).
1588///
1589/// This method takes a conservative view: it may return false even if the mask
1590/// is technically foldable.
1591///
1592/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
1593/// of the dest tensor):
1594/// %c0 = arith.constant 0 : index
1595/// %mask = vector.create_mask 5, 1
1596/// vector.mask %mask {
1597/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1598/// {in_bounds = [true, true]}
1599/// : vector<5x1xi32>, tensor<5x1xi32>
1600/// }
1601///
1602/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
1603/// mask is required to avoid out-of-bounds write):
1604/// %c0 = arith.constant 0 : index
1605/// %mask = vector.create_mask 5, 1
1606/// vector.mask %mask {
1607/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1608/// {in_bounds = [true, true]}
1609/// : vector<8x1xi32>, tensor<5x1xi32>
1610/// }
1611///
1612/// TODO: Re-use in createReadOrMaskedRead
1614 SmallVector<Value> &writeIdxs,
1615 ArrayRef<int64_t> destShape,
1616 ArrayRef<int64_t> maskShape) {
1617 // Masking is unavoidable in the case of dynamic tensors.
1618 if (ShapedType::isDynamicShape(destShape))
1619 return false;
1620
1621 // Collect all constant mask sizes.
1622 SmallVector<int64_t, 4> cstMaskSizes;
1623 for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1624 if (auto intSize = getConstantIntValue(dimSize)) {
1625 cstMaskSizes.push_back(*intSize);
1626 }
1627 }
1628
1629 // If any of the mask sizes is non-constant, bail out.
1630 if (cstMaskSizes.size() != maskShape.size())
1631 return false;
1632
1633 // Collect all constant write indices.
1634 SmallVector<int64_t, 4> cstWriteIdxs;
1635 for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1636 APSInt intVal;
1637 if (matchPattern(idx, m_ConstantInt(&intVal))) {
1638 cstWriteIdxs.push_back(intVal.getSExtValue());
1639 }
1640 }
1641
1642 // If any of the write indices is non-constant, bail out.
1643 if (cstWriteIdxs.size() != destShape.size())
1644 return false;
1645
1646 // Go over all destination dims and check (1) and (2). Take into account that:
1647 // * The number of mask sizes will match the rank of the vector to store.
1648 // This could be lower than the rank of the destination tensor.
1649 // * Mask sizes could be larger than the corresponding mask shape (hence
1650 // `clamp`).
1651 // TODO: The 2nd item should be rejected by the verifier.
1652 int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1653 for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1654 if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1655 /*(2)*/ destShape[rankDiff + i] <
1656 (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1657 cstWriteIdxs[i]))
1658 return false;
1659 }
1660
1661 return true;
1662}
1663
1664/// Creates an optionally masked TransferWriteOp
1665///
1666/// Generates the following operation:
1667/// %res = vector.transfer_write %vecToStore into %dest
1668///
1669/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
1670///
1671/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
1672/// %res = vector.mask %mask {
1673/// vector.transfer_write %vecToStore into %dest
1674/// }
1675///
1676/// The mask shape is identical to `vecToStore` (with the element type ==
1677/// i1), and the mask values are based on the shape of the `dest` tensor.
1678///
1679/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
1680/// is used instead of masking:
1681///
1682/// %write = vector.transfer_write %vecToStore into %dest
1683/// in_bounds_flags = (...)
1684/// %res = vector.transfer_write %input into %dest
1685/// {in_bounds = in_bounds_flags}
1686///
1687/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
1688/// are set to 0.
1689static Operation *
1691 Value dest, SmallVector<Value> writeIndices = {},
1692 bool useInBoundsInsteadOfMasking = false) {
1693
1694 ShapedType destType = cast<ShapedType>(dest.getType());
1695 int64_t destRank = destType.getRank();
1696 auto destShape = destType.getShape();
1697
1698 VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
1699 int64_t vecToStoreRank = vecToStoreType.getRank();
1700 auto vecToStoreShape = vecToStoreType.getShape();
1701
1702 // Compute the in_bounds attribute
1703 SmallVector<bool> inBoundsVal(vecToStoreRank, true);
1704 if (useInBoundsInsteadOfMasking) {
1705 // Update the inBounds attribute.
1706 // FIXME: This computation is too weak - it ignores the write indices.
1707 for (unsigned i = 0; i < vecToStoreRank; i++)
1708 inBoundsVal[i] =
1709 (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) &&
1710 ShapedType::isStatic(destShape[destRank - vecToStoreRank + i]);
1711 }
1712
1713 // If missing, initialize the write indices to 0.
1714 assert((writeIndices.empty() ||
1715 writeIndices.size() == static_cast<size_t>(destRank)) &&
1716 "Invalid number of write indices!");
1717 if (writeIndices.empty()) {
1718 auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
1719 writeIndices.assign(destRank, zero);
1720 }
1721
1722 // Generate the xfer_write Op
1723 Operation *write = vector::TransferWriteOp::create(builder, loc,
1724 /*vector=*/vecToStore,
1725 /*source=*/dest,
1726 /*indices=*/writeIndices,
1727 /*inBounds=*/inBoundsVal);
1728
1729 // If masking is disabled, exit.
1730 if (useInBoundsInsteadOfMasking)
1731 return write;
1732
1733 // Check if masking is needed. If not, exit.
1734 if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
1735 return write;
1736
1737 // Compute the mask and mask the write Op.
1738 auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1739 vecToStoreType.getScalableDims());
1740
1741 SmallVector<OpFoldResult> destSizes =
1742 isa<MemRefType>(dest.getType())
1743 ? memref::getMixedSizes(builder, loc, dest)
1744 : tensor::getMixedSizes(builder, loc, dest);
1745 SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
1746 destSizes.end());
1747
1748 if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1749 vecToStoreShape))
1750 return write;
1751
1752 Value maskForWrite =
1753 builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
1754 return mlir::vector::maskOperation(builder, write, maskForWrite);
1755}
1756
1757/// Given the re-associations, "collapses" the input Vector type
1758///
1759/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1760/// differences:
1761/// * We can safely assume that there are no dynamic sizes.
1762/// * Scalable flags are updated alongside regular dims.
1763///
1764/// When collapsing scalable flags, conservatively avoids cases with two
1765/// scalable dims. We could re-visit this in the future.
1766///
1767/// EXAMPLE:
1768/// type = vector<4x16x[8]x16xf32>
1769/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1770/// (d0, d1, d2, d3) -> (d2, d3)]
1771/// Result:
1772/// vector<64x[128]xf32>
1773static VectorType getCollapsedVecType(VectorType type,
1774 ArrayRef<AffineMap> reassociation) {
1775 assert(type.getNumScalableDims() < 2 &&
1776 "Collapsing more than 1 scalable dim is not supported ATM");
1777
1778 // Use the fact that reassociation is valid to simplify the logic: only use
1779 // each map's rank.
1780 assert(isReassociationValid(reassociation) && "invalid reassociation");
1781
1782 auto shape = type.getShape();
1783 auto scalableFlags = type.getScalableDims();
1784 SmallVector<int64_t> newShape;
1785 SmallVector<bool> newScalableFlags;
1786
1787 unsigned currentDim = 0;
1788 for (AffineMap m : reassociation) {
1789 unsigned dim = m.getNumResults();
1790 int64_t size = 1;
1791 bool flag = false;
1792 for (unsigned d = 0; d < dim; ++d) {
1793 size *= shape[currentDim + d];
1794 flag |= scalableFlags[currentDim + d];
1795 }
1796 newShape.push_back(size);
1797 newScalableFlags.push_back(flag);
1798 currentDim += dim;
1799 }
1800
1801 return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1802}
1803
1804/// Vectorize `linalg.pack` as:
1805/// * xfer_read -> shape_cast -> transpose -> xfer_write
1806///
1807/// The input-vector-sizes specify the _write_ vector sizes (i.e. the vector
1808/// sizes for the xfer_write operation). This is sufficient to infer the other
1809/// vector sizes required here.
1810///
1811/// If the vector sizes are not provided:
1812/// * the vector sizes are determined from the destination tensor static shape.
1813/// * the inBounds attribute is used instead of masking.
1814///
1815/// EXAMPLE (no vector sizes):
1816/// ```
1817/// %pack = tensor.pack %src
1818/// inner_dims_pos = [2, 1]
1819/// inner_tiles = [16, 2]
1820/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
1821/// ``
1822/// is vectorizes as:
1823/// ```
1824/// %read = vector.transfer_read %src
1825/// : tensor<32x7x16xf32>, vector<32x8x16xf32>
1826/// %sc = vector.shape_cast %read
1827/// : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
1828/// %tr = vector.transpose %sc, [0, 1, 3, 4, 2]
1829/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
1830/// %write = vector.transfer_write %tr into %dest
1831/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
1832/// ```
1833static LogicalResult
1834vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1835 ArrayRef<int64_t> inputVectorSizes,
1836 SmallVectorImpl<Value> &newResults) {
1837 if (!inputVectorSizes.empty()) {
1838 assert(inputVectorSizes.size() == packOp.getDestRank() &&
1839 "Invalid number of input vector sizes!");
1840 }
1841
1842 // TODO: Introduce a parent class that will handle the insertion point update.
1843 OpBuilder::InsertionGuard g(rewriter);
1844 rewriter.setInsertionPoint(packOp);
1845
1846 Location loc = packOp.getLoc();
1847 std::optional<Value> padValue = packOp.getPaddingValue()
1848 ? std::optional(packOp.getPaddingValue())
1849 : std::nullopt;
1850
1851 SmallVector<int64_t> destShape =
1852 SmallVector<int64_t>(packOp.getDestType().getShape());
1853
1854 // This is just a convenience alias to clearly communicate that the input
1855 // vector sizes determine the _write_ sizes.
1856 ArrayRef<int64_t> &writeVectorSizes = inputVectorSizes;
1857
1858 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1859 // In addition, use the inBounds attribute instead of masking.
1860 bool useInBoundsInsteadOfMasking = false;
1861 if (writeVectorSizes.empty()) {
1862 if (ShapedType::isDynamicShape(destShape))
1863 return rewriter.notifyMatchFailure(packOp,
1864 "unable to infer vector sizes");
1865
1866 writeVectorSizes = destShape;
1867 useInBoundsInsteadOfMasking = true;
1868 }
1869
1870 // Compute pre-transpose-write-vector-type, i.e. the write vector type
1871 // _before_ the transposition (i.e. before dimension permutation). This is
1872 // done by inverting the permutation/transposition that's part of the Pack
1873 // operation. This type is required to:
1874 // 1) compute the read vector type for masked-read below, and
1875 // 2) generate shape-cast Op below that expands the read vector type.
1876 PackingMetadata packMetadata;
1877 SmallVector<int64_t> preTransposeWriteVecSizses(writeVectorSizes);
1878 auto destInvPermutation = getPackInverseDestPerm(packOp, packMetadata);
1879 applyPermutationToVector(preTransposeWriteVecSizses, destInvPermutation);
1880 auto preTransposeWriteVecType = VectorType::get(
1881 preTransposeWriteVecSizses, packOp.getType().getElementType());
1882
1883 // Compute vector type for the _read_ opeartion. This is simply
1884 // pre-transpose-write-vector-type with the dimensions collapsed
1885 // as per the Pack operation.
1886 VectorType readVecType = getCollapsedVecType(
1887 preTransposeWriteVecType,
1889 rewriter.getContext(), packMetadata.reassociations)));
1890
1891 // Create masked TransferReadOp.
1892 auto maskedRead = vector::createReadOrMaskedRead(
1893 rewriter, loc, packOp.getSource(), readVecType, padValue,
1894 useInBoundsInsteadOfMasking);
1895
1896 // Create ShapeCastOp.
1897 auto shapeCastOp = vector::ShapeCastOp::create(
1898 rewriter, loc, preTransposeWriteVecType, maskedRead);
1899
1900 // Create TransposeOp.
1901 auto destPermutation = invertPermutationVector(destInvPermutation);
1902 auto transposeOp = vector::TransposeOp::create(
1903 rewriter, loc, shapeCastOp.getResult(), destPermutation);
1904
1905 // Create TransferWriteOp.
1906 Operation *write = createWriteOrMaskedWrite(
1907 rewriter, loc, transposeOp.getResult(), packOp.getDest());
1908 newResults.push_back(write->getResult(0));
1909 return success();
1910}
1911
1912/// Vectorize `linalg.unpack` as:
1913/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1914///
1915/// The input-vector-sizes specify the _read_ vector sizes (i.e. the vector
1916/// sizes for the xfer_read operation). This is sufficient to infer the other
1917/// vector sizes required here.
1918///
1919/// If the vector sizes are not provided:
1920/// * the vector sizes are determined from the input tensor static shape.
1921/// * the inBounds attribute is used instead of masking.
1922///
1923/// EXAMPLE (no vector sizes):
1924/// ```
1925/// %unpack = linalg.unpack %src
1926/// inner_dims_pos = [0, 1]
1927/// inner_tiles = [8, 8]
1928/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1929/// ```
1930/// is vectorized as:
1931/// ```
1932/// %read = vector.transfer_read %src
1933/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1934/// %tr = vector.transpose %read, [0, 2, 1, 3]
1935/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1936/// %sc = vector.shape_cast %tr
1937/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1938/// %vector = vector.transfer_write %sc into %dest
1939/// : vector<8x8xf32>, tensor<8x8xf32>
1940/// ```
1941static LogicalResult
1942vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1943 ArrayRef<int64_t> inputVectorSizes,
1944 ArrayRef<bool> inputScalableVecDims,
1945 SmallVectorImpl<Value> &newResults) {
1946 if (!inputVectorSizes.empty()) {
1947 assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1948 "Invalid number of input vector sizes!");
1949 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1950 "Incompatible number of vector sizes and vector scalable flags!");
1951 }
1952
1953 // TODO: Introduce a parent class that will handle the insertion point update.
1954 OpBuilder::InsertionGuard g(rewriter);
1955 rewriter.setInsertionPoint(unpackOp);
1956
1957 RankedTensorType unpackTensorType = unpackOp.getSourceType();
1958
1959 ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1960 bool useInBoundsInsteadOfMasking = false;
1961
1962 Location loc = unpackOp->getLoc();
1963
1964 // Obtain vector sizes for the read operation.
1965 SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1966 SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
1967
1968 // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1969 if (inputVectorSizes.empty()) {
1970 if (ShapedType::isDynamicShape(sourceShape))
1971 return rewriter.notifyMatchFailure(unpackOp,
1972 "Unable to infer vector sizes!");
1973
1974 readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1975 useInBoundsInsteadOfMasking = true;
1976 }
1977
1978 // -- Generate the read operation --
1979 VectorType readVecType =
1980 VectorType::get(readVectorSizes, unpackTensorType.getElementType(),
1981 readScalableVectorFlags);
1982 Value readResult = vector::createReadOrMaskedRead(
1983 rewriter, loc, unpackOp.getSource(), readVecType, std::nullopt,
1984 useInBoundsInsteadOfMasking);
1985
1986 // -- Generate the transpose operation --
1987 PackingMetadata packMetadata;
1988 SmallVector<int64_t> lastDimToInsertPosPerm =
1989 getUnPackInverseSrcPerm(unpackOp, packMetadata);
1990 vector::TransposeOp transposeOp = vector::TransposeOp::create(
1991 rewriter, loc, readResult, lastDimToInsertPosPerm);
1992
1993 // -- Generate the shape_cast operation --
1994 VectorType collapsedVecType = getCollapsedVecType(
1995 transposeOp.getType(),
1997 rewriter.getContext(), packMetadata.reassociations)));
1998 vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1999 rewriter, loc, collapsedVecType, transposeOp->getResult(0));
2000
2001 // -- Generate the write operation --
2002 Operation *write = createWriteOrMaskedWrite(
2003 rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
2004 /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
2005
2006 newResults.push_back(write->getResult(0));
2007 return success();
2008}
2009
2010/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
2011/// and (3) all-zero lowPad to
2012/// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
2013static LogicalResult
2014vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2015 ArrayRef<int64_t> inputVectorSizes,
2016 SmallVectorImpl<Value> &newResults) {
2017 auto padValue = padOp.getConstantPaddingValue();
2018 Location loc = padOp.getLoc();
2019
2020 // TODO: Introduce a parent class that will handle the insertion point update.
2021 OpBuilder::InsertionGuard g(rewriter);
2022 rewriter.setInsertionPoint(padOp);
2023
2024 ReifiedRankedShapedTypeDims reifiedReturnShapes;
2025 LogicalResult status =
2026 cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation())
2027 .reifyResultShapes(rewriter, reifiedReturnShapes);
2028 (void)status; // prevent unused variable warning on non-assert builds
2029 assert(succeeded(status) && "failed to reify result shapes");
2030 auto readType = VectorType::get(inputVectorSizes, padValue.getType());
2031 auto maskedRead = vector::createReadOrMaskedRead(
2032 rewriter, loc, padOp.getSource(), readType, padValue,
2033 /*useInBoundsInsteadOfMasking=*/false);
2034
2035 // Create Xfer write Op
2036 Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
2037 padOp.getResultType().getElementType());
2038 Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest);
2039 newResults.push_back(write->getResult(0));
2040 return success();
2041}
2042
2043// TODO: probably need some extra checks for reduction followed by consumer
2044// ops that may not commute (e.g. linear reduction + non-linear instructions).
2045static LogicalResult reductionPreconditions(LinalgOp op) {
2046 if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
2047 LDBG() << "reduction precondition failed: no reduction iterator";
2048 return failure();
2049 }
2050 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2051 AffineMap indexingMap = op.getMatchingIndexingMap(&opOperand);
2052 if (indexingMap.isPermutation())
2053 continue;
2054
2055 Operation *reduceOp = matchLinalgReduction(&opOperand);
2056 if (!reduceOp || !getCombinerOpKind(reduceOp)) {
2057 LDBG() << "reduction precondition failed: reduction detection failed";
2058 return failure();
2059 }
2060 }
2061 return success();
2062}
2063
2064static LogicalResult
2065vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
2066 bool flatten1DDepthwiseConv) {
2067 if (flatten1DDepthwiseConv) {
2068 LDBG() << "Vectorization of flattened convs with dynamic shapes is not "
2069 "supported";
2070 return failure();
2071 }
2072
2073 if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
2074 LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported";
2075 return failure();
2076 }
2077
2078 // Support dynamic shapes in 1D depthwise convolution, but only in the
2079 // _channel_ dimension.
2080 Value lhs = conv.getDpsInputOperand(0)->get();
2081 ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
2082 auto shapeWithoutCh = lhsShape.drop_back(1);
2083 if (ShapedType::isDynamicShape(shapeWithoutCh)) {
2084 LDBG() << "Dynamically-shaped op vectorization precondition failed: only "
2085 "channel dim can be dynamic";
2086 return failure();
2087 }
2088
2089 return success();
2090}
2091
2092static LogicalResult
2093vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2094 bool flatten1DDepthwiseConv) {
2095 if (isa<ConvolutionOpInterface>(op.getOperation()))
2096 return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
2097
2098 if (hasReductionIterator(op))
2099 return reductionPreconditions(op);
2100
2101 // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
2102 // linalg.copy ops and ops that implement ContractionOpInterface for now.
2103 if (!isElementwise(op) &&
2104 !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>(
2105 op.getOperation()))
2106 return failure();
2107
2108 LDBG() << "Dynamically-shaped op meets vectorization pre-conditions";
2109 return success();
2110}
2111
2112//// This hook considers two cases:
2113/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2114/// infered. This is only possible when all shapes are static.
2115/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2116/// carry out basic sanity-checking.
2117static LogicalResult
2118vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2119 ArrayRef<int64_t> inputVectorSizes) {
2120 // If there are no input vector sizes and all shapes are static, there is
2121 // nothing left to check.
2122 if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2123 unpackOp.getSourceType().hasStaticShape())
2124 return success();
2125
2126 // The number of input vector sizes must be equal to:
2127 // * read-vector-rank
2128 if (!inputVectorSizes.empty() &&
2129 (inputVectorSizes.size() != unpackOp.getSourceRank())) {
2130 LDBG() << "Incorrect number of input vector sizes";
2131 return failure();
2132 }
2133
2134 // Check the vector sizes for the read operation.
2136 unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2137 LDBG() << "Invalid vector sizes for the read operation";
2138 return failure();
2139 }
2140
2141 return success();
2142}
2143
2144static LogicalResult
2145vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2146 ArrayRef<int64_t> inputVectorSizes) {
2147
2148 TypedValue<RankedTensorType> source = sliceOp.getSource();
2149 auto sourceType = source.getType();
2150 if (!VectorType::isValidElementType(sourceType.getElementType()))
2151 return failure();
2152
2153 // Get the pad value.
2154 // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
2155 // scalar padding value. Note that:
2156 // * for in-bounds accesses,
2157 // the value is actually irrelevant. There are 2 cases in which xfer.read
2158 // accesses are known to be in-bounds:
2159 // 1. The source shape is static (output vector sizes would be based on
2160 // the source shape and hence all memory accesses would be in-bounds),
2161 // 2. Masking is used, i.e. the output vector sizes are user-provided. In
2162 // this case it is safe to assume that all memory accesses are in-bounds.
2163 //
2164 // When the value is not known and not needed, use 0. Otherwise, bail out.
2165 Value padValue = getStaticPadVal(sliceOp);
2166 bool isOutOfBoundsRead =
2167 !sourceType.hasStaticShape() && inputVectorSizes.empty();
2168
2169 if (!padValue && isOutOfBoundsRead) {
2170 LDBG() << "Failed to get a pad value for out-of-bounds read access";
2171 return failure();
2172 }
2173 return success();
2174}
2175
2176/// Vectorize a named linalg contraction op into:
2177/// vector::TransferReadOp - Reads vectors from the operands
2178/// vector::ContractionOp - Performs contraction
2179/// vector::TransferWriteOp - Write the result vector back to the
2180/// destination
2181/// The operands shapes are preserved and loaded directly into vectors.
2182/// Any further permutations or numerical casting remain within contraction op.
2183static LogicalResult
2184vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2185 LinalgOp linalgOp,
2186 SmallVectorImpl<Value> &newResults) {
2187 Location loc = linalgOp.getLoc();
2188 MLIRContext *ctx = linalgOp.getContext();
2189
2190 // For simplicity, contraction vectorization is limited to linalg named ops.
2191 // Generic op is ignored as not every arbitrary contraction body can be
2192 // expressed by a vector.contract.
2193 if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2194 return failure();
2195
2196 OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2197 Operation *reduceOp = matchLinalgReduction(outOperand);
2198 auto maybeKind = getCombinerOpKind(reduceOp);
2199 if (!maybeKind) {
2200 LDBG() << "Failed to determine contraction combining kind.";
2201 return failure();
2202 }
2203
2204 // Check that all dimensions are present in the input operands.
2205 // Arbitrary broadcasts are not supported by the vector contraction.
2206 // Broadcasts are expected to be decomposed before vectorization.
2207 AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2208 AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2209 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2210 LDBG() << "Contractions with broadcasts are not supported.";
2211 return failure();
2212 }
2213
2214 // Load operands.
2215 SmallVector<Value> vecOperands;
2216 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2217 // The operand vector shape is computed by mapping the canonical vector
2218 // shape to the operand's domain. Further permutations are left as a part of
2219 // the contraction.
2220 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2221 AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2222 indexingMap.getNumResults(), rewriter.getContext());
2223 Type elemType = getElementTypeOrSelf(opOperand.get());
2224 VectorType readType =
2225 state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2226
2228 rewriter, loc, opOperand.get(), readType,
2229 /*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2230 /*useInBoundsInsteadOfMasking=*/false);
2231 vecOperands.push_back(read);
2232 }
2233
2234 // Remap iterators from linalg to vector.
2235 SmallVector<Attribute> iterAttrs;
2236 auto iterators = linalgOp.getIteratorTypesArray();
2237 for (utils::IteratorType iter : iterators) {
2238 auto vecIter = iter == utils::IteratorType::parallel
2239 ? vector::IteratorType::parallel
2240 : vector::IteratorType::reduction;
2241 iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2242 }
2243
2244 // Create contraction.
2245 Operation *contractOp = vector::ContractionOp::create(
2246 rewriter, loc, /*lhs=*/vecOperands[0],
2247 /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2248 linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2249 contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
2250
2251 // Store result.
2252 Operation *write = createWriteOrMaskedWrite(
2253 rewriter, loc, contractOp->getResult(0), outOperand->get());
2254
2255 // Finalize.
2256 if (!write->getResults().empty())
2257 newResults.push_back(write->getResult(0));
2258
2259 return success();
2260}
2261
2262namespace {
2263enum class ConvOperationKind { Conv, Pool };
2264} // namespace
2265
2266static bool isCastOfBlockArgument(Operation *op) {
2267 return isa<CastOpInterface>(op) && op->getNumOperands() == 1 &&
2268 isa<BlockArgument>(op->getOperand(0));
2269}
2270
2271// Returns the ConvOperationKind of the op using reduceOp of the generic
2272// payload. If it is neither a convolution nor a pooling, it returns
2273// std::nullopt.
2274//
2275// If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction
2276// + yield) and rhs is not used) then it is the body of a pooling
2277// If conv, check for single `mul` predecessor. The `mul` operands must be
2278// block arguments or extension of block arguments.
2279// Otherwise, check for one or zero `ext` predecessor. The `ext` operands
2280// must be block arguments or extension of block arguments.
2281static std::optional<ConvOperationKind>
2282getConvOperationKind(Operation *reduceOp) {
2283 int numBlockArguments =
2284 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
2285
2286 switch (numBlockArguments) {
2287 case 1: {
2288 // Will be convolution if feeder is a MulOp.
2289 // A strength reduced version of MulOp for i1 type is AndOp which is also
2290 // supported. Otherwise, it can be pooling. This strength reduction logic
2291 // is in `buildBinaryFn` helper in the Linalg dialect.
2292 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
2293 llvm::IsaPred<BlockArgument>);
2294 assert(feedValIt != reduceOp->operand_end() &&
2295 "Expected a non-block argument operand");
2296 Operation *feedOp = (*feedValIt).getDefiningOp();
2297 if (isCastOfBlockArgument(feedOp)) {
2298 return ConvOperationKind::Pool;
2299 }
2300
2301 if (!((isa<arith::MulIOp, arith::MulFOp>(feedOp) ||
2302 (isa<arith::AndIOp>(feedOp) &&
2303 feedOp->getResultTypes()[0].isInteger(1))) &&
2304 llvm::all_of(feedOp->getOperands(), [](Value v) {
2305 if (isa<BlockArgument>(v))
2306 return true;
2307 if (Operation *op = v.getDefiningOp())
2308 return isCastOfBlockArgument(op);
2309 return false;
2310 }))) {
2311 return std::nullopt;
2312 }
2313
2314 return ConvOperationKind::Conv;
2315 }
2316 case 2:
2317 // Must be pooling
2318 return ConvOperationKind::Pool;
2319 default:
2320 return std::nullopt;
2321 }
2322}
2323
2324static bool isSupportedPoolKind(vector::CombiningKind kind) {
2325 switch (kind) {
2326 case vector::CombiningKind::ADD:
2327 case vector::CombiningKind::MAXNUMF:
2328 case vector::CombiningKind::MAXIMUMF:
2329 case vector::CombiningKind::MAXSI:
2330 case vector::CombiningKind::MAXUI:
2331 case vector::CombiningKind::MINNUMF:
2332 case vector::CombiningKind::MINIMUMF:
2333 case vector::CombiningKind::MINSI:
2334 case vector::CombiningKind::MINUI:
2335 return true;
2336 default:
2337 return false;
2338 }
2339}
2340
2341static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) {
2342 auto getOperandType = [&](auto operand) {
2343 return dyn_cast<ShapedType>((operand->get()).getType());
2344 };
2345 ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(0));
2346 ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(1));
2347 ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(0));
2348 // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
2349 // (non-channeled convolution -> LHS and RHS both have single dimensions).
2350 // Note that this also ensures 2D and 3D convolutions are rejected.
2351 if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
2352 (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
2353 return failure();
2354
2355 Operation *reduceOp = matchLinalgReduction(convOp.getDpsInitOperand(0));
2356 if (!reduceOp)
2357 return failure();
2358
2359 auto maybeOper = getConvOperationKind(reduceOp);
2360 if (!maybeOper.has_value())
2361 return failure();
2362
2363 auto maybeKind = getCombinerOpKind(reduceOp);
2364 // Typically convolution will have a `Add` CombiningKind but for i1 type it
2365 // can get strength reduced to `OR` which is also supported. This strength
2366 // reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
2367 if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
2368 *maybeKind != vector::CombiningKind::OR) &&
2369 (*maybeOper != ConvOperationKind::Pool ||
2370 !isSupportedPoolKind(*maybeKind)))) {
2371 return failure();
2372 }
2373
2374 auto rhsRank = rhsShapedType.getRank();
2375 if (*maybeOper == ConvOperationKind::Pool) {
2376 if (rhsRank != 1)
2377 return failure();
2378 } else {
2379 if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
2380 return failure();
2381 }
2382
2383 return success();
2384}
2385
2386static LogicalResult vectorizeLinalgOpPrecondition(
2387 LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
2388 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2389 // tensor with dimension of 0 cannot be vectorized.
2390 if (llvm::any_of(linalgOp->getOpOperands(), [&](OpOperand &operand) {
2391 return llvm::is_contained(linalgOp.getShape(&operand), 0);
2392 }))
2393 return failure();
2394 // Check API contract for input vector sizes.
2395 if (!inputVectorSizes.empty() &&
2396 failed(vector::isValidMaskedInputVector(linalgOp.getStaticLoopRanges(),
2397 inputVectorSizes)))
2398 return failure();
2399
2400 if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
2401 linalgOp, flatten1DDepthwiseConv))) {
2402 LDBG() << "Dynamically-shaped op failed vectorization pre-conditions";
2403 return failure();
2404 }
2405
2406 SmallVector<CustomVectorizationPrecondition> customPreconditions;
2407
2408 // Register CustomVectorizationPrecondition for extractOp.
2409 customPreconditions.push_back(tensorExtractVectorizationPrecondition);
2410
2411 // All types in the body should be a supported element type for VectorType.
2412 for (Operation &innerOp : linalgOp->getRegion(0).front()) {
2413 // Check if any custom hook can vectorize the inner op.
2414 if (llvm::any_of(
2415 customPreconditions,
2416 [&](const CustomVectorizationPrecondition &customPrecondition) {
2417 return succeeded(
2418 customPrecondition(&innerOp, vectorizeNDExtract));
2419 })) {
2420 continue;
2421 }
2422 if (!llvm::all_of(innerOp.getOperandTypes(),
2423 VectorType::isValidElementType)) {
2424 return failure();
2425 }
2426 if (!llvm::all_of(innerOp.getResultTypes(),
2427 VectorType::isValidElementType)) {
2428 return failure();
2429 }
2430 }
2431 if (isElementwise(linalgOp))
2432 return success();
2433
2434 // TODO: isaConvolutionOpInterface that can also infer from generic
2435 // features. But we will still need stride/dilation attributes that will be
2436 // annoying to reverse-engineer...
2437 if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
2438 return vectorizeConvOpPrecondition(linalgOp);
2439
2440 // TODO: the common vector shape is equal to the static loop sizes only when
2441 // all indexing maps are projected permutations. For convs and stencils the
2442 // logic will need to evolve.
2443 if (!allIndexingsAreProjectedPermutation(linalgOp)) {
2444 LDBG() << "precondition failed: not projected permutations";
2445 return failure();
2446 }
2447 if (failed(reductionPreconditions(linalgOp))) {
2448 LDBG() << "precondition failed: reduction preconditions";
2449 return failure();
2450 }
2451 return success();
2452}
2453
2454static LogicalResult
2455vectorizePackOpPrecondition(linalg::PackOp packOp,
2456 ArrayRef<int64_t> inputVectorSizes) {
2457 auto padValue = packOp.getPaddingValue();
2458 Attribute cstAttr;
2459 // TODO: Relax this condiiton
2460 if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
2461 LDBG() << "pad value is not constant: " << packOp;
2462 return failure();
2463 }
2464
2465 ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
2466 bool satisfyEmptyCond = true;
2467 if (inputVectorSizes.empty()) {
2468 if (!packOp.getDestType().hasStaticShape() ||
2469 !packOp.getSourceType().hasStaticShape())
2470 satisfyEmptyCond = false;
2471 }
2472
2473 if (!satisfyEmptyCond &&
2475 resultTensorShape.take_front(packOp.getSourceRank()),
2476 inputVectorSizes)))
2477 return failure();
2478
2479 if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
2480 return !getConstantIntValue(v).has_value();
2481 })) {
2482 LDBG() << "inner_tiles must be constant: " << packOp;
2483 return failure();
2484 }
2485
2486 return success();
2487}
2488
2489static LogicalResult
2490vectorizePadOpPrecondition(tensor::PadOp padOp,
2491 ArrayRef<int64_t> inputVectorSizes) {
2492 auto padValue = padOp.getConstantPaddingValue();
2493 if (!padValue) {
2494 LDBG() << "pad value is not constant: " << padOp;
2495 return failure();
2496 }
2497
2498 ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape();
2499 if (failed(vector::isValidMaskedInputVector(resultTensorShape,
2500 inputVectorSizes)))
2501 return failure();
2502
2503 // Padding with non-zero low pad values is not supported, unless the
2504 // corresponding result dim is 1 as this would require shifting the results to
2505 // the right for the low padded dims by the required amount of low padding.
2506 // However, we do support low padding if the dims being low padded have result
2507 // sizes of 1. The reason is when we have a low pad on a unit result dim, the
2508 // input size of that dimension will be dynamically zero (as the sum of the
2509 // low pad and input dim size has to be one) and hence we will create a zero
2510 // mask as the lowering logic just makes the mask one for the input dim size -
2511 // which is zero here. Hence we will load the pad value which is what we want
2512 // in this case. If the low pad is dynamically zero then the lowering is
2513 // correct as well as no shifts are necessary.
2514 if (llvm::any_of(llvm::enumerate(padOp.getLow()), [&](const auto &en) {
2515 Value padValue = en.value();
2516 unsigned pos = en.index();
2517 std::optional<int64_t> pad = getConstantIntValue(padValue);
2518 return (!pad.has_value() || pad.value() != 0) &&
2519 resultTensorShape[pos] != 1;
2520 })) {
2521 LDBG() << "low pad must all be zero for all non unit dims: " << padOp;
2522 return failure();
2523 }
2524
2525 return success();
2526}
2527
2528/// Preconditions for scalable vectors.
2529///
2530/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2531/// models the fact that in practice we would only make selected dimensions
2532/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2533/// unconditionally - we are yet to identify meaningful conditions.
2534static LogicalResult
2535vectorizeScalableVectorPrecondition(Operation *op,
2536 ArrayRef<int64_t> inputVectorSizes,
2537 ArrayRef<bool> inputScalableVecDims) {
2538 assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
2539 "Number of input vector sizes and scalable dims doesn't match");
2540
2541 size_t numOfScalableDims =
2542 llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
2543
2544 if (numOfScalableDims == 0)
2545 return success();
2546
2547 auto linalgOp = dyn_cast<LinalgOp>(op);
2548
2549 // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2550 // exception of UnpackOp for which there is a dedicated hook.
2551 if (!linalgOp) {
2552 return success(isa<linalg::UnPackOp>(op));
2553 }
2554
2555 // Cond 2: There's been no need for more than 2 scalable dims so far
2556 if (numOfScalableDims > 2)
2557 return failure();
2558
2559 // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
2560 // it matches one of the supported cases:
2561 // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
2562 // (*).
2563 // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
2564 // parallel dims.
2565 // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
2566 // dim.
2567 // The 2nd restriction above means that only Matmul-like Ops are supported
2568 // when 2 dims are scalable, e.g. :
2569 // * iterators = [parallel, parallel, reduction]
2570 // * scalable flags = [true, true, false]
2571 //
2572 // (*) Non-unit dims get folded away in practice.
2573 // TODO: Relax these conditions as good motivating examples are identified.
2574
2575 // Find the first scalable flag.
2576 bool seenNonUnitParallel = false;
2577 auto iterators = linalgOp.getIteratorTypesArray();
2578 SmallVector<bool> scalableFlags(inputScalableVecDims);
2579 int64_t idx = scalableFlags.size() - 1;
2580 while (!scalableFlags[idx]) {
2581 bool isNonUnitDim = (inputVectorSizes[idx] != 1);
2582 seenNonUnitParallel |=
2583 (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
2584
2585 iterators.pop_back();
2586 scalableFlags.pop_back();
2587 --idx;
2588 }
2589
2590 // Analyze the iterator corresponding to the first scalable dim.
2591 switch (iterators.back()) {
2592 case utils::IteratorType::reduction: {
2593 // Check 3. above is met.
2594 if (iterators.size() != inputVectorSizes.size()) {
2595 LDBG() << "Non-trailing reduction dim requested for scalable "
2596 "vectorization";
2597 return failure();
2598 }
2599 if (isa<linalg::MatmulOp>(op)) {
2600 LDBG()
2601 << "Scalable vectorization of the reduction dim in Matmul-like ops "
2602 "is not supported";
2603 return failure();
2604 }
2605 break;
2606 }
2607 case utils::IteratorType::parallel: {
2608 // Check 1. and 2. above are met.
2609 if (seenNonUnitParallel) {
2610 LDBG() << "Inner parallel dim not requested for scalable "
2611 "vectorization";
2612 return failure();
2613 }
2614 break;
2615 }
2616 }
2617
2618 // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
2619 // supported for which expect the folowing config:
2620 // * iterators = [parallel, parallel, reduction]
2621 // * scalable flags = [true, true, false]
2622 if (numOfScalableDims == 2) {
2623 // Disallow below case which breaks 3. above:
2624 // * iterators = [..., parallel, reduction]
2625 // * scalable flags = [..., true, true]
2626 if (iterators.back() == utils::IteratorType::reduction) {
2627 LDBG() << "Higher dim than the trailing reduction dim requested for "
2628 "scalable "
2629 "vectorizatio";
2630 return failure();
2631 }
2632 scalableFlags.pop_back();
2633 iterators.pop_back();
2634
2635 if (!scalableFlags.back() ||
2636 (iterators.back() != utils::IteratorType::parallel))
2637 return failure();
2638 }
2639
2640 // Cond 4: Only the following ops are supported in the
2641 // presence of scalable vectors
2642 return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
2643 isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2644 isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2645 isa<linalg::BatchMmt4DOp>(op) ||
2646 hasReductionIterator(linalgOp));
2647}
2648
2650 Operation *op, ArrayRef<int64_t> inputVectorSizes,
2651 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2652 bool flatten1DDepthwiseConv) {
2653
2654 if (!hasVectorizationImpl(op))
2655 return failure();
2656
2657 if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
2658 inputScalableVecDims)))
2659 return failure();
2660
2662 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2663 return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
2664 vectorizeNDExtract,
2665 flatten1DDepthwiseConv);
2666 })
2667 .Case<tensor::PadOp>([&](auto padOp) {
2668 return vectorizePadOpPrecondition(padOp, inputVectorSizes);
2669 })
2670 .Case<linalg::PackOp>([&](auto packOp) {
2671 return vectorizePackOpPrecondition(packOp, inputVectorSizes);
2672 })
2673 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2674 return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
2675 })
2676 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2677 return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
2678 })
2679 .Default([](auto) { return failure(); });
2680}
2681
2682/// Converts affine.apply Ops to arithmetic operations.
2683static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2684 OpBuilder::InsertionGuard g(rewriter);
2685 auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>();
2686
2687 for (auto op : make_early_inc_range(toReplace)) {
2688 rewriter.setInsertionPoint(op);
2689 auto expanded = affine::expandAffineExpr(
2690 rewriter, op->getLoc(), op.getAffineMap().getResult(0),
2691 op.getOperands().take_front(op.getAffineMap().getNumDims()),
2692 op.getOperands().take_back(op.getAffineMap().getNumSymbols()));
2693 rewriter.replaceOp(op, expanded);
2694 }
2695}
2696
2697bool mlir::linalg::hasVectorizationImpl(Operation *op) {
2698 return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp,
2699 tensor::InsertSliceOp>(op);
2700}
2701
2702FailureOr<VectorizationResult> mlir::linalg::vectorize(
2703 RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2704 ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2705 bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
2706 bool createNamedContraction) {
2707 LDBG() << "Attempting to vectorize: " << *op;
2708 LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes);
2709 LDBG() << "Input scalable vector dims: "
2710 << llvm::interleaved(inputScalableVecDims);
2711
2712 if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2713 vectorizeNDExtract,
2714 flatten1DDepthwiseConv))) {
2715 LDBG() << "Vectorization pre-conditions failed";
2716 return failure();
2717 }
2718
2719 // Initialize vectorization state.
2720 VectorizationState state(rewriter);
2721 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
2722 if (failed(state.initState(rewriter, linalgOp, inputVectorSizes,
2723 inputScalableVecDims,
2724 assumeDynamicDimsMatchVecSizes))) {
2725 LDBG() << "Vectorization state couldn't be initialized";
2726 return failure();
2727 }
2728 }
2729
2730 SmallVector<Value> results;
2731 auto vectorizeResult =
2733 .Case<linalg::LinalgOp>([&](auto linalgOp) {
2734 // TODO: isaConvolutionOpInterface that can also infer from
2735 // generic features. Will require stride/dilation attributes
2736 // inference.
2737 if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2738 FailureOr<Operation *> convOr = vectorizeConvolution(
2739 rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2740 flatten1DDepthwiseConv);
2741 if (succeeded(convOr)) {
2742 llvm::append_range(results, (*convOr)->getResults());
2743 return success();
2744 }
2745
2746 LDBG() << "Unsupported convolution can't be vectorized.";
2747 return failure();
2748 }
2749
2750 if (createNamedContraction &&
2751 isa<ContractionOpInterface>(linalgOp.getOperation()))
2752 return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2753 results);
2754
2755 LDBG()
2756 << "Vectorize generic by broadcasting to the canonical vector "
2757 "shape";
2758
2759 // Pre-process before proceeding.
2760 convertAffineApply(rewriter, linalgOp);
2761
2762 // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2763 // to 'OpBuilder' when it is passed over to some methods like
2764 // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2765 // erase an op within these methods, the actual rewriter won't be
2766 // notified and we will end up with read-after-free issues!
2767 return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2768 })
2769 .Case<tensor::PadOp>([&](auto padOp) {
2770 return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2771 results);
2772 })
2773 .Case<linalg::PackOp>([&](auto packOp) {
2774 return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2775 results);
2776 })
2777 .Case<linalg::UnPackOp>([&](auto unpackOp) {
2778 return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2779 inputVectorSizes,
2780 inputScalableVecDims, results);
2781 })
2782 .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2783 return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2784 results);
2785 })
2786 .Default([](auto) { return failure(); });
2787
2788 if (failed(vectorizeResult)) {
2789 LDBG() << "Vectorization failed";
2790 return failure();
2791 }
2792
2793 return VectorizationResult{results};
2794}
2795
2796LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
2797 memref::CopyOp copyOp) {
2798 auto srcType = cast<MemRefType>(copyOp.getSource().getType());
2799 auto dstType = cast<MemRefType>(copyOp.getTarget().getType());
2800 if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
2801 return failure();
2802
2803 auto srcElementType = getElementTypeOrSelf(srcType);
2804 auto dstElementType = getElementTypeOrSelf(dstType);
2805 if (!VectorType::isValidElementType(srcElementType) ||
2806 !VectorType::isValidElementType(dstElementType))
2807 return failure();
2808
2809 auto readType = VectorType::get(srcType.getShape(), srcElementType);
2810 auto writeType = VectorType::get(dstType.getShape(), dstElementType);
2811
2812 Location loc = copyOp->getLoc();
2813 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
2814 SmallVector<Value> indices(srcType.getRank(), zero);
2815
2816 Value readValue = vector::TransferReadOp::create(
2817 rewriter, loc, readType, copyOp.getSource(), indices,
2818 /*padding=*/std::nullopt,
2819 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2820 if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2821 readValue = vector::ExtractOp::create(rewriter, loc, readValue,
2822 ArrayRef<int64_t>());
2823 readValue =
2824 vector::BroadcastOp::create(rewriter, loc, writeType, readValue);
2825 }
2826 Operation *writeValue = vector::TransferWriteOp::create(
2827 rewriter, loc, readValue, copyOp.getTarget(), indices,
2828 rewriter.getMultiDimIdentityMap(srcType.getRank()));
2829 rewriter.replaceOp(copyOp, writeValue->getResults());
2830 return success();
2831}
2832
2833//----------------------------------------------------------------------------//
2834// Misc. vectorization patterns.
2835//----------------------------------------------------------------------------//
2836/// Base pattern for rewriting tensor::PadOps whose result is consumed by a
2837/// given operation type OpTy.
2838template <typename OpTy>
2839struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> {
2840 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
2841
2842 LogicalResult matchAndRewrite(tensor::PadOp padOp,
2843 PatternRewriter &rewriter) const final {
2844 bool changed = false;
2845 // Insert users in vector, because some users may be replaced/removed.
2846 for (auto *user : llvm::to_vector<4>(padOp->getUsers()))
2847 if (auto op = dyn_cast<OpTy>(user))
2848 changed |= rewriteUser(rewriter, padOp, op).succeeded();
2849 return success(changed);
2850 }
2851
2852protected:
2853 virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
2854 tensor::PadOp padOp, OpTy op) const = 0;
2855};
2856
2857/// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.:
2858/// ```
2859/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
2860/// %r = vector.transfer_read %0[%c0, %c0], %cst
2861/// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32>
2862/// ```
2863/// is rewritten to:
2864/// ```
2865/// %r = vector.transfer_read %src[%c0, %c0], %padding
2866/// {in_bounds = [true, true]}
2867/// : tensor<?x?xf32>, vector<17x5xf32>
2868/// ```
2869/// Note: By restricting this pattern to in-bounds TransferReadOps, we can be
2870/// sure that the original padding value %cst was never used.
2871///
2872/// This rewrite is possible if:
2873/// - `xferOp` has no out-of-bounds dims or mask.
2874/// - Low padding is static 0.
2875/// - Single, scalar padding value.
2876struct PadOpVectorizationWithTransferReadPattern
2877 : public VectorizePadOpUserPattern<vector::TransferReadOp> {
2878 using VectorizePadOpUserPattern<
2879 vector::TransferReadOp>::VectorizePadOpUserPattern;
2880
2881 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2882 vector::TransferReadOp xferOp) const override {
2883 // Low padding must be static 0.
2884 if (!padOp.hasZeroLowPad())
2885 return failure();
2886 // Pad value must be a constant.
2887 auto padValue = padOp.getConstantPaddingValue();
2888 if (!padValue)
2889 return failure();
2890 // Padding value of existing `xferOp` is unused.
2891 if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
2892 return failure();
2893
2894 rewriter.modifyOpInPlace(xferOp, [&]() {
2895 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2896 xferOp->setAttr(xferOp.getInBoundsAttrName(),
2897 rewriter.getBoolArrayAttr(inBounds));
2898 xferOp.getBaseMutable().assign(padOp.getSource());
2899 xferOp.getPaddingMutable().assign(padValue);
2900 });
2901
2902 return success();
2903 }
2904};
2905
2906/// Rewrite use of tensor::PadOp result in TransferWriteOp.
2907/// This pattern rewrites TransferWriteOps that write to a padded tensor
2908/// value, where the same amount of padding is immediately removed again after
2909/// the write. In such cases, the TransferWriteOp can write to the non-padded
2910/// tensor value and apply out-of-bounds masking. E.g.:
2911/// ```
2912/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2913/// : tensor<...> to tensor<?x?xf32>
2914/// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32>
2915/// %2 = vector.transfer_write %vec, %1[...]
2916/// : vector<17x5xf32>, tensor<17x5xf32>
2917/// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1]
2918/// : tensor<17x5xf32> to tensor<?x?xf32>
2919/// ```
2920/// is rewritten to:
2921/// ```
2922/// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1]
2923/// : tensor<...> to tensor<?x?xf32>
2924/// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>,
2925/// tensor<?x?xf32>
2926/// ```
2927/// Note: It is important that the ExtractSliceOp %r resizes the result of the
2928/// TransferWriteOp to the same size as the input of the TensorPadOp (or an
2929/// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ
2930/// from %r's old dimensions.
2931///
2932/// This rewrite is possible if:
2933/// - Low padding is static 0.
2934/// - `xferOp` has exactly one use, which is an ExtractSliceOp. This
2935/// ExtractSliceOp trims the same amount of padding that was added
2936/// beforehand.
2937/// - Single, scalar padding value.
2938struct PadOpVectorizationWithTransferWritePattern
2939 : public VectorizePadOpUserPattern<vector::TransferWriteOp> {
2940 using VectorizePadOpUserPattern<
2941 vector::TransferWriteOp>::VectorizePadOpUserPattern;
2942
2943 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
2944 vector::TransferWriteOp xferOp) const override {
2945 // TODO: support 0-d corner case.
2946 if (xferOp.getTransferRank() == 0)
2947 return failure();
2948
2949 // Low padding must be static 0.
2950 if (!padOp.hasZeroLowPad())
2951 return failure();
2952 // Pad value must be a constant.
2953 auto padValue = padOp.getConstantPaddingValue();
2954 if (!padValue)
2955 return failure();
2956 // TransferWriteOp result must be directly consumed by an ExtractSliceOp.
2957 if (!xferOp->hasOneUse())
2958 return failure();
2959 auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
2960 if (!trimPadding)
2961 return failure();
2962 // Only static zero offsets supported when trimming padding.
2963 if (!trimPadding.hasZeroOffset())
2964 return failure();
2965 // trimPadding must remove the amount of padding that was added earlier.
2966 if (!hasSameTensorSize(padOp.getSource(), trimPadding))
2967 return failure();
2968
2969 // Insert the new TransferWriteOp at position of the old TransferWriteOp.
2970 rewriter.setInsertionPoint(xferOp);
2971
2972 SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
2973 auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2974 xferOp, padOp.getSource().getType(), xferOp.getVector(),
2975 padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(),
2976 xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds));
2977 rewriter.replaceOp(trimPadding, newXferOp->getResult(0));
2978
2979 return success();
2980 }
2981
2982 /// Check if `beforePadding` and `afterTrimming` have the same tensor size,
2983 /// i.e., same dimensions.
2984 ///
2985 /// Dimensions may be static, dynamic or mix of both. In case of dynamic
2986 /// dimensions, this function tries to infer the (static) tensor size by
2987 /// looking at the defining op and utilizing op-specific knowledge.
2988 ///
2989 /// This is a conservative analysis. In case equal tensor sizes cannot be
2990 /// proven statically, this analysis returns `false` even though the tensor
2991 /// sizes may turn out to be equal at runtime.
2992 bool hasSameTensorSize(Value beforePadding,
2993 tensor::ExtractSliceOp afterTrimming) const {
2994 // If the input to tensor::PadOp is a CastOp, try with both CastOp
2995 // result and CastOp operand.
2996 if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
2997 if (hasSameTensorSize(castOp.getSource(), afterTrimming))
2998 return true;
2999
3000 auto t1 = dyn_cast<RankedTensorType>(beforePadding.getType());
3001 auto t2 = dyn_cast<RankedTensorType>(afterTrimming.getType());
3002 // Only RankedTensorType supported.
3003 if (!t1 || !t2)
3004 return false;
3005 // Rank of both values must be the same.
3006 if (t1.getRank() != t2.getRank())
3007 return false;
3008
3009 // All static dimensions must be the same. Mixed cases (e.g., dimension
3010 // static in `t1` but dynamic in `t2`) are not supported.
3011 for (unsigned i = 0; i < t1.getRank(); ++i) {
3012 if (t1.isDynamicDim(i) != t2.isDynamicDim(i))
3013 return false;
3014 if (!t1.isDynamicDim(i) && t1.getDimSize(i) != t2.getDimSize(i))
3015 return false;
3016 }
3017
3018 // Nothing more to check if all dimensions are static.
3019 if (t1.getNumDynamicDims() == 0)
3020 return true;
3021
3022 // All dynamic sizes must be the same. The only supported case at the
3023 // moment is when `beforePadding` is an ExtractSliceOp (or a cast
3024 // thereof).
3025
3026 // Apart from CastOp, only ExtractSliceOp is supported.
3027 auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>();
3028 if (!beforeSlice)
3029 return false;
3030
3031 assert(static_cast<size_t>(t1.getRank()) ==
3032 beforeSlice.getMixedSizes().size());
3033 assert(static_cast<size_t>(t2.getRank()) ==
3034 afterTrimming.getMixedSizes().size());
3035
3036 for (unsigned i = 0; i < t1.getRank(); ++i) {
3037 // Skip static dimensions.
3038 if (!t1.isDynamicDim(i))
3039 continue;
3040 auto size1 = beforeSlice.getMixedSizes()[i];
3041 auto size2 = afterTrimming.getMixedSizes()[i];
3042
3043 // Case 1: Same value or same constant int.
3044 if (isEqualConstantIntOrValue(size1, size2))
3045 continue;
3046
3047 // Other cases: Take a deeper look at defining ops of values.
3048 auto v1 = llvm::dyn_cast_if_present<Value>(size1);
3049 auto v2 = llvm::dyn_cast_if_present<Value>(size2);
3050 if (!v1 || !v2)
3051 return false;
3052
3053 // Case 2: Both values are identical AffineMinOps. (Should not happen if
3054 // CSE is run.)
3055 auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>();
3056 auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>();
3057 if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
3058 minOp1.getOperands() == minOp2.getOperands())
3059 continue;
3060
3061 // Add additional cases as needed.
3062 }
3063
3064 // All tests passed.
3065 return true;
3066 }
3067};
3068
3069/// Returns the effective Pad value for the input op, provided it's a scalar.
3070///
3071/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
3072/// this Op performs padding, retrieve the padding value provided that it's
3073/// a scalar and static/fixed for all the padded values. Returns an empty value
3074/// otherwise.
3075///
3076/// TODO: This is used twice (when checking vectorization pre-conditions and
3077/// when vectorizing). Cache results instead of re-running.
3078static Value getStaticPadVal(Operation *op) {
3079 if (!op)
3080 return {};
3081
3082 // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
3083 // being broadcast, provided that it's a scalar.
3084 if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
3085 auto source = bcast.getSource();
3086 if (llvm::dyn_cast<VectorType>(source.getType()))
3087 return {};
3088
3089 return source;
3090 }
3091
3092 // 2. linalg.fill - use the scalar input value that used to fill the output
3093 // tensor.
3094 if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
3095 return fill.getInputs()[0];
3096 }
3097
3098 // 3. tensor.generateOp - can't guarantee the value is fixed without
3099 // analysing, bail out.
3100 if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
3101 return {};
3102 }
3103
3104 // 4. vector.transfer_write - inspect the input vector that's written from. If
3105 // if contains a single value that has been broadcast (e.g. via
3106 // vector.broadcast), extract it, fail otherwise.
3107 if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
3108 return getStaticPadVal(xferWrite.getVector().getDefiningOp());
3109
3110 // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
3111 // than the input tensor, then, provided it's constant, we'll extract the
3112 // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
3113 // TODO: Clarify the semantics when the input tensor is larger than the
3114 // destination.
3115 if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
3116 return getStaticPadVal(slice.getDest().getDefiningOp());
3117
3118 return {};
3119}
3120
3121static LogicalResult
3122vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3123 ArrayRef<int64_t> inputVectorSizes,
3124 SmallVectorImpl<Value> &newResults) {
3125 // TODO: Introduce a parent class that will handle the insertion point update.
3126 OpBuilder::InsertionGuard g(rewriter);
3127 rewriter.setInsertionPoint(sliceOp);
3128
3129 TypedValue<RankedTensorType> source = sliceOp.getSource();
3130 auto sourceType = source.getType();
3131 auto resultType = sliceOp.getResultType();
3132
3133 Value padValue = getStaticPadVal(sliceOp);
3134
3135 if (!padValue) {
3136 auto elemType = sourceType.getElementType();
3137 padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType,
3138 rewriter.getZeroAttr(elemType));
3139 }
3140
3141 // 2. Get the vector shape
3142 SmallVector<int64_t> vecShape;
3143 size_t rankDiff = resultType.getRank() - sourceType.getRank();
3144 for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) {
3145 if (!inputVectorSizes.empty()) {
3146 vecShape.push_back(inputVectorSizes[i]);
3147 } else if (!sourceType.isDynamicDim(i)) {
3148 vecShape.push_back(sourceType.getDimSize(i));
3149 } else if (!resultType.isDynamicDim(i)) {
3150 // Source shape is not statically known, but result shape is.
3151 // Vectorize with size of result shape. This may be larger than the
3152 // source size.
3153 // FIXME: Using rankDiff implies that the source tensor is inserted at
3154 // the end of the destination tensor. However, that's not required.
3155 vecShape.push_back(resultType.getDimSize(rankDiff + i));
3156 } else {
3157 // Neither source nor result dim of padOp is static. Cannot vectorize
3158 // the copy.
3159 return failure();
3160 }
3161 }
3162 auto vecType = VectorType::get(vecShape, sourceType.getElementType());
3163
3164 // 3. Generate TransferReadOp + TransferWriteOp
3165 auto loc = sliceOp.getLoc();
3166
3167 // Create read
3168 SmallVector<Value> readIndices(
3169 vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
3171 rewriter, loc, source, vecType, padValue,
3172 /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3173
3174 // Create write
3175 auto writeIndices =
3176 getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3177 Operation *write =
3178 createWriteOrMaskedWrite(rewriter, loc, read, sliceOp.getDest(),
3179 writeIndices, inputVectorSizes.empty());
3180
3181 // 4. Finalize
3182 newResults.push_back(write->getResult(0));
3183
3184 return success();
3185}
3186
3187/// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
3188/// ```
3189/// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
3190/// %r = tensor.insert_slice %0
3191/// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
3192/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
3193/// ```
3194/// is rewritten to:
3195/// ```
3196/// %0 = vector.transfer_read %src[%c0, %c0], %padding
3197/// : tensor<?x?xf32>, vector<17x5xf32>
3198/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
3199/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
3200/// ```
3201///
3202/// This rewrite is possible if:
3203/// - Low padding is static 0.
3204/// - `padOp` result shape is static.
3205/// - The entire padded tensor is inserted.
3206/// (Implies that sizes of `insertOp` are all static.)
3207/// - Only unit strides in `insertOp`.
3208/// - Single, scalar padding value.
3209/// - `padOp` result not used as destination.
3210struct PadOpVectorizationWithInsertSlicePattern
3211 : public VectorizePadOpUserPattern<tensor::InsertSliceOp> {
3212 using VectorizePadOpUserPattern<
3213 tensor::InsertSliceOp>::VectorizePadOpUserPattern;
3214
3215 LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp,
3216 tensor::InsertSliceOp insertOp) const override {
3217 // Low padding must be static 0.
3218 if (!padOp.hasZeroLowPad())
3219 return failure();
3220 // Only unit stride supported.
3221 if (!insertOp.hasUnitStride())
3222 return failure();
3223 // Pad value must be a constant.
3224 auto padValue = padOp.getConstantPaddingValue();
3225 if (!padValue)
3226 return failure();
3227 // Dynamic shapes not supported.
3228 if (!cast<ShapedType>(padOp.getResult().getType()).hasStaticShape())
3229 return failure();
3230 // Pad result not used as destination.
3231 if (insertOp.getDest() == padOp.getResult())
3232 return failure();
3233
3234 auto vecType = VectorType::get(padOp.getType().getShape(),
3235 padOp.getType().getElementType());
3236 unsigned vecRank = vecType.getRank();
3237 unsigned tensorRank = insertOp.getType().getRank();
3238
3239 // Check if sizes match: Insert the entire tensor into most minor dims.
3240 // (No permutations allowed.)
3241 SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
3242 expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
3243 if (!llvm::all_of(
3244 llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
3245 return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
3246 }))
3247 return failure();
3248
3249 // Insert the TransferReadOp and TransferWriteOp at the position of the
3250 // InsertSliceOp.
3251 rewriter.setInsertionPoint(insertOp);
3252
3253 // Generate TransferReadOp: Read entire source tensor and add high
3254 // padding.
3255 SmallVector<Value> readIndices(
3256 vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0));
3257 auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(),
3258 vecType, padOp.getSource(),
3259 readIndices, padValue);
3260
3261 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
3262 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
3263 // source must fit into the destination at the specified offsets.
3264 auto writeIndices = getValueOrCreateConstantIndexOp(
3265 rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
3266 SmallVector<bool> inBounds(vecRank, true);
3267 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
3268 insertOp, read, insertOp.getDest(), writeIndices,
3269 ArrayRef<bool>{inBounds});
3270
3271 return success();
3272 }
3273};
3274
3276 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
3277 patterns.add<PadOpVectorizationWithTransferReadPattern,
3278 PadOpVectorizationWithTransferWritePattern,
3279 PadOpVectorizationWithInsertSlicePattern>(
3280 patterns.getContext(), baseBenefit.getBenefit() + 1);
3281}
3282
3283//----------------------------------------------------------------------------//
3284// Forwarding patterns
3285//----------------------------------------------------------------------------//
3286
3287/// Check whether there is any interleaved use of any `values` between
3288/// `firstOp` and `secondOp`. Conservatively return `true` if any op or value
3289/// is in a different block.
3290static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
3291 ValueRange values) {
3292 if (firstOp->getBlock() != secondOp->getBlock() ||
3293 !firstOp->isBeforeInBlock(secondOp)) {
3294 LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp
3295 << ", second op: " << *secondOp;
3296 return true;
3297 }
3298 for (auto v : values) {
3299 for (auto &u : v.getUses()) {
3300 Operation *owner = u.getOwner();
3301 if (owner == firstOp || owner == secondOp)
3302 continue;
3303 // TODO: this is too conservative, use dominance info in the future.
3304 if (owner->getBlock() == firstOp->getBlock() &&
3305 (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
3306 continue;
3307 LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp
3308 << ", second op: " << *secondOp;
3309 return true;
3310 }
3311 }
3312 return false;
3313}
3314
3315/// Return the unique subview use of `v` if it is indeed unique, null
3316/// otherwise.
3317static memref::SubViewOp getSubViewUseIfUnique(Value v) {
3318 memref::SubViewOp subViewOp;
3319 for (auto &u : v.getUses()) {
3320 if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(u.getOwner())) {
3321 if (subViewOp)
3322 return memref::SubViewOp();
3323 subViewOp = newSubViewOp;
3324 }
3325 }
3326 return subViewOp;
3327}
3328
3329/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3330/// when available.
3332 vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
3333
3334 // TODO: support mask.
3335 if (xferOp.getMask())
3336 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3337
3338 // Transfer into `view`.
3339 Value viewOrAlloc = xferOp.getBase();
3340 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3341 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3342 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3343
3344 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3345 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3346 if (!subViewOp)
3347 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3348 Value subView = subViewOp.getResult();
3349
3350 // Find the copy into `subView` without interleaved uses.
3351 memref::CopyOp copyOp;
3352 for (auto &u : subView.getUses()) {
3353 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3354 assert(isa<MemRefType>(newCopyOp.getTarget().getType()));
3355 if (newCopyOp.getTarget() != subView)
3356 continue;
3357 if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
3358 continue;
3359 copyOp = newCopyOp;
3360 break;
3361 }
3362 }
3363 if (!copyOp)
3364 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3365
3366 // Find the fill into `viewOrAlloc` without interleaved uses before the
3367 // copy.
3368 FillOp maybeFillOp;
3369 for (auto &u : viewOrAlloc.getUses()) {
3370 if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
3371 assert(isa<MemRefType>(newFillOp.output().getType()));
3372 if (newFillOp.output() != viewOrAlloc)
3373 continue;
3374 if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
3375 continue;
3376 maybeFillOp = newFillOp;
3377 break;
3378 }
3379 }
3380 // Ensure padding matches.
3381 if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value())
3382 return rewriter.notifyMatchFailure(xferOp,
3383 "padding value does not match fill");
3384
3385 // `in` is the subview that memref.copy reads. Replace it.
3386 Value in = copyOp.getSource();
3387
3388 // memref.copy + linalg.fill can be used to create a padded local buffer.
3389 // The `masked` attribute is only valid on this padded buffer.
3390 // When forwarding to vector.transfer_read, the attribute must be reset
3391 // conservatively.
3392 auto vectorType = xferOp.getVectorType();
3393 Value res = vector::TransferReadOp::create(
3394 rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(),
3395 xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(),
3396 rewriter.getBoolArrayAttr(
3397 SmallVector<bool>(vectorType.getRank(), false)));
3398
3399 if (maybeFillOp)
3400 rewriter.eraseOp(maybeFillOp);
3401 rewriter.eraseOp(copyOp);
3402 rewriter.replaceOp(xferOp, res);
3403
3404 return success();
3405}
3406
3407/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
3408/// when available.
3410 vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
3411 // TODO: support mask.
3412 if (xferOp.getMask())
3413 return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
3414
3415 // Transfer into `viewOrAlloc`.
3416 Value viewOrAlloc = xferOp.getBase();
3417 if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
3418 !viewOrAlloc.getDefiningOp<memref::AllocOp>())
3419 return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
3420
3421 // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
3422 memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
3423 if (!subViewOp)
3424 return rewriter.notifyMatchFailure(xferOp, "no subview found");
3425 Value subView = subViewOp.getResult();
3426
3427 // Find the copy from `subView` without interleaved uses.
3428 memref::CopyOp copyOp;
3429 for (auto &u : subViewOp.getResult().getUses()) {
3430 if (auto newCopyOp = dyn_cast<memref::CopyOp>(u.getOwner())) {
3431 if (newCopyOp.getSource() != subView)
3432 continue;
3433 if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
3434 continue;
3435 copyOp = newCopyOp;
3436 break;
3437 }
3438 }
3439 if (!copyOp)
3440 return rewriter.notifyMatchFailure(xferOp, "no copy found");
3441
3442 // `out` is the subview copied into that we replace.
3443 assert(isa<MemRefType>(copyOp.getTarget().getType()));
3444 Value out = copyOp.getTarget();
3445
3446 // Forward vector.transfer into copy.
3447 // memref.copy + linalg.fill can be used to create a padded local buffer.
3448 // The `masked` attribute is only valid on this padded buffer.
3449 // When forwarding to vector.transfer_write, the attribute must be reset
3450 // conservatively.
3451 auto vector = xferOp.getVector();
3452 vector::TransferWriteOp::create(
3453 rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(),
3454 xferOp.getPermutationMapAttr(), xferOp.getMask(),
3455 rewriter.getBoolArrayAttr(SmallVector<bool>(
3456 dyn_cast<VectorType>(vector.getType()).getRank(), false)));
3457
3458 rewriter.eraseOp(copyOp);
3459 rewriter.eraseOp(xferOp);
3460
3461 return success();
3462}
3463
3464//===----------------------------------------------------------------------===//
3465// Convolution vectorization patterns
3466//===----------------------------------------------------------------------===//
3467
3468template <int N>
3469static void bindShapeDims(ShapedType shapedType) {}
3470
3471template <int N, typename IntTy, typename... IntTy2>
3472static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
3473 val = shapedType.getShape()[N];
3474 bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
3475}
3476
3477/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
3478template <typename... IntTy>
3479static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
3480 bindShapeDims<0>(shapedType, vals...);
3481}
3482
3483namespace {
3484/// Generate a vector implementation for either:
3485/// ```
3486/// Op def: ( w, kw )
3487/// Iters: ({Par(), Red()})
3488/// Layout: {{w + kw}, {kw}, {w}}
3489/// ```
3490/// kw is unrolled.
3491///
3492/// or
3493///
3494/// ```
3495/// Op def: ( n, w, c, kw, f )
3496/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3497/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3498/// ```
3499/// kw is unrolled, w is unrolled iff dilationW > 1.
3500///
3501/// or
3502///
3503/// ```
3504/// Op def: ( n, c, w, f, kw )
3505/// Iters: ({Par(), Par(), Par(), Red(), Red()})
3506/// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
3507/// ```
3508/// kw is unrolled, w is unrolled iff dilationW > 1.
3509///
3510/// or
3511///
3512/// ```
3513/// Op def: ( n, w, c, kw )
3514/// Iters: ({Par(), Par(), Par(), Red()})
3515/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3516/// ```
3517/// kw is unrolled, w is unrolled iff dilationW > 1.
3518struct Conv1DGenerator
3519 : public StructuredGenerator<LinalgOp, utils::IteratorType> {
3520 Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp)
3521 : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {
3522
3523 lhsShaped = linalgOp.getDpsInputOperand(0)->get();
3524 rhsShaped = linalgOp.getDpsInputOperand(1)->get();
3525 resShaped = linalgOp.getDpsInitOperand(0)->get();
3526 lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
3527 rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
3528 resShapedType = dyn_cast<ShapedType>(resShaped.getType());
3529
3530 Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
3531 redOp = reduceOp->getName().getIdentifier();
3532
3533 setConvOperationKind(reduceOp);
3534
3535 auto maybeKind = getCombinerOpKind(reduceOp);
3536 reductionKind = maybeKind.value();
3537
3538 // The ConvolutionOpInterface gives us guarantees of existence for
3539 // strides/dilations. However, we do not need to rely on those, we can
3540 // simply use them if present, otherwise use the default and let the generic
3541 // conv. matcher in the ConvGenerator succeed or fail.
3542 auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
3543 auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
3544 strideW = strides ? *strides.getValues<uint64_t>().begin() : 1;
3545 dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
3546 }
3547
3548 /// Generate a vector implementation for:
3549 /// ```
3550 /// Op def: ( w, kw )
3551 /// Iters: ({Par(), Red()})
3552 /// Layout: {{w + kw}, {kw}, {w}}
3553 /// ```
3554 /// kw is always unrolled.
3555 ///
3556 /// or
3557 ///
3558 /// ```
3559 /// Op def: ( n, w, c, kw, f )
3560 /// Iters: ({Par(), Par(), Par(), Red(), Red()})
3561 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
3562 /// ```
3563 /// kw is always unrolled.
3564 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3565 /// > 1.
3566 FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) {
3567 int64_t nSize, wSize, cSize, kwSize, fSize;
3568 SmallVector<int64_t, 3> lhsShape, rhsShape, resShape;
3569 bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W);
3570 switch (conv1DOpOrder) {
3571 case Conv1DOpOrder::W:
3572 // Initialize unused dimensions
3573 nSize = fSize = cSize = 0;
3574 // out{W}
3575 bindShapeDims(resShapedType, wSize);
3576 // kernel{kw}
3577 bindShapeDims(rhsShapedType, kwSize);
3578 lhsShape = {// iw = ow + kw - 1
3579 // (i.e. 16 convolved with 3 -> 14)
3580 (wSize + kwSize - 1)};
3581 rhsShape = {kwSize};
3582 resShape = {wSize};
3583 break;
3584 case Conv1DOpOrder::Nwc:
3585 // out{n, w, f}
3586 bindShapeDims(resShapedType, nSize, wSize, fSize);
3587 switch (oper) {
3588 case ConvOperationKind::Conv:
3589 // kernel{kw, c, f}
3590 bindShapeDims(rhsShapedType, kwSize, cSize);
3591 break;
3592 case ConvOperationKind::Pool:
3593 // kernel{kw}
3594 bindShapeDims(rhsShapedType, kwSize);
3595 cSize = fSize;
3596 break;
3597 }
3598 lhsShape = {nSize,
3599 // iw = ow * sw + kw * dw - 1
3600 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3601 // Perform the proper inclusive -> exclusive -> inclusive.
3602 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3603 1,
3604 cSize};
3605 switch (oper) {
3606 case ConvOperationKind::Conv:
3607 rhsShape = {kwSize, cSize, fSize};
3608 break;
3609 case ConvOperationKind::Pool:
3610 rhsShape = {kwSize};
3611 break;
3612 }
3613 resShape = {nSize, wSize, fSize};
3614 break;
3615 case Conv1DOpOrder::Ncw:
3616 // out{n, f, w}
3617 bindShapeDims(resShapedType, nSize, fSize, wSize);
3618 switch (oper) {
3619 case ConvOperationKind::Conv:
3620 // kernel{f, c, kw}
3621 bindShapeDims(rhsShapedType, fSize, cSize, kwSize);
3622 break;
3623 case ConvOperationKind::Pool:
3624 // kernel{kw}
3625 bindShapeDims(rhsShapedType, kwSize);
3626 cSize = fSize;
3627 break;
3628 }
3629 lhsShape = {nSize, cSize,
3630 // iw = ow * sw + kw * dw - 1
3631 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3632 // Perform the proper inclusive -> exclusive -> inclusive.
3633 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) -
3634 1};
3635 switch (oper) {
3636 case ConvOperationKind::Conv:
3637 rhsShape = {fSize, cSize, kwSize};
3638 break;
3639 case ConvOperationKind::Pool:
3640 rhsShape = {kwSize};
3641 break;
3642 }
3643 resShape = {nSize, fSize, wSize};
3644 break;
3645 }
3646
3647 vector::TransferWriteOp write;
3648 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3649
3650 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3651 // When strideW == 1, we can batch the contiguous loads and avoid
3652 // unrolling
3653 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3654
3655 Type lhsEltType = lhsShapedType.getElementType();
3656 Type rhsEltType = rhsShapedType.getElementType();
3657 Type resEltType = resShapedType.getElementType();
3658 auto lhsType = VectorType::get(lhsShape, lhsEltType);
3659 auto rhsType = VectorType::get(rhsShape, rhsEltType);
3660 auto resType = VectorType::get(resShape, resEltType);
3661 // Zero padding with the corresponding dimensions for lhs, rhs and res.
3662 SmallVector<Value> lhsPadding(lhsShape.size(), zero);
3663 SmallVector<Value> rhsPadding(rhsShape.size(), zero);
3664 SmallVector<Value> resPadding(resShape.size(), zero);
3665
3666 // Read the whole lhs, rhs and res in one shot (with zero padding).
3667 Value lhs = vector::TransferReadOp::create(
3668 rewriter, loc, lhsType, lhsShaped, lhsPadding,
3669 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3670 // This is needed only for Conv.
3671 Value rhs = nullptr;
3672 if (oper == ConvOperationKind::Conv)
3673 rhs = vector::TransferReadOp::create(
3674 rewriter, loc, rhsType, rhsShaped, rhsPadding,
3675 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3676 Value res = vector::TransferReadOp::create(
3677 rewriter, loc, resType, resShaped, resPadding,
3678 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3679
3680 // The base vectorization case for channeled convolution is input:
3681 // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
3682 // vectorization case, we do pre transpose on input, weight, and output.
3683 switch (conv1DOpOrder) {
3684 case Conv1DOpOrder::W:
3685 case Conv1DOpOrder::Nwc:
3686 // Base case, so no transposes necessary.
3687 break;
3688 case Conv1DOpOrder::Ncw: {
3689 // To match base vectorization case, we pre-transpose current case.
3690 // ncw -> nwc
3691 static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1};
3692 lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs);
3693 // fcw -> wcf
3694 static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0};
3695
3696 // This is needed only for Conv.
3697 if (oper == ConvOperationKind::Conv)
3698 rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs);
3699 // nfw -> nwf
3700 static constexpr std::array<int64_t, 3> permRes = {0, 2, 1};
3701 res = vector::TransposeOp::create(rewriter, loc, res, permRes);
3702 break;
3703 }
3704 }
3705
3706 //===------------------------------------------------------------------===//
3707 // Begin vector-only rewrite part
3708 //===------------------------------------------------------------------===//
3709 // Unroll along kw and read slices of lhs and rhs.
3710 SmallVector<Value> lhsVals, rhsVals, resVals;
3711 lhsVals = extractConvInputSlices(rewriter, loc, lhs, nSize, wSize, cSize,
3712 kwSize, strideW, dilationW, wSizeStep,
3713 isSingleChanneled);
3714 // Do not do for pooling.
3715 if (oper == ConvOperationKind::Conv)
3716 rhsVals = extractConvFilterSlices(rewriter, loc, rhs, kwSize);
3717 resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize,
3718 wSizeStep, isSingleChanneled);
3719
3720 auto linearIndex = [&](int64_t kw, int64_t w) {
3721 return kw * (wSize / wSizeStep) + w;
3722 };
3723
3724 // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f}
3725 // or perform outerproduct for non-channeled convolution or perform simple
3726 // arith operation for pooling
3727 for (int64_t kw = 0; kw < kwSize; ++kw) {
3728 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3729 switch (oper) {
3730 case ConvOperationKind::Conv:
3731 if (isSingleChanneled) {
3732 resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc,
3733 lhsVals[linearIndex(kw, w)],
3734 rhsVals[kw], resVals[w]);
3735 } else {
3736 resVals[w] = conv1dSliceAsContraction(rewriter, loc,
3737 lhsVals[linearIndex(kw, w)],
3738 rhsVals[kw], resVals[w]);
3739 }
3740 break;
3741 case ConvOperationKind::Pool:
3742 resVals[w] = pool1dSlice(rewriter, loc, lhsVals[linearIndex(kw, w)],
3743 resVals[w]);
3744 break;
3745 }
3746 }
3747 }
3748
3749 res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals,
3750 isSingleChanneled);
3751 //===------------------------------------------------------------------===//
3752 // End vector-only rewrite part
3753 //===------------------------------------------------------------------===//
3754
3755 // The base vectorization case for channeled convolution is output:
3756 // {n,w,f} To reuse the result from base pattern vectorization case, we
3757 // post transpose the base case result.
3758 switch (conv1DOpOrder) {
3759 case Conv1DOpOrder::W:
3760 case Conv1DOpOrder::Nwc:
3761 // Base case, so no transposes necessary.
3762 break;
3763 case Conv1DOpOrder::Ncw: {
3764 // nwf -> nfw
3765 static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
3766 res = vector::TransposeOp::create(rewriter, loc, res, perm);
3767 break;
3768 }
3769 }
3770
3771 return vector::TransferWriteOp::create(rewriter, loc, res, resShaped,
3772 resPadding)
3773 .getOperation();
3774 }
3775
3776 // Take a value and widen to have the same element type as `ty`.
3777 Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) {
3778 const Type srcElementType = getElementTypeOrSelf(val.getType());
3779 const Type dstElementType = getElementTypeOrSelf(ty);
3780 assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType));
3781 if (srcElementType == dstElementType)
3782 return val;
3783
3784 const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth();
3785 const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth();
3786 const Type dstType =
3787 cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType);
3788
3789 if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) {
3790 return arith::SIToFPOp::create(rewriter, loc, dstType, val);
3791 }
3792
3793 if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) &&
3794 srcWidth < dstWidth)
3795 return arith::ExtFOp::create(rewriter, loc, dstType, val);
3796
3797 if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) &&
3798 srcWidth < dstWidth)
3799 return arith::ExtSIOp::create(rewriter, loc, dstType, val);
3800
3801 assert(false && "unhandled promotion case");
3802 return nullptr;
3803 }
3804
3805 // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
3806 Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc,
3807 Value lhs, Value rhs, Value res) {
3808 vector::IteratorType par = vector::IteratorType::parallel;
3809 vector::IteratorType red = vector::IteratorType::reduction;
3810 AffineExpr n, w, f, c;
3811 bindDims(ctx, n, w, f, c);
3812 lhs = promote(rewriter, loc, lhs, res.getType());
3813 rhs = promote(rewriter, loc, rhs, res.getType());
3814 auto contrationOp = vector::ContractionOp::create(
3815 rewriter, loc, lhs, rhs, res,
3816 /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}},
3817 /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red});
3818 contrationOp.setKind(reductionKind);
3819 return contrationOp;
3820 }
3821
3822 // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel
3823 // convolution.
3824 Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc,
3825 Value lhs, Value rhs, Value res) {
3826 return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs,
3827 rhs, res, vector::CombiningKind::ADD);
3828 }
3829
3830 // Create a reduction: lhs{n, w, c} -> res{n, w, c}
3831 Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs,
3832 Value res) {
3833 if (isPoolExt)
3834 lhs = rewriter.create(loc, poolExtOp, lhs, res.getType())->getResult(0);
3835 return rewriter
3836 .create(loc, redOp, ArrayRef<Value>{lhs, res}, res.getType())
3837 ->getResult(0);
3838 }
3839
3840 /// Generate a vector implementation for:
3841 /// ```
3842 /// Op def: ( n, w, c, kw)
3843 /// Iters: ({Par(), Par(), Par(), Red()})
3844 /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3845 /// ```
3846 /// kw is always unrolled.
3847 /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3848 /// > 1.
3849 FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3850 bool channelDimScalableFlag,
3851 bool flatten) {
3852 bool scalableChDim = false;
3853 bool useMasking = false;
3854 int64_t nSize, wSize, cSize, kwSize;
3855 // kernel{kw, c}
3856 bindShapeDims(rhsShapedType, kwSize, cSize);
3857 if (ShapedType::isDynamic(cSize)) {
3858 assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
3859 cSize = channelDimVecSize;
3860 // Scalable vectors are only used when both conditions are met:
3861 // 1. channel dim is dynamic
3862 // 2. channelDimScalableFlag is set
3863 scalableChDim = channelDimScalableFlag;
3864 useMasking = true;
3865 }
3866
3867 assert(!(useMasking && flatten) &&
3868 "Unsupported flattened conv with dynamic shapes");
3869
3870 // out{n, w, c}
3871 bindShapeDims(resShapedType, nSize, wSize);
3872
3873 vector::TransferWriteOp write;
3874 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
3875
3876 // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
3877 // When strideW == 1, we can batch the contiguous loads and avoid
3878 // unrolling
3879 int64_t wSizeStep = strideW == 1 ? wSize : 1;
3880
3881 Type lhsEltType = lhsShapedType.getElementType();
3882 Type rhsEltType = rhsShapedType.getElementType();
3883 Type resEltType = resShapedType.getElementType();
3884 VectorType lhsType = VectorType::get(
3885 {nSize,
3886 // iw = ow * sw + kw * dw - 1
3887 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
3888 ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
3889 cSize},
3890 lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
3891 VectorType rhsType =
3892 VectorType::get({kwSize, cSize}, rhsEltType,
3893 /*scalableDims=*/{false, scalableChDim});
3894 VectorType resType =
3895 VectorType::get({nSize, wSize, cSize}, resEltType,
3896 /*scalableDims=*/{false, false, scalableChDim});
3897
3898 // Masks the input xfer Op along the channel dim, iff the corresponding
3899 // scalable flag is set.
3900 auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
3901 ArrayRef<bool> scalableDims,
3902 Operation *opToMask) {
3903 if (!useMasking)
3904 return opToMask;
3905 auto maskType =
3906 VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3907
3908 SmallVector<bool> inBounds(maskShape.size(), true);
3909 auto xferOp = cast<VectorTransferOpInterface>(opToMask);
3910 xferOp->setAttr(xferOp.getInBoundsAttrName(),
3911 rewriter.getBoolArrayAttr(inBounds));
3912
3913 SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer(
3914 cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter);
3915
3916 Value maskOp =
3917 vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims);
3918
3919 return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
3920 };
3921
3922 // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
3923 // 0].
3924 Value lhs = vector::TransferReadOp::create(
3925 rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
3926 /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
3927 auto *maybeMaskedLhs = maybeMaskXferOp(
3928 lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
3929
3930 // Read rhs slice of size {kw, c} @ [0, 0].
3931 Value rhs = vector::TransferReadOp::create(
3932 rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero},
3933 /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
3934 auto *maybeMaskedRhs = maybeMaskXferOp(
3935 rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
3936
3937 // Read res slice of size {n, w, c} @ [0, 0, 0].
3938 Value res = vector::TransferReadOp::create(
3939 rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero},
3940 /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
3941 auto *maybeMaskedRes = maybeMaskXferOp(
3942 resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
3943
3944 //===------------------------------------------------------------------===//
3945 // Begin vector-only rewrite part
3946 //===------------------------------------------------------------------===//
3947 // Unroll along kw and read slices of lhs and rhs.
3948 SmallVector<Value> lhsVals, rhsVals, resVals;
3949 SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize};
3950 SmallVector<int64_t> inOutStrides = {1, 1, 1};
3951
3952 // Extract lhs slice of size {n, wSizeStep, c}
3953 // @ [0, sw * w + dw * kw, 0].
3954 for (int64_t kw = 0; kw < kwSize; ++kw) {
3955 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3956 lhsVals.push_back(vector::ExtractStridedSliceOp::create(
3957 rewriter, loc, maybeMaskedLhs->getResult(0),
3958 /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
3959 inOutSliceSizes, inOutStrides));
3960 }
3961 }
3962 // Extract rhs slice of size {c} @ [kw].
3963 for (int64_t kw = 0; kw < kwSize; ++kw) {
3964 rhsVals.push_back(
3965 vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0),
3966 /*offsets=*/ArrayRef<int64_t>{kw}));
3967 }
3968 // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
3969 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3970 resVals.push_back(vector::ExtractStridedSliceOp::create(
3971 rewriter, loc, maybeMaskedRes->getResult(0),
3972 /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
3973 inOutStrides));
3974 }
3975
3976 auto linearIndex = [&](int64_t kw, int64_t w) {
3977 return kw * (wSize / wSizeStep) + w;
3978 };
3979
3980 // Note - the scalable flags are ignored as flattening combined with
3981 // scalable vectorization is not supported.
3982 SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
3983 auto lhsTypeAfterFlattening =
3984 VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3985 auto resTypeAfterFlattening =
3986 VectorType::get(inOutFlattenSliceSizes, resEltType);
3987
3988 // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3989 for (int64_t kw = 0; kw < kwSize; ++kw) {
3990 for (int64_t w = 0; w < wSize; w += wSizeStep) {
3991 Value lhsVal = lhsVals[linearIndex(kw, w)];
3992 Value resVal = resVals[w];
3993 if (flatten) {
3994 // Flatten the input and output vectors (collapse the channel
3995 // dimension)
3996 lhsVal =
3997 vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening,
3998 lhsVals[linearIndex(kw, w)]);
3999 resVal = vector::ShapeCastOp::create(
4000 rewriter, loc, resTypeAfterFlattening, resVals[w]);
4001 }
4002 resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
4003 rhsVals[kw], resVal, flatten);
4004 if (flatten) {
4005 // Un-flatten the output vector (restore the channel dimension)
4006 resVals[w] = vector::ShapeCastOp::create(
4007 rewriter, loc, VectorType::get(inOutSliceSizes, resEltType),
4008 resVals[w]);
4009 }
4010 }
4011 }
4012
4013 // Its possible we failed to create the Fma.
4014 if (!llvm::all_of(resVals, [](Value v) { return v; })) {
4015 // Manually revert (in reverse order) to avoid leaving a bad IR state.
4016 for (auto &collection :
4017 {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
4018 for (Value v : collection)
4019 rewriter.eraseOp(v.getDefiningOp());
4020 return rewriter.notifyMatchFailure(op, "failed to create FMA");
4021 }
4022
4023 // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
4024 // This does not depend on kw.
4025 for (int64_t w = 0; w < wSize; w += wSizeStep) {
4026 maybeMaskedRes = vector::InsertStridedSliceOp::create(
4027 rewriter, loc, resVals[w], maybeMaskedRes->getResult(0),
4028 /*offsets=*/ArrayRef<int64_t>{0, w, 0},
4029 /*strides=*/ArrayRef<int64_t>{1, 1, 1});
4030 }
4031 //===------------------------------------------------------------------===//
4032 // End vector-only rewrite part
4033 //===------------------------------------------------------------------===//
4034
4035 // Write back res slice of size {n, w, c} @ [0, 0, 0].
4036 Operation *resOut = vector::TransferWriteOp::create(
4037 rewriter, loc, maybeMaskedRes->getResult(0), resShaped,
4038 ValueRange{zero, zero, zero});
4039 return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(),
4040 resOut);
4041 }
4042
4043 /// Lower:
4044 /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
4045 /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
4046 /// to MulAcc.
4047 Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
4048 Value lhs, Value rhs, Value res,
4049 bool flatten) {
4050 auto rhsTy = cast<ShapedType>(rhs.getType());
4051 auto resTy = cast<ShapedType>(res.getType());
4052
4053 // TODO(suderman): Change this to use a vector.ima intrinsic.
4054 lhs = promote(rewriter, loc, lhs, resTy);
4055
4056 if (flatten) {
4057 // NOTE: This following logic won't work for scalable vectors. For this
4058 // reason, "flattening" is not supported when shapes are dynamic (this
4059 // should be captured by one of the pre-conditions).
4060
4061 // There are two options for handling the filter:
4062 // * shape_cast(broadcast(filter))
4063 // * broadcast(shuffle(filter))
4064 // Opt for the option without shape_cast to simplify the codegen.
4065 auto rhsSize = cast<VectorType>(rhs.getType()).getShape()[0];
4066 auto resSize = cast<VectorType>(res.getType()).getShape()[1];
4067
4068 SmallVector<int64_t, 16> indices;
4069 for (int i = 0; i < resSize / rhsSize; ++i) {
4070 for (int j = 0; j < rhsSize; ++j)
4071 indices.push_back(j);
4072 }
4073
4074 rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices);
4075 }
4076 // Broadcast the filter to match the output vector
4077 rhs = vector::BroadcastOp::create(rewriter, loc,
4078 resTy.clone(rhsTy.getElementType()), rhs);
4079
4080 rhs = promote(rewriter, loc, rhs, resTy);
4081
4082 if (!lhs || !rhs)
4083 return nullptr;
4084
4085 if (isa<FloatType>(resTy.getElementType()))
4086 return vector::FMAOp::create(rewriter, loc, lhs, rhs, res);
4087
4088 auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs);
4089 return arith::AddIOp::create(rewriter, loc, mul, res);
4090 }
4091
4092 /// Entry point for non-channeled convolution:
4093 /// {{w + kw}, {kw}, {w}}
4094 FailureOr<Operation *> generateNonChanneledConv() {
4095 AffineExpr w, kw;
4096 bindDims(ctx, w, kw);
4097 if (!iters({Par(), Red()}))
4098 return rewriter.notifyMatchFailure(op,
4099 "failed to match conv::W 1-par 1-red");
4100
4101 // No transposition needed.
4102 if (layout({/*lhsIndex*/ {w + kw},
4103 /*rhsIndex*/ {kw},
4104 /*resIndex*/ {w}}))
4105 return conv(Conv1DOpOrder::W);
4106
4107 return rewriter.notifyMatchFailure(op, "not a conv::W layout");
4108 }
4109
4110 /// Entry point that transposes into the common form:
4111 /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
4112 FailureOr<Operation *> generateNwcConv() {
4113 AffineExpr n, w, f, kw, c;
4114 bindDims(ctx, n, w, f, kw, c);
4115 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4116 return rewriter.notifyMatchFailure(
4117 op, "failed to match conv::Nwc 3-par 2-red");
4118
4119 // No transposition needed.
4120 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4121 /*rhsIndex*/ {kw, c, f},
4122 /*resIndex*/ {n, w, f}}))
4123 return conv(Conv1DOpOrder::Nwc);
4124
4125 return rewriter.notifyMatchFailure(op, "not a conv::Nwc layout");
4126 }
4127
4128 /// Entry point that transposes into the common form:
4129 /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}}
4130 FailureOr<Operation *> generateNcwConv() {
4131 AffineExpr n, w, f, kw, c;
4132 bindDims(ctx, n, f, w, c, kw);
4133 if (!iters({Par(), Par(), Par(), Red(), Red()}))
4134 return rewriter.notifyMatchFailure(
4135 op, "failed to match conv::Ncw 3-par 2-red");
4136
4137 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4138 /*rhsIndex*/ {f, c, kw},
4139 /*resIndex*/ {n, f, w}}))
4140 return conv(Conv1DOpOrder::Ncw);
4141
4142 return rewriter.notifyMatchFailure(op, "not a conv::Ncw layout");
4143 }
4144
4145 /// Entry point that transposes into the common form:
4146 /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling
4147 FailureOr<Operation *> generateNwcPooling() {
4148 AffineExpr n, w, c, kw;
4149 bindDims(ctx, n, w, c, kw);
4150 if (!iters({Par(), Par(), Par(), Red()}))
4151 return rewriter.notifyMatchFailure(op,
4152 "failed to match pooling 3-par 1-red");
4153
4154 // No transposition needed.
4155 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4156 /*rhsIndex*/ {kw},
4157 /*resIndex*/ {n, w, c}}))
4158 return conv(Conv1DOpOrder::Nwc);
4159
4160 return rewriter.notifyMatchFailure(op, "not a pooling::Nwc layout");
4161 }
4162
4163 /// Entry point that transposes into the common form:
4164 /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling
4165 FailureOr<Operation *> generateNcwPooling() {
4166 AffineExpr n, w, c, kw;
4167 bindDims(ctx, n, c, w, kw);
4168 if (!iters({Par(), Par(), Par(), Red()}))
4169 return rewriter.notifyMatchFailure(op,
4170 "failed to match pooling 3-par 1-red");
4171
4172 if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw},
4173 /*rhsIndex*/ {kw},
4174 /*resIndex*/ {n, c, w}}))
4175 return conv(Conv1DOpOrder::Ncw);
4176
4177 return rewriter.notifyMatchFailure(op, "not a pooling::Ncw layout");
4178 }
4179
4180 /// Entry point that transposes into the common form:
4181 /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
4182 FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
4183 bool vecChDimScalableFlag = false,
4184 bool flatten = false) {
4185 AffineExpr n, w, c, kw;
4186 bindDims(ctx, n, w, c, kw);
4187 if (!iters({Par(), Par(), Par(), Red()}))
4188 return rewriter.notifyMatchFailure(
4189 op, "failed to match depthwise::Nwc conv 3-par 1-red");
4190
4191 // No transposition needed.
4192 if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
4193 /*rhsIndex*/ {kw, c},
4194 /*resIndex*/ {n, w, c}}))
4195 return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
4196
4197 return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
4198 }
4199
4200private:
4201 ConvOperationKind oper = ConvOperationKind::Conv;
4202 StringAttr redOp;
4203 StringAttr poolExtOp;
4204 bool isPoolExt = false;
4205 int strideW, dilationW;
4206 Value lhsShaped, rhsShaped, resShaped;
4207 ShapedType lhsShapedType, rhsShapedType, resShapedType;
4208 vector::CombiningKind reductionKind;
4209
4210 // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops.
4211 void setConvOperationKind(Operation *reduceOp) {
4212 int numBlockArguments =
4213 llvm::count_if(reduceOp->getOperands(), llvm::IsaPred<BlockArgument>);
4214 if (numBlockArguments == 1) {
4215 // Will be convolution if feeder is a MulOp.
4216 // A strength reduced version of MulOp for i1 type is AndOp which is also
4217 // supported. Otherwise, it can be pooling. This strength reduction logic
4218 // is in `buildBinaryFn` helper in the Linalg dialect.
4219 auto feedValIt = llvm::find_if_not(reduceOp->getOperands(),
4220 llvm::IsaPred<BlockArgument>);
4221 Operation *feedOp = (*feedValIt).getDefiningOp();
4222 if (isCastOfBlockArgument(feedOp)) {
4223 oper = ConvOperationKind::Pool;
4224 isPoolExt = true;
4225 poolExtOp = feedOp->getName().getIdentifier();
4226 return;
4227 }
4228 oper = ConvOperationKind::Conv;
4229 return;
4230 }
4231 // numBlockArugments == 2 and this is a pooling op.
4232 oper = ConvOperationKind::Pool;
4233 isPoolExt = false;
4234 }
4235};
4236} // namespace
4237
4238/// Helper function to vectorize a LinalgOp with convolution semantics.
4239// TODO: extend the generic vectorization to support windows and drop this.
4240static FailureOr<Operation *> vectorizeConvolution(
4241 RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
4242 ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
4243 Conv1DGenerator conv1dGen(rewriter, op);
4244 auto res = conv1dGen.generateNonChanneledConv();
4245 if (succeeded(res))
4246 return res;
4247 res = conv1dGen.generateNwcConv();
4248 if (succeeded(res))
4249 return res;
4250 res = conv1dGen.generateNcwConv();
4251 if (succeeded(res))
4252 return res;
4253 res = conv1dGen.generateNwcPooling();
4254 if (succeeded(res))
4255 return res;
4256 res = conv1dGen.generateNcwPooling();
4257 if (succeeded(res))
4258 return res;
4259
4260 // Only depthwise 1D NWC convs are left - these can be vectorized using masks
4261 // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e.
4262 // masked/scalable) is the channel dim (i.e. the trailing dim).
4263 uint64_t vecChDimSize = ShapedType::kDynamic;
4264 bool vecChDimScalableFlag = false;
4265 if (!inputVecSizes.empty()) {
4266 // Only use the input vector size corresponding to the channel dim. Other
4267 // vector dims will be inferred from the Ops.
4268 assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
4269 isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
4270 "Not a 1D depthwise conv!");
4271 size_t chDimIdx =
4273 .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
4274 .Case<linalg::DepthwiseConv1DNcwCwOp>([](auto conv) { return 1; });
4275
4276 vecChDimSize = inputVecSizes[chDimIdx];
4277 vecChDimScalableFlag = inputScalableVecDims[chDimIdx];
4278 }
4279 return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
4280 flatten1DDepthwiseConv);
4281}
4282
4283struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
4285
4286 LogicalResult matchAndRewrite(LinalgOp op,
4287 PatternRewriter &rewriter) const override {
4288 FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op);
4289 if (failed(resultOrFail))
4290 return failure();
4291 Operation *newOp = *resultOrFail;
4292 if (newOp->getNumResults() == 0) {
4293 rewriter.eraseOp(op.getOperation());
4294 return success();
4295 }
4296 assert(newOp->getNumResults() == 1 && "expected single result");
4297 rewriter.replaceOp(op.getOperation(), newOp->getResult(0));
4298 return success();
4299 }
4300};
4301
4303 RewritePatternSet &patterns, PatternBenefit benefit) {
4304 patterns.add<VectorizeConvolution>(patterns.getContext(), benefit);
4305}
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static std::optional< VectorShape > vectorShape(Type type)
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType)
Checks whether val can be used for calculating a loop invariant index.
static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t wSize, int64_t wSizeStep, SmallVectorImpl< Value > &resVals, bool isSingleChanneled)
Helper function to insert the computed result slices.
static SmallVector< bool > getDimsToReduce(LinalgOp linalgOp)
static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, LinalgOp &linalgOp, VectorType resType)
Infer the memory access pattern for the input ExtractOp.
static bool isMaskTriviallyFoldable(SmallVector< OpFoldResult > &maskSizes, SmallVector< Value > &writeIdxs, ArrayRef< int64_t > destShape, ArrayRef< int64_t > maskShape)
Determines whether a mask for xfer_write is trivially "all true".
static SmallVector< Value > extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, int64_t nSize, int64_t wSize, int64_t cSize, int64_t kwSize, int strideW, int dilationW, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the input slices after filter is unrolled along kw.
VectorMemoryAccessKind
@ Contiguous
@ Gather
@ ScalarBroadcast
static VectorizationHookResult vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp, const IRMapping &bvm)
Helper function to vectorize the tensor.extract operations.
static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, VectorizationState &state, Operation *op, LinalgOp linalgOp)
Helper function to vectorize the index operations of a linalgOp.
static LogicalResult vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, ArrayRef< int64_t > inputVectorSizes, SmallVectorImpl< Value > &newResults)
Vectorize tensor::InsertSliceOp with:
static Operation * createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, Value dest, SmallVector< Value > writeIndices={}, bool useInBoundsInsteadOfMasking=false)
Creates an optionally masked TransferWriteOp.
static FailureOr< Operation * > vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, ArrayRef< int64_t > inputVecSizes={}, ArrayRef< bool > inputVecScalableFlags={}, bool flatten1DDepthwiseConv=false)
Try to vectorize convOp as a convolution.
static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Generic vectorization function that rewrites the body of a linalgOp into vector form.
static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef< CustomVectorizationHook > customVectorizationHooks)
Generic vectorization for a single operation op, given already vectorized operands carried by bvm.
static Operation * matchLinalgReduction(OpOperand *outputOperand)
Check whether outputOperand is a reduction with a single combiner operation.
static Value buildVectorWrite(RewriterBase &rewriter, Value value, OpOperand *outputOperand, VectorizationState &state)
Build a vector.transfer_write of value into outputOperand at indices set to all 0; where outputOperan...
static Value getStaticPadVal(Operation *op)
Returns the effective Pad value for the input op, provided it's a scalar.
static SmallVector< Value > extractConvFilterSlices(RewriterBase &rewriter, Location loc, Value filter, int64_t kwSize)
Helper function to extract the filter slices after filter is unrolled along kw.
static bool hasReductionIterator(LinalgOp &op)
Check if op is a linalg.reduce or a linalg.generic that has at least one reduction iterator.
std::function< LogicalResult(Operation *, bool)> CustomVectorizationPrecondition
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp)
Find the index of the trailing non-unit dim in linalgOp.
static VectorType getCollapsedVecType(VectorType type, ArrayRef< AffineMap > reassociation)
Given the re-associations, "collapses" the input Vector type.
Conv1DOpOrder
Helper enum to represent conv1d input traversal order.
VectorizationHookStatus
Helper data structure to represent the result of vectorization for a single operation.
@ Failure
Op failed to vectorize.
@ NewOp
Op vectorized into a new Op whose results will replace original Op's results.
@ NoReplace
Op vectorized and custom function took care of replacement logic.
static Operation * reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, Value reduceValue, Value initialValue, const IRMapping &bvm)
Emit reduction operations if the shapes of the value to reduce is different that the result shape.
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
std::function< VectorizationHookResult(Operation *, const IRMapping &)> CustomVectorizationHook
static AffineMap reindexIndexingMap(AffineMap map)
Given an indexing map coming from a LinalgOp indexing, restricted to a projectedPermutation,...
static LogicalResult tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract)
Helper function to check if the tensor.extract can be vectorized by the custom hook vectorizeTensorEx...
static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType)
Broadcast value to a vector of shape if possible.
static Value calculateGatherOffset(RewriterBase &rewriter, VectorizationState &state, tensor::ExtractOp extractOp, const IRMapping &bvm)
Calculates the offsets ($index_vec) for vector.gather operations generated from tensor....
static SmallVector< Value > extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, int64_t nSize, int64_t wSize, int64_t fSize, int64_t wSizeStep, bool isSingleChanneled)
Helper function to extract the result slices after filter is unrolled along kw.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType)
Check whether val could be used for calculating the trailing index for a contiguous load operation.
static VectorizationHookResult vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, const IRMapping &bvm, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl< Value > &newResults)
Helper function to vectorize the terminator of a linalgOp.
static Operation * buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, Value acc, ArrayRef< bool > dimsToMask)
Create MultiDimReductionOp to compute the reduction for reductionOp.
#define mul(a, b)
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumResults() const
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getFilteredIdentityMap(MLIRContext *ctx, unsigned numDims, llvm::function_ref< bool(AffineDimExpr)> keepDimFilter)
Returns an identity affine map with numDims input dimensions and filtered results using keepDimFilter...
AffineMap dropZeroResults()
Returns the AffineMap resulting from removing "zero" results (constant values == 0) from this map.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:137
RetT walk(FnT &&callback)
Walk all nested operations, blocks (including this block) or regions, depending on the type of callba...
Definition Block.h:308
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
operand_iterator operand_end()
Definition Operation.h:375
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
bool hasVectorizationImpl(Operation *)
Return true if there's dedicated logic in the Linalg Vectorizer to vectorize this Op,...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
FailureOr< VectorizationResult > vectorize(RewriterBase &rewriter, Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false, bool assumeDynamicDimsMatchVecSizes=false, bool createNamedContraction=false)
Returns a VectorizationResult containing the results of the vectorized op, or failure if the transfor...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns for vectorizing low-D convolution ops.
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeOpPrecondition(Operation *op, ArrayRef< int64_t > inputVectorSizes={}, ArrayRef< bool > inputScalableVecDims={}, bool vectorizeNDExtract=false, bool flatten1DDepthwiseConv=false)
Return success if the operation can be vectorized.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< vector::CombiningKind > getCombinerOpKind(Operation *combinerOp)
Return vector::CombiningKind for the given op.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:792
std::enable_if_t<!is_complex< V >::value, V > readValue(char **linePtr)
Returns an element-value of non-complex type.
Definition File.h:43
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
LogicalResult isValidMaskedInputVector(ArrayRef< int64_t > shape, ArrayRef< int64_t > inputVectorSizes)
Returns success if inputVectorSizes is a valid masking configuraion for given shape,...
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, const VectorType &vecToReadTy, std::optional< Value > padValue=std::nullopt, bool useInBoundsInsteadOfMasking=false)
Creates a TransferReadOp from source.
SmallVector< OpFoldResult > getMixedSizesXfer(bool hasTensorSemantics, Operation *xfer, RewriterBase &rewriter)
A wrapper for getMixedSizes for vector.transfer_read and vector.transfer_write Ops (for source and de...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void getUsedValuesDefinedAbove(Region &region, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
VectorizationHookResult contains the vectorized op returned from a CustomVectorizationHook.
enum VectorizationHookStatus status
Return status from vectorizing the current op.
Operation * newOp
New vectorized operation to replace the current op.
ArrayRef< int64_t > getCanonicalVecShape() const
Returns the canonical vector shape used to vectorize the iteration space.
LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< int64_t > inputVectorSizes, ArrayRef< bool > inputScalableVecDims, bool assumeDynamicDimsMatchVecSizes=false)
Initializes the vectorization state, including the computation of the canonical vector shape for vect...
Operation * maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional< AffineMap > maybeIndexingMap=std::nullopt)
Masks an operation with the canonical vector mask if the operation needs masking.
VectorType getCanonicalVecType(Type elementType, std::optional< AffineMap > dimPermutation=std::nullopt) const
Returns a vector type of the provided elementType with the canonical vector shape and the correspondi...
ArrayRef< bool > getScalableVecDims() const
Returns the vector dimensions that are scalable in the canonical vector shape.
VectorizationState(RewriterBase &rewriter)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override