22 #define DEBUG_TYPE "lower-vector-mask"
26 #define GEN_PASS_DEF_LOWERVECTORMASKPASS
27 #include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
55 auto dstType = op.getResult().getType().cast<VectorType>();
56 int64_t rank = dstType.getRank();
59 op,
"0-D and 1-D vectors are handled separately");
61 auto loc = op.getLoc();
62 auto eltType = dstType.getElementType();
63 int64_t dim = dstType.getDimSize(0);
64 Value idx = op.getOperand(0);
67 VectorType::get(dstType.getShape().drop_front(), eltType);
68 Value trueVal = rewriter.
create<vector::CreateMaskOp>(
69 loc, lowType, op.getOperands().drop_front());
70 Value falseVal = rewriter.
create<arith::ConstantOp>(
74 for (int64_t d = 0; d < dim; d++) {
77 Value val = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
79 Value sel = rewriter.
create<arith::SelectOp>(loc, val, trueVal, falseVal);
82 rewriter.
create<vector::InsertOp>(loc, dstType, sel, result, pos);
100 class ConstantMaskOpLowering :
public OpRewritePattern<vector::ConstantMaskOp> {
106 auto loc = op.getLoc();
107 auto dstType = op.getType();
108 auto eltType = dstType.getElementType();
109 auto dimSizes = op.getMaskDimSizes();
110 int64_t rank = dstType.getRank();
113 assert(dimSizes.size() == 1 &&
114 "Expected exactly one dim size for a 0-D vector");
115 bool value = dimSizes[0].cast<IntegerAttr>().getInt() == 1;
125 if (dstType.cast<VectorType>().isScalable()) {
131 int64_t trueDim =
std::min(dstType.getDimSize(0),
132 dimSizes[0].cast<IntegerAttr>().getInt());
138 for (int64_t d = 0; d < trueDim; d++)
146 VectorType::get(dstType.getShape().drop_front(), eltType);
148 for (int64_t r = 1; r < rank; r++)
149 newDimSizes.push_back(dimSizes[r].cast<IntegerAttr>().getInt());
150 Value trueVal = rewriter.
create<vector::ConstantMaskOp>(
154 for (int64_t d = 0; d < trueDim; d++) {
157 rewriter.
create<vector::InsertOp>(loc, dstType, trueVal, result, pos);
167 patterns.
add<CreateMaskOpLowering, ConstantMaskOpLowering>(
187 template <
class SourceOp>
194 MaskableOpInterface maskableOp = maskOp.getMaskableOp();
195 SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
199 return matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
204 matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
209 struct MaskedTransferReadOpPattern
210 :
public MaskOpRewritePattern<TransferReadOp> {
212 using MaskOpRewritePattern<TransferReadOp>::MaskOpRewritePattern;
215 matchAndRewriteMaskableOp(TransferReadOp readOp, MaskingOpInterface maskingOp,
220 if (maskingOp.hasPassthru())
222 maskingOp,
"Can't lower passthru to vector.transfer_read");
226 maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
227 readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
228 maskingOp.getMask(), readOp.getInBounds().value_or(ArrayAttr()));
234 struct MaskedTransferWriteOpPattern
235 :
public MaskOpRewritePattern<TransferWriteOp> {
237 using MaskOpRewritePattern<TransferWriteOp>::MaskOpRewritePattern;
240 matchAndRewriteMaskableOp(TransferWriteOp writeOp,
241 MaskingOpInterface maskingOp,
244 writeOp.getResult() ? writeOp.getResult().getType() :
Type();
248 maskingOp.getOperation(), resultType, writeOp.getVector(),
249 writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
250 maskingOp.getMask(), writeOp.getInBounds().value_or(ArrayAttr()));
256 struct MaskedGatherOpPattern :
public MaskOpRewritePattern<GatherOp> {
258 using MaskOpRewritePattern<GatherOp>::MaskOpRewritePattern;
261 matchAndRewriteMaskableOp(GatherOp gatherOp, MaskingOpInterface maskingOp,
263 Value passthru = maskingOp.hasPassthru()
264 ? maskingOp.getPassthru()
265 : rewriter.
create<arith::ConstantOp>(
271 maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(),
272 gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(),
278 struct LowerVectorMaskPass
279 :
public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
282 void runOnOperation()
override {
294 registry.
insert<vector::VectorDialect>();
305 patterns.
add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern,
306 MaskedGatherOpPattern>(patterns.
getContext());
310 return std::make_unique<LowerVectorMaskPass>();
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIndexAttr(int64_t value)
DenseIntElementsAttr getBoolVectorAttr(ArrayRef< bool > values)
Vector-typed DenseIntElementsAttr getters. values must not be empty.
Attribute getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createLowerVectorMaskPass()
Creates an instance of the vector.mask lowering pass.
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorMaskLoweringPatternsForSideEffectingOps(RewritePatternSet &patterns)
Populates instances of MaskOpRewritePattern to lower masked operations with vector....
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...