18#include "llvm/Support/Debug.h"
22#define GEN_PASS_DEF_NORMALIZEMEMREFSPASS
23#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
27#define DEBUG_TYPE "normalize-memrefs"
41struct NormalizeMemRefs
43 void runOnOperation()
override;
44 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
45 bool areMemRefsNormalizable(func::FuncOp funcOp);
46 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
47 void setCalleesAndCallersNonNormalizable(
48 func::FuncOp funcOp, ModuleOp moduleOp,
55void NormalizeMemRefs::runOnOperation() {
56 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing Memrefs...\n");
57 ModuleOp moduleOp = getOperation();
65 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
73 moduleOp.walk([&](func::FuncOp funcOp) {
74 if (normalizableFuncs.contains(funcOp)) {
75 if (!areMemRefsNormalizable(funcOp)) {
76 LLVM_DEBUG(llvm::dbgs()
77 <<
"@" << funcOp.getName()
78 <<
" contains ops that cannot normalize MemRefs\n");
82 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
88 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing " << normalizableFuncs.size()
91 for (func::FuncOp &funcOp : normalizableFuncs)
92 normalizeFuncOpMemRefs(funcOp, moduleOp);
100 return llvm::all_of(opUsers, [](
Operation *op) {
107void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
108 func::FuncOp funcOp, ModuleOp moduleOp,
110 if (!normalizableFuncs.contains(funcOp))
114 llvm::dbgs() <<
"@" << funcOp.getName()
115 <<
" calls or is called by non-normalizable function\n");
116 normalizableFuncs.erase(funcOp);
118 std::optional<SymbolTable::UseRange> symbolUses =
119 funcOp.getSymbolUses(moduleOp);
120 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
123 func::FuncOp parentFuncOp =
124 symbolUse.getUser()->getParentOfType<func::FuncOp>();
125 for (func::FuncOp &funcOp : normalizableFuncs) {
126 if (parentFuncOp == funcOp) {
127 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
135 funcOp.walk([&](func::CallOp callOp) {
136 StringAttr callee = callOp.getCalleeAttr().getAttr();
137 for (func::FuncOp &funcOp : normalizableFuncs) {
139 if (callee == funcOp.getNameAttr()) {
140 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
155bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
157 if (funcOp.isExternal())
161 .
walk([&](AllocOp allocOp) -> WalkResult {
162 Value oldMemRef = allocOp.getResult();
163 if (!allocOp.getType().getLayout().isIdentity() &&
172 .
walk([&](AllocaOp allocaOp) -> WalkResult {
173 Value oldMemRef = allocaOp.getResult();
174 if (!allocaOp.getType().getLayout().isIdentity() &&
183 .
walk([&](func::CallOp callOp) -> WalkResult {
184 for (
unsigned resIndex :
185 llvm::seq<unsigned>(0, callOp.getNumResults())) {
186 Value oldMemRef = callOp.getResult(resIndex);
187 if (
auto oldMemRefType =
188 dyn_cast<MemRefType>(oldMemRef.
getType()))
189 if (!oldMemRefType.getLayout().isIdentity() &&
198 for (
unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
199 BlockArgument oldMemRef = funcOp.getArgument(argIndex);
200 if (
auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType()))
201 if (!oldMemRefType.getLayout().isIdentity() &&
215void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
217 FunctionType functionType = funcOp.getFunctionType();
218 SmallVector<Type, 4> resultTypes;
219 FunctionType newFuncType;
220 resultTypes = llvm::to_vector<4>(functionType.getResults());
224 if (!funcOp.isExternal()) {
225 SmallVector<Type, 8> argTypes;
226 for (
const auto &argEn : llvm::enumerate(funcOp.getArguments()))
227 argTypes.push_back(argEn.value().getType());
231 funcOp.walk([&](func::ReturnOp returnOp) {
232 for (
const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
233 Type opType = operandEn.value().getType();
234 MemRefType memrefType = dyn_cast<MemRefType>(opType);
237 if (!memrefType || memrefType == resultTypes[operandEn.index()])
247 if (memrefType.getLayout().isIdentity())
248 resultTypes[operandEn.index()] = memrefType;
254 newFuncType = FunctionType::get(&
getContext(), argTypes,
263 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
266 std::optional<SymbolTable::UseRange> symbolUses =
267 funcOp.getSymbolUses(moduleOp);
268 for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
269 Operation *userOp = symbolUse.getUser();
270 OpBuilder builder(userOp);
275 auto callOp = dyn_cast<func::CallOp>(userOp);
278 Operation *newCallOp =
279 func::CallOp::create(builder, userOp->
getLoc(), callOp.getCalleeAttr(),
281 bool replacingMemRefUsesFailed =
false;
282 bool returnTypeChanged =
false;
283 for (
unsigned resIndex : llvm::seq<unsigned>(0, userOp->
getNumResults())) {
284 OpResult oldResult = userOp->
getResult(resIndex);
285 OpResult newResult = newCallOp->
getResult(resIndex);
291 AffineMap layoutMap =
292 cast<MemRefType>(oldResult.
getType()).getLayout().getAffineMap();
293 if (
failed(replaceAllMemRefUsesWith(oldResult, newResult,
305 replacingMemRefUsesFailed =
true;
308 returnTypeChanged =
true;
310 if (replacingMemRefUsesFailed)
315 if (returnTypeChanged) {
325 func::FuncOp parentFuncOp = newCallOp->
getParentOfType<func::FuncOp>();
326 funcOpsToUpdate.insert(parentFuncOp);
331 if (!funcOp.isExternal())
332 funcOp.setType(newFuncType);
337 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
338 updateFunctionSignature(parentFuncOp, moduleOp);
345void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
350 SmallVector<AllocOp, 4> allocOps;
351 SmallVector<AllocaOp> allocaOps;
352 SmallVector<ReinterpretCastOp> reinterpretCastOps;
353 funcOp.walk([&](Operation *op) {
354 if (
auto allocOp = dyn_cast<AllocOp>(op))
355 allocOps.push_back(allocOp);
356 else if (
auto allocaOp = dyn_cast<AllocaOp>(op))
357 allocaOps.push_back(allocaOp);
358 else if (
auto reinterpretCastOp = dyn_cast<ReinterpretCastOp>(op))
359 reinterpretCastOps.push_back(reinterpretCastOp);
361 for (AllocOp allocOp : allocOps)
362 (void)normalizeMemRef(allocOp);
363 for (AllocaOp allocaOp : allocaOps)
364 (void)normalizeMemRef(allocaOp);
365 for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
366 (void)normalizeMemRef(reinterpretCastOp);
371 FunctionType functionType = funcOp.getFunctionType();
372 SmallVector<Location> functionArgLocs(llvm::map_range(
373 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
374 SmallVector<Type, 8> inputTypes;
376 for (
unsigned argIndex :
377 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
378 Type argType = functionType.getInput(argIndex);
379 MemRefType memrefType = dyn_cast<MemRefType>(argType);
383 inputTypes.push_back(argType);
388 MemRefType newMemRefType = normalizeMemRefType(memrefType);
389 if (newMemRefType == memrefType || funcOp.isExternal()) {
392 inputTypes.push_back(newMemRefType);
397 BlockArgument newMemRef = funcOp.front().insertArgument(
398 argIndex, newMemRefType, functionArgLocs[argIndex]);
399 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
400 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
402 if (
failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
412 funcOp.front().eraseArgument(argIndex);
418 funcOp.front().eraseArgument(argIndex + 1);
426 funcOp.walk([&](Operation *op) {
427 if (op->
hasTrait<OpTrait::MemRefsNormalizable>() &&
429 !funcOp.isExternal()) {
431 Operation *newOp = createOpResultsNormalized(funcOp, op);
436 bool replacingMemRefUsesFailed =
false;
437 for (
unsigned resIndex : llvm::seq<unsigned>(0, op->
getNumResults())) {
439 Value oldMemRef = op->
getResult(resIndex);
440 Value newMemRef = newOp->
getResult(resIndex);
441 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType());
445 MemRefType newMemRefType = cast<MemRefType>(newMemRef.
getType());
446 if (oldMemRefType == newMemRefType)
449 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
450 if (
failed(replaceAllMemRefUsesWith(oldMemRef,
460 replacingMemRefUsesFailed =
true;
464 if (!replacingMemRefUsesFailed) {
479 if (funcOp.isExternal()) {
480 SmallVector<Type, 4> resultTypes;
481 for (
unsigned resIndex :
482 llvm::seq<unsigned>(0, functionType.getNumResults())) {
483 Type resType = functionType.getResult(resIndex);
484 MemRefType memrefType = dyn_cast<MemRefType>(resType);
488 resultTypes.push_back(resType);
493 MemRefType newMemRefType = normalizeMemRefType(memrefType);
494 resultTypes.push_back(newMemRefType);
497 FunctionType newFuncType =
501 funcOp.setType(newFuncType);
503 updateFunctionSignature(funcOp, moduleOp);
511Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
519 SmallVector<Type, 4> resultTypes;
521 bool resultTypeNormalized =
false;
522 for (
unsigned resIndex : llvm::seq<unsigned>(0, oldOp->
getNumResults())) {
524 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
527 resultTypes.push_back(resultType);
532 MemRefType newMemRefType = normalizeMemRefType(memrefType);
533 if (newMemRefType == memrefType) {
536 resultTypes.push_back(memrefType);
539 resultTypes.push_back(newMemRefType);
540 resultTypeNormalized =
true;
542 result.addTypes(resultTypes);
545 if (resultTypeNormalized) {
static bool isMemRefNormalizable(Value::user_range opUsers)
Check whether all the uses of oldMemRef are either dereferencing uses or the op is of type : DeallocO...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Type getType() const
Return the type of this value.
iterator_range< user_iterator > user_range
user_range getUsers() const
static WalkResult advance()
static WalkResult interrupt()
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet