MLIR 22.0.0git
VectorUnroll.cpp
Go to the documentation of this file.
1//===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to do vector unrolling and vector distribution.
10//
11//===----------------------------------------------------------------------===//
12
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
22#include <optional>
23
24#define DEBUG_TYPE "vector-unroll"
25
26using namespace mlir;
27using namespace mlir::vector;
28
29/// Compute the indices of the slice `index` for a transfer op.
32 AffineMap permutationMap,
33 Location loc,
34 OpBuilder &builder) {
35 MLIRContext *ctx = builder.getContext();
36 auto isBroadcast = [](AffineExpr expr) {
37 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
39 return false;
40 };
41 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
42 SmallVector<Value> slicedIndices(indices);
43 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
44 if (isBroadcast(dim.value()))
45 continue;
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
47 auto expr = getAffineDimExpr(0, builder.getContext()) +
48 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
49 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
50 slicedIndices[pos] =
51 affine::AffineApplyOp::create(builder, loc, map, indices[pos]);
52 }
53 return slicedIndices;
54}
55
56// Compute the new indices by adding `offsets` to `originalIndices`.
57// If m < n (m = offsets.size(), n = originalIndices.size()),
58// then only the trailing m values in `originalIndices` are updated.
60 Location loc,
61 OperandRange originalIndices,
62 ArrayRef<int64_t> offsets) {
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
65 SmallVector<Value> indices(originalIndices);
66
67 auto start = indices.size() - offsets.size();
68 for (auto [i, offset] : llvm::enumerate(offsets)) {
69 if (offset != 0) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
72 arith::ConstantIndexOp::create(rewriter, loc, offset));
73 }
74 }
75 return indices;
76}
77
78// Clones `op` into a new operations that takes `operands` and returns
79// `resultTypes`.
81 Operation *op,
82 ArrayRef<Value> operands,
83 ArrayRef<Type> resultTypes) {
84 return builder.create(loc, op->getName().getIdentifier(), operands,
85 resultTypes, op->getAttrs());
86}
87
88/// Return the target shape for unrolling for the given `op`. Return
89/// std::nullopt if the op shouldn't be or cannot be unrolled.
90static std::optional<SmallVector<int64_t>>
92 LDBG() << "Get unroll shape for op " << op->getName().getStringRef();
93 if (options.filterConstraint && failed(options.filterConstraint(op))) {
94 LDBG() << "--no filter constraint -> BAIL";
95 return std::nullopt;
96 }
97 assert(options.nativeShape &&
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() << "--not an unrollable op -> BAIL";
103 return std::nullopt;
104 }
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() << "--could not get shape of op " << *op << " -> BAIL";
108 return std::nullopt;
109 }
110 LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
111
112 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
113 if (!targetShape) {
114 LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP";
115 return std::nullopt;
116 }
117 LDBG() << "--target shape: " << llvm::interleaved(*targetShape);
118
119 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() << "--could not compute integral shape ratio -> BAIL";
122 return std::nullopt;
123 }
124 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
125 LDBG() << "--no unrolling needed -> SKIP";
126 return std::nullopt;
127 }
128 LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS";
129 return targetShape;
130}
131
133getUnrollOrder(unsigned numLoops, Operation *op,
135 SmallVector<int64_t> loopOrder =
136 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
137 if (options.traversalOrderCallback != nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
140 if (order) {
141 loopOrder = std::move(*order);
142 }
143 }
144 return loopOrder;
145}
146
147namespace {
148
149struct UnrollTransferReadPattern
150 : public OpRewritePattern<vector::TransferReadOp> {
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 options(options) {}
156
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter) const override {
159 // TODO: support 0-d corner case.
160 if (readOp.getTransferRank() == 0)
161 return failure();
162 if (readOp.getMask())
163 return failure();
164 auto targetShape = getTargetShape(options, readOp);
165 if (!targetShape)
166 return failure();
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
171
172 // Prepare the result vector;
173 Value result =
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 rewriter.getZeroAttr(sourceVectorType));
176 auto targetType =
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
181 getUnrollOrder(originalSize.size(), readOp, options);
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
184 SmallVector<Value> indices =
185 sliceTransferIndices(elementOffsets, originalIndices,
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(), indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
191
192 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
193 loc, slicedRead, result, elementOffsets, strides);
194 }
195 rewriter.replaceOp(readOp, result);
196 return success();
197 }
198
199private:
200 vector::UnrollVectorOptions options;
201};
202
203struct UnrollTransferWritePattern
204 : public OpRewritePattern<vector::TransferWriteOp> {
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 options(options) {}
210
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter) const override {
213 // TODO: support 0-d corner case.
214 if (writeOp.getTransferRank() == 0)
215 return failure();
216
217 if (writeOp.getMask())
218 return failure();
219 auto targetShape = getTargetShape(options, writeOp);
220 if (!targetShape)
221 return failure();
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
226 // Bail-out if rank(source) != rank(target). The main limitation here is the
227 // fact that `ExtractStridedSlice` requires the rank for the input and
228 // output to match. If needed, we can relax this later.
229 if (originalSize.size() != targetShape->size())
230 return rewriter.notifyMatchFailure(
231 writeOp,
232 "expected source input vector rank to match target shape rank");
233
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
237 getUnrollOrder(originalSize.size(), writeOp, options);
238 Value resultTensor;
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
243 SmallVector<Value> indices =
244 sliceTransferIndices(elementOffsets, originalIndices,
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(), indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 // For the tensor case update the destination for the next transfer write.
251 if (!slicedWrite->getResults().empty())
252 resultTensor = slicedWrite->getResult(0);
253 }
254 if (resultTensor)
255 rewriter.replaceOp(writeOp, resultTensor);
256 else
257 rewriter.eraseOp(writeOp);
258 return success();
259 }
260
261private:
262 vector::UnrollVectorOptions options;
263};
264
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
267
268 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
269
270 static unsigned getHashValue(const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
272 }
273
274 static bool isEqual(const SmallVector<int64_t> &lhs,
275 const SmallVector<int64_t> &rhs) {
276 return lhs == rhs;
277 }
278};
279
280struct UnrollContractionPattern
281 : public OpRewritePattern<vector::ContractionOp> {
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
286 options(options) {}
287
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter) const override {
290 auto targetShape = getTargetShape(options, contractOp);
291 if (!targetShape)
292 return failure();
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
295
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
299 llvm::MapVector<
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
302 accCache;
303
304 SmallVector<int64_t> loopOrder = getUnrollOrder(
305 contractOp.getIteratorTypes().size(), contractOp, options);
306
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310
311 // Helper to compute the new shape of each operand and extract the slice.
312 auto extractOperand = [&](unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
315 SmallVector<int64_t> operandShape = applyPermutationMap(
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
319 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
320 loc, operand, operandOffets, operandShape, operandStrides);
321 };
322
323 // Extract the new lhs operand.
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
326 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328
329 // Extract the new rhs operand.
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
332 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
334
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
337 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
338 // If a version of the accumulator has already been computed, use it
339 // otherwise extract the first version from the original operand.
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
343 else
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
345
346 SmallVector<int64_t> dstShape =
347 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
349 Operation *newOp = cloneOpWithOperandsAndTypes(
350 rewriter, loc, contractOp, slicesOperands, targetType);
351
352 SmallVector<int64_t> dstOffets =
353 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
354 // Save the accumulated value untill all the loops are unrolled since
355 // reduction loop keep updating the accumulator.
356 accCache[dstOffets] = newOp->getResult(0);
357 }
358 // Assemble back the accumulator into a single vector.
359 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
360 rewriter.getZeroAttr(dstVecType));
361 for (const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
363 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
364 loc, it.second, result, it.first, dstStrides);
365 }
366 rewriter.replaceOp(contractOp, result);
367 return success();
368 }
369
370private:
371 vector::UnrollVectorOptions options;
372};
373
374struct UnrollMultiReductionPattern
375 : public OpRewritePattern<vector::MultiDimReductionOp> {
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380 options(options) {}
381
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter) const override {
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
386 return rewriter.notifyMatchFailure(reductionOp,
387 "Unrolling scalars is not supported");
388 }
389 std::optional<SmallVector<int64_t>> targetShape =
390 getTargetShape(options, reductionOp);
391 if (!targetShape)
392 return failure();
393 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
394 llvm::MapVector<
395 SmallVector<int64_t>, Value,
396 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
397 accCache;
398 Location loc = reductionOp.getLoc();
399
400 // Stride of the ratios, this gives us the offsets of sliceCount in a basis
401 // of multiples of the targetShape.
402 for (SmallVector<int64_t> offsets :
403 StaticTileOffsetRange(originalSize, *targetShape)) {
404 SmallVector<Value> operands;
405 SmallVector<int64_t> operandStrides(offsets.size(), 1);
406 Value slicedOperand =
407 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
408 loc, reductionOp.getSource(), offsets, *targetShape,
409 operandStrides);
410 operands.push_back(slicedOperand);
411 SmallVector<int64_t> dstShape;
412 SmallVector<int64_t> destOffset;
413 for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
417 }
418 }
419 Value acc;
420 SmallVector<int64_t> accStrides(destOffset.size(), 1);
421 // If a version of the accumulator has already been computed, use it
422 // otherwise extract the first version from the original operand.
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
425 acc = accIt->second;
426 else
427 acc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
430 auto targetType = VectorType::get(
431 dstShape, reductionOp.getSourceVectorType().getElementType());
432 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
433 operands, targetType);
434 Value result = newOp->getResult(0);
435 accCache[destOffset] = result;
436 }
437 // Assemble back the accumulator into a single vector.
438 Value result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
440 rewriter.getZeroAttr(reductionOp.getDestType()));
441 for (const auto &it : accCache) {
442 SmallVector<int64_t> dstStrides(it.first.size(), 1);
443 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
444 loc, it.second, result, it.first, dstStrides);
445 }
446 rewriter.replaceOp(reductionOp, result);
447 return success();
448 }
449
450private:
451 vector::UnrollVectorOptions options;
452};
453
454struct UnrollElementwisePattern : public RewritePattern {
455 UnrollElementwisePattern(MLIRContext *context,
456 const vector::UnrollVectorOptions &options,
457 PatternBenefit benefit = 1)
458 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
459 options(options) {}
460
461 LogicalResult matchAndRewrite(Operation *op,
462 PatternRewriter &rewriter) const override {
464 return failure();
465 auto targetShape = getTargetShape(options, op);
466 if (!targetShape)
467 return failure();
468 int64_t targetShapeRank = targetShape->size();
469 auto dstVecType = cast<VectorType>(op->getResult(0).getType());
470 SmallVector<int64_t> originalSize =
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
473
474 Location loc = op->getLoc();
475
476 // Handle rank mismatch by adding leading unit dimensions to targetShape
477 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
483
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
485 // Prepare the result vector.
486 Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
487 rewriter.getZeroAttr(dstVecType));
488 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489 VectorType unrolledVecType =
490 VectorType::get(*targetShape, dstVecType.getElementType());
491
492 // Create the unrolled computation.
493 for (SmallVector<int64_t> offsets :
494 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495 SmallVector<Value> extractOperands;
496 for (OpOperand &operand : op->getOpOperands()) {
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
498 if (!vecType) {
499 extractOperands.push_back(operand.get());
500 continue;
501 }
502 Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
504
505 // Reshape to remove leading unit dims if needed
506 if (adjustedTargetShapeRank > targetShapeRank) {
507 extracted = rewriter.createOrFold<vector::ShapeCastOp>(
508 loc, VectorType::get(*targetShape, vecType.getElementType()),
509 extracted);
510 }
511 extractOperands.push_back(extracted);
512 }
513
514 Operation *newOp = cloneOpWithOperandsAndTypes(
515 rewriter, loc, op, extractOperands, unrolledVecType);
516
517 Value computeResult = newOp->getResult(0);
518
519 // Use strides sized to targetShape for proper insertion
520 SmallVector<int64_t> insertStrides =
521 (adjustedTargetShapeRank > targetShapeRank)
522 ? SmallVector<int64_t>(targetShapeRank, 1)
523 : strides;
524
525 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
526 loc, computeResult, result, offsets, insertStrides);
527 }
528 rewriter.replaceOp(op, result);
529 return success();
530 }
531
532private:
533 vector::UnrollVectorOptions options;
534};
535
536struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
537 UnrollReductionPattern(MLIRContext *context,
538 const vector::UnrollVectorOptions &options,
539 PatternBenefit benefit = 1)
540 : OpRewritePattern<vector::ReductionOp>(context, benefit),
541 options(options) {}
542
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544 PatternRewriter &rewriter) const override {
545 std::optional<SmallVector<int64_t>> targetShape =
546 getTargetShape(options, reductionOp);
547 if (!targetShape)
548 return failure();
549 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
550
551 // Create unrolled vector reduction.
552 Location loc = reductionOp.getLoc();
553 Value accumulator = nullptr;
554 for (SmallVector<int64_t> offsets :
555 StaticTileOffsetRange(originalSize, *targetShape)) {
556 SmallVector<int64_t> strides(offsets.size(), 1);
557 Value slicedOperand =
558 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
560 Operation *newOp = cloneOpWithOperandsAndTypes(
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
562 Value result = newOp->getResult(0);
563
564 if (!accumulator) {
565 // This is the first reduction.
566 accumulator = result;
567 } else {
568 // On subsequent reduction, combine with the accumulator.
569 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
570 accumulator, result);
571 }
572 }
573
574 rewriter.replaceOp(reductionOp, accumulator);
575 return success();
576 }
577
578private:
579 const vector::UnrollVectorOptions options;
580};
581
582struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
583 UnrollTransposePattern(MLIRContext *context,
584 const vector::UnrollVectorOptions &options,
585 PatternBenefit benefit = 1)
586 : OpRewritePattern<vector::TransposeOp>(context, benefit),
587 options(options) {}
588
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590 PatternRewriter &rewriter) const override {
591 if (transposeOp.getResultVectorType().getRank() == 0)
592 return failure();
593 auto targetShape = getTargetShape(options, transposeOp);
594 if (!targetShape)
595 return failure();
596 auto originalVectorType = transposeOp.getResultVectorType();
597 SmallVector<int64_t> strides(targetShape->size(), 1);
598 Location loc = transposeOp.getLoc();
599 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
600
601 // Prepare the result vector;
602 Value result =
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
604 rewriter.getZeroAttr(originalVectorType));
605 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
606
607 // Unroll the computation.
608 for (SmallVector<int64_t> elementOffsets :
609 StaticTileOffsetRange(originalSize, *targetShape)) {
610 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611 SmallVector<int64_t> permutedShape(elementOffsets.size());
612 // Compute the source offsets and shape.
613 for (auto indices : llvm::enumerate(permutation)) {
614 permutedOffsets[indices.value()] = elementOffsets[indices.index()];
615 permutedShape[indices.value()] = (*targetShape)[indices.index()];
616 }
617 Value slicedOperand =
618 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
620 strides);
621 Value transposedSlice = rewriter.createOrFold<vector::TransposeOp>(
622 loc, slicedOperand, permutation);
623 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
624 loc, transposedSlice, result, elementOffsets, strides);
625 }
626 rewriter.replaceOp(transposeOp, result);
627 return success();
628 }
629
630private:
631 vector::UnrollVectorOptions options;
632};
633
634struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
635 UnrollGatherPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
639 }
640
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642 PatternRewriter &rewriter) const override {
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
645 return failure();
646 auto targetShape = getTargetShape(options, gatherOp);
647 if (!targetShape)
648 return failure();
649 SmallVector<int64_t> strides(targetShape->size(), 1);
650 Location loc = gatherOp.getLoc();
651 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
652
653 // Prepare the result vector;
654 Value result =
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
656 rewriter.getZeroAttr(sourceVectorType));
657 auto targetType =
658 VectorType::get(*targetShape, sourceVectorType.getElementType());
659
660 SmallVector<int64_t> loopOrder =
661 getUnrollOrder(originalSize.size(), gatherOp, options);
662 for (SmallVector<int64_t> elementOffsets :
663 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
664 // To get the unrolled gather, extract the same slice based on the
665 // decomposed shape from each of the index, mask, and pass-through
666 // vectors.
667 Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669 Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
672 rewriter.createOrFold<vector::ExtractStridedSliceOp>(
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
674 strides);
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
678
679 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
680 loc, slicedGather, result, elementOffsets, strides);
681 }
682 rewriter.replaceOp(gatherOp, result);
683 return success();
684 }
685
686private:
687 vector::UnrollVectorOptions options;
688};
689
690struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
691 UnrollLoadPattern(MLIRContext *context,
692 const vector::UnrollVectorOptions &options,
693 PatternBenefit benefit = 1)
694 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
695
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter) const override {
698 VectorType vecType = loadOp.getVectorType();
699
700 auto targetShape = getTargetShape(options, loadOp);
701 if (!targetShape)
702 return failure();
703
704 Location loc = loadOp.getLoc();
705 ArrayRef<int64_t> originalShape = vecType.getShape();
706 SmallVector<int64_t> strides(targetShape->size(), 1);
707
708 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
709 rewriter.getZeroAttr(vecType));
710
711 SmallVector<int64_t> loopOrder =
712 getUnrollOrder(originalShape.size(), loadOp, options);
713
714 auto targetVecType =
715 VectorType::get(*targetShape, vecType.getElementType());
716
717 for (SmallVector<int64_t> offsets :
718 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
719 SmallVector<Value> indices =
720 sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
722 loadOp.getBase(), indices);
723 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
724 loc, slicedLoad, result, offsets, strides);
725 }
726 rewriter.replaceOp(loadOp, result);
727 return success();
728 }
729
730private:
731 vector::UnrollVectorOptions options;
732};
733
734struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
735 UnrollStorePattern(MLIRContext *context,
736 const vector::UnrollVectorOptions &options,
737 PatternBenefit benefit = 1)
738 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
739
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741 PatternRewriter &rewriter) const override {
742 VectorType vecType = storeOp.getVectorType();
743
744 auto targetShape = getTargetShape(options, storeOp);
745 if (!targetShape)
746 return failure();
747
748 Location loc = storeOp.getLoc();
749 ArrayRef<int64_t> originalShape = vecType.getShape();
750 SmallVector<int64_t> strides(targetShape->size(), 1);
751
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
754
755 SmallVector<int64_t> loopOrder =
756 getUnrollOrder(originalShape.size(), storeOp, options);
757
758 for (SmallVector<int64_t> offsets :
759 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
760 SmallVector<Value> indices =
761 sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
762 Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base, indices);
765 }
766 rewriter.eraseOp(storeOp);
767 return success();
768 }
769
770private:
771 vector::UnrollVectorOptions options;
772};
773
774struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
775 UnrollBroadcastPattern(MLIRContext *context,
776 const vector::UnrollVectorOptions &options,
777 PatternBenefit benefit = 1)
778 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
779 options(options) {}
780
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782 PatternRewriter &rewriter) const override {
783 auto targetShape = getTargetShape(options, broadcastOp);
784 if (!targetShape)
785 return failure();
786
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value result = arith::ConstantOp::create(rewriter, loc, resType,
793 rewriter.getZeroAttr(resType));
794
795 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796 SmallVector<int64_t> strides(originalShape.size(), 1);
797
798 for (SmallVector<int64_t> offsets :
799 StaticTileOffsetRange(originalShape, *targetShape)) {
800 Value newSrc;
801 if (!srcType) {
802 // Scalar to vector broadcast.
803 newSrc = broadcastOp.getSource();
804 } else {
805 // Vector to vector broadcast.
806 int64_t rank = srcType.getRank();
807 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808 SmallVector<int64_t> srcShape(targetShape->end() - rank,
809 targetShape->end());
810 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
811 // adjust the offset and shape for src if the corresponding dim is 1.
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
814 srcOffsets[i] = 0;
815 srcShape[i] = 1;
816 }
817 }
818 newSrc = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
820 }
821
822 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, broadcastOp,
823 newSrc, targetType);
824
825 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
826 loc, newOp->getResult(0), result, offsets, strides);
827 }
828
829 rewriter.replaceOp(broadcastOp, result);
830 return success();
831 }
832
833private:
834 vector::UnrollVectorOptions options;
835};
836
837/// Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838/// outermost dimension of the operand. For example:
839///
840/// ```
841/// %0:4 = vector.to_elements %v : vector<2x2xf32>
842///
843/// ==>
844///
845/// %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
846/// %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
847/// %0:4 = vector.to_elements %v0 : vector<2x2xf32>
848/// %1:4 = vector.to_elements %v1 : vector<2x2xf32>
849/// ```
850///
851/// When this pattern is applied until a fixed-point is reached,
852/// this will produce a sequence of 1-d from_elements
853/// ops.
854struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
855 UnrollToElements(MLIRContext *context,
856 const vector::UnrollVectorOptions &options,
857 PatternBenefit benefit = 1)
858 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
859 options(options) {}
860
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
862 PatternRewriter &rewriter) const override {
863
864 TypedValue<VectorType> source = op.getSource();
865 FailureOr<SmallVector<Value>> result =
866 vector::unrollVectorValue(source, rewriter);
867 if (failed(result)) {
868 return failure();
869 }
870 SmallVector<Value> vectors = *result;
871
872 SmallVector<Value> results;
873 for (Value vector : vectors) {
874 auto subElements =
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
877 }
878 rewriter.replaceOp(op, results);
879 return success();
880 }
881
882private:
883 vector::UnrollVectorOptions options;
884};
885
886/// This pattern unrolls `vector.step` operations according to the provided
887/// target unroll shape. It decomposes a large step vector into smaller step
888/// vectors (segments) and assembles the result by inserting each computed
889/// segment into the appropriate offset of the original vector.
890///
891/// The pattern does not support scalable vectors and will fail to match them.
892///
893/// For each segment, it adds the base step vector and the segment's offset,
894/// then inserts the result into the output vector at the corresponding
895/// position.
896///
897/// Example:
898/// Given a step operation:
899/// %0 = vector.step : vector<8xindex>
900///
901/// and a target unroll shape of <4>, the pattern produces:
902///
903/// %base = vector.step : vector<4xindex>
904/// %zero = arith.constant dense<0> : vector<8xindex>
905/// %result0 = vector.insert_strided_slice %base, %zero
906/// {offsets = [0], strides = [1]} : vector<4xindex> into vector<8xindex>
907/// %offset = arith.constant dense<4> : vector<4xindex>
908/// %segment1 = arith.addi %base, %offset : vector<4xindex>
909/// %result1 = vector.insert_strided_slice %segment1, %result0
910/// {offsets = [4], strides = [1]} : vector<4xindex> into vector<8xindex>
911///
912struct UnrollStepPattern : public OpRewritePattern<vector::StepOp> {
913 UnrollStepPattern(MLIRContext *context,
914 const vector::UnrollVectorOptions &options,
915 PatternBenefit benefit = 1)
916 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
917
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
919 PatternRewriter &rewriter) const override {
920 std::optional<SmallVector<int64_t>> targetShape =
921 getTargetShape(options, stepOp);
922 if (!targetShape)
923 return failure();
924
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
927 // Scalable vectors are not supported by this pattern.
928 return failure();
929 }
930 int64_t originalSize = vecType.getShape()[0];
931 Location loc = stepOp.getLoc();
932 SmallVector<int64_t> strides(1, 1);
933
934 Value result = arith::ConstantOp::create(rewriter, loc, vecType,
935 rewriter.getZeroAttr(vecType));
936
937 auto targetVecType =
938 VectorType::get(*targetShape, vecType.getElementType());
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940 for (const SmallVector<int64_t> &offsets :
941 StaticTileOffsetRange({originalSize}, *targetShape)) {
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
945 targetVecType,
946 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
947 Value tileStep =
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
949
950 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
951 loc, tileStep, result, offsets, strides);
952 }
953 rewriter.replaceOp(stepOp, result);
954 return success();
955 }
956
957private:
958 vector::UnrollVectorOptions options;
959};
960
961/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
962/// outermost dimension. For example:
963/// ```
964/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
965///
966/// ==>
967///
968/// %0 = ub.poison : vector<2x3xf32>
969/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
970/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
971/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
972/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
973/// ```
974///
975/// When this pattern is applied until a fixed-point is reached,
976/// this will produce a sequence of 1-d from_elements
977/// ops.
978struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
979 UnrollFromElements(MLIRContext *context,
980 const vector::UnrollVectorOptions &options,
981 PatternBenefit benefit = 1)
982 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
983 options(options) {}
984
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
986 PatternRewriter &rewriter) const override {
987 ValueRange allElements = op.getElements();
988
989 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
993 "out of bounds");
994 ValueRange subElements =
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
997 };
998
999 return unrollVectorOp(op, rewriter, unrollFromElementsFn);
1000 }
1001
1002private:
1003 vector::UnrollVectorOptions options;
1004};
1005
1006/// Checks whether extractShape is a contiguous slice of shape.
1007/// For extractShape to be contiguous in shape:
1008/// 1) All but the leading dimension of extractShape and shape must match
1009/// exactly. 2) The total number of elements in shape must be evenly divisible
1010/// by
1011/// the total number of elements in extractShape.
1012/// Examples:
1013/// isContiguous([4, 4], [8, 4]) == true
1014/// isContiguous([2, 4], [8, 4]) == true
1015/// isContiguous([2, 2], [8, 4]) == false
1016/// Removes leading unit dimensions to handle cases like:
1017/// isContiguous([1, 16], [1, 32]) == true
1018static bool isContiguous(ArrayRef<int64_t> extractShape,
1020
1021 if (extractShape.size() > shape.size())
1022 return false;
1023
1024 while (!extractShape.empty() && extractShape.front() == 1) {
1025 extractShape = extractShape.drop_front();
1026 }
1027
1028 while (!shape.empty() && shape.front() == 1) {
1029 shape = shape.drop_front();
1030 }
1031
1032 size_t rankDiff = shape.size() - extractShape.size();
1033 if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
1034 return false;
1035
1036 int64_t extractElements = ShapedType::getNumElements(extractShape);
1037 int64_t shapeElements = ShapedType::getNumElements(shape);
1038 return shapeElements % extractElements == 0;
1039}
1040
1041/// Determines what shape to use with `vector.extract_strided_slice` to extract
1042/// a contiguous memory region from a source vector. The extraction must be
1043/// contiguous and contain exactly the specified number of elements. If such an
1044/// extraction shape cannot be determined, returns std::nullopt.
1045/// EXAMPLE 1:
1046/// sourceShape = [16], targetElements = 8
1047/// Working right-to-left:
1048/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
1049/// remaining = 8/8 = 1
1050/// Result: [8]
1051///
1052/// EXAMPLE 2:
1053/// sourceShape = [4, 4], targetElements = 8
1054/// Working right-to-left:
1055/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
1056/// remaining = 8/4 = 2
1057/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1058/// remaining = 2/2 = 1
1059/// Result: [2, 4]
1060static std::optional<SmallVector<int64_t>>
1061calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
1062 int64_t targetElements) {
1063 SmallVector<int64_t> extractShape;
1064 int64_t remainingElements = targetElements;
1065
1066 // Build extract shape from innermost dimension outward to ensure contiguity.
1067 for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1068 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1069 extractShape.insert(extractShape.begin(), takeFromDim);
1070
1071 if (remainingElements % takeFromDim != 0)
1072 return std::nullopt; // Not evenly divisible.
1073 remainingElements /= takeFromDim;
1074 }
1075
1076 // Fill remaining dimensions with 1.
1077 while (extractShape.size() < sourceShape.size())
1078 extractShape.insert(extractShape.begin(), 1);
1079
1080 if (ShapedType::getNumElements(extractShape) != targetElements)
1081 return std::nullopt;
1082
1083 return extractShape;
1084}
1085
1086// Convert result offsets to source offsets via linear position.
1088calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
1089 ArrayRef<int64_t> sourceShape,
1090 ArrayRef<int64_t> resultShape) {
1091 // Convert result offsets to linear position.
1092 int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
1093 // Convert linear position to source offsets.
1094 return delinearize(linearIndex, computeStrides(sourceShape));
1095}
1096
1097/// This pattern unrolls `vector.shape_cast` operations according to the
1098/// provided target unroll shape. It unrolls a large shape cast into smaller
1099/// shape casts by extracting contiguous slices from the source vector, casting
1100/// each slice to the target shape, and assembling the result by inserting each
1101/// computed segment into the appropriate offset of the result vector.
1102///
1103/// This pattern only applies when contiguous slices can be extracted from the
1104/// source vector and inserted into the result vector such that each slice
1105/// remains a valid vector (and not decompose to scalars). In these cases, the
1106/// unrolling proceeds as:
1107/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1108/// vector.insert_strided_slice.
1109///
1110/// Example:
1111/// Given a shape cast operation:
1112/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1113///
1114/// and a target unroll shape of <2x4>, the pattern produces:
1115///
1116/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
1117/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1118/// : vector<8x2xf32> to vector<4x2xf32>
1119/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1120/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1121/// : vector<2x4xf32> into vector<4x4xf32>
1122/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1123/// : vector<8x2xf32> to vector<4x2xf32>
1124/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1125/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1126/// : vector<2x4xf32> into vector<4x4xf32>
1127///
1128struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
1129 UnrollShapeCastPattern(MLIRContext *context,
1130 const vector::UnrollVectorOptions &options,
1131 PatternBenefit benefit = 1)
1132 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1133 options(options) {}
1134
1135 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1136 PatternRewriter &rewriter) const override {
1137 std::optional<SmallVector<int64_t>> targetShape =
1138 getTargetShape(options, shapeCastOp);
1139 if (!targetShape)
1140 return failure();
1141
1142 VectorType sourceType = shapeCastOp.getSourceVectorType();
1143 VectorType resultType = shapeCastOp.getResultVectorType();
1144 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1145 ArrayRef<int64_t> resultShape = resultType.getShape();
1146
1147 if (!isContiguous(*targetShape, resultShape))
1148 return rewriter.notifyMatchFailure(
1149 shapeCastOp, "Only supports cases where target shape is "
1150 "contiguous in result vector shape");
1151
1152 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1153
1154 // Calculate the shape to extract from source.
1155 std::optional<SmallVector<int64_t>> extractShape =
1156 calculateSourceExtractShape(sourceShape, targetElements);
1157 if (!extractShape)
1158 return rewriter.notifyMatchFailure(
1159 shapeCastOp,
1160 "cannot extract target number of elements contiguously from source");
1161
1162 Location loc = shapeCastOp.getLoc();
1163
1164 // Create result vector initialized to zero.
1165 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1166 rewriter.getZeroAttr(resultType));
1167
1168 VectorType targetType =
1169 VectorType::get(*targetShape, sourceType.getElementType());
1170
1171 SmallVector<int64_t> extractStrides(extractShape->size(), 1);
1172 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1173
1174 for (SmallVector<int64_t> resultOffsets :
1175 StaticTileOffsetRange(resultShape, *targetShape)) {
1176 SmallVector<int64_t> sourceOffsets =
1177 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1178 Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
1179 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1181 Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
1182 loc, targetType, sourceChunk);
1183 result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1184 loc, targetChunk, result, resultOffsets, insertStrides);
1185 }
1186
1187 rewriter.replaceOp(shapeCastOp, result);
1188 return success();
1189 }
1190
1191private:
1192 vector::UnrollVectorOptions options;
1193};
1194
1195} // namespace
1196
1197void mlir::vector::populateVectorUnrollPatterns(
1199 PatternBenefit benefit) {
1200 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1201 UnrollContractionPattern, UnrollElementwisePattern,
1202 UnrollReductionPattern, UnrollMultiReductionPattern,
1203 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1204 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1205 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1206 patterns.getContext(), options, benefit);
1207}
1208
1209void mlir::vector::populateVectorToElementsUnrollPatterns(
1211 patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
1212 benefit);
1213}
1214
1215void mlir::vector::populateVectorFromElementsUnrollPatterns(
1217 patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
1218 benefit);
1219}
return success()
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
lhs
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.