24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25 #include "mlir/Dialect/Affine/Passes.h.inc"
44 bool knownNonNegative) {
45 if (staticBasis.empty())
49 result.reserve(staticBasis.size());
50 size_t dynamicIndex = dynamicBasis.size();
51 Value dynamicPart =
nullptr;
52 int64_t staticPart = 1;
55 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
57 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
58 for (int64_t elem : llvm::reverse(staticBasis)) {
59 if (ShapedType::isDynamic(elem)) {
63 dynamicPart = rewriter.
create<arith::MulIOp>(
64 loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
66 dynamicPart = dynamicBasis[dynamicIndex - 1];
72 if (dynamicPart && staticPart == 1) {
73 result.push_back(dynamicPart);
76 rewriter.
createOrFold<arith::ConstantIndexOp>(loc, staticPart);
79 rewriter.
create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
80 result.push_back(stride);
83 std::reverse(result.begin(), result.end());
90 struct LowerDelinearizeIndexOps
93 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
96 Value linearIdx = op.getLinearIndex();
97 unsigned numResults = op.getNumResults();
99 if (numResults == staticBasis.size())
100 staticBasis = staticBasis.drop_front();
102 if (numResults == 1) {
108 results.reserve(numResults);
116 rewriter.
create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
117 results.push_back(initialPart);
119 auto emitModTerm = [&](
Value stride) ->
Value {
120 Value remainder = rewriter.
create<arith::RemSIOp>(loc, linearIdx, stride);
121 Value remainderNegative = rewriter.
create<arith::CmpIOp>(
122 loc, arith::CmpIPredicate::slt, remainder, zero);
127 loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128 Value mod = rewriter.
create<arith::SelectOp>(loc, remainderNegative,
129 corrected, remainder);
134 for (
size_t i = 0, e = strides.size() - 1; i < e; ++i) {
135 Value thisStride = strides[i];
136 Value nextStride = strides[i + 1];
137 Value modulus = emitModTerm(thisStride);
141 Value divided = rewriter.
create<arith::DivSIOp>(loc, modulus, nextStride);
142 results.push_back(divided);
145 results.push_back(emitModTerm(strides.back()));
156 struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
158 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
161 if (op.getMultiIndex().empty()) {
168 size_t numIndexes = multiIndex.size();
170 if (numIndexes == staticBasis.size())
171 staticBasis = staticBasis.drop_front();
177 scaledValues.reserve(numIndexes);
182 for (
auto [stride, idxOp] :
183 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
185 loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
187 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
189 scaledValues.emplace_back(
195 llvm::stable_sort(scaledValues,
196 [&](
auto l,
auto r) {
return l.second > r.second; });
198 Value result = scaledValues.front().first;
199 for (
auto [scaledValue, numHoistableLoops] :
200 llvm::drop_begin(scaledValues)) {
201 std::ignore = numHoistableLoops;
202 result = rewriter.
create<arith::AddIOp>(loc, result, scaledValue,
203 arith::IntegerOverflowFlags::nsw);
210 class ExpandAffineIndexOpsPass
211 :
public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
213 ExpandAffineIndexOpsPass() =
default;
215 void runOnOperation()
override {
220 return signalPassFailure();
228 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
233 return std::make_unique<ExpandAffineIndexOpsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...