MLIR 23.0.0git
ScalarizeFunctionResult.cpp
Go to the documentation of this file.
1//===- ScalarizeFunctionResult.cpp - Scalarize tensor returns ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/BuiltinOps.h"
15#include "mlir/IR/SymbolTable.h"
16#include "llvm/ADT/DenseMap.h"
17
18namespace mlir {
19namespace tensor {
20#define GEN_PASS_DEF_SCALARIZESINGLEELEMENTTENSORRETURNPASS
21#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
22} // namespace tensor
23} // namespace mlir
24
25using namespace mlir;
26
27namespace {
28
29// Analysis state used by the memoized DFS below. It classifies whether a
30// candidate private function can be scalarized. The rewrite only consumes the
31// functions classified as `rewritable`.
32enum class ScalarizationState {
33 // No DFS classification has been computed for this function yet.
34 unknown,
35 // The function is currently on the recursive DFS stack. Re-entering this
36 // state means a cycle was found, which this pass conservatively blocks.
37 visiting,
38 // The function cannot be scalarized safely because either it is not locally
39 // eligible or one of its transitive users blocks the rewrite.
40 blocked,
41 // The function is locally eligible and all transitive users considered by
42 // this pass can be updated consistently.
43 rewritable
44};
45
46// Info analyzed to decide scalarizing a locally eligible function: a private
47// definition with exactly one statically-shaped ranked tensor result containing
48// one element. Functions not meeting these criteria are not represented here,
49// although they may still appear in the broader module analysis as callers or
50// blockers.
51struct ScalarizableFunctionInfo {
52 RankedTensorType tensorType;
53 SmallVector<func::ReturnOp> returnOps;
54};
55
56// Returns per-function scalarization info when this function is locally
57// eligible, i.e. it is a private definition with one statically-shaped ranked
58// tensor result containing exactly one element.
59static FailureOr<ScalarizableFunctionInfo>
60getScalarizableFunctionInfoIfEligible(func::FuncOp func) {
61 if (func.isDeclaration() || !func.isPrivate())
62 return failure();
63
64 FunctionType functionType = func.getFunctionType();
65 if (functionType.getNumResults() != 1)
66 return failure();
67
68 auto tensorType = dyn_cast<RankedTensorType>(functionType.getResult(0));
69 if (!tensorType || !tensorType.hasStaticShape() ||
70 tensorType.getNumElements() != 1)
71 return failure();
72
73 ScalarizableFunctionInfo sfi{tensorType, {}};
74 for (Block &block : func.getBody()) {
75 auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator());
76 if (returnOp)
77 sfi.returnOps.push_back(returnOp);
78 }
79
80 // While FuncOp is guaranteed to contain terminator ops, there is no guarantee
81 // that it will contain ReturnOp(s). Hence the check.
82 if (sfi.returnOps.empty())
83 return failure();
84 return sfi;
85}
86
87struct ScalarizationAnalysis {
88 explicit ScalarizationAnalysis(SymbolUserMap &userMap) : userMap(userMap) {}
89
90 SymbolUserMap &userMap;
91 DenseSet<func::FuncOp> moduleFunctions;
95 // DFS completion order for rewritable functions. The rewrite phase walks
96 // this list in reverse so callees are rewritten before their call sites.
97 SmallVector<func::FuncOp> rewriteOrder;
98};
99
100// Runs the memoized DFS that classifies one candidate function by walking its
101// transitive private call users.
102static ScalarizationState
103computeScalarizationState(func::FuncOp func, ScalarizationAnalysis &analysis) {
104 auto [it, inserted] =
105 analysis.states.try_emplace(func, ScalarizationState::unknown);
106 if (!inserted) {
107 // Conservatively reject recursive cycles instead of reasoning about SCCs.
108 if (it->second == ScalarizationState::visiting)
109 return ScalarizationState::blocked;
110 return it->second;
113 // Starting from one locally-eligible private function, walk its symbol users
114 // upward through private callers. Memoization avoids re-traversing the same
115 // subgraph from the outer linear scan, and early blocking truncates the DFS
116 // as soon as a public/unsupported user is found.
117 auto setBlocked = [&] {
118 analysis.states[func] = ScalarizationState::blocked;
119 return ScalarizationState::blocked;
120 };
121 auto setRewritable = [&] {
122 analysis.states[func] = ScalarizationState::rewritable;
123 analysis.rewriteOrder.push_back(func);
124 return ScalarizationState::rewritable;
125 };
127 if (!analysis.candidateInfos.contains(func))
128 return setBlocked();
129 analysis.states[func] = ScalarizationState::visiting;
130
131 SmallVector<func::CallOp> directCallUsers;
132 for (Operation *user : analysis.userMap.getUsers(func.getOperation())) {
133 auto directCall = dyn_cast<func::CallOp>(user);
134 // Non-call symbol uses, such as func.constant, prevent updating all users
135 // consistently, so the current function stays blocked.
136 if (!directCall)
137 return setBlocked();
138 directCallUsers.push_back(directCall);
140 func::FuncOp caller = directCall->getParentOfType<func::FuncOp>();
141 // A direct call user outside any func.func can still be updated in place,
142 // but it terminates the DFS because there is no caller signature to
143 // analyze or rewrite transitively.
144 if (!caller)
145 continue;
146 // Since `getScalarizableFunctionInfoIfEligible` has already categorized
147 // every direct func.func in the current module, any direct caller must
148 // already appear in the analysis tables.
149 assert(analysis.moduleFunctions.contains(caller) &&
150 "Caller of private function is not a direct function in the module");
151
152 // Public and external callers keep the current function blocked because the
153 // pass cannot rewrite every visible call boundary.
154 if (caller.isPublic())
155 return setBlocked();
156
157 assert(!caller.isExternal() && "Caller of private function is external.");
158
159 // A private non-candidate caller can absorb the scalarized call by
160 // reboxing the scalar result back into a single-element tensor.
161 if (!analysis.candidateInfos.contains(caller))
162 continue;
163
164 if (computeScalarizationState(caller, analysis) !=
165 ScalarizationState::rewritable)
166 return setBlocked();
167 }
168
169 analysis.callUsers.try_emplace(func, std::move(directCallUsers));
170 return setRewritable();
171}
172
173// Builds the module-level analysis state used by the rewrite phase.
174static void computeScalarizationAnalysis(ModuleOp module,
175 ScalarizationAnalysis &analysis) {
176 // First collect all direct functions and the subset that is locally
177 // eligible.
178 for (func::FuncOp func : module.getOps<func::FuncOp>()) {
179 analysis.moduleFunctions.insert(func);
180 FailureOr<ScalarizableFunctionInfo> sfi =
181 getScalarizableFunctionInfoIfEligible(func);
182 if (succeeded(sfi))
183 analysis.candidateInfos.try_emplace(func, std::move(*sfi));
184 }
185 // Then run the memoized DFS for candidate roots.
186 for (func::FuncOp func : module.getOps<func::FuncOp>())
187 if (analysis.candidateInfos.contains(func))
188 (void)computeScalarizationState(func, analysis);
189}
190
191// Rewrites one function that has already been proven rewritable and updates
192// the direct call users cached before any IR mutation started.
193static void rewriteScalarizableFunction(func::FuncOp func,
194 const ScalarizableFunctionInfo &sfi,
195 ArrayRef<func::CallOp> directCalls,
196 RewriterBase &rewriter) {
197 OpBuilder::InsertionGuard guard(rewriter);
198 // Scalarize the unique element before each return.
199 RankedTensorType tensorType = sfi.tensorType;
200 SmallVector<Value> zeroIndices;
201 if (tensorType.getRank() != 0) {
202 rewriter.setInsertionPointToStart(&func.getBody().front());
203 Value zero = arith::ConstantIndexOp::create(rewriter, func.getLoc(), 0);
204 zeroIndices.assign(tensorType.getRank(), zero);
205 }
206
207 Type scalarType = tensorType.getElementType();
208 for (func::ReturnOp funcReturn : sfi.returnOps) {
209 assert(funcReturn.getNumOperands() == 1 &&
210 "func.return must have exactly one operand");
211 assert(funcReturn.getOperand(0).getType() == tensorType &&
212 "func.return operand type must match the function result type");
213 rewriter.setInsertionPoint(funcReturn);
214 Value scalar = rewriter.createOrFold<tensor::ExtractOp>(
215 funcReturn.getLoc(), funcReturn.getOperand(0), zeroIndices);
216 rewriter.replaceOpWithNewOp<func::ReturnOp>(funcReturn, scalar);
217 }
218
219 FunctionType functionType = func.getFunctionType();
220 // Update the function type:
221 // This is a 1-result to 1-result type replacement, so the existing result
222 // attribute dictionary remains attached to result #0 without reordering,
223 // hence the rewrite is done directly without function_interface methods.
224 func.setType(FunctionType::get(func.getContext(), functionType.getInputs(),
225 TypeRange{scalarType}));
226 // Fix direct call users that were recorded during analysis.
227 for (func::CallOp directCall : directCalls) {
228 rewriter.setInsertionPoint(directCall);
229 func::CallOp newDirectCall = func::CallOp::create(
230 rewriter, directCall.getLoc(), func, directCall.getOperands());
231 newDirectCall->setAttrs(directCall->getAttrs());
232
233 if (!directCall.getResult(0).use_empty()) {
234 Value wrappedResult = tensor::FromElementsOp::create(
235 rewriter, directCall.getLoc(), tensorType,
236 ValueRange{newDirectCall.getResult(0)});
237 rewriter.replaceOp(directCall, wrappedResult);
238 } else {
239 rewriter.eraseOp(directCall);
240 }
241 }
242}
243
244/// Drives the complete module-level transform: analyze the original module,
245/// determine which private functions can be scalarized safely, then
246/// rewrite only that precomputed set.
247///
248/// BEFORE (both callee and private caller rewritten)
249/// private callee(x : tensor<1xT>) -> tensor<1xT> { return x }
250/// private caller(x : tensor<1xT>) -> tensor<1xT> {
251/// y = call callee(x)
252/// return y
253/// }
254///
255/// AFTER
256/// private callee(x : tensor<1xT>) -> T {
257/// return tensor.extract x[0]
258/// }
259/// private caller(x : tensor<1xT>) -> T {
260/// y = call callee(x)
261/// return y
262/// }
263///
264/// BEFORE (callee rewritten, unchanged private caller reboxes)
265/// private callee(x : tensor<1xT>) -> tensor<1xT> { return x }
266/// private caller(x : tensor<1xT>, z : tensor<1xT>) -> tensor<1xT> {
267/// y = call callee(x)
268/// r = tensor_op(y, z)
269/// return r
270/// }
271///
272/// AFTER
273/// private callee(x : tensor<1xT>) -> T {
274/// return tensor.extract x[0]
275/// }
276/// private caller(x : tensor<1xT>, z : tensor<1xT>) -> tensor<1xT> {
277/// y = call callee(x)
278/// y_boxed = tensor.from_elements y
279/// r = tensor_op(y_boxed, z)
280/// return r
281/// }
282static LogicalResult
283ScalarizeSingleElementTensorReturns(ModuleOp module, RewriterBase &rewriter) {
284 // The transform depends on module-scoped `SymbolUserMap` state and on a
285 // precomputed DFS result, so it runs as a direct module analysis + rewrite
286 // instead of exposing a pattern-testing / transform-dialect pattern path.
287 SymbolTableCollection symbolTable;
288 // Take a snapshot of symbol users for the original module. This is safe
289 // because it is consulted only during `computeScalarizationAnalysis`, before
290 // rewriting starts, and this pass invocation has exclusive access to the
291 // current module.
292 SymbolUserMap userMap(symbolTable, module);
293 ScalarizationAnalysis analysis(userMap);
294 computeScalarizationAnalysis(module, analysis);
295
296 for (func::FuncOp func : llvm::reverse(analysis.rewriteOrder)) {
297 const ScalarizableFunctionInfo &sfi =
298 analysis.candidateInfos.find(func)->second;
299 ArrayRef<func::CallOp> directCalls = analysis.callUsers.find(func)->second;
300 rewriteScalarizableFunction(func, sfi, directCalls, rewriter);
301 }
302
303 return success();
304}
305
306struct ScalarizeSingleElementTensorReturnPass
308 ScalarizeSingleElementTensorReturnPass> {
309 using Base::Base;
310
311 void runOnOperation() override {
312 IRRewriter rewriter(&getContext());
313 if (failed(ScalarizeSingleElementTensorReturns(getOperation(), rewriter)))
314 signalPassFailure();
315 }
316};
317
318} // namespace
return success()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Block represents an ordered list of Operations.
Definition Block.h:33
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
This class represents a map of symbols to users, and provides efficient implementations of symbol que...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
detail::InFlightRemark analysis(Location loc, RemarkOpts opts)
Report an optimization analysis remark.
Definition Remarks.h:723
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120