19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_NORMALIZEMEMREFSPASS
25 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "normalize-memrefs"
43 struct NormalizeMemRefs
44 :
public memref::impl::NormalizeMemRefsPassBase<NormalizeMemRefs> {
45 void runOnOperation()
override;
46 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
47 bool areMemRefsNormalizable(func::FuncOp funcOp);
48 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
49 void setCalleesAndCallersNonNormalizable(
50 func::FuncOp funcOp, ModuleOp moduleOp,
57 void NormalizeMemRefs::runOnOperation() {
58 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing Memrefs...\n");
59 ModuleOp moduleOp = getOperation();
67 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
75 moduleOp.walk([&](func::FuncOp funcOp) {
76 if (normalizableFuncs.contains(funcOp)) {
77 if (!areMemRefsNormalizable(funcOp)) {
78 LLVM_DEBUG(llvm::dbgs()
79 <<
"@" << funcOp.getName()
80 <<
" contains ops that cannot normalize MemRefs\n");
84 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
90 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing " << normalizableFuncs.size()
93 for (func::FuncOp &funcOp : normalizableFuncs)
94 normalizeFuncOpMemRefs(funcOp, moduleOp);
102 return llvm::all_of(opUsers, [](
Operation *op) {
109 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
110 func::FuncOp funcOp, ModuleOp moduleOp,
112 if (!normalizableFuncs.contains(funcOp))
116 llvm::dbgs() <<
"@" << funcOp.getName()
117 <<
" calls or is called by non-normalizable function\n");
118 normalizableFuncs.erase(funcOp);
120 std::optional<SymbolTable::UseRange> symbolUses =
121 funcOp.getSymbolUses(moduleOp);
125 func::FuncOp parentFuncOp =
126 symbolUse.getUser()->getParentOfType<func::FuncOp>();
127 for (func::FuncOp &funcOp : normalizableFuncs) {
128 if (parentFuncOp == funcOp) {
129 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
137 funcOp.walk([&](func::CallOp callOp) {
138 StringAttr callee = callOp.getCalleeAttr().getAttr();
139 for (func::FuncOp &funcOp : normalizableFuncs) {
141 if (callee == funcOp.getNameAttr()) {
142 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
157 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
159 if (funcOp.isExternal())
164 Value oldMemRef = allocOp.getResult();
165 if (!allocOp.getType().getLayout().isIdentity() &&
175 Value oldMemRef = allocaOp.getResult();
176 if (!allocaOp.getType().getLayout().isIdentity() &&
186 for (
unsigned resIndex :
187 llvm::seq<unsigned>(0, callOp.getNumResults())) {
188 Value oldMemRef = callOp.getResult(resIndex);
189 if (auto oldMemRefType =
190 dyn_cast<MemRefType>(oldMemRef.getType()))
191 if (!oldMemRefType.getLayout().isIdentity() &&
192 !isMemRefNormalizable(oldMemRef.getUsers()))
193 return WalkResult::interrupt();
200 for (
unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
202 if (
auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType()))
203 if (!oldMemRefType.getLayout().isIdentity() &&
217 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
219 FunctionType functionType = funcOp.getFunctionType();
221 FunctionType newFuncType;
222 resultTypes = llvm::to_vector<4>(functionType.getResults());
226 if (!funcOp.isExternal()) {
229 argTypes.push_back(argEn.value().getType());
233 funcOp.walk([&](func::ReturnOp returnOp) {
234 for (
const auto &operandEn :
llvm::enumerate(returnOp.getOperands())) {
235 Type opType = operandEn.value().getType();
236 MemRefType memrefType = dyn_cast<MemRefType>(opType);
239 if (!memrefType || memrefType == resultTypes[operandEn.index()])
249 if (memrefType.getLayout().isIdentity())
250 resultTypes[operandEn.index()] = memrefType;
265 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
268 std::optional<SymbolTable::UseRange> symbolUses =
269 funcOp.getSymbolUses(moduleOp);
277 auto callOp = dyn_cast<func::CallOp>(userOp);
281 builder.
create<func::CallOp>(userOp->
getLoc(), callOp.getCalleeAttr(),
283 bool replacingMemRefUsesFailed =
false;
284 bool returnTypeChanged =
false;
285 for (
unsigned resIndex : llvm::seq<unsigned>(0, userOp->
getNumResults())) {
294 cast<MemRefType>(oldResult.
getType()).getLayout().getAffineMap();
308 replacingMemRefUsesFailed =
true;
311 returnTypeChanged =
true;
313 if (replacingMemRefUsesFailed)
318 if (returnTypeChanged) {
328 func::FuncOp parentFuncOp = newCallOp->
getParentOfType<func::FuncOp>();
329 funcOpsToUpdate.insert(parentFuncOp);
334 if (!funcOp.isExternal())
335 funcOp.setType(newFuncType);
340 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
341 updateFunctionSignature(parentFuncOp, moduleOp);
348 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
357 if (
auto allocOp = dyn_cast<AllocOp>(op))
358 allocOps.push_back(allocOp);
359 else if (
auto allocaOp = dyn_cast<AllocaOp>(op))
360 allocaOps.push_back(allocaOp);
361 else if (
auto reinterpretCastOp = dyn_cast<ReinterpretCastOp>(op))
362 reinterpretCastOps.push_back(reinterpretCastOp);
364 for (AllocOp allocOp : allocOps)
366 for (AllocaOp allocaOp : allocaOps)
368 for (ReinterpretCastOp reinterpretCastOp : reinterpretCastOps)
374 FunctionType functionType = funcOp.getFunctionType();
376 funcOp.getArguments(), [](
BlockArgument arg) { return arg.getLoc(); }));
379 for (
unsigned argIndex :
380 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
381 Type argType = functionType.getInput(argIndex);
382 MemRefType memrefType = dyn_cast<MemRefType>(argType);
386 inputTypes.push_back(argType);
392 if (newMemRefType == memrefType || funcOp.isExternal()) {
395 inputTypes.push_back(newMemRefType);
401 argIndex, newMemRefType, functionArgLocs[argIndex]);
403 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
416 funcOp.front().eraseArgument(argIndex);
422 funcOp.front().eraseArgument(argIndex + 1);
433 !funcOp.isExternal()) {
435 Operation *newOp = createOpResultsNormalized(funcOp, op);
440 bool replacingMemRefUsesFailed = false;
441 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
443 Value oldMemRef = op->getResult(resIndex);
444 Value newMemRef = newOp->getResult(resIndex);
445 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
449 MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
450 if (oldMemRefType == newMemRefType)
453 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
454 if (failed(replaceAllMemRefUsesWith(oldMemRef,
465 replacingMemRefUsesFailed = true;
469 if (!replacingMemRefUsesFailed) {
472 op->replaceAllUsesWith(newOp);
484 if (funcOp.isExternal()) {
486 for (
unsigned resIndex :
487 llvm::seq<unsigned>(0, functionType.getNumResults())) {
488 Type resType = functionType.getResult(resIndex);
489 MemRefType memrefType = dyn_cast<MemRefType>(resType);
493 resultTypes.push_back(resType);
499 resultTypes.push_back(newMemRefType);
502 FunctionType newFuncType =
506 funcOp.setType(newFuncType);
508 updateFunctionSignature(funcOp, moduleOp);
516 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
522 result.addAttributes(oldOp->
getAttrs());
526 bool resultTypeNormalized =
false;
527 for (
unsigned resIndex : llvm::seq<unsigned>(0, oldOp->
getNumResults())) {
529 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
532 resultTypes.push_back(resultType);
538 if (newMemRefType == memrefType) {
541 resultTypes.push_back(memrefType);
544 resultTypes.push_back(newMemRefType);
545 resultTypeNormalized =
true;
547 result.addTypes(resultTypes);
550 if (resultTypeNormalized) {
553 Region *newRegion = result.addRegion();
556 return bb.create(result);
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefNormalizable(Value::user_range opUsers)
Check whether all the uses of oldMemRef are either dereferencing uses or the op is of type : DeallocO...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class represents a specific symbol use.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
MemRefType normalizeMemRefType(MemRefType memrefType)
Normalizes memrefType so that the affine layout map of the memref is transformed to an identity map w...
LogicalResult normalizeMemRef(AllocLikeOp op)
Rewrites the memref defined by alloc or reinterpret_cast op to have an identity layout map and update...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.