22 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
23 #include "mlir/Dialect/Affine/Passes.h.inc"
42 bool knownNonNegative) {
43 if (staticBasis.empty())
47 result.reserve(staticBasis.size());
48 size_t dynamicIndex = dynamicBasis.size();
49 Value dynamicPart =
nullptr;
50 int64_t staticPart = 1;
53 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
55 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
56 for (int64_t elem : llvm::reverse(staticBasis)) {
57 if (ShapedType::isDynamic(elem)) {
62 arith::MulIOp::create(rewriter, loc, dynamicPart,
63 dynamicBasis[dynamicIndex - 1], ovflags);
65 dynamicPart = dynamicBasis[dynamicIndex - 1];
71 if (dynamicPart && staticPart == 1) {
72 result.push_back(dynamicPart);
75 rewriter.
createOrFold<arith::ConstantIndexOp>(loc, staticPart);
78 arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
79 result.push_back(stride);
82 std::reverse(result.begin(), result.end());
88 AffineDelinearizeIndexOp op) {
90 Value linearIdx = op.getLinearIndex();
91 unsigned numResults = op.getNumResults();
93 if (numResults == staticBasis.size())
94 staticBasis = staticBasis.drop_front();
96 if (numResults == 1) {
102 results.reserve(numResults);
110 arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
111 results.push_back(initialPart);
113 auto emitModTerm = [&](
Value stride) ->
Value {
114 Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
115 Value remainderNegative = arith::CmpIOp::create(
116 rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
120 Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
121 arith::IntegerOverflowFlags::nsw);
122 Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
123 corrected, remainder);
128 for (
size_t i = 0, e = strides.size() - 1; i < e; ++i) {
129 Value thisStride = strides[i];
130 Value nextStride = strides[i + 1];
131 Value modulus = emitModTerm(thisStride);
135 Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
136 results.push_back(divided);
139 results.push_back(emitModTerm(strides.back()));
146 AffineLinearizeIndexOp op) {
148 if (op.getMultiIndex().empty()) {
155 size_t numIndexes = multiIndex.size();
157 if (numIndexes == staticBasis.size())
158 staticBasis = staticBasis.drop_front();
164 scaledValues.reserve(numIndexes);
169 for (
auto [stride, idxOp] :
170 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
171 Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
172 arith::IntegerOverflowFlags::nsw);
174 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
176 scaledValues.emplace_back(
182 llvm::stable_sort(scaledValues,
183 [&](
auto l,
auto r) {
return l.second > r.second; });
185 Value result = scaledValues.front().first;
186 for (
auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
187 std::ignore = numHoistableLoops;
188 result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
189 arith::IntegerOverflowFlags::nsw);
196 struct LowerDelinearizeIndexOps
199 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
205 struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
207 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
213 class ExpandAffineIndexOpsPass
214 :
public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
216 ExpandAffineIndexOpsPass() =
default;
218 void runOnOperation()
override {
223 return signalPassFailure();
231 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
236 return std::make_unique<ExpandAffineIndexOpsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, AffineDelinearizeIndexOp op)
Lowers affine.delinearize_index into a sequence of division and remainder operations.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter, AffineLinearizeIndexOp op)
Lowers affine.linearize_index into a sequence of multiplications and additions.
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...