18 #include "llvm/Support/Debug.h"
22 #define GEN_PASS_DEF_NORMALIZEMEMREFSPASS
23 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
27 #define DEBUG_TYPE "normalize-memrefs"
41 struct NormalizeMemRefs
42 :
public memref::impl::NormalizeMemRefsPassBase<NormalizeMemRefs> {
43 void runOnOperation()
override;
44 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
45 bool areMemRefsNormalizable(func::FuncOp funcOp);
46 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
47 void setCalleesAndCallersNonNormalizable(
48 func::FuncOp funcOp, ModuleOp moduleOp,
55 void NormalizeMemRefs::runOnOperation() {
56 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing Memrefs...\n");
57 ModuleOp moduleOp = getOperation();
65 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
73 moduleOp.walk([&](func::FuncOp funcOp) {
74 if (normalizableFuncs.contains(funcOp)) {
75 if (!areMemRefsNormalizable(funcOp)) {
76 LLVM_DEBUG(llvm::dbgs()
77 <<
"@" << funcOp.getName()
78 <<
" contains ops that cannot normalize MemRefs\n");
82 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
88 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing " << normalizableFuncs.size()
91 for (func::FuncOp &funcOp : normalizableFuncs)
92 normalizeFuncOpMemRefs(funcOp, moduleOp);
100 return llvm::all_of(opUsers, [](
Operation *op) {
107 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
108 func::FuncOp funcOp, ModuleOp moduleOp,
110 if (!normalizableFuncs.contains(funcOp))
114 llvm::dbgs() <<
"@" << funcOp.getName()
115 <<
" calls or is called by non-normalizable function\n");
116 normalizableFuncs.erase(funcOp);
118 std::optional<SymbolTable::UseRange> symbolUses =
119 funcOp.getSymbolUses(moduleOp);
123 func::FuncOp parentFuncOp =
124 symbolUse.getUser()->getParentOfType<func::FuncOp>();
125 for (func::FuncOp &funcOp : normalizableFuncs) {
126 if (parentFuncOp == funcOp) {
127 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
135 funcOp.walk([&](func::CallOp callOp) {
136 StringAttr callee = callOp.getCalleeAttr().getAttr();
137 for (func::FuncOp &funcOp : normalizableFuncs) {
139 if (callee == funcOp.getNameAttr()) {
140 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
155 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
157 if (funcOp.isExternal())
162 Value oldMemRef = allocOp.getResult();
163 if (!allocOp.getType().getLayout().isIdentity() &&
173 Value oldMemRef = allocaOp.getResult();
174 if (!allocaOp.getType().getLayout().isIdentity() &&
184 for (
unsigned resIndex :
185 llvm::seq<unsigned>(0, callOp.getNumResults())) {
186 Value oldMemRef = callOp.getResult(resIndex);
187 if (auto oldMemRefType =
188 dyn_cast<MemRefType>(oldMemRef.getType()))
189 if (!oldMemRefType.getLayout().isIdentity() &&
190 !isMemRefNormalizable(oldMemRef.getUsers()))
191 return WalkResult::interrupt();
198 for (
unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
200 if (
auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType()))
201 if (!oldMemRefType.getLayout().isIdentity() &&
215 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
217 FunctionType functionType = funcOp.getFunctionType();
219 FunctionType newFuncType;
220 resultTypes = llvm::to_vector<4>(functionType.getResults());
224 if (!funcOp.isExternal()) {
227 argTypes.push_back(argEn.value().getType());
231 funcOp.walk([&](func::ReturnOp returnOp) {
232 for (
const auto &operandEn :
llvm::enumerate(returnOp.getOperands())) {
233 Type opType = operandEn.value().getType();
234 MemRefType memrefType = dyn_cast<MemRefType>(opType);
237 if (!memrefType || memrefType == resultTypes[operandEn.index()])
247 if (memrefType.getLayout().isIdentity())
248 resultTypes[operandEn.index()] = memrefType;
263 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
266 std::optional<SymbolTable::UseRange> symbolUses =
267 funcOp.getSymbolUses(moduleOp);
275 auto callOp = dyn_cast<func::CallOp>(userOp);
279 func::CallOp::create(builder, userOp->
getLoc(), callOp.getCalleeAttr(),
281 bool replacingMemRefUsesFailed =
false;
282 bool returnTypeChanged =
false;
283 for (
unsigned resIndex : llvm::seq<unsigned>(0, userOp->
getNumResults())) {
292 cast<MemRefType>(oldResult.
getType()).getLayout().getAffineMap();
305 replacingMemRefUsesFailed =
true;
308 returnTypeChanged =
true;
310 if (replacingMemRefUsesFailed)
315 if (returnTypeChanged) {
325 func::FuncOp parentFuncOp = newCallOp->
getParentOfType<func::FuncOp>();
326 funcOpsToUpdate.insert(parentFuncOp);
331 if (!funcOp.isExternal())
332 funcOp.setType(newFuncType);
337 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
338 updateFunctionSignature(parentFuncOp, moduleOp);
345 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
354 if (
auto allocOp = dyn_cast<AllocOp>(op))
355 allocOps.push_back(allocOp);
356 else if (
auto allocaOp = dyn_cast<AllocaOp>(op))
357 allocaOps.push_back(allocaOp);
358 else if (
auto reinterpretCastOp = dyn_cast<ReinterpretCastOp>(op))
359 reinterpretCastOps.push_back(reinterpretCastOp);
361 for (AllocOp allocOp : allocOps)
363 for (AllocaOp allocaOp : allocaOps)
365 for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
371 FunctionType functionType = funcOp.getFunctionType();
373 funcOp.getArguments(), [](
BlockArgument arg) { return arg.getLoc(); }));
376 for (
unsigned argIndex :
377 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
378 Type argType = functionType.getInput(argIndex);
379 MemRefType memrefType = dyn_cast<MemRefType>(argType);
383 inputTypes.push_back(argType);
389 if (newMemRefType == memrefType || funcOp.isExternal()) {
392 inputTypes.push_back(newMemRefType);
398 argIndex, newMemRefType, functionArgLocs[argIndex]);
400 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
412 funcOp.front().eraseArgument(argIndex);
418 funcOp.front().eraseArgument(argIndex + 1);
429 !funcOp.isExternal()) {
431 Operation *newOp = createOpResultsNormalized(funcOp, op);
436 bool replacingMemRefUsesFailed = false;
437 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
439 Value oldMemRef = op->getResult(resIndex);
440 Value newMemRef = newOp->getResult(resIndex);
441 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
445 MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
446 if (oldMemRefType == newMemRefType)
449 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
450 if (failed(replaceAllMemRefUsesWith(oldMemRef,
460 replacingMemRefUsesFailed = true;
464 if (!replacingMemRefUsesFailed) {
467 op->replaceAllUsesWith(newOp);
479 if (funcOp.isExternal()) {
481 for (
unsigned resIndex :
482 llvm::seq<unsigned>(0, functionType.getNumResults())) {
483 Type resType = functionType.getResult(resIndex);
484 MemRefType memrefType = dyn_cast<MemRefType>(resType);
488 resultTypes.push_back(resType);
494 resultTypes.push_back(newMemRefType);
497 FunctionType newFuncType =
501 funcOp.setType(newFuncType);
503 updateFunctionSignature(funcOp, moduleOp);
511 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
517 result.addAttributes(oldOp->
getAttrs());
521 bool resultTypeNormalized =
false;
522 for (
unsigned resIndex : llvm::seq<unsigned>(0, oldOp->
getNumResults())) {
524 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
527 resultTypes.push_back(resultType);
533 if (newMemRefType == memrefType) {
536 resultTypes.push_back(memrefType);
539 resultTypes.push_back(newMemRefType);
540 resultTypeNormalized =
true;
542 result.addTypes(resultTypes);
545 if (resultTypeNormalized) {
548 Region *newRegion = result.addRegion();
551 return bb.create(result);
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefNormalizable(Value::user_range opUsers)
Check whether all the uses of oldMemRef are either dereferencing uses or the op is of type : DeallocO...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class represents a specific symbol use.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
MemRefType normalizeMemRefType(MemRefType memrefType)
Normalizes memrefType so that the affine layout map of the memref is transformed to an identity map w...
LogicalResult normalizeMemRef(AllocLikeOp op)
Rewrites the memref defined by alloc or reinterpret_cast op to have an identity layout map and update...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, llvm::function_ref< bool(Operation *)> userFilterFn=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.