MLIR 23.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29SmallVector<Value> mlir::vector::sliceTransferIndices(
31 AffineMap permutationMap, Location loc, OpBuilder &builder) {
32 MLIRContext *ctx = builder.getContext();
33 auto isBroadcast = [](AffineExpr expr) {
34 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
35 return constExpr.getValue() == 0;
36 return false;
37 };
38 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
39 SmallVector<Value> slicedIndices(indices);
40 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
41 int64_t elementOffset = elementOffsets[dim.index()];
42 if (isBroadcast(dim.value()) || elementOffset == 0)
43 continue;
44 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
45 auto expr = getAffineDimExpr(0, builder.getContext()) +
46 getAffineConstantExpr(elementOffset, ctx);
47 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
48 slicedIndices[pos] =
49 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
50 }
51 return slicedIndices;
52}
53
54// Compute the new indices by adding `offsets` to `originalIndices`.
55// If m < n (m = offsets.size(), n = originalIndices.size()),
56// then only the trailing m values in `originalIndices` are updated.
58 Location loc,
59 OperandRange originalIndices,
60 ArrayRef<int64_t> offsets) {
61 assert(offsets.size() <= originalIndices.size() &&
62 "Offsets should not exceed the number of original indices");
63 SmallVector<Value> indices(originalIndices);
64
65 auto start = indices.size() - offsets.size();
66 for (auto [i, offset] : llvm::enumerate(offsets)) {
67 if (offset != 0) {
68 indices[start + i] = arith::AddIOp::create(
69 rewriter, loc, originalIndices[start + i],
70 arith::ConstantIndexOp::create(rewriter, loc, offset));
71 }
72 }
73 return indices;
74}
75
76// Clones `op` into a new operations that takes `operands` and returns
77// `resultTypes`.
79 Operation *op,
80 ArrayRef<Value> operands,
81 ArrayRef<Type> resultTypes) {
82 return builder.create(loc, op->getName().getIdentifier(), operands,
83 resultTypes, op->getAttrs());
84}
85
86/// Return the target shape for unrolling for the given `op`. Return
87/// std::nullopt if the op shouldn't be or cannot be unrolled.
88static std::optional<SmallVector<int64_t>>
90 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
91 if (options.filterConstraint && failed(options.filterConstraint(op))) {
92 LDBG() << "--no filter constraint -> BAIL";
93 return std::nullopt;
94 }
95 assert(options.nativeShape &&
96 "vector unrolling expects the native shape or native"
97 "shape call back function to be set");
98 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
99 if (!unrollableVectorOp) {
100 LDBG() << "--not an unrollable op -> BAIL";
101 return std::nullopt;
102 }
103 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
104 if (!maybeUnrollShape) {
105 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
106 return std::nullopt;
107 }
108 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
109
110 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
111 if (!targetShape) {
112 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
113 return std::nullopt;
114 }
115 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
116
117 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
118 if (!maybeShapeRatio) {
119 LDBG() << "--could not compute integral shape ratio -> BAIL";
120 return std::nullopt;
121 }
122 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
123 LDBG() << "--no unrolling needed -> SKIP";
124 return std::nullopt;
125 }
126 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
127 return targetShape;
128}
129
131getUnrollOrder(unsigned numLoops, Operation *op,
133 SmallVector<int64_t> loopOrder =
134 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
135 if (options.traversalOrderCallback != nullptr) {
136 std::optional<SmallVector<int64_t>> order =
137 options.traversalOrderCallback(op);
138 if (order) {
139 loopOrder = std::move(*order);
140 }
141 }
142 return loopOrder;
143}
144
145namespace {
146
147struct UnrollTransferReadPattern
148 : public OpRewritePattern<vector::TransferReadOp> {
149 UnrollTransferReadPattern(MLIRContext *context,
150 const vector::UnrollVectorOptions &options,
151 PatternBenefit benefit = 1)
152 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
153 options(options) {}
154
155 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
156 PatternRewriter &rewriter) const override {
157 // TODO: support 0-d corner case.
158 if (readOp.getTransferRank() == 0)
159 return failure();
160 if (readOp.getMask())
161 return failure();
162 auto targetShape = getTargetShape(options, readOp);
163 if (!targetShape)
164 return failure();
165 auto sourceVectorType = readOp.getVectorType();
166 SmallVector<int64_t> strides(targetShape->size(), 1);
167 Location loc = readOp.getLoc();
168 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
169
170 // Prepare the result vector;
171 Value result =
172 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
173 rewriter.getZeroAttr(sourceVectorType));
174 auto targetType =
175 VectorType::get(*targetShape, sourceVectorType.getElementType());
176 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
177 readOp.getIndices().end());
178 SmallVector<int64_t> loopOrder =
179 getUnrollOrder(originalSize.size(), readOp, options);
180 for (SmallVector<int64_t> elementOffsets :
181 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
182 SmallVector<Value> indices =
183 sliceTransferIndices(elementOffsets, originalIndices,
184 readOp.getPermutationMap(), loc, rewriter);
185 auto slicedRead = vector::TransferReadOp::create(
186 rewriter, loc, targetType, readOp.getBase(), indices,
187 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
188 readOp.getInBoundsAttr());
189
190 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
191 loc, slicedRead, result, elementOffsets, strides);
192 }
193 rewriter.replaceOp(readOp, result);
194 return success();
195 }
196
197private:
198 vector::UnrollVectorOptions options;
199};
200
201struct UnrollTransferWritePattern
202 : public OpRewritePattern<vector::TransferWriteOp> {
203 UnrollTransferWritePattern(MLIRContext *context,
204 const vector::UnrollVectorOptions &options,
205 PatternBenefit benefit = 1)
206 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
207 options(options) {}
208
209 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
210 PatternRewriter &rewriter) const override {
211 // TODO: support 0-d corner case.
212 if (writeOp.getTransferRank() == 0)
213 return failure();
214
215 if (writeOp.getMask())
216 return failure();
217 auto targetShape = getTargetShape(options, writeOp);
218 if (!targetShape)
219 return failure();
220 auto sourceVectorType = writeOp.getVectorType();
221 SmallVector<int64_t> strides(targetShape->size(), 1);
222 Location loc = writeOp.getLoc();
223 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
224 // Bail-out if rank(source) != rank(target). The main limitation here is the
225 // fact that `ExtractStridedSlice` requires the rank for the input and
226 // output to match. If needed, we can relax this later.
227 if (originalSize.size() != targetShape->size())
228 return rewriter.notifyMatchFailure(
229 writeOp,
230 "expected source input vector rank to match target shape rank");
231
232 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
233 writeOp.getIndices().end());
234 SmallVector<int64_t> loopOrder =
235 getUnrollOrder(originalSize.size(), writeOp, options);
236 Value resultTensor;
237 for (SmallVector<int64_t> elementOffsets :
238 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
239 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
240 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
241 SmallVector<Value> indices =
242 sliceTransferIndices(elementOffsets, originalIndices,
243 writeOp.getPermutationMap(), loc, rewriter);
244 Operation *slicedWrite = vector::TransferWriteOp::create(
245 rewriter, loc, slicedVector,
246 resultTensor ? resultTensor : writeOp.getBase(), indices,
247 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
248 // For the tensor case update the destination for the next transfer write.
249 if (!slicedWrite->getResults().empty())
250 resultTensor = slicedWrite->getResult(0);
251 }
252 if (resultTensor)
253 rewriter.replaceOp(writeOp, resultTensor);
254 else
255 rewriter.eraseOp(writeOp);
256 return success();
257 }
258
259private:
260 vector::UnrollVectorOptions options;
261};
262
263struct OffsetMapInfo {
264 static unsigned getHashValue(const SmallVector<int64_t> &v) {
265 return static_cast<unsigned>(llvm::hash_combine_range(v));
266 }
267
268 static bool isEqual(const SmallVector<int64_t> &lhs,
269 const SmallVector<int64_t> &rhs) {
270 return lhs == rhs;
271 }
272};
273
274struct UnrollContractionPattern
275 : public OpRewritePattern<vector::ContractionOp> {
276 UnrollContractionPattern(MLIRContext *context,
277 const vector::UnrollVectorOptions &options,
278 PatternBenefit benefit = 1)
279 : OpRewritePattern<vector::ContractionOp>(context, benefit),
280 options(options) {}
281
282 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
283 PatternRewriter &rewriter) const override {
284 auto targetShape = getTargetShape(options, contractOp);
285 if (!targetShape)
286 return failure();
287 auto dstVecType = cast<VectorType>(contractOp.getResultType());
288 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
289
290 Location loc = contractOp.getLoc();
291 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
292 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
293 llvm::MapVector<
294 SmallVector<int64_t>, Value,
295 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
296 accCache;
297
298 SmallVector<int64_t> loopOrder = getUnrollOrder(
299 contractOp.getIteratorTypes().size(), contractOp, options);
300
301 for (SmallVector<int64_t> offsets :
302 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
303 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
304
305 // Helper to compute the new shape of each operand and extract the slice.
306 auto extractOperand = [&](unsigned index, Value operand,
307 AffineMap permutationMap,
308 ArrayRef<int64_t> operandOffets) {
309 SmallVector<int64_t> operandShape = applyPermutationMap(
310 permutationMap, ArrayRef<int64_t>(*targetShape));
311 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
312 slicesOperands[index] =
313 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
314 loc, operand, operandOffets, operandShape, operandStrides);
315 };
316
317 // Extract the new lhs operand.
318 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
319 SmallVector<int64_t> lhsOffets =
320 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
321 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
322
323 // Extract the new rhs operand.
324 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
325 SmallVector<int64_t> rhsOffets =
326 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
327 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
328
329 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
330 SmallVector<int64_t> accOffets =
331 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
332 // If a version of the accumulator has already been computed, use it
333 // otherwise extract the first version from the original operand.
334 auto *accIt = accCache.find(accOffets);
335 if (accIt != accCache.end())
336 slicesOperands[2] = accIt->second;
337 else
338 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
339
340 SmallVector<int64_t> dstShape =
341 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
342 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
343 Operation *newOp = cloneOpWithOperandsAndTypes(
344 rewriter, loc, contractOp, slicesOperands, targetType);
345
346 SmallVector<int64_t> dstOffets =
347 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
348 // Save the accumulated value untill all the loops are unrolled since
349 // reduction loop keep updating the accumulator.
350 accCache[dstOffets] = newOp->getResult(0);
351 }
352 // Assemble back the accumulator into a single vector.
353 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
354 rewriter.getZeroAttr(dstVecType));
355 for (const auto &it : accCache) {
356 SmallVector<int64_t> dstStrides(it.first.size(), 1);
357 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
358 loc, it.second, result, it.first, dstStrides);
359 }
360 rewriter.replaceOp(contractOp, result);
361 return success();
362 }
363
364private:
365 vector::UnrollVectorOptions options;
366};
367
368struct UnrollMultiReductionPattern
369 : public OpRewritePattern<vector::MultiDimReductionOp> {
370 UnrollMultiReductionPattern(MLIRContext *context,
371 const vector::UnrollVectorOptions &options,
372 PatternBenefit benefit = 1)
373 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
374 options(options) {}
375
376 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
377 PatternRewriter &rewriter) const override {
378 std::optional<SmallVector<int64_t>> targetShape =
379 getTargetShape(options, reductionOp);
380 if (!targetShape)
381 return failure();
382 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
383 Location loc = reductionOp.getLoc();
384 auto resultType = reductionOp->getResult(0).getType();
385
386 // Handle scalar result case: all dimensions are reduced.
387 // Each source tile is reduced to a scalar, and partial results are
388 // chained through the accumulator operand.
389 if (resultType.isIntOrFloat()) {
390 Value accumulator = reductionOp.getAcc();
391 for (SmallVector<int64_t> offsets :
392 StaticTileOffsetRange(originalSize, *targetShape)) {
393 SmallVector<int64_t> operandStrides(offsets.size(), 1);
394 Value slicedOperand =
395 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
396 loc, reductionOp.getSource(), offsets, *targetShape,
397 operandStrides);
398 Operation *newOp = cloneOpWithOperandsAndTypes(
399 rewriter, loc, reductionOp, {slicedOperand, accumulator},
400 resultType);
401 accumulator = newOp->getResult(0);
402 }
403 rewriter.replaceOp(reductionOp, accumulator);
404 return success();
405 }
406
407 // Vector result case.
408 llvm::MapVector<
409 SmallVector<int64_t>, Value,
410 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
411 accCache;
412
413 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
414 // of multiples of the targetShape.
415 for (SmallVector<int64_t> offsets :
416 StaticTileOffsetRange(originalSize, *targetShape)) {
417 SmallVector<Value> operands;
418 SmallVector<int64_t> operandStrides(offsets.size(), 1);
419 Value slicedOperand =
420 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
421 loc, reductionOp.getSource(), offsets, *targetShape,
422 operandStrides);
423 operands.push_back(slicedOperand);
424 SmallVector<int64_t> dstShape;
425 SmallVector<int64_t> destOffset;
426 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
427 if (!reductionOp.isReducedDim(i)) {
428 destOffset.push_back(offsets[i]);
429 dstShape.push_back((*targetShape)[i]);
430 }
431 }
432 Value acc;
433 SmallVector<int64_t> accStrides(destOffset.size(), 1);
434 // If a version of the accumulator has already been computed, use it
435 // otherwise extract the first version from the original operand.
436 auto *accIt = accCache.find(destOffset);
437 if (accIt != accCache.end())
438 acc = accIt->second;
439 else
440 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
441 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
442 operands.push_back(acc);
443 auto targetType = VectorType::get(
444 dstShape, reductionOp.getSourceVectorType().getElementType());
445 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
446 operands, targetType);
447 Value result = newOp->getResult(0);
448 accCache[destOffset] = result;
449 }
450 // Assemble back the accumulator into a single vector.
451 Value result = arith::ConstantOp::create(
452 rewriter, loc, reductionOp.getDestType(),
453 rewriter.getZeroAttr(reductionOp.getDestType()));
454 for (const auto &it : accCache) {
455 SmallVector<int64_t> dstStrides(it.first.size(), 1);
456 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
457 loc, it.second, result, it.first, dstStrides);
458 }
459 rewriter.replaceOp(reductionOp, result);
460 return success();
461 }
462
463private:
464 vector::UnrollVectorOptions options;
465};
466
467struct UnrollElementwisePattern : public RewritePattern {
468 UnrollElementwisePattern(MLIRContext *context,
469 const vector::UnrollVectorOptions &options,
470 PatternBenefit benefit = 1)
471 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
472 options(options) {}
473
474 LogicalResult matchAndRewrite(Operation *op,
475 PatternRewriter &rewriter) const override {
477 return failure();
478 auto targetShape = getTargetShape(options, op);
479 if (!targetShape)
480 return failure();
481 int64_t targetShapeRank = targetShape->size();
482 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
483 SmallVector<int64_t> originalSize =
484 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
485 int64_t originalShapeRank = originalSize.size();
486
487 Location loc = op->getLoc();
488
489 // Handle rank mismatch by adding leading unit dimensions to targetShape
490 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
491 int64_t rankDiff = originalShapeRank - targetShapeRank;
492 std::fill(adjustedTargetShape.begin(),
493 adjustedTargetShape.begin() + rankDiff, 1);
494 std::copy(targetShape->begin(), targetShape->end(),
495 adjustedTargetShape.begin() + rankDiff);
496
497 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
498 // Prepare the result vector.
499 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
500 rewriter.getZeroAttr(dstVecType));
501 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
502 VectorType unrolledVecType =
503 VectorType::get(*targetShape, dstVecType.getElementType());
504
505 // Create the unrolled computation.
506 for (SmallVector<int64_t> offsets :
507 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
508 SmallVector<Value> extractOperands;
509 for (OpOperand &operand : op->getOpOperands()) {
510 auto vecType = dyn_cast<VectorType>(operand.get().getType());
511 if (!vecType) {
512 extractOperands.push_back(operand.get());
513 continue;
514 }
515 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
516 loc, operand.get(), offsets, adjustedTargetShape, strides);
517
518 // Reshape to remove leading unit dims if needed
519 if (adjustedTargetShapeRank > targetShapeRank) {
520 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
521 loc, VectorType::get(*targetShape, vecType.getElementType()),
522 extracted);
523 }
524 extractOperands.push_back(extracted);
525 }
526
527 Operation *newOp = cloneOpWithOperandsAndTypes(
528 rewriter, loc, op, extractOperands, unrolledVecType);
529
530 Value computeResult = newOp->getResult(0);
531
532 // Use strides sized to targetShape for proper insertion
533 SmallVector<int64_t> insertStrides =
534 (adjustedTargetShapeRank > targetShapeRank)
535 ? SmallVector<int64_t>(targetShapeRank, 1)
536 : strides;
537
538 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
539 loc, computeResult, result, offsets, insertStrides);
540 }
541 rewriter.replaceOp(op, result);
542 return success();
543 }
544
545private:
546 vector::UnrollVectorOptions options;
547};
548
549struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
550 UnrollReductionPattern(MLIRContext *context,
551 const vector::UnrollVectorOptions &options,
552 PatternBenefit benefit = 1)
553 : OpRewritePattern<vector::ReductionOp>(context, benefit),
554 options(options) {}
555
556 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
557 PatternRewriter &rewriter) const override {
558 std::optional<SmallVector<int64_t>> targetShape =
559 getTargetShape(options, reductionOp);
560 if (!targetShape)
561 return failure();
562 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
563
564 // Create unrolled vector reduction.
565 Location loc = reductionOp.getLoc();
566 Value accumulator = nullptr;
567 for (SmallVector<int64_t> offsets :
568 StaticTileOffsetRange(originalSize, *targetShape)) {
569 SmallVector<int64_t> strides(offsets.size(), 1);
570 Value slicedOperand =
571 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
572 loc, reductionOp.getVector(), offsets, *targetShape, strides);
573 Operation *newOp = cloneOpWithOperandsAndTypes(
574 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
575 Value result = newOp->getResult(0);
576
577 if (!accumulator) {
578 // This is the first reduction.
579 accumulator = result;
580 } else {
581 // On subsequent reduction, combine with the accumulator.
582 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
583 accumulator, result);
584 }
585 }
586
587 rewriter.replaceOp(reductionOp, accumulator);
588 return success();
589 }
590
591private:
592 const vector::UnrollVectorOptions options;
593};
594
595struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
596 UnrollTransposePattern(MLIRContext *context,
597 const vector::UnrollVectorOptions &options,
598 PatternBenefit benefit = 1)
599 : OpRewritePattern<vector::TransposeOp>(context, benefit),
600 options(options) {}
601
602 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
603 PatternRewriter &rewriter) const override {
604 if (transposeOp.getResultVectorType().getRank() == 0)
605 return failure();
606 auto targetShape = getTargetShape(options, transposeOp);
607 if (!targetShape)
608 return failure();
609 auto originalVectorType = transposeOp.getResultVectorType();
610 SmallVector<int64_t> strides(targetShape->size(), 1);
611 Location loc = transposeOp.getLoc();
612 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
613
614 // Prepare the result vector;
615 Value result =
616 arith::ConstantOp::create(rewriter, loc, originalVectorType,
617 rewriter.getZeroAttr(originalVectorType));
618 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
619
620 // Unroll the computation.
621 for (SmallVector<int64_t> elementOffsets :
622 StaticTileOffsetRange(originalSize, *targetShape)) {
623 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
624 SmallVector<int64_t> permutedShape(elementOffsets.size());
625 // Compute the source offsets and shape.
626 for (auto indices : llvm::enumerate(permutation)) {
627 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
628 permutedShape[indices.value()] = (*targetShape)[indices.index()];
629 }
630 Value slicedOperand =
631 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
632 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
633 strides);
634 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
635 loc, slicedOperand, permutation);
636 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
637 loc, transposedSlice, result, elementOffsets, strides);
638 }
639 rewriter.replaceOp(transposeOp, result);
640 return success();
641 }
642
643private:
644 vector::UnrollVectorOptions options;
645};
646
647struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
648 UnrollGatherPattern(MLIRContext *context,
649 const vector::UnrollVectorOptions &options,
650 PatternBenefit benefit = 1)
651 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
652 }
653
654 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
655 PatternRewriter &rewriter) const override {
656 VectorType sourceVectorType = gatherOp.getVectorType();
657 if (sourceVectorType.getRank() == 0)
658 return failure();
659 auto targetShape = getTargetShape(options, gatherOp);
660 if (!targetShape)
661 return failure();
662 SmallVector<int64_t> strides(targetShape->size(), 1);
663 Location loc = gatherOp.getLoc();
664 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
665
666 // Prepare the result vector;
667 Value result =
668 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
669 rewriter.getZeroAttr(sourceVectorType));
670 auto targetType =
671 VectorType::get(*targetShape, sourceVectorType.getElementType());
672
673 SmallVector<int64_t> loopOrder =
674 getUnrollOrder(originalSize.size(), gatherOp, options);
675 for (SmallVector<int64_t> elementOffsets :
676 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
677 // To get the unrolled gather, extract the same slice based on the
678 // decomposed shape from each of the index, mask, and pass-through
679 // vectors.
680 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
681 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
682 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
683 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
684 Value passThruSubVec =
685 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
686 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
687 strides);
688 auto slicedGather = vector::GatherOp::create(
689 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
690 indexSubVec, maskSubVec, passThruSubVec);
691
692 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
693 loc, slicedGather, result, elementOffsets, strides);
694 }
695 rewriter.replaceOp(gatherOp, result);
696 return success();
697 }
698
699private:
700 vector::UnrollVectorOptions options;
701};
702
703struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
704 UnrollLoadPattern(MLIRContext *context,
705 const vector::UnrollVectorOptions &options,
706 PatternBenefit benefit = 1)
707 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
708
709 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
710 PatternRewriter &rewriter) const override {
711 VectorType vecType = loadOp.getVectorType();
712
713 auto targetShape = getTargetShape(options, loadOp);
714 if (!targetShape)
715 return failure();
716
717 Location loc = loadOp.getLoc();
718 ArrayRef<int64_t> originalShape = vecType.getShape();
719 SmallVector<int64_t> strides(targetShape->size(), 1);
720
721 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
722 rewriter.getZeroAttr(vecType));
723
724 SmallVector<int64_t> loopOrder =
725 getUnrollOrder(originalShape.size(), loadOp, options);
726
727 auto targetVecType =
728 VectorType::get(*targetShape, vecType.getElementType());
729
730 for (SmallVector<int64_t> offsets :
731 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
732 SmallVector<Value> indices =
733 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
734 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
735 loadOp.getBase(), indices);
736 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
737 loc, slicedLoad, result, offsets, strides);
738 }
739 rewriter.replaceOp(loadOp, result);
740 return success();
741 }
742
743private:
744 vector::UnrollVectorOptions options;
745};
746
747struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
748 UnrollStorePattern(MLIRContext *context,
749 const vector::UnrollVectorOptions &options,
750 PatternBenefit benefit = 1)
751 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
752
753 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
754 PatternRewriter &rewriter) const override {
755 VectorType vecType = storeOp.getVectorType();
756
757 auto targetShape = getTargetShape(options, storeOp);
758 if (!targetShape)
759 return failure();
760
761 Location loc = storeOp.getLoc();
762 ArrayRef<int64_t> originalShape = vecType.getShape();
763 SmallVector<int64_t> strides(targetShape->size(), 1);
764
765 Value base = storeOp.getBase();
766 Value vector = storeOp.getValueToStore();
767
768 SmallVector<int64_t> loopOrder =
769 getUnrollOrder(originalShape.size(), storeOp, options);
770
771 for (SmallVector<int64_t> offsets :
772 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
773 SmallVector<Value> indices =
774 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
775 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
776 loc, vector, offsets, *targetShape, strides);
777 vector::StoreOp::create(rewriter, loc, slice, base, indices);
778 }
779 rewriter.eraseOp(storeOp);
780 return success();
781 }
782
783private:
784 vector::UnrollVectorOptions options;
785};
786
787struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
788 UnrollBroadcastPattern(MLIRContext *context,
789 const vector::UnrollVectorOptions &options,
790 PatternBenefit benefit = 1)
791 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
792 options(options) {}
793
794 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
795 PatternRewriter &rewriter) const override {
796 auto targetShape = getTargetShape(options, broadcastOp);
797 if (!targetShape)
798 return failure();
799
800 Location loc = broadcastOp.getLoc();
801 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
802 VectorType resType = broadcastOp.getResultVectorType();
803 VectorType targetType =
804 resType.cloneWith(*targetShape, resType.getElementType());
805 Value result = arith::ConstantOp::create(rewriter, loc, resType,
806 rewriter.getZeroAttr(resType));
807
808 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
809 SmallVector<int64_t> strides(originalShape.size(), 1);
810
811 for (SmallVector<int64_t> offsets :
812 StaticTileOffsetRange(originalShape, *targetShape)) {
813 Value newSrc;
814 if (!srcType) {
815 // Scalar to vector broadcast.
816 newSrc = broadcastOp.getSource();
817 } else {
818 // Vector to vector broadcast.
819 int64_t rank = srcType.getRank();
820 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
821 SmallVector<int64_t> srcShape(targetShape->end() - rank,
822 targetShape->end());
823 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
824 // adjust the offset and shape for src if the corresponding dim is 1.
825 for (int64_t i = 0; i < rank; ++i) {
826 if (srcType.getDimSize(i) == 1) {
827 srcOffsets[i] = 0;
828 srcShape[i] = 1;
829 }
830 }
831 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
832 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
833 }
834
835 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
836 newSrc, targetType);
837
838 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
839 loc, newOp->getResult(0), result, offsets, strides);
840 }
841
842 rewriter.replaceOp(broadcastOp, result);
843 return success();
844 }
845
846private:
847 vector::UnrollVectorOptions options;
848};
849
850/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
851/// outermost dimension of the operand. For example:
852///
853/// ```
854/// %0:4 = vector.to_elements %v : vector<2x2xf32>
855///
856/// ==>
857///
858/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
859/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
860/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
861/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
862/// ```
863///
864/// When this pattern is applied until a fixed-point is reached,
865/// this will produce a sequence of 1-d from_elements
866/// ops.
867struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
868 UnrollToElements(MLIRContext *context,
869 const vector::UnrollVectorOptions &options,
870 PatternBenefit benefit = 1)
871 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
872 options(options) {}
873
874 LogicalResult matchAndRewrite(vector::ToElementsOp op,
875 PatternRewriter &rewriter) const override {
876
877 TypedValue<VectorType> source = op.getSource();
878 FailureOr<SmallVector<Value>> result =
879 vector::unrollVectorValue(source, rewriter);
880 if (failed(result)) {
881 return failure();
882 }
883 SmallVector<Value> vectors = *result;
884
885 SmallVector<Value> results;
886 for (Value vector : vectors) {
887 auto subElements =
888 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
889 llvm::append_range(results, subElements.getResults());
890 }
891 rewriter.replaceOp(op, results);
892 return success();
893 }
894
895private:
896 vector::UnrollVectorOptions options;
897};
898
899/// This pattern unrolls `vector.step` operations according to the provided
900/// target unroll shape. It decomposes a large step vector into smaller step
901/// vectors (segments) and assembles the result by inserting each computed
902/// segment into the appropriate offset of the original vector.
903///
904/// The pattern does not support scalable vectors and will fail to match them.
905///
906/// For each segment, it adds the base step vector and the segment's offset,
907/// then inserts the result into the output vector at the corresponding
908/// position.
909///
910/// Example:
911/// Given a step operation:
912/// %0 = vector.step : vector<8xindex>
913///
914/// and a target unroll shape of <4>, the pattern produces:
915///
916/// %base = vector.step : vector<4xindex>
917/// %zero = arith.constant dense<0> : vector<8xindex>
918/// %result0 = vector.insert_strided_slice %base, %zero
919/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
920/// %offset = arith.constant dense<4> : vector<4xindex>
921/// %segment1 = arith.addi %base, %offset : vector<4xindex>
922/// %result1 = vector.insert_strided_slice %segment1, %result0
923/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
924///
925struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
926 UnrollStepPattern(MLIRContext *context,
927 const vector::UnrollVectorOptions &options,
928 PatternBenefit benefit = 1)
929 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
930
931 LogicalResult matchAndRewrite(vector::StepOp stepOp,
932 PatternRewriter &rewriter) const override {
933 std::optional<SmallVector<int64_t>> targetShape =
934 getTargetShape(options, stepOp);
935 if (!targetShape)
936 return failure();
937
938 VectorType vecType = stepOp.getType();
939 if (vecType.isScalable()) {
940 // Scalable vectors are not supported by this pattern.
941 return failure();
942 }
943 int64_t originalSize = vecType.getShape()[0];
944 Location loc = stepOp.getLoc();
945 SmallVector<int64_t> strides(1, 1);
946
947 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
948 rewriter.getZeroAttr(vecType));
949
950 auto targetVecType =
951 VectorType::get(*targetShape, vecType.getElementType());
952 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
953 for (const SmallVector<int64_t> &offsets :
954 StaticTileOffsetRange({originalSize}, *targetShape)) {
955 Value bcastOffset = arith::ConstantOp::create(
956 rewriter, loc, targetVecType,
958 targetVecType,
959 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
960 Value tileStep =
961 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
962
963 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
964 loc, tileStep, result, offsets, strides);
965 }
966 rewriter.replaceOp(stepOp, result);
967 return success();
968 }
969
970private:
971 vector::UnrollVectorOptions options;
972};
973
974/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
975/// outermost dimension. For example:
976/// ```
977/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
978///
979/// ==>
980///
981/// %0 = ub.poison : vector<2x3xf32>
982/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
983/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
984/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
985/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
986/// ```
987///
988/// When this pattern is applied until a fixed-point is reached,
989/// this will produce a sequence of 1-d from_elements
990/// ops.
991struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
992 UnrollFromElements(MLIRContext *context,
993 const vector::UnrollVectorOptions &options,
994 PatternBenefit benefit = 1)
995 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
996 options(options) {}
997
998 LogicalResult matchAndRewrite(vector::FromElementsOp op,
999 PatternRewriter &rewriter) const override {
1000 ValueRange allElements = op.getElements();
1001
1002 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1003 VectorType subTy, int64_t index) {
1004 size_t subTyNumElements = subTy.getNumElements();
1005 assert((index + 1) * subTyNumElements <= allElements.size() &&
1006 "out of bounds");
1007 ValueRange subElements =
1008 allElements.slice(index * subTyNumElements, subTyNumElements);
1009 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1010 };
1011
1012 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1013 }
1014
1015private:
1016 vector::UnrollVectorOptions options;
1017};
1018
1019/// This pattern unrolls `vector.create_mask` operations into smaller mask
1020/// operations based on the target unroll shape. Each unrolled slice computes
1021/// its local mask size in each dimension (d) as:
1022/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1023/// Example:
1024/// Given a create_mask operation:
1025/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1026/// elements
1027///
1028/// and a target unroll shape of <4x8>, the pattern produces:
1029///
1030/// %false = arith.constant dense<false> : vector<8x16xi1>
1031///
1032/// Slice [0,0]:
1033/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1034/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1035/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1036/// : vector<4x8xi1> into vector<8x16xi1>
1037/// Slice [0,8]:
1038/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1039/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1040/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1041/// : vector<4x8xi1> into vector<8x16xi1>
1042/// Slice [4,0]:
1043/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1044/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1045/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1046/// : vector<4x8xi1> into vector<8x16xi1>
1047/// Slice [4,8]:
1048/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1049/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1050/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1051/// : vector<4x8xi1> into vector<8x16xi1>
1052struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1053 UnrollCreateMaskPattern(MLIRContext *context,
1054 const vector::UnrollVectorOptions &options,
1055 PatternBenefit benefit = 1)
1056 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1057 options(options) {}
1058
1059 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1060 PatternRewriter &rewriter) const override {
1061 auto targetShape = getTargetShape(options, createMaskOp);
1062 if (!targetShape)
1063 return failure();
1064
1065 VectorType resultType = createMaskOp.getVectorType();
1066 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1067 Location loc = createMaskOp.getLoc();
1068
1069 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1070 rewriter.getZeroAttr(resultType));
1071 VectorType targetVectorType =
1072 VectorType::get(*targetShape, rewriter.getI1Type());
1073 SmallVector<int64_t> strides(targetShape->size(), 1);
1074
1075 // In each dimension (d), each unrolled vector computes its mask size as:
1076 // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1077 for (SmallVector<int64_t> offsets :
1078 StaticTileOffsetRange(originalSize, *targetShape)) {
1079 SmallVector<Value> unrolledOperands;
1080
1081 for (auto [i, originalMaskOperand] :
1082 llvm::enumerate(createMaskOp.getOperands())) {
1083 Value offsetVal =
1084 arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1085 Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1086 loc, originalMaskOperand, offsetVal);
1087 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1088 Value unrolledDimSize =
1089 arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1090 Value nonNegative =
1091 rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1092 Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1093 loc, nonNegative, unrolledDimSize);
1094 unrolledOperands.push_back(unrolledOperand);
1095 }
1096
1097 auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1098 loc, targetVectorType, unrolledOperands);
1099 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1100 loc, unrolledMask, result, offsets, strides);
1101 }
1102 rewriter.replaceOp(createMaskOp, result);
1103 return success();
1104 }
1105
1106private:
1107 vector::UnrollVectorOptions options;
1108};
1109
1110/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1111/// operations based on the target unroll shape. Each unrolled slice computes
1112/// whether its elements should be masked based on the original mask dimensions
1113/// and the slice's offset position.
1114///
1115/// Example:
1116/// Given a constant_mask operation:
1117/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1118///
1119/// and a target unroll shape of <4x8>, the pattern produces:
1120///
1121/// %false = arith.constant dense<false> : vector<8x16xi1>
1122///
1123/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1124/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1125/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1126/// : vector<4x8xi1> into vector<8x16xi1>
1127///
1128/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1129/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1130/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1131/// : vector<4x8xi1> into vector<8x16xi1>
1132///
1133/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1134/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1135/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1136/// : vector<4x8xi1> into vector<8x16xi1>
1137///
1138/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1139/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1140/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1141/// : vector<4x8xi1> into vector<8x16xi1>
1142struct UnrollConstantMaskPattern
1143 : public OpRewritePattern<vector::ConstantMaskOp> {
1144 UnrollConstantMaskPattern(MLIRContext *context,
1145 const vector::UnrollVectorOptions &options,
1146 PatternBenefit benefit = 1)
1147 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1148 options(options) {}
1149
1150 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1151 PatternRewriter &rewriter) const override {
1152 std::optional<SmallVector<int64_t>> targetShape =
1153 getTargetShape(options, constantMaskOp);
1154 if (!targetShape)
1155 return failure();
1156
1157 VectorType resultType = constantMaskOp.getVectorType();
1158 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1159 Location loc = constantMaskOp.getLoc();
1160
1161 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1162 rewriter.getZeroAttr(resultType));
1163 VectorType targetVectorType =
1164 VectorType::get(*targetShape, rewriter.getI1Type());
1165 SmallVector<int64_t> strides(targetShape->size(), 1);
1166
1167 // In each dimension (d), each unrolled vector computes its mask size as:
1168 // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1169 for (const SmallVector<int64_t> &offsets :
1170 StaticTileOffsetRange(originalSize, *targetShape)) {
1171 SmallVector<int64_t> unrolledMaskDims;
1172
1173 for (auto [i, originalMaskDim] :
1174 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1175 // Calculate how many elements in this dimension should be masked
1176 // for this particular slice
1177 int64_t adjustedMaskSize =
1178 std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1179 int64_t unrolledMaskDim =
1180 std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1181 unrolledMaskDims.push_back(unrolledMaskDim);
1182 }
1183
1184 auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1185 loc, targetVectorType, unrolledMaskDims);
1186 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1187 loc, unrolledMask, result, offsets, strides);
1188 }
1189 rewriter.replaceOp(constantMaskOp, result);
1190 return success();
1191 }
1192
1193private:
1194 vector::UnrollVectorOptions options;
1195};
1196
1197/// Checks whether extractShape is a contiguous slice of shape.
1198/// For extractShape to be contiguous in shape:
1199/// 1) All but the leading dimension of extractShape and shape must match
1200/// exactly. 2) The total number of elements in shape must be evenly divisible
1201/// by
1202/// the total number of elements in extractShape.
1203/// Examples:
1204/// isContiguous([4, 4], [8, 4]) == true
1205/// isContiguous([2, 4], [8, 4]) == true
1206/// isContiguous([2, 2], [8, 4]) == false
1207/// Removes leading unit dimensions to handle cases like:
1208/// isContiguous([1, 16], [1, 32]) == true
1209static bool isContiguous(ArrayRef<int64_t> extractShape,
1211
1212 if (extractShape.empty() || shape.empty() ||
1213 extractShape.size() > shape.size())
1214 return false;
1215
1216 while (extractShape.size() > 1 && extractShape.front() == 1)
1217 extractShape = extractShape.drop_front();
1218
1219 while (shape.size() > 1 && shape.front() == 1) {
1220 shape = shape.drop_front();
1221 }
1222
1223 size_t rankDiff = shape.size() - extractShape.size();
1224 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1225 return false;
1226
1227 int64_t extractElements = ShapedType::getNumElements(extractShape);
1228 int64_t shapeElements = ShapedType::getNumElements(shape);
1229 return shapeElements % extractElements == 0;
1230}
1231
1232/// Determines what shape to use with `vector.extract_strided_slice` to extract
1233/// a contiguous memory region from a source vector. The extraction must be
1234/// contiguous and contain exactly the specified number of elements. If such an
1235/// extraction shape cannot be determined, returns std::nullopt.
1236/// EXAMPLE 1:
1237/// sourceShape = [16], targetElements = 8
1238/// Working right-to-left:
1239/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1240/// remaining = 8/8 = 1
1241/// Result: [8]
1242///
1243/// EXAMPLE 2:
1244/// sourceShape = [4, 4], targetElements = 8
1245/// Working right-to-left:
1246/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1247/// remaining = 8/4 = 2
1248/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1249/// remaining = 2/2 = 1
1250/// Result: [2, 4]
1251static std::optional<SmallVector<int64_t>>
1252calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1253 int64_t targetElements) {
1254 SmallVector<int64_t> extractShape;
1255 int64_t remainingElements = targetElements;
1256
1257 // Build extract shape from innermost dimension outward to ensure contiguity.
1258 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1259 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1260 extractShape.insert(extractShape.begin(), takeFromDim);
1261
1262 if (remainingElements % takeFromDim != 0)
1263 return std::nullopt; // Not evenly divisible.
1264 remainingElements /= takeFromDim;
1265 }
1266
1267 // Fill remaining dimensions with 1.
1268 while (extractShape.size() < sourceShape.size())
1269 extractShape.insert(extractShape.begin(), 1);
1270
1271 if (ShapedType::getNumElements(extractShape) != targetElements)
1272 return std::nullopt;
1273
1274 return extractShape;
1275}
1276
1277// Convert result offsets to source offsets via linear position.
1279calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1280 ArrayRef<int64_t> sourceShape,
1281 ArrayRef<int64_t> resultShape) {
1282 // Convert result offsets to linear position.
1283 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1284 // Convert linear position to source offsets.
1285 return delinearize(linearIndex, computeStrides(sourceShape));
1286}
1287
1288/// This pattern unrolls `vector.shape_cast` operations according to the
1289/// provided target unroll shape. It unrolls a large shape cast into smaller
1290/// shape casts by extracting contiguous slices from the source vector, casting
1291/// each slice to the target shape, and assembling the result by inserting each
1292/// computed segment into the appropriate offset of the result vector.
1293///
1294/// This pattern only applies when contiguous slices can be extracted from the
1295/// source vector and inserted into the result vector such that each slice
1296/// remains a valid vector (and not decompose to scalars). In these cases, the
1297/// unrolling proceeds as:
1298/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1299/// vector.insert_strided_slice.
1300///
1301/// Example:
1302/// Given a shape cast operation:
1303/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1304///
1305/// and a target unroll shape of <2x4>, the pattern produces:
1306///
1307/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1308/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1309/// : vector<8x2xf32> to vector<4x2xf32>
1310/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1311/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1312/// : vector<2x4xf32> into vector<4x4xf32>
1313/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1314/// : vector<8x2xf32> to vector<4x2xf32>
1315/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1316/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1317/// : vector<2x4xf32> into vector<4x4xf32>
1318///
1319struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1320 UnrollShapeCastPattern(MLIRContext *context,
1321 const vector::UnrollVectorOptions &options,
1322 PatternBenefit benefit = 1)
1323 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1324 options(options) {}
1325
1326 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1327 PatternRewriter &rewriter) const override {
1328 std::optional<SmallVector<int64_t>> targetShape =
1329 getTargetShape(options, shapeCastOp);
1330 if (!targetShape)
1331 return failure();
1332
1333 VectorType sourceType = shapeCastOp.getSourceVectorType();
1334 VectorType resultType = shapeCastOp.getResultVectorType();
1335 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1336 ArrayRef<int64_t> resultShape = resultType.getShape();
1337
1338 if (!isContiguous(*targetShape, resultShape))
1339 return rewriter.notifyMatchFailure(
1340 shapeCastOp, "Only supports cases where target shape is "
1341 "contiguous in result vector shape");
1342
1343 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1344
1345 // Calculate the shape to extract from source.
1346 std::optional<SmallVector<int64_t>> extractShape =
1347 calculateSourceExtractShape(sourceShape, targetElements);
1348 if (!extractShape)
1349 return rewriter.notifyMatchFailure(
1350 shapeCastOp,
1351 "cannot extract target number of elements contiguously from source");
1352
1353 Location loc = shapeCastOp.getLoc();
1354
1355 // Create result vector initialized to zero.
1356 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1357 rewriter.getZeroAttr(resultType));
1358
1359 VectorType targetType =
1360 VectorType::get(*targetShape, sourceType.getElementType());
1361
1362 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1363 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1364
1365 for (SmallVector<int64_t> resultOffsets :
1366 StaticTileOffsetRange(resultShape, *targetShape)) {
1367 SmallVector<int64_t> sourceOffsets =
1368 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1369 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1370 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1372 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1373 loc, targetType, sourceChunk);
1374 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1375 loc, targetChunk, result, resultOffsets, insertStrides);
1376 }
1377
1378 rewriter.replaceOp(shapeCastOp, result);
1379 return success();
1380 }
1381
1382private:
1383 vector::UnrollVectorOptions options;
1384};
1385
1386// Unroll vector::BitCastOp into smaller slice-based bitcast operations.
1387// Decomposes the result vector into target shape chunks and bitcasts
1388// corresponding source slices, accounting for element bitwidth ratios.
1389/// Example:
1390/// Given a bitcast Op:
1391///
1392/// vector.bitcast %src : vector<4x8xf32>
1393///
1394/// and a target unroll shape of <2x4>, the pattern produces:
1395///
1396/// %slice_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x4xf32>
1397/// %slice_0 = vector.bitcast %slice_0 : vector<2x4xf32>
1398/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
1399/// // ... repeat for remaining slices
1400struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
1401 UnrollBitCastPattern(MLIRContext *context,
1402 const vector::UnrollVectorOptions &options,
1403 PatternBenefit benefit = 1)
1404 : OpRewritePattern<vector::BitCastOp>(context, benefit),
1405 options(options) {}
1406
1407 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1408 PatternRewriter &rewriter) const override {
1409 auto targetShape = getTargetShape(options, bitCastOp);
1410 if (!targetShape)
1411 return rewriter.notifyMatchFailure(bitCastOp,
1412 "failed to get target shape");
1413
1414 VectorType sourceType = bitCastOp.getSourceVectorType();
1415 VectorType resultType = bitCastOp.getResultVectorType();
1416 ArrayRef<int64_t> resultShape = resultType.getShape();
1417 Location loc = bitCastOp.getLoc();
1418
1419 if (targetShape->size() != resultShape.size())
1420 return rewriter.notifyMatchFailure(
1421 bitCastOp, "target shape rank must match result rank");
1422
1423 unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
1424 unsigned resultElementBits = resultType.getElementTypeBitWidth();
1425
1426 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1427 targetShape->end());
1428 int64_t lastDim = sourceSliceShape.size() - 1;
1429
1430 sourceSliceShape[lastDim] =
1431 ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
1432
1433 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1434 rewriter.getZeroAttr(resultType));
1435 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1436 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1437
1438 VectorType targetType =
1439 VectorType::get(*targetShape, resultType.getElementType());
1440
1441 for (SmallVector<int64_t> resultOffsets :
1442 StaticTileOffsetRange(resultShape, *targetShape)) {
1443 SmallVector<int64_t> sourceOffsets = resultOffsets;
1444 sourceOffsets[lastDim] =
1445 (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
1446
1447 Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1448 loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
1449 sourceStrides);
1450 Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
1451 loc, targetType, sourceSlice);
1452 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1453 loc, bitcastSlice, result, resultOffsets, resultStrides);
1454 }
1455
1456 rewriter.replaceOp(bitCastOp, result);
1457 return success();
1458 }
1459
1460private:
1461 vector::UnrollVectorOptions options;
1462};
1463
1464/// Pattern to unroll vector.interleave into smaller slice-sized operations.
1465/// Decomposes a large interleave into slices by extracting slices from both
1466/// input vectors, interleaving them, and inserting back into the result.
1467///
1468/// Example:
1469/// Given an interleave Op:
1470///
1471/// vector.interleave %lhs, %rhs : vector<4x8xf32>
1472///
1473/// and a target unroll shape of <2x4>, the pattern produces:
1474///
1475/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
1476/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
1477/// %slice_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
1478/// : vector<2x4xf32>
1479/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
1480/// // ... repeat for remaining slices
1481struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
1482 UnrollInterleavePattern(MLIRContext *context,
1483 const vector::UnrollVectorOptions &options,
1484 PatternBenefit benefit = 1)
1485 : OpRewritePattern<vector::InterleaveOp>(context, benefit),
1486 options(options) {}
1487
1488 LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
1489 PatternRewriter &rewriter) const override {
1490 auto targetShape = getTargetShape(options, interleaveOp);
1491 if (!targetShape)
1492 return rewriter.notifyMatchFailure(interleaveOp,
1493 "failed to get target shape");
1494
1495 VectorType resultType = interleaveOp.getResultVectorType();
1496 ArrayRef<int64_t> resultShape = resultType.getShape();
1497 Location loc = interleaveOp.getLoc();
1498
1499 if (targetShape->size() != resultShape.size())
1500 return rewriter.notifyMatchFailure(
1501 interleaveOp, "target shape rank must match result rank");
1502
1503 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1504 targetShape->end());
1505 int64_t lastDim = sourceSliceShape.size() - 1;
1506 sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
1507
1508 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1509 rewriter.getZeroAttr(resultType));
1510 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1511 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1512
1513 VectorType targetType =
1514 VectorType::get(*targetShape, resultType.getElementType());
1515
1516 for (SmallVector<int64_t> resultOffsets :
1517 StaticTileOffsetRange(resultShape, *targetShape)) {
1518 SmallVector<int64_t> sourceOffsets = resultOffsets;
1519 sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
1520
1521 Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1522 loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
1523 sourceStrides);
1524 Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1525 loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
1526 sourceStrides);
1527 Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
1528 loc, targetType, lhsSlice, rhsSlice);
1529 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1530 loc, interleaveSlice, result, resultOffsets, resultStrides);
1531 }
1532
1533 rewriter.replaceOp(interleaveOp, result);
1534 return success();
1535 }
1536
1537private:
1538 vector::UnrollVectorOptions options;
1539};
1540
1541/// Pattern to unroll vector.deinterleave into smaller slice-sized operations.
1542/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
1543/// by extracting source slices, deinterleaving them, and inserting into two
1544/// result vectors.
1545///
1546/// Example:
1547/// Given a deinterleave Op:
1548///
1549/// vector.deinterleave %src : vector<4x8xf32>
1550///
1551/// and a target unroll shape of <2x4>, the pattern produces:
1552///
1553/// %slice_0 = vector.extract_strided_slice %src[0, 0] : vector<2x4xf32>
1554/// %slice_lhs_0, %slice_rhs_0 = vector.deinterleave %slice_0 :
1555/// vector<2x4xf32> %result1 = vector.insert_strided_slice %slice_lhs_0,
1556/// %init1[0, 0] %result2 = vector.insert_strided_slice %slice_rhs_0,
1557/// %init2[0, 0]
1558/// // ... repeat for remaining slices
1559struct UnrollDeinterleavePattern
1560 : public OpRewritePattern<vector::DeinterleaveOp> {
1561 UnrollDeinterleavePattern(MLIRContext *context,
1562 const vector::UnrollVectorOptions &options,
1563 PatternBenefit benefit = 1)
1564 : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
1565 options(options) {}
1566
1567 LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
1568 PatternRewriter &rewriter) const override {
1569 auto targetShape = getTargetShape(options, deinterleaveOp);
1570 if (!targetShape)
1571 return rewriter.notifyMatchFailure(deinterleaveOp,
1572 "failed to get target shape");
1573
1574 VectorType resultType = deinterleaveOp.getResultVectorType();
1575 ArrayRef<int64_t> resultShape = resultType.getShape();
1576 Location loc = deinterleaveOp.getLoc();
1577
1578 if (targetShape->size() != resultShape.size())
1579 return rewriter.notifyMatchFailure(
1580 deinterleaveOp, "target shape rank must match result rank");
1581
1582 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1583 targetShape->end());
1584 int64_t lastDim = sourceSliceShape.size() - 1;
1585 sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
1586
1587 Value resultOdd = arith::ConstantOp::create(
1588 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
1589 Value resultEven = arith::ConstantOp::create(
1590 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
1591 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1592 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1593
1594 for (SmallVector<int64_t> resultOffsets :
1595 StaticTileOffsetRange(resultShape, *targetShape)) {
1596 SmallVector<int64_t> sourceOffsets = resultOffsets;
1597 sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
1598
1599 Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1600 loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
1601 sourceStrides);
1602
1603 auto deinterleaveSlice =
1604 vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
1605
1606 resultOdd = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1607 loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
1608 resultStrides);
1609 resultEven = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1610 loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
1611 resultStrides);
1612 }
1613
1614 rewriter.replaceOp(deinterleaveOp, ValueRange{resultOdd, resultEven});
1615 return success();
1616 }
1617
1618private:
1619 vector::UnrollVectorOptions options;
1620};
1621
1622} // namespace
1623
1624void mlir::vector::populateVectorUnrollPatterns(
1626 PatternBenefit benefit) {
1627 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1628 UnrollContractionPattern, UnrollElementwisePattern,
1629 UnrollReductionPattern, UnrollMultiReductionPattern,
1630 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1631 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1632 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1633 UnrollCreateMaskPattern, UnrollConstantMaskPattern,
1634 UnrollBitCastPattern, UnrollInterleavePattern,
1635 UnrollDeinterleavePattern>(patterns.getContext(), options,
1636 benefit);
1637}
1638
1639void mlir::vector::populateVectorToElementsUnrollPatterns(
1640 RewritePatternSet &patterns, PatternBenefit benefit) {
1641 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1642 benefit);
1643}
1644
1645void mlir::vector::populateVectorFromElementsUnrollPatterns(
1646 RewritePatternSet &patterns, PatternBenefit benefit) {
1647 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1648 benefit);
1649}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:462
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:408
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
result_range getResults()
Definition Operation.h:440
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.