MLIR 22.0.0git
VectorTransforms.cpp
Go to the documentation of this file.
1//===- VectorTransforms.cpp - Conversion within the Vector dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites as 1->N patterns.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include <cassert>
16#include <cstdint>
17#include <functional>
18#include <optional>
19
30#include "mlir/IR/Location.h"
31#include "mlir/IR/Matchers.h"
34
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/Support/FormatVariadic.h"
37
38#define DEBUG_TYPE "vector-to-vector"
39
40using namespace mlir;
41using namespace mlir::vector;
42
43template <typename IntType>
45 return llvm::to_vector<4>(llvm::map_range(
46 arrayAttr.getAsRange<IntegerAttr>(),
47 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
48}
49
50// Helper to find an index in an affine map.
51static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
52 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
53 int64_t idx = map.getDimPosition(i);
54 if (idx == index)
55 return i;
56 }
57 return std::nullopt;
58}
59
60namespace {
61
62/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
63/// Ex:
64/// ```
65/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
66/// %1 = vector.multi_reduction add, %0 [1]
67/// : vector<8x32x16xf32> to vector<8x16xf32>
68/// ```
69/// Gets converted to:
70/// ```
71/// %1 = vector.contract {indexing_maps = [
72/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
73/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
74/// affine_map<(d0, d1, d2) -> (d0, d1)>],
75/// iterator_types = ["parallel", "parallel", "reduction"],
76/// kind = add} %0, %arg1, %cst_f0
77/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
78/// ```
79struct MultiReduceToContract
80 : public OpRewritePattern<vector::MultiDimReductionOp> {
81 using Base::Base;
82
83 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
84 PatternRewriter &rewriter) const override {
85 if (reduceOp.getKind() != vector::CombiningKind::ADD)
86 return failure();
87 Operation *mulOp = reduceOp.getSource().getDefiningOp();
88 if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
89 return failure();
90 SmallVector<bool> reductionMask = reduceOp.getReductionMask();
91 auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
92 SmallVector<AffineExpr> exprs;
93 SmallVector<vector::IteratorType> iteratorTypes;
94 for (const auto &isReduceDim : llvm::enumerate(reductionMask)) {
95 if (!isReduceDim.value()) {
96 iteratorTypes.push_back(vector::IteratorType::parallel);
97 exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
98 } else {
99 iteratorTypes.push_back(vector::IteratorType::reduction);
100 }
101 }
102 auto dstMap =
103 AffineMap::get(/*dimCount=*/reductionMask.size(),
104 /*symbolCount=*/0, exprs, reduceOp.getContext());
105 rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
106 reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), reduceOp.getAcc(),
107 rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
108 rewriter.getArrayAttr(llvm::to_vector(llvm::map_range(
109 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
110 return IteratorTypeAttr::get(rewriter.getContext(), t);
111 }))));
112 return success();
113 }
114};
115
116/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user.
117/// Ex:
118/// ```
119/// %0 = vector.transpose %arg0, [2, 0, 1]
120/// : vector<32x16x8xf32> to vector<8x32x16xf32>
121/// %1 = vector.contract {indexing_maps = [
122/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
123/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
124/// affine_map<(d0, d1, d2) -> (d0, d1)>],
125/// iterator_types = ["parallel", "parallel", "reduction"],
126/// kind = add} %0, %arg1, %cst_f0
127/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
128/// ```
129/// Gets converted to:
130/// ```
131/// %1 = vector.contract {indexing_maps = [
132/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
133/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
134/// affine_map<(d0, d1, d2) -> (d0, d1)>],
135/// iterator_types = ["parallel", "parallel", "reduction"],
136/// kind = add} %arg0, %arg1, %cst_f0
137/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
138/// ```
139struct CombineContractABTranspose final
140 : public OpRewritePattern<vector::ContractionOp> {
141 using Base::Base;
142
143 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
144 PatternRewriter &rewriter) const override {
145 SmallVector<AffineMap> maps =
146 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
147 Value lhs = contractOp.getLhs();
148 Value rhs = contractOp.getRhs();
149 size_t index = 0;
150 bool changed = false;
151 for (Value *operand : {&lhs, &rhs}) {
152 AffineMap &map = maps[index++];
153 auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
154 if (!transposeOp)
155 continue;
156 AffineMap permutationMap = AffineMap::getPermutationMap(
157 transposeOp.getPermutation(), contractOp.getContext());
158 map = inversePermutation(permutationMap).compose(map);
159 *operand = transposeOp.getVector();
160 changed = true;
161 }
162 if (!changed)
163 return failure();
164 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
165 contractOp, lhs, rhs, contractOp.getAcc(),
166 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
167 return success();
168 }
169};
170
171/// Merges accumulator and result transposes into contract.
172///
173/// For example:
174/// ```mlir
175/// %accT = vector.transpose %acc, [0, 2, 1]
176/// : vector<2x8x4xf32> to vector<2x4x8xf32>
177/// %contract = vector.contract {
178/// indexing_maps = [
179/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
180/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
181/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
182/// ],
183/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
184/// kind = #vector.kind<add>
185/// } %lhs, %rhs, %accT
186/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32>
187/// %0 = vector.transpose %contract, [0, 2, 1]
188/// : vector<2x4x8xf32> to vector<2x8x4>
189/// ```
190/// Becomes:
191/// ```mlir
192/// %0 = vector.contract {
193/// indexing_maps = [
194/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
195/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>,
196/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
197/// ],
198/// iterator_types = ["parallel", "parallel", "parallel", "reduction"],
199/// kind = #vector.kind<add>
200/// } %lhs, %rhs, %acc
201/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32>
202/// ```
203struct CombineContractResultTranspose final
204 : public OpRewritePattern<vector::TransposeOp> {
205 using Base::Base;
206
207 LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
208 PatternRewriter &rewriter) const override {
209 auto contractOp = resTOp.getVector().getDefiningOp<vector::ContractionOp>();
210 if (!contractOp || !contractOp->hasOneUse())
211 return failure();
212
213 auto accTOp = contractOp.getAcc().getDefiningOp<vector::TransposeOp>();
214 if (!accTOp)
215 return failure();
216
217 MLIRContext *context = contractOp.getContext();
218 auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray());
219 AffineMap contractMap = maps.back();
220
221 // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B.
222 // To index into A in contract, we need revert(f)(g(C)) -> A.
223 auto accTMap =
224 AffineMap::getPermutationMap(accTOp.getPermutation(), context);
225
226 // Contract performs g(C) -> D. Result transpose performs h(D) -> E.
227 // To index into E in contract, we need h(g(C)) -> E.
228 auto resTMap =
229 AffineMap::getPermutationMap(resTOp.getPermutation(), context);
230 auto combinedResMap = resTMap.compose(contractMap);
231
232 // The accumulator and result share the same indexing map. So they should be
233 // the same to be able to merge. This means combinedResMap is the same as
234 // inversePermutation(accTMap).compose(contractMap), which means
235 if (inversePermutation(accTMap) != resTMap)
236 return failure();
237 maps.back() = combinedResMap;
238
239 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
240 resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(),
241 rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes());
242 return success();
243 }
244};
245
246/// Merge BroadcastOp into ContractionOp user.
247/// Ex:
248/// ```
249/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
250/// %1 = vector.contract {indexing_maps = [
251/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
252/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
253/// affine_map<(d0, d1, d2) -> (d0, d1)>],
254/// iterator_types = ["parallel", "parallel", "reduction"],
255/// kind = add} %0, %arg1, %cst_f0
256/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
257/// ```
258/// Gets converted to:
259/// ```
260/// %1 = vector.contract {indexing_maps = [
261/// affine_map<(d0, d1, d2) -> (d1, d2)>,
262/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
263/// affine_map<(d0, d1, d2) -> (d0, d1)>],
264/// iterator_types = ["parallel", "parallel", "reduction"],
265/// kind = add} %arg0, %arg1, %cst_f0
266/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267/// ```
268///
269/// For masked vector.contract, the mask requires updating when a dimension is
270/// dropped. In such cases, the dropped dimensions must correspond to the mask's
271/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272/// is not supported.
273FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274 MaskingOpInterface maskingOp,
275 PatternRewriter &rewriter) {
277 llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278 Value lhs = contractOp.getLhs();
279 Value rhs = contractOp.getRhs();
280 size_t index = 0;
281 bool changed = false;
282 for (Value *operand : {&lhs, &rhs}) {
283 AffineMap &map = maps[index++];
284 auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285 if (!broadcast)
286 continue;
287 // contractionOp can only take vector as operands.
288 auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
289 if (!srcType ||
290 srcType.getRank() == broadcast.getResultVectorType().getRank())
291 continue;
292 int64_t rankDiff =
293 broadcast.getResultVectorType().getRank() - srcType.getRank();
294 bool innerDimBroadcast = false;
295 SmallVector<AffineExpr> originalDims;
296 for (const auto &dim : llvm::enumerate(srcType.getShape())) {
297 if (dim.value() !=
298 broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299 innerDimBroadcast = true;
300 break;
301 }
302 originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
303 }
304 // Contract doesn't support inner dimension broadcast. Once this is
305 // relaxed we can remove this case.
306 if (innerDimBroadcast)
307 continue;
308
309 // It would be incorrect to fold a broadcast onto a reduction dimension
310 // of non-unit size.
311 bool nonUnitDimReductionBroadcast = false;
312 for (int64_t i = 0; i < rankDiff; ++i) {
313 if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
314 isReductionIterator(contractOp.getIteratorTypes()
315 .getValue()[map.getDimPosition(i)])) {
316 nonUnitDimReductionBroadcast = true;
317 break;
318 }
319 }
320 if (nonUnitDimReductionBroadcast)
321 continue;
322
323 AffineMap broadcastMap =
324 AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
325 originalDims, contractOp.getContext());
326 map = broadcastMap.compose(map);
327 *operand = broadcast.getSource();
328 changed = true;
329 }
330
331 if (!changed)
332 return failure();
333
334 // Determine which dims are usused, now that the maps have been composed
335 // with the broadcast maps.
336 llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337 // Compress unused dims.
338 for (auto &m : maps)
339 m = compressDims(m, unusedDimsBitVector);
340 // Compute the combined iterators.
341 SmallVector<Attribute> iterators;
342 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343 if (!unusedDimsBitVector.test(i))
344 iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
345 }
346
347 // Check whether any of the unused dims is non-unit, e.g.:
348 // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349 // This is only required when collapsing a mask. If there is no mask, skip.
350 VectorType oldMaskType;
351 bool isAnyUnusedDimNonUnit = false;
352 if (maskingOp) {
353 oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354 for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355 if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356 isAnyUnusedDimNonUnit = true;
357 break;
358 }
359 }
360 }
361
362 // Check that compressing unused dims isn't removing all reduction dimension
363 // pairs. For example, if the vector.contract had only one reduction
364 // iterator and that was a unit-dimension created by a broadcast,
365 // then we should bail here, otherwise we would create a contract without
366 // a reduction dimension pair.
367 bool hasReductionIteratorApplyingOnBothSides = false;
368 for (unsigned i = 0; i < iterators.size(); ++i) {
369 if (!isReductionIterator(iterators[i]))
370 continue;
371 if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
372 hasReductionIteratorApplyingOnBothSides = true;
373 break;
374 }
375 }
376 if (!hasReductionIteratorApplyingOnBothSides)
377 return failure();
378
379 // If the compressed maps have a dimension that is not used by either LHS or
380 // RHS then the ContractionOp verifier would fail.
381 if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
382 return failure();
383
384 Operation *newOp = vector::ContractionOp::create(
385 rewriter, contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
386 rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
387
388 // Handle the mask.
389 if (maskingOp) {
390 if (isAnyUnusedDimNonUnit)
391 return rewriter.notifyMatchFailure(contractOp,
392 "Cannont drop non-unit mask dim.");
393 assert(unusedDimsBitVector.size() ==
394 static_cast<size_t>(oldMaskType.getRank()) &&
395 "The mask rank is incorrect!");
396
397 // If a dimension has been dropped, update the mask accordingly. Otherwise,
398 // keep it as is.
399 Value mask = maskingOp.getMask();
400 if (unusedDimsBitVector.count() != 0) {
401 // At this point, two assumptions are made:
402 // * The unused dimensions are the leading mask dimensions
403 // (vector.contract does not support inner dim broadcasting).
404 // * The unused dimensions are all unit.
405 // These conditions are effectively verified in the blocks preceeding this
406 // one.
407 auto newShape =
408 oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409 auto newShapeScalableDims =
410 oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411 VectorType maskOpType =
412 VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
413 mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
414 maskOpType, maskingOp.getMask())
415 .getResult();
416 }
417
418 newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
419 }
420 return newOp->getResult(0);
421}
422
423struct CombineContractBroadcastMask
424 : public MaskableOpRewritePattern<vector::ContractionOp> {
425 using MaskableOpRewritePattern::MaskableOpRewritePattern;
426 FailureOr<Value>
427
428 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
429 MaskingOpInterface maskingOp,
430 PatternRewriter &rewriter) const override {
431 return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
432 }
433};
434
435/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and
436/// contraction ops closer, which kicks in CombineContractBroadcast pattern when
437/// casting ops are around these operations.
438/// Ex:
439/// ```
440/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8>
441/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32>
442/// ```
443/// Gets converted to:
444/// ```
445/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32>
446/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32>
447/// ```
448struct ReorderCastOpsOnBroadcast
449 : public OpInterfaceRewritePattern<CastOpInterface> {
450 using OpInterfaceRewritePattern<CastOpInterface>::OpInterfaceRewritePattern;
451
452 LogicalResult matchAndRewrite(CastOpInterface op,
453 PatternRewriter &rewriter) const override {
454 if (op->getNumOperands() != 1)
455 return failure();
456 if (!isa<VectorType>(op->getResult(0).getType()))
457 return failure();
458 auto bcastOp = op->getOperand(0).getDefiningOp<vector::BroadcastOp>();
459 if (!bcastOp)
460 return failure();
461
462 Type castResTy = getElementTypeOrSelf(op->getResult(0));
463 if (auto vecTy = dyn_cast<VectorType>(bcastOp.getSourceType()))
464 castResTy = vecTy.clone(castResTy);
465 auto *castOp =
466 rewriter.create(op->getLoc(), op->getName().getIdentifier(),
467 bcastOp.getSource(), castResTy, op->getAttrs());
468 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
469 op, op->getResult(0).getType(), castOp->getResult(0));
470 return success();
471 }
472};
473
474/// Reorders elementwise(transpose) to transpose(elementwise). This makes
475/// transpose ops and contraction ops closer, which kicks in
476/// CombineContractABTranspose pattern when elementwise ops are between these
477/// operations. Ex:
478/// ```
479/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
480/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32>
481/// %r = arith.addf %at, %bt : vector<2x4xf32>
482/// ```
483/// Gets converted to:
484/// ```
485/// %0 = arith.addf %a, %b : vector<4x2xf32>
486/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32>
487/// ```
488struct ReorderElementwiseOpsOnTranspose final
489 : public OpTraitRewritePattern<OpTrait::Elementwise> {
491 LogicalResult matchAndRewrite(Operation *op,
492 PatternRewriter &rewriter) const override {
493 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
494 return failure();
495
496 // Make sure all operands are transpose/constant ops and collect their
497 // transposition maps.
498 SmallVector<ArrayRef<int64_t>> transposeMaps;
499 transposeMaps.reserve(op->getNumOperands());
500 // Record the initial type before transposition. We'll use its shape later.
501 // Any type will do here as we will check all transpose maps are the same.
502 VectorType srcType;
503 for (Value operand : op->getOperands()) {
504 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
505 if (transposeOp) {
506 transposeMaps.push_back(transposeOp.getPermutation());
507 srcType = transposeOp.getSourceVectorType();
508 } else if (!matchPattern(operand, m_Constant())) {
509 return failure();
510 }
511 }
512 if (transposeMaps.empty())
513 return failure();
514 // This is an elementwise op, so all transposed operands should have the
515 // same type. We need to additionally check that all transposes uses the
516 // same map.
517 if (!llvm::all_equal(transposeMaps))
518 return rewriter.notifyMatchFailure(op, "different transpose map");
519
520 SmallVector<Value> srcValues;
521 srcValues.reserve(op->getNumOperands());
522
523 // If there are constant operands, we need to insert inverse transposes for
524 // them. Calculate the inverse order first.
525 auto order = transposeMaps.front();
526 SmallVector<int64_t> invOrder(order.size());
527 for (int i = 0, e = order.size(); i < e; ++i)
528 invOrder[order[i]] = i;
529
530 for (Value operand : op->getOperands()) {
531 auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
532 if (transposeOp) {
533 srcValues.push_back(transposeOp.getVector());
534 } else {
535 // This is a constant. Create a reverse transpose op for it.
536 auto vectorType =
537 srcType.clone(cast<VectorType>(operand.getType()).getElementType());
538 srcValues.push_back(vector::TransposeOp::create(
539 rewriter, operand.getLoc(), vectorType, operand, invOrder));
540 }
541 }
542
543 auto vectorType = srcType.clone(
544 cast<VectorType>(op->getResultTypes()[0]).getElementType());
545 Operation *elementwiseOp =
546 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
547 vectorType, op->getAttrs());
548 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
549 op, op->getResultTypes()[0], elementwiseOp->getResult(0),
550 transposeMaps.front());
551 return success();
552 }
553};
554
555// Returns the values in `arrayAttr` as an integer vector.
556static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
557 return llvm::to_vector<4>(
558 llvm::map_range(arrayAttr.getAsRange<IntegerAttr>(),
559 [](IntegerAttr attr) { return attr.getInt(); }));
560}
561
562// Shuffles vector.bitcast op after vector.extract op.
563//
564// This transforms IR like:
565// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16>
566// %1 = vector.extract %0[3] : f16 from vector<8xf16>
567// Into:
568// %0 = vector.extract %src[1] : f32 from vector<4xf32>
569// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16>
570// %2 = vector.extract %1[1] : f16 from vector<2xf16>
571struct BubbleDownVectorBitCastForExtract
572 : public OpRewritePattern<vector::ExtractOp> {
573 using Base::Base;
574
575 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
576 PatternRewriter &rewriter) const override {
577 // Only support extracting scalars for now.
578 if (extractOp.getSourceVectorType().getRank() != 1)
579 return failure();
580
581 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
582 if (!castOp)
583 return failure();
584
585 VectorType castSrcType = castOp.getSourceVectorType();
586 VectorType castDstType = castOp.getResultVectorType();
587 assert(castSrcType.getRank() == castDstType.getRank());
588
589 // Fail to match if we only have one element in the cast op source.
590 // This is to avoid infinite loop given that this pattern can generate
591 // such cases.
592 if (castSrcType.getNumElements() == 1)
593 return failure();
594
595 // Only support casting to a larger number of elements or now.
596 // E.g., vector<4xf32> -> vector<8xf16>.
597 if (castSrcType.getNumElements() > castDstType.getNumElements())
598 return failure();
599
600 unsigned expandRatio =
601 castDstType.getNumElements() / castSrcType.getNumElements();
602
603 // Get the first element of the mixed position as integer.
604 auto mixedPos = extractOp.getMixedPosition();
605 if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0]))
606 return failure();
607 uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt();
608
609 // Get the single scalar (as a vector) in the source value that packs the
610 // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
611 Location loc = extractOp.getLoc();
612 Value packedValue = vector::ExtractOp::create(
613 rewriter, loc, castOp.getSource(), index / expandRatio);
614 Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
615 Value zero = arith::ConstantOp::create(rewriter, loc, packedVecType,
616 rewriter.getZeroAttr(packedVecType));
617 packedValue = vector::InsertOp::create(rewriter, loc, packedValue, zero,
618 /*position=*/0);
619
620 // Cast it to a vector with the desired scalar's type.
621 // E.g. f32 -> vector<2xf16>
622 VectorType packedType =
623 VectorType::get({expandRatio}, castDstType.getElementType());
624 Value castedValue =
625 vector::BitCastOp::create(rewriter, loc, packedType, packedValue);
626
627 // Finally extract the desired scalar.
628 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, castedValue,
629 index % expandRatio);
630 return success();
631 }
632};
633
634// Shuffles vector.bitcast op after vector.extract_strided_slice op.
635//
636// This transforms IR like:
637// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16>
638// %0 = vector.extract_strided_slice %cast {
639// offsets = [4], sizes = [4], strides = [1]
640// } : vector<8xf16> to vector<4xf16>
641// Into:
642// %0 = vector.extract_strided_slice %src {
643// offsets = [2], sizes = [2], strides = [1]
644// } : vector<4xf32> to vector<2xf32>
645// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
646struct BubbleDownBitCastForStridedSliceExtract
647 : public OpRewritePattern<vector::ExtractStridedSliceOp> {
648 using Base::Base;
649
650 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
651 PatternRewriter &rewriter) const override {
652 auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
653 if (!castOp)
654 return failure();
655
656 VectorType castSrcType = castOp.getSourceVectorType();
657 VectorType castDstType = castOp.getResultVectorType();
658 assert(castSrcType.getRank() == castDstType.getRank());
659
660 int64_t castSrcLastDim = castSrcType.getShape().back();
661 int64_t castDstLastDim = castDstType.getShape().back();
662 // Require casting to more elements for now; other cases to be implemented.
663 if (castSrcLastDim > castDstLastDim)
664 return failure();
665
666 // Only accept all one strides for now.
667 if (llvm::any_of(extractOp.getStrides().getAsValueRange<IntegerAttr>(),
668 [](const APInt &val) { return !val.isOne(); }))
669 return failure();
670
671 unsigned rank = extractOp.getSourceVectorType().getRank();
672 assert(castDstLastDim % castSrcLastDim == 0);
673 int64_t expandRatio = castDstLastDim / castSrcLastDim;
674
675 // If we have a less number of offsets than the rank, then implicitly we
676 // are selecting the full range for the last bitcasted dimension; other
677 // dimensions aren't affected. Otherwise, we need to scale down the last
678 // dimension's offset given we are extracting from less elements now.
679 ArrayAttr newOffsets = extractOp.getOffsets();
680 if (newOffsets.size() == rank) {
681 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
682 if (offsets.back() % expandRatio != 0)
683 return failure();
684 offsets.back() = offsets.back() / expandRatio;
685 newOffsets = rewriter.getI64ArrayAttr(offsets);
686 }
687
688 // Similarly for sizes.
689 ArrayAttr newSizes = extractOp.getSizes();
690 if (newSizes.size() == rank) {
691 SmallVector<int64_t> sizes = getIntValueVector(newSizes);
692 if (sizes.back() % expandRatio != 0)
693 return failure();
694 sizes.back() = sizes.back() / expandRatio;
695 newSizes = rewriter.getI64ArrayAttr(sizes);
696 }
697
698 SmallVector<int64_t> dims =
699 llvm::to_vector<4>(cast<VectorType>(extractOp.getType()).getShape());
700 dims.back() = dims.back() / expandRatio;
701 VectorType newExtractType =
702 VectorType::get(dims, castSrcType.getElementType());
703
704 auto newExtractOp = vector::ExtractStridedSliceOp::create(
705 rewriter, extractOp.getLoc(), newExtractType, castOp.getSource(),
706 newOffsets, newSizes, extractOp.getStrides());
707
708 rewriter.replaceOpWithNewOp<vector::BitCastOp>(
709 extractOp, extractOp.getType(), newExtractOp);
710
711 return success();
712 }
713};
714
715// Shuffles vector.bitcast op before vector.insert_strided_slice op.
716//
717// This transforms IR like:
718// %0 = vector.insert %val, %dst[4] : vector<32xi4> into vector<8x32xi4>
719// %1 = vector.bitcast %0 : vector<8x32xi4> to vector<8x16xi8>
720// Into:
721// %0 = vector.bitcast %val : vector<32xi4> to vector<16xi8>
722// %1 = vector.bitcast %dst : vector<8x32xi4> to vector<8x16xi8>
723// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
724//
725struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
726 using Base::Base;
727
728 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
729 PatternRewriter &rewriter) const override {
730 VectorType castSrcType = bitcastOp.getSourceVectorType();
731 VectorType castDstType = bitcastOp.getResultVectorType();
732
733 // 0-D and scalable vectors are not supported yet.
734 if (castSrcType.getRank() == 0 || castSrcType.isScalable() ||
735 castDstType.isScalable())
736 return failure();
737
738 int64_t castSrcLastDim = castSrcType.getShape().back();
739 int64_t castDstLastDim = castDstType.getShape().back();
740 bool isNumElemsShrink = castSrcLastDim >= castDstLastDim;
741 int64_t ratio;
742 if (isNumElemsShrink) {
743 assert(castSrcLastDim % castDstLastDim == 0);
744 ratio = castSrcLastDim / castDstLastDim;
745 } else {
746 assert(castDstLastDim % castSrcLastDim == 0);
747 ratio = castDstLastDim / castSrcLastDim;
748 }
749
750 auto insertOp = bitcastOp.getSource().getDefiningOp<vector::InsertOp>();
751 if (!insertOp)
752 return failure();
753
754 // Only vector sources are supported for now.
755 auto insertSrcType = dyn_cast<VectorType>(insertOp.getValueToStoreType());
756 if (!insertSrcType)
757 return failure();
758
759 // Bitcast the source.
760 SmallVector<int64_t> srcDims(insertSrcType.getShape());
761 srcDims.back() =
762 isNumElemsShrink ? srcDims.back() / ratio : srcDims.back() * ratio;
763 VectorType newCastSrcType =
764 VectorType::get(srcDims, castDstType.getElementType());
765 auto newCastSrcOp =
766 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
767 insertOp.getValueToStore());
768
769 SmallVector<int64_t> dstDims(insertOp.getDestVectorType().getShape());
770 dstDims.back() =
771 isNumElemsShrink ? dstDims.back() / ratio : dstDims.back() * ratio;
772 VectorType newCastDstType =
773 VectorType::get(dstDims, castDstType.getElementType());
774
775 // Bitcast the destination.
776 auto newCastDstOp = vector::BitCastOp::create(
777 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
778
779 // Generate new insert.
780 rewriter.replaceOpWithNewOp<vector::InsertOp>(
781 bitcastOp, newCastSrcOp, newCastDstOp, insertOp.getMixedPosition());
782 return success();
783 }
784};
785
786// Shuffles vector.bitcast op before vector.insert_strided_slice op.
787//
788// This transforms IR like:
789// %0 = vector.insert_strided_slice %src, %dst {
790// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16>
791// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
792// Into:
793// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32>
794// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32>
795// %2 = vector.insert_strided_slice %src, %dst {
796// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
797struct BubbleUpBitCastForStridedSliceInsert
798 : public OpRewritePattern<vector::BitCastOp> {
799 using Base::Base;
800
801 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
802 PatternRewriter &rewriter) const override {
803 VectorType castSrcType = bitcastOp.getSourceVectorType();
804 VectorType castDstType = bitcastOp.getResultVectorType();
805 assert(castSrcType.getRank() == castDstType.getRank());
806 // Skip 0-D vector which will not from InsertStridedSliceOp.
807 if (castSrcType.getRank() == 0)
808 return failure();
809
810 int64_t castSrcLastDim = castSrcType.getShape().back();
811 int64_t castDstLastDim = castDstType.getShape().back();
812 // Require casting to less elements for now; other cases to be implemented.
813 if (castSrcLastDim < castDstLastDim)
814 return failure();
815
816 assert(castSrcLastDim % castDstLastDim == 0);
817 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
818
819 auto insertOp =
820 bitcastOp.getSource().getDefiningOp<vector::InsertStridedSliceOp>();
821 if (!insertOp)
822 return failure();
823
824 // Only accept all one strides for now.
825 if (llvm::any_of(insertOp.getStrides().getAsValueRange<IntegerAttr>(),
826 [](const APInt &val) { return !val.isOne(); }))
827 return failure();
828
829 unsigned rank = insertOp.getSourceVectorType().getRank();
830 // Require insert op to have the same rank for the source and destination
831 // vector; other cases to be implemented.
832 if (rank != insertOp.getDestVectorType().getRank())
833 return failure();
834
835 // Requires that shape of insert op src is castable to dstType.
836 unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
837 unsigned destinationWidth =
838 castDstType.getElementType().getIntOrFloatBitWidth();
839 unsigned numElements = destinationWidth / sourceWidth;
840 if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
841 return failure();
842
843 ArrayAttr newOffsets = insertOp.getOffsets();
844 assert(newOffsets.size() == rank);
845 SmallVector<int64_t> offsets = getIntValueVector(newOffsets);
846 if (offsets.back() % shrinkRatio != 0)
847 return failure();
848 offsets.back() = offsets.back() / shrinkRatio;
849 newOffsets = rewriter.getI64ArrayAttr(offsets);
850
851 SmallVector<int64_t> srcDims =
852 llvm::to_vector<4>(insertOp.getSourceVectorType().getShape());
853 srcDims.back() = srcDims.back() / shrinkRatio;
854 VectorType newCastSrcType =
855 VectorType::get(srcDims, castDstType.getElementType());
856
857 auto newCastSrcOp =
858 vector::BitCastOp::create(rewriter, bitcastOp.getLoc(), newCastSrcType,
859 insertOp.getValueToStore());
860
861 SmallVector<int64_t> dstDims =
862 llvm::to_vector<4>(insertOp.getDestVectorType().getShape());
863 dstDims.back() = dstDims.back() / shrinkRatio;
864 VectorType newCastDstType =
865 VectorType::get(dstDims, castDstType.getElementType());
866
867 auto newCastDstOp = vector::BitCastOp::create(
868 rewriter, bitcastOp.getLoc(), newCastDstType, insertOp.getDest());
869
870 rewriter.replaceOpWithNewOp<vector::InsertStridedSliceOp>(
871 bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets,
872 insertOp.getStrides());
873
874 return success();
875 }
876};
877
878// Breaks down vector.bitcast op
879//
880// This transforms IR like:
881// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
882// Into:
883// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
884// %1 = vector.extract_strided_slice %0 {
885// offsets = [0], sizes = [4], strides = [1]
886// } : vector<8xf16> to vector<4xf16>
887// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32>
888// %4 = vector.insert_strided_slice %2, %cst {
889// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
890// %5 = vector.extract_strided_slice %0 {
891// offsets = [4], sizes = [4], strides = [1]
892// } : vector<8xf16> to vector<4xf16>
893// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32>
894// %7 = vector.insert_strided_slice %6, %cst {
895// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
896struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
897 using Base::Base;
898
899public:
900 BreakDownVectorBitCast(MLIRContext *context,
901 std::function<bool(vector::BitCastOp)> controlFn,
902 PatternBenefit benefit)
903 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
904
905 LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
906 PatternRewriter &rewriter) const override {
907
908 if (controlFn && !controlFn(bitcastOp))
909 return failure();
910
911 VectorType castSrcType = bitcastOp.getSourceVectorType();
912 VectorType castDstType = bitcastOp.getResultVectorType();
913 assert(castSrcType.getRank() == castDstType.getRank());
914
915 // This transformation builds on top of
916 // vector.{extract|insert}_strided_slice, which do not support
917 // extracting/inserting "scallable sub-vectors". Bail out.
918 if (castSrcType.isScalable())
919 return rewriter.notifyMatchFailure(bitcastOp,
920 "Scalable vectors are not supported");
921
922 // Only support rank 1 case for now.
923 if (castSrcType.getRank() != 1)
924 return failure();
925
926 int64_t castSrcLastDim = castSrcType.getShape().back();
927 int64_t castDstLastDim = castDstType.getShape().back();
928 // Require casting to less elements for now; other cases to be implemented.
929 if (castSrcLastDim < castDstLastDim)
930 return failure();
931
932 assert(castSrcLastDim % castDstLastDim == 0);
933 int64_t shrinkRatio = castSrcLastDim / castDstLastDim;
934 // Nothing to do if it is already bitcasting to a single element.
935 if (castSrcLastDim == shrinkRatio)
936 return failure();
937
938 Location loc = bitcastOp.getLoc();
939 Type elemType = castDstType.getElementType();
940 assert(elemType.isSignlessIntOrIndexOrFloat());
941
942 Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
943 rewriter.getZeroAttr(elemType));
944 Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
945
946 SmallVector<int64_t> sliceShape = {castDstLastDim};
947 SmallVector<int64_t> strides = {1};
948 VectorType newCastDstType =
949 VectorType::get(SmallVector<int64_t>{castDstLastDim / shrinkRatio},
950 castDstType.getElementType());
951
952 for (int i = 0, e = shrinkRatio; i < e; ++i) {
953 Value extracted = ExtractStridedSliceOp::create(
954 rewriter, loc, bitcastOp.getSource(),
955 ArrayRef<int64_t>{i * castDstLastDim}, sliceShape, strides);
956 Value bitcast =
957 BitCastOp::create(rewriter, loc, newCastDstType, extracted);
958 res = InsertStridedSliceOp::create(
959 rewriter, loc, bitcast, res,
960 ArrayRef<int64_t>{i * castDstLastDim / shrinkRatio}, strides);
961 }
962 rewriter.replaceOp(bitcastOp, res);
963 return success();
964 }
965
966private:
967 std::function<bool(BitCastOp)> controlFn;
968};
969
970static bool haveSameShapeAndScaling(Type t, Type u) {
971 auto tVec = dyn_cast<VectorType>(t);
972 auto uVec = dyn_cast<VectorType>(u);
973 if (!tVec) {
974 return !uVec;
975 }
976 if (!uVec) {
977 return false;
978 }
979 return tVec.getShape() == uVec.getShape() &&
980 tVec.getScalableDims() == uVec.getScalableDims();
981}
982
983/// If `type` is shaped, clone it with `newElementType`. Otherwise,
984/// return `newElementType`.
985static Type cloneOrReplace(Type type, Type newElementType) {
986 if (auto shapedType = dyn_cast<ShapedType>(type)) {
987 return shapedType.clone(newElementType);
988 }
989 return newElementType;
990}
991
992/// If `value` is the result of a broadcast operation, return the input
993/// of the broadcast operation.
994static Value getBroadcastLikeSource(Value value) {
995
996 Operation *op = value.getDefiningOp();
997 if (!op)
998 return {};
999
1000 if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
1001 return broadcast.getSource();
1002
1003 return {};
1004}
1005
1006/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
1007///
1008/// Example:
1009/// ```
1010/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
1011/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
1012/// %r = arith.addi %a, %b : vector<1x4xindex>
1013/// ```
1014/// Gets converted to:
1015/// ```
1016/// %r = arith.addi %arg0, %arg1 : index
1017/// %b = vector.broadcast %r : index to vector<1x4xindex>
1018/// ```
1019struct ReorderElementwiseOpsOnBroadcast final
1020 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1022 LogicalResult matchAndRewrite(Operation *op,
1023 PatternRewriter &rewriter) const override {
1024 if (op->getNumResults() != 1)
1025 return failure();
1026 auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
1027 if (!resultType)
1028 return failure();
1030 return rewriter.notifyMatchFailure(
1031 op, "Op doesn't have ElementwiseMappableTraits");
1032 if (op->getNumOperands() == 0)
1033 return failure();
1034 if (isa<vector::FMAOp>(op)) {
1035 return rewriter.notifyMatchFailure(
1036 op,
1037 "Op only accepts vector types - not supported as broadcast source "
1038 "might be a scalar");
1039 }
1040
1041 Type resultElemType = resultType.getElementType();
1042
1043 // Get the type of the first non-constant operand
1044 Value broadcastSource;
1045 for (Value operand : op->getOperands()) {
1046 Operation *definingOp = operand.getDefiningOp();
1047 if (!definingOp)
1048 return failure();
1049 if (definingOp->hasTrait<OpTrait::ConstantLike>())
1050 continue;
1051 broadcastSource = getBroadcastLikeSource(operand);
1052 break;
1053 }
1054 if (!broadcastSource)
1055 return failure();
1056 Type unbroadcastResultType =
1057 cloneOrReplace(broadcastSource.getType(), resultElemType);
1058
1059 // Make sure that all operands are broadcast from identically-shaped types:
1060 // * scalar (`vector.broadcast`), or
1061 // * vector (`vector.broadcast`).
1062 // Otherwise the re-ordering wouldn't be safe.
1063 if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
1064 if (auto source = getBroadcastLikeSource(val))
1065 return haveSameShapeAndScaling(source.getType(),
1066 broadcastSource.getType());
1067 SplatElementsAttr splatConst;
1068 return matchPattern(val, m_Constant(&splatConst));
1069 })) {
1070 return rewriter.notifyMatchFailure(
1071 op,
1072 "not all operands are constants or broadcasts from the same type");
1073 }
1074
1075 // Collect the source values before broadcasting
1076 SmallVector<Value> srcValues;
1077 srcValues.reserve(op->getNumOperands());
1078 for (Value operand : op->getOperands()) {
1079 SplatElementsAttr splatConst;
1080 if (matchPattern(operand, m_Constant(&splatConst))) {
1081 Attribute newConst;
1082 Type elementType = getElementTypeOrSelf(operand.getType());
1083 Type newType = cloneOrReplace(unbroadcastResultType, elementType);
1084 if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
1085 newConst = splatConst.resizeSplat(newTypeShaped);
1086 } else {
1087 newConst = splatConst.getSplatValue<Attribute>();
1088 }
1089 Operation *newConstOp =
1090 operand.getDefiningOp()->getDialect()->materializeConstant(
1091 rewriter, newConst, newType, operand.getLoc());
1092 srcValues.push_back(newConstOp->getResult(0));
1093 } else {
1094 srcValues.push_back(operand.getDefiningOp()->getOperand(0));
1095 }
1096 }
1097
1098 // Create the "elementwise" Op
1099 Operation *elementwiseOp =
1100 rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
1101 unbroadcastResultType, op->getAttrs());
1102
1103 // Replace the original Op with the elementwise Op
1104 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1105 op, resultType, elementwiseOp->getResults());
1106
1107 return success();
1108 }
1109};
1110
1111/// Pattern to rewrite a ExtractOp(Elementwise) -> Elementwise(ExtractOp).
1112/// This may result in cleaner code when extracting a single value
1113/// from multi-element vector and also to help canonicalize 1-element vectors to
1114/// scalars.
1115///
1116/// Example:
1117/// ```
1118/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
1119/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
1120/// ```
1121/// Gets converted to:
1122/// ```
1123/// %0 = vector.extract %arg0[1] : f32 from vector<4xf32>
1124/// %1 = vector.extract %arg1[1] : f32 from vector<4xf32>
1125/// %2 = arith.addf %0, %1 : f32
1126/// ```
1127class ExtractOpFromElementwise final
1128 : public OpRewritePattern<vector::ExtractOp> {
1129public:
1130 using Base::Base;
1131
1132 LogicalResult matchAndRewrite(vector::ExtractOp op,
1133 PatternRewriter &rewriter) const override {
1134 Operation *eltwise = op.getSource().getDefiningOp();
1135
1136 // TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
1137 // as it doesn't support scalars.
1138 if (!eltwise || !OpTrait::hasElementwiseMappableTraits(eltwise) ||
1139 isa<vector::FMAOp>(eltwise))
1140 return rewriter.notifyMatchFailure(op, "not an elementwise op");
1141
1142 if (eltwise->getNumResults() != 1)
1143 return rewriter.notifyMatchFailure(op, "expected single result");
1144
1145 if (!eltwise->hasOneUse())
1146 return rewriter.notifyMatchFailure(op, "expected single op use");
1147
1148 if (!llvm::all_equal(eltwise->getOperandTypes()))
1149 return rewriter.notifyMatchFailure(op, "operand types are different");
1150
1151 // Dynamic position can cause dominance issues, so conservatively fail for
1152 // now.
1153 if (!op.getDynamicPosition().empty())
1154 return rewriter.notifyMatchFailure(
1155 op, "dynamic position not yet implemented");
1156
1157 Type dstType = op.getType();
1158
1159 OpBuilder::InsertionGuard g(rewriter);
1160 rewriter.setInsertionPoint(eltwise);
1161
1162 IRMapping mapping;
1163 Location loc = eltwise->getLoc();
1164 SmallVector<OpFoldResult> pos = op.getMixedPosition();
1165 for (Value arg : eltwise->getOperands()) {
1166 Value newArg = vector::ExtractOp::create(rewriter, loc, arg, pos);
1167 mapping.map(arg, newArg);
1168 }
1169
1170 Operation *newEltwise = rewriter.clone(*eltwise, mapping);
1171 newEltwise->getResult(0).setType(dstType);
1172
1173 rewriter.replaceOp(op, newEltwise);
1174 rewriter.eraseOp(eltwise);
1175 return success();
1176 }
1177};
1178
1179/// Check if the element type is suitable for vector.load/store sinking.
1180/// Element type must be index or byte-aligned integer or floating-point type.
1181static bool isSupportedMemSinkElementType(Type type) {
1182 if (isa<IndexType>(type))
1183 return true;
1184
1185 return type.isIntOrFloat() && type.getIntOrFloatBitWidth() % 8 == 0;
1186}
1187
1188/// Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1189/// Only index and byte-aligned integer and floating-point element types are
1190/// supported for now.
1191///
1192/// Example:
1193/// ```
1194/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1195/// vector.extract %0[1] : f32 from vector<4xf32>
1196/// ```
1197/// Gets converted to:
1198/// ```
1199/// %c1 = arith.constant 1 : index
1200/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1201/// %1 = memref.load %arg0[%0] : memref<?xf32>
1202/// ```
1203class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1204public:
1205 using Base::Base;
1206
1207 LogicalResult matchAndRewrite(vector::ExtractOp op,
1208 PatternRewriter &rewriter) const override {
1209 auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
1210 if (!loadOp)
1211 return rewriter.notifyMatchFailure(op, "expected a load op");
1212
1213 // Checking for single use so we won't duplicate load ops.
1214 if (!loadOp->hasOneUse())
1215 return rewriter.notifyMatchFailure(op, "expected single op use");
1216
1217 VectorType loadVecType = loadOp.getVectorType();
1218 if (loadVecType.isScalable())
1219 return rewriter.notifyMatchFailure(op,
1220 "scalable vectors are not supported");
1221
1222 MemRefType memType = loadOp.getMemRefType();
1223
1224 // Non-byte-aligned types are tricky and may require special handling,
1225 // ignore them for now.
1226 if (!isSupportedMemSinkElementType(memType.getElementType()))
1227 return rewriter.notifyMatchFailure(op, "unsupported element type");
1228
1229 int64_t rankOffset = memType.getRank() - loadVecType.getRank();
1230 if (rankOffset < 0)
1231 return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1232
1233 auto extractVecType = dyn_cast<VectorType>(op.getResult().getType());
1234 int64_t finalRank = 0;
1235 if (extractVecType)
1236 finalRank = extractVecType.getRank();
1237
1238 SmallVector<Value> indices = loadOp.getIndices();
1239 SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1240
1241 // There may be memory stores between the load and the extract op, so we
1242 // need to make sure that the new load op is inserted at the same place as
1243 // the original load op.
1244 OpBuilder::InsertionGuard g(rewriter);
1245 rewriter.setInsertionPoint(loadOp);
1246 Location loc = loadOp.getLoc();
1247 ArithIndexingBuilder idxBuilderf(rewriter, loc);
1248 for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1249 OpFoldResult pos = extractPos[i - rankOffset];
1250 if (isZeroInteger(pos))
1251 continue;
1252
1253 Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1254 indices[i] = idxBuilderf.add(indices[i], offset);
1255 }
1256
1257 Value base = loadOp.getBase();
1258 if (extractVecType) {
1259 rewriter.replaceOpWithNewOp<vector::LoadOp>(op, extractVecType, base,
1260 indices);
1261 } else {
1262 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1263 }
1264 // We checked for single use so we can safely erase the load op.
1265 rewriter.eraseOp(loadOp);
1266 return success();
1267 }
1268};
1269
1270/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
1271///
1272/// Example:
1273/// ```
1274/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
1275/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1276/// ```
1277/// Gets converted to:
1278/// ```
1279/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1280/// ```
1281class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
1282public:
1283 using Base::Base;
1284
1285 LogicalResult matchAndRewrite(vector::StoreOp op,
1286 PatternRewriter &rewriter) const override {
1287 VectorType vecType = op.getVectorType();
1288 if (vecType.isScalable())
1289 return rewriter.notifyMatchFailure(op,
1290 "scalable vectors are not supported");
1291
1292 if (isa<VectorType>(op.getMemRefType().getElementType()))
1293 return rewriter.notifyMatchFailure(
1294 op, "memrefs of vectors are not supported");
1295
1296 if (vecType.getNumElements() != 1)
1297 return rewriter.notifyMatchFailure(
1298 op, "only 1-element vectors are supported");
1299
1300 Value toStore = op.getValueToStore();
1301 Value source = getBroadcastLikeSource(toStore);
1302 if (!source)
1303 return rewriter.notifyMatchFailure(
1304 op, "value to store is not from a broadcast");
1305
1306 // Checking for single use so we can remove broadcast.
1307 Operation *broadcast = toStore.getDefiningOp();
1308 if (!broadcast->hasOneUse())
1309 return rewriter.notifyMatchFailure(op, "expected single op use");
1310
1311 Value base = op.getBase();
1312 ValueRange indices = op.getIndices();
1313
1314 if (isa<VectorType>(source.getType())) {
1315 rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1316 } else {
1317 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1318 }
1319 rewriter.eraseOp(broadcast);
1320 return success();
1321 }
1322};
1323
1324// Helper that returns a vector comparison that constructs a mask:
1325// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
1326//
1327// If `dim == 0` then the result will be a 0-D vector.
1328//
1329// NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative,
1330// much more compact, IR for this operation, but LLVM eventually
1331// generates more elaborate instructions for this intrinsic since it
1332// is very conservative on the boundary conditions.
1333static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
1334 bool force32BitVectorIndices, int64_t dim,
1335 Value b, Value *off = nullptr) {
1336 auto loc = op->getLoc();
1337 // If we can assume all indices fit in 32-bit, we perform the vector
1338 // comparison in 32-bit to get a higher degree of SIMD parallelism.
1339 // Otherwise we perform the vector comparison using 64-bit indices.
1340 Type idxType =
1341 force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
1342 DenseIntElementsAttr indicesAttr;
1343 if (dim == 0 && force32BitVectorIndices) {
1344 indicesAttr = DenseIntElementsAttr::get(
1345 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int32_t>{0});
1346 } else if (dim == 0) {
1347 indicesAttr = DenseIntElementsAttr::get(
1348 VectorType::get(ArrayRef<int64_t>{}, idxType), ArrayRef<int64_t>{0});
1349 } else if (force32BitVectorIndices) {
1350 indicesAttr = rewriter.getI32VectorAttr(
1351 llvm::to_vector<4>(llvm::seq<int32_t>(0, dim)));
1352 } else {
1353 indicesAttr = rewriter.getI64VectorAttr(
1354 llvm::to_vector<4>(llvm::seq<int64_t>(0, dim)));
1355 }
1356 Value indices = arith::ConstantOp::create(rewriter, loc, indicesAttr);
1357 // Add in an offset if requested.
1358 if (off) {
1359 Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
1360 Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
1361 indices = arith::AddIOp::create(rewriter, loc, ov, indices);
1362 }
1363 // Construct the vector comparison.
1364 Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
1365 Value bounds =
1366 vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
1367 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
1368 indices, bounds);
1369}
1370
1371template <typename ConcreteOp>
1372struct MaterializeTransferMask : public OpRewritePattern<ConcreteOp> {
1373public:
1374 explicit MaterializeTransferMask(MLIRContext *context, bool enableIndexOpt,
1375 PatternBenefit benefit = 1)
1376 : mlir::OpRewritePattern<ConcreteOp>(context, benefit),
1377 force32BitVectorIndices(enableIndexOpt) {}
1378
1379 LogicalResult matchAndRewrite(ConcreteOp xferOp,
1380 PatternRewriter &rewriter) const override {
1381 if (!xferOp.hasOutOfBoundsDim())
1382 return failure();
1383
1384 if (xferOp.getVectorType().getRank() > 1 || xferOp.getIndices().empty())
1385 return failure();
1386
1387 Location loc = xferOp->getLoc();
1388 VectorType vtp = xferOp.getVectorType();
1389
1390 // Create the in-bounds mask with all elements between [0 .. dim - offset)
1391 // set and [dim - offset .. vector_length) unset.
1392 //
1393 // TODO: when the leaf transfer rank is k > 1, we need the last `k`
1394 // dimensions here.
1395 unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
1396 Value off = xferOp.getIndices()[lastIndex];
1397 Value dim =
1398 vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
1399 Value b = arith::SubIOp::create(rewriter, loc, dim.getType(), dim, off);
1400 Value mask = vector::CreateMaskOp::create(
1401 rewriter, loc,
1402 VectorType::get(vtp.getShape(), rewriter.getI1Type(),
1403 vtp.getScalableDims()),
1404 b);
1405 if (xferOp.getMask()) {
1406 // Intersect the in-bounds with the mask specified as an op parameter.
1407 mask = arith::AndIOp::create(rewriter, loc, mask, xferOp.getMask());
1408 }
1409
1410 rewriter.modifyOpInPlace(xferOp, [&]() {
1411 xferOp.getMaskMutable().assign(mask);
1412 xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
1413 });
1414
1415 return success();
1416 }
1417
1418private:
1419 const bool force32BitVectorIndices;
1420};
1421
1422/// Conversion pattern for a `vector.create_mask` (0-D and 1-D only).
1423class VectorCreateMaskOpConversion
1424 : public OpRewritePattern<vector::CreateMaskOp> {
1425public:
1426 explicit VectorCreateMaskOpConversion(MLIRContext *context,
1427 bool enableIndexOpt,
1428 PatternBenefit benefit = 1)
1429 : mlir::OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1430 force32BitVectorIndices(enableIndexOpt) {}
1431
1432 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
1433 PatternRewriter &rewriter) const override {
1434 auto dstType = op.getType();
1435 if (cast<VectorType>(dstType).isScalable())
1436 return failure();
1437 int64_t rank = dstType.getRank();
1438 if (rank > 1)
1439 return failure();
1440 rewriter.replaceOp(
1441 op, buildVectorComparison(rewriter, op, force32BitVectorIndices,
1442 rank == 0 ? 0 : dstType.getDimSize(0),
1443 op.getOperand(0)));
1444 return success();
1445 }
1446
1447private:
1448 const bool force32BitVectorIndices;
1449};
1450
1451/// Returns true if all the `i1` elements of `constantOp` are set to `value`.
1452static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
1453 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constantOp.getValue());
1454 // TODO: Support non-dense constant.
1455 if (!denseAttr)
1456 return false;
1457
1458 assert(denseAttr.getElementType().isInteger(1) && "Unexpected type");
1459 return denseAttr.isSplat() && denseAttr.getSplatValue<bool>() == value;
1460}
1461
1462/// Folds a select operation between an all-true and all-false vector. For now,
1463/// only single element vectors (i.e., vector<1xi1>) are supported. That is:
1464///
1465/// %true = arith.constant dense<true> : vector<1xi1>
1466/// %false = arith.constant dense<false> : vector<1xi1>
1467/// %result = arith.select %cond, %true, %false : i1, vector<1xi1>
1468/// =>
1469/// %result = vector.broadcast %cond : i1 to vector<1xi1>
1470///
1471/// InstCombine seems to handle vectors with multiple elements but not the
1472/// single element ones.
1473struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
1474 using Base::Base;
1475
1476 LogicalResult matchAndRewrite(arith::SelectOp selectOp,
1477 PatternRewriter &rewriter) const override {
1478 auto vecType = dyn_cast<VectorType>(selectOp.getType());
1479 if (!vecType || !vecType.getElementType().isInteger(1))
1480 return failure();
1481
1482 // Only scalar conditions can be folded.
1483 Value cond = selectOp.getCondition();
1484 if (isa<VectorType>(cond.getType()))
1485 return failure();
1486
1487 // TODO: Support n-D and scalable vectors.
1488 if (vecType.getRank() != 1 || vecType.isScalable())
1489 return failure();
1490
1491 // TODO: Support vectors with multiple elements.
1492 if (vecType.getShape()[0] != 1)
1493 return failure();
1494
1495 auto trueConst = selectOp.getTrueValue().getDefiningOp<arith::ConstantOp>();
1496 if (!trueConst || !allI1ConstantValuesSetTo(trueConst, true))
1497 return failure();
1498
1499 auto falseConst =
1500 selectOp.getFalseValue().getDefiningOp<arith::ConstantOp>();
1501 if (!falseConst || !allI1ConstantValuesSetTo(falseConst, false))
1502 return failure();
1503
1504 // Replace select with its condition broadcasted to single element vector.
1505 auto elemType = rewriter.getIntegerType(vecType.getNumElements());
1506 auto bcastType = VectorType::get(/*shape=*/{1}, elemType);
1507 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(selectOp, bcastType, cond);
1508 return success();
1509 }
1510};
1511
1512/// Returns the number of dims can be folded away from transfer ops. It returns
1513/// a failure if it can not determine the number of dims to be folded.
1514///
1515/// Ex 1: returns "2" if `srcType` is memref<512x16x1x1xf32> and
1516/// `vectorType` is vector<16x16x1x1xf32>
1517/// (there two inner most dims can be dropped by memref.subview ops)
1518///
1519/// Ex 2: returns "1" if `srcType` is memref<512x16x1x1xf32> with
1520/// [8192, 16, 8, 1] strides and `vectorType` is vector<16x16x1x1xf32>
1521/// (only the inner most unit dim of `srcType` can be dropped)
1522///
1523/// Ex 3: return "0" if `srcType` is memref<512x16x1x1xf32> and
1524/// `vectorType` is vector<16x16x1x[1]xf32>
1525/// (the most inner dim in `vectorType` is not a unit dim (it's a "scalable
1526/// unit")
1527static FailureOr<size_t>
1528getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
1529 SmallVector<int64_t> srcStrides;
1530 int64_t srcOffset;
1531 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
1532 return failure();
1533
1534 auto isUnitDim = [](VectorType type, int dim) {
1535 return type.getDimSize(dim) == 1 && !type.getScalableDims()[dim];
1536 };
1537
1538 // According to vector.transfer_read/write semantics, the vector can be a
1539 // slice. Thus, we have to offset the check index with `rankDiff` in
1540 // `srcStrides` and source dim sizes.
1541 size_t result = 0;
1542 int rankDiff = srcType.getRank() - vectorType.getRank();
1543 for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
1544 // Check that the inner dim size is 1 for both memref type and vector slice.
1545 // It can be folded only if they are 1 and the stride is 1.
1546 int dim = vectorType.getRank() - i - 1;
1547 if (srcStrides[dim + rankDiff] != 1 ||
1548 srcType.getDimSize(dim + rankDiff) != 1 || !isUnitDim(vectorType, dim))
1549 break;
1550 result++;
1551 }
1552 return result;
1553}
1554
1555/// Drop inner most contiguous unit dimensions from transfer_read operand.
1557 : public OpRewritePattern<vector::TransferReadOp> {
1558 using Base::Base;
1559
1560 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
1561 PatternRewriter &rewriter) const override {
1562 // TODO: support 0-d corner case.
1563 if (readOp.getTransferRank() == 0)
1564 return failure();
1565
1566 // TODO: support mask.
1567 if (readOp.getMask())
1568 return failure();
1569
1570 auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
1571 if (!srcType)
1572 return failure();
1573
1574 if (!readOp.getPermutationMap().isMinorIdentity())
1575 return failure();
1576
1577 auto targetType = readOp.getVectorType();
1578 if (targetType.getRank() <= 1)
1579 return failure();
1580
1581 FailureOr<size_t> maybeDimsToDrop =
1582 getTransferFoldableInnerUnitDims(srcType, targetType);
1583 if (failed(maybeDimsToDrop))
1584 return failure();
1585
1586 size_t dimsToDrop = maybeDimsToDrop.value();
1587 if (dimsToDrop == 0)
1588 return failure();
1589
1590 auto inBounds = readOp.getInBoundsValues();
1591 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1592 if (llvm::is_contained(droppedInBounds, false))
1593 return failure();
1594
1595 auto resultTargetVecType =
1596 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1597 targetType.getElementType(),
1598 targetType.getScalableDims().drop_back(dimsToDrop));
1599
1600 auto loc = readOp.getLoc();
1602 memref::getMixedSizes(rewriter, loc, readOp.getBase());
1603 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1604 rewriter.getIndexAttr(0));
1605 SmallVector<OpFoldResult> strides(srcType.getRank(),
1606 rewriter.getIndexAttr(1));
1607 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1608 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1609 strides);
1610 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1611 readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1612 Value rankedReducedView =
1613 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1614 readOp.getBase(), offsets, sizes, strides);
1615 auto permMap = getTransferMinorIdentityMap(
1616 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1617 Value result = vector::TransferReadOp::create(
1618 rewriter, loc, resultTargetVecType, rankedReducedView,
1619 readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1620 readOp.getPadding(),
1621 // TODO: support mask.
1622 /*mask=*/Value(), inBoundsAttr);
1623 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(readOp, targetType,
1624 result);
1625 return success();
1626 }
1627};
1628
1629/// Drop inner most contiguous unit dimensions from transfer_write operand.
1630/// E.g.,
1631/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
1632/// {in_bounds = [true, true, true, true, true]}
1633/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
1634///
1635/// will be replaced with
1636///
1637/// %subview = memref.subview %arg0
1638/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
1639/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
1640/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
1641/// to vector<1x16x16xf32>
1642/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
1643/// {in_bounds = [true, true, true]}
1644/// : vector<1x16x16xf32>, memref<1x512x16xf32>
1645///
1646/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
1648 : public OpRewritePattern<vector::TransferWriteOp> {
1649 using Base::Base;
1650
1651 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
1652 PatternRewriter &rewriter) const override {
1653 // TODO: support 0-d corner case.
1654 if (writeOp.getTransferRank() == 0)
1655 return failure();
1656
1657 // TODO: support mask.
1658 if (writeOp.getMask())
1659 return failure();
1660
1661 auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
1662 if (!srcType)
1663 return failure();
1664
1665 if (!writeOp.getPermutationMap().isMinorIdentity())
1666 return failure();
1667
1668 auto targetType = writeOp.getVectorType();
1669 if (targetType.getRank() <= 1)
1670 return failure();
1671
1672 FailureOr<size_t> maybeDimsToDrop =
1673 getTransferFoldableInnerUnitDims(srcType, targetType);
1674 if (failed(maybeDimsToDrop))
1675 return failure();
1676
1677 size_t dimsToDrop = maybeDimsToDrop.value();
1678 if (dimsToDrop == 0)
1679 return failure();
1680
1681 auto inBounds = writeOp.getInBoundsValues();
1682 auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
1683 if (llvm::is_contained(droppedInBounds, false))
1684 return failure();
1685
1686 auto resultTargetVecType =
1687 VectorType::get(targetType.getShape().drop_back(dimsToDrop),
1688 targetType.getElementType(),
1689 targetType.getScalableDims().drop_back(dimsToDrop));
1690
1691 Location loc = writeOp.getLoc();
1693 memref::getMixedSizes(rewriter, loc, writeOp.getBase());
1694 SmallVector<OpFoldResult> offsets(srcType.getRank(),
1695 rewriter.getIndexAttr(0));
1696 SmallVector<OpFoldResult> strides(srcType.getRank(),
1697 rewriter.getIndexAttr(1));
1698 MemRefType resultMemrefType = memref::SubViewOp::inferRankReducedResultType(
1699 srcType.getShape().drop_back(dimsToDrop), srcType, offsets, sizes,
1700 strides);
1701 ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
1702 writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
1703
1704 Value rankedReducedView =
1705 memref::SubViewOp::create(rewriter, loc, resultMemrefType,
1706 writeOp.getBase(), offsets, sizes, strides);
1707 auto permMap = getTransferMinorIdentityMap(
1708 cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
1709
1710 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
1711 loc, resultTargetVecType, writeOp.getVector());
1712 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1713 writeOp, shapeCast, rankedReducedView,
1714 writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
1715 // TODO: support mask.
1716 /*mask=*/Value(), inBoundsAttr);
1717 return success();
1718 }
1719};
1720
1721/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
1722/// semantics to a contraction suitable for MMT (matrix matrix multiplication
1723/// with the RHS transposed) lowering.
1725 : OpRewritePattern<vector::ContractionOp> {
1726 using Base::Base;
1727
1729 std::function<LogicalResult(vector::ContractionOp op)>;
1730
1732 FilterConstraintType constraint)
1733 : OpRewritePattern<vector::ContractionOp>(context, benefit),
1734 filter(std::move(constraint)) {}
1735
1736 LogicalResult matchAndRewrite(vector::ContractionOp op,
1737 PatternRewriter &rewriter) const override {
1738 if (failed(filter(op)))
1739 return failure();
1740
1741 Location loc = op.getLoc();
1742 Value lhs = op.getLhs();
1743 Value rhs = op.getRhs();
1744 Value res = op.getAcc();
1745
1746 // Set up the parallel/reduction structure in right form.
1747 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
1748 auto infer = [&](MapList m) {
1749 return AffineMap::inferFromExprList(m, op.getContext());
1750 };
1751 AffineExpr m;
1752 AffineExpr n;
1753 AffineExpr k;
1754 bindDims(rewriter.getContext(), m, n, k);
1755 static constexpr std::array<int64_t, 2> perm = {1, 0};
1756 auto iteratorTypes = op.getIteratorTypes().getValue();
1757 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
1758 if (iteratorTypes.size() != 3 ||
1759 !vector::isParallelIterator(iteratorTypes[0]) ||
1760 !vector::isParallelIterator(iteratorTypes[1]) ||
1761 !vector::isReductionIterator(iteratorTypes[2]))
1762 return rewriter.notifyMatchFailure(op, "contraction is not a gemm");
1763
1764 // The canonical form is "TNT" = A row-major, B col-major, C row-major.
1765 const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
1766 if (maps == canonicalForm)
1767 return rewriter.notifyMatchFailure(op, "already in the canonical form");
1768
1769 // Create a vector transpose making sure to emit zero/sign-extend at the
1770 // end.
1771 auto createTranspose = [&rewriter, loc](Value mat) -> Value {
1772 if (auto sext = mat.getDefiningOp<arith::ExtSIOp>()) {
1773 Value trans =
1774 vector::TransposeOp::create(rewriter, loc, sext.getIn(), perm);
1775 VectorType newType =
1776 cast<VectorType>(trans.getType())
1777 .clone(cast<VectorType>(mat.getType()).getElementType());
1778 return arith::ExtSIOp::create(rewriter, loc, newType, trans);
1779 }
1780 if (auto zext = mat.getDefiningOp<arith::ExtUIOp>()) {
1781 Value trans =
1782 vector::TransposeOp::create(rewriter, loc, zext.getIn(), perm);
1783 VectorType newType =
1784 VectorType::get(cast<VectorType>(trans.getType()).getShape(),
1785 cast<VectorType>(mat.getType()).getElementType());
1786 return arith::ExtUIOp::create(rewriter, loc, newType, trans);
1787 }
1788 return vector::TransposeOp::create(rewriter, loc, mat, perm);
1789 };
1790
1791 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
1792 rhs = createTranspose(rhs);
1793 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
1794 lhs = createTranspose(lhs);
1795 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
1796 rhs = createTranspose(rhs);
1797 lhs = createTranspose(lhs);
1798 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
1799 std::swap(rhs, lhs);
1800 rhs = createTranspose(rhs);
1801 lhs = createTranspose(lhs);
1802 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
1803 std::swap(rhs, lhs);
1804 rhs = createTranspose(rhs);
1805 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
1806 std::swap(lhs, rhs);
1807 lhs = createTranspose(lhs);
1808 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
1809 std::swap(lhs, rhs);
1810 } else {
1811 return rewriter.notifyMatchFailure(op, "unhandled contraction form");
1812 }
1813 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1814 op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
1815 op.getIteratorTypes());
1816 return success();
1817 };
1818
1819private:
1820 FilterConstraintType filter;
1821};
1822
1823/// Pattern to fold arithmetic extensions on floating point data types into
1824/// vector contraction operations. linalg.matmul introduces arithmetic
1825/// extensions on its operands. Please mlir snippets below for more details.
1826/// ```mlir
1827/// "linalg.matmul"(%lhs, %rhs, %acc) ({
1828/// ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
1829/// %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
1830/// %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
1831/// %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
1832/// %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
1833/// "linalg.yield"(%acc) : (f32) -> ()
1834/// })
1835/// ```
1836/// This restricts the native usage of mixed precision NVIDIA Ampere Tensor
1837/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
1838/// This pattern folds the arithmetic extensions into the vector contraction and
1839/// enables the usage of native mixed precision Tensor Core instructions.
1840template <typename ExtOp>
1842 : public OpRewritePattern<vector::ContractionOp> {
1843 using Base::Base;
1844
1845 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
1846 PatternRewriter &rewriter) const override {
1847
1848 auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1849 auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
1850
1851 if (!lhsDefOp || !rhsDefOp) {
1852 return rewriter.notifyMatchFailure(contractOp,
1853 "no defining op on contract operands");
1854 }
1855
1856 rewriter.replaceOpWithNewOp<vector::ContractionOp>(
1857 contractOp, lhsDefOp->getOperand(0), rhsDefOp->getOperand(0),
1858 contractOp.getAcc(), contractOp.getIndexingMapsAttr(),
1859 contractOp.getIteratorTypesAttr());
1860
1861 return success();
1862 }
1863};
1864
1865/// Pattern to fold chained reduction to a series of vector additions and a
1866/// final reduction. This form should require fewer subgroup operations.
1867///
1868/// ```mlir
1869/// %a = vector.reduction <add> %x, %acc
1870/// %b = vector.reduction <add> %y, %a
1871/// ==>
1872/// %a = arith.addf %x, %y
1873/// %b = vector.reduction <add> %a, %acc
1874/// ```
1875struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
1876 using Base::Base;
1877
1878 LogicalResult matchAndRewrite(vector::ReductionOp op,
1879 PatternRewriter &rewriter) const override {
1880 // TODO: Handle other combining kinds.
1881 if (op.getKind() != vector::CombiningKind::ADD)
1882 return failure();
1883
1884 // Accumulator is optional.
1885 Value acc = op.getAcc();
1886 if (!acc)
1887 return failure();
1888
1889 if (!acc.getType().isIntOrFloat())
1890 return failure();
1891
1892 auto parentReduction = acc.getDefiningOp<vector::ReductionOp>();
1893 if (!parentReduction)
1894 return failure();
1895
1896 Location loc = op.getLoc();
1897 Value vAdd;
1898 if (isa<IntegerType>(acc.getType())) {
1899 vAdd = rewriter.createOrFold<arith::AddIOp>(
1900 loc, parentReduction.getVector(), op.getVector());
1901 } else {
1902 vAdd = arith::AddFOp::create(rewriter, loc, parentReduction.getVector(),
1903 op.getVector());
1904 }
1905 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), vAdd,
1906 parentReduction.getAcc());
1907 return success();
1908 }
1909};
1910
1911// Helper function dropping unit non-scalable dimension from a VectorType
1912// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1913// dimensions are not dropped. Folding such dimensions would require "shifting"
1914// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1915// vector<[4]xf32>). This could be implemented in the future.
1916static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1917 auto inVecShape = inVecTy.getShape();
1918 SmallVector<int64_t> newShape;
1919 SmallVector<bool> newScalableDims;
1920 for (auto [dim, isScalable] :
1921 llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1922 if (dim == 1 && !isScalable)
1923 continue;
1924
1925 newShape.push_back(dim);
1926 newScalableDims.push_back(isScalable);
1927 }
1928 // All dims have been dropped, return vector<1xeType>.
1929 if (newShape.empty()) {
1930 newShape.push_back(1);
1931 newScalableDims.push_back(false);
1932 }
1933
1934 return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1935}
1936
1937/// For vectors with at least one unit dim, replaces:
1938/// elementwise(a, b)
1939/// with:
1940/// sc_a = shape_cast(a)
1941/// sc_b = shape_cast(b)
1942/// res = elementwise(sc_a, sc_b)
1943/// return shape_cast(res)
1944/// The newly inserted shape_cast Ops fold (before elementwise Op) and then
1945/// restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
1946/// required to be rank > 1.
1947///
1948/// Ex:
1949/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
1950/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1951///
1952/// gets converted to:
1953///
1954/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
1955/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
1956/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
1957/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
1958/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1959///
1960/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
1961/// `%cast`.
1963 : public OpTraitRewritePattern<OpTrait::Elementwise> {
1965 LogicalResult matchAndRewrite(Operation *op,
1966 PatternRewriter &rewriter) const override {
1967 if (op->getNumResults() != 1 || op->getNumRegions() != 0)
1968 return failure();
1969
1970 auto resultVectorType = dyn_cast<VectorType>(op->getResult(0).getType());
1971 if (!resultVectorType)
1972 return failure();
1973
1974 // Check the operand pre-conditions. For `Elementwise` ops all operands are
1975 // guaranteed to have identical shapes (with some exceptions such as
1976 // `arith.select`) and it suffices to only check one of them.
1977 auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1978 if (!sourceVectorType)
1979 return failure();
1980 if (sourceVectorType.getRank() < 2)
1981 return failure();
1982
1983 SmallVector<Value> newOperands;
1984 auto loc = op->getLoc();
1985 for (auto operand : op->getOperands()) {
1986 auto opVectorType = cast<VectorType>(operand.getType());
1987 auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1988 if (newVType == opVectorType)
1989 return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1990
1991 auto opSC = vector::ShapeCastOp::create(rewriter, loc, newVType, operand);
1992 newOperands.push_back(opSC);
1993 }
1994
1995 VectorType newResultVectorType =
1996 dropNonScalableUnitDimFromType(resultVectorType);
1997 // Create an updated elementwise Op without unit dim.
1998 Operation *elementwiseOp =
1999 rewriter.create(loc, op->getName().getIdentifier(), newOperands,
2000 newResultVectorType, op->getAttrs());
2001
2002 // Restore the unit dim by applying vector.shape_cast to the result.
2003 rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
2004 elementwiseOp->getResult(0));
2005
2006 return success();
2007 }
2008};
2009
2010/// A pattern to drop unit dims from vector.transpose.
2011///
2012/// Example:
2013///
2014/// BEFORE:
2015/// ```mlir
2016/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
2017/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
2018/// ```
2019///
2020/// AFTER:
2021/// ```mlir
2022/// %dropDims = vector.shape_cast %vector
2023/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
2024/// %transpose = vector.transpose %0, [1, 0]
2025/// : vector<4x[4]xf32> to vector<[4]x4xf32>
2026/// %restoreDims = vector.shape_cast %transpose
2027/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2028/// ```
2030 : OpRewritePattern<vector::TransposeOp> {
2031 using Base::Base;
2032
2033 LogicalResult matchAndRewrite(vector::TransposeOp op,
2034 PatternRewriter &rewriter) const override {
2035 VectorType sourceType = op.getSourceVectorType();
2036 VectorType sourceTypeWithoutUnitDims =
2038
2039 if (sourceType == sourceTypeWithoutUnitDims)
2040 return failure();
2041
2042 // Construct a map from dimIdx -> number of dims dropped before dimIdx.
2043 auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
2044 SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
2045 int64_t droppedDims = 0;
2046 for (auto [i, dim] : llvm::enumerate(sourceDims)) {
2047 droppedDimsBefore[i] = droppedDims;
2048 if (dim == std::make_tuple(1, false))
2049 ++droppedDims;
2050 }
2051
2052 // Drop unit dims from transpose permutation.
2053 ArrayRef<int64_t> perm = op.getPermutation();
2054 SmallVector<int64_t> newPerm;
2055 for (int64_t idx : perm) {
2056 if (sourceDims[idx] == std::make_tuple(1, false))
2057 continue;
2058 newPerm.push_back(idx - droppedDimsBefore[idx]);
2059 }
2060
2061 // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
2062 // type when the dimensions are unit dimensions. In this case, the newPerm
2063 // should be [0].
2064 if (newPerm.empty()) {
2065 newPerm.push_back(0);
2066 }
2067
2068 Location loc = op.getLoc();
2069 // Drop the unit dims via shape_cast.
2070 auto dropDimsShapeCast = vector::ShapeCastOp::create(
2071 rewriter, loc, sourceTypeWithoutUnitDims, op.getVector());
2072 // Create the new transpose.
2073 auto transposeWithoutUnitDims =
2074 vector::TransposeOp::create(rewriter, loc, dropDimsShapeCast, newPerm);
2075 // Restore the unit dims via shape cast.
2076 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2077 op, op.getResultVectorType(), transposeWithoutUnitDims);
2078
2079 return success();
2080 }
2081};
2082
2083/// A pattern to drop unit dims from the iter_args of an scf.for.
2084///
2085/// Example:
2086///
2087/// BEFORE:
2088/// ```mlir
2089/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
2090/// ...
2091/// scf.yield %
2092/// }
2093/// ```
2094///
2095/// AFTER:
2096/// ```mlir
2097/// %drop = vector.shape_cast %init
2098/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
2099/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
2100/// %new_iter = vector.shape_cast %iter
2101/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2102/// ...
2103/// }
2104/// %res = vector.shape_cast %new_loop
2105/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
2106/// ```
2107struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
2108 using Base::Base;
2109
2110 LogicalResult matchAndRewrite(scf::ForOp forOp,
2111 PatternRewriter &rewriter) const override {
2112 /// Find the first iter_arg with droppable unit dims. Further applications
2113 /// of this pattern will apply to later arguments.
2114 for (OpOperand &operand : forOp.getInitArgsMutable()) {
2115 auto vectorType = dyn_cast<VectorType>(operand.get().getType());
2116 if (!vectorType)
2117 continue;
2118
2119 VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
2120 if (vectorType == newVectorType)
2121 continue;
2122
2123 // Create a new ForOp with that iter operand replaced.
2124 auto castFn = [](OpBuilder &b, Location loc, Type type, Value source) {
2125 return vector::ShapeCastOp::create(b, loc, type, source);
2126 };
2127
2129 castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
2130 rewriter.replaceOp(forOp,
2131 replaceAndCastForOpIterArg(rewriter, forOp, operand,
2132 replacement, castFn));
2133 return success();
2134 }
2135 return failure();
2136 }
2137};
2138
2139/// Pattern to eliminate redundant zero-constants added to reduction operands.
2140/// It's enough for there to be one initial zero value, so we can eliminate the
2141/// extra ones that feed into `vector.reduction <add>`. These get created by the
2142/// `ChainedReduction` pattern.
2143///
2144/// ```mlir
2145/// %a = arith.addf %x, %zero
2146/// %b = arith.addf %a, %y
2147/// %c = vector.reduction <add> %b, %acc
2148/// ==>
2149/// %b = arith.addf %a, %y
2150/// %c = vector.reduction <add> %b, %acc
2151/// ```
2152struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
2153 using Base::Base;
2154
2155 LogicalResult matchAndRewrite(vector::ReductionOp op,
2156 PatternRewriter &rewriter) const override {
2157 // TODO: Handle other reduction kinds and their identity values.
2158 if (op.getKind() != vector::CombiningKind::ADD)
2159 return failure();
2160
2161 Type elemType = op.getSourceVectorType().getElementType();
2162 // The integer case should be handled by `arith.addi` folders, only check
2163 // for floats here.
2164 if (!isa<FloatType>(elemType))
2165 return failure();
2166
2167 auto vAdd = op.getVector().getDefiningOp<arith::AddFOp>();
2168 if (!vAdd)
2169 return failure();
2170 auto addLhs = vAdd.getLhs().getDefiningOp<arith::AddFOp>();
2171 if (!addLhs)
2172 return failure();
2173
2174 if (!matchPattern(addLhs.getRhs(), m_AnyZeroFloat()))
2175 return failure();
2176
2177 auto newAdd = arith::AddFOp::create(rewriter, vAdd.getLoc(),
2178 addLhs.getLhs(), vAdd.getRhs());
2179 rewriter.replaceOpWithNewOp<vector::ReductionOp>(op, op.getKind(), newAdd,
2180 op.getAcc());
2181 return success();
2182 }
2183};
2184
2185/// Example:
2186/// ```
2187/// %a = vector.reduction <add> %x : vector<2xf32> into f32
2188/// ```
2189/// is transformed into:
2190/// ```
2191/// %y = vector.extract %x[0] : f32 from vector<2xf32>
2192/// %z = vector.extract %x[1] : f32 from vector<2xf32>
2193/// %a = arith.addf %y, %z : f32
2194/// ```
2195struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
2197 unsigned maxNumElementsToExtract,
2198 PatternBenefit benefit)
2199 : OpRewritePattern(context, benefit),
2200 maxNumElementsToExtract(maxNumElementsToExtract) {}
2201
2202 LogicalResult matchAndRewrite(vector::ReductionOp op,
2203 PatternRewriter &rewriter) const override {
2204 VectorType type = op.getSourceVectorType();
2205 if (type.isScalable() || op.isMasked())
2206 return failure();
2207 assert(type.getRank() == 1 && "Expected a 1-d vector");
2208
2209 int64_t numElems = type.getNumElements();
2210 if (numElems > maxNumElementsToExtract) {
2211 return rewriter.notifyMatchFailure(
2212 op, llvm::formatv("has too many vector elements ({0}) to break down "
2213 "(max allowed: {1})",
2214 numElems, maxNumElementsToExtract));
2215 }
2216
2217 Location loc = op.getLoc();
2218 SmallVector<Value> extracted(numElems, nullptr);
2219 for (auto [idx, extractedElem] : llvm::enumerate(extracted))
2220 extractedElem = vector::ExtractOp::create(rewriter, loc, op.getVector(),
2221 static_cast<int64_t>(idx));
2222
2223 Value res = extracted.front();
2224 for (auto extractedElem : llvm::drop_begin(extracted))
2225 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res,
2226 extractedElem, op.getFastmathAttr());
2227 if (Value acc = op.getAcc())
2228 res = vector::makeArithReduction(rewriter, loc, op.getKind(), res, acc,
2229 op.getFastmathAttr());
2230
2231 rewriter.replaceOp(op, res);
2232 return success();
2233 }
2234
2235private:
2236 unsigned maxNumElementsToExtract = 0;
2237};
2238
2239/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A,
2240/// B)`.
2241/// Example:
2242/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
2243/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
2244/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
2245/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
2246///
2247/// Becomes :
2248///
2249/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
2250///
2251/// Supports only 1D-to-2D broadcasts. The following cases are not supported.
2252/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
2253/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
2254/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>
2255template <typename MulOpType>
2256struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
2257 using OpRewritePattern<MulOpType>::OpRewritePattern;
2258 // Returns whether a vector.broadcast matches requirements for an outerproduct
2259 // pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension.
2260 bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const {
2261 // Fail if it is not a 1-to-2 dimension to broadcast to avoid generating
2262 // shape_casts/broadcasts which does not belong in this pattern.
2263 if (!broadcastOp.computeBroadcastedUnitDims().empty())
2264 return false;
2265 // Avoid broadcast like f32 or vector<f32> -> ResType
2266 auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
2267 return srcType && srcType.getRank() != 2;
2268 }
2269
2270 LogicalResult matchAndRewrite(MulOpType mulOp,
2271 PatternRewriter &rewriter) const override {
2272 auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
2273 if (!resType)
2274 return failure();
2275 if (resType.getRank() != 2)
2276 return failure();
2277 /// If operandA can be written as tr(broadcast(A)) and operandB as
2278 /// broadcast(B) where broadcasts are 1D-to-2D, create and return
2279 /// vector.outerproduct(A, B). Returns failure() otherwise.
2280 auto matchOuterProduct =
2281 [&](Value operandA,
2282 Value operandB) -> FailureOr<vector::OuterProductOp> {
2283 auto transposedLhs = operandA.getDefiningOp<vector::TransposeOp>();
2284 if (!transposedLhs)
2285 return failure();
2286 // Fail unless this is a true 2-D matrix transpose.
2287 ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
2288 if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0)
2289 return failure();
2290
2291 auto broadcastedLhs =
2292 transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>();
2293 if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs))
2294 return failure();
2295
2296 auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>();
2297 if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs))
2298 return failure();
2299
2300 return vector::OuterProductOp::create(
2301 rewriter, mulOp->getLoc(), resType, broadcastedLhs.getSource(),
2302 broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD);
2303 };
2304
2305 Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1);
2306 auto maybeOuterP = matchOuterProduct(lhs, rhs);
2307 // Handle commutativity, the transposed op is the outerproduct LHS.
2308 if (failed(maybeOuterP))
2309 maybeOuterP = matchOuterProduct(rhs, lhs);
2310 if (failed(maybeOuterP))
2311 return failure();
2312 rewriter.replaceOp(mulOp, maybeOuterP->getResult());
2313 return success();
2314 }
2315};
2316
2317} // namespace
2318
2325
2326void mlir::vector::populateVectorMaskMaterializationPatterns(
2327 RewritePatternSet &patterns, bool force32BitVectorIndices,
2328 PatternBenefit benefit) {
2329 patterns.add<VectorCreateMaskOpConversion,
2330 MaterializeTransferMask<vector::TransferReadOp>,
2331 MaterializeTransferMask<vector::TransferWriteOp>>(
2332 patterns.getContext(), force32BitVectorIndices, benefit);
2333 patterns.add<FoldI1Select>(patterns.getContext(), benefit);
2334}
2335
2336void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
2339 DropUnitDimsFromTransposeOp>(patterns.getContext(), benefit);
2340}
2341
2342void mlir::vector::populateBubbleVectorBitCastOpPatterns(
2344 patterns.add<BubbleDownVectorBitCastForExtract,
2345 BubbleDownBitCastForStridedSliceExtract,
2346 BubbleUpBitCastForInsert, BubbleUpBitCastForStridedSliceInsert>(
2347 patterns.getContext(), benefit);
2348}
2349
2350void mlir::vector::populateBreakDownVectorBitCastOpPatterns(
2352 std::function<bool(vector::BitCastOp)> controlFn, PatternBenefit benefit) {
2353 patterns.add<BreakDownVectorBitCast>(patterns.getContext(),
2354 std::move(controlFn), benefit);
2355}
2356
2359 std::function<LogicalResult(vector::ContractionOp)> constraint,
2360 PatternBenefit benefit) {
2361 patterns.add<CanonicalizeContractMatmulToMMT>(patterns.getContext(), benefit,
2362 std::move(constraint));
2363}
2364
2367 patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
2368 CombineContractABTranspose, CombineContractResultTranspose>(
2369 patterns.getContext(), benefit);
2370}
2371
2378
2380 PatternBenefit benefit) {
2381 patterns.add<ReorderElementwiseOpsOnTranspose, ReorderCastOpsOnBroadcast,
2382 ReorderElementwiseOpsOnBroadcast, ExtractOpFromElementwise>(
2383 patterns.getContext(), benefit);
2384}
2385
2386void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2387 PatternBenefit benefit) {
2388 // TODO: Consider converting these patterns to canonicalizations.
2389 patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
2390 benefit);
2391}
2392
2393void mlir::vector::populateChainedVectorReductionFoldingPatterns(
2395 patterns.add<ChainedReduction>(patterns.getContext(), benefit);
2396 patterns.add<ReduceRedundantZero>(patterns.getContext(),
2397 PatternBenefit(benefit.getBenefit() + 1));
2398}
2399
2400void mlir::vector::populateBreakDownVectorReductionPatterns(
2401 RewritePatternSet &patterns, unsigned maxNumElementsToExtract,
2402 PatternBenefit benefit) {
2403 patterns.add<BreakDownVectorReduction>(patterns.getContext(),
2404 maxNumElementsToExtract, benefit);
2405}
2406
2409 patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>,
2410 FoldArithToVectorOuterProduct<arith::MulIOp>>(
2411 patterns.getContext());
2412}
2413
2414//===----------------------------------------------------------------------===//
2415// TableGen'd enum attribute definitions
2416//===----------------------------------------------------------------------===//
2417
2418#include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.cpp.inc"
return success()
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static FailureOr< size_t > getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType)
Returns the number of dims can be folded away from transfer ops. It returns a failure if it can not d...
Drop inner most contiguous unit dimensions from transfer_read operand.
Drop inner most contiguous unit dimensions from transfer_write operand. E.g., vector....
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:122
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:128
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents an operand of an operation.
Definition Value.h:257
OpTraitRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting again...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold elementwise op on vectors to the vector dialect.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
void populateDropInnerMostUnitDimsXferOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to collapse the most inner unit dims in xfer Ops.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns)
Collect a set of patterns that fold arithmetic extension on floating point into vector contract for t...
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Patterns that remove redundant Vector Ops by re-ordering them with e.g.
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_float_predicate_matcher m_AnyZeroFloat()
Matches a constant scalar / vector splat / tensor splat float (both positive and negative) zero.
Definition Matchers.h:399
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims)
Drop the dims that are listed in unusedDims.
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
BreakDownVectorReduction(MLIRContext *context, unsigned maxNumElementsToExtract, PatternBenefit benefit)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction sui...
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
CanonicalizeContractMatmulToMMT(MLIRContext *context, PatternBenefit benefit, FilterConstraintType constraint)
Pattern to fold chained reduction to a series of vector additions and a final reduction....
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
For vectors with at least one unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b =...
OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
A pattern to drop unit dims from the iter_args of an scf.for.
LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override
A pattern to drop unit dims from vector.transpose.
LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override
Pattern to fold arithmetic extensions on floating point data types into vector contraction operations...
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override
Pattern to eliminate redundant zero-constants added to reduction operands. It's enough for there to b...
LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.