MLIR 22.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include <cassert>
16#include <cstdint>
17#include <functional>
18#include <optional>
19
30#include "mlir/IR/Location.h"
31#include "mlir/IR/Matchers.h"
34
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
37
38#define DEBUG_TYPE "vector-to-vector"
39
40using namespace mlir;
41using namespace mlir::vector;
42
43template <typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48}
49
50// Helper to find an index in an affine map.
51static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53 int64_t idx = map.getDimPosition(i);
54 if (idx == index)
55 return i;
56 }
57 return std::nullopt;
58}
59
60namespace {
61
62/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63/// Ex:
64/// ```
65/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66/// %1 = vector.multi_reduction add, %0 [1]
67/// : vector<8x32x16xf32> to vector<8x16xf32>
68/// ```
69/// Gets converted to:
70/// ```
71/// %1 = vector.contract {indexing_maps = [
72/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74/// affine_map<(d0, d1, d2) -> (d0, d1)>],
75/// iterator_types = ["parallel", "parallel", "reduction"],
76/// kind = add} %0, %arg1, %cst_f0
77/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78/// ```
79struct MultiReduceToContract
80 : public OpRewritePattern<vector::MultiDimReductionOp> {
81 using Base::Base;
82
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter) const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
86 return failure();
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
89 return failure();
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
97 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
98 } else {
99 iteratorTypes.push_back(vector::IteratorType::reduction);
100 }
101 }
102 auto dstMap =
103 AffineMap::get(/*dimCount=*/reductionMask.size(),
104 /*symbolCount=*/0, exprs, reduceOp.getContext());
105 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
107 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
108 rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
109 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
111 }))));
112 return success();
113 }
114};
115
116/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117/// Ex:
118/// ```
119/// %0 = vector.transpose %arg0, [2, 0, 1]
120/// : vector<32x16x8xf32> to vector<8x32x16xf32>
121/// %1 = vector.contract {indexing_maps = [
122/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124/// affine_map<(d0, d1, d2) -> (d0, d1)>],
125/// iterator_types = ["parallel", "parallel", "reduction"],
126/// kind = add} %0, %arg1, %cst_f0
127/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128/// ```
129/// Gets converted to:
130/// ```
131/// %1 = vector.contract {indexing_maps = [
132/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1)>],
135/// iterator_types = ["parallel", "parallel", "reduction"],
136/// kind = add} %arg0, %arg1, %cst_f0
137/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138/// ```
139struct CombineContractABTranspose final
140 : public OpRewritePattern<vector::ContractionOp> {
141 using Base::Base;
142
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter) const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
149 size_t index = 0;
150 bool changed = false;
151 for (Value *operand : {&lhs, &rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154 if (!transposeOp)
155 continue;
156 AffineMap permutationMap = AffineMap::getPermutationMap(
157 transposeOp.getPermutation(), contractOp.getContext());
158 map = inversePermutation(permutationMap).compose(map);
159 *operand = transposeOp.getVector();
160 changed = true;
161 }
162 if (!changed)
163 return failure();
164 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165 contractOp, lhs, rhs, contractOp.getAcc(),
166 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
167 return success();
168 }
169};
170
171/// Merges accumulator and result transposes into contract.
172///
173/// For example:
174/// ```mlir
175/// %accT = vector.transpose %acc, [0, 2, 1]
176/// : vector<2x8x4xf32> to vector<2x4x8xf32>
177/// %contract = vector.contract {
178/// indexing_maps = [
179/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182/// ],
183/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184/// kind = #vector.kind<add>
185/// } %lhs, %rhs, %accT
186/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187/// %0 = vector.transpose %contract, [0, 2, 1]
188/// : vector<2x4x8xf32> to vector<2x8x4>
189/// ```
190/// Becomes:
191/// ```mlir
192/// %0 = vector.contract {
193/// indexing_maps = [
194/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197/// ],
198/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199/// kind = #vector.kind<add>
200/// } %lhs, %rhs, %acc
201/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202/// ```
203struct CombineContractResultTranspose final
204 : public OpRewritePattern<vector::TransposeOp> {
205 using Base::Base;
206
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter) const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
211 return failure();
212
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214 if (!accTOp)
215 return failure();
216
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
220
221 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222 // To index into A in contract, we need revert(f)(g(C)) -> A.
223 auto accTMap =
224 AffineMap::getPermutationMap(accTOp.getPermutation(), context);
225
226 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227 // To index into E in contract, we need h(g(C)) -> E.
228 auto resTMap =
229 AffineMap::getPermutationMap(resTOp.getPermutation(), context);
230 auto combinedResMap = resTMap.compose(contractMap);
231
232 // The accumulator and result share the same indexing map. So they should be
233 // the same to be able to merge. This means combinedResMap is the same as
234 // inversePermutation(accTMap).compose(contractMap), which means
235 if (inversePermutation(accTMap) != resTMap)
236 return failure();
237 maps.back() = combinedResMap;
238
239 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
241 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
242 return success();
243 }
244};
245
246/// Merge BroadcastOp into ContractionOp user.
247/// Ex:
248/// ```
249/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250/// %1 = vector.contract {indexing_maps = [
251/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253/// affine_map<(d0, d1, d2) -> (d0, d1)>],
254/// iterator_types = ["parallel", "parallel", "reduction"],
255/// kind = add} %0, %arg1, %cst_f0
256/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257/// ```
258/// Gets converted to:
259/// ```
260/// %1 = vector.contract {indexing_maps = [
261/// affine_map<(d0, d1, d2) -> (d1, d2)>,
262/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263/// affine_map<(d0, d1, d2) -> (d0, d1)>],
264/// iterator_types = ["parallel", "parallel", "reduction"],
265/// kind = add} %arg0, %arg1, %cst_f0
266/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267/// ```
268///
269/// For masked vector.contract, the mask requires updating when a dimension is
270/// dropped. In such cases, the dropped dimensions must correspond to the mask's
271/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272/// is not supported.
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
275 PatternRewriter &rewriter) {
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
280 size_t index = 0;
281 bool changed = false;
282 for (Value *operand : {&lhs, &rhs}) {
283 AffineMap &map = maps[index++];
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285 if (!broadcast)
286 continue;
287 // contractionOp can only take vector as operands.
288 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
289 if (!srcType ||
290 srcType.getRank() == broadcast.getResultVectorType().getRank())
291 continue;
292 int64_t rankDiff =
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast = false;
295 SmallVector<AffineExpr> originalDims;
296 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
297 if (dim.value() !=
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast = true;
300 break;
301 }
302 originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
303 }
304 // Contract doesn't support inner dimension broadcast. Once this is
305 // relaxed we can remove this case.
306 if (innerDimBroadcast)
307 continue;
308
309 // It would be incorrect to fold a broadcast onto a reduction dimension
310 // of non-unit size.
311 bool nonUnitDimReductionBroadcast = false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
314 isReductionIterator(contractOp.getIteratorTypes()
315 .getValue()[map.getDimPosition(i)])) {
316 nonUnitDimReductionBroadcast = true;
317 break;
318 }
319 }
320 if (nonUnitDimReductionBroadcast)
321 continue;
322
323 AffineMap broadcastMap =
324 AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
325 originalDims, contractOp.getContext());
326 map = broadcastMap.compose(map);
327 *operand = broadcast.getSource();
328 changed = true;
329 }
330
331 if (!changed)
332 return failure();
333
334 // Determine which dims are usused, now that the maps have been composed
335 // with the broadcast maps.
336 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337 // Compress unused dims.
338 for (auto &m : maps)
339 m = compressDims(m, unusedDimsBitVector);
340 // Compute the combined iterators.
341 SmallVector<Attribute> iterators;
342 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
345 }
346
347 // Check whether any of the unused dims is non-unit, e.g.:
348 // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349 // This is only required when collapsing a mask. If there is no mask, skip.
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit = false;
352 if (maskingOp) {
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit = true;
357 break;
358 }
359 }
360 }
361
362 // Check that compressing unused dims isn't removing all reduction dimension
363 // pairs. For example, if the vector.contract had only one reduction
364 // iterator and that was a unit-dimension created by a broadcast,
365 // then we should bail here, otherwise we would create a contract without
366 // a reduction dimension pair.
367 bool hasReductionIteratorApplyingOnBothSides = false;
368 for (unsigned i = 0; i < iterators.size(); ++i) {
369 if (!isReductionIterator(iterators[i]))
370 continue;
371 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
372 hasReductionIteratorApplyingOnBothSides = true;
373 break;
374 }
375 }
376 if (!hasReductionIteratorApplyingOnBothSides)
377 return failure();
378
379 // If the compressed maps have a dimension that is not used by either LHS or
380 // RHS then the ContractionOp verifier would fail.
381 if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
382 return failure();
383
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
386 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
387
388 // Handle the mask.
389 if (maskingOp) {
390 if (isAnyUnusedDimNonUnit)
391 return rewriter.notifyMatchFailure(contractOp,
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
396
397 // If a dimension has been dropped, update the mask accordingly. Otherwise,
398 // keep it as is.
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
401 // At this point, two assumptions are made:
402 // * The unused dimensions are the leading mask dimensions
403 // (vector.contract does not support inner dim broadcasting).
404 // * The unused dimensions are all unit.
405 // These conditions are effectively verified in the blocks preceeding this
406 // one.
407 auto newShape =
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
415 .getResult();
416 }
417
418 newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
419 }
420 return newOp->getResult(0);
421}
422
423struct CombineContractBroadcastMask
424 : public MaskableOpRewritePattern<vector::ContractionOp> {
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
426 FailureOr<Value>
427
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
430 PatternRewriter &rewriter) const override {
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
432 }
433};
434
435/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
436/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
437/// casting ops are around these operations.
438/// Ex:
439/// ```
440/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
441/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
442/// ```
443/// Gets converted to:
444/// ```
445/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
446/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
447/// ```
448struct ReorderCastOpsOnBroadcast
449 : public OpInterfaceRewritePattern<CastOpInterface> {
450 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
451
452 LogicalResult matchAndRewrite(CastOpInterface op,
453 PatternRewriter &rewriter) const override {
454 if (op->getNumOperands() != 1)
455 return failure();
456 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
457 if (!bcastOp)
458 return failure();
459
460 Type castResTy = getElementTypeOrSelf(op->getResult(0));
461 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
462 castResTy = vecTy.clone(castResTy);
463 auto *castOp =
464 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
465 bcastOp.getSource(), castResTy, op->getAttrs());
466 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
467 op, op->getResult(0).getType(), castOp->getResult(0));
468 return success();
469 }
470};
471
472/// Reorders elementwise(transpose) to transpose(elementwise). This makes
473/// transpose ops and contraction ops closer, which kicks in
474/// CombineContractABTranspose pattern when elementwise ops are between these
475/// operations. Ex:
476/// ```
477/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
478/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
479/// %r = arith.addf %at, %bt : vector<2x4xf32>
480/// ```
481/// Gets converted to:
482/// ```
483/// %0 = arith.addf %a, %b : vector<4x2xf32>
484/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
485/// ```
486struct ReorderElementwiseOpsOnTranspose final
487 : public OpTraitRewritePattern<OpTrait::Elementwise> {
489 LogicalResult matchAndRewrite(Operation *op,
490 PatternRewriter &rewriter) const override {
491 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
492 return failure();
493
494 // Make sure all operands are transpose/constant ops and collect their
495 // transposition maps.
496 SmallVector<ArrayRef<int64_t>> transposeMaps;
497 transposeMaps.reserve(op->getNumOperands());
498 // Record the initial type before transposition. We'll use its shape later.
499 // Any type will do here as we will check all transpose maps are the same.
500 VectorType srcType;
501 for (Value operand : op->getOperands()) {
502 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
503 if (transposeOp) {
504 transposeMaps.push_back(transposeOp.getPermutation());
505 srcType = transposeOp.getSourceVectorType();
506 } else if (!matchPattern(operand, m_Constant())) {
507 return failure();
508 }
509 }
510 if (transposeMaps.empty())
511 return failure();
512 // This is an elementwise op, so all transposed operands should have the
513 // same type. We need to additionally check that all transposes uses the
514 // same map.
515 if (!llvm::all_equal(transposeMaps))
516 return rewriter.notifyMatchFailure(op, "different transpose map");
517
518 SmallVector<Value> srcValues;
519 srcValues.reserve(op->getNumOperands());
520
521 // If there are constant operands, we need to insert inverse transposes for
522 // them. Calculate the inverse order first.
523 auto order = transposeMaps.front();
524 SmallVector<int64_t> invOrder(order.size());
525 for (int i = 0, e = order.size(); i < e; ++i)
526 invOrder[order[i]] = i;
527
528 for (Value operand : op->getOperands()) {
529 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
530 if (transposeOp) {
531 srcValues.push_back(transposeOp.getVector());
532 } else {
533 // This is a constant. Create a reverse transpose op for it.
534 auto vectorType =
535 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
536 srcValues.push_back(vector::TransposeOp::create(
537 rewriter, operand.getLoc(), vectorType, operand, invOrder));
538 }
539 }
540
541 auto vectorType = srcType.clone(
542 cast<VectorType>(op->getResultTypes()[0]).getElementType());
543 Operation *elementwiseOp =
544 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
545 vectorType, op->getAttrs());
546 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
547 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
548 transposeMaps.front());
549 return success();
550 }
551};
552
553// Returns the values in `arrayAttr` as an integer vector.
554static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
555 return llvm::to_vector<4>(
556 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
557 [](IntegerAttr attr) { return attr.getInt(); }));
558}
559
560// Shuffles vector.bitcast op after vector.extract op.
561//
562// This transforms IR like:
563// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
564// %1 = vector.extract %0[3] : f16 from vector<8xf16>
565// Into:
566// %0 = vector.extract %src[1] : f32 from vector<4xf32>
567// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
568// %2 = vector.extract %1[1] : f16 from vector<2xf16>
569struct BubbleDownVectorBitCastForExtract
570 : public OpRewritePattern<vector::ExtractOp> {
571 using Base::Base;
572
573 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
574 PatternRewriter &rewriter) const override {
575 // Only support extracting scalars for now.
576 if (extractOp.getSourceVectorType().getRank() != 1)
577 return failure();
578
579 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
580 if (!castOp)
581 return failure();
582
583 VectorType castSrcType = castOp.getSourceVectorType();
584 VectorType castDstType = castOp.getResultVectorType();
585 assert(castSrcType.getRank() == castDstType.getRank());
586
587 // Fail to match if we only have one element in the cast op source.
588 // This is to avoid infinite loop given that this pattern can generate
589 // such cases.
590 if (castSrcType.getNumElements() == 1)
591 return failure();
592
593 // Only support casting to a larger number of elements or now.
594 // E.g., vector<4xf32> -> vector<8xf16>.
595 if (castSrcType.getNumElements() > castDstType.getNumElements())
596 return failure();
597
598 unsigned expandRatio =
599 castDstType.getNumElements() / castSrcType.getNumElements();
600
601 // Get the first element of the mixed position as integer.
602 auto mixedPos = extractOp.getMixedPosition();
603 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
604 return failure();
605 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
606
607 // Get the single scalar (as a vector) in the source value that packs the
608 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
609 Location loc = extractOp.getLoc();
610 Value packedValue = vector::ExtractOp::create(
611 rewriter, loc, castOp.getSource(), index / expandRatio);
612 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
613 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
614 rewriter.getZeroAttr(packedVecType));
615 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
616 /*position=*/0);
617
618 // Cast it to a vector with the desired scalar's type.
619 // E.g. f32 -> vector<2xf16>
620 VectorType packedType =
621 VectorType::get({expandRatio}, castDstType.getElementType());
622 Value castedValue =
623 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
624
625 // Finally extract the desired scalar.
626 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
627 index % expandRatio);
628 return success();
629 }
630};
631
632// Shuffles vector.bitcast op after vector.extract_strided_slice op.
633//
634// This transforms IR like:
635// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
636// %0 = vector.extract_strided_slice %cast {
637// offsets = [4], sizes = [4], strides = [1]
638// } : vector<8xf16> to vector<4xf16>
639// Into:
640// %0 = vector.extract_strided_slice %src {
641// offsets = [2], sizes = [2], strides = [1]
642// } : vector<4xf32> to vector<2xf32>
643// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
644struct BubbleDownBitCastForStridedSliceExtract
645 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
646 using Base::Base;
647
648 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
649 PatternRewriter &rewriter) const override {
650 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
651 if (!castOp)
652 return failure();
653
654 VectorType castSrcType = castOp.getSourceVectorType();
655 VectorType castDstType = castOp.getResultVectorType();
656 assert(castSrcType.getRank() == castDstType.getRank());
657
658 int64_t castSrcLastDim = castSrcType.getShape().back();
659 int64_t castDstLastDim = castDstType.getShape().back();
660 // Require casting to more elements for now; other cases to be implemented.
661 if (castSrcLastDim > castDstLastDim)
662 return failure();
663
664 // Only accept all one strides for now.
665 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
666 [](const APInt &val) { return !val.isOne(); }))
667 return failure();
668
669 unsigned rank = extractOp.getSourceVectorType().getRank();
670 assert(castDstLastDim % castSrcLastDim == 0);
671 int64_t expandRatio = castDstLastDim / castSrcLastDim;
672
673 // If we have a less number of offsets than the rank, then implicitly we
674 // are selecting the full range for the last bitcasted dimension; other
675 // dimensions aren't affected. Otherwise, we need to scale down the last
676 // dimension's offset given we are extracting from less elements now.
677 ArrayAttr newOffsets = extractOp.getOffsets();
678 if (newOffsets.size() == rank) {
679 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
680 if (offsets.back() % expandRatio != 0)
681 return failure();
682 offsets.back() = offsets.back() / expandRatio;
683 newOffsets = rewriter.getI64ArrayAttr(offsets);
684 }
685
686 // Similarly for sizes.
687 ArrayAttr newSizes = extractOp.getSizes();
688 if (newSizes.size() == rank) {
689 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
690 if (sizes.back() % expandRatio != 0)
691 return failure();
692 sizes.back() = sizes.back() / expandRatio;
693 newSizes = rewriter.getI64ArrayAttr(sizes);
694 }
695
696 SmallVector<int64_t> dims =
697 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
698 dims.back() = dims.back() / expandRatio;
699 VectorType newExtractType =
700 VectorType::get(dims, castSrcType.getElementType());
701
702 auto newExtractOp = vector::ExtractStridedSliceOp::create(
703 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
704 newOffsets, newSizes, extractOp.getStrides());
705
706 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
707 extractOp, extractOp.getType(), newExtractOp);
708
709 return success();
710 }
711};
712
713// Shuffles vector.bitcast op before vector.insert_strided_slice op.
714//
715// This transforms IR like:
716// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
717// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
718// Into:
719// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
720// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
721// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
722//
723struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
724 using Base::Base;
725
726 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
727 PatternRewriter &rewriter) const override {
728 VectorType castSrcType = bitcastOp.getSourceVectorType();
729 VectorType castDstType = bitcastOp.getResultVectorType();
730
731 // 0-D and scalable vectors are not supported yet.
732 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
733 castDstType.isScalable())
734 return failure();
735
736 int64_t castSrcLastDim = castSrcType.getShape().back();
737 int64_t castDstLastDim = castDstType.getShape().back();
738 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
739 int64_t ratio;
740 if (isNumElemsShrink) {
741 assert(castSrcLastDim % castDstLastDim == 0);
742 ratio = castSrcLastDim / castDstLastDim;
743 } else {
744 assert(castDstLastDim % castSrcLastDim == 0);
745 ratio = castDstLastDim / castSrcLastDim;
746 }
747
748 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
749 if (!insertOp)
750 return failure();
751
752 // Only vector sources are supported for now.
753 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
754 if (!insertSrcType)
755 return failure();
756
757 // Bitcast the source.
758 SmallVector<int64_t> srcDims(insertSrcType.getShape());
759 srcDims.back() =
760 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
761 VectorType newCastSrcType =
762 VectorType::get(srcDims, castDstType.getElementType());
763 auto newCastSrcOp =
764 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
765 insertOp.getValueToStore());
766
767 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
768 dstDims.back() =
769 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
770 VectorType newCastDstType =
771 VectorType::get(dstDims, castDstType.getElementType());
772
773 // Bitcast the destination.
774 auto newCastDstOp = vector::BitCastOp::create(
775 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
776
777 // Generate new insert.
778 rewriter.replaceOpWithNewOp<vector::InsertOp>(
779 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
780 return success();
781 }
782};
783
784// Shuffles vector.bitcast op before vector.insert_strided_slice op.
785//
786// This transforms IR like:
787// %0 = vector.insert_strided_slice %src, %dst {
788// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
789// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
790// Into:
791// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
792// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
793// %2 = vector.insert_strided_slice %src, %dst {
794// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
795struct BubbleUpBitCastForStridedSliceInsert
796 : public OpRewritePattern<vector::BitCastOp> {
797 using Base::Base;
798
799 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
800 PatternRewriter &rewriter) const override {
801 VectorType castSrcType = bitcastOp.getSourceVectorType();
802 VectorType castDstType = bitcastOp.getResultVectorType();
803 assert(castSrcType.getRank() == castDstType.getRank());
804 // Skip 0-D vector which will not from InsertStridedSliceOp.
805 if (castSrcType.getRank() == 0)
806 return failure();
807
808 int64_t castSrcLastDim = castSrcType.getShape().back();
809 int64_t castDstLastDim = castDstType.getShape().back();
810 // Require casting to less elements for now; other cases to be implemented.
811 if (castSrcLastDim < castDstLastDim)
812 return failure();
813
814 assert(castSrcLastDim % castDstLastDim == 0);
815 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
816
817 auto insertOp =
818 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
819 if (!insertOp)
820 return failure();
821
822 // Only accept all one strides for now.
823 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
824 [](const APInt &val) { return !val.isOne(); }))
825 return failure();
826
827 unsigned rank = insertOp.getSourceVectorType().getRank();
828 // Require insert op to have the same rank for the source and destination
829 // vector; other cases to be implemented.
830 if (rank != insertOp.getDestVectorType().getRank())
831 return failure();
832
833 // Requires that shape of insert op src is castable to dstType.
834 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
835 unsigned destinationWidth =
836 castDstType.getElementType().getIntOrFloatBitWidth();
837 unsigned numElements = destinationWidth / sourceWidth;
838 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
839 return failure();
840
841 ArrayAttr newOffsets = insertOp.getOffsets();
842 assert(newOffsets.size() == rank);
843 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
844 if (offsets.back() % shrinkRatio != 0)
845 return failure();
846 offsets.back() = offsets.back() / shrinkRatio;
847 newOffsets = rewriter.getI64ArrayAttr(offsets);
848
849 SmallVector<int64_t> srcDims =
850 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
851 srcDims.back() = srcDims.back() / shrinkRatio;
852 VectorType newCastSrcType =
853 VectorType::get(srcDims, castDstType.getElementType());
854
855 auto newCastSrcOp =
856 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
857 insertOp.getValueToStore());
858
859 SmallVector<int64_t> dstDims =
860 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
861 dstDims.back() = dstDims.back() / shrinkRatio;
862 VectorType newCastDstType =
863 VectorType::get(dstDims, castDstType.getElementType());
864
865 auto newCastDstOp = vector::BitCastOp::create(
866 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
867
868 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
869 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
870 insertOp.getStrides());
871
872 return success();
873 }
874};
875
876// Breaks down vector.bitcast op
877//
878// This transforms IR like:
879// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
880// Into:
881// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
882// %1 = vector.extract_strided_slice %0 {
883// offsets = [0], sizes = [4], strides = [1]
884// } : vector<8xf16> to vector<4xf16>
885// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
886// %4 = vector.insert_strided_slice %2, %cst {
887// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
888// %5 = vector.extract_strided_slice %0 {
889// offsets = [4], sizes = [4], strides = [1]
890// } : vector<8xf16> to vector<4xf16>
891// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
892// %7 = vector.insert_strided_slice %6, %cst {
893// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
894struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
895 using Base::Base;
896
897public:
898 BreakDownVectorBitCast(MLIRContext *context,
899 std::function<bool(vector::BitCastOp)> controlFn,
900 PatternBenefit benefit)
901 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
902
903 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
904 PatternRewriter &rewriter) const override {
905
906 if (controlFn && !controlFn(bitcastOp))
907 return failure();
908
909 VectorType castSrcType = bitcastOp.getSourceVectorType();
910 VectorType castDstType = bitcastOp.getResultVectorType();
911 assert(castSrcType.getRank() == castDstType.getRank());
912
913 // This transformation builds on top of
914 // vector.{extract|insert}_strided_slice, which do not support
915 // extracting/inserting "scallable sub-vectors". Bail out.
916 if (castSrcType.isScalable())
917 return rewriter.notifyMatchFailure(bitcastOp,
918 "Scalable vectors are not supported");
919
920 // Only support rank 1 case for now.
921 if (castSrcType.getRank() != 1)
922 return failure();
923
924 int64_t castSrcLastDim = castSrcType.getShape().back();
925 int64_t castDstLastDim = castDstType.getShape().back();
926 // Require casting to less elements for now; other cases to be implemented.
927 if (castSrcLastDim < castDstLastDim)
928 return failure();
929
930 assert(castSrcLastDim % castDstLastDim == 0);
931 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
932 // Nothing to do if it is already bitcasting to a single element.
933 if (castSrcLastDim == shrinkRatio)
934 return failure();
935
936 Location loc = bitcastOp.getLoc();
937 Type elemType = castDstType.getElementType();
938 assert(elemType.isSignlessIntOrIndexOrFloat());
939
940 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
941 rewriter.getZeroAttr(elemType));
942 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
943
944 SmallVector<int64_t> sliceShape = {castDstLastDim};
945 SmallVector<int64_t> strides = {1};
946 VectorType newCastDstType =
947 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
948 castDstType.getElementType());
949
950 for (int i = 0, e = shrinkRatio; i < e; ++i) {
951 Value extracted = ExtractStridedSliceOp::create(
952 rewriter, loc, bitcastOp.getSource(),
953 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
954 Value bitcast =
955 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
956 res = InsertStridedSliceOp::create(
957 rewriter, loc, bitcast, res,
958 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
959 }
960 rewriter.replaceOp(bitcastOp, res);
961 return success();
962 }
963
964private:
965 std::function<bool(BitCastOp)> controlFn;
966};
967
968static bool haveSameShapeAndScaling(Type t, Type u) {
969 auto tVec = dyn_cast<VectorType>(t);
970 auto uVec = dyn_cast<VectorType>(u);
971 if (!tVec) {
972 return !uVec;
973 }
974 if (!uVec) {
975 return false;
976 }
977 return tVec.getShape() == uVec.getShape() &&
978 tVec.getScalableDims() == uVec.getScalableDims();
979}
980
981/// If `type` is shaped, clone it with `newElementType`. Otherwise,
982/// return `newElementType`.
983static Type cloneOrReplace(Type type, Type newElementType) {
984 if (auto shapedType = dyn_cast<ShapedType>(type)) {
985 return shapedType.clone(newElementType);
986 }
987 return newElementType;
988}
989
990/// If `value` is the result of a broadcast operation, return the input
991/// of the broadcast operation.
992static Value getBroadcastLikeSource(Value value) {
993
994 Operation *op = value.getDefiningOp();
995 if (!op)
996 return {};
997
998 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
999 return broadcast.getSource();
1000
1001 return {};
1002}
1003
1004/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
1005///
1006/// Example:
1007/// ```
1008/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
1009/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
1010/// %r = arith.addi %a, %b : vector<1x4xindex>
1011/// ```
1012/// Gets converted to:
1013/// ```
1014/// %r = arith.addi %arg0, %arg1 : index
1015/// %b = vector.broadcast %r : index to vector<1x4xindex>
1016/// ```
1017struct ReorderElementwiseOpsOnBroadcast final
1018 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1020 LogicalResult matchAndRewrite(Operation *op,
1021 PatternRewriter &rewriter) const override {
1022 if (op->getNumResults() != 1)
1023 return failure();
1024 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
1025 if (!resultType)
1026 return failure();
1028 return rewriter.notifyMatchFailure(
1029 op, "Op doesn't have ElementwiseMappableTraits");
1030 if (op->getNumOperands() == 0)
1031 return failure();
1032 if (isa<vector::FMAOp>(op)) {
1033 return rewriter.notifyMatchFailure(
1034 op,
1035 "Op only accepts vector types - not supported as broadcast source "
1036 "might be a scalar");
1037 }
1038
1039 Type resultElemType = resultType.getElementType();
1040
1041 // Get the type of the first non-constant operand
1042 Value broadcastSource;
1043 for (Value operand : op->getOperands()) {
1044 Operation *definingOp = operand.getDefiningOp();
1045 if (!definingOp)
1046 return failure();
1047 if (definingOp->hasTrait<OpTrait::ConstantLike>())
1048 continue;
1049 broadcastSource = getBroadcastLikeSource(operand);
1050 break;
1051 }
1052 if (!broadcastSource)
1053 return failure();
1054 Type unbroadcastResultType =
1055 cloneOrReplace(broadcastSource.getType(), resultElemType);
1056
1057 // Make sure that all operands are broadcast from identically-shaped types:
1058 // * scalar (`vector.broadcast`), or
1059 // * vector (`vector.broadcast`).
1060 // Otherwise the re-ordering wouldn't be safe.
1061 if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
1062 if (auto source = getBroadcastLikeSource(val))
1063 return haveSameShapeAndScaling(source.getType(),
1064 broadcastSource.getType());
1065 SplatElementsAttr splatConst;
1066 return matchPattern(val, m_Constant(&splatConst));
1067 })) {
1068 return rewriter.notifyMatchFailure(
1069 op,
1070 "not all operands are constants or broadcasts from the same type");
1071 }
1072
1073 // Collect the source values before broadcasting
1074 SmallVector<Value> srcValues;
1075 srcValues.reserve(op->getNumOperands());
1076 for (Value operand : op->getOperands()) {
1077 SplatElementsAttr splatConst;
1078 if (matchPattern(operand, m_Constant(&splatConst))) {
1079 Attribute newConst;
1080 Type elementType = getElementTypeOrSelf(operand.getType());
1081 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1082 if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1083 newConst = splatConst.resizeSplat(newTypeShaped);
1084 } else {
1085 newConst = splatConst.getSplatValue<Attribute>();
1086 }
1087 Operation *newConstOp =
1088 operand.getDefiningOp()->getDialect()->materializeConstant(
1089 rewriter, newConst, newType, operand.getLoc());
1090 srcValues.push_back(newConstOp->getResult(0));
1091 } else {
1092 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1093 }
1094 }
1095
1096 // Create the "elementwise" Op
1097 Operation *elementwiseOp =
1098 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1099 unbroadcastResultType, op->getAttrs());
1100
1101 // Replace the original Op with the elementwise Op
1102 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1103 op, resultType, elementwiseOp->getResults());
1104
1105 return success();
1106 }
1107};
1108
1109/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1110/// This may result in cleaner code when extracting a single value
1111/// from multi-element vector and also to help canonicalize 1-element vectors to
1112/// scalars.
1113///
1114/// Example:
1115/// ```
1116/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1117/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1118/// ```
1119/// Gets converted to:
1120/// ```
1121/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1122/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1123/// %2 = arith.addf %0, %1 : f32
1124/// ```
1125class ExtractOpFromElementwise final
1126 : public OpRewritePattern<vector::ExtractOp> {
1127public:
1128 using Base::Base;
1129
1130 LogicalResult matchAndRewrite(vector::ExtractOp op,
1131 PatternRewriter &rewriter) const override {
1132 Operation *eltwise = op.getSource().getDefiningOp();
1133
1134 // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1135 // as it doesn't support scalars.
1136 if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1137 isa<vector::FMAOp>(eltwise))
1138 return rewriter.notifyMatchFailure(op, "not an elementwise op");
1139
1140 if (eltwise->getNumResults() != 1)
1141 return rewriter.notifyMatchFailure(op, "expected single result");
1142
1143 if (!eltwise->hasOneUse())
1144 return rewriter.notifyMatchFailure(op, "expected single op use");
1145
1146 if (!llvm::all_equal(eltwise->getOperandTypes()))
1147 return rewriter.notifyMatchFailure(op, "operand types are different");
1148
1149 // Dynamic position can cause dominance issues, so conservatively fail for
1150 // now.
1151 if (!op.getDynamicPosition().empty())
1152 return rewriter.notifyMatchFailure(
1153 op, "dynamic position not yet implemented");
1154
1155 Type dstType = op.getType();
1156
1157 OpBuilder::InsertionGuard g(rewriter);
1158 rewriter.setInsertionPoint(eltwise);
1159
1160 IRMapping mapping;
1161 Location loc = eltwise->getLoc();
1162 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1163 for (Value arg : eltwise->getOperands()) {
1164 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1165 mapping.map(arg, newArg);
1166 }
1167
1168 Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1169 newEltwise->getResult(0).setType(dstType);
1170
1171 rewriter.replaceOp(op, newEltwise);
1172 rewriter.eraseOp(eltwise);
1173 return success();
1174 }
1175};
1176
1177/// Check if the element type is suitable for vector.load/store sinking.
1178/// Element type must be index or byte-aligned integer or floating-point type.
1179static bool isSupportedMemSinkElementType(Type type) {
1180 if (isa<IndexType>(type))
1181 return true;
1182
1183 return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1184}
1185
1186/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1187/// Only index and byte-aligned integer and floating-point element types are
1188/// supported for now.
1189///
1190/// Example:
1191/// ```
1192/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1193/// vector.extract %0[1] : f32 from vector<4xf32>
1194/// ```
1195/// Gets converted to:
1196/// ```
1197/// %c1 = arith.constant 1 : index
1198/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1199/// %1 = memref.load %arg0[%0] : memref<?xf32>
1200/// ```
1201class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1202public:
1203 using Base::Base;
1204
1205 LogicalResult matchAndRewrite(vector::ExtractOp op,
1206 PatternRewriter &rewriter) const override {
1207 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1208 if (!loadOp)
1209 return rewriter.notifyMatchFailure(op, "expected a load op");
1210
1211 // Checking for single use so we won't duplicate load ops.
1212 if (!loadOp->hasOneUse())
1213 return rewriter.notifyMatchFailure(op, "expected single op use");
1214
1215 VectorType loadVecType = loadOp.getVectorType();
1216 if (loadVecType.isScalable())
1217 return rewriter.notifyMatchFailure(op,
1218 "scalable vectors are not supported");
1219
1220 MemRefType memType = loadOp.getMemRefType();
1221
1222 // Non-byte-aligned types are tricky and may require special handling,
1223 // ignore them for now.
1224 if (!isSupportedMemSinkElementType(memType.getElementType()))
1225 return rewriter.notifyMatchFailure(op, "unsupported element type");
1226
1227 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1228 if (rankOffset < 0)
1229 return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1230
1231 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1232 int64_t finalRank = 0;
1233 if (extractVecType)
1234 finalRank = extractVecType.getRank();
1235
1236 SmallVector<Value> indices = loadOp.getIndices();
1237 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1238
1239 // There may be memory stores between the load and the extract op, so we
1240 // need to make sure that the new load op is inserted at the same place as
1241 // the original load op.
1242 OpBuilder::InsertionGuard g(rewriter);
1243 rewriter.setInsertionPoint(loadOp);
1244 Location loc = loadOp.getLoc();
1245 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1246 for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1247 OpFoldResult pos = extractPos[i - rankOffset];
1248 if (isZeroInteger(pos))
1249 continue;
1250
1251 Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1252 indices[i] = idxBuilderf.add(indices[i], offset);
1253 }
1254
1255 Value base = loadOp.getBase();
1256 if (extractVecType) {
1257 rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1258 indices);
1259 } else {
1260 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1261 }
1262 // We checked for single use so we can safely erase the load op.
1263 rewriter.eraseOp(loadOp);
1264 return success();
1265 }
1266};
1267
1268/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
1269///
1270/// Example:
1271/// ```
1272/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
1273/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1274/// ```
1275/// Gets converted to:
1276/// ```
1277/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1278/// ```
1279class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
1280public:
1281 using Base::Base;
1282
1283 LogicalResult matchAndRewrite(vector::StoreOp op,
1284 PatternRewriter &rewriter) const override {
1285 VectorType vecType = op.getVectorType();
1286 if (vecType.isScalable())
1287 return rewriter.notifyMatchFailure(op,
1288 "scalable vectors are not supported");
1289
1290 if (isa<VectorType>(op.getMemRefType().getElementType()))
1291 return rewriter.notifyMatchFailure(
1292 op, "memrefs of vectors are not supported");
1293
1294 if (vecType.getNumElements() != 1)
1295 return rewriter.notifyMatchFailure(
1296 op, "only 1-element vectors are supported");
1297
1298 Value toStore = op.getValueToStore();
1299 Value source = getBroadcastLikeSource(toStore);
1300 if (!source)
1301 return rewriter.notifyMatchFailure(
1302 op, "value to store is not from a broadcast");
1303
1304 // Checking for single use so we can remove broadcast.
1305 Operation *broadcast = toStore.getDefiningOp();
1306 if (!broadcast->hasOneUse())
1307 return rewriter.notifyMatchFailure(op, "expected single op use");
1308
1309 Value base = op.getBase();
1310 ValueRange indices = op.getIndices();
1311
1312 if (isa<VectorType>(source.getType())) {
1313 rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1314 } else {
1315 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1316 }
1317 rewriter.eraseOp(broadcast);
1318 return success();
1319 }
1320};
1321
1322// Helper that returns a vector comparison that constructs a mask:
1323// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1324//
1325// If `dim == 0` then the result will be a 0-D vector.
1326//
1327// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1328// much more compact, IR for this operation, but LLVM eventually
1329// generates more elaborate instructions for this intrinsic since it
1330// is very conservative on the boundary conditions.
1331static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1332 bool force32BitVectorIndices, int64_t dim,
1333 Value b, Value *off = nullptr) {
1334 auto loc = op->getLoc();
1335 // If we can assume all indices fit in 32-bit, we perform the vector
1336 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1337 // Otherwise we perform the vector comparison using 64-bit indices.
1338 Type idxType =
1339 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1340 DenseIntElementsAttr indicesAttr;
1341 if (dim == 0 && force32BitVectorIndices) {
1342 indicesAttr = DenseIntElementsAttr::get(
1343 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1344 } else if (dim == 0) {
1345 indicesAttr = DenseIntElementsAttr::get(
1346 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1347 } else if (force32BitVectorIndices) {
1348 indicesAttr = rewriter.getI32VectorAttr(
1349 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1350 } else {
1351 indicesAttr = rewriter.getI64VectorAttr(
1352 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1353 }
1354 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1355 // Add in an offset if requested.
1356 if (off) {
1357 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1358 Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
1359 indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1360 }
1361 // Construct the vector comparison.
1362 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1363 Value bounds =
1364 vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1365 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1366 indices, bounds);
1367}
1368
1369template <typename ConcreteOp>
1370struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1371public:
1372 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1373 PatternBenefit benefit = 1)
1374 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1375 force32BitVectorIndices(enableIndexOpt) {}
1376
1377 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1378 PatternRewriter &rewriter) const override {
1379 if (!xferOp.hasOutOfBoundsDim())
1380 return failure();
1381
1382 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1383 return failure();
1384
1385 Location loc = xferOp->getLoc();
1386 VectorType vtp = xferOp.getVectorType();
1387
1388 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1389 // set and [dim - offset .. vector_length) unset.
1390 //
1391 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1392 // dimensions here.
1393 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1394 Value off = xferOp.getIndices()[lastIndex];
1395 Value dim =
1396 vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
1397 Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off);
1398 Value mask = vector::CreateMaskOp::create(
1399 rewriter, loc,
1400 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1401 vtp.getScalableDims()),
1402 b);
1403 if (xferOp.getMask()) {
1404 // Intersect the in-bounds with the mask specified as an op parameter.
1405 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1406 }
1407
1408 rewriter.modifyOpInPlace(xferOp, [&]() {
1409 xferOp.getMaskMutable().assign(mask);
1410 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1411 });
1412
1413 return success();
1414 }
1415
1416private:
1417 const bool force32BitVectorIndices;
1418};
1419
1420/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1421class VectorCreateMaskOpConversion
1422 : public OpRewritePattern<vector::CreateMaskOp> {
1423public:
1424 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1425 bool enableIndexOpt,
1426 PatternBenefit benefit = 1)
1427 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1428 force32BitVectorIndices(enableIndexOpt) {}
1429
1430 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1431 PatternRewriter &rewriter) const override {
1432 auto dstType = op.getType();
1433 if (cast<VectorType>(dstType).isScalable())
1434 return failure();
1435 int64_t rank = dstType.getRank();
1436 if (rank > 1)
1437 return failure();
1438 rewriter.replaceOp(
1439 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1440 rank == 0 ? 0 : dstType.getDimSize(0),
1441 op.getOperand(0)));
1442 return success();
1443 }
1444
1445private:
1446 const bool force32BitVectorIndices;
1447};
1448
1449/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1450static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1451 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1452 // TODO: Support non-dense constant.
1453 if (!denseAttr)
1454 return false;
1455
1456 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1457 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1458}
1459
1460/// Folds a select operation between an all-true and all-false vector. For now,
1461/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1462///
1463/// %true = arith.constant dense<true> : vector<1xi1>
1464/// %false = arith.constant dense<false> : vector<1xi1>
1465/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1466/// =>
1467/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1468///
1469/// InstCombine seems to handle vectors with multiple elements but not the
1470/// single element ones.
1471struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1472 using Base::Base;
1473
1474 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1475 PatternRewriter &rewriter) const override {
1476 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1477 if (!vecType || !vecType.getElementType().isInteger(1))
1478 return failure();
1479
1480 // Only scalar conditions can be folded.
1481 Value cond = selectOp.getCondition();
1482 if (isa<VectorType>(cond.getType()))
1483 return failure();
1484
1485 // TODO: Support n-D and scalable vectors.
1486 if (vecType.getRank() != 1 || vecType.isScalable())
1487 return failure();
1488
1489 // TODO: Support vectors with multiple elements.
1490 if (vecType.getShape()[0] != 1)
1491 return failure();
1492
1493 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1494 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1495 return failure();
1496
1497 auto falseConst =
1498 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1499 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1500 return failure();
1501
1502 // Replace select with its condition broadcasted to single element vector.
1503 auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1504 auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1505 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1506 return success();
1507 }
1508};
1509
1510/// Returns the number of dims can be folded away from transfer ops. It returns
1511/// a failure if it can not determine the number of dims to be folded.
1512///
1513/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1514/// `vectorType` is vector<16x16x1x1xf32>
1515/// (there two inner most dims can be dropped by memref.subview ops)
1516///
1517/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1518/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1519/// (only the inner most unit dim of `srcType` can be dropped)
1520///
1521/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1522/// `vectorType` is vector<16x16x1x[1]xf32>
1523/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1524/// unit")
1525static FailureOr<size_t>
1526getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1527 SmallVector<int64_t> srcStrides;
1528 int64_t srcOffset;
1529 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1530 return failure();
1531
1532 auto isUnitDim = [](VectorType type, int dim) {
1533 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1534 };
1535
1536 // According to vector.transfer_read/write semantics, the vector can be a
1537 // slice. Thus, we have to offset the check index with `rankDiff` in
1538 // `srcStrides` and source dim sizes.
1539 size_t result = 0;
1540 int rankDiff = srcType.getRank() - vectorType.getRank();
1541 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1542 // Check that the inner dim size is 1 for both memref type and vector slice.
1543 // It can be folded only if they are 1 and the stride is 1.
1544 int dim = vectorType.getRank() - i - 1;
1545 if (srcStrides[dim + rankDiff] != 1 ||
1546 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1547 break;
1548 result++;
1549 }
1550 return result;
1551}
1552
1553/// Drop inner most contiguous unit dimensions from transfer_read operand.
1555 : public OpRewritePattern<vector::TransferReadOp> {
1556 using Base::Base;
1557
1558 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1559 PatternRewriter &rewriter) const override {
1560 // TODO: support 0-d corner case.
1561 if (readOp.getTransferRank() == 0)
1562 return failure();
1563
1564 // TODO: support mask.
1565 if (readOp.getMask())
1566 return failure();
1567
1568 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1569 if (!srcType)
1570 return failure();
1571
1572 if (!readOp.getPermutationMap().isMinorIdentity())
1573 return failure();
1574
1575 auto targetType = readOp.getVectorType();
1576 if (targetType.getRank() <= 1)
1577 return failure();
1578
1579 FailureOr<size_t> maybeDimsToDrop =
1580 getTransferFoldableInnerUnitDims(srcType, targetType);
1581 if (failed(maybeDimsToDrop))
1582 return failure();
1583
1584 size_t dimsToDrop = maybeDimsToDrop.value();
1585 if (dimsToDrop == 0)
1586 return failure();
1587
1588 auto inBounds = readOp.getInBoundsValues();
1589 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1590 if (llvm::is_contained(droppedInBounds, false))
1591 return failure();
1592
1593 auto resultTargetVecType =
1594 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1595 targetType.getElementType(),
1596 targetType.getScalableDims().drop_back(dimsToDrop));
1597
1598 auto loc = readOp.getLoc();
1600 memref::getMixedSizes(rewriter, loc, readOp.getBase());
1601 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1602 rewriter.getIndexAttr(0));
1603 SmallVector<OpFoldResult> strides(srcType.getRank(),
1604 rewriter.getIndexAttr(1));
1605 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1606 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1607 strides);
1608 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1609 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1610 Value rankedReducedView =
1611 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1612 readOp.getBase(), offsets, sizes, strides);
1613 auto permMap = getTransferMinorIdentityMap(
1614 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1615 Value result = vector::TransferReadOp::create(
1616 rewriter, loc, resultTargetVecType, rankedReducedView,
1617 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1618 readOp.getPadding(),
1619 // TODO: support mask.
1620 /*mask=*/Value(), inBoundsAttr);
1621 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1622 result);
1623 return success();
1624 }
1625};
1626
1627/// Drop inner most contiguous unit dimensions from transfer_write operand.
1628/// E.g.,
1629/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1630/// {in_bounds = [true, true, true, true, true]}
1631/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1632///
1633/// will be replaced with
1634///
1635/// %subview = memref.subview %arg0
1636/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1637/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1638/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1639/// to vector<1x16x16xf32>
1640/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1641/// {in_bounds = [true, true, true]}
1642/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1643///
1644/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1646 : public OpRewritePattern<vector::TransferWriteOp> {
1647 using Base::Base;
1648
1649 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1650 PatternRewriter &rewriter) const override {
1651 // TODO: support 0-d corner case.
1652 if (writeOp.getTransferRank() == 0)
1653 return failure();
1654
1655 // TODO: support mask.
1656 if (writeOp.getMask())
1657 return failure();
1658
1659 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1660 if (!srcType)
1661 return failure();
1662
1663 if (!writeOp.getPermutationMap().isMinorIdentity())
1664 return failure();
1665
1666 auto targetType = writeOp.getVectorType();
1667 if (targetType.getRank() <= 1)
1668 return failure();
1669
1670 FailureOr<size_t> maybeDimsToDrop =
1671 getTransferFoldableInnerUnitDims(srcType, targetType);
1672 if (failed(maybeDimsToDrop))
1673 return failure();
1674
1675 size_t dimsToDrop = maybeDimsToDrop.value();
1676 if (dimsToDrop == 0)
1677 return failure();
1678
1679 auto inBounds = writeOp.getInBoundsValues();
1680 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1681 if (llvm::is_contained(droppedInBounds, false))
1682 return failure();
1683
1684 auto resultTargetVecType =
1685 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1686 targetType.getElementType(),
1687 targetType.getScalableDims().drop_back(dimsToDrop));
1688
1689 Location loc = writeOp.getLoc();
1691 memref::getMixedSizes(rewriter, loc, writeOp.getBase());
1692 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1693 rewriter.getIndexAttr(0));
1694 SmallVector<OpFoldResult> strides(srcType.getRank(),
1695 rewriter.getIndexAttr(1));
1696 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1697 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1698 strides);
1699 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1700 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1701
1702 Value rankedReducedView =
1703 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1704 writeOp.getBase(), offsets, sizes, strides);
1705 auto permMap = getTransferMinorIdentityMap(
1706 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1707
1708 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1709 loc, resultTargetVecType, writeOp.getVector());
1710 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1711 writeOp, shapeCast, rankedReducedView,
1712 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1713 // TODO: support mask.
1714 /*mask=*/Value(), inBoundsAttr);
1715 return success();
1716 }
1717};
1718
1719/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1720/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1721/// with the RHS transposed) lowering.
1723 : OpRewritePattern<vector::ContractionOp> {
1724 using Base::Base;
1725
1727 std::function<LogicalResult(vector::ContractionOp op)>;
1728
1730 FilterConstraintType constraint)
1731 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1732 filter(std::move(constraint)) {}
1733
1734 LogicalResult matchAndRewrite(vector::ContractionOp op,
1735 PatternRewriter &rewriter) const override {
1736 if (failed(filter(op)))
1737 return failure();
1738
1739 Location loc = op.getLoc();
1740 Value lhs = op.getLhs();
1741 Value rhs = op.getRhs();
1742 Value res = op.getAcc();
1743
1744 // Set up the parallel/reduction structure in right form.
1745 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1746 auto infer = [&](MapList m) {
1747 return AffineMap::inferFromExprList(m, op.getContext());
1748 };
1749 AffineExpr m;
1750 AffineExpr n;
1751 AffineExpr k;
1752 bindDims(rewriter.getContext(), m, n, k);
1753 static constexpr std::array<int64_t, 2> perm = {1, 0};
1754 auto iteratorTypes = op.getIteratorTypes().getValue();
1755 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1756 if (iteratorTypes.size() != 3 ||
1757 !vector::isParallelIterator(iteratorTypes[0]) ||
1758 !vector::isParallelIterator(iteratorTypes[1]) ||
1759 !vector::isReductionIterator(iteratorTypes[2]))
1760 return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1761
1762 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1763 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1764 if (maps == canonicalForm)
1765 return rewriter.notifyMatchFailure(op, "already in the canonical form");
1766
1767 // Create a vector transpose making sure to emit zero/sign-extend at the
1768 // end.
1769 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1770 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1771 Value trans =
1772 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1773 VectorType newType =
1774 cast<VectorType>(trans.getType())
1775 .clone(cast<VectorType>(mat.getType()).getElementType());
1776 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1777 }
1778 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1779 Value trans =
1780 vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm);
1781 VectorType newType =
1782 VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1783 cast<VectorType>(mat.getType()).getElementType());
1784 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1785 }
1786 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1787 };
1788
1789 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1790 rhs = createTranspose(rhs);
1791 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1792 lhs = createTranspose(lhs);
1793 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1794 rhs = createTranspose(rhs);
1795 lhs = createTranspose(lhs);
1796 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1797 std::swap(rhs, lhs);
1798 rhs = createTranspose(rhs);
1799 lhs = createTranspose(lhs);
1800 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1801 std::swap(rhs, lhs);
1802 rhs = createTranspose(rhs);
1803 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1804 std::swap(lhs, rhs);
1805 lhs = createTranspose(lhs);
1806 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1807 std::swap(lhs, rhs);
1808 } else {
1809 return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1810 }
1811 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1812 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1813 op.getIteratorTypes());
1814 return success();
1815 };
1816
1817private:
1818 FilterConstraintType filter;
1819};
1820
1821/// Pattern to fold arithmetic extensions on floating point data types into
1822/// vector contraction operations. linalg.matmul introduces arithmetic
1823/// extensions on its operands. Please mlir snippets below for more details.
1824/// ```mlir
1825/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1826/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1827/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1828/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1829/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1830/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1831/// "linalg.yield"(%acc) : (f32) -> ()
1832/// })
1833/// ```
1834/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1835/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1836/// This pattern folds the arithmetic extensions into the vector contraction and
1837/// enables the usage of native mixed precision Tensor Core instructions.
1838template <typename ExtOp>
1840 : public OpRewritePattern<vector::ContractionOp> {
1841 using Base::Base;
1842
1843 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1844 PatternRewriter &rewriter) const override {
1845
1846 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1847 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1848
1849 if (!lhsDefOp || !rhsDefOp) {
1850 return rewriter.notifyMatchFailure(contractOp,
1851 "no defining op on contract operands");
1852 }
1853
1854 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1855 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1856 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1857 contractOp.getIteratorTypesAttr());
1858
1859 return success();
1860 }
1861};
1862
1863/// Pattern to fold chained reduction to a series of vector additions and a
1864/// final reduction. This form should require fewer subgroup operations.
1865///
1866/// ```mlir
1867/// %a = vector.reduction <add> %x, %acc
1868/// %b = vector.reduction <add> %y, %a
1869/// ==>
1870/// %a = arith.addf %x, %y
1871/// %b = vector.reduction <add> %a, %acc
1872/// ```
1873struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1874 using Base::Base;
1875
1876 LogicalResult matchAndRewrite(vector::ReductionOp op,
1877 PatternRewriter &rewriter) const override {
1878 // TODO: Handle other combining kinds.
1879 if (op.getKind() != vector::CombiningKind::ADD)
1880 return failure();
1881
1882 // Accumulator is optional.
1883 Value acc = op.getAcc();
1884 if (!acc)
1885 return failure();
1886
1887 if (!acc.getType().isIntOrFloat())
1888 return failure();
1889
1890 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1891 if (!parentReduction)
1892 return failure();
1893
1894 Location loc = op.getLoc();
1895 Value vAdd;
1896 if (isa<IntegerType>(acc.getType())) {
1897 vAdd = rewriter.createOrFold<arith::AddIOp>(
1898 loc, parentReduction.getVector(), op.getVector());
1899 } else {
1900 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1901 op.getVector());
1902 }
1903 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1904 parentReduction.getAcc());
1905 return success();
1906 }
1907};
1908
1909// Helper function dropping unit non-scalable dimension from a VectorType
1910// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1911// dimensions are not dropped. Folding such dimensions would require "shifting"
1912// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1913// vector<[4]xf32>). This could be implemented in the future.
1914static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1915 auto inVecShape = inVecTy.getShape();
1916 SmallVector<int64_t> newShape;
1917 SmallVector<bool> newScalableDims;
1918 for (auto [dim, isScalable] :
1919 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1920 if (dim == 1 && !isScalable)
1921 continue;
1922
1923 newShape.push_back(dim);
1924 newScalableDims.push_back(isScalable);
1925 }
1926 // All dims have been dropped, return vector<1xeType>.
1927 if (newShape.empty()) {
1928 newShape.push_back(1);
1929 newScalableDims.push_back(false);
1930 }
1931
1932 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1933}
1934
1935/// For vectors with at least one unit dim, replaces:
1936/// elementwise(a, b)
1937/// with:
1938/// sc_a = shape_cast(a)
1939/// sc_b = shape_cast(b)
1940/// res = elementwise(sc_a, sc_b)
1941/// return shape_cast(res)
1942/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1943/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1944/// required to be rank > 1.
1945///
1946/// Ex:
1947/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1948/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1949///
1950/// gets converted to:
1951///
1952/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1953/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1954/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1955/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1956/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1957///
1958/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1959/// `%cast`.
1961 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1963 LogicalResult matchAndRewrite(Operation *op,
1964 PatternRewriter &rewriter) const override {
1965 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1966 return failure();
1967
1968 auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1969 if (!resultVectorType)
1970 return failure();
1971
1972 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1973 // guaranteed to have identical shapes (with some exceptions such as
1974 // `arith.select`) and it suffices to only check one of them.
1975 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1976 if (!sourceVectorType)
1977 return failure();
1978 if (sourceVectorType.getRank() < 2)
1979 return failure();
1980
1981 SmallVector<Value> newOperands;
1982 auto loc = op->getLoc();
1983 for (auto operand : op->getOperands()) {
1984 auto opVectorType = cast<VectorType>(operand.getType());
1985 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1986 if (newVType == opVectorType)
1987 return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1988
1989 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1990 newOperands.push_back(opSC);
1991 }
1992
1993 VectorType newResultVectorType =
1994 dropNonScalableUnitDimFromType(resultVectorType);
1995 // Create an updated elementwise Op without unit dim.
1996 Operation *elementwiseOp =
1997 rewriter.create(loc, op->getName().getIdentifier(), newOperands,
1998 newResultVectorType, op->getAttrs());
1999
2000 // Restore the unit dim by applying vector.shape_cast to the result.
2001 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
2002 elementwiseOp->getResult(0));
2003
2004 return success();
2005 }
2006};
2007
2008/// A pattern to drop unit dims from vector.transpose.
2009///
2010/// Example:
2011///
2012/// BEFORE:
2013/// ```mlir
2014/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
2015/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
2016/// ```
2017///
2018/// AFTER:
2019/// ```mlir
2020/// %dropDims = vector.shape_cast %vector
2021/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
2022/// %transpose = vector.transpose %0, [1, 0]
2023/// : vector<4x[4]xf32> to vector<[4]x4xf32>
2024/// %restoreDims = vector.shape_cast %transpose
2025/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2026/// ```
2028 : OpRewritePattern<vector::TransposeOp> {
2029 using Base::Base;
2030
2031 LogicalResult matchAndRewrite(vector::TransposeOp op,
2032 PatternRewriter &rewriter) const override {
2033 VectorType sourceType = op.getSourceVectorType();
2034 VectorType sourceTypeWithoutUnitDims =
2036
2037 if (sourceType == sourceTypeWithoutUnitDims)
2038 return failure();
2039
2040 // Construct a map from dimIdx -> number of dims dropped before dimIdx.
2041 auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
2042 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
2043 int64_t droppedDims = 0;
2044 for (auto [i, dim] : llvm::enumerate(sourceDims)) {
2045 droppedDimsBefore[i] = droppedDims;
2046 if (dim == std::make_tuple(1, false))
2047 ++droppedDims;
2048 }
2049
2050 // Drop unit dims from transpose permutation.
2051 ArrayRef<int64_t> perm = op.getPermutation();
2052 SmallVector<int64_t> newPerm;
2053 for (int64_t idx : perm) {
2054 if (sourceDims[idx] == std::make_tuple(1, false))
2055 continue;
2056 newPerm.push_back(idx - droppedDimsBefore[idx]);
2057 }
2058
2059 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
2060 // type when the dimensions are unit dimensions. In this case, the newPerm
2061 // should be [0].
2062 if (newPerm.empty()) {
2063 newPerm.push_back(0);
2064 }
2065
2066 Location loc = op.getLoc();
2067 // Drop the unit dims via shape_cast.
2068 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2069 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2070 // Create the new transpose.
2071 auto transposeWithoutUnitDims =
2072 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2073 // Restore the unit dims via shape cast.
2074 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2075 op, op.getResultVectorType(), transposeWithoutUnitDims);
2076
2077 return success();
2078 }
2079};
2080
2081/// A pattern to drop unit dims from the iter_args of an scf.for.
2082///
2083/// Example:
2084///
2085/// BEFORE:
2086/// ```mlir
2087/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2088/// ...
2089/// scf.yield %
2090/// }
2091/// ```
2092///
2093/// AFTER:
2094/// ```mlir
2095/// %drop = vector.shape_cast %init
2096/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2097/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2098/// %new_iter = vector.shape_cast %iter
2099/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2100/// ...
2101/// }
2102/// %res = vector.shape_cast %new_loop
2103/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2104/// ```
2105struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2106 using Base::Base;
2107
2108 LogicalResult matchAndRewrite(scf::ForOp forOp,
2109 PatternRewriter &rewriter) const override {
2110 /// Find the first iter_arg with droppable unit dims. Further applications
2111 /// of this pattern will apply to later arguments.
2112 for (OpOperand &operand : forOp.getInitArgsMutable()) {
2113 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2114 if (!vectorType)
2115 continue;
2116
2117 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2118 if (vectorType == newVectorType)
2119 continue;
2120
2121 // Create a new ForOp with that iter operand replaced.
2122 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2123 return vector::ShapeCastOp::create(b, loc, type, source);
2124 };
2125
2127 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2128 rewriter.replaceOp(forOp,
2129 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2130 replacement, castFn));
2131 return success();
2132 }
2133 return failure();
2134 }
2135};
2136
2137/// Pattern to eliminate redundant zero-constants added to reduction operands.
2138/// It's enough for there to be one initial zero value, so we can eliminate the
2139/// extra ones that feed into `vector.reduction <add>`. These get created by the
2140/// `ChainedReduction` pattern.
2141///
2142/// ```mlir
2143/// %a = arith.addf %x, %zero
2144/// %b = arith.addf %a, %y
2145/// %c = vector.reduction <add> %b, %acc
2146/// ==>
2147/// %b = arith.addf %a, %y
2148/// %c = vector.reduction <add> %b, %acc
2149/// ```
2150struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2151 using Base::Base;
2152
2153 LogicalResult matchAndRewrite(vector::ReductionOp op,
2154 PatternRewriter &rewriter) const override {
2155 // TODO: Handle other reduction kinds and their identity values.
2156 if (op.getKind() != vector::CombiningKind::ADD)
2157 return failure();
2158
2159 Type elemType = op.getSourceVectorType().getElementType();
2160 // The integer case should be handled by `arith.addi` folders, only check
2161 // for floats here.
2162 if (!isa<FloatType>(elemType))
2163 return failure();
2164
2165 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2166 if (!vAdd)
2167 return failure();
2168 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2169 if (!addLhs)
2170 return failure();
2171
2172 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2173 return failure();
2174
2175 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2176 addLhs.getLhs(), vAdd.getRhs());
2177 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2178 op.getAcc());
2179 return success();
2180 }
2181};
2182
2183/// Example:
2184/// ```
2185/// %a = vector.reduction <add> %x : vector<2xf32> into f32
2186/// ```
2187/// is transformed into:
2188/// ```
2189/// %y = vector.extract %x[0] : f32 from vector<2xf32>
2190/// %z = vector.extract %x[1] : f32 from vector<2xf32>
2191/// %a = arith.addf %y, %z : f32
2192/// ```
2193struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2195 unsigned maxNumElementsToExtract,
2196 PatternBenefit benefit)
2197 : OpRewritePattern(context, benefit),
2198 maxNumElementsToExtract(maxNumElementsToExtract) {}
2199
2200 LogicalResult matchAndRewrite(vector::ReductionOp op,
2201 PatternRewriter &rewriter) const override {
2202 VectorType type = op.getSourceVectorType();
2203 if (type.isScalable() || op.isMasked())
2204 return failure();
2205 assert(type.getRank() == 1 && "Expected a 1-d vector");
2206
2207 int64_t numElems = type.getNumElements();
2208 if (numElems > maxNumElementsToExtract) {
2209 return rewriter.notifyMatchFailure(
2210 op, llvm::formatv("has too many vector elements ({0}) to break down "
2211 "(max allowed: {1})",
2212 numElems, maxNumElementsToExtract));
2213 }
2214
2215 Location loc = op.getLoc();
2216 SmallVector<Value> extracted(numElems, nullptr);
2217 for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2218 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2219 static_cast<int64_t>(idx));
2220
2221 Value res = extracted.front();
2222 for (auto extractedElem : llvm::drop_begin(extracted))
2223 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2224 extractedElem, op.getFastmathAttr());
2225 if (Value acc = op.getAcc())
2226 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2227 op.getFastmathAttr());
2228
2229 rewriter.replaceOp(op, res);
2230 return success();
2231 }
2232
2233private:
2234 unsigned maxNumElementsToExtract = 0;
2235};
2236
2237/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2238/// B)`.
2239/// Example:
2240/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2241/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2242/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2243/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2244///
2245/// Becomes :
2246///
2247/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2248///
2249/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2250/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2251/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2252/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2253template <typename MulOpType>
2254struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2255 using OpRewritePattern<MulOpType>::OpRewritePattern;
2256 // Returns whether a vector.broadcast matches requirements for an outerproduct
2257 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2258 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2259 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2260 // shape_casts/broadcasts which does not belong in this pattern.
2261 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2262 return false;
2263 // Avoid broadcast like f32 or vector<f32> -> ResType
2264 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2265 return srcType && srcType.getRank() != 2;
2266 }
2267
2268 LogicalResult matchAndRewrite(MulOpType mulOp,
2269 PatternRewriter &rewriter) const override {
2270 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2271 if (!resType)
2272 return failure();
2273 if (resType.getRank() != 2)
2274 return failure();
2275 /// If operandA can be written as tr(broadcast(A)) and operandB as
2276 /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2277 /// vector.outerproduct(A, B). Returns failure() otherwise.
2278 auto matchOuterProduct =
2279 [&](Value operandA,
2280 Value operandB) -> FailureOr<vector::OuterProductOp> {
2281 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2282 if (!transposedLhs)
2283 return failure();
2284 // Fail unless this is a true 2-D matrix transpose.
2285 ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2286 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2287 return failure();
2288
2289 auto broadcastedLhs =
2290 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2291 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2292 return failure();
2293
2294 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2295 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2296 return failure();
2297
2298 return vector::OuterProductOp::create(
2299 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2300 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2301 };
2302
2303 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2304 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2305 // Handle commutativity, the transposed op is the outerproduct LHS.
2306 if (failed(maybeOuterP))
2307 maybeOuterP = matchOuterProduct(rhs, lhs);
2308 if (failed(maybeOuterP))
2309 return failure();
2310 rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2311 return success();
2312 }
2313};
2314
2315} // namespace
2316
2323
2324void mlir::vector::populateVectorMaskMaterializationPatterns(
2325 RewritePatternSet &patterns, bool force32BitVectorIndices,
2326 PatternBenefit benefit) {
2327 patterns.add<VectorCreateMaskOpConversion,
2328 MaterializeTransferMask<vector::TransferReadOp>,
2329 MaterializeTransferMask<vector::TransferWriteOp>>(
2330 patterns.getContext(), force32BitVectorIndices, benefit);
2331 patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2332}
2333
2334void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2337 DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
2338}
2339
2340void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2342 patterns.add<BubbleDownVectorBitCastForExtract,
2343 BubbleDownBitCastForStridedSliceExtract,
2344 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2345 patterns.getContext(), benefit);
2346}
2347
2348void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2350 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2351 patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2352 std::move(controlFn), benefit);
2353}
2354
2357 std::function<LogicalResult(vector::ContractionOp)> constraint,
2358 PatternBenefit benefit) {
2359 patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2360 std::move(constraint));
2361}
2362
2365 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2366 CombineContractABTranspose, CombineContractResultTranspose>(
2367 patterns.getContext(), benefit);
2368}
2369
2376
2378 PatternBenefit benefit) {
2379 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2380 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2381 patterns.getContext(), benefit);
2382}
2383
2384void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2385 PatternBenefit benefit) {
2386 // TODO: Consider converting these patterns to canonicalizations.
2387 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
2388 benefit);
2389}
2390
2391void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2393 patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2394 patterns.add<ReduceRedundantZero>(patterns.getContext(),
2395 PatternBenefit(benefit.getBenefit() + 1));
2396}
2397
2398void mlir::vector::populateBreakDownVectorReductionPatterns(
2399 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2400 PatternBenefit benefit) {
2401 patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2402 maxNumElementsToExtract, benefit);
2403}
2404
2407 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2408 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2409 patterns.getContext());
2410}
2411
2412//===----------------------------------------------------------------------===//
2413// TableGen'd enum attribute definitions
2414//===----------------------------------------------------------------------===//
2415
2416#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
return success()
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static FailureOr< size_t > getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType)
Returns the number of dims can be folded away from transfer ops. It returns a failure if it can not d...
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:122
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:128
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents an operand of an operation.
Definition Value.h:257
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction sui...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.