MLIR 23.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29/// Compute the indices of the slice `index` for a transfer op.
32 AffineMap permutationMap,
33 Location loc,
34 OpBuilder &builder) {
35 MLIRContext *ctx = builder.getContext();
36 auto isBroadcast = [](AffineExpr expr) {
37 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
39 return false;
40 };
41 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42 SmallVector<Value> slicedIndices(indices);
43 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44 if (isBroadcast(dim.value()))
45 continue;
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47 auto expr = getAffineDimExpr(0, builder.getContext()) +
48 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50 slicedIndices[pos] =
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52 }
53 return slicedIndices;
54}
55
56// Compute the new indices by adding `offsets` to `originalIndices`.
57// If m < n (m = offsets.size(), n = originalIndices.size()),
58// then only the trailing m values in `originalIndices` are updated.
60 Location loc,
61 OperandRange originalIndices,
62 ArrayRef<int64_t> offsets) {
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
65 SmallVector<Value> indices(originalIndices);
66
67 auto start = indices.size() - offsets.size();
68 for (auto [i, offset] : llvm::enumerate(offsets)) {
69 if (offset != 0) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
72 arith::ConstantIndexOp::create(rewriter, loc, offset));
73 }
74 }
75 return indices;
76}
77
78// Clones `op` into a new operations that takes `operands` and returns
79// `resultTypes`.
81 Operation *op,
82 ArrayRef<Value> operands,
83 ArrayRef<Type> resultTypes) {
84 return builder.create(loc, op->getName().getIdentifier(), operands,
85 resultTypes, op->getAttrs());
86}
87
88/// Return the target shape for unrolling for the given `op`. Return
89/// std::nullopt if the op shouldn't be or cannot be unrolled.
90static std::optional<SmallVector<int64_t>>
92 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93 if (options.filterConstraint && failed(options.filterConstraint(op))) {
94 LDBG() << "--no filter constraint -> BAIL";
95 return std::nullopt;
96 }
97 assert(options.nativeShape &&
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() << "--not an unrollable op -> BAIL";
103 return std::nullopt;
104 }
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108 return std::nullopt;
109 }
110 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111
112 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113 if (!targetShape) {
114 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115 return std::nullopt;
116 }
117 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118
119 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() << "--could not compute integral shape ratio -> BAIL";
122 return std::nullopt;
123 }
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125 LDBG() << "--no unrolling needed -> SKIP";
126 return std::nullopt;
127 }
128 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129 return targetShape;
130}
131
133getUnrollOrder(unsigned numLoops, Operation *op,
135 SmallVector<int64_t> loopOrder =
136 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137 if (options.traversalOrderCallback != nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
140 if (order) {
141 loopOrder = std::move(*order);
142 }
143 }
144 return loopOrder;
145}
146
147namespace {
148
149struct UnrollTransferReadPattern
150 : public OpRewritePattern<vector::TransferReadOp> {
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 options(options) {}
156
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter) const override {
159 // TODO: support 0-d corner case.
160 if (readOp.getTransferRank() == 0)
161 return failure();
162 if (readOp.getMask())
163 return failure();
164 auto targetShape = getTargetShape(options, readOp);
165 if (!targetShape)
166 return failure();
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171
172 // Prepare the result vector;
173 Value result =
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 rewriter.getZeroAttr(sourceVectorType));
176 auto targetType =
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
181 getUnrollOrder(originalSize.size(), readOp, options);
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184 SmallVector<Value> indices =
185 sliceTransferIndices(elementOffsets, originalIndices,
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
191
192 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
194 }
195 rewriter.replaceOp(readOp, result);
196 return success();
197 }
198
199private:
200 vector::UnrollVectorOptions options;
201};
202
203struct UnrollTransferWritePattern
204 : public OpRewritePattern<vector::TransferWriteOp> {
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 options(options) {}
210
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter) const override {
213 // TODO: support 0-d corner case.
214 if (writeOp.getTransferRank() == 0)
215 return failure();
216
217 if (writeOp.getMask())
218 return failure();
219 auto targetShape = getTargetShape(options, writeOp);
220 if (!targetShape)
221 return failure();
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226 // Bail-out if rank(source) != rank(target). The main limitation here is the
227 // fact that `ExtractStridedSlice` requires the rank for the input and
228 // output to match. If needed, we can relax this later.
229 if (originalSize.size() != targetShape->size())
230 return rewriter.notifyMatchFailure(
231 writeOp,
232 "expected source input vector rank to match target shape rank");
233
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
237 getUnrollOrder(originalSize.size(), writeOp, options);
238 Value resultTensor;
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243 SmallVector<Value> indices =
244 sliceTransferIndices(elementOffsets, originalIndices,
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 // For the tensor case update the destination for the next transfer write.
251 if (!slicedWrite->getResults().empty())
252 resultTensor = slicedWrite->getResult(0);
253 }
254 if (resultTensor)
255 rewriter.replaceOp(writeOp, resultTensor);
256 else
257 rewriter.eraseOp(writeOp);
258 return success();
259 }
260
261private:
262 vector::UnrollVectorOptions options;
263};
264
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267
268 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269
270 static unsigned getHashValue(const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
272 }
273
274 static bool isEqual(const SmallVector<int64_t> &lhs,
275 const SmallVector<int64_t> &rhs) {
276 return lhs == rhs;
277 }
278};
279
280struct UnrollContractionPattern
281 : public OpRewritePattern<vector::ContractionOp> {
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
286 options(options) {}
287
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter) const override {
290 auto targetShape = getTargetShape(options, contractOp);
291 if (!targetShape)
292 return failure();
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299 llvm::MapVector<
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302 accCache;
303
304 SmallVector<int64_t> loopOrder = getUnrollOrder(
305 contractOp.getIteratorTypes().size(), contractOp, options);
306
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310
311 // Helper to compute the new shape of each operand and extract the slice.
312 auto extractOperand = [&](unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
315 SmallVector<int64_t> operandShape = applyPermutationMap(
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
319 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320 loc, operand, operandOffets, operandShape, operandStrides);
321 };
322
323 // Extract the new lhs operand.
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
326 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328
329 // Extract the new rhs operand.
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
332 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
337 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338 // If a version of the accumulator has already been computed, use it
339 // otherwise extract the first version from the original operand.
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
343 else
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345
346 SmallVector<int64_t> dstShape =
347 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
349 Operation *newOp = cloneOpWithOperandsAndTypes(
350 rewriter, loc, contractOp, slicesOperands, targetType);
351
352 SmallVector<int64_t> dstOffets =
353 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354 // Save the accumulated value untill all the loops are unrolled since
355 // reduction loop keep updating the accumulator.
356 accCache[dstOffets] = newOp->getResult(0);
357 }
358 // Assemble back the accumulator into a single vector.
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360 rewriter.getZeroAttr(dstVecType));
361 for (const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
363 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
365 }
366 rewriter.replaceOp(contractOp, result);
367 return success();
368 }
369
370private:
371 vector::UnrollVectorOptions options;
372};
373
374struct UnrollMultiReductionPattern
375 : public OpRewritePattern<vector::MultiDimReductionOp> {
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380 options(options) {}
381
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter) const override {
384 std::optional<SmallVector<int64_t>> targetShape =
385 getTargetShape(options, reductionOp);
386 if (!targetShape)
387 return failure();
388 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
389 Location loc = reductionOp.getLoc();
390 auto resultType = reductionOp->getResult(0).getType();
391
392 // Handle scalar result case: all dimensions are reduced.
393 // Each source tile is reduced to a scalar, and partial results are
394 // chained through the accumulator operand.
395 if (resultType.isIntOrFloat()) {
396 Value accumulator = reductionOp.getAcc();
397 for (SmallVector<int64_t> offsets :
398 StaticTileOffsetRange(originalSize, *targetShape)) {
399 SmallVector<int64_t> operandStrides(offsets.size(), 1);
400 Value slicedOperand =
401 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
402 loc, reductionOp.getSource(), offsets, *targetShape,
403 operandStrides);
404 Operation *newOp = cloneOpWithOperandsAndTypes(
405 rewriter, loc, reductionOp, {slicedOperand, accumulator},
406 resultType);
407 accumulator = newOp->getResult(0);
408 }
409 rewriter.replaceOp(reductionOp, accumulator);
410 return success();
411 }
412
413 // Vector result case.
414 llvm::MapVector<
415 SmallVector<int64_t>, Value,
416 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
417 accCache;
418
419 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
420 // of multiples of the targetShape.
421 for (SmallVector<int64_t> offsets :
422 StaticTileOffsetRange(originalSize, *targetShape)) {
423 SmallVector<Value> operands;
424 SmallVector<int64_t> operandStrides(offsets.size(), 1);
425 Value slicedOperand =
426 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
427 loc, reductionOp.getSource(), offsets, *targetShape,
428 operandStrides);
429 operands.push_back(slicedOperand);
430 SmallVector<int64_t> dstShape;
431 SmallVector<int64_t> destOffset;
432 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
433 if (!reductionOp.isReducedDim(i)) {
434 destOffset.push_back(offsets[i]);
435 dstShape.push_back((*targetShape)[i]);
436 }
437 }
438 Value acc;
439 SmallVector<int64_t> accStrides(destOffset.size(), 1);
440 // If a version of the accumulator has already been computed, use it
441 // otherwise extract the first version from the original operand.
442 auto *accIt = accCache.find(destOffset);
443 if (accIt != accCache.end())
444 acc = accIt->second;
445 else
446 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
447 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
448 operands.push_back(acc);
449 auto targetType = VectorType::get(
450 dstShape, reductionOp.getSourceVectorType().getElementType());
451 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
452 operands, targetType);
453 Value result = newOp->getResult(0);
454 accCache[destOffset] = result;
455 }
456 // Assemble back the accumulator into a single vector.
457 Value result = arith::ConstantOp::create(
458 rewriter, loc, reductionOp.getDestType(),
459 rewriter.getZeroAttr(reductionOp.getDestType()));
460 for (const auto &it : accCache) {
461 SmallVector<int64_t> dstStrides(it.first.size(), 1);
462 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
463 loc, it.second, result, it.first, dstStrides);
464 }
465 rewriter.replaceOp(reductionOp, result);
466 return success();
467 }
468
469private:
470 vector::UnrollVectorOptions options;
471};
472
473struct UnrollElementwisePattern : public RewritePattern {
474 UnrollElementwisePattern(MLIRContext *context,
475 const vector::UnrollVectorOptions &options,
476 PatternBenefit benefit = 1)
477 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
478 options(options) {}
479
480 LogicalResult matchAndRewrite(Operation *op,
481 PatternRewriter &rewriter) const override {
483 return failure();
484 auto targetShape = getTargetShape(options, op);
485 if (!targetShape)
486 return failure();
487 int64_t targetShapeRank = targetShape->size();
488 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
489 SmallVector<int64_t> originalSize =
490 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
491 int64_t originalShapeRank = originalSize.size();
492
493 Location loc = op->getLoc();
494
495 // Handle rank mismatch by adding leading unit dimensions to targetShape
496 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
497 int64_t rankDiff = originalShapeRank - targetShapeRank;
498 std::fill(adjustedTargetShape.begin(),
499 adjustedTargetShape.begin() + rankDiff, 1);
500 std::copy(targetShape->begin(), targetShape->end(),
501 adjustedTargetShape.begin() + rankDiff);
502
503 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
504 // Prepare the result vector.
505 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
506 rewriter.getZeroAttr(dstVecType));
507 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
508 VectorType unrolledVecType =
509 VectorType::get(*targetShape, dstVecType.getElementType());
510
511 // Create the unrolled computation.
512 for (SmallVector<int64_t> offsets :
513 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
514 SmallVector<Value> extractOperands;
515 for (OpOperand &operand : op->getOpOperands()) {
516 auto vecType = dyn_cast<VectorType>(operand.get().getType());
517 if (!vecType) {
518 extractOperands.push_back(operand.get());
519 continue;
520 }
521 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
522 loc, operand.get(), offsets, adjustedTargetShape, strides);
523
524 // Reshape to remove leading unit dims if needed
525 if (adjustedTargetShapeRank > targetShapeRank) {
526 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
527 loc, VectorType::get(*targetShape, vecType.getElementType()),
528 extracted);
529 }
530 extractOperands.push_back(extracted);
531 }
532
533 Operation *newOp = cloneOpWithOperandsAndTypes(
534 rewriter, loc, op, extractOperands, unrolledVecType);
535
536 Value computeResult = newOp->getResult(0);
537
538 // Use strides sized to targetShape for proper insertion
539 SmallVector<int64_t> insertStrides =
540 (adjustedTargetShapeRank > targetShapeRank)
541 ? SmallVector<int64_t>(targetShapeRank, 1)
542 : strides;
543
544 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
545 loc, computeResult, result, offsets, insertStrides);
546 }
547 rewriter.replaceOp(op, result);
548 return success();
549 }
550
551private:
552 vector::UnrollVectorOptions options;
553};
554
555struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
556 UnrollReductionPattern(MLIRContext *context,
557 const vector::UnrollVectorOptions &options,
558 PatternBenefit benefit = 1)
559 : OpRewritePattern<vector::ReductionOp>(context, benefit),
560 options(options) {}
561
562 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
563 PatternRewriter &rewriter) const override {
564 std::optional<SmallVector<int64_t>> targetShape =
565 getTargetShape(options, reductionOp);
566 if (!targetShape)
567 return failure();
568 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
569
570 // Create unrolled vector reduction.
571 Location loc = reductionOp.getLoc();
572 Value accumulator = nullptr;
573 for (SmallVector<int64_t> offsets :
574 StaticTileOffsetRange(originalSize, *targetShape)) {
575 SmallVector<int64_t> strides(offsets.size(), 1);
576 Value slicedOperand =
577 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
578 loc, reductionOp.getVector(), offsets, *targetShape, strides);
579 Operation *newOp = cloneOpWithOperandsAndTypes(
580 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
581 Value result = newOp->getResult(0);
582
583 if (!accumulator) {
584 // This is the first reduction.
585 accumulator = result;
586 } else {
587 // On subsequent reduction, combine with the accumulator.
588 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
589 accumulator, result);
590 }
591 }
592
593 rewriter.replaceOp(reductionOp, accumulator);
594 return success();
595 }
596
597private:
598 const vector::UnrollVectorOptions options;
599};
600
601struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
602 UnrollTransposePattern(MLIRContext *context,
603 const vector::UnrollVectorOptions &options,
604 PatternBenefit benefit = 1)
605 : OpRewritePattern<vector::TransposeOp>(context, benefit),
606 options(options) {}
607
608 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
609 PatternRewriter &rewriter) const override {
610 if (transposeOp.getResultVectorType().getRank() == 0)
611 return failure();
612 auto targetShape = getTargetShape(options, transposeOp);
613 if (!targetShape)
614 return failure();
615 auto originalVectorType = transposeOp.getResultVectorType();
616 SmallVector<int64_t> strides(targetShape->size(), 1);
617 Location loc = transposeOp.getLoc();
618 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
619
620 // Prepare the result vector;
621 Value result =
622 arith::ConstantOp::create(rewriter, loc, originalVectorType,
623 rewriter.getZeroAttr(originalVectorType));
624 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
625
626 // Unroll the computation.
627 for (SmallVector<int64_t> elementOffsets :
628 StaticTileOffsetRange(originalSize, *targetShape)) {
629 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
630 SmallVector<int64_t> permutedShape(elementOffsets.size());
631 // Compute the source offsets and shape.
632 for (auto indices : llvm::enumerate(permutation)) {
633 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
634 permutedShape[indices.value()] = (*targetShape)[indices.index()];
635 }
636 Value slicedOperand =
637 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
638 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
639 strides);
640 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
641 loc, slicedOperand, permutation);
642 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
643 loc, transposedSlice, result, elementOffsets, strides);
644 }
645 rewriter.replaceOp(transposeOp, result);
646 return success();
647 }
648
649private:
650 vector::UnrollVectorOptions options;
651};
652
653struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
654 UnrollGatherPattern(MLIRContext *context,
655 const vector::UnrollVectorOptions &options,
656 PatternBenefit benefit = 1)
657 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
658 }
659
660 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
661 PatternRewriter &rewriter) const override {
662 VectorType sourceVectorType = gatherOp.getVectorType();
663 if (sourceVectorType.getRank() == 0)
664 return failure();
665 auto targetShape = getTargetShape(options, gatherOp);
666 if (!targetShape)
667 return failure();
668 SmallVector<int64_t> strides(targetShape->size(), 1);
669 Location loc = gatherOp.getLoc();
670 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
671
672 // Prepare the result vector;
673 Value result =
674 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
675 rewriter.getZeroAttr(sourceVectorType));
676 auto targetType =
677 VectorType::get(*targetShape, sourceVectorType.getElementType());
678
679 SmallVector<int64_t> loopOrder =
680 getUnrollOrder(originalSize.size(), gatherOp, options);
681 for (SmallVector<int64_t> elementOffsets :
682 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
683 // To get the unrolled gather, extract the same slice based on the
684 // decomposed shape from each of the index, mask, and pass-through
685 // vectors.
686 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
687 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
688 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
689 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
690 Value passThruSubVec =
691 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
692 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
693 strides);
694 auto slicedGather = vector::GatherOp::create(
695 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
696 indexSubVec, maskSubVec, passThruSubVec);
697
698 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
699 loc, slicedGather, result, elementOffsets, strides);
700 }
701 rewriter.replaceOp(gatherOp, result);
702 return success();
703 }
704
705private:
706 vector::UnrollVectorOptions options;
707};
708
709struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
710 UnrollLoadPattern(MLIRContext *context,
711 const vector::UnrollVectorOptions &options,
712 PatternBenefit benefit = 1)
713 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
714
715 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
716 PatternRewriter &rewriter) const override {
717 VectorType vecType = loadOp.getVectorType();
718
719 auto targetShape = getTargetShape(options, loadOp);
720 if (!targetShape)
721 return failure();
722
723 Location loc = loadOp.getLoc();
724 ArrayRef<int64_t> originalShape = vecType.getShape();
725 SmallVector<int64_t> strides(targetShape->size(), 1);
726
727 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
728 rewriter.getZeroAttr(vecType));
729
730 SmallVector<int64_t> loopOrder =
731 getUnrollOrder(originalShape.size(), loadOp, options);
732
733 auto targetVecType =
734 VectorType::get(*targetShape, vecType.getElementType());
735
736 for (SmallVector<int64_t> offsets :
737 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
738 SmallVector<Value> indices =
739 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
740 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
741 loadOp.getBase(), indices);
742 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
743 loc, slicedLoad, result, offsets, strides);
744 }
745 rewriter.replaceOp(loadOp, result);
746 return success();
747 }
748
749private:
750 vector::UnrollVectorOptions options;
751};
752
753struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
754 UnrollStorePattern(MLIRContext *context,
755 const vector::UnrollVectorOptions &options,
756 PatternBenefit benefit = 1)
757 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
758
759 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
760 PatternRewriter &rewriter) const override {
761 VectorType vecType = storeOp.getVectorType();
762
763 auto targetShape = getTargetShape(options, storeOp);
764 if (!targetShape)
765 return failure();
766
767 Location loc = storeOp.getLoc();
768 ArrayRef<int64_t> originalShape = vecType.getShape();
769 SmallVector<int64_t> strides(targetShape->size(), 1);
770
771 Value base = storeOp.getBase();
772 Value vector = storeOp.getValueToStore();
773
774 SmallVector<int64_t> loopOrder =
775 getUnrollOrder(originalShape.size(), storeOp, options);
776
777 for (SmallVector<int64_t> offsets :
778 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
779 SmallVector<Value> indices =
780 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
781 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
782 loc, vector, offsets, *targetShape, strides);
783 vector::StoreOp::create(rewriter, loc, slice, base, indices);
784 }
785 rewriter.eraseOp(storeOp);
786 return success();
787 }
788
789private:
790 vector::UnrollVectorOptions options;
791};
792
793struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
794 UnrollBroadcastPattern(MLIRContext *context,
795 const vector::UnrollVectorOptions &options,
796 PatternBenefit benefit = 1)
797 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
798 options(options) {}
799
800 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
801 PatternRewriter &rewriter) const override {
802 auto targetShape = getTargetShape(options, broadcastOp);
803 if (!targetShape)
804 return failure();
805
806 Location loc = broadcastOp.getLoc();
807 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
808 VectorType resType = broadcastOp.getResultVectorType();
809 VectorType targetType =
810 resType.cloneWith(*targetShape, resType.getElementType());
811 Value result = arith::ConstantOp::create(rewriter, loc, resType,
812 rewriter.getZeroAttr(resType));
813
814 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
815 SmallVector<int64_t> strides(originalShape.size(), 1);
816
817 for (SmallVector<int64_t> offsets :
818 StaticTileOffsetRange(originalShape, *targetShape)) {
819 Value newSrc;
820 if (!srcType) {
821 // Scalar to vector broadcast.
822 newSrc = broadcastOp.getSource();
823 } else {
824 // Vector to vector broadcast.
825 int64_t rank = srcType.getRank();
826 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
827 SmallVector<int64_t> srcShape(targetShape->end() - rank,
828 targetShape->end());
829 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
830 // adjust the offset and shape for src if the corresponding dim is 1.
831 for (int64_t i = 0; i < rank; ++i) {
832 if (srcType.getDimSize(i) == 1) {
833 srcOffsets[i] = 0;
834 srcShape[i] = 1;
835 }
836 }
837 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
838 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
839 }
840
841 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
842 newSrc, targetType);
843
844 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
845 loc, newOp->getResult(0), result, offsets, strides);
846 }
847
848 rewriter.replaceOp(broadcastOp, result);
849 return success();
850 }
851
852private:
853 vector::UnrollVectorOptions options;
854};
855
856/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
857/// outermost dimension of the operand. For example:
858///
859/// ```
860/// %0:4 = vector.to_elements %v : vector<2x2xf32>
861///
862/// ==>
863///
864/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
865/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
866/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
867/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
868/// ```
869///
870/// When this pattern is applied until a fixed-point is reached,
871/// this will produce a sequence of 1-d from_elements
872/// ops.
873struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
874 UnrollToElements(MLIRContext *context,
875 const vector::UnrollVectorOptions &options,
876 PatternBenefit benefit = 1)
877 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
878 options(options) {}
879
880 LogicalResult matchAndRewrite(vector::ToElementsOp op,
881 PatternRewriter &rewriter) const override {
882
883 TypedValue<VectorType> source = op.getSource();
884 FailureOr<SmallVector<Value>> result =
885 vector::unrollVectorValue(source, rewriter);
886 if (failed(result)) {
887 return failure();
888 }
889 SmallVector<Value> vectors = *result;
890
891 SmallVector<Value> results;
892 for (Value vector : vectors) {
893 auto subElements =
894 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
895 llvm::append_range(results, subElements.getResults());
896 }
897 rewriter.replaceOp(op, results);
898 return success();
899 }
900
901private:
902 vector::UnrollVectorOptions options;
903};
904
905/// This pattern unrolls `vector.step` operations according to the provided
906/// target unroll shape. It decomposes a large step vector into smaller step
907/// vectors (segments) and assembles the result by inserting each computed
908/// segment into the appropriate offset of the original vector.
909///
910/// The pattern does not support scalable vectors and will fail to match them.
911///
912/// For each segment, it adds the base step vector and the segment's offset,
913/// then inserts the result into the output vector at the corresponding
914/// position.
915///
916/// Example:
917/// Given a step operation:
918/// %0 = vector.step : vector<8xindex>
919///
920/// and a target unroll shape of <4>, the pattern produces:
921///
922/// %base = vector.step : vector<4xindex>
923/// %zero = arith.constant dense<0> : vector<8xindex>
924/// %result0 = vector.insert_strided_slice %base, %zero
925/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
926/// %offset = arith.constant dense<4> : vector<4xindex>
927/// %segment1 = arith.addi %base, %offset : vector<4xindex>
928/// %result1 = vector.insert_strided_slice %segment1, %result0
929/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
930///
931struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
932 UnrollStepPattern(MLIRContext *context,
933 const vector::UnrollVectorOptions &options,
934 PatternBenefit benefit = 1)
935 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
936
937 LogicalResult matchAndRewrite(vector::StepOp stepOp,
938 PatternRewriter &rewriter) const override {
939 std::optional<SmallVector<int64_t>> targetShape =
940 getTargetShape(options, stepOp);
941 if (!targetShape)
942 return failure();
943
944 VectorType vecType = stepOp.getType();
945 if (vecType.isScalable()) {
946 // Scalable vectors are not supported by this pattern.
947 return failure();
948 }
949 int64_t originalSize = vecType.getShape()[0];
950 Location loc = stepOp.getLoc();
951 SmallVector<int64_t> strides(1, 1);
952
953 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
954 rewriter.getZeroAttr(vecType));
955
956 auto targetVecType =
957 VectorType::get(*targetShape, vecType.getElementType());
958 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
959 for (const SmallVector<int64_t> &offsets :
960 StaticTileOffsetRange({originalSize}, *targetShape)) {
961 Value bcastOffset = arith::ConstantOp::create(
962 rewriter, loc, targetVecType,
964 targetVecType,
965 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
966 Value tileStep =
967 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
968
969 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
970 loc, tileStep, result, offsets, strides);
971 }
972 rewriter.replaceOp(stepOp, result);
973 return success();
974 }
975
976private:
977 vector::UnrollVectorOptions options;
978};
979
980/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
981/// outermost dimension. For example:
982/// ```
983/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
984///
985/// ==>
986///
987/// %0 = ub.poison : vector<2x3xf32>
988/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
989/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
990/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
991/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
992/// ```
993///
994/// When this pattern is applied until a fixed-point is reached,
995/// this will produce a sequence of 1-d from_elements
996/// ops.
997struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
998 UnrollFromElements(MLIRContext *context,
999 const vector::UnrollVectorOptions &options,
1000 PatternBenefit benefit = 1)
1001 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
1002 options(options) {}
1003
1004 LogicalResult matchAndRewrite(vector::FromElementsOp op,
1005 PatternRewriter &rewriter) const override {
1006 ValueRange allElements = op.getElements();
1007
1008 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1009 VectorType subTy, int64_t index) {
1010 size_t subTyNumElements = subTy.getNumElements();
1011 assert((index + 1) * subTyNumElements <= allElements.size() &&
1012 "out of bounds");
1013 ValueRange subElements =
1014 allElements.slice(index * subTyNumElements, subTyNumElements);
1015 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1016 };
1017
1018 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1019 }
1020
1021private:
1022 vector::UnrollVectorOptions options;
1023};
1024
1025/// This pattern unrolls `vector.create_mask` operations into smaller mask
1026/// operations based on the target unroll shape. Each unrolled slice computes
1027/// its local mask size in each dimension (d) as:
1028/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1029/// Example:
1030/// Given a create_mask operation:
1031/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1032/// elements
1033///
1034/// and a target unroll shape of <4x8>, the pattern produces:
1035///
1036/// %false = arith.constant dense<false> : vector<8x16xi1>
1037///
1038/// Slice [0,0]:
1039/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1040/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1041/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1042/// : vector<4x8xi1> into vector<8x16xi1>
1043/// Slice [0,8]:
1044/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1045/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1046/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1047/// : vector<4x8xi1> into vector<8x16xi1>
1048/// Slice [4,0]:
1049/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1050/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1051/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1052/// : vector<4x8xi1> into vector<8x16xi1>
1053/// Slice [4,8]:
1054/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1055/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1056/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1057/// : vector<4x8xi1> into vector<8x16xi1>
1058struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1059 UnrollCreateMaskPattern(MLIRContext *context,
1060 const vector::UnrollVectorOptions &options,
1061 PatternBenefit benefit = 1)
1062 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1063 options(options) {}
1064
1065 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1066 PatternRewriter &rewriter) const override {
1067 auto targetShape = getTargetShape(options, createMaskOp);
1068 if (!targetShape)
1069 return failure();
1070
1071 VectorType resultType = createMaskOp.getVectorType();
1072 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1073 Location loc = createMaskOp.getLoc();
1074
1075 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1076 rewriter.getZeroAttr(resultType));
1077 VectorType targetVectorType =
1078 VectorType::get(*targetShape, rewriter.getI1Type());
1079 SmallVector<int64_t> strides(targetShape->size(), 1);
1080
1081 // In each dimension (d), each unrolled vector computes its mask size as:
1082 // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1083 for (SmallVector<int64_t> offsets :
1084 StaticTileOffsetRange(originalSize, *targetShape)) {
1085 SmallVector<Value> unrolledOperands;
1086
1087 for (auto [i, originalMaskOperand] :
1088 llvm::enumerate(createMaskOp.getOperands())) {
1089 Value offsetVal =
1090 arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1091 Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1092 loc, originalMaskOperand, offsetVal);
1093 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1094 Value unrolledDimSize =
1095 arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1096 Value nonNegative =
1097 rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1098 Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1099 loc, nonNegative, unrolledDimSize);
1100 unrolledOperands.push_back(unrolledOperand);
1101 }
1102
1103 auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1104 loc, targetVectorType, unrolledOperands);
1105 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1106 loc, unrolledMask, result, offsets, strides);
1107 }
1108 rewriter.replaceOp(createMaskOp, result);
1109 return success();
1110 }
1111
1112private:
1113 vector::UnrollVectorOptions options;
1114};
1115
1116/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1117/// operations based on the target unroll shape. Each unrolled slice computes
1118/// whether its elements should be masked based on the original mask dimensions
1119/// and the slice's offset position.
1120///
1121/// Example:
1122/// Given a constant_mask operation:
1123/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1124///
1125/// and a target unroll shape of <4x8>, the pattern produces:
1126///
1127/// %false = arith.constant dense<false> : vector<8x16xi1>
1128///
1129/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1130/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1131/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1132/// : vector<4x8xi1> into vector<8x16xi1>
1133///
1134/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1135/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1136/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1137/// : vector<4x8xi1> into vector<8x16xi1>
1138///
1139/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1140/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1141/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1142/// : vector<4x8xi1> into vector<8x16xi1>
1143///
1144/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1145/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1146/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1147/// : vector<4x8xi1> into vector<8x16xi1>
1148struct UnrollConstantMaskPattern
1149 : public OpRewritePattern<vector::ConstantMaskOp> {
1150 UnrollConstantMaskPattern(MLIRContext *context,
1151 const vector::UnrollVectorOptions &options,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1154 options(options) {}
1155
1156 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1157 PatternRewriter &rewriter) const override {
1158 std::optional<SmallVector<int64_t>> targetShape =
1159 getTargetShape(options, constantMaskOp);
1160 if (!targetShape)
1161 return failure();
1162
1163 VectorType resultType = constantMaskOp.getVectorType();
1164 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1165 Location loc = constantMaskOp.getLoc();
1166
1167 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1168 rewriter.getZeroAttr(resultType));
1169 VectorType targetVectorType =
1170 VectorType::get(*targetShape, rewriter.getI1Type());
1171 SmallVector<int64_t> strides(targetShape->size(), 1);
1172
1173 // In each dimension (d), each unrolled vector computes its mask size as:
1174 // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1175 for (const SmallVector<int64_t> &offsets :
1176 StaticTileOffsetRange(originalSize, *targetShape)) {
1177 SmallVector<int64_t> unrolledMaskDims;
1178
1179 for (auto [i, originalMaskDim] :
1180 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1181 // Calculate how many elements in this dimension should be masked
1182 // for this particular slice
1183 int64_t adjustedMaskSize =
1184 std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1185 int64_t unrolledMaskDim =
1186 std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1187 unrolledMaskDims.push_back(unrolledMaskDim);
1188 }
1189
1190 auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1191 loc, targetVectorType, unrolledMaskDims);
1192 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1193 loc, unrolledMask, result, offsets, strides);
1194 }
1195 rewriter.replaceOp(constantMaskOp, result);
1196 return success();
1197 }
1198
1199private:
1200 vector::UnrollVectorOptions options;
1201};
1202
1203/// Checks whether extractShape is a contiguous slice of shape.
1204/// For extractShape to be contiguous in shape:
1205/// 1) All but the leading dimension of extractShape and shape must match
1206/// exactly. 2) The total number of elements in shape must be evenly divisible
1207/// by
1208/// the total number of elements in extractShape.
1209/// Examples:
1210/// isContiguous([4, 4], [8, 4]) == true
1211/// isContiguous([2, 4], [8, 4]) == true
1212/// isContiguous([2, 2], [8, 4]) == false
1213/// Removes leading unit dimensions to handle cases like:
1214/// isContiguous([1, 16], [1, 32]) == true
1215static bool isContiguous(ArrayRef<int64_t> extractShape,
1217
1218 if (extractShape.empty() || shape.empty() ||
1219 extractShape.size() > shape.size())
1220 return false;
1221
1222 while (extractShape.size() > 1 && extractShape.front() == 1)
1223 extractShape = extractShape.drop_front();
1224
1225 while (shape.size() > 1 && shape.front() == 1) {
1226 shape = shape.drop_front();
1227 }
1228
1229 size_t rankDiff = shape.size() - extractShape.size();
1230 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1231 return false;
1232
1233 int64_t extractElements = ShapedType::getNumElements(extractShape);
1234 int64_t shapeElements = ShapedType::getNumElements(shape);
1235 return shapeElements % extractElements == 0;
1236}
1237
1238/// Determines what shape to use with `vector.extract_strided_slice` to extract
1239/// a contiguous memory region from a source vector. The extraction must be
1240/// contiguous and contain exactly the specified number of elements. If such an
1241/// extraction shape cannot be determined, returns std::nullopt.
1242/// EXAMPLE 1:
1243/// sourceShape = [16], targetElements = 8
1244/// Working right-to-left:
1245/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1246/// remaining = 8/8 = 1
1247/// Result: [8]
1248///
1249/// EXAMPLE 2:
1250/// sourceShape = [4, 4], targetElements = 8
1251/// Working right-to-left:
1252/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1253/// remaining = 8/4 = 2
1254/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1255/// remaining = 2/2 = 1
1256/// Result: [2, 4]
1257static std::optional<SmallVector<int64_t>>
1258calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1259 int64_t targetElements) {
1260 SmallVector<int64_t> extractShape;
1261 int64_t remainingElements = targetElements;
1262
1263 // Build extract shape from innermost dimension outward to ensure contiguity.
1264 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1265 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1266 extractShape.insert(extractShape.begin(), takeFromDim);
1267
1268 if (remainingElements % takeFromDim != 0)
1269 return std::nullopt; // Not evenly divisible.
1270 remainingElements /= takeFromDim;
1271 }
1272
1273 // Fill remaining dimensions with 1.
1274 while (extractShape.size() < sourceShape.size())
1275 extractShape.insert(extractShape.begin(), 1);
1276
1277 if (ShapedType::getNumElements(extractShape) != targetElements)
1278 return std::nullopt;
1279
1280 return extractShape;
1281}
1282
1283// Convert result offsets to source offsets via linear position.
1285calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1286 ArrayRef<int64_t> sourceShape,
1287 ArrayRef<int64_t> resultShape) {
1288 // Convert result offsets to linear position.
1289 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1290 // Convert linear position to source offsets.
1291 return delinearize(linearIndex, computeStrides(sourceShape));
1292}
1293
1294/// This pattern unrolls `vector.shape_cast` operations according to the
1295/// provided target unroll shape. It unrolls a large shape cast into smaller
1296/// shape casts by extracting contiguous slices from the source vector, casting
1297/// each slice to the target shape, and assembling the result by inserting each
1298/// computed segment into the appropriate offset of the result vector.
1299///
1300/// This pattern only applies when contiguous slices can be extracted from the
1301/// source vector and inserted into the result vector such that each slice
1302/// remains a valid vector (and not decompose to scalars). In these cases, the
1303/// unrolling proceeds as:
1304/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1305/// vector.insert_strided_slice.
1306///
1307/// Example:
1308/// Given a shape cast operation:
1309/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1310///
1311/// and a target unroll shape of <2x4>, the pattern produces:
1312///
1313/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1314/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1315/// : vector<8x2xf32> to vector<4x2xf32>
1316/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1317/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1318/// : vector<2x4xf32> into vector<4x4xf32>
1319/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1320/// : vector<8x2xf32> to vector<4x2xf32>
1321/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1322/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1323/// : vector<2x4xf32> into vector<4x4xf32>
1324///
1325struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1326 UnrollShapeCastPattern(MLIRContext *context,
1327 const vector::UnrollVectorOptions &options,
1328 PatternBenefit benefit = 1)
1329 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1330 options(options) {}
1331
1332 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1333 PatternRewriter &rewriter) const override {
1334 std::optional<SmallVector<int64_t>> targetShape =
1335 getTargetShape(options, shapeCastOp);
1336 if (!targetShape)
1337 return failure();
1338
1339 VectorType sourceType = shapeCastOp.getSourceVectorType();
1340 VectorType resultType = shapeCastOp.getResultVectorType();
1341 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1342 ArrayRef<int64_t> resultShape = resultType.getShape();
1343
1344 if (!isContiguous(*targetShape, resultShape))
1345 return rewriter.notifyMatchFailure(
1346 shapeCastOp, "Only supports cases where target shape is "
1347 "contiguous in result vector shape");
1348
1349 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1350
1351 // Calculate the shape to extract from source.
1352 std::optional<SmallVector<int64_t>> extractShape =
1353 calculateSourceExtractShape(sourceShape, targetElements);
1354 if (!extractShape)
1355 return rewriter.notifyMatchFailure(
1356 shapeCastOp,
1357 "cannot extract target number of elements contiguously from source");
1358
1359 Location loc = shapeCastOp.getLoc();
1360
1361 // Create result vector initialized to zero.
1362 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1363 rewriter.getZeroAttr(resultType));
1364
1365 VectorType targetType =
1366 VectorType::get(*targetShape, sourceType.getElementType());
1367
1368 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1369 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1370
1371 for (SmallVector<int64_t> resultOffsets :
1372 StaticTileOffsetRange(resultShape, *targetShape)) {
1373 SmallVector<int64_t> sourceOffsets =
1374 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1375 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1376 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1378 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1379 loc, targetType, sourceChunk);
1380 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1381 loc, targetChunk, result, resultOffsets, insertStrides);
1382 }
1383
1384 rewriter.replaceOp(shapeCastOp, result);
1385 return success();
1386 }
1387
1388private:
1389 vector::UnrollVectorOptions options;
1390};
1391
1392} // namespace
1393
1394void mlir::vector::populateVectorUnrollPatterns(
1396 PatternBenefit benefit) {
1397 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1398 UnrollContractionPattern, UnrollElementwisePattern,
1399 UnrollReductionPattern, UnrollMultiReductionPattern,
1400 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1401 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1402 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1403 UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1404 patterns.getContext(), options, benefit);
1405}
1406
1407void mlir::vector::populateVectorToElementsUnrollPatterns(
1408 RewritePatternSet &patterns, PatternBenefit benefit) {
1409 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1410 benefit);
1411}
1412
1413void mlir::vector::populateVectorFromElementsUnrollPatterns(
1414 RewritePatternSet &patterns, PatternBenefit benefit) {
1415 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1416 benefit);
1417}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.