23#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
24#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
43 bool knownNonNegative) {
44 if (staticBasis.empty())
48 result.reserve(staticBasis.size());
49 size_t dynamicIndex = dynamicBasis.size();
50 Value dynamicPart =
nullptr;
54 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
56 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
57 for (
int64_t elem : llvm::reverse(staticBasis)) {
58 if (ShapedType::isDynamic(elem)) {
63 arith::MulIOp::create(rewriter, loc, dynamicPart,
64 dynamicBasis[dynamicIndex - 1], ovflags);
66 dynamicPart = dynamicBasis[dynamicIndex - 1];
72 if (dynamicPart && staticPart == 1) {
73 result.push_back(dynamicPart);
79 arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
91 if (value.
getType() == targetType)
93 return vector::BroadcastOp::create(rewriter, loc, targetType, value);
98 AffineDelinearizeIndexOp op) {
100 Value linearIdx = op.getLinearIndex();
101 unsigned numResults = op.getNumResults();
103 if (numResults == staticBasis.size())
104 staticBasis = staticBasis.drop_front();
106 if (numResults == 1) {
112 results.reserve(numResults);
120 for (
Value &stride : strides)
124 arith::ConstantOp::create(rewriter, loc, rewriter.
getZeroAttr(indexType));
127 arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
128 results.push_back(initialPart);
130 auto emitModTerm = [&](
Value stride) ->
Value {
131 Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
132 Value remainderNegative = arith::CmpIOp::create(
133 rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
137 Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
138 arith::IntegerOverflowFlags::nsw);
139 Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
140 corrected, remainder);
145 for (
size_t i = 0, e = strides.size() - 1; i < e; ++i) {
146 Value thisStride = strides[i];
147 Value nextStride = strides[i + 1];
148 Value modulus = emitModTerm(thisStride);
152 Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
153 results.push_back(divided);
156 results.push_back(emitModTerm(strides.back()));
163 AffineLinearizeIndexOp op) {
165 if (op.getMultiIndex().empty()) {
167 op, rewriter.
getZeroAttr(op.getLinearIndex().getType()));
173 Type indexType = op.getLinearIndex().getType();
174 size_t numIndexes = multiIndex.size();
176 if (numIndexes == staticBasis.size())
177 staticBasis = staticBasis.drop_front();
184 for (
Value &stride : strides)
188 scaledValues.reserve(numIndexes);
193 for (
auto [stride, idxOp] :
194 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
195 Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
196 arith::IntegerOverflowFlags::nsw);
198 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
200 scaledValues.emplace_back(
206 llvm::stable_sort(scaledValues,
207 [&](
auto l,
auto r) {
return l.second > r.second; });
210 for (
auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
211 std::ignore = numHoistableLoops;
212 result = arith::AddIOp::create(rewriter, loc,
result, scaledValue,
213 arith::IntegerOverflowFlags::nsw);
220struct LowerDelinearizeIndexOps
223 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
229struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
231 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
232 PatternRewriter &rewriter)
const override {
233 return affine::lowerAffineLinearizeIndexOp(rewriter, op);
237class ExpandAffineIndexOpsPass
238 :
public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
240 ExpandAffineIndexOpsPass() =
default;
242 void runOnOperation()
override {
244 RewritePatternSet patterns(context);
247 return signalPassFailure();
255 patterns.
insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
260 return std::make_unique<ExpandAffineIndexOpsPass>();
static Value broadcastToMatchType(RewriterBase &rewriter, Location loc, Value value, Type targetType)
Broadcast a scalar value to match the given type.
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, AffineDelinearizeIndexOp op)
Lowers affine.delinearize_index into a sequence of division and remainder operations.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter, AffineLinearizeIndexOp op)
Lowers affine.linearize_index into a sequence of multiplications and additions.
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Performs explicit copying for the contiguous sequence of operations in the block iterator range [‘beg...
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...