16#include "llvm/ADT/DenseMap.h"
20#define GEN_PASS_DEF_SCALARIZESINGLEELEMENTTENSORRETURNPASS
21#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
32enum class ScalarizationState {
51struct ScalarizableFunctionInfo {
52 RankedTensorType tensorType;
53 SmallVector<func::ReturnOp> returnOps;
59static FailureOr<ScalarizableFunctionInfo>
60getScalarizableFunctionInfoIfEligible(func::FuncOp
func) {
61 if (
func.isDeclaration() || !
func.isPrivate())
64 FunctionType functionType =
func.getFunctionType();
65 if (functionType.getNumResults() != 1)
68 auto tensorType = dyn_cast<RankedTensorType>(functionType.getResult(0));
69 if (!tensorType || !tensorType.hasStaticShape() ||
70 tensorType.getNumElements() != 1)
73 ScalarizableFunctionInfo sfi{tensorType, {}};
75 auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator());
77 sfi.returnOps.push_back(returnOp);
82 if (sfi.returnOps.empty())
87struct ScalarizationAnalysis {
88 explicit ScalarizationAnalysis(SymbolUserMap &userMap) : userMap(userMap) {}
90 SymbolUserMap &userMap;
97 SmallVector<func::FuncOp> rewriteOrder;
102static ScalarizationState
103computeScalarizationState(func::FuncOp
func, ScalarizationAnalysis &analysis) {
105 analysis.states.try_emplace(
func, ScalarizationState::unknown);
108 if (it->second == ScalarizationState::visiting)
109 return ScalarizationState::blocked;
117 auto setBlocked = [&] {
118 analysis.states[
func] = ScalarizationState::blocked;
119 return ScalarizationState::blocked;
121 auto setRewritable = [&] {
122 analysis.states[
func] = ScalarizationState::rewritable;
123 analysis.rewriteOrder.push_back(
func);
124 return ScalarizationState::rewritable;
127 if (!analysis.candidateInfos.contains(
func))
129 analysis.states[
func] = ScalarizationState::visiting;
132 for (
Operation *user : analysis.userMap.getUsers(
func.getOperation())) {
133 auto directCall = dyn_cast<func::CallOp>(user);
138 directCallUsers.push_back(directCall);
140 func::FuncOp caller = directCall->getParentOfType<func::FuncOp>();
149 assert(analysis.moduleFunctions.contains(caller) &&
150 "Caller of private function is not a direct function in the module");
154 if (caller.isPublic())
157 assert(!caller.isExternal() &&
"Caller of private function is external.");
161 if (!analysis.candidateInfos.contains(caller))
164 if (computeScalarizationState(caller, analysis) !=
165 ScalarizationState::rewritable)
169 analysis.callUsers.try_emplace(
func, std::move(directCallUsers));
170 return setRewritable();
174static void computeScalarizationAnalysis(ModuleOp module,
175 ScalarizationAnalysis &analysis) {
178 for (func::FuncOp
func : module.getOps<func::FuncOp>()) {
179 analysis.moduleFunctions.insert(
func);
180 FailureOr<ScalarizableFunctionInfo> sfi =
181 getScalarizableFunctionInfoIfEligible(
func);
183 analysis.candidateInfos.try_emplace(
func, std::move(*sfi));
186 for (func::FuncOp
func : module.getOps<func::FuncOp>())
188 (
void)computeScalarizationState(
func, analysis);
193static void rewriteScalarizableFunction(func::FuncOp
func,
194 const ScalarizableFunctionInfo &sfi,
199 RankedTensorType tensorType = sfi.tensorType;
201 if (tensorType.getRank() != 0) {
204 zeroIndices.assign(tensorType.getRank(), zero);
207 Type scalarType = tensorType.getElementType();
208 for (func::ReturnOp funcReturn : sfi.returnOps) {
209 assert(funcReturn.getNumOperands() == 1 &&
210 "func.return must have exactly one operand");
211 assert(funcReturn.getOperand(0).getType() == tensorType &&
212 "func.return operand type must match the function result type");
215 funcReturn.getLoc(), funcReturn.getOperand(0), zeroIndices);
219 FunctionType functionType =
func.getFunctionType();
224 func.setType(FunctionType::get(
func.getContext(), functionType.getInputs(),
227 for (func::CallOp directCall : directCalls) {
229 func::CallOp newDirectCall = func::CallOp::create(
230 rewriter, directCall.getLoc(),
func, directCall.getOperands());
231 newDirectCall->setAttrs(directCall->getAttrs());
233 if (!directCall.getResult(0).use_empty()) {
234 Value wrappedResult = tensor::FromElementsOp::create(
235 rewriter, directCall.getLoc(), tensorType,
237 rewriter.
replaceOp(directCall, wrappedResult);
283ScalarizeSingleElementTensorReturns(ModuleOp module,
RewriterBase &rewriter) {
293 ScalarizationAnalysis
analysis(userMap);
294 computeScalarizationAnalysis(module, analysis);
296 for (func::FuncOp
func : llvm::reverse(
analysis.rewriteOrder)) {
297 const ScalarizableFunctionInfo &sfi =
300 rewriteScalarizableFunction(
func, sfi, directCalls, rewriter);
306struct ScalarizeSingleElementTensorReturnPass
308 ScalarizeSingleElementTensorReturnPass> {
311 void runOnOperation()
override {
313 if (
failed(ScalarizeSingleElementTensorReturns(getOperation(), rewriter)))
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Block represents an ordered list of Operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap