42 bool knownNonNegative) {
43 if (staticBasis.empty())
47 result.reserve(staticBasis.size());
48 size_t dynamicIndex = dynamicBasis.size();
49 Value dynamicPart =
nullptr;
53 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
55 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
56 for (
int64_t elem : llvm::reverse(staticBasis)) {
57 if (ShapedType::isDynamic(elem)) {
62 arith::MulIOp::create(rewriter, loc, dynamicPart,
63 dynamicBasis[dynamicIndex - 1], ovflags);
65 dynamicPart = dynamicBasis[dynamicIndex - 1];
71 if (dynamicPart && staticPart == 1) {
72 result.push_back(dynamicPart);
78 arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
88 AffineDelinearizeIndexOp op) {
90 Value linearIdx = op.getLinearIndex();
91 unsigned numResults = op.getNumResults();
93 if (numResults == staticBasis.size())
94 staticBasis = staticBasis.drop_front();
96 if (numResults == 1) {
102 results.reserve(numResults);
110 arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
111 results.push_back(initialPart);
113 auto emitModTerm = [&](
Value stride) ->
Value {
114 Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
115 Value remainderNegative = arith::CmpIOp::create(
116 rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
120 Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
121 arith::IntegerOverflowFlags::nsw);
122 Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
123 corrected, remainder);
128 for (
size_t i = 0, e = strides.size() - 1; i < e; ++i) {
131 Value modulus = emitModTerm(thisStride);
135 Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
136 results.push_back(divided);
139 results.push_back(emitModTerm(strides.back()));
141 rewriter.replaceOp(op, results);
146 AffineLinearizeIndexOp op) {
148 if (op.getMultiIndex().empty()) {
155 size_t numIndexes = multiIndex.size();
157 if (numIndexes == staticBasis.size())
158 staticBasis = staticBasis.drop_front();
164 scaledValues.reserve(numIndexes);
169 for (
auto [stride, idxOp] :
170 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
171 Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
172 arith::IntegerOverflowFlags::nsw);
174 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
176 scaledValues.emplace_back(
182 llvm::stable_sort(scaledValues,
183 [&](
auto l,
auto r) {
return l.second > r.second; });
186 for (
auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
187 std::ignore = numHoistableLoops;
188 result = arith::AddIOp::create(rewriter, loc,
result, scaledValue,
189 arith::IntegerOverflowFlags::nsw);
196struct LowerDelinearizeIndexOps
199 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
205struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
207 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
208 PatternRewriter &rewriter)
const override {
213class ExpandAffineIndexOpsPass
216 ExpandAffineIndexOpsPass() =
default;
218 void runOnOperation()
override {
220 RewritePatternSet
patterns(context);
223 return signalPassFailure();
231 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
236 return std::make_unique<ExpandAffineIndexOpsPass>();
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...