MLIR 22.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2//----===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// Module Bufferization is an extension of One-Shot Bufferize that
11// bufferizes function boundaries. It provides `BufferizableOpInterface`
12// implementations for FuncOp, CallOp and ReturnOp. Although it is named
13// Module Bufferization, it may operate on any SymbolTable.
14//
15// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16// ...)`. This function analyzes the given op and determines the order of
17// analysis and bufferization: Functions that are called are processed before
18// their respective callers.
19//
20// After analyzing a FuncOp, additional information about its bbArgs is
21// gathered and stored in `FuncAnalysisState`.
22//
23// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
24// for
25// each tensor return value (if any).
26// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
27// read/written.
28//
29// Module Bufferization implements the following calling convention.
30//
31// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
32// be written to in-place.
33// * If a tensor operand of a CallOp is read after the CallOp, the operand of
34// the CallOp must bufferize out-of-place.
35//
36// Example: The tensor.insert op bufferizes in-place because it is allowed to
37// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
38// out-of-place because `%t0` is modified by the callee but read by the
39// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
40// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
41// ```
42// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
43// %f = ... : f32
44// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
45// return %0 : tensor<?xf32>
46// }
47//
48// func @caller() -> () {
49// %t0 = ... : tensor<?xf32>
50// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
51// %2 = tensor.extract %1[...] : tensor<?xf32>
52// }
53// ```
54//
55// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
56// analyze the function body. In such a case, the CallOp analysis conservatively
57// assumes that each tensor OpOperand is both read and written.
58//
59// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
60// as "not reading" and/or "not writing".
61
63
73#include "mlir/IR/Operation.h"
74
75using namespace mlir;
76using namespace mlir::bufferization;
77using namespace mlir::bufferization::func_ext;
78
79/// A mapping of FuncOps to their callers.
81
82/// Get or create FuncAnalysisState.
83static FuncAnalysisState &
90
91namespace {
92
93/// Annotate IR with the results of the analysis. For testing purposes only.
94static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
95 BlockArgument bbArg) {
96 const char *kEquivalentArgsAttr = "__equivalent_func_args__";
97 Operation *op = returnVal.getOwner();
98
99 SmallVector<int64_t> equivBbArgs;
100 if (op->hasAttr(kEquivalentArgsAttr)) {
101 auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
102 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
103 return cast<IntegerAttr>(a).getValue().getSExtValue();
104 }));
105 } else {
106 equivBbArgs.append(op->getNumOperands(), -1);
107 }
108 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
109
110 OpBuilder b(op->getContext());
111 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
112}
113
114/// Store function BlockArguments that are equivalent to/aliasing a returned
115/// value in FuncAnalysisState.
116static LogicalResult
117aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
118 FuncAnalysisState &funcState) {
119 if (funcOp.getBody().empty()) {
120 // No function body available. Conservatively assume that every tensor
121 // return value may alias with any tensor bbArg.
122 FunctionType type = funcOp.getFunctionType();
123 for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
124 if (!isa<TensorType>(inputIt.value()))
125 continue;
126 for (const auto &resultIt : llvm::enumerate(type.getResults())) {
127 if (!isa<TensorType>(resultIt.value()))
128 continue;
129 int64_t returnIdx = resultIt.index();
130 int64_t bbArgIdx = inputIt.index();
131 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
132 }
133 }
134 return success();
135 }
136
137 // Find all func.return ops.
138 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
139 assert(!returnOps.empty() && "expected at least one ReturnOp");
140
141 // Build alias sets. Merge all aliases from all func.return ops.
142 for (BlockArgument bbArg : funcOp.getArguments()) {
143 if (isa<RankedTensorType>(bbArg.getType())) {
144 int64_t bbArgIdx = bbArg.getArgNumber();
145 // Store aliases in a set, so that we don't add the same alias twice.
146 SetVector<int64_t> aliases;
147 for (func::ReturnOp returnOp : returnOps) {
148 for (OpOperand &returnVal : returnOp->getOpOperands()) {
149 if (isa<RankedTensorType>(returnVal.get().getType())) {
150 int64_t returnIdx = returnVal.getOperandNumber();
151 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
152 aliases.insert(returnIdx);
153 }
154 }
155 }
156 for (int64_t alias : aliases)
157 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
158 }
159 }
160
161 // Build equivalence sets.
162 // Helper function that finds an equivalent block argument index for the
163 // given OpOperand. Return std::nullopt if no equivalent block argument could
164 // be found.
165 auto findEquivalentBlockArgIdx =
166 [&](OpOperand &opOperand) -> std::optional<int64_t> {
167 Value v = opOperand.get();
168 if (!isa<TensorType>(v.getType()))
169 return std::nullopt;
170 for (BlockArgument bbArg : funcOp.getArguments()) {
171 if (isa<RankedTensorType>(bbArg.getType())) {
172 if (state.areEquivalentBufferizedValues(v, bbArg)) {
173 if (state.getOptions().testAnalysisOnly)
174 annotateEquivalentReturnBbArg(opOperand, bbArg);
175 return bbArg.getArgNumber();
176 }
177 }
178 }
179 return std::nullopt;
180 };
181
182 int64_t numResults = returnOps.front()->getNumOperands();
183 for (int64_t i = 0; i < numResults; ++i) {
184 // Find the equivalent block argument index for the i-th operand of the
185 // first func.return op.
186 std::optional<int64_t> maybeEquiv =
187 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
188 if (!maybeEquiv.has_value())
189 continue;
190 int64_t bbArgIdx = *maybeEquiv;
191 bool allEquiv = true;
192
193 // Check if all other func.return ops have the same equivalent block
194 // argument for the i-th operand. In contrast to aliasing information,
195 // which is just "merged", equivalence information must match across all
196 // func.return ops.
197 for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
198 std::optional<int64_t> maybeEquiv =
199 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
200 if (maybeEquiv != bbArgIdx) {
201 allEquiv = false;
202 break;
203 }
204 }
205
206 // All func.return ops have the same equivalent block argument for the i-th
207 // operand.
208 if (allEquiv)
209 funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
210 }
211
212 return success();
213}
214
215static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
216 bool isWritten) {
217 OpBuilder b(funcOp.getContext());
218 Attribute accessType;
219 if (isRead && isWritten) {
220 accessType = b.getStringAttr("read-write");
221 } else if (isRead) {
222 accessType = b.getStringAttr("read");
223 } else if (isWritten) {
224 accessType = b.getStringAttr("write");
225 } else {
226 accessType = b.getStringAttr("none");
227 }
228 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
229 accessType);
230}
231
232/// Determine which FuncOp bbArgs are read and which are written. When run on a
233/// function with unknown ops, we conservatively assume that such ops bufferize
234/// to a read + write.
235static LogicalResult
236funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
237 FuncAnalysisState &funcState) {
238 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
239 ++idx) {
240 // Skip non-tensor arguments.
241 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
242 continue;
243 bool isRead;
244 bool isWritten;
245 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
246 idx, BufferizationDialect::kBufferAccessAttrName)) {
247 // Buffer access behavior is specified on the function. Skip the analysis.
248 StringRef str = accessAttr.getValue();
249 isRead = str == "read" || str == "read-write";
250 isWritten = str == "write" || str == "read-write";
251 } else if (funcOp.getBody().empty()) {
252 // If the function has no body, conservatively assume that all args are
253 // read + written.
254 isRead = true;
255 isWritten = true;
256 } else {
257 // Analyze the body of the function.
258 BlockArgument bbArg = funcOp.getArgument(idx);
259 isRead = state.isValueRead(bbArg);
260 isWritten = state.isValueWritten(bbArg);
261 }
262
263 if (state.getOptions().testAnalysisOnly)
264 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
265 if (isRead)
266 funcState.readBbArgs[funcOp].insert(idx);
267 if (isWritten)
268 funcState.writtenBbArgs[funcOp].insert(idx);
269 }
270
271 return success();
272}
273} // namespace
274
275/// Remove bufferization attributes on FuncOp arguments.
277 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
278 funcOp.removeArgAttr(bbArg.getArgNumber(),
279 BufferizationDialect::kBufferLayoutAttrName);
280 funcOp.removeArgAttr(bbArg.getArgNumber(),
281 BufferizationDialect::kWritableAttrName);
282}
283
284/// Return the func::FuncOp called by `callOp`.
285static func::FuncOp
286getCalledFunction(func::CallOp callOp,
287 mlir::SymbolTableCollection &symbolTable) {
288 return dyn_cast_or_null<func::FuncOp>(
289 callOp.resolveCallableInTable(&symbolTable));
290}
291
292/// Return "true" if the given function signature has tensor semantics.
293static bool hasTensorSignature(func::FuncOp funcOp) {
294 return llvm::any_of(funcOp.getFunctionType().getInputs(),
295 llvm::IsaPred<TensorType>) ||
296 llvm::any_of(funcOp.getFunctionType().getResults(),
297 llvm::IsaPred<TensorType>);
298}
299
300/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
301/// callee-caller order (i.e., callees without callers first). Store all
302/// remaining functions (i.e., the ones that call each other recursively) in
303/// `remainingFuncOps`. Does not traverse nested symbol tables.
304///
305/// Store the map of FuncOp to all its callers in `callerMap`.
306///
307/// Return `failure()` if we are unable to retrieve the called FuncOp from
308/// any func::CallOp.
309static LogicalResult getFuncOpsOrderedByCalls(
310 Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
311 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
312 SymbolTableCollection &symbolTables) {
313 // For each FuncOp, the set of functions called by it (i.e. the union of
314 // symbols of all nested func::CallOp).
316 // For each FuncOp, the number of func::CallOp it contains.
317 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
318 for (mlir::Region &region : moduleOp->getRegions()) {
319 for (mlir::Block &block : region.getBlocks()) {
320 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
321 // Collect function calls and populate the caller map.
322 numberCallOpsContainedInFuncOp[funcOp] = 0;
323 WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
324 func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
325 assert(calledFunction && "could not retrieved called func::FuncOp");
326 // If the called function does not have any tensors in its signature,
327 // then it is not necessary to bufferize the callee before the caller.
328 if (!hasTensorSignature(calledFunction))
329 return WalkResult::skip();
330
331 callerMap[calledFunction].insert(callOp);
332 if (calledBy[calledFunction].insert(funcOp).second) {
333 numberCallOpsContainedInFuncOp[funcOp]++;
334 }
335 return WalkResult::advance();
336 });
337 if (res.wasInterrupted())
338 return failure();
339 }
340 }
341 }
342
343 // Iteratively remove function operations that do not call any of the
344 // functions remaining in the callCounter map and add them to ordered list.
346
347 for (const auto &entry : numberCallOpsContainedInFuncOp) {
348 if (entry.second == 0)
349 worklist.push_back(entry.first);
350 }
351
352 while (!worklist.empty()) {
353 func::FuncOp func = worklist.pop_back_val();
354 orderedFuncOps.push_back(func);
355
356 for (func::FuncOp caller : calledBy[func]) {
357 auto &count = numberCallOpsContainedInFuncOp[caller];
358
359 if (--count == 0)
360 worklist.push_back(caller);
361 }
362
363 numberCallOpsContainedInFuncOp.erase(func);
364 }
365
366 // Put all other functions in the list of remaining functions. These are
367 // functions that call each other circularly.
368 for (auto it : numberCallOpsContainedInFuncOp)
369 remainingFuncOps.push_back(it.first);
370
371 return success();
372}
373
374/// Helper function that extracts the source from a memref.cast. If the given
375/// value is not a memref.cast result, simply returns the given value.
377 auto castOp = v.getDefiningOp<memref::CastOp>();
378 if (!castOp)
379 return v;
380 return castOp.getSource();
381}
382
383/// Helper function that returns the return types (skipping casts) of the given
384/// func.return ops. This function returns as many types as the return ops have
385/// operands. If the i-th operand is not the same for all func.return ops, then
386/// the i-th returned type is an "empty" type.
388 assert(!returnOps.empty() && "expected at least one ReturnOp");
389 int numOperands = returnOps.front()->getNumOperands();
390
391 // Helper function that unpacks memref.cast ops and returns the type.
392 auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
393
395 for (int i = 0; i < numOperands; ++i) {
396 // Get the type of the i-th operand of the first func.return ops.
397 Type t = getSourceType(returnOps.front()->getOperand(i));
398
399 // Check if all other func.return ops have a matching operand type.
400 for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
401 if (getSourceType(returnOps[j]->getOperand(i)) != t)
402 t = Type();
403
404 result.push_back(t);
405 }
406
407 return result;
408}
409
410/// Fold return values that are memref casts and update function return types.
411///
412/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
413/// is not known yet. Therefore, the bufferization uses memref types with the
414/// most generic layout map as function return types. After bufferizing the
415/// entire function body, a more concise memref type can potentially be used for
416/// the return type of the function.
417static void foldMemRefCasts(func::FuncOp funcOp) {
418 // There is nothing to do for bodiless ops.
419 if (funcOp.getBody().empty())
420 return;
421
422 // Compute the common result types of all return ops.
423 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
424 SmallVector<Type> resultTypes = getReturnTypes(returnOps);
425
426 // Remove direct casts.
427 for (func::ReturnOp returnOp : returnOps) {
428 for (OpOperand &operand : returnOp->getOpOperands()) {
429 // Bail if no common result type was found.
430 if (resultTypes[operand.getOperandNumber()]) {
431 operand.set(unpackCast(operand.get()));
432 }
433 }
434 }
435
436 // Fill in the missing result types that were not the same among all
437 // func.return ops.
438 for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
439 if (resultTypes[i])
440 continue;
441 resultTypes[i] = funcOp.getFunctionType().getResult(i);
442 }
443
444 // Update the function type.
445 auto newFuncType = FunctionType::get(
446 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
447 funcOp.setType(newFuncType);
448}
449
450LogicalResult
453 BufferizationStatistics *statistics) {
454 assert(state.getOptions().bufferizeFunctionBoundaries &&
455 "expected that function boundary bufferization is activated");
457
458 // A list of non-circular functions in the order in which they are analyzed
459 // and bufferized.
460 SmallVector<func::FuncOp> orderedFuncOps;
461 // A list of all other functions. I.e., functions that call each other
462 // recursively. For these, we analyze the function body but not the function
463 // boundary.
464 SmallVector<func::FuncOp> remainingFuncOps;
465
466 // A mapping of FuncOps to their callers.
467 FuncCallerMap callerMap;
468
469 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
470 remainingFuncOps, callerMap,
471 funcState.symbolTables)))
472 return failure();
473
474 // Analyze functions in order. Starting with functions that are not calling
475 // any other functions.
476 for (func::FuncOp funcOp : orderedFuncOps) {
477 if (!state.getOptions().isOpAllowed(funcOp))
478 continue;
479
480 // Now analyzing function.
481 funcState.startFunctionAnalysis(funcOp);
482
483 // Analyze funcOp.
484 if (failed(analyzeOp(funcOp, state, statistics)))
485 return failure();
486
487 // Run some extra function analyses.
488 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
489 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
490 return failure();
491
492 // Mark op as fully analyzed.
494 }
495
496 // Analyze all other functions. All function boundary analyses are skipped.
497 for (func::FuncOp funcOp : remainingFuncOps) {
498 if (!state.getOptions().isOpAllowed(funcOp))
499 continue;
500
501 // Analyze funcOp.
502 if (failed(analyzeOp(funcOp, state, statistics)))
503 return failure();
504
505 // TODO: We currently skip all function argument analyses for functions
506 // that call each other circularly. These analyses do not support recursive
507 // calls yet. The `BufferizableOpInterface` implementations of `func`
508 // dialect ops return conservative results in the absence of analysis
509 // information.
510 }
511
512 return success();
513}
514
516 Operation *moduleOp) {
517 for (mlir::Region &region : moduleOp->getRegions()) {
518 for (mlir::Block &block : region.getBlocks()) {
519 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
520 for (BlockArgument bbArg : funcOp.getArguments())
522 }
523 }
524 }
525}
526
529 BufferizationState &state, BufferizationStatistics *statistics) {
530 assert(options.bufferizeFunctionBoundaries &&
531 "expected that function boundary bufferization is activated");
532 IRRewriter rewriter(moduleOp->getContext());
533
534 // A list of non-circular functions in the order in which they are analyzed
535 // and bufferized.
536 SmallVector<func::FuncOp> orderedFuncOps;
537 // A list of all other functions. I.e., functions that call each other
538 // recursively. For these, we analyze the function body but not the function
539 // boundary.
540 SmallVector<func::FuncOp> remainingFuncOps;
541
542 // A mapping of FuncOps to their callers.
543 FuncCallerMap callerMap;
544
545 // Try to bufferize functions in calling order. I.e., first bufferize
546 // functions that do not call other functions. This allows us to infer
547 // accurate buffer types for function return values. Functions that call
548 // each other recursively are bufferized in an unspecified order at the end.
549 // We may use unnecessarily "complex" (in terms of layout map) buffer types.
550 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
551 remainingFuncOps, callerMap,
552 state.getSymbolTables())))
553 return failure();
554 llvm::append_range(orderedFuncOps, remainingFuncOps);
555
556 // Bufferize functions.
557 for (func::FuncOp funcOp : orderedFuncOps) {
558 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
559 // would be invalidated.
560
561 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
562 // This function was not analyzed and RaW conflicts were not resolved.
563 // Buffer copies must be inserted before every write.
564 OneShotBufferizationOptions updatedOptions = options;
565 updatedOptions.copyBeforeWrite = true;
566 if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
567 return failure();
568 } else {
569 if (failed(bufferizeOp(funcOp, options, state, statistics)))
570 return failure();
571 }
572
573 // Change buffer return types to more precise layout maps.
574 if (options.inferFunctionResultLayout)
575 foldMemRefCasts(funcOp);
576 }
577
578 // Bufferize all other ops.
579 for (mlir::Region &region : moduleOp->getRegions()) {
580 for (mlir::Block &block : region.getBlocks()) {
581 for (mlir::Operation &op :
582 llvm::make_early_inc_range(block.getOperations())) {
583 // Functions were already bufferized.
584 if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
585 continue;
586 if (failed(bufferizeOp(&op, options, state, statistics)))
587 return failure();
588 }
589 }
590 }
591
592 // Post-pass cleanup of function argument attributes.
594
595 return success();
596}
597
600 BufferizationState &state, BufferizationStatistics *statistics) {
601 assert(options.bufferizeFunctionBoundaries &&
602 "expected that function boundary bufferization is activated");
603 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
604 "invalid combination of bufferization flags");
605 if (!options.copyBeforeWrite) {
606 if (options.noAnalysisFuncFilter.empty()) {
607 if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
608 return failure();
609 } else {
610 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
611 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
612 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
613 auto func = dyn_cast<func::FuncOp>(op);
614 if (!func)
615 func = op->getParentOfType<func::FuncOp>();
616 if (func)
617 return llvm::is_contained(options.noAnalysisFuncFilter,
618 func.getSymName());
619 return false;
620 };
621 OneShotBufferizationOptions updatedOptions(options);
622 updatedOptions.opFilter.denyOperation(analysisFilterFn);
623 if (failed(
624 insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
625 return failure();
626 }
627 }
628 if (options.testAnalysisOnly)
629 return success();
630 if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
631 return failure();
632 return success();
633}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
DenseMap< func::FuncOp, DenseSet< Operation * > > FuncCallerMap
A mapping of FuncOps to their callers.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
State for analysis-enabled bufferization.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
Ty * getExtension()
Returns the extension of the specified type.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.