MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Func dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Func dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/SmallVector.h"
18 
19 using namespace mlir;
20 
21 FailureOr<func::FuncOp>
22 func::replaceFuncWithNewOrder(RewriterBase &rewriter, func::FuncOp funcOp,
23  ArrayRef<unsigned> newArgsOrder,
24  ArrayRef<unsigned> newResultsOrder) {
25  // Generate an empty new function operation with the same name as the
26  // original.
27  assert(funcOp.getNumArguments() == newArgsOrder.size() &&
28  "newArgsOrder must match the number of arguments in the function");
29  assert(funcOp.getNumResults() == newResultsOrder.size() &&
30  "newResultsOrder must match the number of results in the function");
31 
32  if (!funcOp.getBody().hasOneBlock())
33  return rewriter.notifyMatchFailure(
34  funcOp, "expected function to have exactly one block");
35 
36  ArrayRef<Type> origInputTypes = funcOp.getFunctionType().getInputs();
37  ArrayRef<Type> origOutputTypes = funcOp.getFunctionType().getResults();
38  SmallVector<Type> newInputTypes, newOutputTypes;
40  for (unsigned int idx : newArgsOrder) {
41  newInputTypes.push_back(origInputTypes[idx]);
42  locs.push_back(funcOp.getArgument(newArgsOrder[idx]).getLoc());
43  }
44  for (unsigned int idx : newResultsOrder)
45  newOutputTypes.push_back(origOutputTypes[idx]);
46  rewriter.setInsertionPoint(funcOp);
47  auto newFuncOp = func::FuncOp::create(
48  rewriter, funcOp.getLoc(), funcOp.getName(),
49  rewriter.getFunctionType(newInputTypes, newOutputTypes));
50 
51  Region &newRegion = newFuncOp.getBody();
52  rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
53  newFuncOp.setVisibility(funcOp.getVisibility());
54  newFuncOp->setDiscardableAttrs(funcOp->getDiscardableAttrDictionary());
55 
56  // Map the arguments of the original function to the new function in
57  // the new order and adjust the attributes accordingly.
58  IRMapping operandMapper;
59  SmallVector<DictionaryAttr> argAttrs, resultAttrs;
60  funcOp.getAllArgAttrs(argAttrs);
61  for (unsigned int i = 0; i < newArgsOrder.size(); ++i) {
62  operandMapper.map(funcOp.getArgument(newArgsOrder[i]),
63  newFuncOp.getArgument(i));
64  newFuncOp.setArgAttrs(i, argAttrs[newArgsOrder[i]]);
65  }
66  funcOp.getAllResultAttrs(resultAttrs);
67  for (unsigned int i = 0; i < newResultsOrder.size(); ++i)
68  newFuncOp.setResultAttrs(i, resultAttrs[newResultsOrder[i]]);
69 
70  // Clone the operations from the original function to the new function.
71  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
72  for (Operation &op : funcOp.getOps())
73  rewriter.clone(op, operandMapper);
74 
75  // Handle the return operation.
76  auto returnOp = cast<func::ReturnOp>(
77  newFuncOp.getFunctionBody().begin()->getTerminator());
78  SmallVector<Value> newReturnValues;
79  for (unsigned int idx : newResultsOrder)
80  newReturnValues.push_back(returnOp.getOperand(idx));
81  rewriter.setInsertionPoint(returnOp);
82  auto newReturnOp =
83  func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
84  newReturnOp->setDiscardableAttrs(returnOp->getDiscardableAttrDictionary());
85  rewriter.eraseOp(returnOp);
86 
87  rewriter.eraseOp(funcOp);
88 
89  return newFuncOp;
90 }
91 
92 func::CallOp
93 func::replaceCallOpWithNewOrder(RewriterBase &rewriter, func::CallOp callOp,
94  ArrayRef<unsigned> newArgsOrder,
95  ArrayRef<unsigned> newResultsOrder) {
96  assert(
97  callOp.getNumOperands() == newArgsOrder.size() &&
98  "newArgsOrder must match the number of operands in the call operation");
99  assert(
100  callOp.getNumResults() == newResultsOrder.size() &&
101  "newResultsOrder must match the number of results in the call operation");
102  SmallVector<Value> newArgsOrderValues;
103  for (unsigned int argIdx : newArgsOrder)
104  newArgsOrderValues.push_back(callOp.getOperand(argIdx));
105  SmallVector<Type> newResultTypes;
106  for (unsigned int resIdx : newResultsOrder)
107  newResultTypes.push_back(callOp.getResult(resIdx).getType());
108 
109  // Replace the kernel call operation with a new one that has the
110  // reordered arguments.
111  rewriter.setInsertionPoint(callOp);
112  auto newCallOp =
113  func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
114  newResultTypes, newArgsOrderValues);
115  newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
116  for (auto &&[newIndex, origIndex] : llvm::enumerate(newResultsOrder))
117  rewriter.replaceAllUsesWith(callOp.getResult(origIndex),
118  newCallOp.getResult(newIndex));
119  rewriter.eraseOp(callOp);
120 
121  return newCallOp;
122 }
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
CallOp replaceCallOpWithNewOrder(RewriterBase &rewriter, CallOp callOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new call operation with the values as the original call operation, but with the arguments r...
FailureOr< FuncOp > replaceFuncWithNewOrder(RewriterBase &rewriter, FuncOp funcOp, llvm::ArrayRef< unsigned > newArgsOrder, llvm::ArrayRef< unsigned > newResultsOrder)
Creates a new function operation with the same name as the original function operation,...
Include the generated interface declarations.