MLIR 22.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2//----===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// Module Bufferization is an extension of One-Shot Bufferize that
11// bufferizes function boundaries. It provides `BufferizableOpInterface`
12// implementations for FuncOp, CallOp and ReturnOp. Although it is named
13// Module Bufferization, it may operate on any SymbolTable.
14//
15// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16// ...)`. This function analyzes the given op and determines the order of
17// analysis and bufferization: Functions that are called are processed before
18// their respective callers.
19//
20// After analyzing a FuncOp, additional information about its bbArgs is
21// gathered and stored in `FuncAnalysisState`.
22//
23// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
24// for
25// each tensor return value (if any).
26// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
27// read/written.
28//
29// Module Bufferization implements the following calling convention.
30//
31// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
32// be written to in-place.
33// * If a tensor operand of a CallOp is read after the CallOp, the operand of
34// the CallOp must bufferize out-of-place.
35//
36// Example: The tensor.insert op bufferizes in-place because it is allowed to
37// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
38// out-of-place because `%t0` is modified by the callee but read by the
39// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
40// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
41// ```
42// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
43// %f = ... : f32
44// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
45// return %0 : tensor<?xf32>
46// }
47//
48// func @caller() -> () {
49// %t0 = ... : tensor<?xf32>
50// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
51// %2 = tensor.extract %1[...] : tensor<?xf32>
52// }
53// ```
54//
55// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
56// analyze the function body. In such a case, the CallOp analysis conservatively
57// assumes that each tensor OpOperand is both read and written.
58//
59// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
60// as "not reading" and/or "not writing".
61
63
73#include "mlir/IR/Operation.h"
74
75using namespace mlir;
76using namespace mlir::bufferization;
77using namespace mlir::bufferization::func_ext;
78
79/// A mapping of FuncOps to their callers.
81
82/// Get or create FuncAnalysisState.
83static FuncAnalysisState &
90
91namespace {
92
93/// Annotate IR with the results of the analysis. For testing purposes only.
94static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
95 BlockArgument bbArg) {
96 const char *kEquivalentArgsAttr = "__equivalent_func_args__";
97 Operation *op = returnVal.getOwner();
98
99 SmallVector<int64_t> equivBbArgs;
100 if (op->hasAttr(kEquivalentArgsAttr)) {
101 auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
102 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
103 return cast<IntegerAttr>(a).getValue().getSExtValue();
104 }));
105 } else {
106 equivBbArgs.append(op->getNumOperands(), -1);
107 }
108 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
109
110 OpBuilder b(op->getContext());
111 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
112}
113
114/// Store function BlockArguments that are equivalent to/aliasing a returned
115/// value in FuncAnalysisState.
116static LogicalResult
117aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
118 FuncAnalysisState &funcState) {
119 if (funcOp.getBody().empty()) {
120 // No function body available. Conservatively assume that every tensor
121 // return value may alias with any tensor bbArg.
122 FunctionType type = funcOp.getFunctionType();
123 for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
124 if (!isa<TensorType>(inputIt.value()))
125 continue;
126 for (const auto &resultIt : llvm::enumerate(type.getResults())) {
127 if (!isa<TensorType>(resultIt.value()))
128 continue;
129 int64_t returnIdx = resultIt.index();
130 int64_t bbArgIdx = inputIt.index();
131 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
132 }
133 }
134 return success();
135 }
136
137 // Find all func.return ops.
138 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
139 // TODO: throw error when there is any non-func.return op that has the
140 // ReturnLike trait
141 if (returnOps.empty()) {
142 return funcOp.emitError("cannot bufferize func.func without func.return");
143 }
144
145 // Build alias sets. Merge all aliases from all func.return ops.
146 for (BlockArgument bbArg : funcOp.getArguments()) {
147 if (isa<RankedTensorType>(bbArg.getType())) {
148 int64_t bbArgIdx = bbArg.getArgNumber();
149 // Store aliases in a set, so that we don't add the same alias twice.
150 SetVector<int64_t> aliases;
151 for (func::ReturnOp returnOp : returnOps) {
152 for (OpOperand &returnVal : returnOp->getOpOperands()) {
153 if (isa<RankedTensorType>(returnVal.get().getType())) {
154 int64_t returnIdx = returnVal.getOperandNumber();
155 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
156 aliases.insert(returnIdx);
157 }
158 }
159 }
160 for (int64_t alias : aliases)
161 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
162 }
163 }
164
165 // Build equivalence sets.
166 // Helper function that finds an equivalent block argument index for the
167 // given OpOperand. Return std::nullopt if no equivalent block argument could
168 // be found.
169 auto findEquivalentBlockArgIdx =
170 [&](OpOperand &opOperand) -> std::optional<int64_t> {
171 Value v = opOperand.get();
172 if (!isa<TensorType>(v.getType()))
173 return std::nullopt;
174 for (BlockArgument bbArg : funcOp.getArguments()) {
175 if (isa<RankedTensorType>(bbArg.getType())) {
176 if (state.areEquivalentBufferizedValues(v, bbArg)) {
177 if (state.getOptions().testAnalysisOnly)
178 annotateEquivalentReturnBbArg(opOperand, bbArg);
179 return bbArg.getArgNumber();
180 }
181 }
182 }
183 return std::nullopt;
184 };
185
186 int64_t numResults = returnOps.front()->getNumOperands();
187 for (int64_t i = 0; i < numResults; ++i) {
188 // Find the equivalent block argument index for the i-th operand of the
189 // first func.return op.
190 std::optional<int64_t> maybeEquiv =
191 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
192 if (!maybeEquiv.has_value())
193 continue;
194 int64_t bbArgIdx = *maybeEquiv;
195 bool allEquiv = true;
196
197 // Check if all other func.return ops have the same equivalent block
198 // argument for the i-th operand. In contrast to aliasing information,
199 // which is just "merged", equivalence information must match across all
200 // func.return ops.
201 for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
202 std::optional<int64_t> maybeEquiv =
203 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
204 if (maybeEquiv != bbArgIdx) {
205 allEquiv = false;
206 break;
207 }
208 }
209
210 // All func.return ops have the same equivalent block argument for the i-th
211 // operand.
212 if (allEquiv)
213 funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
214 }
215
216 return success();
217}
218
219static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
220 bool isWritten) {
221 OpBuilder b(funcOp.getContext());
222 Attribute accessType;
223 if (isRead && isWritten) {
224 accessType = b.getStringAttr("read-write");
225 } else if (isRead) {
226 accessType = b.getStringAttr("read");
227 } else if (isWritten) {
228 accessType = b.getStringAttr("write");
229 } else {
230 accessType = b.getStringAttr("none");
231 }
232 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
233 accessType);
234}
235
236/// Determine which FuncOp bbArgs are read and which are written. When run on a
237/// function with unknown ops, we conservatively assume that such ops bufferize
238/// to a read + write.
239static LogicalResult
240funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
241 FuncAnalysisState &funcState) {
242 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
243 ++idx) {
244 // Skip non-tensor arguments.
245 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
246 continue;
247 bool isRead;
248 bool isWritten;
249 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
250 idx, BufferizationDialect::kBufferAccessAttrName)) {
251 // Buffer access behavior is specified on the function. Skip the analysis.
252 StringRef str = accessAttr.getValue();
253 isRead = str == "read" || str == "read-write";
254 isWritten = str == "write" || str == "read-write";
255 } else if (funcOp.getBody().empty()) {
256 // If the function has no body, conservatively assume that all args are
257 // read + written.
258 isRead = true;
259 isWritten = true;
260 } else {
261 // Analyze the body of the function.
262 BlockArgument bbArg = funcOp.getArgument(idx);
263 isRead = state.isValueRead(bbArg);
264 isWritten = state.isValueWritten(bbArg);
265 }
266
267 if (state.getOptions().testAnalysisOnly)
268 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
269 if (isRead)
270 funcState.readBbArgs[funcOp].insert(idx);
271 if (isWritten)
272 funcState.writtenBbArgs[funcOp].insert(idx);
273 }
274
275 return success();
276}
277} // namespace
278
279/// Remove bufferization attributes on FuncOp arguments.
281 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
282 funcOp.removeArgAttr(bbArg.getArgNumber(),
283 BufferizationDialect::kBufferLayoutAttrName);
284 funcOp.removeArgAttr(bbArg.getArgNumber(),
285 BufferizationDialect::kWritableAttrName);
286}
287
288/// Return the func::FuncOp called by `callOp`.
289static func::FuncOp
290getCalledFunction(func::CallOp callOp,
291 mlir::SymbolTableCollection &symbolTable) {
292 return dyn_cast_or_null<func::FuncOp>(
293 callOp.resolveCallableInTable(&symbolTable));
294}
295
296/// Return "true" if the given function signature has tensor semantics.
297static bool hasTensorSignature(func::FuncOp funcOp) {
298 return llvm::any_of(funcOp.getFunctionType().getInputs(),
299 llvm::IsaPred<TensorType>) ||
300 llvm::any_of(funcOp.getFunctionType().getResults(),
301 llvm::IsaPred<TensorType>);
302}
303
304/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
305/// callee-caller order (i.e., callees without callers first). Store all
306/// remaining functions (i.e., the ones that call each other recursively) in
307/// `remainingFuncOps`. Does not traverse nested symbol tables.
308///
309/// Store the map of FuncOp to all its callers in `callerMap`.
310///
311/// Return `failure()` if we are unable to retrieve the called FuncOp from
312/// any func::CallOp.
313static LogicalResult getFuncOpsOrderedByCalls(
314 Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
315 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
316 SymbolTableCollection &symbolTables) {
317 // For each FuncOp, the set of functions called by it (i.e. the union of
318 // symbols of all nested func::CallOp).
320 // For each FuncOp, the number of func::CallOp it contains.
321 DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
322 for (mlir::Region &region : moduleOp->getRegions()) {
323 for (mlir::Block &block : region.getBlocks()) {
324 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
325 // Collect function calls and populate the caller map.
326 numberCallOpsContainedInFuncOp[funcOp] = 0;
327 WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
328 func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
329 assert(calledFunction && "could not retrieved called func::FuncOp");
330 // If the called function does not have any tensors in its signature,
331 // then it is not necessary to bufferize the callee before the caller.
332 if (!hasTensorSignature(calledFunction))
333 return WalkResult::skip();
334
335 callerMap[calledFunction].insert(callOp);
336 if (calledBy[calledFunction].insert(funcOp).second) {
337 numberCallOpsContainedInFuncOp[funcOp]++;
338 }
339 return WalkResult::advance();
340 });
341 if (res.wasInterrupted())
342 return failure();
343 }
344 }
345 }
346
347 // Iteratively remove function operations that do not call any of the
348 // functions remaining in the callCounter map and add them to ordered list.
350
351 for (const auto &entry : numberCallOpsContainedInFuncOp) {
352 if (entry.second == 0)
353 worklist.push_back(entry.first);
354 }
355
356 while (!worklist.empty()) {
357 func::FuncOp func = worklist.pop_back_val();
358 orderedFuncOps.push_back(func);
359
360 for (func::FuncOp caller : calledBy[func]) {
361 auto &count = numberCallOpsContainedInFuncOp[caller];
362
363 if (--count == 0)
364 worklist.push_back(caller);
365 }
366
367 numberCallOpsContainedInFuncOp.erase(func);
368 }
369
370 // Put all other functions in the list of remaining functions. These are
371 // functions that call each other circularly.
372 for (auto it : numberCallOpsContainedInFuncOp)
373 remainingFuncOps.push_back(it.first);
374
375 return success();
376}
377
378/// Helper function that extracts the source from a memref.cast. If the given
379/// value is not a memref.cast result, simply returns the given value.
381 auto castOp = v.getDefiningOp<memref::CastOp>();
382 if (!castOp)
383 return v;
384 return castOp.getSource();
385}
386
387/// Helper function that returns the return types (skipping casts) of the given
388/// func.return ops. This function returns as many types as the return ops have
389/// operands. If the i-th operand is not the same for all func.return ops, then
390/// the i-th returned type is an "empty" type.
392 assert(!returnOps.empty() && "expected at least one ReturnOp");
393 int numOperands = returnOps.front()->getNumOperands();
394
395 // Helper function that unpacks memref.cast ops and returns the type.
396 auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
397
399 for (int i = 0; i < numOperands; ++i) {
400 // Get the type of the i-th operand of the first func.return ops.
401 Type t = getSourceType(returnOps.front()->getOperand(i));
402
403 // Check if all other func.return ops have a matching operand type.
404 for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
405 if (getSourceType(returnOps[j]->getOperand(i)) != t)
406 t = Type();
407
408 result.push_back(t);
409 }
410
411 return result;
412}
413
414/// Fold return values that are memref casts and update function return types.
415///
416/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
417/// is not known yet. Therefore, the bufferization uses memref types with the
418/// most generic layout map as function return types. After bufferizing the
419/// entire function body, a more concise memref type can potentially be used for
420/// the return type of the function.
421static void foldMemRefCasts(func::FuncOp funcOp) {
422 // There is nothing to do for bodiless ops.
423 if (funcOp.getBody().empty())
424 return;
425
426 // Compute the common result types of all return ops.
427 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
428 SmallVector<Type> resultTypes = getReturnTypes(returnOps);
429
430 // Remove direct casts.
431 for (func::ReturnOp returnOp : returnOps) {
432 for (OpOperand &operand : returnOp->getOpOperands()) {
433 // Bail if no common result type was found.
434 if (resultTypes[operand.getOperandNumber()]) {
435 operand.set(unpackCast(operand.get()));
436 }
437 }
438 }
439
440 // Fill in the missing result types that were not the same among all
441 // func.return ops.
442 for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
443 if (resultTypes[i])
444 continue;
445 resultTypes[i] = funcOp.getFunctionType().getResult(i);
446 }
447
448 // Update the function type.
449 auto newFuncType = FunctionType::get(
450 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
451 funcOp.setType(newFuncType);
452}
453
454LogicalResult
457 BufferizationStatistics *statistics) {
458 assert(state.getOptions().bufferizeFunctionBoundaries &&
459 "expected that function boundary bufferization is activated");
461
462 // A list of non-circular functions in the order in which they are analyzed
463 // and bufferized.
464 SmallVector<func::FuncOp> orderedFuncOps;
465 // A list of all other functions. I.e., functions that call each other
466 // recursively. For these, we analyze the function body but not the function
467 // boundary.
468 SmallVector<func::FuncOp> remainingFuncOps;
469
470 // A mapping of FuncOps to their callers.
471 FuncCallerMap callerMap;
472
473 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
474 remainingFuncOps, callerMap,
475 funcState.symbolTables)))
476 return failure();
477
478 // Analyze functions in order. Starting with functions that are not calling
479 // any other functions.
480 for (func::FuncOp funcOp : orderedFuncOps) {
481 if (!state.getOptions().isOpAllowed(funcOp))
482 continue;
483
484 // Now analyzing function.
485 funcState.startFunctionAnalysis(funcOp);
486
487 // Analyze funcOp.
488 if (failed(analyzeOp(funcOp, state, statistics)))
489 return failure();
490
491 // Run some extra function analyses.
492 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
493 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
494 return failure();
495
496 // Mark op as fully analyzed.
498 }
499
500 // Analyze all other functions. All function boundary analyses are skipped.
501 for (func::FuncOp funcOp : remainingFuncOps) {
502 if (!state.getOptions().isOpAllowed(funcOp))
503 continue;
504
505 // Analyze funcOp.
506 if (failed(analyzeOp(funcOp, state, statistics)))
507 return failure();
508
509 // TODO: We currently skip all function argument analyses for functions
510 // that call each other circularly. These analyses do not support recursive
511 // calls yet. The `BufferizableOpInterface` implementations of `func`
512 // dialect ops return conservative results in the absence of analysis
513 // information.
514 }
515
516 return success();
517}
518
520 Operation *moduleOp) {
521 for (mlir::Region &region : moduleOp->getRegions()) {
522 for (mlir::Block &block : region.getBlocks()) {
523 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
524 for (BlockArgument bbArg : funcOp.getArguments())
526 }
527 }
528 }
529}
530
533 BufferizationState &state, BufferizationStatistics *statistics) {
534 assert(options.bufferizeFunctionBoundaries &&
535 "expected that function boundary bufferization is activated");
536 IRRewriter rewriter(moduleOp->getContext());
537
538 // A list of non-circular functions in the order in which they are analyzed
539 // and bufferized.
540 SmallVector<func::FuncOp> orderedFuncOps;
541 // A list of all other functions. I.e., functions that call each other
542 // recursively. For these, we analyze the function body but not the function
543 // boundary.
544 SmallVector<func::FuncOp> remainingFuncOps;
545
546 // A mapping of FuncOps to their callers.
547 FuncCallerMap callerMap;
548
549 // Try to bufferize functions in calling order. I.e., first bufferize
550 // functions that do not call other functions. This allows us to infer
551 // accurate buffer types for function return values. Functions that call
552 // each other recursively are bufferized in an unspecified order at the end.
553 // We may use unnecessarily "complex" (in terms of layout map) buffer types.
554 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
555 remainingFuncOps, callerMap,
556 state.getSymbolTables())))
557 return failure();
558 llvm::append_range(orderedFuncOps, remainingFuncOps);
559
560 // Bufferize functions.
561 for (func::FuncOp funcOp : orderedFuncOps) {
562 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
563 // would be invalidated.
564
565 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
566 // This function was not analyzed and RaW conflicts were not resolved.
567 // Buffer copies must be inserted before every write.
568 OneShotBufferizationOptions updatedOptions = options;
569 updatedOptions.copyBeforeWrite = true;
570 if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
571 return failure();
572 } else {
573 if (failed(bufferizeOp(funcOp, options, state, statistics)))
574 return failure();
575 }
576
577 // Change buffer return types to more precise layout maps.
578 if (options.inferFunctionResultLayout)
579 foldMemRefCasts(funcOp);
580 }
581
582 // Bufferize all other ops.
583 for (mlir::Region &region : moduleOp->getRegions()) {
584 for (mlir::Block &block : region.getBlocks()) {
585 for (mlir::Operation &op :
586 llvm::make_early_inc_range(block.getOperations())) {
587 // Functions were already bufferized.
588 if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
589 continue;
590 if (failed(bufferizeOp(&op, options, state, statistics)))
591 return failure();
592 }
593 }
594 }
595
596 // Post-pass cleanup of function argument attributes.
598
599 return success();
600}
601
604 BufferizationState &state, BufferizationStatistics *statistics) {
605 assert(options.bufferizeFunctionBoundaries &&
606 "expected that function boundary bufferization is activated");
607 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
608 "invalid combination of bufferization flags");
609 if (!options.copyBeforeWrite) {
610 if (options.noAnalysisFuncFilter.empty()) {
611 if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
612 return failure();
613 } else {
614 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
615 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
616 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
617 auto func = dyn_cast<func::FuncOp>(op);
618 if (!func)
619 func = op->getParentOfType<func::FuncOp>();
620 if (func)
621 return llvm::is_contained(options.noAnalysisFuncFilter,
622 func.getSymName());
623 return false;
624 };
625 OneShotBufferizationOptions updatedOptions(options);
626 updatedOptions.opFilter.denyOperation(analysisFilterFn);
627 if (failed(
628 insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
629 return failure();
630 }
631 }
632 if (options.testAnalysisOnly)
633 return success();
634 if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
635 return failure();
636 return success();
637}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
DenseMap< func::FuncOp, DenseSet< Operation * > > FuncCallerMap
A mapping of FuncOps to their callers.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
unsigned getNumOperands()
Definition Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
State for analysis-enabled bufferization.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
Ty * getExtension()
Returns the extension of the specified type.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.