MLIR 23.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29SmallVector<Value> mlir::vector::sliceTransferIndices(
31 AffineMap permutationMap, Location loc, OpBuilder &builder) {
32 MLIRContext *ctx = builder.getContext();
33 auto isBroadcast = [](AffineExpr expr) {
34 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
35 return constExpr.getValue() == 0;
36 return false;
37 };
38 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
39 SmallVector<Value> slicedIndices(indices);
40 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
41 int64_t elementOffset = elementOffsets[dim.index()];
42 if (isBroadcast(dim.value()) || elementOffset == 0)
43 continue;
44 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
45 auto expr = getAffineDimExpr(0, builder.getContext()) +
46 getAffineConstantExpr(elementOffset, ctx);
47 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
48 slicedIndices[pos] =
49 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
50 }
51 return slicedIndices;
52}
53
54// Compute the new indices by adding `offsets` to `originalIndices`.
55// If m < n (m = offsets.size(), n = originalIndices.size()),
56// then only the trailing m values in `originalIndices` are updated.
58 Location loc,
59 OperandRange originalIndices,
60 ArrayRef<int64_t> offsets) {
61 assert(offsets.size() <= originalIndices.size() &&
62 "Offsets should not exceed the number of original indices");
63 SmallVector<Value> indices(originalIndices);
64
65 auto start = indices.size() - offsets.size();
66 for (auto [i, offset] : llvm::enumerate(offsets)) {
67 if (offset != 0) {
68 indices[start + i] = arith::AddIOp::create(
69 rewriter, loc, originalIndices[start + i],
70 arith::ConstantIndexOp::create(rewriter, loc, offset));
71 }
72 }
73 return indices;
74}
75
76// Clones `op` into a new operations that takes `operands` and returns
77// `resultTypes`.
79 Operation *op,
80 ArrayRef<Value> operands,
81 ArrayRef<Type> resultTypes) {
82 return builder.create(loc, op->getName().getIdentifier(), operands,
83 resultTypes, op->getAttrs());
84}
85
86/// Return the target shape for unrolling for the given `op`. Return
87/// std::nullopt if the op shouldn't be or cannot be unrolled.
88static std::optional<SmallVector<int64_t>>
90 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
91 if (options.filterConstraint && failed(options.filterConstraint(op))) {
92 LDBG() << "--no filter constraint -> BAIL";
93 return std::nullopt;
94 }
95 assert(options.nativeShape &&
96 "vector unrolling expects the native shape or native"
97 "shape call back function to be set");
98 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
99 if (!unrollableVectorOp) {
100 LDBG() << "--not an unrollable op -> BAIL";
101 return std::nullopt;
102 }
103 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
104 if (!maybeUnrollShape) {
105 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
106 return std::nullopt;
107 }
108 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
109
110 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
111 if (!targetShape) {
112 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
113 return std::nullopt;
114 }
115 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
116
117 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
118 if (!maybeShapeRatio) {
119 LDBG() << "--could not compute integral shape ratio -> BAIL";
120 return std::nullopt;
121 }
122 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
123 LDBG() << "--no unrolling needed -> SKIP";
124 return std::nullopt;
125 }
126 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
127 return targetShape;
128}
129
131getUnrollOrder(unsigned numLoops, Operation *op,
133 SmallVector<int64_t> loopOrder =
134 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
135 if (options.traversalOrderCallback != nullptr) {
136 std::optional<SmallVector<int64_t>> order =
137 options.traversalOrderCallback(op);
138 if (order) {
139 loopOrder = std::move(*order);
140 }
141 }
142 return loopOrder;
143}
144
145namespace {
146
147struct UnrollTransferReadPattern
148 : public OpRewritePattern<vector::TransferReadOp> {
149 UnrollTransferReadPattern(MLIRContext *context,
150 const vector::UnrollVectorOptions &options,
151 PatternBenefit benefit = 1)
152 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
153 options(options) {}
154
155 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
156 PatternRewriter &rewriter) const override {
157 // TODO: support 0-d corner case.
158 if (readOp.getTransferRank() == 0)
159 return failure();
160 if (readOp.getMask())
161 return failure();
162 auto targetShape = getTargetShape(options, readOp);
163 if (!targetShape)
164 return failure();
165 auto sourceVectorType = readOp.getVectorType();
166 SmallVector<int64_t> strides(targetShape->size(), 1);
167 Location loc = readOp.getLoc();
168 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
169
170 // Prepare the result vector;
171 Value result =
172 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
173 rewriter.getZeroAttr(sourceVectorType));
174 auto targetType =
175 VectorType::get(*targetShape, sourceVectorType.getElementType());
176 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
177 readOp.getIndices().end());
178 SmallVector<int64_t> loopOrder =
179 getUnrollOrder(originalSize.size(), readOp, options);
180 for (SmallVector<int64_t> elementOffsets :
181 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
182 SmallVector<Value> indices =
183 sliceTransferIndices(elementOffsets, originalIndices,
184 readOp.getPermutationMap(), loc, rewriter);
185 auto slicedRead = vector::TransferReadOp::create(
186 rewriter, loc, targetType, readOp.getBase(), indices,
187 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
188 readOp.getInBoundsAttr());
189
190 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
191 loc, slicedRead, result, elementOffsets, strides);
192 }
193 rewriter.replaceOp(readOp, result);
194 return success();
195 }
196
197private:
198 vector::UnrollVectorOptions options;
199};
200
201struct UnrollTransferWritePattern
202 : public OpRewritePattern<vector::TransferWriteOp> {
203 UnrollTransferWritePattern(MLIRContext *context,
204 const vector::UnrollVectorOptions &options,
205 PatternBenefit benefit = 1)
206 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
207 options(options) {}
208
209 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
210 PatternRewriter &rewriter) const override {
211 // TODO: support 0-d corner case.
212 if (writeOp.getTransferRank() == 0)
213 return failure();
214
215 if (writeOp.getMask())
216 return failure();
217 auto targetShape = getTargetShape(options, writeOp);
218 if (!targetShape)
219 return failure();
220 auto sourceVectorType = writeOp.getVectorType();
221 SmallVector<int64_t> strides(targetShape->size(), 1);
222 Location loc = writeOp.getLoc();
223 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
224 // Bail-out if rank(source) != rank(target). The main limitation here is the
225 // fact that `ExtractStridedSlice` requires the rank for the input and
226 // output to match. If needed, we can relax this later.
227 if (originalSize.size() != targetShape->size())
228 return rewriter.notifyMatchFailure(
229 writeOp,
230 "expected source input vector rank to match target shape rank");
231
232 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
233 writeOp.getIndices().end());
234 SmallVector<int64_t> loopOrder =
235 getUnrollOrder(originalSize.size(), writeOp, options);
236 Value resultTensor;
237 for (SmallVector<int64_t> elementOffsets :
238 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
239 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
240 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
241 SmallVector<Value> indices =
242 sliceTransferIndices(elementOffsets, originalIndices,
243 writeOp.getPermutationMap(), loc, rewriter);
244 Operation *slicedWrite = vector::TransferWriteOp::create(
245 rewriter, loc, slicedVector,
246 resultTensor ? resultTensor : writeOp.getBase(), indices,
247 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
248 // For the tensor case update the destination for the next transfer write.
249 if (!slicedWrite->getResults().empty())
250 resultTensor = slicedWrite->getResult(0);
251 }
252 if (resultTensor)
253 rewriter.replaceOp(writeOp, resultTensor);
254 else
255 rewriter.eraseOp(writeOp);
256 return success();
257 }
258
259private:
260 vector::UnrollVectorOptions options;
261};
262
263struct OffsetMapInfo {
264 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
265
266 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
267
268 static unsigned getHashValue(const SmallVector<int64_t> &v) {
269 return static_cast<unsigned>(llvm::hash_combine_range(v));
270 }
271
272 static bool isEqual(const SmallVector<int64_t> &lhs,
273 const SmallVector<int64_t> &rhs) {
274 return lhs == rhs;
275 }
276};
277
278struct UnrollContractionPattern
279 : public OpRewritePattern<vector::ContractionOp> {
280 UnrollContractionPattern(MLIRContext *context,
281 const vector::UnrollVectorOptions &options,
282 PatternBenefit benefit = 1)
283 : OpRewritePattern<vector::ContractionOp>(context, benefit),
284 options(options) {}
285
286 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
287 PatternRewriter &rewriter) const override {
288 auto targetShape = getTargetShape(options, contractOp);
289 if (!targetShape)
290 return failure();
291 auto dstVecType = cast<VectorType>(contractOp.getResultType());
292 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
293
294 Location loc = contractOp.getLoc();
295 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
296 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
297 llvm::MapVector<
298 SmallVector<int64_t>, Value,
299 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
300 accCache;
301
302 SmallVector<int64_t> loopOrder = getUnrollOrder(
303 contractOp.getIteratorTypes().size(), contractOp, options);
304
305 for (SmallVector<int64_t> offsets :
306 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
307 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
308
309 // Helper to compute the new shape of each operand and extract the slice.
310 auto extractOperand = [&](unsigned index, Value operand,
311 AffineMap permutationMap,
312 ArrayRef<int64_t> operandOffets) {
313 SmallVector<int64_t> operandShape = applyPermutationMap(
314 permutationMap, ArrayRef<int64_t>(*targetShape));
315 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
316 slicesOperands[index] =
317 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
318 loc, operand, operandOffets, operandShape, operandStrides);
319 };
320
321 // Extract the new lhs operand.
322 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
323 SmallVector<int64_t> lhsOffets =
324 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
325 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
326
327 // Extract the new rhs operand.
328 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
329 SmallVector<int64_t> rhsOffets =
330 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
331 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
332
333 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
334 SmallVector<int64_t> accOffets =
335 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
336 // If a version of the accumulator has already been computed, use it
337 // otherwise extract the first version from the original operand.
338 auto *accIt = accCache.find(accOffets);
339 if (accIt != accCache.end())
340 slicesOperands[2] = accIt->second;
341 else
342 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
343
344 SmallVector<int64_t> dstShape =
345 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
346 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
347 Operation *newOp = cloneOpWithOperandsAndTypes(
348 rewriter, loc, contractOp, slicesOperands, targetType);
349
350 SmallVector<int64_t> dstOffets =
351 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
352 // Save the accumulated value untill all the loops are unrolled since
353 // reduction loop keep updating the accumulator.
354 accCache[dstOffets] = newOp->getResult(0);
355 }
356 // Assemble back the accumulator into a single vector.
357 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
358 rewriter.getZeroAttr(dstVecType));
359 for (const auto &it : accCache) {
360 SmallVector<int64_t> dstStrides(it.first.size(), 1);
361 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
362 loc, it.second, result, it.first, dstStrides);
363 }
364 rewriter.replaceOp(contractOp, result);
365 return success();
366 }
367
368private:
369 vector::UnrollVectorOptions options;
370};
371
372struct UnrollMultiReductionPattern
373 : public OpRewritePattern<vector::MultiDimReductionOp> {
374 UnrollMultiReductionPattern(MLIRContext *context,
375 const vector::UnrollVectorOptions &options,
376 PatternBenefit benefit = 1)
377 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
378 options(options) {}
379
380 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
381 PatternRewriter &rewriter) const override {
382 std::optional<SmallVector<int64_t>> targetShape =
383 getTargetShape(options, reductionOp);
384 if (!targetShape)
385 return failure();
386 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
387 Location loc = reductionOp.getLoc();
388 auto resultType = reductionOp->getResult(0).getType();
389
390 // Handle scalar result case: all dimensions are reduced.
391 // Each source tile is reduced to a scalar, and partial results are
392 // chained through the accumulator operand.
393 if (resultType.isIntOrFloat()) {
394 Value accumulator = reductionOp.getAcc();
395 for (SmallVector<int64_t> offsets :
396 StaticTileOffsetRange(originalSize, *targetShape)) {
397 SmallVector<int64_t> operandStrides(offsets.size(), 1);
398 Value slicedOperand =
399 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
400 loc, reductionOp.getSource(), offsets, *targetShape,
401 operandStrides);
402 Operation *newOp = cloneOpWithOperandsAndTypes(
403 rewriter, loc, reductionOp, {slicedOperand, accumulator},
404 resultType);
405 accumulator = newOp->getResult(0);
406 }
407 rewriter.replaceOp(reductionOp, accumulator);
408 return success();
409 }
410
411 // Vector result case.
412 llvm::MapVector<
413 SmallVector<int64_t>, Value,
414 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
415 accCache;
416
417 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
418 // of multiples of the targetShape.
419 for (SmallVector<int64_t> offsets :
420 StaticTileOffsetRange(originalSize, *targetShape)) {
421 SmallVector<Value> operands;
422 SmallVector<int64_t> operandStrides(offsets.size(), 1);
423 Value slicedOperand =
424 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
425 loc, reductionOp.getSource(), offsets, *targetShape,
426 operandStrides);
427 operands.push_back(slicedOperand);
428 SmallVector<int64_t> dstShape;
429 SmallVector<int64_t> destOffset;
430 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
431 if (!reductionOp.isReducedDim(i)) {
432 destOffset.push_back(offsets[i]);
433 dstShape.push_back((*targetShape)[i]);
434 }
435 }
436 Value acc;
437 SmallVector<int64_t> accStrides(destOffset.size(), 1);
438 // If a version of the accumulator has already been computed, use it
439 // otherwise extract the first version from the original operand.
440 auto *accIt = accCache.find(destOffset);
441 if (accIt != accCache.end())
442 acc = accIt->second;
443 else
444 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
445 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
446 operands.push_back(acc);
447 auto targetType = VectorType::get(
448 dstShape, reductionOp.getSourceVectorType().getElementType());
449 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
450 operands, targetType);
451 Value result = newOp->getResult(0);
452 accCache[destOffset] = result;
453 }
454 // Assemble back the accumulator into a single vector.
455 Value result = arith::ConstantOp::create(
456 rewriter, loc, reductionOp.getDestType(),
457 rewriter.getZeroAttr(reductionOp.getDestType()));
458 for (const auto &it : accCache) {
459 SmallVector<int64_t> dstStrides(it.first.size(), 1);
460 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
461 loc, it.second, result, it.first, dstStrides);
462 }
463 rewriter.replaceOp(reductionOp, result);
464 return success();
465 }
466
467private:
468 vector::UnrollVectorOptions options;
469};
470
471struct UnrollElementwisePattern : public RewritePattern {
472 UnrollElementwisePattern(MLIRContext *context,
473 const vector::UnrollVectorOptions &options,
474 PatternBenefit benefit = 1)
475 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
476 options(options) {}
477
478 LogicalResult matchAndRewrite(Operation *op,
479 PatternRewriter &rewriter) const override {
481 return failure();
482 auto targetShape = getTargetShape(options, op);
483 if (!targetShape)
484 return failure();
485 int64_t targetShapeRank = targetShape->size();
486 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
487 SmallVector<int64_t> originalSize =
488 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
489 int64_t originalShapeRank = originalSize.size();
490
491 Location loc = op->getLoc();
492
493 // Handle rank mismatch by adding leading unit dimensions to targetShape
494 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
495 int64_t rankDiff = originalShapeRank - targetShapeRank;
496 std::fill(adjustedTargetShape.begin(),
497 adjustedTargetShape.begin() + rankDiff, 1);
498 std::copy(targetShape->begin(), targetShape->end(),
499 adjustedTargetShape.begin() + rankDiff);
500
501 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
502 // Prepare the result vector.
503 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
504 rewriter.getZeroAttr(dstVecType));
505 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
506 VectorType unrolledVecType =
507 VectorType::get(*targetShape, dstVecType.getElementType());
508
509 // Create the unrolled computation.
510 for (SmallVector<int64_t> offsets :
511 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
512 SmallVector<Value> extractOperands;
513 for (OpOperand &operand : op->getOpOperands()) {
514 auto vecType = dyn_cast<VectorType>(operand.get().getType());
515 if (!vecType) {
516 extractOperands.push_back(operand.get());
517 continue;
518 }
519 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
520 loc, operand.get(), offsets, adjustedTargetShape, strides);
521
522 // Reshape to remove leading unit dims if needed
523 if (adjustedTargetShapeRank > targetShapeRank) {
524 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
525 loc, VectorType::get(*targetShape, vecType.getElementType()),
526 extracted);
527 }
528 extractOperands.push_back(extracted);
529 }
530
531 Operation *newOp = cloneOpWithOperandsAndTypes(
532 rewriter, loc, op, extractOperands, unrolledVecType);
533
534 Value computeResult = newOp->getResult(0);
535
536 // Use strides sized to targetShape for proper insertion
537 SmallVector<int64_t> insertStrides =
538 (adjustedTargetShapeRank > targetShapeRank)
539 ? SmallVector<int64_t>(targetShapeRank, 1)
540 : strides;
541
542 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
543 loc, computeResult, result, offsets, insertStrides);
544 }
545 rewriter.replaceOp(op, result);
546 return success();
547 }
548
549private:
550 vector::UnrollVectorOptions options;
551};
552
553struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
554 UnrollReductionPattern(MLIRContext *context,
555 const vector::UnrollVectorOptions &options,
556 PatternBenefit benefit = 1)
557 : OpRewritePattern<vector::ReductionOp>(context, benefit),
558 options(options) {}
559
560 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
561 PatternRewriter &rewriter) const override {
562 std::optional<SmallVector<int64_t>> targetShape =
563 getTargetShape(options, reductionOp);
564 if (!targetShape)
565 return failure();
566 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
567
568 // Create unrolled vector reduction.
569 Location loc = reductionOp.getLoc();
570 Value accumulator = nullptr;
571 for (SmallVector<int64_t> offsets :
572 StaticTileOffsetRange(originalSize, *targetShape)) {
573 SmallVector<int64_t> strides(offsets.size(), 1);
574 Value slicedOperand =
575 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
576 loc, reductionOp.getVector(), offsets, *targetShape, strides);
577 Operation *newOp = cloneOpWithOperandsAndTypes(
578 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
579 Value result = newOp->getResult(0);
580
581 if (!accumulator) {
582 // This is the first reduction.
583 accumulator = result;
584 } else {
585 // On subsequent reduction, combine with the accumulator.
586 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
587 accumulator, result);
588 }
589 }
590
591 rewriter.replaceOp(reductionOp, accumulator);
592 return success();
593 }
594
595private:
596 const vector::UnrollVectorOptions options;
597};
598
599struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
600 UnrollTransposePattern(MLIRContext *context,
601 const vector::UnrollVectorOptions &options,
602 PatternBenefit benefit = 1)
603 : OpRewritePattern<vector::TransposeOp>(context, benefit),
604 options(options) {}
605
606 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
607 PatternRewriter &rewriter) const override {
608 if (transposeOp.getResultVectorType().getRank() == 0)
609 return failure();
610 auto targetShape = getTargetShape(options, transposeOp);
611 if (!targetShape)
612 return failure();
613 auto originalVectorType = transposeOp.getResultVectorType();
614 SmallVector<int64_t> strides(targetShape->size(), 1);
615 Location loc = transposeOp.getLoc();
616 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
617
618 // Prepare the result vector;
619 Value result =
620 arith::ConstantOp::create(rewriter, loc, originalVectorType,
621 rewriter.getZeroAttr(originalVectorType));
622 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
623
624 // Unroll the computation.
625 for (SmallVector<int64_t> elementOffsets :
626 StaticTileOffsetRange(originalSize, *targetShape)) {
627 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
628 SmallVector<int64_t> permutedShape(elementOffsets.size());
629 // Compute the source offsets and shape.
630 for (auto indices : llvm::enumerate(permutation)) {
631 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
632 permutedShape[indices.value()] = (*targetShape)[indices.index()];
633 }
634 Value slicedOperand =
635 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
636 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
637 strides);
638 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
639 loc, slicedOperand, permutation);
640 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
641 loc, transposedSlice, result, elementOffsets, strides);
642 }
643 rewriter.replaceOp(transposeOp, result);
644 return success();
645 }
646
647private:
648 vector::UnrollVectorOptions options;
649};
650
651struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
652 UnrollGatherPattern(MLIRContext *context,
653 const vector::UnrollVectorOptions &options,
654 PatternBenefit benefit = 1)
655 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
656 }
657
658 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
659 PatternRewriter &rewriter) const override {
660 VectorType sourceVectorType = gatherOp.getVectorType();
661 if (sourceVectorType.getRank() == 0)
662 return failure();
663 auto targetShape = getTargetShape(options, gatherOp);
664 if (!targetShape)
665 return failure();
666 SmallVector<int64_t> strides(targetShape->size(), 1);
667 Location loc = gatherOp.getLoc();
668 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
669
670 // Prepare the result vector;
671 Value result =
672 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
673 rewriter.getZeroAttr(sourceVectorType));
674 auto targetType =
675 VectorType::get(*targetShape, sourceVectorType.getElementType());
676
677 SmallVector<int64_t> loopOrder =
678 getUnrollOrder(originalSize.size(), gatherOp, options);
679 for (SmallVector<int64_t> elementOffsets :
680 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
681 // To get the unrolled gather, extract the same slice based on the
682 // decomposed shape from each of the index, mask, and pass-through
683 // vectors.
684 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
685 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
686 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
687 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
688 Value passThruSubVec =
689 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
690 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
691 strides);
692 auto slicedGather = vector::GatherOp::create(
693 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
694 indexSubVec, maskSubVec, passThruSubVec);
695
696 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
697 loc, slicedGather, result, elementOffsets, strides);
698 }
699 rewriter.replaceOp(gatherOp, result);
700 return success();
701 }
702
703private:
704 vector::UnrollVectorOptions options;
705};
706
707struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
708 UnrollLoadPattern(MLIRContext *context,
709 const vector::UnrollVectorOptions &options,
710 PatternBenefit benefit = 1)
711 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
712
713 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
714 PatternRewriter &rewriter) const override {
715 VectorType vecType = loadOp.getVectorType();
716
717 auto targetShape = getTargetShape(options, loadOp);
718 if (!targetShape)
719 return failure();
720
721 Location loc = loadOp.getLoc();
722 ArrayRef<int64_t> originalShape = vecType.getShape();
723 SmallVector<int64_t> strides(targetShape->size(), 1);
724
725 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
726 rewriter.getZeroAttr(vecType));
727
728 SmallVector<int64_t> loopOrder =
729 getUnrollOrder(originalShape.size(), loadOp, options);
730
731 auto targetVecType =
732 VectorType::get(*targetShape, vecType.getElementType());
733
734 for (SmallVector<int64_t> offsets :
735 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
736 SmallVector<Value> indices =
737 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
738 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
739 loadOp.getBase(), indices);
740 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
741 loc, slicedLoad, result, offsets, strides);
742 }
743 rewriter.replaceOp(loadOp, result);
744 return success();
745 }
746
747private:
748 vector::UnrollVectorOptions options;
749};
750
751struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
752 UnrollStorePattern(MLIRContext *context,
753 const vector::UnrollVectorOptions &options,
754 PatternBenefit benefit = 1)
755 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
756
757 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
758 PatternRewriter &rewriter) const override {
759 VectorType vecType = storeOp.getVectorType();
760
761 auto targetShape = getTargetShape(options, storeOp);
762 if (!targetShape)
763 return failure();
764
765 Location loc = storeOp.getLoc();
766 ArrayRef<int64_t> originalShape = vecType.getShape();
767 SmallVector<int64_t> strides(targetShape->size(), 1);
768
769 Value base = storeOp.getBase();
770 Value vector = storeOp.getValueToStore();
771
772 SmallVector<int64_t> loopOrder =
773 getUnrollOrder(originalShape.size(), storeOp, options);
774
775 for (SmallVector<int64_t> offsets :
776 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
777 SmallVector<Value> indices =
778 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
779 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
780 loc, vector, offsets, *targetShape, strides);
781 vector::StoreOp::create(rewriter, loc, slice, base, indices);
782 }
783 rewriter.eraseOp(storeOp);
784 return success();
785 }
786
787private:
788 vector::UnrollVectorOptions options;
789};
790
791struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
792 UnrollBroadcastPattern(MLIRContext *context,
793 const vector::UnrollVectorOptions &options,
794 PatternBenefit benefit = 1)
795 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
796 options(options) {}
797
798 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
799 PatternRewriter &rewriter) const override {
800 auto targetShape = getTargetShape(options, broadcastOp);
801 if (!targetShape)
802 return failure();
803
804 Location loc = broadcastOp.getLoc();
805 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
806 VectorType resType = broadcastOp.getResultVectorType();
807 VectorType targetType =
808 resType.cloneWith(*targetShape, resType.getElementType());
809 Value result = arith::ConstantOp::create(rewriter, loc, resType,
810 rewriter.getZeroAttr(resType));
811
812 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
813 SmallVector<int64_t> strides(originalShape.size(), 1);
814
815 for (SmallVector<int64_t> offsets :
816 StaticTileOffsetRange(originalShape, *targetShape)) {
817 Value newSrc;
818 if (!srcType) {
819 // Scalar to vector broadcast.
820 newSrc = broadcastOp.getSource();
821 } else {
822 // Vector to vector broadcast.
823 int64_t rank = srcType.getRank();
824 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
825 SmallVector<int64_t> srcShape(targetShape->end() - rank,
826 targetShape->end());
827 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
828 // adjust the offset and shape for src if the corresponding dim is 1.
829 for (int64_t i = 0; i < rank; ++i) {
830 if (srcType.getDimSize(i) == 1) {
831 srcOffsets[i] = 0;
832 srcShape[i] = 1;
833 }
834 }
835 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
836 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
837 }
838
839 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
840 newSrc, targetType);
841
842 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
843 loc, newOp->getResult(0), result, offsets, strides);
844 }
845
846 rewriter.replaceOp(broadcastOp, result);
847 return success();
848 }
849
850private:
851 vector::UnrollVectorOptions options;
852};
853
854/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
855/// outermost dimension of the operand. For example:
856///
857/// ```
858/// %0:4 = vector.to_elements %v : vector<2x2xf32>
859///
860/// ==>
861///
862/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
863/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
864/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
865/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
866/// ```
867///
868/// When this pattern is applied until a fixed-point is reached,
869/// this will produce a sequence of 1-d from_elements
870/// ops.
871struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
872 UnrollToElements(MLIRContext *context,
873 const vector::UnrollVectorOptions &options,
874 PatternBenefit benefit = 1)
875 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
876 options(options) {}
877
878 LogicalResult matchAndRewrite(vector::ToElementsOp op,
879 PatternRewriter &rewriter) const override {
880
881 TypedValue<VectorType> source = op.getSource();
882 FailureOr<SmallVector<Value>> result =
883 vector::unrollVectorValue(source, rewriter);
884 if (failed(result)) {
885 return failure();
886 }
887 SmallVector<Value> vectors = *result;
888
889 SmallVector<Value> results;
890 for (Value vector : vectors) {
891 auto subElements =
892 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
893 llvm::append_range(results, subElements.getResults());
894 }
895 rewriter.replaceOp(op, results);
896 return success();
897 }
898
899private:
900 vector::UnrollVectorOptions options;
901};
902
903/// This pattern unrolls `vector.step` operations according to the provided
904/// target unroll shape. It decomposes a large step vector into smaller step
905/// vectors (segments) and assembles the result by inserting each computed
906/// segment into the appropriate offset of the original vector.
907///
908/// The pattern does not support scalable vectors and will fail to match them.
909///
910/// For each segment, it adds the base step vector and the segment's offset,
911/// then inserts the result into the output vector at the corresponding
912/// position.
913///
914/// Example:
915/// Given a step operation:
916/// %0 = vector.step : vector<8xindex>
917///
918/// and a target unroll shape of <4>, the pattern produces:
919///
920/// %base = vector.step : vector<4xindex>
921/// %zero = arith.constant dense<0> : vector<8xindex>
922/// %result0 = vector.insert_strided_slice %base, %zero
923/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
924/// %offset = arith.constant dense<4> : vector<4xindex>
925/// %segment1 = arith.addi %base, %offset : vector<4xindex>
926/// %result1 = vector.insert_strided_slice %segment1, %result0
927/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
928///
929struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
930 UnrollStepPattern(MLIRContext *context,
931 const vector::UnrollVectorOptions &options,
932 PatternBenefit benefit = 1)
933 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
934
935 LogicalResult matchAndRewrite(vector::StepOp stepOp,
936 PatternRewriter &rewriter) const override {
937 std::optional<SmallVector<int64_t>> targetShape =
938 getTargetShape(options, stepOp);
939 if (!targetShape)
940 return failure();
941
942 VectorType vecType = stepOp.getType();
943 if (vecType.isScalable()) {
944 // Scalable vectors are not supported by this pattern.
945 return failure();
946 }
947 int64_t originalSize = vecType.getShape()[0];
948 Location loc = stepOp.getLoc();
949 SmallVector<int64_t> strides(1, 1);
950
951 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
952 rewriter.getZeroAttr(vecType));
953
954 auto targetVecType =
955 VectorType::get(*targetShape, vecType.getElementType());
956 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
957 for (const SmallVector<int64_t> &offsets :
958 StaticTileOffsetRange({originalSize}, *targetShape)) {
959 Value bcastOffset = arith::ConstantOp::create(
960 rewriter, loc, targetVecType,
962 targetVecType,
963 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
964 Value tileStep =
965 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
966
967 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
968 loc, tileStep, result, offsets, strides);
969 }
970 rewriter.replaceOp(stepOp, result);
971 return success();
972 }
973
974private:
975 vector::UnrollVectorOptions options;
976};
977
978/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
979/// outermost dimension. For example:
980/// ```
981/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
982///
983/// ==>
984///
985/// %0 = ub.poison : vector<2x3xf32>
986/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
987/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
988/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
989/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
990/// ```
991///
992/// When this pattern is applied until a fixed-point is reached,
993/// this will produce a sequence of 1-d from_elements
994/// ops.
995struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
996 UnrollFromElements(MLIRContext *context,
997 const vector::UnrollVectorOptions &options,
998 PatternBenefit benefit = 1)
999 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
1000 options(options) {}
1001
1002 LogicalResult matchAndRewrite(vector::FromElementsOp op,
1003 PatternRewriter &rewriter) const override {
1004 ValueRange allElements = op.getElements();
1005
1006 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1007 VectorType subTy, int64_t index) {
1008 size_t subTyNumElements = subTy.getNumElements();
1009 assert((index + 1) * subTyNumElements <= allElements.size() &&
1010 "out of bounds");
1011 ValueRange subElements =
1012 allElements.slice(index * subTyNumElements, subTyNumElements);
1013 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1014 };
1015
1016 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1017 }
1018
1019private:
1020 vector::UnrollVectorOptions options;
1021};
1022
1023/// This pattern unrolls `vector.create_mask` operations into smaller mask
1024/// operations based on the target unroll shape. Each unrolled slice computes
1025/// its local mask size in each dimension (d) as:
1026/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1027/// Example:
1028/// Given a create_mask operation:
1029/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1030/// elements
1031///
1032/// and a target unroll shape of <4x8>, the pattern produces:
1033///
1034/// %false = arith.constant dense<false> : vector<8x16xi1>
1035///
1036/// Slice [0,0]:
1037/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1038/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1039/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1040/// : vector<4x8xi1> into vector<8x16xi1>
1041/// Slice [0,8]:
1042/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1043/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1044/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1045/// : vector<4x8xi1> into vector<8x16xi1>
1046/// Slice [4,0]:
1047/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1048/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1049/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1050/// : vector<4x8xi1> into vector<8x16xi1>
1051/// Slice [4,8]:
1052/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1053/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1054/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1055/// : vector<4x8xi1> into vector<8x16xi1>
1056struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1057 UnrollCreateMaskPattern(MLIRContext *context,
1058 const vector::UnrollVectorOptions &options,
1059 PatternBenefit benefit = 1)
1060 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1061 options(options) {}
1062
1063 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1064 PatternRewriter &rewriter) const override {
1065 auto targetShape = getTargetShape(options, createMaskOp);
1066 if (!targetShape)
1067 return failure();
1068
1069 VectorType resultType = createMaskOp.getVectorType();
1070 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1071 Location loc = createMaskOp.getLoc();
1072
1073 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1074 rewriter.getZeroAttr(resultType));
1075 VectorType targetVectorType =
1076 VectorType::get(*targetShape, rewriter.getI1Type());
1077 SmallVector<int64_t> strides(targetShape->size(), 1);
1078
1079 // In each dimension (d), each unrolled vector computes its mask size as:
1080 // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1081 for (SmallVector<int64_t> offsets :
1082 StaticTileOffsetRange(originalSize, *targetShape)) {
1083 SmallVector<Value> unrolledOperands;
1084
1085 for (auto [i, originalMaskOperand] :
1086 llvm::enumerate(createMaskOp.getOperands())) {
1087 Value offsetVal =
1088 arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1089 Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1090 loc, originalMaskOperand, offsetVal);
1091 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1092 Value unrolledDimSize =
1093 arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1094 Value nonNegative =
1095 rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1096 Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1097 loc, nonNegative, unrolledDimSize);
1098 unrolledOperands.push_back(unrolledOperand);
1099 }
1100
1101 auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1102 loc, targetVectorType, unrolledOperands);
1103 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1104 loc, unrolledMask, result, offsets, strides);
1105 }
1106 rewriter.replaceOp(createMaskOp, result);
1107 return success();
1108 }
1109
1110private:
1111 vector::UnrollVectorOptions options;
1112};
1113
1114/// This pattern unrolls `vector.constant_mask` operations into smaller mask
1115/// operations based on the target unroll shape. Each unrolled slice computes
1116/// whether its elements should be masked based on the original mask dimensions
1117/// and the slice's offset position.
1118///
1119/// Example:
1120/// Given a constant_mask operation:
1121/// %0 = vector.constant_mask [6, 10] : vector<8x16xi1>
1122///
1123/// and a target unroll shape of <4x8>, the pattern produces:
1124///
1125/// %false = arith.constant dense<false> : vector<8x16xi1>
1126///
1127/// Slice [0,0]: elements [0:4, 0:8] - fully within [6, 10] bounds
1128/// %mask00 = vector.constant_mask [4, 8] : vector<4x8xi1>
1129/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1130/// : vector<4x8xi1> into vector<8x16xi1>
1131///
1132/// Slice [0,8]: elements [0:4, 8:16] - partially within bounds
1133/// %mask01 = vector.constant_mask [4, 2] : vector<4x8xi1>
1134/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1135/// : vector<4x8xi1> into vector<8x16xi1>
1136///
1137/// Slice [4,0]: elements [4:8, 0:8] - partially within bounds
1138/// %mask10 = vector.constant_mask [2, 8] : vector<4x8xi1>
1139/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1140/// : vector<4x8xi1> into vector<8x16xi1>
1141///
1142/// Slice [4,8]: elements [4:8, 8:16] - partially within bounds
1143/// %mask11 = vector.constant_mask [2, 2] : vector<4x8xi1>
1144/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1145/// : vector<4x8xi1> into vector<8x16xi1>
1146struct UnrollConstantMaskPattern
1147 : public OpRewritePattern<vector::ConstantMaskOp> {
1148 UnrollConstantMaskPattern(MLIRContext *context,
1149 const vector::UnrollVectorOptions &options,
1150 PatternBenefit benefit = 1)
1151 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1152 options(options) {}
1153
1154 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1155 PatternRewriter &rewriter) const override {
1156 std::optional<SmallVector<int64_t>> targetShape =
1157 getTargetShape(options, constantMaskOp);
1158 if (!targetShape)
1159 return failure();
1160
1161 VectorType resultType = constantMaskOp.getVectorType();
1162 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1163 Location loc = constantMaskOp.getLoc();
1164
1165 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1166 rewriter.getZeroAttr(resultType));
1167 VectorType targetVectorType =
1168 VectorType::get(*targetShape, rewriter.getI1Type());
1169 SmallVector<int64_t> strides(targetShape->size(), 1);
1170
1171 // In each dimension (d), each unrolled vector computes its mask size as:
1172 // min(max(originalMaskDim[d] - offset[d], 0), unrolledDimSize[d]).
1173 for (const SmallVector<int64_t> &offsets :
1174 StaticTileOffsetRange(originalSize, *targetShape)) {
1175 SmallVector<int64_t> unrolledMaskDims;
1176
1177 for (auto [i, originalMaskDim] :
1178 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1179 // Calculate how many elements in this dimension should be masked
1180 // for this particular slice
1181 int64_t adjustedMaskSize =
1182 std::max(originalMaskDim - offsets[i], static_cast<int64_t>(0));
1183 int64_t unrolledMaskDim =
1184 std::min(adjustedMaskSize, static_cast<int64_t>((*targetShape)[i]));
1185 unrolledMaskDims.push_back(unrolledMaskDim);
1186 }
1187
1188 auto unrolledMask = rewriter.createOrFold<vector::ConstantMaskOp>(
1189 loc, targetVectorType, unrolledMaskDims);
1190 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1191 loc, unrolledMask, result, offsets, strides);
1192 }
1193 rewriter.replaceOp(constantMaskOp, result);
1194 return success();
1195 }
1196
1197private:
1198 vector::UnrollVectorOptions options;
1199};
1200
1201/// Checks whether extractShape is a contiguous slice of shape.
1202/// For extractShape to be contiguous in shape:
1203/// 1) All but the leading dimension of extractShape and shape must match
1204/// exactly. 2) The total number of elements in shape must be evenly divisible
1205/// by
1206/// the total number of elements in extractShape.
1207/// Examples:
1208/// isContiguous([4, 4], [8, 4]) == true
1209/// isContiguous([2, 4], [8, 4]) == true
1210/// isContiguous([2, 2], [8, 4]) == false
1211/// Removes leading unit dimensions to handle cases like:
1212/// isContiguous([1, 16], [1, 32]) == true
1213static bool isContiguous(ArrayRef<int64_t> extractShape,
1215
1216 if (extractShape.empty() || shape.empty() ||
1217 extractShape.size() > shape.size())
1218 return false;
1219
1220 while (extractShape.size() > 1 && extractShape.front() == 1)
1221 extractShape = extractShape.drop_front();
1222
1223 while (shape.size() > 1 && shape.front() == 1) {
1224 shape = shape.drop_front();
1225 }
1226
1227 size_t rankDiff = shape.size() - extractShape.size();
1228 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1229 return false;
1230
1231 int64_t extractElements = ShapedType::getNumElements(extractShape);
1232 int64_t shapeElements = ShapedType::getNumElements(shape);
1233 return shapeElements % extractElements == 0;
1234}
1235
1236/// Determines what shape to use with `vector.extract_strided_slice` to extract
1237/// a contiguous memory region from a source vector. The extraction must be
1238/// contiguous and contain exactly the specified number of elements. If such an
1239/// extraction shape cannot be determined, returns std::nullopt.
1240/// EXAMPLE 1:
1241/// sourceShape = [16], targetElements = 8
1242/// Working right-to-left:
1243/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1244/// remaining = 8/8 = 1
1245/// Result: [8]
1246///
1247/// EXAMPLE 2:
1248/// sourceShape = [4, 4], targetElements = 8
1249/// Working right-to-left:
1250/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1251/// remaining = 8/4 = 2
1252/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1253/// remaining = 2/2 = 1
1254/// Result: [2, 4]
1255static std::optional<SmallVector<int64_t>>
1256calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1257 int64_t targetElements) {
1258 SmallVector<int64_t> extractShape;
1259 int64_t remainingElements = targetElements;
1260
1261 // Build extract shape from innermost dimension outward to ensure contiguity.
1262 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1263 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1264 extractShape.insert(extractShape.begin(), takeFromDim);
1265
1266 if (remainingElements % takeFromDim != 0)
1267 return std::nullopt; // Not evenly divisible.
1268 remainingElements /= takeFromDim;
1269 }
1270
1271 // Fill remaining dimensions with 1.
1272 while (extractShape.size() < sourceShape.size())
1273 extractShape.insert(extractShape.begin(), 1);
1274
1275 if (ShapedType::getNumElements(extractShape) != targetElements)
1276 return std::nullopt;
1277
1278 return extractShape;
1279}
1280
1281// Convert result offsets to source offsets via linear position.
1283calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1284 ArrayRef<int64_t> sourceShape,
1285 ArrayRef<int64_t> resultShape) {
1286 // Convert result offsets to linear position.
1287 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1288 // Convert linear position to source offsets.
1289 return delinearize(linearIndex, computeStrides(sourceShape));
1290}
1291
1292/// This pattern unrolls `vector.shape_cast` operations according to the
1293/// provided target unroll shape. It unrolls a large shape cast into smaller
1294/// shape casts by extracting contiguous slices from the source vector, casting
1295/// each slice to the target shape, and assembling the result by inserting each
1296/// computed segment into the appropriate offset of the result vector.
1297///
1298/// This pattern only applies when contiguous slices can be extracted from the
1299/// source vector and inserted into the result vector such that each slice
1300/// remains a valid vector (and not decompose to scalars). In these cases, the
1301/// unrolling proceeds as:
1302/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1303/// vector.insert_strided_slice.
1304///
1305/// Example:
1306/// Given a shape cast operation:
1307/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1308///
1309/// and a target unroll shape of <2x4>, the pattern produces:
1310///
1311/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1312/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1313/// : vector<8x2xf32> to vector<4x2xf32>
1314/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1315/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1316/// : vector<2x4xf32> into vector<4x4xf32>
1317/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1318/// : vector<8x2xf32> to vector<4x2xf32>
1319/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1320/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1321/// : vector<2x4xf32> into vector<4x4xf32>
1322///
1323struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1324 UnrollShapeCastPattern(MLIRContext *context,
1325 const vector::UnrollVectorOptions &options,
1326 PatternBenefit benefit = 1)
1327 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1328 options(options) {}
1329
1330 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1331 PatternRewriter &rewriter) const override {
1332 std::optional<SmallVector<int64_t>> targetShape =
1333 getTargetShape(options, shapeCastOp);
1334 if (!targetShape)
1335 return failure();
1336
1337 VectorType sourceType = shapeCastOp.getSourceVectorType();
1338 VectorType resultType = shapeCastOp.getResultVectorType();
1339 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1340 ArrayRef<int64_t> resultShape = resultType.getShape();
1341
1342 if (!isContiguous(*targetShape, resultShape))
1343 return rewriter.notifyMatchFailure(
1344 shapeCastOp, "Only supports cases where target shape is "
1345 "contiguous in result vector shape");
1346
1347 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1348
1349 // Calculate the shape to extract from source.
1350 std::optional<SmallVector<int64_t>> extractShape =
1351 calculateSourceExtractShape(sourceShape, targetElements);
1352 if (!extractShape)
1353 return rewriter.notifyMatchFailure(
1354 shapeCastOp,
1355 "cannot extract target number of elements contiguously from source");
1356
1357 Location loc = shapeCastOp.getLoc();
1358
1359 // Create result vector initialized to zero.
1360 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1361 rewriter.getZeroAttr(resultType));
1362
1363 VectorType targetType =
1364 VectorType::get(*targetShape, sourceType.getElementType());
1365
1366 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1367 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1368
1369 for (SmallVector<int64_t> resultOffsets :
1370 StaticTileOffsetRange(resultShape, *targetShape)) {
1371 SmallVector<int64_t> sourceOffsets =
1372 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1373 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1374 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1376 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1377 loc, targetType, sourceChunk);
1378 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1379 loc, targetChunk, result, resultOffsets, insertStrides);
1380 }
1381
1382 rewriter.replaceOp(shapeCastOp, result);
1383 return success();
1384 }
1385
1386private:
1387 vector::UnrollVectorOptions options;
1388};
1389
1390// Unroll vector::BitCastOp into smaller slice-based bitcast operations.
1391// Decomposes the result vector into target shape chunks and bitcasts
1392// corresponding source slices, accounting for element bitwidth ratios.
1393/// Example:
1394/// Given a bitcast Op:
1395///
1396/// vector.bitcast %src : vector<4x8xf32>
1397///
1398/// and a target unroll shape of <2x4>, the pattern produces:
1399///
1400/// %slice_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x4xf32>
1401/// %slice_0 = vector.bitcast %slice_0 : vector<2x4xf32>
1402/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
1403/// // ... repeat for remaining slices
1404struct UnrollBitCastPattern : public OpRewritePattern<vector::BitCastOp> {
1405 UnrollBitCastPattern(MLIRContext *context,
1406 const vector::UnrollVectorOptions &options,
1407 PatternBenefit benefit = 1)
1408 : OpRewritePattern<vector::BitCastOp>(context, benefit),
1409 options(options) {}
1410
1411 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1412 PatternRewriter &rewriter) const override {
1413 auto targetShape = getTargetShape(options, bitCastOp);
1414 if (!targetShape)
1415 return rewriter.notifyMatchFailure(bitCastOp,
1416 "failed to get target shape");
1417
1418 VectorType sourceType = bitCastOp.getSourceVectorType();
1419 VectorType resultType = bitCastOp.getResultVectorType();
1420 ArrayRef<int64_t> resultShape = resultType.getShape();
1421 Location loc = bitCastOp.getLoc();
1422
1423 if (targetShape->size() != resultShape.size())
1424 return rewriter.notifyMatchFailure(
1425 bitCastOp, "target shape rank must match result rank");
1426
1427 unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
1428 unsigned resultElementBits = resultType.getElementTypeBitWidth();
1429
1430 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1431 targetShape->end());
1432 int64_t lastDim = sourceSliceShape.size() - 1;
1433
1434 sourceSliceShape[lastDim] =
1435 ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
1436
1437 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1438 rewriter.getZeroAttr(resultType));
1439 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1440 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1441
1442 VectorType targetType =
1443 VectorType::get(*targetShape, resultType.getElementType());
1444
1445 for (SmallVector<int64_t> resultOffsets :
1446 StaticTileOffsetRange(resultShape, *targetShape)) {
1447 SmallVector<int64_t> sourceOffsets = resultOffsets;
1448 sourceOffsets[lastDim] =
1449 (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
1450
1451 Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1452 loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
1453 sourceStrides);
1454 Value bitcastSlice = rewriter.createOrFold<vector::BitCastOp>(
1455 loc, targetType, sourceSlice);
1456 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1457 loc, bitcastSlice, result, resultOffsets, resultStrides);
1458 }
1459
1460 rewriter.replaceOp(bitCastOp, result);
1461 return success();
1462 }
1463
1464private:
1465 vector::UnrollVectorOptions options;
1466};
1467
1468/// Pattern to unroll vector.interleave into smaller slice-sized operations.
1469/// Decomposes a large interleave into slices by extracting slices from both
1470/// input vectors, interleaving them, and inserting back into the result.
1471///
1472/// Example:
1473/// Given an interleave Op:
1474///
1475/// vector.interleave %lhs, %rhs : vector<4x8xf32>
1476///
1477/// and a target unroll shape of <2x4>, the pattern produces:
1478///
1479/// %slice_lhs_0 = vector.extract_strided_slice %lhs[0, 0] : vector<2x2xf32>
1480/// %slice_rhs_0 = vector.extract_strided_slice %rhs[0, 0] : vector<2x2xf32>
1481/// %slice_0 = vector.interleave %slice_lhs_0, %slice_rhs_0
1482/// : vector<2x4xf32>
1483/// %result = vector.insert_strided_slice %slice_0, %init[0, 0]
1484/// // ... repeat for remaining slices
1485struct UnrollInterleavePattern : public OpRewritePattern<vector::InterleaveOp> {
1486 UnrollInterleavePattern(MLIRContext *context,
1487 const vector::UnrollVectorOptions &options,
1488 PatternBenefit benefit = 1)
1489 : OpRewritePattern<vector::InterleaveOp>(context, benefit),
1490 options(options) {}
1491
1492 LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
1493 PatternRewriter &rewriter) const override {
1494 auto targetShape = getTargetShape(options, interleaveOp);
1495 if (!targetShape)
1496 return rewriter.notifyMatchFailure(interleaveOp,
1497 "failed to get target shape");
1498
1499 VectorType resultType = interleaveOp.getResultVectorType();
1500 ArrayRef<int64_t> resultShape = resultType.getShape();
1501 Location loc = interleaveOp.getLoc();
1502
1503 if (targetShape->size() != resultShape.size())
1504 return rewriter.notifyMatchFailure(
1505 interleaveOp, "target shape rank must match result rank");
1506
1507 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1508 targetShape->end());
1509 int64_t lastDim = sourceSliceShape.size() - 1;
1510 sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
1511
1512 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1513 rewriter.getZeroAttr(resultType));
1514 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1515 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1516
1517 VectorType targetType =
1518 VectorType::get(*targetShape, resultType.getElementType());
1519
1520 for (SmallVector<int64_t> resultOffsets :
1521 StaticTileOffsetRange(resultShape, *targetShape)) {
1522 SmallVector<int64_t> sourceOffsets = resultOffsets;
1523 sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
1524
1525 Value lhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1526 loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
1527 sourceStrides);
1528 Value rhsSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1529 loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
1530 sourceStrides);
1531 Value interleaveSlice = rewriter.createOrFold<vector::InterleaveOp>(
1532 loc, targetType, lhsSlice, rhsSlice);
1533 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1534 loc, interleaveSlice, result, resultOffsets, resultStrides);
1535 }
1536
1537 rewriter.replaceOp(interleaveOp, result);
1538 return success();
1539 }
1540
1541private:
1542 vector::UnrollVectorOptions options;
1543};
1544
1545/// Pattern to unroll vector.deinterleave into smaller slice-sized operations.
1546/// Decomposes a large deinterleave (which splits a vector into even/odd halves)
1547/// by extracting source slices, deinterleaving them, and inserting into two
1548/// result vectors.
1549///
1550/// Example:
1551/// Given a deinterleave Op:
1552///
1553/// vector.deinterleave %src : vector<4x8xf32>
1554///
1555/// and a target unroll shape of <2x4>, the pattern produces:
1556///
1557/// %slice_0 = vector.extract_strided_slice %src[0, 0] : vector<2x4xf32>
1558/// %slice_lhs_0, %slice_rhs_0 = vector.deinterleave %slice_0 :
1559/// vector<2x4xf32> %result1 = vector.insert_strided_slice %slice_lhs_0,
1560/// %init1[0, 0] %result2 = vector.insert_strided_slice %slice_rhs_0,
1561/// %init2[0, 0]
1562/// // ... repeat for remaining slices
1563struct UnrollDeinterleavePattern
1564 : public OpRewritePattern<vector::DeinterleaveOp> {
1565 UnrollDeinterleavePattern(MLIRContext *context,
1566 const vector::UnrollVectorOptions &options,
1567 PatternBenefit benefit = 1)
1568 : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
1569 options(options) {}
1570
1571 LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
1572 PatternRewriter &rewriter) const override {
1573 auto targetShape = getTargetShape(options, deinterleaveOp);
1574 if (!targetShape)
1575 return rewriter.notifyMatchFailure(deinterleaveOp,
1576 "failed to get target shape");
1577
1578 VectorType resultType = deinterleaveOp.getResultVectorType();
1579 ArrayRef<int64_t> resultShape = resultType.getShape();
1580 Location loc = deinterleaveOp.getLoc();
1581
1582 if (targetShape->size() != resultShape.size())
1583 return rewriter.notifyMatchFailure(
1584 deinterleaveOp, "target shape rank must match result rank");
1585
1586 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1587 targetShape->end());
1588 int64_t lastDim = sourceSliceShape.size() - 1;
1589 sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
1590
1591 Value resultOdd = arith::ConstantOp::create(
1592 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
1593 Value resultEven = arith::ConstantOp::create(
1594 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
1595 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1596 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1597
1598 for (SmallVector<int64_t> resultOffsets :
1599 StaticTileOffsetRange(resultShape, *targetShape)) {
1600 SmallVector<int64_t> sourceOffsets = resultOffsets;
1601 sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
1602
1603 Value sourceSlice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1604 loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
1605 sourceStrides);
1606
1607 auto deinterleaveSlice =
1608 vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
1609
1610 resultOdd = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1611 loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
1612 resultStrides);
1613 resultEven = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1614 loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
1615 resultStrides);
1616 }
1617
1618 rewriter.replaceOp(deinterleaveOp, ValueRange{resultOdd, resultEven});
1619 return success();
1620 }
1621
1622private:
1623 vector::UnrollVectorOptions options;
1624};
1625
1626} // namespace
1627
1628void mlir::vector::populateVectorUnrollPatterns(
1630 PatternBenefit benefit) {
1631 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1632 UnrollContractionPattern, UnrollElementwisePattern,
1633 UnrollReductionPattern, UnrollMultiReductionPattern,
1634 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1635 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1636 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1637 UnrollCreateMaskPattern, UnrollConstantMaskPattern,
1638 UnrollBitCastPattern, UnrollInterleavePattern,
1639 UnrollDeinterleavePattern>(patterns.getContext(), options,
1640 benefit);
1641}
1642
1643void mlir::vector::populateVectorToElementsUnrollPatterns(
1644 RewritePatternSet &patterns, PatternBenefit benefit) {
1645 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1646 benefit);
1647}
1648
1649void mlir::vector::populateVectorFromElementsUnrollPatterns(
1650 RewritePatternSet &patterns, PatternBenefit benefit) {
1651 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1652 benefit);
1653}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.