19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_NORMALIZEMEMREFS
25 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "normalize-memrefs"
42 struct NormalizeMemRefs
43 :
public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
44 void runOnOperation()
override;
45 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
46 bool areMemRefsNormalizable(func::FuncOp funcOp);
47 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
48 void setCalleesAndCallersNonNormalizable(
49 func::FuncOp funcOp, ModuleOp moduleOp,
56 std::unique_ptr<OperationPass<ModuleOp>>
58 return std::make_unique<NormalizeMemRefs>();
61 void NormalizeMemRefs::runOnOperation() {
62 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing Memrefs...\n");
63 ModuleOp moduleOp = getOperation();
71 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
79 moduleOp.walk([&](func::FuncOp funcOp) {
80 if (normalizableFuncs.contains(funcOp)) {
81 if (!areMemRefsNormalizable(funcOp)) {
82 LLVM_DEBUG(llvm::dbgs()
83 <<
"@" << funcOp.getName()
84 <<
" contains ops that cannot normalize MemRefs\n");
88 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
94 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing " << normalizableFuncs.size()
97 for (func::FuncOp &funcOp : normalizableFuncs)
98 normalizeFuncOpMemRefs(funcOp, moduleOp);
106 return llvm::all_of(opUsers, [](
Operation *op) {
113 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
114 func::FuncOp funcOp, ModuleOp moduleOp,
116 if (!normalizableFuncs.contains(funcOp))
120 llvm::dbgs() <<
"@" << funcOp.getName()
121 <<
" calls or is called by non-normalizable function\n");
122 normalizableFuncs.erase(funcOp);
124 std::optional<SymbolTable::UseRange> symbolUses =
125 funcOp.getSymbolUses(moduleOp);
129 func::FuncOp parentFuncOp =
130 symbolUse.getUser()->getParentOfType<func::FuncOp>();
131 for (func::FuncOp &funcOp : normalizableFuncs) {
132 if (parentFuncOp == funcOp) {
133 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
141 funcOp.walk([&](func::CallOp callOp) {
142 StringAttr callee = callOp.getCalleeAttr().getAttr();
143 for (func::FuncOp &funcOp : normalizableFuncs) {
145 if (callee == funcOp.getNameAttr()) {
146 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
161 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
163 if (funcOp.isExternal())
168 Value oldMemRef = allocOp.getResult();
169 if (!allocOp.getType().getLayout().isIdentity() &&
179 Value oldMemRef = allocaOp.getResult();
180 if (!allocaOp.getType().getLayout().isIdentity() &&
190 for (
unsigned resIndex :
191 llvm::seq<unsigned>(0, callOp.getNumResults())) {
192 Value oldMemRef = callOp.getResult(resIndex);
193 if (auto oldMemRefType =
194 dyn_cast<MemRefType>(oldMemRef.getType()))
195 if (!oldMemRefType.getLayout().isIdentity() &&
196 !isMemRefNormalizable(oldMemRef.getUsers()))
197 return WalkResult::interrupt();
204 for (
unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
206 if (
auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType()))
207 if (!oldMemRefType.getLayout().isIdentity() &&
221 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
223 FunctionType functionType = funcOp.getFunctionType();
225 FunctionType newFuncType;
226 resultTypes = llvm::to_vector<4>(functionType.getResults());
230 if (!funcOp.isExternal()) {
233 argTypes.push_back(argEn.value().getType());
237 funcOp.walk([&](func::ReturnOp returnOp) {
238 for (
const auto &operandEn :
llvm::enumerate(returnOp.getOperands())) {
239 Type opType = operandEn.value().getType();
240 MemRefType memrefType = dyn_cast<MemRefType>(opType);
243 if (!memrefType || memrefType == resultTypes[operandEn.index()])
253 if (memrefType.getLayout().isIdentity())
254 resultTypes[operandEn.index()] = memrefType;
269 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
272 std::optional<SymbolTable::UseRange> symbolUses =
273 funcOp.getSymbolUses(moduleOp);
281 auto callOp = dyn_cast<func::CallOp>(userOp);
285 builder.
create<func::CallOp>(userOp->
getLoc(), callOp.getCalleeAttr(),
287 bool replacingMemRefUsesFailed =
false;
288 bool returnTypeChanged =
false;
289 for (
unsigned resIndex : llvm::seq<unsigned>(0, userOp->
getNumResults())) {
298 cast<MemRefType>(oldResult.
getType()).getLayout().getAffineMap();
312 replacingMemRefUsesFailed =
true;
315 returnTypeChanged =
true;
317 if (replacingMemRefUsesFailed)
322 if (returnTypeChanged) {
332 func::FuncOp parentFuncOp = newCallOp->
getParentOfType<func::FuncOp>();
333 funcOpsToUpdate.insert(parentFuncOp);
338 if (!funcOp.isExternal())
339 funcOp.setType(newFuncType);
344 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
345 updateFunctionSignature(parentFuncOp, moduleOp);
351 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
357 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
358 for (memref::AllocOp allocOp : allocOps)
362 funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
363 for (memref::AllocaOp allocaOp : allocaOps)
369 FunctionType functionType = funcOp.getFunctionType();
371 funcOp.getArguments(), [](
BlockArgument arg) { return arg.getLoc(); }));
374 for (
unsigned argIndex :
375 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
376 Type argType = functionType.getInput(argIndex);
377 MemRefType memrefType = dyn_cast<MemRefType>(argType);
381 inputTypes.push_back(argType);
387 if (newMemRefType == memrefType || funcOp.isExternal()) {
390 inputTypes.push_back(newMemRefType);
396 argIndex, newMemRefType, functionArgLocs[argIndex]);
398 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
411 funcOp.front().eraseArgument(argIndex);
417 funcOp.front().eraseArgument(argIndex + 1);
428 !funcOp.isExternal()) {
430 Operation *newOp = createOpResultsNormalized(funcOp, op);
435 bool replacingMemRefUsesFailed = false;
436 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
438 Value oldMemRef = op->getResult(resIndex);
439 Value newMemRef = newOp->getResult(resIndex);
440 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
444 MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
445 if (oldMemRefType == newMemRefType)
448 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
449 if (failed(replaceAllMemRefUsesWith(oldMemRef,
460 replacingMemRefUsesFailed = true;
464 if (!replacingMemRefUsesFailed) {
467 op->replaceAllUsesWith(newOp);
479 if (funcOp.isExternal()) {
481 for (
unsigned resIndex :
482 llvm::seq<unsigned>(0, functionType.getNumResults())) {
483 Type resType = functionType.getResult(resIndex);
484 MemRefType memrefType = dyn_cast<MemRefType>(resType);
488 resultTypes.push_back(resType);
494 resultTypes.push_back(newMemRefType);
497 FunctionType newFuncType =
501 funcOp.setType(newFuncType);
503 updateFunctionSignature(funcOp, moduleOp);
511 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
517 result.addAttributes(oldOp->
getAttrs());
521 bool resultTypeNormalized =
false;
522 for (
unsigned resIndex : llvm::seq<unsigned>(0, oldOp->
getNumResults())) {
524 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
527 resultTypes.push_back(resultType);
533 if (newMemRefType == memrefType) {
536 resultTypes.push_back(memrefType);
539 resultTypes.push_back(newMemRefType);
540 resultTypeNormalized =
true;
542 result.addTypes(resultTypes);
545 if (resultTypeNormalized) {
548 Region *newRegion = result.addRegion();
551 return bb.create(result);
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefNormalizable(Value::user_range opUsers)
Check whether all the uses of oldMemRef are either dereferencing uses or the op is of type : DeallocO...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class represents a specific symbol use.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
MemRefType normalizeMemRefType(MemRefType memrefType)
Normalizes memrefType so that the affine layout map of the memref is transformed to an identity map w...
LogicalResult normalizeMemRef(AllocLikeOp op)
Rewrites the memref defined by this alloc op to have an identity layout map and updates all its index...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< OperationPass< ModuleOp > > createNormalizeMemRefsPass()
Creates an interprocedural pass to normalize memrefs to have a trivial (identity) layout map.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.