MLIR 23.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2//----===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// Module Bufferization is an extension of One-Shot Bufferize that
11// bufferizes function boundaries. It provides `BufferizableOpInterface`
12// implementations for FuncOp, CallOp and ReturnOp. Although it is named
13// Module Bufferization, it may operate on any SymbolTable.
14//
15// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16// ...)`. This function analyzes the given op and determines the order of
17// analysis and bufferization: Functions that are called are processed before
18// their respective callers.
19//
20// After analyzing a FuncOp, additional information about its bbArgs is
21// gathered and stored in `FuncAnalysisState`.
22//
23// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
24// for
25// each tensor return value (if any).
26// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
27// read/written.
28//
29// Module Bufferization implements the following calling convention.
30//
31// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
32// be written to in-place.
33// * If a tensor operand of a CallOp is read after the CallOp, the operand of
34// the CallOp must bufferize out-of-place.
35//
36// Example: The tensor.insert op bufferizes in-place because it is allowed to
37// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
38// out-of-place because `%t0` is modified by the callee but read by the
39// tensor.extract op. The analysis of CallOps decides whether an OpOperand must
40// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
41// ```
42// func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
43// %f = ... : f32
44// %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
45// return %0 : tensor<?xf32>
46// }
47//
48// func @caller() -> () {
49// %t0 = ... : tensor<?xf32>
50// %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
51// %2 = tensor.extract %1[...] : tensor<?xf32>
52// }
53// ```
54//
55// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
56// analyze the function body. In such a case, the CallOp analysis conservatively
57// assumes that each tensor OpOperand is both read and written.
58//
59// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
60// as "not reading" and/or "not writing".
61
73#include "mlir/IR/Operation.h"
74#include "llvm/ADT/MapVector.h"
75#include "llvm/ADT/SmallVectorExtras.h"
76
77using namespace mlir;
78using namespace mlir::bufferization;
79using namespace mlir::bufferization::func_ext;
80
81/// A mapping of FuncOps to their callers.
83
84/// Get or create FuncAnalysisState.
85static FuncAnalysisState &
92
93namespace {
94
95/// Annotate IR with the results of the analysis. For testing purposes only.
96static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
97 BlockArgument bbArg) {
98 const char *kEquivalentArgsAttr = "__equivalent_func_args__";
99 Operation *op = returnVal.getOwner();
100
101 SmallVector<int64_t> equivBbArgs;
102 if (op->hasAttr(kEquivalentArgsAttr)) {
103 auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
104 equivBbArgs = llvm::map_to_vector<4>(attr, [](Attribute a) {
105 return cast<IntegerAttr>(a).getValue().getSExtValue();
106 });
107 } else {
108 equivBbArgs.append(op->getNumOperands(), -1);
109 }
110 equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
111
112 OpBuilder b(op->getContext());
113 op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
114}
115
116/// Store function BlockArguments that are equivalent to/aliasing a returned
117/// value in FuncAnalysisState.
118static LogicalResult
119aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
120 FuncAnalysisState &funcState) {
121 if (funcOp.getBody().empty()) {
122 // No function body available. Conservatively assume that every tensor
123 // return value may alias with any tensor bbArg.
124 FunctionType type = funcOp.getFunctionType();
125 for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
126 if (!isa<TensorLikeType>(inputIt.value()))
127 continue;
128 for (const auto &resultIt : llvm::enumerate(type.getResults())) {
129 if (!isa<TensorLikeType>(resultIt.value()))
130 continue;
131 int64_t returnIdx = resultIt.index();
132 int64_t bbArgIdx = inputIt.index();
133 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
134 }
135 }
136 return success();
137 }
138
139 // Find all func.return ops.
140 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
141 // TODO: throw error when there is any non-func.return op that has the
142 // ReturnLike trait
143 if (returnOps.empty()) {
144 return funcOp.emitError("cannot bufferize func.func without func.return");
145 }
146
147 // Build alias sets. Merge all aliases from all func.return ops.
148 for (BlockArgument bbArg : funcOp.getArguments()) {
149 if (isa<TensorLikeType>(bbArg.getType())) {
150 int64_t bbArgIdx = bbArg.getArgNumber();
151 // Store aliases in a set, so that we don't add the same alias twice.
152 SetVector<int64_t> aliases;
153 for (func::ReturnOp returnOp : returnOps) {
154 for (OpOperand &returnVal : returnOp->getOpOperands()) {
155 if (isa<TensorLikeType>(returnVal.get().getType())) {
156 int64_t returnIdx = returnVal.getOperandNumber();
157 if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
158 aliases.insert(returnIdx);
159 }
160 }
161 }
162 for (int64_t alias : aliases)
163 funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
164 }
165 }
166
167 // Build equivalence sets.
168 // Helper function that finds an equivalent block argument index for the
169 // given OpOperand. Return std::nullopt if no equivalent block argument could
170 // be found.
171 auto findEquivalentBlockArgIdx =
172 [&](OpOperand &opOperand) -> std::optional<int64_t> {
173 Value v = opOperand.get();
174 if (!isa<TensorLikeType>(v.getType()))
175 return std::nullopt;
176 for (BlockArgument bbArg : funcOp.getArguments()) {
177 if (isa<TensorLikeType>(bbArg.getType())) {
178 if (state.areEquivalentBufferizedValues(v, bbArg)) {
179 if (state.getOptions().testAnalysisOnly)
180 annotateEquivalentReturnBbArg(opOperand, bbArg);
181 return bbArg.getArgNumber();
182 }
183 }
184 }
185 return std::nullopt;
186 };
187
188 int64_t numResults = returnOps.front()->getNumOperands();
189 for (int64_t i = 0; i < numResults; ++i) {
190 // Find the equivalent block argument index for the i-th operand of the
191 // first func.return op.
192 std::optional<int64_t> maybeEquiv =
193 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
194 if (!maybeEquiv.has_value())
195 continue;
196 int64_t bbArgIdx = *maybeEquiv;
197 bool allEquiv = true;
198
199 // Check if all other func.return ops have the same equivalent block
200 // argument for the i-th operand. In contrast to aliasing information,
201 // which is just "merged", equivalence information must match across all
202 // func.return ops.
203 for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
204 std::optional<int64_t> maybeEquiv =
205 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
206 if (maybeEquiv != bbArgIdx) {
207 allEquiv = false;
208 break;
209 }
210 }
211
212 // All func.return ops have the same equivalent block argument for the i-th
213 // operand.
214 if (allEquiv)
215 funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
216 }
217
218 return success();
219}
220
221static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
222 bool isWritten) {
223 OpBuilder b(funcOp.getContext());
224 Attribute accessType;
225 if (isRead && isWritten) {
226 accessType = b.getStringAttr("read-write");
227 } else if (isRead) {
228 accessType = b.getStringAttr("read");
229 } else if (isWritten) {
230 accessType = b.getStringAttr("write");
231 } else {
232 accessType = b.getStringAttr("none");
233 }
234 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
235 accessType);
236}
237
238/// Determine which FuncOp bbArgs are read and which are written. When run on a
239/// function with unknown ops, we conservatively assume that such ops bufferize
240/// to a read + write.
241static LogicalResult
242funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
243 FuncAnalysisState &funcState) {
244 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
245 ++idx) {
246 // Skip non-tensor arguments.
247 if (!isa<TensorLikeType>(funcOp.getFunctionType().getInput(idx)))
248 continue;
249 bool isRead;
250 bool isWritten;
251 if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
252 idx, BufferizationDialect::kBufferAccessAttrName)) {
253 // Buffer access behavior is specified on the function. Skip the analysis.
254 StringRef str = accessAttr.getValue();
255 isRead = str == "read" || str == "read-write";
256 isWritten = str == "write" || str == "read-write";
257 } else if (funcOp.getBody().empty()) {
258 // If the function has no body, conservatively assume that all args are
259 // read + written.
260 isRead = true;
261 isWritten = true;
262 } else {
263 // Analyze the body of the function.
264 BlockArgument bbArg = funcOp.getArgument(idx);
265 isRead = state.isValueRead(bbArg);
266 isWritten = state.isValueWritten(bbArg);
267 }
268
269 if (state.getOptions().testAnalysisOnly)
270 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
271 if (isRead)
272 funcState.readBbArgs[funcOp].insert(idx);
273 if (isWritten)
274 funcState.writtenBbArgs[funcOp].insert(idx);
275 }
276
277 return success();
278}
279} // namespace
280
281/// Remove bufferization attributes on FuncOp arguments.
283 auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
284 funcOp.removeArgAttr(bbArg.getArgNumber(),
285 BufferizationDialect::kBufferLayoutAttrName);
286 funcOp.removeArgAttr(bbArg.getArgNumber(),
287 BufferizationDialect::kWritableAttrName);
288}
289
290/// Return the func::FuncOp called by `callOp`.
291static func::FuncOp
292getCalledFunction(func::CallOp callOp,
293 mlir::SymbolTableCollection &symbolTable) {
294 return dyn_cast_or_null<func::FuncOp>(
295 callOp.resolveCallableInTable(&symbolTable));
296}
297
298/// Return "true" if the given function signature has tensor semantics.
299static bool hasTensorSignature(func::FuncOp funcOp) {
300 return llvm::any_of(funcOp.getFunctionType().getInputs(),
301 llvm::IsaPred<TensorLikeType>) ||
302 llvm::any_of(funcOp.getFunctionType().getResults(),
303 llvm::IsaPred<TensorLikeType>);
304}
305
306/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
307/// callee-caller order (i.e., callees without callers first). Store all
308/// remaining functions (i.e., the ones that call each other recursively) in
309/// `remainingFuncOps`. Does not traverse nested symbol tables.
310///
311/// Store the map of FuncOp to all its callers in `callerMap`.
312///
313/// Return `failure()` if we are unable to retrieve the called FuncOp from
314/// any func::CallOp.
315static LogicalResult getFuncOpsOrderedByCalls(
316 Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
317 SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
318 SymbolTableCollection &symbolTables) {
319 // For each FuncOp, the set of functions called by it (i.e. the union of
320 // symbols of all nested func::CallOp).
322 // For each FuncOp, the number of func::CallOp it contains.
323 llvm::MapVector<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
324 for (mlir::Region &region : moduleOp->getRegions()) {
325 for (mlir::Block &block : region.getBlocks()) {
326 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
327 // Collect function calls and populate the caller map.
328 numberCallOpsContainedInFuncOp[funcOp] = 0;
329 WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
330 func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
331 assert(calledFunction && "could not retrieved called func::FuncOp");
332 // If the called function does not have any tensors in its signature,
333 // then it is not necessary to bufferize the callee before the caller.
334 if (!hasTensorSignature(calledFunction))
335 return WalkResult::skip();
336
337 callerMap[calledFunction].insert(callOp);
338 if (calledBy[calledFunction].insert(funcOp)) {
339 numberCallOpsContainedInFuncOp[funcOp]++;
340 }
341 return WalkResult::advance();
342 });
343 if (res.wasInterrupted())
344 return failure();
345 }
346 }
347 }
348
349 // Iteratively remove function operations that do not call any of the
350 // functions remaining in the callCounter map and add them to ordered list.
352
353 for (const auto &entry : numberCallOpsContainedInFuncOp) {
354 if (entry.second == 0)
355 worklist.push_back(entry.first);
356 }
357
358 while (!worklist.empty()) {
359 func::FuncOp func = worklist.pop_back_val();
360 orderedFuncOps.push_back(func);
361
362 for (func::FuncOp caller : calledBy[func]) {
363 auto &count = numberCallOpsContainedInFuncOp[caller];
364
365 if (--count == 0)
366 worklist.push_back(caller);
367 }
368
369 numberCallOpsContainedInFuncOp.erase(func);
370 }
371
372 // Put all other functions in the list of remaining functions. These are
373 // functions that call each other circularly.
374 for (auto it : numberCallOpsContainedInFuncOp)
375 remainingFuncOps.push_back(it.first);
376
377 return success();
378}
379
380/// Helper function that extracts the source from a memref.cast. If the given
381/// value is not a memref.cast result, simply returns the given value.
382/// Only unpacks casts where the source is at least as specific as the result
383/// (i.e., does not unpack casts from unranked to ranked memref, which would
384/// downgrade the type).
386 auto castOp = v.getDefiningOp<memref::CastOp>();
387 if (!castOp)
388 return v;
389 // Do not unpack a cast from unranked to ranked memref: folding would
390 // downgrade the function return type from ranked to unranked.
391 if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
392 isa<MemRefType>(v.getType()))
393 return v;
394 return castOp.getSource();
395}
396
397/// Helper function that returns the return types (skipping casts) of the given
398/// func.return ops. This function returns as many types as the return ops have
399/// operands. If the i-th operand is not the same for all func.return ops, then
400/// the i-th returned type is an "empty" type.
402 assert(!returnOps.empty() && "expected at least one ReturnOp");
403 int numOperands = returnOps.front()->getNumOperands();
404
405 // Helper function that unpacks memref.cast ops and returns the type.
406 auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407
409 for (int i = 0; i < numOperands; ++i) {
410 // Get the type of the i-th operand of the first func.return ops.
411 Type t = getSourceType(returnOps.front()->getOperand(i));
412
413 // Check if all other func.return ops have a matching operand type.
414 for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415 if (getSourceType(returnOps[j]->getOperand(i)) != t)
416 t = Type();
417
418 result.push_back(t);
419 }
420
421 return result;
422}
423
424/// Fold return values that are memref casts and update function return types.
425///
426/// During FuncOp bufferization, the exact type of the returned memrefs (if any)
427/// is not known yet. Therefore, the bufferization uses memref types with the
428/// most generic layout map as function return types. After bufferizing the
429/// entire function body, a more concise memref type can potentially be used for
430/// the return type of the function.
431static void foldMemRefCasts(func::FuncOp funcOp) {
432 // There is nothing to do for bodiless ops.
433 if (funcOp.getBody().empty())
434 return;
435
436 // Compute the common result types of all return ops.
437 SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438 SmallVector<Type> resultTypes = getReturnTypes(returnOps);
439
440 // Remove direct casts.
441 for (func::ReturnOp returnOp : returnOps) {
442 for (OpOperand &operand : returnOp->getOpOperands()) {
443 // Bail if no common result type was found.
444 if (resultTypes[operand.getOperandNumber()]) {
445 operand.set(unpackCast(operand.get()));
446 }
447 }
448 }
449
450 // Fill in the missing result types that were not the same among all
451 // func.return ops.
452 for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453 if (resultTypes[i])
454 continue;
455 resultTypes[i] = funcOp.getFunctionType().getResult(i);
456 }
457
458 // Update the function type.
459 auto newFuncType = FunctionType::get(
460 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461 funcOp.setType(newFuncType);
462}
463
464LogicalResult
467 BufferizationStatistics *statistics) {
468 assert(state.getOptions().bufferizeFunctionBoundaries &&
469 "expected that function boundary bufferization is activated");
471
472 // A list of non-circular functions in the order in which they are analyzed
473 // and bufferized.
474 SmallVector<func::FuncOp> orderedFuncOps;
475 // A list of all other functions. I.e., functions that call each other
476 // recursively. For these, we analyze the function body but not the function
477 // boundary.
478 SmallVector<func::FuncOp> remainingFuncOps;
479
480 // A mapping of FuncOps to their callers.
481 FuncCallerMap callerMap;
482
483 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
484 remainingFuncOps, callerMap,
485 funcState.symbolTables)))
486 return failure();
487
488 // Analyze functions in order. Starting with functions that are not calling
489 // any other functions.
490 for (func::FuncOp funcOp : orderedFuncOps) {
491 if (!state.getOptions().isOpAllowed(funcOp))
492 continue;
493
494 // Now analyzing function.
495 funcState.startFunctionAnalysis(funcOp);
496
497 // Analyze funcOp.
498 if (failed(analyzeOp(funcOp, state, statistics)))
499 return failure();
500
501 // Run some extra function analyses.
502 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
503 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
504 return failure();
505
506 // Mark op as fully analyzed.
508 }
509
510 // Analyze all other functions. All function boundary analyses are skipped.
511 for (func::FuncOp funcOp : remainingFuncOps) {
512 if (!state.getOptions().isOpAllowed(funcOp))
513 continue;
514
515 // Analyze funcOp.
516 if (failed(analyzeOp(funcOp, state, statistics)))
517 return failure();
518
519 // TODO: We currently skip all function argument analyses for functions
520 // that call each other circularly. These analyses do not support recursive
521 // calls yet. The `BufferizableOpInterface` implementations of `func`
522 // dialect ops return conservative results in the absence of analysis
523 // information.
524 }
525
526 return success();
527}
528
530 Operation *moduleOp) {
531 for (mlir::Region &region : moduleOp->getRegions()) {
532 for (mlir::Block &block : region.getBlocks()) {
533 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
534 for (BlockArgument bbArg : funcOp.getArguments())
536 }
537 }
538 }
539}
540
543 BufferizationState &state, BufferizationStatistics *statistics) {
544 assert(options.bufferizeFunctionBoundaries &&
545 "expected that function boundary bufferization is activated");
546 IRRewriter rewriter(moduleOp->getContext());
547
548 // A list of non-circular functions in the order in which they are analyzed
549 // and bufferized.
550 SmallVector<func::FuncOp> orderedFuncOps;
551 // A list of all other functions. I.e., functions that call each other
552 // recursively. For these, we analyze the function body but not the function
553 // boundary.
554 SmallVector<func::FuncOp> remainingFuncOps;
555
556 // A mapping of FuncOps to their callers.
557 FuncCallerMap callerMap;
558
559 // Try to bufferize functions in calling order. I.e., first bufferize
560 // functions that do not call other functions. This allows us to infer
561 // accurate buffer types for function return values. Functions that call
562 // each other recursively are bufferized in an unspecified order at the end.
563 // We may use unnecessarily "complex" (in terms of layout map) buffer types.
564 if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
565 remainingFuncOps, callerMap,
566 state.getSymbolTables())))
567 return failure();
568 llvm::append_range(orderedFuncOps, remainingFuncOps);
569
570 // Bufferize functions.
571 for (func::FuncOp funcOp : orderedFuncOps) {
572 // Note: It would be good to apply cleanups here but we cannot as aliasInfo
573 // would be invalidated.
574
575 if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
576 // This function was not analyzed and RaW conflicts were not resolved.
577 // Buffer copies must be inserted before every write.
578 OneShotBufferizationOptions updatedOptions = options;
579 updatedOptions.copyBeforeWrite = true;
580 if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
581 return failure();
582 } else {
583 if (failed(bufferizeOp(funcOp, options, state, statistics)))
584 return failure();
585 }
586
587 // Change buffer return types to more precise layout maps.
588 if (options.inferFunctionResultLayout)
589 foldMemRefCasts(funcOp);
590 }
591
592 // Bufferize all other ops.
593 for (mlir::Region &region : moduleOp->getRegions()) {
594 for (mlir::Block &block : region.getBlocks()) {
595 for (mlir::Operation &op :
596 llvm::make_early_inc_range(block.getOperations())) {
597 // Functions were already bufferized.
598 if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
599 continue;
600 if (failed(bufferizeOp(&op, options, state, statistics)))
601 return failure();
602 }
603 }
604 }
605
606 // Post-pass cleanup of function argument attributes.
608
609 return success();
610}
611
614 BufferizationState &state, BufferizationStatistics *statistics) {
615 assert(options.bufferizeFunctionBoundaries &&
616 "expected that function boundary bufferization is activated");
617 assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
618 "invalid combination of bufferization flags");
619 if (!options.copyBeforeWrite) {
620 if (options.noAnalysisFuncFilter.empty()) {
621 if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
622 return failure();
623 } else {
624 // FuncOps whose names are specified in options.noAnalysisFuncFilter will
625 // not be analyzed. Ops in these FuncOps will not be analyzed as well.
626 OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
627 auto func = dyn_cast<func::FuncOp>(op);
628 if (!func)
629 func = op->getParentOfType<func::FuncOp>();
630 if (func)
631 return llvm::is_contained(options.noAnalysisFuncFilter,
632 func.getSymName());
633 return false;
634 };
635 OneShotBufferizationOptions updatedOptions(options);
636 updatedOptions.opFilter.denyOperation(analysisFilterFn);
637 if (failed(
638 insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
639 return failure();
640 }
641 }
642 if (options.testAnalysisOnly)
643 return success();
644 if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
645 return failure();
646 return success();
647}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
DenseMap< func::FuncOp, DenseSet< Operation * > > FuncCallerMap
A mapping of FuncOps to their callers.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:560
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:586
unsigned getNumOperands()
Definition Operation.h:372
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
State for analysis-enabled bufferization.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
Ty * getExtension()
Returns the extension of the specified type.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.