MLIR  22.0.0git
LowerVectorShuffle.cpp
Go to the documentation of this file.
1 //===- LowerVectorShuffle.cpp - Lower 'vector.shuffle' operation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of complex `vector.shuffle` operation to a
10 // set of simpler operations supported by LLVM/SPIR-V.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/PatternMatch.h"
18 
19 #define DEBUG_TYPE "vector-shuffle-lowering"
20 
21 using namespace mlir;
22 using namespace mlir::vector;
23 
24 namespace {
25 
26 /// Lowers a `vector.shuffle` operation with mixed-size inputs to a new
27 /// `vector.shuffle` which promotes the smaller input to the larger vector size
28 /// and an updated version of the original `vector.shuffle`.
29 ///
30 /// Example:
31 ///
32 /// %0 = vector.shuffle %v1, %v2 [0, 2, 1, 3] : vector<2xf32>, vector<4xf32>
33 ///
34 /// is lowered to:
35 ///
36 /// %0 = vector.shuffle %v1, %v1 [0, 1, -1, -1] :
37 /// vector<2xf32>, vector<2xf32>
38 /// %1 = vector.shuffle %0, %v2 [0, 4, 1, 5] :
39 /// vector<4xf32>, vector<4xf32>
40 ///
41 /// Note: This transformation helps legalize vector.shuffle ops when lowering
42 /// to SPIR-V/LLVM, which don't support shuffle operations with mixed-size
43 /// inputs.
44 ///
45 struct MixedSizeInputShuffleOpRewrite final
46  : OpRewritePattern<vector::ShuffleOp> {
48 
49  LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
50  PatternRewriter &rewriter) const override {
51  auto v1Type = shuffleOp.getV1VectorType();
52  auto v2Type = shuffleOp.getV2VectorType();
53 
54  // Only support 1-D shuffle for now.
55  if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
56  return failure();
57 
58  // Bail out if inputs don't have mixed sizes.
59  int64_t v1OrigNumElems = v1Type.getNumElements();
60  int64_t v2OrigNumElems = v2Type.getNumElements();
61  if (v1OrigNumElems == v2OrigNumElems)
62  return failure();
63 
64  // Determine which input needs promotion.
65  bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
66  Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
67  VectorType promotedType = promoteV1 ? v2Type : v1Type;
68  int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
69  int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
70 
71  // Create a shuffle with a mask that preserves existing elements and fills
72  // up with poison.
73  SmallVector<int64_t> promoteMask(promotedNumElems, ShuffleOp::kPoisonIndex);
74  for (int64_t i = 0; i < origNumElems; ++i)
75  promoteMask[i] = i;
76 
77  Value promotedInput = rewriter.create<vector::ShuffleOp>(
78  shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
79  promoteMask);
80 
81  // Create the final shuffle with the promoted inputs.
82  Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
83  Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
84 
85  SmallVector<int64_t> newMask;
86  if (!promoteV1) {
87  newMask = to_vector(shuffleOp.getMask());
88  } else {
89  // Adjust V2 indices to account for the new V1 size.
90  for (auto idx : shuffleOp.getMask()) {
91  int64_t newIdx = idx;
92  if (idx >= v1OrigNumElems) {
93  newIdx += promotedNumElems - v1OrigNumElems;
94  }
95  newMask.push_back(newIdx);
96  }
97  }
98 
99  rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
100  shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
101  newMask);
102  return success();
103  }
104 };
105 } // namespace
106 
109  patterns.add<MixedSizeInputShuffleOpRewrite>(patterns.getContext(), benefit);
110 }
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322