MLIR 23.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29/// Compute the indices of the slice `index` for a transfer op.
32 AffineMap permutationMap,
33 Location loc,
34 OpBuilder &builder) {
35 MLIRContext *ctx = builder.getContext();
36 auto isBroadcast = [](AffineExpr expr) {
37 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
39 return false;
40 };
41 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42 SmallVector<Value> slicedIndices(indices);
43 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44 if (isBroadcast(dim.value()))
45 continue;
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47 auto expr = getAffineDimExpr(0, builder.getContext()) +
48 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50 slicedIndices[pos] =
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52 }
53 return slicedIndices;
54}
55
56// Compute the new indices by adding `offsets` to `originalIndices`.
57// If m < n (m = offsets.size(), n = originalIndices.size()),
58// then only the trailing m values in `originalIndices` are updated.
60 Location loc,
61 OperandRange originalIndices,
62 ArrayRef<int64_t> offsets) {
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
65 SmallVector<Value> indices(originalIndices);
66
67 auto start = indices.size() - offsets.size();
68 for (auto [i, offset] : llvm::enumerate(offsets)) {
69 if (offset != 0) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
72 arith::ConstantIndexOp::create(rewriter, loc, offset));
73 }
74 }
75 return indices;
76}
77
78// Clones `op` into a new operations that takes `operands` and returns
79// `resultTypes`.
81 Operation *op,
82 ArrayRef<Value> operands,
83 ArrayRef<Type> resultTypes) {
84 return builder.create(loc, op->getName().getIdentifier(), operands,
85 resultTypes, op->getAttrs());
86}
87
88/// Return the target shape for unrolling for the given `op`. Return
89/// std::nullopt if the op shouldn't be or cannot be unrolled.
90static std::optional<SmallVector<int64_t>>
92 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93 if (options.filterConstraint && failed(options.filterConstraint(op))) {
94 LDBG() << "--no filter constraint -> BAIL";
95 return std::nullopt;
96 }
97 assert(options.nativeShape &&
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() << "--not an unrollable op -> BAIL";
103 return std::nullopt;
104 }
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108 return std::nullopt;
109 }
110 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111
112 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113 if (!targetShape) {
114 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115 return std::nullopt;
116 }
117 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118
119 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() << "--could not compute integral shape ratio -> BAIL";
122 return std::nullopt;
123 }
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125 LDBG() << "--no unrolling needed -> SKIP";
126 return std::nullopt;
127 }
128 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129 return targetShape;
130}
131
133getUnrollOrder(unsigned numLoops, Operation *op,
135 SmallVector<int64_t> loopOrder =
136 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137 if (options.traversalOrderCallback != nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
140 if (order) {
141 loopOrder = std::move(*order);
142 }
143 }
144 return loopOrder;
145}
146
147namespace {
148
149struct UnrollTransferReadPattern
150 : public OpRewritePattern<vector::TransferReadOp> {
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 options(options) {}
156
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter) const override {
159 // TODO: support 0-d corner case.
160 if (readOp.getTransferRank() == 0)
161 return failure();
162 if (readOp.getMask())
163 return failure();
164 auto targetShape = getTargetShape(options, readOp);
165 if (!targetShape)
166 return failure();
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171
172 // Prepare the result vector;
173 Value result =
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 rewriter.getZeroAttr(sourceVectorType));
176 auto targetType =
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
181 getUnrollOrder(originalSize.size(), readOp, options);
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184 SmallVector<Value> indices =
185 sliceTransferIndices(elementOffsets, originalIndices,
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
191
192 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
194 }
195 rewriter.replaceOp(readOp, result);
196 return success();
197 }
198
199private:
200 vector::UnrollVectorOptions options;
201};
202
203struct UnrollTransferWritePattern
204 : public OpRewritePattern<vector::TransferWriteOp> {
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 options(options) {}
210
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter) const override {
213 // TODO: support 0-d corner case.
214 if (writeOp.getTransferRank() == 0)
215 return failure();
216
217 if (writeOp.getMask())
218 return failure();
219 auto targetShape = getTargetShape(options, writeOp);
220 if (!targetShape)
221 return failure();
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226 // Bail-out if rank(source) != rank(target). The main limitation here is the
227 // fact that `ExtractStridedSlice` requires the rank for the input and
228 // output to match. If needed, we can relax this later.
229 if (originalSize.size() != targetShape->size())
230 return rewriter.notifyMatchFailure(
231 writeOp,
232 "expected source input vector rank to match target shape rank");
233
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
237 getUnrollOrder(originalSize.size(), writeOp, options);
238 Value resultTensor;
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243 SmallVector<Value> indices =
244 sliceTransferIndices(elementOffsets, originalIndices,
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 // For the tensor case update the destination for the next transfer write.
251 if (!slicedWrite->getResults().empty())
252 resultTensor = slicedWrite->getResult(0);
253 }
254 if (resultTensor)
255 rewriter.replaceOp(writeOp, resultTensor);
256 else
257 rewriter.eraseOp(writeOp);
258 return success();
259 }
260
261private:
262 vector::UnrollVectorOptions options;
263};
264
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267
268 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269
270 static unsigned getHashValue(const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
272 }
273
274 static bool isEqual(const SmallVector<int64_t> &lhs,
275 const SmallVector<int64_t> &rhs) {
276 return lhs == rhs;
277 }
278};
279
280struct UnrollContractionPattern
281 : public OpRewritePattern<vector::ContractionOp> {
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
286 options(options) {}
287
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter) const override {
290 auto targetShape = getTargetShape(options, contractOp);
291 if (!targetShape)
292 return failure();
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299 llvm::MapVector<
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302 accCache;
303
304 SmallVector<int64_t> loopOrder = getUnrollOrder(
305 contractOp.getIteratorTypes().size(), contractOp, options);
306
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310
311 // Helper to compute the new shape of each operand and extract the slice.
312 auto extractOperand = [&](unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
315 SmallVector<int64_t> operandShape = applyPermutationMap(
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
319 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320 loc, operand, operandOffets, operandShape, operandStrides);
321 };
322
323 // Extract the new lhs operand.
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
326 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328
329 // Extract the new rhs operand.
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
332 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
337 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338 // If a version of the accumulator has already been computed, use it
339 // otherwise extract the first version from the original operand.
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
343 else
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345
346 SmallVector<int64_t> dstShape =
347 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
349 Operation *newOp = cloneOpWithOperandsAndTypes(
350 rewriter, loc, contractOp, slicesOperands, targetType);
351
352 SmallVector<int64_t> dstOffets =
353 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354 // Save the accumulated value untill all the loops are unrolled since
355 // reduction loop keep updating the accumulator.
356 accCache[dstOffets] = newOp->getResult(0);
357 }
358 // Assemble back the accumulator into a single vector.
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360 rewriter.getZeroAttr(dstVecType));
361 for (const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
363 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
365 }
366 rewriter.replaceOp(contractOp, result);
367 return success();
368 }
369
370private:
371 vector::UnrollVectorOptions options;
372};
373
374struct UnrollMultiReductionPattern
375 : public OpRewritePattern<vector::MultiDimReductionOp> {
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380 options(options) {}
381
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter) const override {
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
386 return rewriter.notifyMatchFailure(reductionOp,
387 "Unrolling scalars is not supported");
388 }
389 std::optional<SmallVector<int64_t>> targetShape =
390 getTargetShape(options, reductionOp);
391 if (!targetShape)
392 return failure();
393 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
394 llvm::MapVector<
395 SmallVector<int64_t>, Value,
396 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
397 accCache;
398 Location loc = reductionOp.getLoc();
399
400 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
401 // of multiples of the targetShape.
402 for (SmallVector<int64_t> offsets :
403 StaticTileOffsetRange(originalSize, *targetShape)) {
404 SmallVector<Value> operands;
405 SmallVector<int64_t> operandStrides(offsets.size(), 1);
406 Value slicedOperand =
407 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
408 loc, reductionOp.getSource(), offsets, *targetShape,
409 operandStrides);
410 operands.push_back(slicedOperand);
411 SmallVector<int64_t> dstShape;
412 SmallVector<int64_t> destOffset;
413 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
417 }
418 }
419 Value acc;
420 SmallVector<int64_t> accStrides(destOffset.size(), 1);
421 // If a version of the accumulator has already been computed, use it
422 // otherwise extract the first version from the original operand.
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
425 acc = accIt->second;
426 else
427 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
430 auto targetType = VectorType::get(
431 dstShape, reductionOp.getSourceVectorType().getElementType());
432 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
433 operands, targetType);
434 Value result = newOp->getResult(0);
435 accCache[destOffset] = result;
436 }
437 // Assemble back the accumulator into a single vector.
438 Value result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
440 rewriter.getZeroAttr(reductionOp.getDestType()));
441 for (const auto &it : accCache) {
442 SmallVector<int64_t> dstStrides(it.first.size(), 1);
443 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
444 loc, it.second, result, it.first, dstStrides);
445 }
446 rewriter.replaceOp(reductionOp, result);
447 return success();
448 }
449
450private:
451 vector::UnrollVectorOptions options;
452};
453
454struct UnrollElementwisePattern : public RewritePattern {
455 UnrollElementwisePattern(MLIRContext *context,
456 const vector::UnrollVectorOptions &options,
457 PatternBenefit benefit = 1)
458 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
459 options(options) {}
460
461 LogicalResult matchAndRewrite(Operation *op,
462 PatternRewriter &rewriter) const override {
464 return failure();
465 auto targetShape = getTargetShape(options, op);
466 if (!targetShape)
467 return failure();
468 int64_t targetShapeRank = targetShape->size();
469 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
470 SmallVector<int64_t> originalSize =
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
473
474 Location loc = op->getLoc();
475
476 // Handle rank mismatch by adding leading unit dimensions to targetShape
477 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
483
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
485 // Prepare the result vector.
486 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
487 rewriter.getZeroAttr(dstVecType));
488 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489 VectorType unrolledVecType =
490 VectorType::get(*targetShape, dstVecType.getElementType());
491
492 // Create the unrolled computation.
493 for (SmallVector<int64_t> offsets :
494 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495 SmallVector<Value> extractOperands;
496 for (OpOperand &operand : op->getOpOperands()) {
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
498 if (!vecType) {
499 extractOperands.push_back(operand.get());
500 continue;
501 }
502 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
504
505 // Reshape to remove leading unit dims if needed
506 if (adjustedTargetShapeRank > targetShapeRank) {
507 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
508 loc, VectorType::get(*targetShape, vecType.getElementType()),
509 extracted);
510 }
511 extractOperands.push_back(extracted);
512 }
513
514 Operation *newOp = cloneOpWithOperandsAndTypes(
515 rewriter, loc, op, extractOperands, unrolledVecType);
516
517 Value computeResult = newOp->getResult(0);
518
519 // Use strides sized to targetShape for proper insertion
520 SmallVector<int64_t> insertStrides =
521 (adjustedTargetShapeRank > targetShapeRank)
522 ? SmallVector<int64_t>(targetShapeRank, 1)
523 : strides;
524
525 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
526 loc, computeResult, result, offsets, insertStrides);
527 }
528 rewriter.replaceOp(op, result);
529 return success();
530 }
531
532private:
533 vector::UnrollVectorOptions options;
534};
535
536struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
537 UnrollReductionPattern(MLIRContext *context,
538 const vector::UnrollVectorOptions &options,
539 PatternBenefit benefit = 1)
540 : OpRewritePattern<vector::ReductionOp>(context, benefit),
541 options(options) {}
542
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544 PatternRewriter &rewriter) const override {
545 std::optional<SmallVector<int64_t>> targetShape =
546 getTargetShape(options, reductionOp);
547 if (!targetShape)
548 return failure();
549 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
550
551 // Create unrolled vector reduction.
552 Location loc = reductionOp.getLoc();
553 Value accumulator = nullptr;
554 for (SmallVector<int64_t> offsets :
555 StaticTileOffsetRange(originalSize, *targetShape)) {
556 SmallVector<int64_t> strides(offsets.size(), 1);
557 Value slicedOperand =
558 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
560 Operation *newOp = cloneOpWithOperandsAndTypes(
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
562 Value result = newOp->getResult(0);
563
564 if (!accumulator) {
565 // This is the first reduction.
566 accumulator = result;
567 } else {
568 // On subsequent reduction, combine with the accumulator.
569 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
570 accumulator, result);
571 }
572 }
573
574 rewriter.replaceOp(reductionOp, accumulator);
575 return success();
576 }
577
578private:
579 const vector::UnrollVectorOptions options;
580};
581
582struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
583 UnrollTransposePattern(MLIRContext *context,
584 const vector::UnrollVectorOptions &options,
585 PatternBenefit benefit = 1)
586 : OpRewritePattern<vector::TransposeOp>(context, benefit),
587 options(options) {}
588
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590 PatternRewriter &rewriter) const override {
591 if (transposeOp.getResultVectorType().getRank() == 0)
592 return failure();
593 auto targetShape = getTargetShape(options, transposeOp);
594 if (!targetShape)
595 return failure();
596 auto originalVectorType = transposeOp.getResultVectorType();
597 SmallVector<int64_t> strides(targetShape->size(), 1);
598 Location loc = transposeOp.getLoc();
599 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
600
601 // Prepare the result vector;
602 Value result =
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
604 rewriter.getZeroAttr(originalVectorType));
605 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
606
607 // Unroll the computation.
608 for (SmallVector<int64_t> elementOffsets :
609 StaticTileOffsetRange(originalSize, *targetShape)) {
610 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611 SmallVector<int64_t> permutedShape(elementOffsets.size());
612 // Compute the source offsets and shape.
613 for (auto indices : llvm::enumerate(permutation)) {
614 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
615 permutedShape[indices.value()] = (*targetShape)[indices.index()];
616 }
617 Value slicedOperand =
618 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
620 strides);
621 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
622 loc, slicedOperand, permutation);
623 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624 loc, transposedSlice, result, elementOffsets, strides);
625 }
626 rewriter.replaceOp(transposeOp, result);
627 return success();
628 }
629
630private:
631 vector::UnrollVectorOptions options;
632};
633
634struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
635 UnrollGatherPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
639 }
640
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642 PatternRewriter &rewriter) const override {
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
645 return failure();
646 auto targetShape = getTargetShape(options, gatherOp);
647 if (!targetShape)
648 return failure();
649 SmallVector<int64_t> strides(targetShape->size(), 1);
650 Location loc = gatherOp.getLoc();
651 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
652
653 // Prepare the result vector;
654 Value result =
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
656 rewriter.getZeroAttr(sourceVectorType));
657 auto targetType =
658 VectorType::get(*targetShape, sourceVectorType.getElementType());
659
660 SmallVector<int64_t> loopOrder =
661 getUnrollOrder(originalSize.size(), gatherOp, options);
662 for (SmallVector<int64_t> elementOffsets :
663 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
664 // To get the unrolled gather, extract the same slice based on the
665 // decomposed shape from each of the index, mask, and pass-through
666 // vectors.
667 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
672 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
674 strides);
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
678
679 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
680 loc, slicedGather, result, elementOffsets, strides);
681 }
682 rewriter.replaceOp(gatherOp, result);
683 return success();
684 }
685
686private:
687 vector::UnrollVectorOptions options;
688};
689
690struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
691 UnrollLoadPattern(MLIRContext *context,
692 const vector::UnrollVectorOptions &options,
693 PatternBenefit benefit = 1)
694 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
695
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter) const override {
698 VectorType vecType = loadOp.getVectorType();
699
700 auto targetShape = getTargetShape(options, loadOp);
701 if (!targetShape)
702 return failure();
703
704 Location loc = loadOp.getLoc();
705 ArrayRef<int64_t> originalShape = vecType.getShape();
706 SmallVector<int64_t> strides(targetShape->size(), 1);
707
708 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
709 rewriter.getZeroAttr(vecType));
710
711 SmallVector<int64_t> loopOrder =
712 getUnrollOrder(originalShape.size(), loadOp, options);
713
714 auto targetVecType =
715 VectorType::get(*targetShape, vecType.getElementType());
716
717 for (SmallVector<int64_t> offsets :
718 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
719 SmallVector<Value> indices =
720 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
722 loadOp.getBase(), indices);
723 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
724 loc, slicedLoad, result, offsets, strides);
725 }
726 rewriter.replaceOp(loadOp, result);
727 return success();
728 }
729
730private:
731 vector::UnrollVectorOptions options;
732};
733
734struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
735 UnrollStorePattern(MLIRContext *context,
736 const vector::UnrollVectorOptions &options,
737 PatternBenefit benefit = 1)
738 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
739
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741 PatternRewriter &rewriter) const override {
742 VectorType vecType = storeOp.getVectorType();
743
744 auto targetShape = getTargetShape(options, storeOp);
745 if (!targetShape)
746 return failure();
747
748 Location loc = storeOp.getLoc();
749 ArrayRef<int64_t> originalShape = vecType.getShape();
750 SmallVector<int64_t> strides(targetShape->size(), 1);
751
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
754
755 SmallVector<int64_t> loopOrder =
756 getUnrollOrder(originalShape.size(), storeOp, options);
757
758 for (SmallVector<int64_t> offsets :
759 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
760 SmallVector<Value> indices =
761 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
762 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base, indices);
765 }
766 rewriter.eraseOp(storeOp);
767 return success();
768 }
769
770private:
771 vector::UnrollVectorOptions options;
772};
773
774struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
775 UnrollBroadcastPattern(MLIRContext *context,
776 const vector::UnrollVectorOptions &options,
777 PatternBenefit benefit = 1)
778 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
779 options(options) {}
780
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782 PatternRewriter &rewriter) const override {
783 auto targetShape = getTargetShape(options, broadcastOp);
784 if (!targetShape)
785 return failure();
786
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value result = arith::ConstantOp::create(rewriter, loc, resType,
793 rewriter.getZeroAttr(resType));
794
795 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796 SmallVector<int64_t> strides(originalShape.size(), 1);
797
798 for (SmallVector<int64_t> offsets :
799 StaticTileOffsetRange(originalShape, *targetShape)) {
800 Value newSrc;
801 if (!srcType) {
802 // Scalar to vector broadcast.
803 newSrc = broadcastOp.getSource();
804 } else {
805 // Vector to vector broadcast.
806 int64_t rank = srcType.getRank();
807 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808 SmallVector<int64_t> srcShape(targetShape->end() - rank,
809 targetShape->end());
810 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
811 // adjust the offset and shape for src if the corresponding dim is 1.
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
814 srcOffsets[i] = 0;
815 srcShape[i] = 1;
816 }
817 }
818 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
820 }
821
822 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
823 newSrc, targetType);
824
825 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
826 loc, newOp->getResult(0), result, offsets, strides);
827 }
828
829 rewriter.replaceOp(broadcastOp, result);
830 return success();
831 }
832
833private:
834 vector::UnrollVectorOptions options;
835};
836
837/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838/// outermost dimension of the operand. For example:
839///
840/// ```
841/// %0:4 = vector.to_elements %v : vector<2x2xf32>
842///
843/// ==>
844///
845/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
846/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
847/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
848/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
849/// ```
850///
851/// When this pattern is applied until a fixed-point is reached,
852/// this will produce a sequence of 1-d from_elements
853/// ops.
854struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
855 UnrollToElements(MLIRContext *context,
856 const vector::UnrollVectorOptions &options,
857 PatternBenefit benefit = 1)
858 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
859 options(options) {}
860
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
862 PatternRewriter &rewriter) const override {
863
864 TypedValue<VectorType> source = op.getSource();
865 FailureOr<SmallVector<Value>> result =
866 vector::unrollVectorValue(source, rewriter);
867 if (failed(result)) {
868 return failure();
869 }
870 SmallVector<Value> vectors = *result;
871
872 SmallVector<Value> results;
873 for (Value vector : vectors) {
874 auto subElements =
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
877 }
878 rewriter.replaceOp(op, results);
879 return success();
880 }
881
882private:
883 vector::UnrollVectorOptions options;
884};
885
886/// This pattern unrolls `vector.step` operations according to the provided
887/// target unroll shape. It decomposes a large step vector into smaller step
888/// vectors (segments) and assembles the result by inserting each computed
889/// segment into the appropriate offset of the original vector.
890///
891/// The pattern does not support scalable vectors and will fail to match them.
892///
893/// For each segment, it adds the base step vector and the segment's offset,
894/// then inserts the result into the output vector at the corresponding
895/// position.
896///
897/// Example:
898/// Given a step operation:
899/// %0 = vector.step : vector<8xindex>
900///
901/// and a target unroll shape of <4>, the pattern produces:
902///
903/// %base = vector.step : vector<4xindex>
904/// %zero = arith.constant dense<0> : vector<8xindex>
905/// %result0 = vector.insert_strided_slice %base, %zero
906/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
907/// %offset = arith.constant dense<4> : vector<4xindex>
908/// %segment1 = arith.addi %base, %offset : vector<4xindex>
909/// %result1 = vector.insert_strided_slice %segment1, %result0
910/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
911///
912struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
913 UnrollStepPattern(MLIRContext *context,
914 const vector::UnrollVectorOptions &options,
915 PatternBenefit benefit = 1)
916 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
917
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
919 PatternRewriter &rewriter) const override {
920 std::optional<SmallVector<int64_t>> targetShape =
921 getTargetShape(options, stepOp);
922 if (!targetShape)
923 return failure();
924
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
927 // Scalable vectors are not supported by this pattern.
928 return failure();
929 }
930 int64_t originalSize = vecType.getShape()[0];
931 Location loc = stepOp.getLoc();
932 SmallVector<int64_t> strides(1, 1);
933
934 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
935 rewriter.getZeroAttr(vecType));
936
937 auto targetVecType =
938 VectorType::get(*targetShape, vecType.getElementType());
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940 for (const SmallVector<int64_t> &offsets :
941 StaticTileOffsetRange({originalSize}, *targetShape)) {
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
945 targetVecType,
946 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
947 Value tileStep =
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
949
950 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
951 loc, tileStep, result, offsets, strides);
952 }
953 rewriter.replaceOp(stepOp, result);
954 return success();
955 }
956
957private:
958 vector::UnrollVectorOptions options;
959};
960
961/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
962/// outermost dimension. For example:
963/// ```
964/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
965///
966/// ==>
967///
968/// %0 = ub.poison : vector<2x3xf32>
969/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
970/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
971/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
972/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
973/// ```
974///
975/// When this pattern is applied until a fixed-point is reached,
976/// this will produce a sequence of 1-d from_elements
977/// ops.
978struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
979 UnrollFromElements(MLIRContext *context,
980 const vector::UnrollVectorOptions &options,
981 PatternBenefit benefit = 1)
982 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
983 options(options) {}
984
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
986 PatternRewriter &rewriter) const override {
987 ValueRange allElements = op.getElements();
988
989 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
993 "out of bounds");
994 ValueRange subElements =
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
997 };
998
999 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1000 }
1001
1002private:
1003 vector::UnrollVectorOptions options;
1004};
1005
1006/// This pattern unrolls `vector.create_mask` operations into smaller mask
1007/// operations based on the target unroll shape. Each unrolled slice computes
1008/// its local mask size in each dimension (d) as:
1009/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1010/// Example:
1011/// Given a create_mask operation:
1012/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1013/// elements
1014///
1015/// and a target unroll shape of <4x8>, the pattern produces:
1016///
1017/// %false = arith.constant dense<false> : vector<8x16xi1>
1018///
1019/// Slice [0,0]:
1020/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1021/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1022/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1023/// : vector<4x8xi1> into vector<8x16xi1>
1024/// Slice [0,8]:
1025/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1026/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1027/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1028/// : vector<4x8xi1> into vector<8x16xi1>
1029/// Slice [4,0]:
1030/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1031/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1032/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1033/// : vector<4x8xi1> into vector<8x16xi1>
1034/// Slice [4,8]:
1035/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1036/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1037/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1038/// : vector<4x8xi1> into vector<8x16xi1>
1039struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1040 UnrollCreateMaskPattern(MLIRContext *context,
1041 const vector::UnrollVectorOptions &options,
1042 PatternBenefit benefit = 1)
1043 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1044 options(options) {}
1045
1046 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1047 PatternRewriter &rewriter) const override {
1048 auto targetShape = getTargetShape(options, createMaskOp);
1049 if (!targetShape)
1050 return failure();
1051
1052 VectorType resultType = createMaskOp.getVectorType();
1053 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1054 Location loc = createMaskOp.getLoc();
1055
1056 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1057 rewriter.getZeroAttr(resultType));
1058 VectorType targetVectorType =
1059 VectorType::get(*targetShape, rewriter.getI1Type());
1060 SmallVector<int64_t> strides(targetShape->size(), 1);
1061
1062 // In each dimension (d), each unrolled vector computes its mask size as:
1063 // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1064 for (SmallVector<int64_t> offsets :
1065 StaticTileOffsetRange(originalSize, *targetShape)) {
1066 SmallVector<Value> unrolledOperands;
1067
1068 for (auto [i, originalMaskOperand] :
1069 llvm::enumerate(createMaskOp.getOperands())) {
1070 Value offsetVal =
1071 arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1072 Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1073 loc, originalMaskOperand, offsetVal);
1074 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1075 Value unrolledDimSize =
1076 arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1077 Value nonNegative =
1078 rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079 Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1080 loc, nonNegative, unrolledDimSize);
1081 unrolledOperands.push_back(unrolledOperand);
1082 }
1083
1084 auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1085 loc, targetVectorType, unrolledOperands);
1086 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1087 loc, unrolledMask, result, offsets, strides);
1088 }
1089 rewriter.replaceOp(createMaskOp, result);
1090 return success();
1091 }
1092
1093private:
1094 vector::UnrollVectorOptions options;
1095};
1096
1097/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1098/// operations based on the target unroll shape. Each unrolled slice computes
1099/// whether its elements should be masked based on the original mask dimensions
1100/// and the slice's offset position.
1101///
1102/// Example:
1103/// Given a constant_mask operation:
1104/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1105///
1106/// and a target unroll shape of <4x8>, the pattern produces:
1107///
1108/// %false = arith.constant dense<false> : vector<8x16xi1>
1109///
1110/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1111/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1112/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1113/// : vector<4x8xi1> into vector<8x16xi1>
1114///
1115/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1116/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1117/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1118/// : vector<4x8xi1> into vector<8x16xi1>
1119///
1120/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1121/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1122/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1123/// : vector<4x8xi1> into vector<8x16xi1>
1124///
1125/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1126/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1127/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1128/// : vector<4x8xi1> into vector<8x16xi1>
1129struct UnrollConstantMaskPattern
1130 : public OpRewritePattern<vector::ConstantMaskOp> {
1131 UnrollConstantMaskPattern(MLIRContext *context,
1132 const vector::UnrollVectorOptions &options,
1133 PatternBenefit benefit = 1)
1134 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1135 options(options) {}
1136
1137 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1138 PatternRewriter &rewriter) const override {
1139 std::optional<SmallVector<int64_t>> targetShape =
1140 getTargetShape(options, constantMaskOp);
1141 if (!targetShape)
1142 return failure();
1143
1144 VectorType resultType = constantMaskOp.getVectorType();
1145 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1146 Location loc = constantMaskOp.getLoc();
1147
1148 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1149 rewriter.getZeroAttr(resultType));
1150 VectorType targetVectorType =
1151 VectorType::get(*targetShape, rewriter.getI1Type());
1152 SmallVector<int64_t> strides(targetShape->size(), 1);
1153
1154 // In each dimension (d), each unrolled vector computes its mask size as:
1155 // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1156 for (const SmallVector<int64_t> &offsets :
1157 StaticTileOffsetRange(originalSize, *targetShape)) {
1158 SmallVector<int64_t> unrolledMaskDims;
1159
1160 for (auto [i, originalMaskDim] :
1161 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1162 // Calculate how many elements in this dimension should be masked
1163 // for this particular slice
1164 int64_t adjustedMaskSize =
1165 std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1166 int64_t unrolledMaskDim =
1167 std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1168 unrolledMaskDims.push_back(unrolledMaskDim);
1169 }
1170
1171 auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1172 loc, targetVectorType, unrolledMaskDims);
1173 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1174 loc, unrolledMask, result, offsets, strides);
1175 }
1176 rewriter.replaceOp(constantMaskOp, result);
1177 return success();
1178 }
1179
1180private:
1181 vector::UnrollVectorOptions options;
1182};
1183
1184/// Checks whether extractShape is a contiguous slice of shape.
1185/// For extractShape to be contiguous in shape:
1186/// 1) All but the leading dimension of extractShape and shape must match
1187/// exactly. 2) The total number of elements in shape must be evenly divisible
1188/// by
1189/// the total number of elements in extractShape.
1190/// Examples:
1191/// isContiguous([4, 4], [8, 4]) == true
1192/// isContiguous([2, 4], [8, 4]) == true
1193/// isContiguous([2, 2], [8, 4]) == false
1194/// Removes leading unit dimensions to handle cases like:
1195/// isContiguous([1, 16], [1, 32]) == true
1196static bool isContiguous(ArrayRef<int64_t> extractShape,
1198
1199 if (extractShape.empty() || shape.empty() ||
1200 extractShape.size() > shape.size())
1201 return false;
1202
1203 while (extractShape.size() > 1 && extractShape.front() == 1)
1204 extractShape = extractShape.drop_front();
1205
1206 while (shape.size() > 1 && shape.front() == 1) {
1207 shape = shape.drop_front();
1208 }
1209
1210 size_t rankDiff = shape.size() - extractShape.size();
1211 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1212 return false;
1213
1214 int64_t extractElements = ShapedType::getNumElements(extractShape);
1215 int64_t shapeElements = ShapedType::getNumElements(shape);
1216 return shapeElements % extractElements == 0;
1217}
1218
1219/// Determines what shape to use with `vector.extract_strided_slice` to extract
1220/// a contiguous memory region from a source vector. The extraction must be
1221/// contiguous and contain exactly the specified number of elements. If such an
1222/// extraction shape cannot be determined, returns std::nullopt.
1223/// EXAMPLE 1:
1224/// sourceShape = [16], targetElements = 8
1225/// Working right-to-left:
1226/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1227/// remaining = 8/8 = 1
1228/// Result: [8]
1229///
1230/// EXAMPLE 2:
1231/// sourceShape = [4, 4], targetElements = 8
1232/// Working right-to-left:
1233/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1234/// remaining = 8/4 = 2
1235/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1236/// remaining = 2/2 = 1
1237/// Result: [2, 4]
1238static std::optional<SmallVector<int64_t>>
1239calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1240 int64_t targetElements) {
1241 SmallVector<int64_t> extractShape;
1242 int64_t remainingElements = targetElements;
1243
1244 // Build extract shape from innermost dimension outward to ensure contiguity.
1245 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1246 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1247 extractShape.insert(extractShape.begin(), takeFromDim);
1248
1249 if (remainingElements % takeFromDim != 0)
1250 return std::nullopt; // Not evenly divisible.
1251 remainingElements /= takeFromDim;
1252 }
1253
1254 // Fill remaining dimensions with 1.
1255 while (extractShape.size() < sourceShape.size())
1256 extractShape.insert(extractShape.begin(), 1);
1257
1258 if (ShapedType::getNumElements(extractShape) != targetElements)
1259 return std::nullopt;
1260
1261 return extractShape;
1262}
1263
1264// Convert result offsets to source offsets via linear position.
1266calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1267 ArrayRef<int64_t> sourceShape,
1268 ArrayRef<int64_t> resultShape) {
1269 // Convert result offsets to linear position.
1270 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1271 // Convert linear position to source offsets.
1272 return delinearize(linearIndex, computeStrides(sourceShape));
1273}
1274
1275/// This pattern unrolls `vector.shape_cast` operations according to the
1276/// provided target unroll shape. It unrolls a large shape cast into smaller
1277/// shape casts by extracting contiguous slices from the source vector, casting
1278/// each slice to the target shape, and assembling the result by inserting each
1279/// computed segment into the appropriate offset of the result vector.
1280///
1281/// This pattern only applies when contiguous slices can be extracted from the
1282/// source vector and inserted into the result vector such that each slice
1283/// remains a valid vector (and not decompose to scalars). In these cases, the
1284/// unrolling proceeds as:
1285/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1286/// vector.insert_strided_slice.
1287///
1288/// Example:
1289/// Given a shape cast operation:
1290/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1291///
1292/// and a target unroll shape of <2x4>, the pattern produces:
1293///
1294/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1295/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1296/// : vector<8x2xf32> to vector<4x2xf32>
1297/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1298/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1299/// : vector<2x4xf32> into vector<4x4xf32>
1300/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1301/// : vector<8x2xf32> to vector<4x2xf32>
1302/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1303/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1304/// : vector<2x4xf32> into vector<4x4xf32>
1305///
1306struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1307 UnrollShapeCastPattern(MLIRContext *context,
1308 const vector::UnrollVectorOptions &options,
1309 PatternBenefit benefit = 1)
1310 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1311 options(options) {}
1312
1313 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1314 PatternRewriter &rewriter) const override {
1315 std::optional<SmallVector<int64_t>> targetShape =
1316 getTargetShape(options, shapeCastOp);
1317 if (!targetShape)
1318 return failure();
1319
1320 VectorType sourceType = shapeCastOp.getSourceVectorType();
1321 VectorType resultType = shapeCastOp.getResultVectorType();
1322 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1323 ArrayRef<int64_t> resultShape = resultType.getShape();
1324
1325 if (!isContiguous(*targetShape, resultShape))
1326 return rewriter.notifyMatchFailure(
1327 shapeCastOp, "Only supports cases where target shape is "
1328 "contiguous in result vector shape");
1329
1330 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1331
1332 // Calculate the shape to extract from source.
1333 std::optional<SmallVector<int64_t>> extractShape =
1334 calculateSourceExtractShape(sourceShape, targetElements);
1335 if (!extractShape)
1336 return rewriter.notifyMatchFailure(
1337 shapeCastOp,
1338 "cannot extract target number of elements contiguously from source");
1339
1340 Location loc = shapeCastOp.getLoc();
1341
1342 // Create result vector initialized to zero.
1343 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1344 rewriter.getZeroAttr(resultType));
1345
1346 VectorType targetType =
1347 VectorType::get(*targetShape, sourceType.getElementType());
1348
1349 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1350 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1351
1352 for (SmallVector<int64_t> resultOffsets :
1353 StaticTileOffsetRange(resultShape, *targetShape)) {
1354 SmallVector<int64_t> sourceOffsets =
1355 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1356 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1357 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1359 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1360 loc, targetType, sourceChunk);
1361 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1362 loc, targetChunk, result, resultOffsets, insertStrides);
1363 }
1364
1365 rewriter.replaceOp(shapeCastOp, result);
1366 return success();
1367 }
1368
1369private:
1370 vector::UnrollVectorOptions options;
1371};
1372
1373} // namespace
1374
1375void mlir::vector::populateVectorUnrollPatterns(
1377 PatternBenefit benefit) {
1378 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1379 UnrollContractionPattern, UnrollElementwisePattern,
1380 UnrollReductionPattern, UnrollMultiReductionPattern,
1381 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1382 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1383 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1384 UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1385 patterns.getContext(), options, benefit);
1386}
1387
1388void mlir::vector::populateVectorToElementsUnrollPatterns(
1389 RewritePatternSet &patterns, PatternBenefit benefit) {
1390 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1391 benefit);
1392}
1393
1394void mlir::vector::populateVectorFromElementsUnrollPatterns(
1395 RewritePatternSet &patterns, PatternBenefit benefit) {
1396 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1397 benefit);
1398}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:520
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:415
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:391
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getResults()
Definition Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:412
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.