MLIR 22.0.0git
RaiseMemrefDialect.cpp
Go to the documentation of this file.
1//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functionality to convert memref load and store ops to
10// the corresponding affine ops, inferring the affine map as needed.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/Debug.h"
26
27namespace mlir {
28namespace affine {
29#define GEN_PASS_DEF_RAISEMEMREFDIALECT
30#include "mlir/Dialect/Affine/Passes.h.inc"
31} // namespace affine
32} // namespace mlir
33
34#define DEBUG_TYPE "raise-memref-to-affine"
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40
41/// Find the index of the given value in the `dims` list,
42/// and append it if it was not already in the list. The
43/// dims list is a list of symbols or dimensions of the
44/// affine map. Within the results of an affine map, they
45/// are identified by their index, which is why we need
46/// this function.
47static std::optional<size_t>
48findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
49 function_ref<bool(Value)> isValidElement) {
50
51 Value *loopIV = llvm::find(dims, value);
52 if (loopIV != dims.end()) {
53 // We found an IV that already has an index, return that index.
54 return {std::distance(dims.begin(), loopIV)};
55 }
56 if (isValidElement(value)) {
57 // This is a valid element for the dim/symbol list, push this as a
58 // parameter.
59 size_t idx = dims.size();
60 dims.push_back(value);
61 return idx;
62 }
63 return std::nullopt;
64}
65
66/// Convert a value to an affine expr if possible. Adds dims and symbols
67/// if needed.
68static AffineExpr toAffineExpr(Value value,
70 llvm::SmallVectorImpl<Value> &affineSymbols) {
71 using namespace matchers;
72 IntegerAttr::ValueType cst;
73 if (matchPattern(value, m_ConstantInt(&cst))) {
74 return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
75 }
76
77 Operation *definingOp = value.getDefiningOp();
78 if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
79 llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
80 // TODO: replace recursion with explicit stack.
81 // For the moment this can be tolerated as we only recurse on
82 // arith.addi and arith.muli, so there cannot be any infinite
83 // recursion. The depth of these expressions should be in most
84 // cases very manageable, as affine expressions should be as
85 // simple as `a + b * c`.
86 AffineExpr lhsE =
87 toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
88 AffineExpr rhsE =
89 toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
90
91 if (lhsE && rhsE) {
92 AffineExprKind kind;
93 if (isa<arith::AddIOp>(definingOp)) {
95 } else {
97
98 if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
99 // This is not an affine expression, give up.
100 return {};
101 }
102 }
103 return getAffineBinaryOpExpr(kind, lhsE, rhsE);
104 }
105 return {};
106 }
107
108 if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
109 return affine::isValidSymbol(v);
110 })) {
111 return getAffineSymbolExpr(*dimIx, value.getContext());
112 }
113
114 if (auto dimIx = findInListOrAdd(
115 value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
116
117 return getAffineDimExpr(*dimIx, value.getContext());
118 }
119
120 return {};
121}
122
123static LogicalResult
124computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
127 SmallVector<Value> symbols;
129
130 for (Value indexExpr : indices) {
131 AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
132 if (!res) {
133 return failure();
134 }
135 results.push_back(res);
136 }
137
138 map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
139
140 dims.append(symbols);
141 mapArgs.swap(dims);
142 return success();
143}
144
145struct RaiseMemrefDialect
146 : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
147
148 void runOnOperation() override {
149 auto *ctx = &getContext();
150 Operation *op = getOperation();
151 IRRewriter rewriter(ctx);
152 AffineMap map;
153 SmallVector<Value> mapArgs;
154 op->walk([&](Operation *op) {
155 rewriter.setInsertionPoint(op);
156 if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
157
158 if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
159 mapArgs))) {
160 rewriter.replaceOpWithNewOp<AffineStoreOp>(
161 op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
162 return;
163 }
164
165 LLVM_DEBUG(llvm::dbgs()
166 << "[affine] Cannot raise memref op: " << op << "\n");
167
168 } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
169 if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
170 mapArgs))) {
171 rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
172 mapArgs);
173 return;
174 }
175 LLVM_DEBUG(llvm::dbgs()
176 << "[affine] Cannot raise memref op: " << op << "\n");
177 }
178 });
179 }
180};
181
182} // namespace
183
184std::unique_ptr<OperationPass<func::FuncOp>>
186 return std::make_unique<RaiseMemrefDialect>();
187}
return success()
b getContext())
auto load
Base type for affine expression.
Definition AffineExpr.h:68
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
std::unique_ptr< OperationPass< func::FuncOp > > createRaiseMemrefToAffine()
Creates a pass that converts some memref operators to affine operators.
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
AffineExprKind
Definition AffineExpr.h:40
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)