22#define DEBUG_TYPE "lower-vector-mask"
26#define GEN_PASS_DEF_LOWERVECTORMASKPASS
27#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
53 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
55 auto dstType = cast<VectorType>(op.getResult().getType());
56 int64_t rank = dstType.getRank();
59 op,
"0-D and 1-D vectors are handled separately");
61 if (dstType.getScalableDims().front())
63 op,
"Cannot unroll leading scalable dim in dstType");
65 auto loc = op.getLoc();
66 int64_t dim = dstType.getDimSize(0);
67 Value idx = op.getOperand(0);
70 Value trueVal = vector::CreateMaskOp::create(rewriter, loc, lowType,
71 op.getOperands().drop_front());
72 Value falseVal = arith::ConstantOp::create(rewriter, loc, lowType,
74 Value result = arith::ConstantOp::create(rewriter, loc, dstType,
76 for (
int64_t d = 0; d < dim; d++) {
78 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(d));
79 Value val = arith::CmpIOp::create(rewriter, loc,
80 arith::CmpIPredicate::slt, bnd, idx);
82 arith::SelectOp::create(rewriter, loc, val, trueVal, falseVal);
83 result = vector::InsertOp::create(rewriter, loc, sel,
result, d);
101class ConstantMaskOpLowering :
public OpRewritePattern<vector::ConstantMaskOp> {
105 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
107 auto loc = op.getLoc();
108 auto dstType = op.getType();
109 auto dimSizes = op.getMaskDimSizes();
110 int64_t rank = dstType.getRank();
113 assert(dimSizes.size() == 1 &&
114 "Expected exactly one dim size for a 0-D vector");
115 bool value = dimSizes.front() == 1;
123 int64_t trueDimSize = dimSizes.front();
126 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
136 SmallVector<bool> values(dstType.getDimSize(0),
false);
137 for (int64_t d = 0; d < trueDimSize; d++)
145 if (dstType.getScalableDims().front())
147 op,
"Cannot unroll leading scalable dim in dstType");
149 VectorType lowType = VectorType::Builder(dstType).dropDim(0);
150 Value trueVal = vector::ConstantMaskOp::create(rewriter, loc, lowType,
151 dimSizes.drop_front());
152 Value
result = arith::ConstantOp::create(rewriter, loc, dstType,
154 for (int64_t d = 0; d < trueDimSize; d++)
155 result = vector::InsertOp::create(rewriter, loc, trueVal,
result, d);
165 patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
185template <
class SourceOp>
190 LogicalResult matchAndRewrite(MaskOp maskOp,
192 auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
195 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
199 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
203 virtual LogicalResult
204 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
209struct MaskedTransferReadOpPattern
210 :
public MaskOpRewritePattern<TransferReadOp> {
212 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
215 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
216 PatternRewriter &rewriter)
const override {
220 if (maskingOp.hasPassthru())
222 maskingOp,
"Can't lower passthru to vector.transfer_read");
226 maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(),
227 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
228 maskingOp.getMask(), readOp.getInBounds());
234struct MaskedTransferWriteOpPattern
235 :
public MaskOpRewritePattern<TransferWriteOp> {
237 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
240 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
241 MaskingOpInterface maskingOp,
242 PatternRewriter &rewriter)
const override {
244 writeOp.getResult() ? writeOp.getResult().getType() : Type();
248 maskingOp.getOperation(), resultType, writeOp.getVector(),
249 writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(),
250 maskingOp.getMask(), writeOp.getInBounds());
256struct MaskedGatherOpPattern :
public MaskOpRewritePattern<GatherOp> {
258 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
261 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
262 PatternRewriter &rewriter)
const override {
263 Value passthru = maskingOp.hasPassthru()
264 ? maskingOp.getPassthru()
265 : arith::ConstantOp::create(
266 rewriter, gatherOp.getLoc(),
271 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272 gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(),
278struct LowerVectorMaskPass
282 void runOnOperation()
override {
283 Operation *op = getOperation();
286 RewritePatternSet loweringPatterns(context);
288 MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
294 void getDependentDialects(DialectRegistry ®istry)
const override {
295 registry.
insert<vector::VectorDialect>();
306 patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
307 MaskedGatherOpPattern>(
patterns.getContext());
311 return std::make_unique<LowerVectorMaskPass>();
IntegerAttr getIndexAttr(int64_t value)
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
LowerVectorMaskPassBase Base
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...