22 LogicalResult resolveAllTrueCreateMaskOp(
IRRewriter &rewriter,
23 vector::CreateMaskOp createMaskOp,
25 auto maskType = createMaskOp.getVectorType();
26 auto maskTypeDimScalableFlags = maskType.getScalableDims();
27 auto maskTypeDimSizes = maskType.getShape();
29 struct UnknownMaskDim {
44 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
48 if (vscaleMultiplier < maskTypeDimSizes[i])
52 unknownDims.push_back(UnknownMaskDim{i, dimSize});
56 for (
auto [i, dimSize] : unknownDims) {
59 FailureOr<ConstantOrScalableBound> dimLowerBound =
63 if (failed(dimLowerBound))
65 auto dimLowerBoundSize = dimLowerBound->getSize();
66 if (failed(dimLowerBoundSize))
68 if (dimLowerBoundSize->scalable) {
71 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
76 if (maskTypeDimScalableFlags[i])
79 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
87 auto allTrue = rewriter.
create<vector::ConstantMaskOp>(
98 std::optional<VscaleRange> vscaleRange) {
109 function.walk([&](vector::CreateMaskOp createMaskOp) {
110 worklist.push_back(createMaskOp);
114 for (
auto mask : worklist)
115 (void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, std::optional< VscaleRange > vscaleRange={})
Attempts to eliminate redundant vector masks by replacing them with all-true constants at the top of ...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, StopConditionFn stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.