MLIR 22.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29/// Compute the indices of the slice `index` for a transfer op.
32 AffineMap permutationMap,
33 Location loc,
34 OpBuilder &builder) {
35 MLIRContext *ctx = builder.getContext();
36 auto isBroadcast = [](AffineExpr expr) {
37 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
39 return false;
40 };
41 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42 SmallVector<Value> slicedIndices(indices);
43 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44 if (isBroadcast(dim.value()))
45 continue;
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47 auto expr = getAffineDimExpr(0, builder.getContext()) +
48 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50 slicedIndices[pos] =
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52 }
53 return slicedIndices;
54}
55
56// Compute the new indices by adding `offsets` to `originalIndices`.
57// If m < n (m = offsets.size(), n = originalIndices.size()),
58// then only the trailing m values in `originalIndices` are updated.
60 Location loc,
61 OperandRange originalIndices,
62 ArrayRef<int64_t> offsets) {
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
65 SmallVector<Value> indices(originalIndices);
66
67 auto start = indices.size() - offsets.size();
68 for (auto [i, offset] : llvm::enumerate(offsets)) {
69 if (offset != 0) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
72 arith::ConstantIndexOp::create(rewriter, loc, offset));
73 }
74 }
75 return indices;
76}
77
78// Clones `op` into a new operations that takes `operands` and returns
79// `resultTypes`.
81 Operation *op,
82 ArrayRef<Value> operands,
83 ArrayRef<Type> resultTypes) {
84 return builder.create(loc, op->getName().getIdentifier(), operands,
85 resultTypes, op->getAttrs());
86}
87
88/// Return the target shape for unrolling for the given `op`. Return
89/// std::nullopt if the op shouldn't be or cannot be unrolled.
90static std::optional<SmallVector<int64_t>>
92 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93 if (options.filterConstraint && failed(options.filterConstraint(op))) {
94 LDBG() << "--no filter constraint -> BAIL";
95 return std::nullopt;
96 }
97 assert(options.nativeShape &&
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() << "--not an unrollable op -> BAIL";
103 return std::nullopt;
104 }
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108 return std::nullopt;
109 }
110 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111
112 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113 if (!targetShape) {
114 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115 return std::nullopt;
116 }
117 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118
119 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() << "--could not compute integral shape ratio -> BAIL";
122 return std::nullopt;
123 }
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125 LDBG() << "--no unrolling needed -> SKIP";
126 return std::nullopt;
127 }
128 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129 return targetShape;
130}
131
133getUnrollOrder(unsigned numLoops, Operation *op,
135 SmallVector<int64_t> loopOrder =
136 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137 if (options.traversalOrderCallback != nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
140 if (order) {
141 loopOrder = std::move(*order);
142 }
143 }
144 return loopOrder;
145}
146
147namespace {
148
149struct UnrollTransferReadPattern
150 : public OpRewritePattern<vector::TransferReadOp> {
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 options(options) {}
156
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter) const override {
159 // TODO: support 0-d corner case.
160 if (readOp.getTransferRank() == 0)
161 return failure();
162 if (readOp.getMask())
163 return failure();
164 auto targetShape = getTargetShape(options, readOp);
165 if (!targetShape)
166 return failure();
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171
172 // Prepare the result vector;
173 Value result =
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 rewriter.getZeroAttr(sourceVectorType));
176 auto targetType =
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
181 getUnrollOrder(originalSize.size(), readOp, options);
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184 SmallVector<Value> indices =
185 sliceTransferIndices(elementOffsets, originalIndices,
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
191
192 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
194 }
195 rewriter.replaceOp(readOp, result);
196 return success();
197 }
198
199private:
200 vector::UnrollVectorOptions options;
201};
202
203struct UnrollTransferWritePattern
204 : public OpRewritePattern<vector::TransferWriteOp> {
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 options(options) {}
210
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter) const override {
213 // TODO: support 0-d corner case.
214 if (writeOp.getTransferRank() == 0)
215 return failure();
216
217 if (writeOp.getMask())
218 return failure();
219 auto targetShape = getTargetShape(options, writeOp);
220 if (!targetShape)
221 return failure();
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226 // Bail-out if rank(source) != rank(target). The main limitation here is the
227 // fact that `ExtractStridedSlice` requires the rank for the input and
228 // output to match. If needed, we can relax this later.
229 if (originalSize.size() != targetShape->size())
230 return rewriter.notifyMatchFailure(
231 writeOp,
232 "expected source input vector rank to match target shape rank");
233
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
237 getUnrollOrder(originalSize.size(), writeOp, options);
238 Value resultTensor;
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243 SmallVector<Value> indices =
244 sliceTransferIndices(elementOffsets, originalIndices,
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 // For the tensor case update the destination for the next transfer write.
251 if (!slicedWrite->getResults().empty())
252 resultTensor = slicedWrite->getResult(0);
253 }
254 if (resultTensor)
255 rewriter.replaceOp(writeOp, resultTensor);
256 else
257 rewriter.eraseOp(writeOp);
258 return success();
259 }
260
261private:
262 vector::UnrollVectorOptions options;
263};
264
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267
268 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269
270 static unsigned getHashValue(const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
272 }
273
274 static bool isEqual(const SmallVector<int64_t> &lhs,
275 const SmallVector<int64_t> &rhs) {
276 return lhs == rhs;
277 }
278};
279
280struct UnrollContractionPattern
281 : public OpRewritePattern<vector::ContractionOp> {
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
286 options(options) {}
287
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter) const override {
290 auto targetShape = getTargetShape(options, contractOp);
291 if (!targetShape)
292 return failure();
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299 llvm::MapVector<
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302 accCache;
303
304 SmallVector<int64_t> loopOrder = getUnrollOrder(
305 contractOp.getIteratorTypes().size(), contractOp, options);
306
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310
311 // Helper to compute the new shape of each operand and extract the slice.
312 auto extractOperand = [&](unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
315 SmallVector<int64_t> operandShape = applyPermutationMap(
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
319 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320 loc, operand, operandOffets, operandShape, operandStrides);
321 };
322
323 // Extract the new lhs operand.
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
326 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328
329 // Extract the new rhs operand.
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
332 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
337 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338 // If a version of the accumulator has already been computed, use it
339 // otherwise extract the first version from the original operand.
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
343 else
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345
346 SmallVector<int64_t> dstShape =
347 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
349 Operation *newOp = cloneOpWithOperandsAndTypes(
350 rewriter, loc, contractOp, slicesOperands, targetType);
351
352 SmallVector<int64_t> dstOffets =
353 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354 // Save the accumulated value untill all the loops are unrolled since
355 // reduction loop keep updating the accumulator.
356 accCache[dstOffets] = newOp->getResult(0);
357 }
358 // Assemble back the accumulator into a single vector.
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360 rewriter.getZeroAttr(dstVecType));
361 for (const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
363 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
365 }
366 rewriter.replaceOp(contractOp, result);
367 return success();
368 }
369
370private:
371 vector::UnrollVectorOptions options;
372};
373
374struct UnrollMultiReductionPattern
375 : public OpRewritePattern<vector::MultiDimReductionOp> {
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380 options(options) {}
381
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter) const override {
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
386 return rewriter.notifyMatchFailure(reductionOp,
387 "Unrolling scalars is not supported");
388 }
389 std::optional<SmallVector<int64_t>> targetShape =
390 getTargetShape(options, reductionOp);
391 if (!targetShape)
392 return failure();
393 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
394 llvm::MapVector<
395 SmallVector<int64_t>, Value,
396 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
397 accCache;
398 Location loc = reductionOp.getLoc();
399
400 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
401 // of multiples of the targetShape.
402 for (SmallVector<int64_t> offsets :
403 StaticTileOffsetRange(originalSize, *targetShape)) {
404 SmallVector<Value> operands;
405 SmallVector<int64_t> operandStrides(offsets.size(), 1);
406 Value slicedOperand =
407 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
408 loc, reductionOp.getSource(), offsets, *targetShape,
409 operandStrides);
410 operands.push_back(slicedOperand);
411 SmallVector<int64_t> dstShape;
412 SmallVector<int64_t> destOffset;
413 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
417 }
418 }
419 Value acc;
420 SmallVector<int64_t> accStrides(destOffset.size(), 1);
421 // If a version of the accumulator has already been computed, use it
422 // otherwise extract the first version from the original operand.
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
425 acc = accIt->second;
426 else
427 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
430 auto targetType = VectorType::get(
431 dstShape, reductionOp.getSourceVectorType().getElementType());
432 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
433 operands, targetType);
434 Value result = newOp->getResult(0);
435 accCache[destOffset] = result;
436 }
437 // Assemble back the accumulator into a single vector.
438 Value result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
440 rewriter.getZeroAttr(reductionOp.getDestType()));
441 for (const auto &it : accCache) {
442 SmallVector<int64_t> dstStrides(it.first.size(), 1);
443 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
444 loc, it.second, result, it.first, dstStrides);
445 }
446 rewriter.replaceOp(reductionOp, result);
447 return success();
448 }
449
450private:
451 vector::UnrollVectorOptions options;
452};
453
454struct UnrollElementwisePattern : public RewritePattern {
455 UnrollElementwisePattern(MLIRContext *context,
456 const vector::UnrollVectorOptions &options,
457 PatternBenefit benefit = 1)
458 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
459 options(options) {}
460
461 LogicalResult matchAndRewrite(Operation *op,
462 PatternRewriter &rewriter) const override {
464 return failure();
465 auto targetShape = getTargetShape(options, op);
466 if (!targetShape)
467 return failure();
468 int64_t targetShapeRank = targetShape->size();
469 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
470 SmallVector<int64_t> originalSize =
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
473
474 Location loc = op->getLoc();
475
476 // Handle rank mismatch by adding leading unit dimensions to targetShape
477 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
483
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
485 // Prepare the result vector.
486 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
487 rewriter.getZeroAttr(dstVecType));
488 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489 VectorType unrolledVecType =
490 VectorType::get(*targetShape, dstVecType.getElementType());
491
492 // Create the unrolled computation.
493 for (SmallVector<int64_t> offsets :
494 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495 SmallVector<Value> extractOperands;
496 for (OpOperand &operand : op->getOpOperands()) {
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
498 if (!vecType) {
499 extractOperands.push_back(operand.get());
500 continue;
501 }
502 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
504
505 // Reshape to remove leading unit dims if needed
506 if (adjustedTargetShapeRank > targetShapeRank) {
507 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
508 loc, VectorType::get(*targetShape, vecType.getElementType()),
509 extracted);
510 }
511 extractOperands.push_back(extracted);
512 }
513
514 Operation *newOp = cloneOpWithOperandsAndTypes(
515 rewriter, loc, op, extractOperands, unrolledVecType);
516
517 Value computeResult = newOp->getResult(0);
518
519 // Use strides sized to targetShape for proper insertion
520 SmallVector<int64_t> insertStrides =
521 (adjustedTargetShapeRank > targetShapeRank)
522 ? SmallVector<int64_t>(targetShapeRank, 1)
523 : strides;
524
525 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
526 loc, computeResult, result, offsets, insertStrides);
527 }
528 rewriter.replaceOp(op, result);
529 return success();
530 }
531
532private:
533 vector::UnrollVectorOptions options;
534};
535
536struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
537 UnrollReductionPattern(MLIRContext *context,
538 const vector::UnrollVectorOptions &options,
539 PatternBenefit benefit = 1)
540 : OpRewritePattern<vector::ReductionOp>(context, benefit),
541 options(options) {}
542
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544 PatternRewriter &rewriter) const override {
545 std::optional<SmallVector<int64_t>> targetShape =
546 getTargetShape(options, reductionOp);
547 if (!targetShape)
548 return failure();
549 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
550
551 // Create unrolled vector reduction.
552 Location loc = reductionOp.getLoc();
553 Value accumulator = nullptr;
554 for (SmallVector<int64_t> offsets :
555 StaticTileOffsetRange(originalSize, *targetShape)) {
556 SmallVector<int64_t> strides(offsets.size(), 1);
557 Value slicedOperand =
558 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
560 Operation *newOp = cloneOpWithOperandsAndTypes(
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
562 Value result = newOp->getResult(0);
563
564 if (!accumulator) {
565 // This is the first reduction.
566 accumulator = result;
567 } else {
568 // On subsequent reduction, combine with the accumulator.
569 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
570 accumulator, result);
571 }
572 }
573
574 rewriter.replaceOp(reductionOp, accumulator);
575 return success();
576 }
577
578private:
579 const vector::UnrollVectorOptions options;
580};
581
582struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
583 UnrollTransposePattern(MLIRContext *context,
584 const vector::UnrollVectorOptions &options,
585 PatternBenefit benefit = 1)
586 : OpRewritePattern<vector::TransposeOp>(context, benefit),
587 options(options) {}
588
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590 PatternRewriter &rewriter) const override {
591 if (transposeOp.getResultVectorType().getRank() == 0)
592 return failure();
593 auto targetShape = getTargetShape(options, transposeOp);
594 if (!targetShape)
595 return failure();
596 auto originalVectorType = transposeOp.getResultVectorType();
597 SmallVector<int64_t> strides(targetShape->size(), 1);
598 Location loc = transposeOp.getLoc();
599 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
600
601 // Prepare the result vector;
602 Value result =
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
604 rewriter.getZeroAttr(originalVectorType));
605 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
606
607 // Unroll the computation.
608 for (SmallVector<int64_t> elementOffsets :
609 StaticTileOffsetRange(originalSize, *targetShape)) {
610 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611 SmallVector<int64_t> permutedShape(elementOffsets.size());
612 // Compute the source offsets and shape.
613 for (auto indices : llvm::enumerate(permutation)) {
614 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
615 permutedShape[indices.value()] = (*targetShape)[indices.index()];
616 }
617 Value slicedOperand =
618 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
620 strides);
621 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
622 loc, slicedOperand, permutation);
623 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624 loc, transposedSlice, result, elementOffsets, strides);
625 }
626 rewriter.replaceOp(transposeOp, result);
627 return success();
628 }
629
630private:
631 vector::UnrollVectorOptions options;
632};
633
634struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
635 UnrollGatherPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
639 }
640
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642 PatternRewriter &rewriter) const override {
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
645 return failure();
646 auto targetShape = getTargetShape(options, gatherOp);
647 if (!targetShape)
648 return failure();
649 SmallVector<int64_t> strides(targetShape->size(), 1);
650 Location loc = gatherOp.getLoc();
651 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
652
653 // Prepare the result vector;
654 Value result =
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
656 rewriter.getZeroAttr(sourceVectorType));
657 auto targetType =
658 VectorType::get(*targetShape, sourceVectorType.getElementType());
659
660 SmallVector<int64_t> loopOrder =
661 getUnrollOrder(originalSize.size(), gatherOp, options);
662 for (SmallVector<int64_t> elementOffsets :
663 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
664 // To get the unrolled gather, extract the same slice based on the
665 // decomposed shape from each of the index, mask, and pass-through
666 // vectors.
667 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
672 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
674 strides);
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
678
679 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
680 loc, slicedGather, result, elementOffsets, strides);
681 }
682 rewriter.replaceOp(gatherOp, result);
683 return success();
684 }
685
686private:
687 vector::UnrollVectorOptions options;
688};
689
690struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
691 UnrollLoadPattern(MLIRContext *context,
692 const vector::UnrollVectorOptions &options,
693 PatternBenefit benefit = 1)
694 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
695
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter) const override {
698 VectorType vecType = loadOp.getVectorType();
699
700 auto targetShape = getTargetShape(options, loadOp);
701 if (!targetShape)
702 return failure();
703
704 Location loc = loadOp.getLoc();
705 ArrayRef<int64_t> originalShape = vecType.getShape();
706 SmallVector<int64_t> strides(targetShape->size(), 1);
707
708 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
709 rewriter.getZeroAttr(vecType));
710
711 SmallVector<int64_t> loopOrder =
712 getUnrollOrder(originalShape.size(), loadOp, options);
713
714 auto targetVecType =
715 VectorType::get(*targetShape, vecType.getElementType());
716
717 for (SmallVector<int64_t> offsets :
718 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
719 SmallVector<Value> indices =
720 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
722 loadOp.getBase(), indices);
723 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
724 loc, slicedLoad, result, offsets, strides);
725 }
726 rewriter.replaceOp(loadOp, result);
727 return success();
728 }
729
730private:
731 vector::UnrollVectorOptions options;
732};
733
734struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
735 UnrollStorePattern(MLIRContext *context,
736 const vector::UnrollVectorOptions &options,
737 PatternBenefit benefit = 1)
738 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
739
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741 PatternRewriter &rewriter) const override {
742 VectorType vecType = storeOp.getVectorType();
743
744 auto targetShape = getTargetShape(options, storeOp);
745 if (!targetShape)
746 return failure();
747
748 Location loc = storeOp.getLoc();
749 ArrayRef<int64_t> originalShape = vecType.getShape();
750 SmallVector<int64_t> strides(targetShape->size(), 1);
751
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
754
755 SmallVector<int64_t> loopOrder =
756 getUnrollOrder(originalShape.size(), storeOp, options);
757
758 for (SmallVector<int64_t> offsets :
759 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
760 SmallVector<Value> indices =
761 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
762 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base, indices);
765 }
766 rewriter.eraseOp(storeOp);
767 return success();
768 }
769
770private:
771 vector::UnrollVectorOptions options;
772};
773
774struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
775 UnrollBroadcastPattern(MLIRContext *context,
776 const vector::UnrollVectorOptions &options,
777 PatternBenefit benefit = 1)
778 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
779 options(options) {}
780
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782 PatternRewriter &rewriter) const override {
783 auto targetShape = getTargetShape(options, broadcastOp);
784 if (!targetShape)
785 return failure();
786
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value result = arith::ConstantOp::create(rewriter, loc, resType,
793 rewriter.getZeroAttr(resType));
794
795 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796 SmallVector<int64_t> strides(originalShape.size(), 1);
797
798 for (SmallVector<int64_t> offsets :
799 StaticTileOffsetRange(originalShape, *targetShape)) {
800 Value newSrc;
801 if (!srcType) {
802 // Scalar to vector broadcast.
803 newSrc = broadcastOp.getSource();
804 } else {
805 // Vector to vector broadcast.
806 int64_t rank = srcType.getRank();
807 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808 SmallVector<int64_t> srcShape(targetShape->end() - rank,
809 targetShape->end());
810 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
811 // adjust the offset and shape for src if the corresponding dim is 1.
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
814 srcOffsets[i] = 0;
815 srcShape[i] = 1;
816 }
817 }
818 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
820 }
821
822 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
823 newSrc, targetType);
824
825 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
826 loc, newOp->getResult(0), result, offsets, strides);
827 }
828
829 rewriter.replaceOp(broadcastOp, result);
830 return success();
831 }
832
833private:
834 vector::UnrollVectorOptions options;
835};
836
837/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838/// outermost dimension of the operand. For example:
839///
840/// ```
841/// %0:4 = vector.to_elements %v : vector<2x2xf32>
842///
843/// ==>
844///
845/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
846/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
847/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
848/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
849/// ```
850///
851/// When this pattern is applied until a fixed-point is reached,
852/// this will produce a sequence of 1-d from_elements
853/// ops.
854struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
855 UnrollToElements(MLIRContext *context,
856 const vector::UnrollVectorOptions &options,
857 PatternBenefit benefit = 1)
858 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
859 options(options) {}
860
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
862 PatternRewriter &rewriter) const override {
863
864 TypedValue<VectorType> source = op.getSource();
865 FailureOr<SmallVector<Value>> result =
866 vector::unrollVectorValue(source, rewriter);
867 if (failed(result)) {
868 return failure();
869 }
870 SmallVector<Value> vectors = *result;
871
872 SmallVector<Value> results;
873 for (Value vector : vectors) {
874 auto subElements =
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
877 }
878 rewriter.replaceOp(op, results);
879 return success();
880 }
881
882private:
883 vector::UnrollVectorOptions options;
884};
885
886/// This pattern unrolls `vector.step` operations according to the provided
887/// target unroll shape. It decomposes a large step vector into smaller step
888/// vectors (segments) and assembles the result by inserting each computed
889/// segment into the appropriate offset of the original vector.
890///
891/// The pattern does not support scalable vectors and will fail to match them.
892///
893/// For each segment, it adds the base step vector and the segment's offset,
894/// then inserts the result into the output vector at the corresponding
895/// position.
896///
897/// Example:
898/// Given a step operation:
899/// %0 = vector.step : vector<8xindex>
900///
901/// and a target unroll shape of <4>, the pattern produces:
902///
903/// %base = vector.step : vector<4xindex>
904/// %zero = arith.constant dense<0> : vector<8xindex>
905/// %result0 = vector.insert_strided_slice %base, %zero
906/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
907/// %offset = arith.constant dense<4> : vector<4xindex>
908/// %segment1 = arith.addi %base, %offset : vector<4xindex>
909/// %result1 = vector.insert_strided_slice %segment1, %result0
910/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
911///
912struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
913 UnrollStepPattern(MLIRContext *context,
914 const vector::UnrollVectorOptions &options,
915 PatternBenefit benefit = 1)
916 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
917
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
919 PatternRewriter &rewriter) const override {
920 std::optional<SmallVector<int64_t>> targetShape =
921 getTargetShape(options, stepOp);
922 if (!targetShape)
923 return failure();
924
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
927 // Scalable vectors are not supported by this pattern.
928 return failure();
929 }
930 int64_t originalSize = vecType.getShape()[0];
931 Location loc = stepOp.getLoc();
932 SmallVector<int64_t> strides(1, 1);
933
934 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
935 rewriter.getZeroAttr(vecType));
936
937 auto targetVecType =
938 VectorType::get(*targetShape, vecType.getElementType());
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940 for (const SmallVector<int64_t> &offsets :
941 StaticTileOffsetRange({originalSize}, *targetShape)) {
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
945 targetVecType,
946 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
947 Value tileStep =
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
949
950 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
951 loc, tileStep, result, offsets, strides);
952 }
953 rewriter.replaceOp(stepOp, result);
954 return success();
955 }
956
957private:
958 vector::UnrollVectorOptions options;
959};
960
961/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
962/// outermost dimension. For example:
963/// ```
964/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
965///
966/// ==>
967///
968/// %0 = ub.poison : vector<2x3xf32>
969/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
970/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
971/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
972/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
973/// ```
974///
975/// When this pattern is applied until a fixed-point is reached,
976/// this will produce a sequence of 1-d from_elements
977/// ops.
978struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
979 UnrollFromElements(MLIRContext *context,
980 const vector::UnrollVectorOptions &options,
981 PatternBenefit benefit = 1)
982 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
983 options(options) {}
984
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
986 PatternRewriter &rewriter) const override {
987 ValueRange allElements = op.getElements();
988
989 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
993 "out of bounds");
994 ValueRange subElements =
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
997 };
998
999 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1000 }
1001
1002private:
1003 vector::UnrollVectorOptions options;
1004};
1005
1006/// This pattern unrolls `vector.create_mask` operations into smaller mask
1007/// operations based on the target unroll shape. Each unrolled slice computes
1008/// its local mask size in each dimension (d) as:
1009/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1010/// Example:
1011/// Given a create_mask operation:
1012/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1013/// elements
1014///
1015/// and a target unroll shape of <4x8>, the pattern produces:
1016///
1017/// %false = arith.constant dense<false> : vector<8x16xi1>
1018///
1019/// Slice [0,0]:
1020/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1021/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1022/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1023/// : vector<4x8xi1> into vector<8x16xi1>
1024/// Slice [0,8]:
1025/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1026/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1027/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1028/// : vector<4x8xi1> into vector<8x16xi1>
1029/// Slice [4,0]:
1030/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1031/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1032/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1033/// : vector<4x8xi1> into vector<8x16xi1>
1034/// Slice [4,8]:
1035/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1036/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1037/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1038/// : vector<4x8xi1> into vector<8x16xi1>
1039struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1040 UnrollCreateMaskPattern(MLIRContext *context,
1041 const vector::UnrollVectorOptions &options,
1042 PatternBenefit benefit = 1)
1043 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1044 options(options) {}
1045
1046 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1047 PatternRewriter &rewriter) const override {
1048 auto targetShape = getTargetShape(options, createMaskOp);
1049 if (!targetShape)
1050 return failure();
1051
1052 VectorType resultType = createMaskOp.getVectorType();
1053 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1054 Location loc = createMaskOp.getLoc();
1055
1056 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1057 rewriter.getZeroAttr(resultType));
1058 VectorType targetVectorType =
1059 VectorType::get(*targetShape, rewriter.getI1Type());
1060 SmallVector<int64_t> strides(targetShape->size(), 1);
1061
1062 // In each dimension (d), each unrolled vector computes its mask size as:
1063 // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1064 for (SmallVector<int64_t> offsets :
1065 StaticTileOffsetRange(originalSize, *targetShape)) {
1066 SmallVector<Value> unrolledOperands;
1067
1068 for (auto [i, originalMaskOperand] :
1069 llvm::enumerate(createMaskOp.getOperands())) {
1070 Value offsetVal =
1071 arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1072 Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1073 loc, originalMaskOperand, offsetVal);
1074 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1075 Value unrolledDimSize =
1076 arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1077 Value nonNegative =
1078 rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079 Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1080 loc, nonNegative, unrolledDimSize);
1081 unrolledOperands.push_back(unrolledOperand);
1082 }
1083
1084 auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1085 loc, targetVectorType, unrolledOperands);
1086 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1087 loc, unrolledMask, result, offsets, strides);
1088 }
1089 rewriter.replaceOp(createMaskOp, result);
1090 return success();
1091 }
1092
1093private:
1094 vector::UnrollVectorOptions options;
1095};
1096
1097/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1098/// operations based on the target unroll shape. Each unrolled slice computes
1099/// whether its elements should be masked based on the original mask dimensions
1100/// and the slice's offset position.
1101///
1102/// Example:
1103/// Given a constant_mask operation:
1104/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1105///
1106/// and a target unroll shape of <4x8>, the pattern produces:
1107///
1108/// %false = arith.constant dense<false> : vector<8x16xi1>
1109///
1110/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1111/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1112/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1113/// : vector<4x8xi1> into vector<8x16xi1>
1114///
1115/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1116/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1117/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1118/// : vector<4x8xi1> into vector<8x16xi1>
1119///
1120/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1121/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1122/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1123/// : vector<4x8xi1> into vector<8x16xi1>
1124///
1125/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1126/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1127/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1128/// : vector<4x8xi1> into vector<8x16xi1>
1129struct UnrollConstantMaskPattern
1130 : public OpRewritePattern<vector::ConstantMaskOp> {
1131 UnrollConstantMaskPattern(MLIRContext *context,
1132 const vector::UnrollVectorOptions &options,
1133 PatternBenefit benefit = 1)
1134 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1135 options(options) {}
1136
1137 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1138 PatternRewriter &rewriter) const override {
1139 std::optional<SmallVector<int64_t>> targetShape =
1140 getTargetShape(options, constantMaskOp);
1141 if (!targetShape)
1142 return failure();
1143
1144 VectorType resultType = constantMaskOp.getVectorType();
1145 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1146 Location loc = constantMaskOp.getLoc();
1147
1148 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1149 rewriter.getZeroAttr(resultType));
1150 VectorType targetVectorType =
1151 VectorType::get(*targetShape, rewriter.getI1Type());
1152 SmallVector<int64_t> strides(targetShape->size(), 1);
1153
1154 // In each dimension (d), each unrolled vector computes its mask size as:
1155 // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1156 for (const SmallVector<int64_t> &offsets :
1157 StaticTileOffsetRange(originalSize, *targetShape)) {
1158 SmallVector<int64_t> unrolledMaskDims;
1159
1160 for (auto [i, originalMaskDim] :
1161 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1162 // Calculate how many elements in this dimension should be masked
1163 // for this particular slice
1164 int64_t adjustedMaskSize =
1165 std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1166 int64_t unrolledMaskDim =
1167 std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1168 unrolledMaskDims.push_back(unrolledMaskDim);
1169 }
1170
1171 auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1172 loc, targetVectorType, unrolledMaskDims);
1173 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1174 loc, unrolledMask, result, offsets, strides);
1175 }
1176 rewriter.replaceOp(constantMaskOp, result);
1177 return success();
1178 }
1179
1180private:
1181 vector::UnrollVectorOptions options;
1182};
1183
1184/// Checks whether extractShape is a contiguous slice of shape.
1185/// For extractShape to be contiguous in shape:
1186/// 1) All but the leading dimension of extractShape and shape must match
1187/// exactly. 2) The total number of elements in shape must be evenly divisible
1188/// by
1189/// the total number of elements in extractShape.
1190/// Examples:
1191/// isContiguous([4, 4], [8, 4]) == true
1192/// isContiguous([2, 4], [8, 4]) == true
1193/// isContiguous([2, 2], [8, 4]) == false
1194/// Removes leading unit dimensions to handle cases like:
1195/// isContiguous([1, 16], [1, 32]) == true
1196static bool isContiguous(ArrayRef<int64_t> extractShape,
1198
1199 if (extractShape.size() > shape.size())
1200 return false;
1201
1202 while (!extractShape.empty() && extractShape.front() == 1) {
1203 extractShape = extractShape.drop_front();
1204 }
1205
1206 while (!shape.empty() && shape.front() == 1) {
1207 shape = shape.drop_front();
1208 }
1209
1210 size_t rankDiff = shape.size() - extractShape.size();
1211 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1212 return false;
1213
1214 int64_t extractElements = ShapedType::getNumElements(extractShape);
1215 int64_t shapeElements = ShapedType::getNumElements(shape);
1216 return shapeElements % extractElements == 0;
1217}
1218
1219/// Determines what shape to use with `vector.extract_strided_slice` to extract
1220/// a contiguous memory region from a source vector. The extraction must be
1221/// contiguous and contain exactly the specified number of elements. If such an
1222/// extraction shape cannot be determined, returns std::nullopt.
1223/// EXAMPLE 1:
1224/// sourceShape = [16], targetElements = 8
1225/// Working right-to-left:
1226/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1227/// remaining = 8/8 = 1
1228/// Result: [8]
1229///
1230/// EXAMPLE 2:
1231/// sourceShape = [4, 4], targetElements = 8
1232/// Working right-to-left:
1233/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1234/// remaining = 8/4 = 2
1235/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1236/// remaining = 2/2 = 1
1237/// Result: [2, 4]
1238static std::optional<SmallVector<int64_t>>
1239calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1240 int64_t targetElements) {
1241 SmallVector<int64_t> extractShape;
1242 int64_t remainingElements = targetElements;
1243
1244 // Build extract shape from innermost dimension outward to ensure contiguity.
1245 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1246 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1247 extractShape.insert(extractShape.begin(), takeFromDim);
1248
1249 if (remainingElements % takeFromDim != 0)
1250 return std::nullopt; // Not evenly divisible.
1251 remainingElements /= takeFromDim;
1252 }
1253
1254 // Fill remaining dimensions with 1.
1255 while (extractShape.size() < sourceShape.size())
1256 extractShape.insert(extractShape.begin(), 1);
1257
1258 if (ShapedType::getNumElements(extractShape) != targetElements)
1259 return std::nullopt;
1260
1261 return extractShape;
1262}
1263
1264// Convert result offsets to source offsets via linear position.
1266calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1267 ArrayRef<int64_t> sourceShape,
1268 ArrayRef<int64_t> resultShape) {
1269 // Convert result offsets to linear position.
1270 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1271 // Convert linear position to source offsets.
1272 return delinearize(linearIndex, computeStrides(sourceShape));
1273}
1274
1275/// This pattern unrolls `vector.shape_cast` operations according to the
1276/// provided target unroll shape. It unrolls a large shape cast into smaller
1277/// shape casts by extracting contiguous slices from the source vector, casting
1278/// each slice to the target shape, and assembling the result by inserting each
1279/// computed segment into the appropriate offset of the result vector.
1280///
1281/// This pattern only applies when contiguous slices can be extracted from the
1282/// source vector and inserted into the result vector such that each slice
1283/// remains a valid vector (and not decompose to scalars). In these cases, the
1284/// unrolling proceeds as:
1285/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1286/// vector.insert_strided_slice.
1287///
1288/// Example:
1289/// Given a shape cast operation:
1290/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1291///
1292/// and a target unroll shape of <2x4>, the pattern produces:
1293///
1294/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1295/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1296/// : vector<8x2xf32> to vector<4x2xf32>
1297/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1298/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1299/// : vector<2x4xf32> into vector<4x4xf32>
1300/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1301/// : vector<8x2xf32> to vector<4x2xf32>
1302/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1303/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1304/// : vector<2x4xf32> into vector<4x4xf32>
1305///
1306struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1307 UnrollShapeCastPattern(MLIRContext *context,
1308 const vector::UnrollVectorOptions &options,
1309 PatternBenefit benefit = 1)
1310 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1311 options(options) {}
1312
1313 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1314 PatternRewriter &rewriter) const override {
1315 std::optional<SmallVector<int64_t>> targetShape =
1316 getTargetShape(options, shapeCastOp);
1317 if (!targetShape)
1318 return failure();
1319
1320 VectorType sourceType = shapeCastOp.getSourceVectorType();
1321 VectorType resultType = shapeCastOp.getResultVectorType();
1322 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1323 ArrayRef<int64_t> resultShape = resultType.getShape();
1324
1325 if (!isContiguous(*targetShape, resultShape))
1326 return rewriter.notifyMatchFailure(
1327 shapeCastOp, "Only supports cases where target shape is "
1328 "contiguous in result vector shape");
1329
1330 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1331
1332 // Calculate the shape to extract from source.
1333 std::optional<SmallVector<int64_t>> extractShape =
1334 calculateSourceExtractShape(sourceShape, targetElements);
1335 if (!extractShape)
1336 return rewriter.notifyMatchFailure(
1337 shapeCastOp,
1338 "cannot extract target number of elements contiguously from source");
1339
1340 Location loc = shapeCastOp.getLoc();
1341
1342 // Create result vector initialized to zero.
1343 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1344 rewriter.getZeroAttr(resultType));
1345
1346 VectorType targetType =
1347 VectorType::get(*targetShape, sourceType.getElementType());
1348
1349 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1350 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1351
1352 for (SmallVector<int64_t> resultOffsets :
1353 StaticTileOffsetRange(resultShape, *targetShape)) {
1354 SmallVector<int64_t> sourceOffsets =
1355 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1356 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1357 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1359 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1360 loc, targetType, sourceChunk);
1361 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1362 loc, targetChunk, result, resultOffsets, insertStrides);
1363 }
1364
1365 rewriter.replaceOp(shapeCastOp, result);
1366 return success();
1367 }
1368
1369private:
1370 vector::UnrollVectorOptions options;
1371};
1372
1373} // namespace
1374
1375void mlir::vector::populateVectorUnrollPatterns(
1377 PatternBenefit benefit) {
1378 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1379 UnrollContractionPattern, UnrollElementwisePattern,
1380 UnrollReductionPattern, UnrollMultiReductionPattern,
1381 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1382 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1383 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1384 UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1385 patterns.getContext(), options, benefit);
1386}
1387
1388void mlir::vector::populateVectorToElementsUnrollPatterns(
1390 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1391 benefit);
1392}
1393
1394void mlir::vector::populateVectorFromElementsUnrollPatterns(
1396 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1397 benefit);
1398}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.