MLIR  21.0.0git
RaiseMemrefDialect.cpp
Go to the documentation of this file.
1 //===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functionality to convert memref load and store ops to
10 // the corresponding affine ops, inferring the affine map as needed.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/AffineExpr.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/Operation.h"
23 #include "mlir/Pass/Pass.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/Debug.h"
26 
27 namespace mlir {
28 namespace affine {
29 #define GEN_PASS_DEF_RAISEMEMREFDIALECT
30 #include "mlir/Dialect/Affine/Passes.h.inc"
31 } // namespace affine
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "raise-memref-to-affine"
35 
36 using namespace mlir;
37 using namespace mlir::affine;
38 
39 namespace {
40 
41 /// Find the index of the given value in the `dims` list,
42 /// and append it if it was not already in the list. The
43 /// dims list is a list of symbols or dimensions of the
44 /// affine map. Within the results of an affine map, they
45 /// are identified by their index, which is why we need
46 /// this function.
47 static std::optional<size_t>
48 findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
49  function_ref<bool(Value)> isValidElement) {
50 
51  Value *loopIV = llvm::find(dims, value);
52  if (loopIV != dims.end()) {
53  // We found an IV that already has an index, return that index.
54  return {std::distance(dims.begin(), loopIV)};
55  }
56  if (isValidElement(value)) {
57  // This is a valid element for the dim/symbol list, push this as a
58  // parameter.
59  size_t idx = dims.size();
60  dims.push_back(value);
61  return idx;
62  }
63  return std::nullopt;
64 }
65 
66 /// Convert a value to an affine expr if possible. Adds dims and symbols
67 /// if needed.
68 static AffineExpr toAffineExpr(Value value,
69  llvm::SmallVectorImpl<Value> &affineDims,
70  llvm::SmallVectorImpl<Value> &affineSymbols) {
71  using namespace matchers;
72  IntegerAttr::ValueType cst;
73  if (matchPattern(value, m_ConstantInt(&cst))) {
74  return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
75  }
76 
77  Operation *definingOp = value.getDefiningOp();
78  if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
79  llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
80  // TODO: replace recursion with explicit stack.
81  // For the moment this can be tolerated as we only recurse on
82  // arith.addi and arith.muli, so there cannot be any infinite
83  // recursion. The depth of these expressions should be in most
84  // cases very manageable, as affine expressions should be as
85  // simple as `a + b * c`.
86  AffineExpr lhsE =
87  toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
88  AffineExpr rhsE =
89  toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
90 
91  if (lhsE && rhsE) {
93  if (isa<arith::AddIOp>(definingOp)) {
95  } else {
97 
98  if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
99  // This is not an affine expression, give up.
100  return {};
101  }
102  }
103  return getAffineBinaryOpExpr(kind, lhsE, rhsE);
104  }
105  return {};
106  }
107 
108  if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
109  return affine::isValidSymbol(v);
110  })) {
111  return getAffineSymbolExpr(*dimIx, value.getContext());
112  }
113 
114  if (auto dimIx = findInListOrAdd(
115  value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
116 
117  return getAffineDimExpr(*dimIx, value.getContext());
118  }
119 
120  return {};
121 }
122 
123 static LogicalResult
124 computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
125  llvm::SmallVectorImpl<Value> &mapArgs) {
126  SmallVector<AffineExpr> results;
127  SmallVector<Value> symbols;
128  SmallVector<Value> dims;
129 
130  for (Value indexExpr : indices) {
131  AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
132  if (!res) {
133  return failure();
134  }
135  results.push_back(res);
136  }
137 
138  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
139 
140  dims.append(symbols);
141  mapArgs.swap(dims);
142  return success();
143 }
144 
145 struct RaiseMemrefDialect
146  : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
147 
148  void runOnOperation() override {
149  auto *ctx = &getContext();
150  Operation *op = getOperation();
151  IRRewriter rewriter(ctx);
152  AffineMap map;
153  SmallVector<Value> mapArgs;
154  op->walk([&](Operation *op) {
155  rewriter.setInsertionPoint(op);
156  if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
157 
158  if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
159  mapArgs))) {
160  rewriter.replaceOpWithNewOp<AffineStoreOp>(
161  op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
162  return;
163  }
164 
165  LLVM_DEBUG(llvm::dbgs()
166  << "[affine] Cannot raise memref op: " << op << "\n");
167 
168  } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
169  if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
170  mapArgs))) {
171  rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
172  mapArgs);
173  return;
174  }
175  LLVM_DEBUG(llvm::dbgs()
176  << "[affine] Cannot raise memref op: " << op << "\n");
177  }
178  });
179  }
180 };
181 
182 } // namespace
183 
184 std::unique_ptr<OperationPass<func::FuncOp>>
186  return std::make_unique<RaiseMemrefDialect>();
187 }
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1194::ArityGroupAndKind::Kind kind
Base type for affine expression.
Definition: AffineExpr.h:68
bool isSymbolicOrConstant() const
Returns true if this expression is made out of only symbols and constants, i.e., it does not involve ...
Definition: AffineExpr.cpp:188
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:288
std::unique_ptr< OperationPass< func::FuncOp > > createRaiseMemrefToAffine()
Creates a pass that converts some memref operators to affine operators.
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:409
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
AffineExprKind
Definition: AffineExpr.h:40
@ Mul
RHS of mul is always a constant or a symbolic expression.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631