19 #define DEBUG_TYPE "vector-shuffle-lowering"
45 struct MixedSizeInputShuffleOpRewrite final
49 LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
51 auto v1Type = shuffleOp.getV1VectorType();
52 auto v2Type = shuffleOp.getV2VectorType();
55 if (v1Type.getRank() != 1 || v2Type.getRank() != 1)
59 int64_t v1OrigNumElems = v1Type.getNumElements();
60 int64_t v2OrigNumElems = v2Type.getNumElements();
61 if (v1OrigNumElems == v2OrigNumElems)
65 bool promoteV1 = v1OrigNumElems < v2OrigNumElems;
66 Value inputToPromote = promoteV1 ? shuffleOp.getV1() : shuffleOp.getV2();
67 VectorType promotedType = promoteV1 ? v2Type : v1Type;
68 int64_t origNumElems = promoteV1 ? v1OrigNumElems : v2OrigNumElems;
69 int64_t promotedNumElems = promoteV1 ? v2OrigNumElems : v1OrigNumElems;
74 for (int64_t i = 0; i < origNumElems; ++i)
77 Value promotedInput = rewriter.
create<vector::ShuffleOp>(
78 shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote,
82 Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1();
83 Value promotedV2 = promoteV1 ? shuffleOp.getV2() : promotedInput;
87 newMask = to_vector(shuffleOp.getMask());
90 for (
auto idx : shuffleOp.getMask()) {
92 if (idx >= v1OrigNumElems) {
93 newIdx += promotedNumElems - v1OrigNumElems;
95 newMask.push_back(newIdx);
100 shuffleOp, shuffleOp.getResultVectorType(), promotedV1, promotedV2,
109 patterns.add<MixedSizeInputShuffleOpRewrite>(
patterns.getContext(), benefit);
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorShuffleLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...