20LogicalResult resolveAllTrueCreateMaskOp(
IRRewriter &rewriter,
21 vector::CreateMaskOp createMaskOp,
22 VscaleRange vscaleRange) {
23 auto maskType = createMaskOp.getVectorType();
24 auto maskTypeDimScalableFlags = maskType.getScalableDims();
25 auto maskTypeDimSizes = maskType.getShape();
27 struct UnknownMaskDim {
39 for (
auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
42 if (maskTypeDimScalableFlags[i] || intSize < maskTypeDimSizes[i])
46 if (vscaleMultiplier < maskTypeDimSizes[i])
50 unknownDims.push_back(UnknownMaskDim{i, dimSize});
54 for (
auto [i, dimSize] : unknownDims) {
57 FailureOr<ConstantOrScalableBound> dimLowerBound =
59 dimSize, {}, vscaleRange.vscaleMin, vscaleRange.vscaleMax,
63 auto dimLowerBoundSize = dimLowerBound->getSize();
64 if (
failed(dimLowerBoundSize))
66 if (dimLowerBoundSize->scalable) {
69 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
74 if (maskTypeDimScalableFlags[i])
77 if (dimLowerBoundSize->baseSize < maskTypeDimSizes[i])
85 auto allTrue = vector::ConstantMaskOp::create(
96 std::optional<VscaleRange> vscaleRange) {
107 function.walk([&](vector::CreateMaskOp createMaskOp) {
108 worklist.push_back(createMaskOp);
112 for (
auto mask : worklist)
113 (
void)resolveAllTrueCreateMaskOp(rewriter, mask, *vscaleRange);
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
void eliminateVectorMasks(IRRewriter &rewriter, FunctionOpInterface function, std::optional< VscaleRange > vscaleRange={})
Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds masking) fastpath and a s...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static FailureOr< ConstantOrScalableBound > computeScalableBound(Value value, std::optional< int64_t > dim, unsigned vscaleMin, unsigned vscaleMax, presburger::BoundType boundType, bool closedUB=true, const StopConditionFn &stopCondition=nullptr)
Computes a (possibly) scalable bound for a given value.