24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25 #include "mlir/Dialect/Affine/Passes.h.inc"
42 if (staticBasis.empty())
46 result.reserve(staticBasis.size());
47 size_t dynamicIndex = dynamicBasis.size();
48 Value dynamicPart =
nullptr;
49 int64_t staticPart = 1;
50 for (int64_t elem : llvm::reverse(staticBasis)) {
51 if (ShapedType::isDynamic(elem)) {
53 dynamicPart = rewriter.
create<arith::MulIOp>(
54 loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
56 dynamicPart = dynamicBasis[dynamicIndex - 1];
62 if (dynamicPart && staticPart == 1) {
63 result.push_back(dynamicPart);
66 rewriter.
createOrFold<arith::ConstantIndexOp>(loc, staticPart);
68 stride = rewriter.
create<arith::MulIOp>(loc, dynamicPart, stride);
69 result.push_back(stride);
72 std::reverse(result.begin(), result.end());
79 struct LowerDelinearizeIndexOps
82 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
85 Value linearIdx = op.getLinearIndex();
86 unsigned numResults = op.getNumResults();
88 if (numResults == staticBasis.size())
89 staticBasis = staticBasis.drop_front();
91 if (numResults == 1) {
97 results.reserve(numResults);
104 rewriter.
create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
105 results.push_back(initialPart);
107 auto emitModTerm = [&](
Value stride) ->
Value {
108 Value remainder = rewriter.
create<arith::RemSIOp>(loc, linearIdx, stride);
109 Value remainderNegative = rewriter.
create<arith::CmpIOp>(
110 loc, arith::CmpIPredicate::slt, remainder, zero);
111 Value corrected = rewriter.
create<arith::AddIOp>(loc, remainder, stride);
112 Value mod = rewriter.
create<arith::SelectOp>(loc, remainderNegative,
113 corrected, remainder);
118 for (
size_t i = 0, e = strides.size() - 1; i < e; ++i) {
119 Value thisStride = strides[i];
120 Value nextStride = strides[i + 1];
121 Value modulus = emitModTerm(thisStride);
125 Value divided = rewriter.
create<arith::DivSIOp>(loc, modulus, nextStride);
126 results.push_back(divided);
129 results.push_back(emitModTerm(strides.back()));
140 struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
142 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
145 if (op.getMultiIndex().empty()) {
152 size_t numIndexes = multiIndex.size();
154 if (numIndexes == staticBasis.size())
155 staticBasis = staticBasis.drop_front();
160 scaledValues.reserve(numIndexes);
165 for (
auto [stride, idxOp] :
166 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
168 rewriter.
create<arith::MulIOp>(loc, idxOp.get(), stride);
170 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
172 scaledValues.emplace_back(
178 llvm::stable_sort(scaledValues,
179 [&](
auto l,
auto r) {
return l.second > r.second; });
181 Value result = scaledValues.front().first;
182 for (
auto [scaledValue, numHoistableLoops] :
183 llvm::drop_begin(scaledValues)) {
184 std::ignore = numHoistableLoops;
185 result = rewriter.
create<arith::AddIOp>(loc, result, scaledValue);
192 class ExpandAffineIndexOpsPass
193 :
public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
195 ExpandAffineIndexOpsPass() =
default;
197 void runOnOperation()
override {
202 return signalPassFailure();
210 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
215 return std::make_unique<ExpandAffineIndexOpsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...