MLIR 23.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
14
25#include "mlir/IR/Location.h"
26#include "mlir/IR/Matchers.h"
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVectorExtras.h"
32#include "llvm/Support/FormatVariadic.h"
33
34#include <cassert>
35#include <cstdint>
36#include <functional>
37#include <optional>
38
39#define DEBUG_TYPE "vector-to-vector"
40
41using namespace mlir;
42using namespace mlir::vector;
43
44template <typename IntType>
46 return llvm::to_vector<4>(llvm::map_range(
47 arrayAttr.getAsRange<IntegerAttr>(),
48 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
49}
50
51// Helper to find an index in an affine map.
52static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
53 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
54 int64_t idx = map.getDimPosition(i);
55 if (idx == index)
56 return i;
57 }
58 return std::nullopt;
59}
60
61namespace {
62
63/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
64/// Ex:
65/// ```
66/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
67/// %1 = vector.multi_reduction add, %0 [1]
68/// : vector<8x32x16xf32> to vector<8x16xf32>
69/// ```
70/// Gets converted to:
71/// ```
72/// %1 = vector.contract {indexing_maps = [
73/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
75/// affine_map<(d0, d1, d2) -> (d0, d1)>],
76/// iterator_types = ["parallel", "parallel", "reduction"],
77/// kind = add} %0, %arg1, %cst_f0
78/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
79/// ```
80struct MultiReduceToContract
81 : public OpRewritePattern<vector::MultiDimReductionOp> {
82 using Base::Base;
83
84 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
85 PatternRewriter &rewriter) const override {
86 if (reduceOp.getKind() != vector::CombiningKind::ADD)
87 return failure();
88 Operation *mulOp = reduceOp.getSource().getDefiningOp();
89 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
90 return failure();
91 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
92 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
93 SmallVector<AffineExpr> exprs;
94 SmallVector<vector::IteratorType> iteratorTypes;
95 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
96 if (!isReduceDim.value()) {
97 iteratorTypes.push_back(vector::IteratorType::parallel);
98 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
99 } else {
100 iteratorTypes.push_back(vector::IteratorType::reduction);
101 }
102 }
103 auto dstMap =
104 AffineMap::get(/*dimCount=*/reductionMask.size(),
105 /*symbolCount=*/0, exprs, reduceOp.getContext());
106 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
107 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
108 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
109 rewriter.getArrayAttr(llvm::map_to_vector(
110 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
111 return IteratorTypeAttr::get(rewriter.getContext(), t);
112 })));
113 return success();
114 }
115};
116
117/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
118/// Ex:
119/// ```
120/// %0 = vector.transpose %arg0, [2, 0, 1]
121/// : vector<32x16x8xf32> to vector<8x32x16xf32>
122/// %1 = vector.contract {indexing_maps = [
123/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
125/// affine_map<(d0, d1, d2) -> (d0, d1)>],
126/// iterator_types = ["parallel", "parallel", "reduction"],
127/// kind = add} %0, %arg1, %cst_f0
128/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
129/// ```
130/// Gets converted to:
131/// ```
132/// %1 = vector.contract {indexing_maps = [
133/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
135/// affine_map<(d0, d1, d2) -> (d0, d1)>],
136/// iterator_types = ["parallel", "parallel", "reduction"],
137/// kind = add} %arg0, %arg1, %cst_f0
138/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
139/// ```
140struct CombineContractABTranspose final
141 : public OpRewritePattern<vector::ContractionOp> {
142 using Base::Base;
143
144 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
145 PatternRewriter &rewriter) const override {
146 SmallVector<AffineMap> maps =
147 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
148 Value lhs = contractOp.getLhs();
149 Value rhs = contractOp.getRhs();
150 size_t index = 0;
151 bool changed = false;
152 for (Value *operand : {&lhs, &rhs}) {
153 AffineMap &map = maps[index++];
154 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
155 if (!transposeOp)
156 continue;
157 AffineMap permutationMap = AffineMap::getPermutationMap(
158 transposeOp.getPermutation(), contractOp.getContext());
159 map = inversePermutation(permutationMap).compose(map);
160 *operand = transposeOp.getVector();
161 changed = true;
162 }
163 if (!changed)
164 return failure();
165 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
166 contractOp, lhs, rhs, contractOp.getAcc(),
167 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
168 return success();
169 }
170};
171
172/// Merges accumulator and result transposes into contract.
173///
174/// For example:
175/// ```mlir
176/// %accT = vector.transpose %acc, [0, 2, 1]
177/// : vector<2x8x4xf32> to vector<2x4x8xf32>
178/// %contract = vector.contract {
179/// indexing_maps = [
180/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
181/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
182/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
183/// ],
184/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
185/// kind = #vector.kind<add>
186/// } %lhs, %rhs, %accT
187/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
188/// %0 = vector.transpose %contract, [0, 2, 1]
189/// : vector<2x4x8xf32> to vector<2x8x4>
190/// ```
191/// Becomes:
192/// ```mlir
193/// %0 = vector.contract {
194/// indexing_maps = [
195/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
196/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
197/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
198/// ],
199/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
200/// kind = #vector.kind<add>
201/// } %lhs, %rhs, %acc
202/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
203/// ```
204struct CombineContractResultTranspose final
205 : public OpRewritePattern<vector::TransposeOp> {
206 using Base::Base;
207
208 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
209 PatternRewriter &rewriter) const override {
210 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
211 if (!contractOp || !contractOp->hasOneUse())
212 return failure();
213
214 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
215 if (!accTOp)
216 return failure();
217
218 MLIRContext *context = contractOp.getContext();
219 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
220 AffineMap contractMap = maps.back();
221
222 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
223 // To index into A in contract, we need revert(f)(g(C)) -> A.
224 auto accTMap =
225 AffineMap::getPermutationMap(accTOp.getPermutation(), context);
226
227 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
228 // To index into E in contract, we need h(g(C)) -> E.
229 auto resTMap =
230 AffineMap::getPermutationMap(resTOp.getPermutation(), context);
231 auto combinedResMap = resTMap.compose(contractMap);
232
233 // The accumulator and result share the same indexing map. So they should be
234 // the same to be able to merge. This means combinedResMap is the same as
235 // inversePermutation(accTMap).compose(contractMap), which means
236 if (inversePermutation(accTMap) != resTMap)
237 return failure();
238 maps.back() = combinedResMap;
239
240 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
241 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
242 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
243 return success();
244 }
245};
246
247/// Merge BroadcastOp into ContractionOp user.
248/// Ex:
249/// ```
250/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
251/// %1 = vector.contract {indexing_maps = [
252/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
254/// affine_map<(d0, d1, d2) -> (d0, d1)>],
255/// iterator_types = ["parallel", "parallel", "reduction"],
256/// kind = add} %0, %arg1, %cst_f0
257/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
258/// ```
259/// Gets converted to:
260/// ```
261/// %1 = vector.contract {indexing_maps = [
262/// affine_map<(d0, d1, d2) -> (d1, d2)>,
263/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
264/// affine_map<(d0, d1, d2) -> (d0, d1)>],
265/// iterator_types = ["parallel", "parallel", "reduction"],
266/// kind = add} %arg0, %arg1, %cst_f0
267/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
268/// ```
269///
270/// For masked vector.contract, the mask requires updating when a dimension is
271/// dropped. In such cases, the dropped dimensions must correspond to the mask's
272/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
273/// is not supported.
274FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
275 MaskingOpInterface maskingOp,
276 PatternRewriter &rewriter) {
278 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
279 Value lhs = contractOp.getLhs();
280 Value rhs = contractOp.getRhs();
281 size_t index = 0;
282 bool changed = false;
283 for (Value *operand : {&lhs, &rhs}) {
284 AffineMap &map = maps[index++];
285 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
286 if (!broadcast)
287 continue;
288 // contractionOp can only take vector as operands.
289 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
290 if (!srcType ||
291 srcType.getRank() == broadcast.getResultVectorType().getRank())
292 continue;
293 int64_t rankDiff =
294 broadcast.getResultVectorType().getRank() - srcType.getRank();
295 bool innerDimBroadcast = false;
296 SmallVector<AffineExpr> originalDims;
297 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
298 if (dim.value() !=
299 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
300 innerDimBroadcast = true;
301 break;
302 }
303 originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
304 }
305 // Contract doesn't support inner dimension broadcast. Once this is
306 // relaxed we can remove this case.
307 if (innerDimBroadcast)
308 continue;
309
310 // It would be incorrect to fold a broadcast onto a reduction dimension
311 // of non-unit size.
312 bool nonUnitDimReductionBroadcast = false;
313 for (int64_t i = 0; i < rankDiff; ++i) {
314 if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
315 isReductionIterator(contractOp.getIteratorTypes()
316 .getValue()[map.getDimPosition(i)])) {
317 nonUnitDimReductionBroadcast = true;
318 break;
319 }
320 }
321 if (nonUnitDimReductionBroadcast)
322 continue;
323
324 AffineMap broadcastMap =
325 AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
326 originalDims, contractOp.getContext());
327 map = broadcastMap.compose(map);
328 *operand = broadcast.getSource();
329 changed = true;
330 }
331
332 if (!changed)
333 return failure();
334
335 // Determine which dims are usused, now that the maps have been composed
336 // with the broadcast maps.
337 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
338 // Compress unused dims.
339 for (auto &m : maps)
340 m = compressDims(m, unusedDimsBitVector);
341 // Compute the combined iterators.
342 SmallVector<Attribute> iterators;
343 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
344 if (!unusedDimsBitVector.test(i))
345 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
346 }
347
348 // Check whether any of the unused dims is non-unit, e.g.:
349 // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
350 // This is only required when collapsing a mask. If there is no mask, skip.
351 VectorType oldMaskType;
352 bool isAnyUnusedDimNonUnit = false;
353 if (maskingOp) {
354 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
355 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
356 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
357 isAnyUnusedDimNonUnit = true;
358 break;
359 }
360 }
361 }
362
363 // Check that compressing unused dims isn't removing all reduction dimension
364 // pairs. For example, if the vector.contract had only one reduction
365 // iterator and that was a unit-dimension created by a broadcast,
366 // then we should bail here, otherwise we would create a contract without
367 // a reduction dimension pair.
368 bool hasReductionIteratorApplyingOnBothSides = false;
369 for (unsigned i = 0; i < iterators.size(); ++i) {
370 if (!isReductionIterator(iterators[i]))
371 continue;
372 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
373 hasReductionIteratorApplyingOnBothSides = true;
374 break;
375 }
376 }
377 if (!hasReductionIteratorApplyingOnBothSides)
378 return failure();
379
380 // If the compressed maps have a dimension that is not used by either LHS or
381 // RHS then the ContractionOp verifier would fail.
382 if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
383 return failure();
384
385 Operation *newOp = vector::ContractionOp::create(
386 rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
387 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
388
389 // Handle the mask.
390 if (maskingOp) {
391 if (isAnyUnusedDimNonUnit)
392 return rewriter.notifyMatchFailure(contractOp,
393 "Cannont drop non-unit mask dim.");
394 assert(unusedDimsBitVector.size() ==
395 static_cast<size_t>(oldMaskType.getRank()) &&
396 "The mask rank is incorrect!");
397
398 // If a dimension has been dropped, update the mask accordingly. Otherwise,
399 // keep it as is.
400 Value mask = maskingOp.getMask();
401 if (unusedDimsBitVector.count() != 0) {
402 // At this point, two assumptions are made:
403 // * The unused dimensions are the leading mask dimensions
404 // (vector.contract does not support inner dim broadcasting).
405 // * The unused dimensions are all unit.
406 // These conditions are effectively verified in the blocks preceeding this
407 // one.
408 auto newShape =
409 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
410 auto newShapeScalableDims =
411 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
412 VectorType maskOpType =
413 VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
414 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
415 maskOpType, maskingOp.getMask())
416 .getResult();
417 }
418
419 newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
420 }
421 return newOp->getResult(0);
422}
423
424struct CombineContractBroadcastMask
425 : public MaskableOpRewritePattern<vector::ContractionOp> {
426 using MaskableOpRewritePattern::MaskableOpRewritePattern;
427 FailureOr<Value>
428
429 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430 MaskingOpInterface maskingOp,
431 PatternRewriter &rewriter) const override {
432 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
433 }
434};
435
436/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
437/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
438/// casting ops are around these operations.
439/// Ex:
440/// ```
441/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
442/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
443/// ```
444/// Gets converted to:
445/// ```
446/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
447/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
448/// ```
449struct ReorderCastOpsOnBroadcast
450 : public OpInterfaceRewritePattern<CastOpInterface> {
451 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
452
453 LogicalResult matchAndRewrite(CastOpInterface op,
454 PatternRewriter &rewriter) const override {
455 if (op->getNumOperands() != 1)
456 return failure();
457 if (!isa<VectorType>(op->getResult(0).getType()))
458 return failure();
459 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
460 if (!bcastOp)
461 return failure();
462
463 Type castResTy = getElementTypeOrSelf(op->getResult(0));
464 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
465 castResTy = vecTy.clone(castResTy);
466 auto *castOp =
467 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
468 bcastOp.getSource(), castResTy, op->getAttrs());
469 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
470 op, op->getResult(0).getType(), castOp->getResult(0));
471 return success();
472 }
473};
474
475/// Reorders elementwise(transpose) to transpose(elementwise). This makes
476/// transpose ops and contraction ops closer, which kicks in
477/// CombineContractABTranspose pattern when elementwise ops are between these
478/// operations. Ex:
479/// ```
480/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
481/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
482/// %r = arith.addf %at, %bt : vector<2x4xf32>
483/// ```
484/// Gets converted to:
485/// ```
486/// %0 = arith.addf %a, %b : vector<4x2xf32>
487/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
488/// ```
489struct ReorderElementwiseOpsOnTranspose final
490 : public OpTraitRewritePattern<OpTrait::Elementwise> {
492 LogicalResult matchAndRewrite(Operation *op,
493 PatternRewriter &rewriter) const override {
494 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
495 return failure();
496
497 // Make sure all operands are transpose/constant ops and collect their
498 // transposition maps.
499 SmallVector<ArrayRef<int64_t>> transposeMaps;
500 transposeMaps.reserve(op->getNumOperands());
501 // Record the initial type before transposition. We'll use its shape later.
502 // Any type will do here as we will check all transpose maps are the same.
503 VectorType srcType;
504 for (Value operand : op->getOperands()) {
505 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
506 if (transposeOp) {
507 transposeMaps.push_back(transposeOp.getPermutation());
508 srcType = transposeOp.getSourceVectorType();
509 } else if (!matchPattern(operand, m_Constant())) {
510 return failure();
511 }
512 }
513 if (transposeMaps.empty())
514 return failure();
515 // This is an elementwise op, so all transposed operands should have the
516 // same type. We need to additionally check that all transposes uses the
517 // same map.
518 if (!llvm::all_equal(transposeMaps))
519 return rewriter.notifyMatchFailure(op, "different transpose map");
520
521 SmallVector<Value> srcValues;
522 srcValues.reserve(op->getNumOperands());
523
524 // If there are constant operands, we need to insert inverse transposes for
525 // them. Calculate the inverse order first.
526 auto order = transposeMaps.front();
527 SmallVector<int64_t> invOrder(order.size());
528 for (int i = 0, e = order.size(); i < e; ++i)
529 invOrder[order[i]] = i;
530
531 for (Value operand : op->getOperands()) {
532 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
533 if (transposeOp) {
534 srcValues.push_back(transposeOp.getVector());
535 } else {
536 // This is a constant. Create a reverse transpose op for it.
537 auto vectorType =
538 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
539 srcValues.push_back(vector::TransposeOp::create(
540 rewriter, operand.getLoc(), vectorType, operand, invOrder));
541 }
542 }
543
544 auto vectorType = srcType.clone(
545 cast<VectorType>(op->getResultTypes()[0]).getElementType());
546 Operation *elementwiseOp =
547 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
548 vectorType, op->getAttrs());
549 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
550 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
551 transposeMaps.front());
552 return success();
553 }
554};
555
556// Returns the values in `arrayAttr` as an integer vector.
557static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
558 return llvm::map_to_vector<4>(arrayAttr.getAsRange<IntegerAttr>(),
559 [](IntegerAttr attr) { return attr.getInt(); });
560}
561
562// Shuffles vector.bitcast op after vector.extract op.
563//
564// This transforms IR like:
565// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
566// %1 = vector.extract %0[3] : f16 from vector<8xf16>
567// Into:
568// %0 = vector.extract %src[1] : f32 from vector<4xf32>
569// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
570// %2 = vector.extract %1[1] : f16 from vector<2xf16>
571struct BubbleDownVectorBitCastForExtract
572 : public OpRewritePattern<vector::ExtractOp> {
573 using Base::Base;
574
575 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 PatternRewriter &rewriter) const override {
577 // Only support extracting scalars for now.
578 if (extractOp.getSourceVectorType().getRank() != 1)
579 return failure();
580
581 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
582 if (!castOp)
583 return failure();
584
585 VectorType castSrcType = castOp.getSourceVectorType();
586 VectorType castDstType = castOp.getResultVectorType();
587 assert(castSrcType.getRank() == castDstType.getRank());
588
589 // Fail to match if we only have one element in the cast op source.
590 // This is to avoid infinite loop given that this pattern can generate
591 // such cases.
592 if (castSrcType.getNumElements() == 1)
593 return failure();
594
595 // Only support casting to a larger number of elements or now.
596 // E.g., vector<4xf32> -> vector<8xf16>.
597 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 return failure();
599
600 unsigned expandRatio =
601 castDstType.getNumElements() / castSrcType.getNumElements();
602
603 // Get the first element of the mixed position as integer.
604 auto mixedPos = extractOp.getMixedPosition();
605 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
606 return failure();
607 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
608
609 // Get the single scalar (as a vector) in the source value that packs the
610 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
611 Location loc = extractOp.getLoc();
612 Value packedValue = vector::ExtractOp::create(
613 rewriter, loc, castOp.getSource(), index / expandRatio);
614 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
615 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
616 rewriter.getZeroAttr(packedVecType));
617 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
618 /*position=*/0);
619
620 // Cast it to a vector with the desired scalar's type.
621 // E.g. f32 -> vector<2xf16>
622 VectorType packedType =
623 VectorType::get({expandRatio}, castDstType.getElementType());
624 Value castedValue =
625 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
626
627 // Finally extract the desired scalar.
628 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
629 index % expandRatio);
630 return success();
631 }
632};
633
634// Shuffles vector.bitcast op after vector.extract_strided_slice op.
635//
636// This transforms IR like:
637// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
638// %0 = vector.extract_strided_slice %cast {
639// offsets = [4], sizes = [4], strides = [1]
640// } : vector<8xf16> to vector<4xf16>
641// Into:
642// %0 = vector.extract_strided_slice %src {
643// offsets = [2], sizes = [2], strides = [1]
644// } : vector<4xf32> to vector<2xf32>
645// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
646struct BubbleDownBitCastForStridedSliceExtract
647 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
648 using Base::Base;
649
650 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 PatternRewriter &rewriter) const override {
652 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
653 if (!castOp)
654 return failure();
655
656 VectorType castSrcType = castOp.getSourceVectorType();
657 VectorType castDstType = castOp.getResultVectorType();
658 assert(castSrcType.getRank() == castDstType.getRank());
659
660 int64_t castSrcLastDim = castSrcType.getShape().back();
661 int64_t castDstLastDim = castDstType.getShape().back();
662 // Require casting to more elements for now; other cases to be implemented.
663 if (castSrcLastDim > castDstLastDim)
664 return failure();
665
666 // Only accept all one strides for now.
667 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
668 [](const APInt &val) { return !val.isOne(); }))
669 return failure();
670
671 unsigned rank = extractOp.getSourceVectorType().getRank();
672 assert(castDstLastDim % castSrcLastDim == 0);
673 int64_t expandRatio = castDstLastDim / castSrcLastDim;
674
675 // If we have a less number of offsets than the rank, then implicitly we
676 // are selecting the full range for the last bitcasted dimension; other
677 // dimensions aren't affected. Otherwise, we need to scale down the last
678 // dimension's offset given we are extracting from less elements now.
679 ArrayAttr newOffsets = extractOp.getOffsets();
680 if (newOffsets.size() == rank) {
681 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
682 if (offsets.back() % expandRatio != 0)
683 return failure();
684 offsets.back() = offsets.back() / expandRatio;
685 newOffsets = rewriter.getI64ArrayAttr(offsets);
686 }
687
688 // Similarly for sizes.
689 ArrayAttr newSizes = extractOp.getSizes();
690 if (newSizes.size() == rank) {
691 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
692 if (sizes.back() % expandRatio != 0)
693 return failure();
694 sizes.back() = sizes.back() / expandRatio;
695 newSizes = rewriter.getI64ArrayAttr(sizes);
696 }
697
698 SmallVector<int64_t> dims =
699 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
700 dims.back() = dims.back() / expandRatio;
701 VectorType newExtractType =
702 VectorType::get(dims, castSrcType.getElementType());
703
704 auto newExtractOp = vector::ExtractStridedSliceOp::create(
705 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
706 newOffsets, newSizes, extractOp.getStrides());
707
708 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
709 extractOp, extractOp.getType(), newExtractOp);
710
711 return success();
712 }
713};
714
715// Shuffles vector.bitcast op before vector.insert_strided_slice op.
716//
717// This transforms IR like:
718// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
719// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
720// Into:
721// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
722// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
723// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
724//
725struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
726 using Base::Base;
727
728 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 PatternRewriter &rewriter) const override {
730 VectorType castSrcType = bitcastOp.getSourceVectorType();
731 VectorType castDstType = bitcastOp.getResultVectorType();
732
733 // 0-D and scalable vectors are not supported yet.
734 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
735 castDstType.isScalable())
736 return failure();
737
738 int64_t castSrcLastDim = castSrcType.getShape().back();
739 int64_t castDstLastDim = castDstType.getShape().back();
740 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
741 int64_t ratio;
742 if (isNumElemsShrink) {
743 assert(castSrcLastDim % castDstLastDim == 0);
744 ratio = castSrcLastDim / castDstLastDim;
745 } else {
746 assert(castDstLastDim % castSrcLastDim == 0);
747 ratio = castDstLastDim / castSrcLastDim;
748 }
749
750 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
751 if (!insertOp)
752 return failure();
753
754 // Only vector sources are supported for now.
755 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
756 if (!insertSrcType)
757 return failure();
758
759 // Bitcast the source.
760 SmallVector<int64_t> srcDims(insertSrcType.getShape());
761 srcDims.back() =
762 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
763 VectorType newCastSrcType =
764 VectorType::get(srcDims, castDstType.getElementType());
765 auto newCastSrcOp =
766 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
767 insertOp.getValueToStore());
768
769 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
770 dstDims.back() =
771 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
772 VectorType newCastDstType =
773 VectorType::get(dstDims, castDstType.getElementType());
774
775 // Bitcast the destination.
776 auto newCastDstOp = vector::BitCastOp::create(
777 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
778
779 // Generate new insert.
780 rewriter.replaceOpWithNewOp<vector::InsertOp>(
781 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
782 return success();
783 }
784};
785
786// Shuffles vector.bitcast op before vector.insert_strided_slice op.
787//
788// This transforms IR like:
789// %0 = vector.insert_strided_slice %src, %dst {
790// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
791// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
792// Into:
793// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
794// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
795// %2 = vector.insert_strided_slice %src, %dst {
796// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
797struct BubbleUpBitCastForStridedSliceInsert
798 : public OpRewritePattern<vector::BitCastOp> {
799 using Base::Base;
800
801 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
802 PatternRewriter &rewriter) const override {
803 VectorType castSrcType = bitcastOp.getSourceVectorType();
804 VectorType castDstType = bitcastOp.getResultVectorType();
805 assert(castSrcType.getRank() == castDstType.getRank());
806 // Skip 0-D vector which will not from InsertStridedSliceOp.
807 if (castSrcType.getRank() == 0)
808 return failure();
809
810 int64_t castSrcLastDim = castSrcType.getShape().back();
811 int64_t castDstLastDim = castDstType.getShape().back();
812 // Require casting to less elements for now; other cases to be implemented.
813 if (castSrcLastDim < castDstLastDim)
814 return failure();
815
816 assert(castSrcLastDim % castDstLastDim == 0);
817 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818
819 auto insertOp =
820 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
821 if (!insertOp)
822 return failure();
823
824 // Only accept all one strides for now.
825 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
826 [](const APInt &val) { return !val.isOne(); }))
827 return failure();
828
829 unsigned rank = insertOp.getSourceVectorType().getRank();
830 // Require insert op to have the same rank for the source and destination
831 // vector; other cases to be implemented.
832 if (rank != insertOp.getDestVectorType().getRank())
833 return failure();
834
835 // Requires that shape of insert op src is castable to dstType.
836 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
837 unsigned destinationWidth =
838 castDstType.getElementType().getIntOrFloatBitWidth();
839 unsigned numElements = destinationWidth / sourceWidth;
840 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 return failure();
842
843 ArrayAttr newOffsets = insertOp.getOffsets();
844 assert(newOffsets.size() == rank);
845 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
846 if (offsets.back() % shrinkRatio != 0)
847 return failure();
848 offsets.back() = offsets.back() / shrinkRatio;
849 newOffsets = rewriter.getI64ArrayAttr(offsets);
850
851 SmallVector<int64_t> srcDims =
852 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
853 srcDims.back() = srcDims.back() / shrinkRatio;
854 VectorType newCastSrcType =
855 VectorType::get(srcDims, castDstType.getElementType());
856
857 auto newCastSrcOp =
858 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
859 insertOp.getValueToStore());
860
861 SmallVector<int64_t> dstDims =
862 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
863 dstDims.back() = dstDims.back() / shrinkRatio;
864 VectorType newCastDstType =
865 VectorType::get(dstDims, castDstType.getElementType());
866
867 auto newCastDstOp = vector::BitCastOp::create(
868 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
869
870 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
871 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
872 insertOp.getStrides());
873
874 return success();
875 }
876};
877
878// Breaks down vector.bitcast op
879//
880// This transforms IR like:
881// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
882// Into:
883// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
884// %1 = vector.extract_strided_slice %0 {
885// offsets = [0], sizes = [4], strides = [1]
886// } : vector<8xf16> to vector<4xf16>
887// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
888// %4 = vector.insert_strided_slice %2, %cst {
889// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
890// %5 = vector.extract_strided_slice %0 {
891// offsets = [4], sizes = [4], strides = [1]
892// } : vector<8xf16> to vector<4xf16>
893// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
894// %7 = vector.insert_strided_slice %6, %cst {
895// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
896struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
897 using Base::Base;
898
899public:
900 BreakDownVectorBitCast(MLIRContext *context,
901 std::function<bool(vector::BitCastOp)> controlFn,
902 PatternBenefit benefit)
903 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
904
905 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 PatternRewriter &rewriter) const override {
907
908 if (controlFn && !controlFn(bitcastOp))
909 return failure();
910
911 VectorType castSrcType = bitcastOp.getSourceVectorType();
912 VectorType castDstType = bitcastOp.getResultVectorType();
913 assert(castSrcType.getRank() == castDstType.getRank());
914
915 // This transformation builds on top of
916 // vector.{extract|insert}_strided_slice, which do not support
917 // extracting/inserting "scallable sub-vectors". Bail out.
918 if (castSrcType.isScalable())
919 return rewriter.notifyMatchFailure(bitcastOp,
920 "Scalable vectors are not supported");
921
922 // Only support rank 1 case for now.
923 if (castSrcType.getRank() != 1)
924 return failure();
925
926 int64_t castSrcLastDim = castSrcType.getShape().back();
927 int64_t castDstLastDim = castDstType.getShape().back();
928 // Require casting to less elements for now; other cases to be implemented.
929 if (castSrcLastDim < castDstLastDim)
930 return failure();
931
932 assert(castSrcLastDim % castDstLastDim == 0);
933 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
934 // Nothing to do if it is already bitcasting to a single element.
935 if (castSrcLastDim == shrinkRatio)
936 return failure();
937
938 Location loc = bitcastOp.getLoc();
939 Type elemType = castDstType.getElementType();
940 assert(elemType.isSignlessIntOrIndexOrFloat());
941
942 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
943 rewriter.getZeroAttr(elemType));
944 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
945
946 SmallVector<int64_t> sliceShape = {castDstLastDim};
947 SmallVector<int64_t> strides = {1};
948 VectorType newCastDstType =
949 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
950 castDstType.getElementType());
951
952 for (int i = 0, e = shrinkRatio; i < e; ++i) {
953 Value extracted = ExtractStridedSliceOp::create(
954 rewriter, loc, bitcastOp.getSource(),
955 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
956 Value bitcast =
957 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
958 res = InsertStridedSliceOp::create(
959 rewriter, loc, bitcast, res,
960 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
961 }
962 rewriter.replaceOp(bitcastOp, res);
963 return success();
964 }
965
966private:
967 std::function<bool(BitCastOp)> controlFn;
968};
969
970static bool haveSameShapeAndScaling(Type t, Type u) {
971 auto tVec = dyn_cast<VectorType>(t);
972 auto uVec = dyn_cast<VectorType>(u);
973 if (!tVec) {
974 return !uVec;
975 }
976 if (!uVec) {
977 return false;
978 }
979 return tVec.getShape() == uVec.getShape() &&
980 tVec.getScalableDims() == uVec.getScalableDims();
981}
982
983/// If `type` is shaped, clone it with `newElementType`. Otherwise,
984/// return `newElementType`.
985static Type cloneOrReplace(Type type, Type newElementType) {
986 if (auto shapedType = dyn_cast<ShapedType>(type)) {
987 return shapedType.clone(newElementType);
988 }
989 return newElementType;
990}
991
992/// If `value` is the result of a broadcast operation, return the input
993/// of the broadcast operation.
994static Value getBroadcastLikeSource(Value value) {
995
996 Operation *op = value.getDefiningOp();
997 if (!op)
998 return {};
999
1000 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1001 return broadcast.getSource();
1002
1003 return {};
1004}
1005
1006/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
1007///
1008/// Example:
1009/// ```
1010/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
1011/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
1012/// %r = arith.addi %a, %b : vector<1x4xindex>
1013/// ```
1014/// Gets converted to:
1015/// ```
1016/// %r = arith.addi %arg0, %arg1 : index
1017/// %b = vector.broadcast %r : index to vector<1x4xindex>
1018/// ```
1019struct ReorderElementwiseOpsOnBroadcast final
1020 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1022 LogicalResult matchAndRewrite(Operation *op,
1023 PatternRewriter &rewriter) const override {
1024 if (op->getNumResults() != 1)
1025 return failure();
1026 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
1027 if (!resultType)
1028 return failure();
1030 return rewriter.notifyMatchFailure(
1031 op, "Op doesn't have ElementwiseMappableTraits");
1032 if (op->getNumOperands() == 0)
1033 return failure();
1034 if (isa<vector::FMAOp>(op)) {
1035 return rewriter.notifyMatchFailure(
1036 op,
1037 "Op only accepts vector types - not supported as broadcast source "
1038 "might be a scalar");
1039 }
1040
1041 Type resultElemType = resultType.getElementType();
1042
1043 // Get the type of the first non-constant operand
1044 Value broadcastSource;
1045 for (Value operand : op->getOperands()) {
1046 Operation *definingOp = operand.getDefiningOp();
1047 if (!definingOp)
1048 return failure();
1049 if (definingOp->hasTrait<OpTrait::ConstantLike>())
1050 continue;
1051 broadcastSource = getBroadcastLikeSource(operand);
1052 break;
1053 }
1054 if (!broadcastSource)
1055 return failure();
1056 Type unbroadcastResultType =
1057 cloneOrReplace(broadcastSource.getType(), resultElemType);
1058
1059 // Make sure that all operands are broadcast from identically-shaped types:
1060 // * scalar (`vector.broadcast`), or
1061 // * vector (`vector.broadcast`).
1062 // Otherwise the re-ordering wouldn't be safe.
1063 if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
1064 if (auto source = getBroadcastLikeSource(val))
1065 return haveSameShapeAndScaling(source.getType(),
1066 broadcastSource.getType());
1067 SplatElementsAttr splatConst;
1068 return matchPattern(val, m_Constant(&splatConst));
1069 })) {
1070 return rewriter.notifyMatchFailure(
1071 op,
1072 "not all operands are constants or broadcasts from the same type");
1073 }
1074
1075 // Collect the source values before broadcasting
1076 SmallVector<Value> srcValues;
1077 srcValues.reserve(op->getNumOperands());
1078 for (Value operand : op->getOperands()) {
1079 SplatElementsAttr splatConst;
1080 if (matchPattern(operand, m_Constant(&splatConst))) {
1081 Attribute newConst;
1082 Type elementType = getElementTypeOrSelf(operand.getType());
1083 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1084 if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1085 newConst = splatConst.resizeSplat(newTypeShaped);
1086 } else {
1087 newConst = splatConst.getSplatValue<Attribute>();
1088 }
1089 Operation *newConstOp =
1090 operand.getDefiningOp()->getDialect()->materializeConstant(
1091 rewriter, newConst, newType, operand.getLoc());
1092 srcValues.push_back(newConstOp->getResult(0));
1093 } else {
1094 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1095 }
1096 }
1097
1098 // Create the "elementwise" Op
1099 Operation *elementwiseOp =
1100 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1101 unbroadcastResultType, op->getAttrs());
1102
1103 // Replace the original Op with the elementwise Op
1104 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1105 op, resultType, elementwiseOp->getResults());
1106
1107 return success();
1108 }
1109};
1110
1111/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1112/// This may result in cleaner code when extracting a single value
1113/// from multi-element vector and also to help canonicalize 1-element vectors to
1114/// scalars.
1115///
1116/// Example:
1117/// ```
1118/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1119/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1120/// ```
1121/// Gets converted to:
1122/// ```
1123/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1124/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1125/// %2 = arith.addf %0, %1 : f32
1126/// ```
1127class ExtractOpFromElementwise final
1128 : public OpRewritePattern<vector::ExtractOp> {
1129public:
1130 using Base::Base;
1131
1132 LogicalResult matchAndRewrite(vector::ExtractOp op,
1133 PatternRewriter &rewriter) const override {
1134 Operation *eltwise = op.getSource().getDefiningOp();
1135
1136 // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1137 // as it doesn't support scalars.
1138 if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1139 isa<vector::FMAOp>(eltwise))
1140 return rewriter.notifyMatchFailure(op, "not an elementwise op");
1141
1142 if (eltwise->getNumResults() != 1)
1143 return rewriter.notifyMatchFailure(op, "expected single result");
1144
1145 if (!eltwise->hasOneUse())
1146 return rewriter.notifyMatchFailure(op, "expected single op use");
1147
1148 if (!llvm::all_equal(eltwise->getOperandTypes()))
1149 return rewriter.notifyMatchFailure(op, "operand types are different");
1150
1151 // Dynamic position can cause dominance issues, so conservatively fail for
1152 // now.
1153 if (!op.getDynamicPosition().empty())
1154 return rewriter.notifyMatchFailure(
1155 op, "dynamic position not yet implemented");
1156
1157 Type dstType = op.getType();
1158
1159 OpBuilder::InsertionGuard g(rewriter);
1160 rewriter.setInsertionPoint(eltwise);
1161
1162 IRMapping mapping;
1163 Location loc = eltwise->getLoc();
1164 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1165 for (Value arg : eltwise->getOperands()) {
1166 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1167 mapping.map(arg, newArg);
1168 }
1169
1170 Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1171 newEltwise->getResult(0).setType(dstType);
1172
1173 rewriter.replaceOp(op, newEltwise);
1174 rewriter.eraseOp(eltwise);
1175 return success();
1176 }
1177};
1178
1179/// Check if the element type is suitable for vector.load/store sinking.
1180/// Element type must be index or byte-aligned integer or floating-point type.
1181static bool isSupportedMemSinkElementType(Type type) {
1182 if (isa<IndexType>(type))
1183 return true;
1184
1185 return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1186}
1187
1188/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1189/// Only index and byte-aligned integer and floating-point element types are
1190/// supported for now.
1191///
1192/// Example:
1193/// ```
1194/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1195/// vector.extract %0[1] : f32 from vector<4xf32>
1196/// ```
1197/// Gets converted to:
1198/// ```
1199/// %c1 = arith.constant 1 : index
1200/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1201/// %1 = memref.load %arg0[%0] : memref<?xf32>
1202/// ```
1203class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1204public:
1205 using Base::Base;
1206
1207 LogicalResult matchAndRewrite(vector::ExtractOp op,
1208 PatternRewriter &rewriter) const override {
1209 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1210 if (!loadOp)
1211 return rewriter.notifyMatchFailure(op, "expected a load op");
1212
1213 // Checking for single use so we won't duplicate load ops.
1214 if (!loadOp->hasOneUse())
1215 return rewriter.notifyMatchFailure(op, "expected single op use");
1216
1217 VectorType loadVecType = loadOp.getVectorType();
1218 if (loadVecType.isScalable())
1219 return rewriter.notifyMatchFailure(op,
1220 "scalable vectors are not supported");
1221
1222 MemRefType memType = loadOp.getMemRefType();
1223
1224 // Non-byte-aligned types are tricky and may require special handling,
1225 // ignore them for now.
1226 if (!isSupportedMemSinkElementType(memType.getElementType()))
1227 return rewriter.notifyMatchFailure(op, "unsupported element type");
1228
1229 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1230 if (rankOffset < 0)
1231 return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1232
1233 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1234 int64_t finalRank = 0;
1235 if (extractVecType)
1236 finalRank = extractVecType.getRank();
1237
1238 SmallVector<Value> indices = loadOp.getIndices();
1239 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1240
1241 // There may be memory stores between the load and the extract op, so we
1242 // need to make sure that the new load op is inserted at the same place as
1243 // the original load op.
1244 OpBuilder::InsertionGuard g(rewriter);
1245 rewriter.setInsertionPoint(loadOp);
1246 Location loc = loadOp.getLoc();
1247 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1248 for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1249 OpFoldResult pos = extractPos[i - rankOffset];
1250 if (isZeroInteger(pos))
1251 continue;
1252
1253 Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1254 indices[i] = idxBuilderf.add(indices[i], offset);
1255 }
1256
1257 Value base = loadOp.getBase();
1258 if (extractVecType) {
1259 rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1260 indices);
1261 } else {
1262 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1263 }
1264 // We checked for single use so we can safely erase the load op.
1265 rewriter.eraseOp(loadOp);
1266 return success();
1267 }
1268};
1269
1270/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
1271///
1272/// Example:
1273/// ```
1274/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
1275/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1276/// ```
1277/// Gets converted to:
1278/// ```
1279/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1280/// ```
1281class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
1282public:
1283 using Base::Base;
1284
1285 LogicalResult matchAndRewrite(vector::StoreOp op,
1286 PatternRewriter &rewriter) const override {
1287 VectorType vecType = op.getVectorType();
1288 if (vecType.isScalable())
1289 return rewriter.notifyMatchFailure(op,
1290 "scalable vectors are not supported");
1291
1292 if (isa<VectorType>(op.getMemRefType().getElementType()))
1293 return rewriter.notifyMatchFailure(
1294 op, "memrefs of vectors are not supported");
1295
1296 if (vecType.getNumElements() != 1)
1297 return rewriter.notifyMatchFailure(
1298 op, "only 1-element vectors are supported");
1299
1300 Value toStore = op.getValueToStore();
1301 Value source = getBroadcastLikeSource(toStore);
1302 if (!source)
1303 return rewriter.notifyMatchFailure(
1304 op, "value to store is not from a broadcast");
1305
1306 // Checking for single use so we can remove broadcast.
1307 Operation *broadcast = toStore.getDefiningOp();
1308 if (!broadcast->hasOneUse())
1309 return rewriter.notifyMatchFailure(op, "expected single op use");
1310
1311 Value base = op.getBase();
1312 ValueRange indices = op.getIndices();
1313
1314 if (isa<VectorType>(source.getType())) {
1315 rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1316 } else {
1317 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1318 }
1319 rewriter.eraseOp(broadcast);
1320 return success();
1321 }
1322};
1323
1324// Helper that returns a vector comparison that constructs a mask:
1325// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1326//
1327// If `dim == 0` then the result will be a 0-D vector.
1328//
1329// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1330// much more compact, IR for this operation, but LLVM eventually
1331// generates more elaborate instructions for this intrinsic since it
1332// is very conservative on the boundary conditions.
1333static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1334 bool force32BitVectorIndices, int64_t dim,
1335 Value b, Value *off = nullptr) {
1336 auto loc = op->getLoc();
1337 // If we can assume all indices fit in 32-bit, we perform the vector
1338 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1339 // Otherwise we perform the vector comparison using 64-bit indices.
1340 Type idxType =
1341 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1342 DenseIntElementsAttr indicesAttr;
1343 if (dim == 0 && force32BitVectorIndices) {
1344 indicesAttr = DenseIntElementsAttr::get(
1345 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1346 } else if (dim == 0) {
1347 indicesAttr = DenseIntElementsAttr::get(
1348 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1349 } else if (force32BitVectorIndices) {
1350 indicesAttr = rewriter.getI32VectorAttr(
1351 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1352 } else {
1353 indicesAttr = rewriter.getI64VectorAttr(
1354 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1355 }
1356 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1357 // Add in an offset if requested.
1358 if (off) {
1359 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1360 Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
1361 indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1362 }
1363 // Construct the vector comparison.
1364 // When using 32-bit indices, cap `b` at INT32_MAX before casting to prevent
1365 // signed overflow for large index values (e.g., 2^51 wrapping to 0 in i32).
1366 // Note: for fixed-size vectors, `dim` is a tighter bound (since any b >= dim
1367 // already implies all-true), but we use INT32_MAX for uniformity with the
1368 // scalable-vector path.
1369 if (force32BitVectorIndices) {
1370 Value maxBound =
1371 arith::ConstantIndexOp::create(rewriter, loc, (1LL << 31) - 1);
1372 b = arith::MinSIOp::create(rewriter, loc, b, maxBound);
1373 }
1374 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1375 Value bounds =
1376 vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1377 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1378 indices, bounds);
1379}
1380
1381template <typename ConcreteOp>
1382struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1383public:
1384 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1385 PatternBenefit benefit = 1)
1386 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1387 force32BitVectorIndices(enableIndexOpt) {}
1388
1389 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1390 PatternRewriter &rewriter) const override {
1391 if (!xferOp.hasOutOfBoundsDim())
1392 return failure();
1393
1394 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1395 return failure();
1396
1397 Location loc = xferOp->getLoc();
1398 VectorType vtp = xferOp.getVectorType();
1399
1400 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1401 // set and [dim - offset .. vector_length) unset.
1402 //
1403 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1404 // dimensions here.
1405 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1406 Value off = xferOp.getIndices()[lastIndex];
1407 Value dim =
1408 vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
1409 Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off);
1410 Value mask = vector::CreateMaskOp::create(
1411 rewriter, loc,
1412 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1413 vtp.getScalableDims()),
1414 b);
1415 if (xferOp.getMask()) {
1416 // Intersect the in-bounds with the mask specified as an op parameter.
1417 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1418 }
1419
1420 rewriter.modifyOpInPlace(xferOp, [&]() {
1421 xferOp.getMaskMutable().assign(mask);
1422 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1423 });
1424
1425 return success();
1426 }
1427
1428private:
1429 const bool force32BitVectorIndices;
1430};
1431
1432/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1433class VectorCreateMaskOpConversion
1434 : public OpRewritePattern<vector::CreateMaskOp> {
1435public:
1436 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1437 bool enableIndexOpt,
1438 PatternBenefit benefit = 1)
1439 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1440 force32BitVectorIndices(enableIndexOpt) {}
1441
1442 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1443 PatternRewriter &rewriter) const override {
1444 auto dstType = op.getType();
1445 if (cast<VectorType>(dstType).isScalable())
1446 return failure();
1447 int64_t rank = dstType.getRank();
1448 if (rank > 1)
1449 return failure();
1450 rewriter.replaceOp(
1451 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1452 rank == 0 ? 0 : dstType.getDimSize(0),
1453 op.getOperand(0)));
1454 return success();
1455 }
1456
1457private:
1458 const bool force32BitVectorIndices;
1459};
1460
1461/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1462static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1463 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1464 // TODO: Support non-dense constant.
1465 if (!denseAttr)
1466 return false;
1467
1468 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1469 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1470}
1471
1472/// Folds a select operation between an all-true and all-false vector. For now,
1473/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1474///
1475/// %true = arith.constant dense<true> : vector<1xi1>
1476/// %false = arith.constant dense<false> : vector<1xi1>
1477/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1478/// =>
1479/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1480///
1481/// InstCombine seems to handle vectors with multiple elements but not the
1482/// single element ones.
1483struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1484 using Base::Base;
1485
1486 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1487 PatternRewriter &rewriter) const override {
1488 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1489 if (!vecType || !vecType.getElementType().isInteger(1))
1490 return failure();
1491
1492 // Only scalar conditions can be folded.
1493 Value cond = selectOp.getCondition();
1494 if (isa<VectorType>(cond.getType()))
1495 return failure();
1496
1497 // TODO: Support n-D and scalable vectors.
1498 if (vecType.getRank() != 1 || vecType.isScalable())
1499 return failure();
1500
1501 // TODO: Support vectors with multiple elements.
1502 if (vecType.getShape()[0] != 1)
1503 return failure();
1504
1505 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1506 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1507 return failure();
1508
1509 auto falseConst =
1510 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1511 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1512 return failure();
1513
1514 // Replace select with its condition broadcasted to single element vector.
1515 auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1516 auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1517 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1518 return success();
1519 }
1520};
1521
1522/// Returns the number of dims can be folded away from transfer ops. It returns
1523/// a failure if it can not determine the number of dims to be folded.
1524///
1525/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1526/// `vectorType` is vector<16x16x1x1xf32>
1527/// (there two inner most dims can be dropped by memref.subview ops)
1528///
1529/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1530/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1531/// (only the inner most unit dim of `srcType` can be dropped)
1532///
1533/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1534/// `vectorType` is vector<16x16x1x[1]xf32>
1535/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1536/// unit")
1537static FailureOr<size_t>
1538getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1539 SmallVector<int64_t> srcStrides;
1540 int64_t srcOffset;
1541 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1542 return failure();
1543
1544 auto isUnitDim = [](VectorType type, int dim) {
1545 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1546 };
1547
1548 // According to vector.transfer_read/write semantics, the vector can be a
1549 // slice. Thus, we have to offset the check index with `rankDiff` in
1550 // `srcStrides` and source dim sizes.
1551 size_t result = 0;
1552 int rankDiff = srcType.getRank() - vectorType.getRank();
1553 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1554 // Check that the inner dim size is 1 for both memref type and vector slice.
1555 // It can be folded only if they are 1 and the stride is 1.
1556 int dim = vectorType.getRank() - i - 1;
1557 if (srcStrides[dim + rankDiff] != 1 ||
1558 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1559 break;
1560 result++;
1561 }
1562 return result;
1563}
1564
1565/// Drop inner most contiguous unit dimensions from transfer_read operand.
1567 : public OpRewritePattern<vector::TransferReadOp> {
1568 using Base::Base;
1569
1570 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1571 PatternRewriter &rewriter) const override {
1572 // TODO: support 0-d corner case.
1573 if (readOp.getTransferRank() == 0)
1574 return failure();
1575
1576 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1577 if (!srcType)
1578 return failure();
1579
1580 if (!readOp.getPermutationMap().isMinorIdentity())
1581 return failure();
1582
1583 auto targetType = readOp.getVectorType();
1584 if (targetType.getRank() <= 1)
1585 return failure();
1586
1587 FailureOr<size_t> maybeDimsToDrop =
1588 getTransferFoldableInnerUnitDims(srcType, targetType);
1589 if (failed(maybeDimsToDrop))
1590 return failure();
1591
1592 size_t dimsToDrop = maybeDimsToDrop.value();
1593 if (dimsToDrop == 0)
1594 return failure();
1595
1596 auto inBounds = readOp.getInBoundsValues();
1597 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1598 if (llvm::is_contained(droppedInBounds, false))
1599 return failure();
1600
1601 auto resultTargetVecType =
1602 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1603 targetType.getElementType(),
1604 targetType.getScalableDims().drop_back(dimsToDrop));
1605
1606 auto loc = readOp.getLoc();
1608 memref::getMixedSizes(rewriter, loc, readOp.getBase());
1609 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1610 rewriter.getIndexAttr(0));
1611 SmallVector<OpFoldResult> strides(srcType.getRank(),
1612 rewriter.getIndexAttr(1));
1613 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1614 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1615 strides);
1616 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1617 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1618 Value rankedReducedView =
1619 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1620 readOp.getBase(), offsets, sizes, strides);
1621 auto permMap = getTransferMinorIdentityMap(
1622 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1623
1624 // If there is a mask, shape_cast it to drop the same inner unit dims.
1625 Value mask = readOp.getMask();
1626 if (mask) {
1627 auto maskType = cast<VectorType>(mask.getType());
1628 auto reducedMaskType = VectorType::get(
1629 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1630 maskType.getScalableDims().drop_back(dimsToDrop));
1631 mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1632 mask);
1633 }
1634
1635 Value result = vector::TransferReadOp::create(
1636 rewriter, loc, resultTargetVecType, rankedReducedView,
1637 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1638 readOp.getPadding(), mask, inBoundsAttr);
1639 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1640 result);
1641 return success();
1642 }
1643};
1644
1645/// Drop inner most contiguous unit dimensions from transfer_write operand.
1646/// E.g.,
1647/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1648/// {in_bounds = [true, true, true, true, true]}
1649/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1650///
1651/// will be replaced with
1652///
1653/// %subview = memref.subview %arg0
1654/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1655/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1656/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1657/// to vector<1x16x16xf32>
1658/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1659/// {in_bounds = [true, true, true]}
1660/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1661///
1662/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1664 : public OpRewritePattern<vector::TransferWriteOp> {
1665 using Base::Base;
1666
1667 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1668 PatternRewriter &rewriter) const override {
1669 // TODO: support 0-d corner case.
1670 if (writeOp.getTransferRank() == 0)
1671 return failure();
1672
1673 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1674 if (!srcType)
1675 return failure();
1676
1677 if (!writeOp.getPermutationMap().isMinorIdentity())
1678 return failure();
1679
1680 auto targetType = writeOp.getVectorType();
1681 if (targetType.getRank() <= 1)
1682 return failure();
1683
1684 FailureOr<size_t> maybeDimsToDrop =
1685 getTransferFoldableInnerUnitDims(srcType, targetType);
1686 if (failed(maybeDimsToDrop))
1687 return failure();
1688
1689 size_t dimsToDrop = maybeDimsToDrop.value();
1690 if (dimsToDrop == 0)
1691 return failure();
1692
1693 auto inBounds = writeOp.getInBoundsValues();
1694 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1695 if (llvm::is_contained(droppedInBounds, false))
1696 return failure();
1697
1698 auto resultTargetVecType =
1699 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1700 targetType.getElementType(),
1701 targetType.getScalableDims().drop_back(dimsToDrop));
1702
1703 Location loc = writeOp.getLoc();
1705 memref::getMixedSizes(rewriter, loc, writeOp.getBase());
1706 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1707 rewriter.getIndexAttr(0));
1708 SmallVector<OpFoldResult> strides(srcType.getRank(),
1709 rewriter.getIndexAttr(1));
1710 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1711 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1712 strides);
1713 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1714 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1715
1716 Value rankedReducedView =
1717 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1718 writeOp.getBase(), offsets, sizes, strides);
1719 auto permMap = getTransferMinorIdentityMap(
1720 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1721
1722 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1723 loc, resultTargetVecType, writeOp.getVector());
1724
1725 // If there is a mask, shape_cast it to drop the same inner unit dims.
1726 Value mask = writeOp.getMask();
1727 if (mask) {
1728 auto maskType = cast<VectorType>(mask.getType());
1729 auto reducedMaskType = VectorType::get(
1730 maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(),
1731 maskType.getScalableDims().drop_back(dimsToDrop));
1732 mask = rewriter.createOrFold<vector::ShapeCastOp>(loc, reducedMaskType,
1733 mask);
1734 }
1735
1736 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1737 writeOp, shapeCast, rankedReducedView,
1738 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1739 mask, inBoundsAttr);
1740 return success();
1741 }
1742};
1743
1744/// Canonicalization of a `vector.contract %a, %b, %c` with row-major matmul
1745/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1746/// with the RHS transposed) lowering.
1748 : OpRewritePattern<vector::ContractionOp> {
1749 using Base::Base;
1750
1752 std::function<LogicalResult(vector::ContractionOp op)>;
1753
1755 FilterConstraintType constraint)
1756 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1757 filter(std::move(constraint)) {}
1758
1759 LogicalResult matchAndRewrite(vector::ContractionOp op,
1760 PatternRewriter &rewriter) const override {
1761 if (failed(filter(op)))
1762 return failure();
1763
1764 Location loc = op.getLoc();
1765 Value lhs = op.getLhs();
1766 Value rhs = op.getRhs();
1767 Value res = op.getAcc();
1768
1769 // Set up the parallel/reduction structure in right form.
1770 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1771 auto infer = [&](MapList m) {
1772 return AffineMap::inferFromExprList(m, op.getContext());
1773 };
1774 AffineExpr m;
1775 AffineExpr n;
1776 AffineExpr k;
1777 bindDims(rewriter.getContext(), m, n, k);
1778 static constexpr std::array<int64_t, 2> perm = {1, 0};
1779 auto iteratorTypes = op.getIteratorTypes().getValue();
1780 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1781 if (iteratorTypes.size() != 3 ||
1782 !vector::isParallelIterator(iteratorTypes[0]) ||
1783 !vector::isParallelIterator(iteratorTypes[1]) ||
1784 !vector::isReductionIterator(iteratorTypes[2]))
1785 return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1786
1787 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1788 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1789 if (maps == canonicalForm)
1790 return rewriter.notifyMatchFailure(op, "already in the canonical form");
1791
1792 // Create a vector transpose making sure to emit zero/sign-extend at the
1793 // end.
1794 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1795 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1796 Value trans =
1797 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1798 VectorType newType =
1799 cast<VectorType>(trans.getType())
1800 .clone(cast<VectorType>(mat.getType()).getElementType());
1801 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1802 }
1803 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1804 Value trans =
1805 vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm);
1806 VectorType newType =
1807 VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1808 cast<VectorType>(mat.getType()).getElementType());
1809 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1810 }
1811 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1812 };
1813
1814 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1815 rhs = createTranspose(rhs);
1816 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1817 lhs = createTranspose(lhs);
1818 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1819 rhs = createTranspose(rhs);
1820 lhs = createTranspose(lhs);
1821 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1822 std::swap(rhs, lhs);
1823 rhs = createTranspose(rhs);
1824 lhs = createTranspose(lhs);
1825 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1826 std::swap(rhs, lhs);
1827 rhs = createTranspose(rhs);
1828 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1829 std::swap(lhs, rhs);
1830 lhs = createTranspose(lhs);
1831 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1832 std::swap(lhs, rhs);
1833 } else {
1834 return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1835 }
1836 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1837 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1838 op.getIteratorTypes());
1839 return success();
1840 };
1841
1842private:
1843 FilterConstraintType filter;
1844};
1845
1846/// Pattern to fold arithmetic extensions on floating point data types into
1847/// vector contraction operations. linalg.matmul introduces arithmetic
1848/// extensions on its operands. Please mlir snippets below for more details.
1849/// ```mlir
1850/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1851/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1852/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1853/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1854/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1855/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1856/// "linalg.yield"(%acc) : (f32) -> ()
1857/// })
1858/// ```
1859/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1860/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1861/// This pattern folds the arithmetic extensions into the vector contraction and
1862/// enables the usage of native mixed precision Tensor Core instructions.
1863template <typename ExtOp>
1865 : public OpRewritePattern<vector::ContractionOp> {
1866 using Base::Base;
1867
1868 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1869 PatternRewriter &rewriter) const override {
1870
1871 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1872 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1873
1874 if (!lhsDefOp || !rhsDefOp) {
1875 return rewriter.notifyMatchFailure(contractOp,
1876 "no defining op on contract operands");
1877 }
1878
1879 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1880 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1881 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1882 contractOp.getIteratorTypesAttr());
1883
1884 return success();
1885 }
1886};
1887
1888/// Pattern to fold chained reduction to a series of vector additions and a
1889/// final reduction. This form should require fewer subgroup operations.
1890///
1891/// ```mlir
1892/// %a = vector.reduction <add> %x, %acc
1893/// %b = vector.reduction <add> %y, %a
1894/// ==>
1895/// %a = arith.addf %x, %y
1896/// %b = vector.reduction <add> %a, %acc
1897/// ```
1898struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1899 using Base::Base;
1900
1901 LogicalResult matchAndRewrite(vector::ReductionOp op,
1902 PatternRewriter &rewriter) const override {
1903 // TODO: Handle other combining kinds.
1904 if (op.getKind() != vector::CombiningKind::ADD)
1905 return failure();
1906
1907 // Accumulator is optional.
1908 Value acc = op.getAcc();
1909 if (!acc)
1910 return failure();
1911
1912 if (!acc.getType().isIntOrFloat())
1913 return failure();
1914
1915 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1916 if (!parentReduction)
1917 return failure();
1918
1919 Location loc = op.getLoc();
1920 Value vAdd;
1921 if (isa<IntegerType>(acc.getType())) {
1922 vAdd = rewriter.createOrFold<arith::AddIOp>(
1923 loc, parentReduction.getVector(), op.getVector());
1924 } else {
1925 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1926 op.getVector());
1927 }
1928 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1929 parentReduction.getAcc());
1930 return success();
1931 }
1932};
1933
1934// Helper function dropping unit non-scalable dimension from a VectorType
1935// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1936// dimensions are not dropped. Folding such dimensions would require "shifting"
1937// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1938// vector<[4]xf32>). This could be implemented in the future.
1939static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1940 auto inVecShape = inVecTy.getShape();
1941 SmallVector<int64_t> newShape;
1942 SmallVector<bool> newScalableDims;
1943 for (auto [dim, isScalable] :
1944 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1945 if (dim == 1 && !isScalable)
1946 continue;
1947
1948 newShape.push_back(dim);
1949 newScalableDims.push_back(isScalable);
1950 }
1951 // All dims have been dropped, return vector<1xeType>.
1952 if (newShape.empty()) {
1953 newShape.push_back(1);
1954 newScalableDims.push_back(false);
1955 }
1956
1957 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1958}
1959
1960/// For vectors with at least one unit dim, replaces:
1961/// elementwise(a, b)
1962/// with:
1963/// sc_a = shape_cast(a)
1964/// sc_b = shape_cast(b)
1965/// res = elementwise(sc_a, sc_b)
1966/// return shape_cast(res)
1967/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1968/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1969/// required to be rank > 1.
1970///
1971/// Ex:
1972/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1973/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1974///
1975/// gets converted to:
1976///
1977/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1978/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1979/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1980/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1981/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1982///
1983/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1984/// `%cast`.
1986 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1988 LogicalResult matchAndRewrite(Operation *op,
1989 PatternRewriter &rewriter) const override {
1990 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1991 return failure();
1992
1993 auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1994 if (!resultVectorType)
1995 return failure();
1996
1997 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1998 // guaranteed to have identical shapes (with some exceptions such as
1999 // `arith.select`) and it suffices to only check one of them.
2000 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
2001 if (!sourceVectorType)
2002 return failure();
2003 if (sourceVectorType.getRank() < 2)
2004 return failure();
2005
2006 SmallVector<Value> newOperands;
2007 auto loc = op->getLoc();
2008 for (auto operand : op->getOperands()) {
2009 auto opVectorType = cast<VectorType>(operand.getType());
2010 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
2011 if (newVType == opVectorType)
2012 return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
2013
2014 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
2015 newOperands.push_back(opSC);
2016 }
2017
2018 VectorType newResultVectorType =
2019 dropNonScalableUnitDimFromType(resultVectorType);
2020 // Create an updated elementwise Op without unit dim.
2021 Operation *elementwiseOp =
2022 rewriter.create(loc, op->getName().getIdentifier(), newOperands,
2023 newResultVectorType, op->getAttrs());
2024
2025 // Restore the unit dim by applying vector.shape_cast to the result.
2026 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
2027 elementwiseOp->getResult(0));
2028
2029 return success();
2030 }
2031};
2032
2033/// A pattern to drop unit dims from vector.transpose.
2034///
2035/// Example:
2036///
2037/// BEFORE:
2038/// ```mlir
2039/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
2040/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
2041/// ```
2042///
2043/// AFTER:
2044/// ```mlir
2045/// %dropDims = vector.shape_cast %vector
2046/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
2047/// %transpose = vector.transpose %0, [1, 0]
2048/// : vector<4x[4]xf32> to vector<[4]x4xf32>
2049/// %restoreDims = vector.shape_cast %transpose
2050/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2051/// ```
2053 : OpRewritePattern<vector::TransposeOp> {
2054 using Base::Base;
2055
2056 LogicalResult matchAndRewrite(vector::TransposeOp op,
2057 PatternRewriter &rewriter) const override {
2058 VectorType sourceType = op.getSourceVectorType();
2059 VectorType sourceTypeWithoutUnitDims =
2061
2062 if (sourceType == sourceTypeWithoutUnitDims)
2063 return failure();
2064
2065 // Construct a map from dimIdx -> number of dims dropped before dimIdx.
2066 auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
2067 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
2068 int64_t droppedDims = 0;
2069 for (auto [i, dim] : llvm::enumerate(sourceDims)) {
2070 droppedDimsBefore[i] = droppedDims;
2071 if (dim == std::make_tuple(1, false))
2072 ++droppedDims;
2073 }
2074
2075 // Drop unit dims from transpose permutation.
2076 ArrayRef<int64_t> perm = op.getPermutation();
2077 SmallVector<int64_t> newPerm;
2078 for (int64_t idx : perm) {
2079 if (sourceDims[idx] == std::make_tuple(1, false))
2080 continue;
2081 newPerm.push_back(idx - droppedDimsBefore[idx]);
2082 }
2083
2084 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
2085 // type when the dimensions are unit dimensions. In this case, the newPerm
2086 // should be [0].
2087 if (newPerm.empty()) {
2088 newPerm.push_back(0);
2089 }
2090
2091 Location loc = op.getLoc();
2092 // Drop the unit dims via shape_cast.
2093 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2094 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2095 // Create the new transpose.
2096 auto transposeWithoutUnitDims =
2097 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2098 // Restore the unit dims via shape cast.
2099 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2100 op, op.getResultVectorType(), transposeWithoutUnitDims);
2101
2102 return success();
2103 }
2104};
2105
2106/// A pattern to drop unit dims from the iter_args of an scf.for.
2107///
2108/// Example:
2109///
2110/// BEFORE:
2111/// ```mlir
2112/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2113/// ...
2114/// scf.yield %
2115/// }
2116/// ```
2117///
2118/// AFTER:
2119/// ```mlir
2120/// %drop = vector.shape_cast %init
2121/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2122/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2123/// %new_iter = vector.shape_cast %iter
2124/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2125/// ...
2126/// }
2127/// %res = vector.shape_cast %new_loop
2128/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2129/// ```
2130struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2131 using Base::Base;
2132
2133 LogicalResult matchAndRewrite(scf::ForOp forOp,
2134 PatternRewriter &rewriter) const override {
2135 /// Find the first iter_arg with droppable unit dims. Further applications
2136 /// of this pattern will apply to later arguments.
2137 for (OpOperand &operand : forOp.getInitArgsMutable()) {
2138 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2139 if (!vectorType)
2140 continue;
2141
2142 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2143 if (vectorType == newVectorType)
2144 continue;
2145
2146 // Create a new ForOp with that iter operand replaced.
2147 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2148 return vector::ShapeCastOp::create(b, loc, type, source);
2149 };
2150
2152 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2153 rewriter.replaceOp(forOp,
2154 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2155 replacement, castFn));
2156 return success();
2157 }
2158 return failure();
2159 }
2160};
2161
2162/// Pattern to eliminate redundant zero-constants added to reduction operands.
2163/// It's enough for there to be one initial zero value, so we can eliminate the
2164/// extra ones that feed into `vector.reduction <add>`. These get created by the
2165/// `ChainedReduction` pattern.
2166///
2167/// ```mlir
2168/// %a = arith.addf %x, %zero
2169/// %b = arith.addf %a, %y
2170/// %c = vector.reduction <add> %b, %acc
2171/// ==>
2172/// %b = arith.addf %a, %y
2173/// %c = vector.reduction <add> %b, %acc
2174/// ```
2175struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2176 using Base::Base;
2177
2178 LogicalResult matchAndRewrite(vector::ReductionOp op,
2179 PatternRewriter &rewriter) const override {
2180 // TODO: Handle other reduction kinds and their identity values.
2181 if (op.getKind() != vector::CombiningKind::ADD)
2182 return failure();
2183
2184 Type elemType = op.getSourceVectorType().getElementType();
2185 // The integer case should be handled by `arith.addi` folders, only check
2186 // for floats here.
2187 if (!isa<FloatType>(elemType))
2188 return failure();
2189
2190 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2191 if (!vAdd)
2192 return failure();
2193 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2194 if (!addLhs)
2195 return failure();
2196
2197 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2198 return failure();
2199
2200 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2201 addLhs.getLhs(), vAdd.getRhs());
2202 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2203 op.getAcc());
2204 return success();
2205 }
2206};
2207
2208/// Example:
2209/// ```
2210/// %a = vector.reduction <add> %x : vector<2xf32> into f32
2211/// ```
2212/// is transformed into:
2213/// ```
2214/// %y = vector.extract %x[0] : f32 from vector<2xf32>
2215/// %z = vector.extract %x[1] : f32 from vector<2xf32>
2216/// %a = arith.addf %y, %z : f32
2217/// ```
2218struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2220 unsigned maxNumElementsToExtract,
2221 PatternBenefit benefit)
2222 : OpRewritePattern(context, benefit),
2223 maxNumElementsToExtract(maxNumElementsToExtract) {}
2224
2225 LogicalResult matchAndRewrite(vector::ReductionOp op,
2226 PatternRewriter &rewriter) const override {
2227 VectorType type = op.getSourceVectorType();
2228 if (type.isScalable() || op.isMasked())
2229 return failure();
2230 assert(type.getRank() == 1 && "Expected a 1-d vector");
2231
2232 int64_t numElems = type.getNumElements();
2233 if (numElems > maxNumElementsToExtract) {
2234 return rewriter.notifyMatchFailure(
2235 op, llvm::formatv("has too many vector elements ({0}) to break down "
2236 "(max allowed: {1})",
2237 numElems, maxNumElementsToExtract));
2238 }
2239
2240 Location loc = op.getLoc();
2241 SmallVector<Value> extracted(numElems, nullptr);
2242 for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2243 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2244 static_cast<int64_t>(idx));
2245
2246 Value res = extracted.front();
2247 for (auto extractedElem : llvm::drop_begin(extracted))
2248 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2249 extractedElem, op.getFastmathAttr());
2250 if (Value acc = op.getAcc())
2251 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2252 op.getFastmathAttr());
2253
2254 rewriter.replaceOp(op, res);
2255 return success();
2256 }
2257
2258private:
2259 unsigned maxNumElementsToExtract = 0;
2260};
2261
2262/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2263/// B)`.
2264/// Example:
2265/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2266/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2267/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2268/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2269///
2270/// Becomes :
2271///
2272/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2273///
2274/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2275/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2276/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2277/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2278template <typename MulOpType>
2279struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2280 using OpRewritePattern<MulOpType>::OpRewritePattern;
2281 // Returns whether a vector.broadcast matches requirements for an outerproduct
2282 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2283 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2284 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2285 // shape_casts/broadcasts which does not belong in this pattern.
2286 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2287 return false;
2288 // Avoid broadcast like f32 or vector<f32> -> ResType
2289 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2290 return srcType && srcType.getRank() != 2;
2291 }
2292
2293 LogicalResult matchAndRewrite(MulOpType mulOp,
2294 PatternRewriter &rewriter) const override {
2295 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2296 if (!resType)
2297 return failure();
2298 if (resType.getRank() != 2)
2299 return failure();
2300 /// If operandA can be written as tr(broadcast(A)) and operandB as
2301 /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2302 /// vector.outerproduct(A, B). Returns failure() otherwise.
2303 auto matchOuterProduct =
2304 [&](Value operandA,
2305 Value operandB) -> FailureOr<vector::OuterProductOp> {
2306 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2307 if (!transposedLhs)
2308 return failure();
2309 // Fail unless this is a true 2-D matrix transpose.
2310 ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2311 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2312 return failure();
2313
2314 auto broadcastedLhs =
2315 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2316 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2317 return failure();
2318
2319 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2320 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2321 return failure();
2322
2323 return vector::OuterProductOp::create(
2324 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2325 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2326 };
2327
2328 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2329 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2330 // Handle commutativity, the transposed op is the outerproduct LHS.
2331 if (failed(maybeOuterP))
2332 maybeOuterP = matchOuterProduct(rhs, lhs);
2333 if (failed(maybeOuterP))
2334 return failure();
2335 rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2336 return success();
2337 }
2338};
2339
2340} // namespace
2341
2348
2349void mlir::vector::populateVectorMaskMaterializationPatterns(
2350 RewritePatternSet &patterns, bool force32BitVectorIndices,
2351 PatternBenefit benefit) {
2352 patterns.add<VectorCreateMaskOpConversion,
2353 MaterializeTransferMask<vector::TransferReadOp>,
2354 MaterializeTransferMask<vector::TransferWriteOp>>(
2355 patterns.getContext(), force32BitVectorIndices, benefit);
2356 patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2357}
2358
2359void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2360 RewritePatternSet &patterns, PatternBenefit benefit) {
2362 DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
2363}
2364
2365void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2366 RewritePatternSet &patterns, PatternBenefit benefit) {
2367 patterns.add<BubbleDownVectorBitCastForExtract,
2368 BubbleDownBitCastForStridedSliceExtract,
2369 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2370 patterns.getContext(), benefit);
2371}
2372
2373void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2374 RewritePatternSet &patterns,
2375 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2376 patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2377 std::move(controlFn), benefit);
2378}
2379
2381 RewritePatternSet &patterns,
2382 std::function<LogicalResult(vector::ContractionOp)> constraint,
2383 PatternBenefit benefit) {
2384 patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2385 std::move(constraint));
2386}
2387
2389 RewritePatternSet &patterns, PatternBenefit benefit) {
2390 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2391 CombineContractABTranspose, CombineContractResultTranspose>(
2392 patterns.getContext(), benefit);
2393}
2394
2401
2403 PatternBenefit benefit) {
2404 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2405 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2406 patterns.getContext(), benefit);
2407}
2408
2409void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2410 PatternBenefit benefit) {
2411 // TODO: Consider converting these patterns to canonicalizations.
2412 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
2413 benefit);
2414}
2415
2416void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2417 RewritePatternSet &patterns, PatternBenefit benefit) {
2418 patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2419 patterns.add<ReduceRedundantZero>(patterns.getContext(),
2420 PatternBenefit(benefit.getBenefit() + 1));
2421}
2422
2423void mlir::vector::populateBreakDownVectorReductionPatterns(
2424 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2425 PatternBenefit benefit) {
2426 patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2427 maxNumElementsToExtract, benefit);
2428}
2429
2431 RewritePatternSet &patterns) {
2432 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2433 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2434 patterns.getContext());
2435}
2436
2437//===----------------------------------------------------------------------===//
2438// TableGen'd enum attribute definitions
2439//===----------------------------------------------------------------------===//
2440
2441#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
return success()
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static FailureOr< size_t > getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType)
Returns the number of dims can be folded away from transfer ops. It returns a failure if it can not d...
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:126
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:132
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:461
This class represents an operand of an operation.
Definition Value.h:254
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:875
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
unsigned getNumOperands()
Definition Operation.h:372
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:122
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction suitab...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.