19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/Support/Debug.h"
24 #define GEN_PASS_DEF_NORMALIZEMEMREFS
25 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 #define DEBUG_TYPE "normalize-memrefs"
42 struct NormalizeMemRefs
43 :
public memref::impl::NormalizeMemRefsBase<NormalizeMemRefs> {
44 void runOnOperation()
override;
45 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
46 bool areMemRefsNormalizable(func::FuncOp funcOp);
47 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
48 void setCalleesAndCallersNonNormalizable(
49 func::FuncOp funcOp, ModuleOp moduleOp,
56 std::unique_ptr<OperationPass<ModuleOp>>
58 return std::make_unique<NormalizeMemRefs>();
61 void NormalizeMemRefs::runOnOperation() {
62 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing Memrefs...\n");
63 ModuleOp moduleOp = getOperation();
71 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
79 moduleOp.walk([&](func::FuncOp funcOp) {
80 if (normalizableFuncs.contains(funcOp)) {
81 if (!areMemRefsNormalizable(funcOp)) {
82 LLVM_DEBUG(llvm::dbgs()
83 <<
"@" << funcOp.getName()
84 <<
" contains ops that cannot normalize MemRefs\n");
88 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
94 LLVM_DEBUG(llvm::dbgs() <<
"Normalizing " << normalizableFuncs.size()
97 for (func::FuncOp &funcOp : normalizableFuncs)
98 normalizeFuncOpMemRefs(funcOp, moduleOp);
106 return llvm::all_of(opUsers, [](
Operation *op) {
113 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
114 func::FuncOp funcOp, ModuleOp moduleOp,
116 if (!normalizableFuncs.contains(funcOp))
120 llvm::dbgs() <<
"@" << funcOp.getName()
121 <<
" calls or is called by non-normalizable function\n");
122 normalizableFuncs.erase(funcOp);
124 std::optional<SymbolTable::UseRange> symbolUses =
125 funcOp.getSymbolUses(moduleOp);
129 func::FuncOp parentFuncOp =
130 symbolUse.getUser()->getParentOfType<func::FuncOp>();
131 for (func::FuncOp &funcOp : normalizableFuncs) {
132 if (parentFuncOp == funcOp) {
133 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
141 funcOp.walk([&](func::CallOp callOp) {
142 StringAttr callee = callOp.getCalleeAttr().getAttr();
143 for (func::FuncOp &funcOp : normalizableFuncs) {
145 if (callee == funcOp.getNameAttr()) {
146 setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
161 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
163 if (funcOp.isExternal())
168 Value oldMemRef = allocOp.getResult();
169 if (!allocOp.getType().getLayout().isIdentity() &&
179 for (
unsigned resIndex :
180 llvm::seq<unsigned>(0, callOp.getNumResults())) {
181 Value oldMemRef = callOp.getResult(resIndex);
182 if (auto oldMemRefType =
183 dyn_cast<MemRefType>(oldMemRef.getType()))
184 if (!oldMemRefType.getLayout().isIdentity() &&
185 !isMemRefNormalizable(oldMemRef.getUsers()))
186 return WalkResult::interrupt();
193 for (
unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
195 if (
auto oldMemRefType = dyn_cast<MemRefType>(oldMemRef.
getType()))
196 if (!oldMemRefType.getLayout().isIdentity() &&
210 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
212 FunctionType functionType = funcOp.getFunctionType();
214 FunctionType newFuncType;
215 resultTypes = llvm::to_vector<4>(functionType.getResults());
219 if (!funcOp.isExternal()) {
222 argTypes.push_back(argEn.value().getType());
226 funcOp.walk([&](func::ReturnOp returnOp) {
227 for (
const auto &operandEn :
llvm::enumerate(returnOp.getOperands())) {
228 Type opType = operandEn.value().getType();
229 MemRefType memrefType = dyn_cast<MemRefType>(opType);
232 if (!memrefType || memrefType == resultTypes[operandEn.index()])
242 if (memrefType.getLayout().isIdentity())
243 resultTypes[operandEn.index()] = memrefType;
258 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
261 std::optional<SymbolTable::UseRange> symbolUses =
262 funcOp.getSymbolUses(moduleOp);
270 auto callOp = dyn_cast<func::CallOp>(userOp);
274 builder.
create<func::CallOp>(userOp->
getLoc(), callOp.getCalleeAttr(),
276 bool replacingMemRefUsesFailed =
false;
277 bool returnTypeChanged =
false;
278 for (
unsigned resIndex : llvm::seq<unsigned>(0, userOp->
getNumResults())) {
287 cast<MemRefType>(oldResult.
getType()).getLayout().getAffineMap();
301 replacingMemRefUsesFailed =
true;
304 returnTypeChanged =
true;
306 if (replacingMemRefUsesFailed)
311 if (returnTypeChanged) {
321 func::FuncOp parentFuncOp = newCallOp->
getParentOfType<func::FuncOp>();
322 funcOpsToUpdate.insert(parentFuncOp);
327 if (!funcOp.isExternal())
328 funcOp.setType(newFuncType);
333 for (func::FuncOp parentFuncOp : funcOpsToUpdate)
334 updateFunctionSignature(parentFuncOp, moduleOp);
340 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
346 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
347 for (memref::AllocOp allocOp : allocOps)
353 FunctionType functionType = funcOp.getFunctionType();
355 funcOp.getArguments(), [](
BlockArgument arg) { return arg.getLoc(); }));
358 for (
unsigned argIndex :
359 llvm::seq<unsigned>(0, functionType.getNumInputs())) {
360 Type argType = functionType.getInput(argIndex);
361 MemRefType memrefType = dyn_cast<MemRefType>(argType);
365 inputTypes.push_back(argType);
371 if (newMemRefType == memrefType || funcOp.isExternal()) {
374 inputTypes.push_back(newMemRefType);
380 argIndex, newMemRefType, functionArgLocs[argIndex]);
382 AffineMap layoutMap = memrefType.getLayout().getAffineMap();
395 funcOp.front().eraseArgument(argIndex);
401 funcOp.front().eraseArgument(argIndex + 1);
412 !funcOp.isExternal()) {
414 Operation *newOp = createOpResultsNormalized(funcOp, op);
419 bool replacingMemRefUsesFailed = false;
420 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
422 Value oldMemRef = op->getResult(resIndex);
423 Value newMemRef = newOp->getResult(resIndex);
424 MemRefType oldMemRefType = dyn_cast<MemRefType>(oldMemRef.getType());
428 MemRefType newMemRefType = cast<MemRefType>(newMemRef.getType());
429 if (oldMemRefType == newMemRefType)
432 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
433 if (failed(replaceAllMemRefUsesWith(oldMemRef,
444 replacingMemRefUsesFailed = true;
448 if (!replacingMemRefUsesFailed) {
451 op->replaceAllUsesWith(newOp);
463 if (funcOp.isExternal()) {
465 for (
unsigned resIndex :
466 llvm::seq<unsigned>(0, functionType.getNumResults())) {
467 Type resType = functionType.getResult(resIndex);
468 MemRefType memrefType = dyn_cast<MemRefType>(resType);
472 resultTypes.push_back(resType);
478 resultTypes.push_back(newMemRefType);
481 FunctionType newFuncType =
485 funcOp.setType(newFuncType);
487 updateFunctionSignature(funcOp, moduleOp);
495 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
501 result.addAttributes(oldOp->
getAttrs());
505 bool resultTypeNormalized =
false;
506 for (
unsigned resIndex : llvm::seq<unsigned>(0, oldOp->
getNumResults())) {
508 MemRefType memrefType = dyn_cast<MemRefType>(resultType);
511 resultTypes.push_back(resultType);
517 if (newMemRefType == memrefType) {
520 resultTypes.push_back(memrefType);
523 resultTypes.push_back(newMemRefType);
524 resultTypeNormalized =
true;
526 result.addTypes(resultTypes);
529 if (resultTypeNormalized) {
532 Region *newRegion = result.addRegion();
535 return bb.create(result);
static MLIRContext * getContext(OpFoldResult val)
static bool isMemRefNormalizable(Value::user_range opUsers)
Check whether all the uses of oldMemRef are either dereferencing uses or the op is of type : DeallocO...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class represents a specific symbol use.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
LogicalResult normalizeMemRef(memref::AllocOp *op)
Rewrites the memref defined by this alloc op to have an identity layout map and updates all its index...
MemRefType normalizeMemRefType(MemRefType memrefType)
Normalizes memrefType so that the affine layout map of the memref is transformed to an identity map w...
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< OperationPass< ModuleOp > > createNormalizeMemRefsPass()
Creates an interprocedural pass to normalize memrefs to have a trivial (identity) layout map.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.