22 #define DEBUG_TYPE "lower-vector-mask"
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
53 LogicalResult matchAndRewrite(vector::CreateMaskOp op,
55 auto dstType = cast<VectorType>(op.getResult().getType());
56 int64_t rank = dstType.getRank();
59 op,
"0-D and 1-D vectors are handled separately");
61 if (dstType.getScalableDims().front())
63 op,
"Cannot unroll leading scalable dim in dstType");
65 auto loc = op.getLoc();
66 int64_t dim = dstType.getDimSize(0);
67 Value idx = op.getOperand(0);
70 Value trueVal = rewriter.
create<vector::CreateMaskOp>(
71 loc, lowType, op.getOperands().drop_front());
72 Value falseVal = rewriter.
create<arith::ConstantOp>(
76 for (int64_t d = 0; d < dim; d++) {
79 Value val = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
81 Value sel = rewriter.
create<arith::SelectOp>(loc, val, trueVal, falseVal);
82 result = rewriter.
create<vector::InsertOp>(loc, sel, result, d);
100 class ConstantMaskOpLowering :
public OpRewritePattern<vector::ConstantMaskOp> {
104 LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
106 auto loc = op.getLoc();
107 auto dstType = op.getType();
108 auto dimSizes = op.getMaskDimSizes();
109 int64_t rank = dstType.getRank();
112 assert(dimSizes.size() == 1 &&
113 "Expected exactly one dim size for a 0-D vector");
114 bool value = dimSizes.front() == 1;
122 int64_t trueDimSize = dimSizes.front();
125 if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
136 for (int64_t d = 0; d < trueDimSize; d++)
144 if (dstType.getScalableDims().front())
146 op,
"Cannot unroll leading scalable dim in dstType");
149 Value trueVal = rewriter.
create<vector::ConstantMaskOp>(
150 loc, lowType, dimSizes.drop_front());
153 for (int64_t d = 0; d < trueDimSize; d++)
154 result = rewriter.
create<vector::InsertOp>(loc, trueVal, result, d);
164 patterns.
add<CreateMaskOpLowering, ConstantMaskOpLowering>(
184 template <
class SourceOp>
189 LogicalResult matchAndRewrite(MaskOp maskOp,
191 auto maskableOp = cast_or_null<MaskableOpInterface>(maskOp.getMaskableOp());
194 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
198 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
202 virtual LogicalResult
203 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
208 struct MaskedTransferReadOpPattern
209 :
public MaskOpRewritePattern<TransferReadOp> {
211 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
214 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
219 if (maskingOp.hasPassthru())
221 maskingOp,
"Can't lower passthru to vector.transfer_read");
225 maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
226 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
227 maskingOp.getMask(), readOp.getInBounds());
233 struct MaskedTransferWriteOpPattern
234 :
public MaskOpRewritePattern<TransferWriteOp> {
236 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
239 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
240 MaskingOpInterface maskingOp,
243 writeOp.getResult() ? writeOp.getResult().getType() :
Type();
247 maskingOp.getOperation(), resultType, writeOp.getVector(),
248 writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
249 maskingOp.getMask(), writeOp.getInBounds());
255 struct MaskedGatherOpPattern :
public MaskOpRewritePattern<GatherOp> {
257 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
260 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
262 Value passthru = maskingOp.hasPassthru()
263 ? maskingOp.getPassthru()
264 : rewriter.
create<arith::ConstantOp>(
270 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
271 gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
277 struct LowerVectorMaskPass
278 :
public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
281 void runOnOperation()
override {
287 MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
294 registry.
insert<vector::VectorDialect>();
305 patterns.
add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306 MaskedGatherOpPattern>(patterns.
getContext());
310 return std::make_unique<LowerVectorMaskPass>();
IntegerAttr getIndexAttr(int64_t value)
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...