49 struct VectorMaskedLoadOpConverter final
53 LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
55 VectorType maskVType = maskedLoadOp.getMaskVectorType();
56 if (maskVType.getShape().size() != 1)
58 maskedLoadOp,
"expected vector.maskedstore with 1-D mask");
60 Location loc = maskedLoadOp.getLoc();
61 int64_t maskLength = maskVType.getShape()[0];
64 Value mask = maskedLoadOp.getMask();
65 Value base = maskedLoadOp.getBase();
66 Value iValue = maskedLoadOp.getPassThru();
67 auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
70 for (int64_t i = 0; i < maskLength; ++i) {
71 auto maskBit = rewriter.
create<vector::ExtractOp>(loc, mask, i);
73 auto ifOp = rewriter.
create<scf::IfOp>(
77 builder.create<memref::LoadOp>(loc, base, indices);
79 builder.create<vector::InsertOp>(loc, loadedValue, iValue, i);
80 builder.create<scf::YieldOp>(loc, combinedValue.getResult());
83 builder.
create<scf::YieldOp>(loc, iValue);
87 indices.back() = rewriter.
create<arith::AddIOp>(loc, indices.back(), one);
116 struct VectorMaskedStoreOpConverter final
120 LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
122 VectorType maskVType = maskedStoreOp.getMaskVectorType();
123 if (maskVType.getShape().size() != 1)
125 maskedStoreOp,
"expected vector.maskedstore with 1-D mask");
127 Location loc = maskedStoreOp.getLoc();
128 int64_t maskLength = maskVType.getShape()[0];
131 Value mask = maskedStoreOp.getMask();
132 Value base = maskedStoreOp.getBase();
133 Value value = maskedStoreOp.getValueToStore();
134 auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
137 for (int64_t i = 0; i < maskLength; ++i) {
138 auto maskBit = rewriter.
create<vector::ExtractOp>(loc, mask, i);
140 auto ifOp = rewriter.
create<scf::IfOp>(loc, maskBit,
false);
142 auto extractedValue = rewriter.
create<vector::ExtractOp>(loc, value, i);
143 rewriter.
create<memref::StoreOp>(loc, extractedValue, base, indices);
146 indices.back() = rewriter.
create<arith::AddIOp>(loc, indices.back(), one);
149 rewriter.
eraseOp(maskedStoreOp);
159 patterns.
add<VectorMaskedLoadOpConverter, VectorMaskedStoreOpConverter>(
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...