MLIR  21.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Module Bufferization is an extension of One-Shot Bufferize that
10 // bufferizes function boundaries. It provides `BufferizableOpInterface`
11 // implementations for FuncOp, CallOp and ReturnOp.
12 //
13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14 // This function analyzes the given module and determines the order of analysis
15 // and bufferization: Functions that are called are processed before their
16 // respective callers.
17 //
18 // After analyzing a FuncOp, additional information about its bbArgs is
19 // gathered and stored in `FuncAnalysisState`.
20 //
21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22 // for
23 // each tensor return value (if any).
24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25 // read/written.
26 //
27 // Module Bufferization implements the following calling convention.
28 //
29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30 // be written to in-place.
31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
32 // the CallOp must bufferize out-of-place.
33 //
34 // Example: The tensor.insert op bufferizes in-place because it is allowed to
35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36 // out-of-place because `%t0` is modified by the callee but read by the
37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39 // ```
40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41 // %f = ... : f32
42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43 // return %0 : tensor<?xf32>
44 // }
45 //
46 // func @caller() -> () {
47 // %t0 = ... : tensor<?xf32>
48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49 // %2 = tensor.extract %1[...] : tensor<?xf32>
50 // }
51 // ```
52 //
53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54 // analyze the function body. In such a case, the CallOp analysis conservatively
55 // assumes that each tensor OpOperand is both read and written.
56 //
57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58 // as "not reading" and/or "not writing".
59 
61 
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/Operation.h"
72 
73 using namespace mlir;
74 using namespace mlir::bufferization;
75 using namespace mlir::bufferization::func_ext;
76 
77 /// A mapping of FuncOps to their callers.
79 
80 /// Get or create FuncAnalysisState.
81 static FuncAnalysisState &
83  auto *result = state.getExtension<FuncAnalysisState>();
84  if (result)
85  return *result;
86  return state.addExtension<FuncAnalysisState>();
87 }
88 
89 namespace {
90 
91 /// Annotate IR with the results of the analysis. For testing purposes only.
92 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93  BlockArgument bbArg) {
94  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95  Operation *op = returnVal.getOwner();
96 
97  SmallVector<int64_t> equivBbArgs;
98  if (op->hasAttr(kEquivalentArgsAttr)) {
99  auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
100  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
101  return cast<IntegerAttr>(a).getValue().getSExtValue();
102  }));
103  } else {
104  equivBbArgs.append(op->getNumOperands(), -1);
105  }
106  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107 
108  OpBuilder b(op->getContext());
109  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110 }
111 
112 /// Store function BlockArguments that are equivalent to/aliasing a returned
113 /// value in FuncAnalysisState.
114 static LogicalResult
115 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116  FuncAnalysisState &funcState) {
117  if (funcOp.getBody().empty()) {
118  // No function body available. Conservatively assume that every tensor
119  // return value may alias with any tensor bbArg.
120  FunctionType type = funcOp.getFunctionType();
121  for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
122  if (!isa<TensorType>(inputIt.value()))
123  continue;
124  for (const auto &resultIt : llvm::enumerate(type.getResults())) {
125  if (!isa<TensorType>(resultIt.value()))
126  continue;
127  int64_t returnIdx = resultIt.index();
128  int64_t bbArgIdx = inputIt.index();
129  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
130  }
131  }
132  return success();
133  }
134 
135  // Find all func.return ops.
136  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137  assert(!returnOps.empty() && "expected at least one ReturnOp");
138 
139  // Build alias sets. Merge all aliases from all func.return ops.
140  for (BlockArgument bbArg : funcOp.getArguments()) {
141  if (isa<RankedTensorType>(bbArg.getType())) {
142  int64_t bbArgIdx = bbArg.getArgNumber();
143  // Store aliases in a set, so that we don't add the same alias twice.
144  SetVector<int64_t> aliases;
145  for (func::ReturnOp returnOp : returnOps) {
146  for (OpOperand &returnVal : returnOp->getOpOperands()) {
147  if (isa<RankedTensorType>(returnVal.get().getType())) {
148  int64_t returnIdx = returnVal.getOperandNumber();
149  if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150  aliases.insert(returnIdx);
151  }
152  }
153  }
154  for (int64_t alias : aliases)
155  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156  }
157  }
158 
159  // Build equivalence sets.
160  // Helper function that finds an equivalent block argument index for the
161  // given OpOperand. Return std::nullopt if no equivalent block argument could
162  // be found.
163  auto findEquivalentBlockArgIdx =
164  [&](OpOperand &opOperand) -> std::optional<int64_t> {
165  Value v = opOperand.get();
166  if (!isa<TensorType>(v.getType()))
167  return std::nullopt;
168  for (BlockArgument bbArg : funcOp.getArguments()) {
169  if (isa<RankedTensorType>(bbArg.getType())) {
170  if (state.areEquivalentBufferizedValues(v, bbArg)) {
171  if (state.getOptions().testAnalysisOnly)
172  annotateEquivalentReturnBbArg(opOperand, bbArg);
173  return bbArg.getArgNumber();
174  }
175  }
176  }
177  return std::nullopt;
178  };
179 
180  int64_t numResults = returnOps.front()->getNumOperands();
181  for (int64_t i = 0; i < numResults; ++i) {
182  // Find the equivalent block argument index for the i-th operand of the
183  // first func.return op.
184  std::optional<int64_t> maybeEquiv =
185  findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186  if (!maybeEquiv.has_value())
187  continue;
188  int64_t bbArgIdx = *maybeEquiv;
189  bool allEquiv = true;
190 
191  // Check if all other func.return ops have the same equivalent block
192  // argument for the i-th operand. In contrast to aliasing information,
193  // which is just "merged", equivalence information must match across all
194  // func.return ops.
195  for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196  std::optional<int64_t> maybeEquiv =
197  findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198  if (maybeEquiv != bbArgIdx) {
199  allEquiv = false;
200  break;
201  }
202  }
203 
204  // All func.return ops have the same equivalent block argument for the i-th
205  // operand.
206  if (allEquiv)
207  funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208  }
209 
210  return success();
211 }
212 
213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
214  bool isWritten) {
215  OpBuilder b(funcOp.getContext());
216  Attribute accessType;
217  if (isRead && isWritten) {
218  accessType = b.getStringAttr("read-write");
219  } else if (isRead) {
220  accessType = b.getStringAttr("read");
221  } else if (isWritten) {
222  accessType = b.getStringAttr("write");
223  } else {
224  accessType = b.getStringAttr("none");
225  }
226  funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
227  accessType);
228 }
229 
230 /// Determine which FuncOp bbArgs are read and which are written. When run on a
231 /// function with unknown ops, we conservatively assume that such ops bufferize
232 /// to a read + write.
233 static LogicalResult
234 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235  FuncAnalysisState &funcState) {
236  for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
237  ++idx) {
238  // Skip non-tensor arguments.
239  if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
240  continue;
241  bool isRead;
242  bool isWritten;
243  if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244  idx, BufferizationDialect::kBufferAccessAttrName)) {
245  // Buffer access behavior is specified on the function. Skip the analysis.
246  StringRef str = accessAttr.getValue();
247  isRead = str == "read" || str == "read-write";
248  isWritten = str == "write" || str == "read-write";
249  } else if (funcOp.getBody().empty()) {
250  // If the function has no body, conservatively assume that all args are
251  // read + written.
252  isRead = true;
253  isWritten = true;
254  } else {
255  // Analyze the body of the function.
256  BlockArgument bbArg = funcOp.getArgument(idx);
257  isRead = state.isValueRead(bbArg);
258  isWritten = state.isValueWritten(bbArg);
259  }
260 
261  if (state.getOptions().testAnalysisOnly)
262  annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263  if (isRead)
264  funcState.readBbArgs[funcOp].insert(idx);
265  if (isWritten)
266  funcState.writtenBbArgs[funcOp].insert(idx);
267  }
268 
269  return success();
270 }
271 } // namespace
272 
273 /// Remove bufferization attributes on FuncOp arguments.
275  auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276  funcOp.removeArgAttr(bbArg.getArgNumber(),
277  BufferizationDialect::kBufferLayoutAttrName);
278  funcOp.removeArgAttr(bbArg.getArgNumber(),
279  BufferizationDialect::kWritableAttrName);
280 }
281 
282 /// Return the func::FuncOp called by `callOp`.
283 static func::FuncOp
284 getCalledFunction(func::CallOp callOp,
285  mlir::SymbolTableCollection &symbolTable) {
286  SymbolRefAttr sym =
287  llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
288  if (!sym)
289  return nullptr;
290  return dyn_cast_or_null<func::FuncOp>(
291  symbolTable.lookupNearestSymbolFrom(callOp, sym));
292 }
293 
294 /// Return "true" if the given function signature has tensor semantics.
295 static bool hasTensorSignature(func::FuncOp funcOp) {
296  return llvm::any_of(funcOp.getFunctionType().getInputs(),
297  llvm::IsaPred<TensorType>) ||
298  llvm::any_of(funcOp.getFunctionType().getResults(),
299  llvm::IsaPred<TensorType>);
300 }
301 
302 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
303 /// callee-caller order (i.e., callees without callers first). Store all
304 /// remaining functions (i.e., the ones that call each other recursively) in
305 /// `remainingFuncOps`. Does not traverse nested symbol tables.
306 ///
307 /// Store the map of FuncOp to all its callers in `callerMap`.
308 ///
309 /// Return `failure()` if we are unable to retrieve the called FuncOp from
310 /// any func::CallOp.
311 static LogicalResult getFuncOpsOrderedByCalls(
312  ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313  SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314  SymbolTableCollection &symbolTables) {
315  // For each FuncOp, the set of functions called by it (i.e. the union of
316  // symbols of all nested func::CallOp).
318  // For each FuncOp, the number of func::CallOp it contains.
319  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
320 
321  for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
322  // Collect function calls and populate the caller map.
323  numberCallOpsContainedInFuncOp[funcOp] = 0;
324  WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
325  func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
326  assert(calledFunction && "could not retrieved called func::FuncOp");
327  // If the called function does not have any tensors in its signature, then
328  // it is not necessary to bufferize the callee before the caller.
329  if (!hasTensorSignature(calledFunction))
330  return WalkResult::skip();
331 
332  callerMap[calledFunction].insert(callOp);
333  if (calledBy[calledFunction].insert(funcOp).second) {
334  numberCallOpsContainedInFuncOp[funcOp]++;
335  }
336  return WalkResult::advance();
337  });
338  if (res.wasInterrupted())
339  return failure();
340  }
341 
342  // Iteratively remove function operations that do not call any of the
343  // functions remaining in the callCounter map and add them to ordered list.
344  SmallVector<func::FuncOp> worklist;
345 
346  for (const auto &entry : numberCallOpsContainedInFuncOp) {
347  if (entry.second == 0)
348  worklist.push_back(entry.first);
349  }
350 
351  while (!worklist.empty()) {
352  func::FuncOp func = worklist.pop_back_val();
353  orderedFuncOps.push_back(func);
354 
355  for (func::FuncOp caller : calledBy[func]) {
356  auto &count = numberCallOpsContainedInFuncOp[caller];
357 
358  if (--count == 0)
359  worklist.push_back(caller);
360  }
361 
362  numberCallOpsContainedInFuncOp.erase(func);
363  }
364 
365  // Put all other functions in the list of remaining functions. These are
366  // functions that call each other circularly.
367  for (auto it : numberCallOpsContainedInFuncOp)
368  remainingFuncOps.push_back(it.first);
369 
370  return success();
371 }
372 
373 /// Helper function that extracts the source from a memref.cast. If the given
374 /// value is not a memref.cast result, simply returns the given value.
375 static Value unpackCast(Value v) {
376  auto castOp = v.getDefiningOp<memref::CastOp>();
377  if (!castOp)
378  return v;
379  return castOp.getSource();
380 }
381 
382 /// Helper function that returns the return types (skipping casts) of the given
383 /// func.return ops. This function returns as many types as the return ops have
384 /// operands. If the i-th operand is not the same for all func.return ops, then
385 /// the i-th returned type is an "empty" type.
387  assert(!returnOps.empty() && "expected at least one ReturnOp");
388  int numOperands = returnOps.front()->getNumOperands();
389 
390  // Helper function that unpacks memref.cast ops and returns the type.
391  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
392 
393  SmallVector<Type> result;
394  for (int i = 0; i < numOperands; ++i) {
395  // Get the type of the i-th operand of the first func.return ops.
396  Type t = getSourceType(returnOps.front()->getOperand(i));
397 
398  // Check if all other func.return ops have a matching operand type.
399  for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
400  if (getSourceType(returnOps[j]->getOperand(i)) != t)
401  t = Type();
402 
403  result.push_back(t);
404  }
405 
406  return result;
407 }
408 
409 /// Fold return values that are memref casts and update function return types.
410 ///
411 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
412 /// is not known yet. Therefore, the bufferization uses memref types with the
413 /// most generic layout map as function return types. After bufferizing the
414 /// entire function body, a more concise memref type can potentially be used for
415 /// the return type of the function.
416 static void foldMemRefCasts(func::FuncOp funcOp) {
417  // There is nothing to do for bodiless ops.
418  if (funcOp.getBody().empty())
419  return;
420 
421  // Compute the common result types of all return ops.
422  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
423  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
424 
425  // Remove direct casts.
426  for (func::ReturnOp returnOp : returnOps) {
427  for (OpOperand &operand : returnOp->getOpOperands()) {
428  // Bail if no common result type was found.
429  if (resultTypes[operand.getOperandNumber()]) {
430  operand.set(unpackCast(operand.get()));
431  }
432  }
433  }
434 
435  // Fill in the missing result types that were not the same among all
436  // func.return ops.
437  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
438  if (resultTypes[i])
439  continue;
440  resultTypes[i] = funcOp.getFunctionType().getResult(i);
441  }
442 
443  // Update the function type.
444  auto newFuncType = FunctionType::get(
445  funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
446  funcOp.setType(newFuncType);
447 }
448 
449 LogicalResult
451  OneShotAnalysisState &state,
452  BufferizationStatistics *statistics) {
453  assert(state.getOptions().bufferizeFunctionBoundaries &&
454  "expected that function boundary bufferization is activated");
456 
457  // A list of non-circular functions in the order in which they are analyzed
458  // and bufferized.
459  SmallVector<func::FuncOp> orderedFuncOps;
460  // A list of all other functions. I.e., functions that call each other
461  // recursively. For these, we analyze the function body but not the function
462  // boundary.
463  SmallVector<func::FuncOp> remainingFuncOps;
464 
465  // A mapping of FuncOps to their callers.
466  FuncCallerMap callerMap;
467 
468  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
469  remainingFuncOps, callerMap,
470  funcState.symbolTables)))
471  return failure();
472 
473  // Analyze functions in order. Starting with functions that are not calling
474  // any other functions.
475  for (func::FuncOp funcOp : orderedFuncOps) {
476  if (!state.getOptions().isOpAllowed(funcOp))
477  continue;
478 
479  // Now analyzing function.
480  funcState.startFunctionAnalysis(funcOp);
481 
482  // Analyze funcOp.
483  if (failed(analyzeOp(funcOp, state, statistics)))
484  return failure();
485 
486  // Run some extra function analyses.
487  if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
488  failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
489  return failure();
490 
491  // Mark op as fully analyzed.
493  }
494 
495  // Analyze all other functions. All function boundary analyses are skipped.
496  for (func::FuncOp funcOp : remainingFuncOps) {
497  if (!state.getOptions().isOpAllowed(funcOp))
498  continue;
499 
500  // Analyze funcOp.
501  if (failed(analyzeOp(funcOp, state, statistics)))
502  return failure();
503 
504  // TODO: We currently skip all function argument analyses for functions
505  // that call each other circularly. These analyses do not support recursive
506  // calls yet. The `BufferizableOpInterface` implementations of `func`
507  // dialect ops return conservative results in the absence of analysis
508  // information.
509  }
510 
511  return success();
512 }
513 
515  ModuleOp moduleOp) {
516  for (auto op : moduleOp.getOps<func::FuncOp>()) {
517  for (BlockArgument bbArg : op.getArguments())
519  }
520 }
521 
523  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
524  BufferizationState &state, BufferizationStatistics *statistics) {
525  assert(options.bufferizeFunctionBoundaries &&
526  "expected that function boundary bufferization is activated");
527  IRRewriter rewriter(moduleOp.getContext());
528 
529  // A list of non-circular functions in the order in which they are analyzed
530  // and bufferized.
531  SmallVector<func::FuncOp> orderedFuncOps;
532  // A list of all other functions. I.e., functions that call each other
533  // recursively. For these, we analyze the function body but not the function
534  // boundary.
535  SmallVector<func::FuncOp> remainingFuncOps;
536 
537  // A mapping of FuncOps to their callers.
538  FuncCallerMap callerMap;
539 
540  // Try to bufferize functions in calling order. I.e., first bufferize
541  // functions that do not call other functions. This allows us to infer
542  // accurate buffer types for function return values. Functions that call
543  // each other recursively are bufferized in an unspecified order at the end.
544  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
545  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
546  remainingFuncOps, callerMap,
547  state.getSymbolTables())))
548  return failure();
549  llvm::append_range(orderedFuncOps, remainingFuncOps);
550 
551  // Bufferize functions.
552  for (func::FuncOp funcOp : orderedFuncOps) {
553  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
554  // would be invalidated.
555 
556  if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
557  // This function was not analyzed and RaW conflicts were not resolved.
558  // Buffer copies must be inserted before every write.
559  OneShotBufferizationOptions updatedOptions = options;
560  updatedOptions.copyBeforeWrite = true;
561  if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
562  return failure();
563  } else {
564  if (failed(bufferizeOp(funcOp, options, state, statistics)))
565  return failure();
566  }
567 
568  // Change buffer return types to more precise layout maps.
569  if (options.inferFunctionResultLayout)
570  foldMemRefCasts(funcOp);
571  }
572 
573  // Bufferize all other ops.
574  for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
575  // Functions were already bufferized.
576  if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
577  continue;
578  if (failed(bufferizeOp(&op, options, state, statistics)))
579  return failure();
580  }
581 
582  // Post-pass cleanup of function argument attributes.
584 
585  return success();
586 }
587 
589  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
590  BufferizationState &state, BufferizationStatistics *statistics) {
591  assert(options.bufferizeFunctionBoundaries &&
592  "expected that function boundary bufferization is activated");
593  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
594  "invalid combination of bufferization flags");
595  if (!options.copyBeforeWrite) {
596  if (options.noAnalysisFuncFilter.empty()) {
597  if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
598  return failure();
599  } else {
600  // FuncOps whose names are specified in options.noAnalysisFuncFilter will
601  // not be analyzed. Ops in these FuncOps will not be analyzed as well.
602  OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
603  auto func = dyn_cast<func::FuncOp>(op);
604  if (!func)
605  func = op->getParentOfType<func::FuncOp>();
606  if (func)
607  return llvm::is_contained(options.noAnalysisFuncFilter,
608  func.getSymName());
609  return false;
610  };
611  OneShotBufferizationOptions updatedOptions(options);
612  updatedOptions.opFilter.denyOperation(analysisFilterFn);
613  if (failed(
614  insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
615  return failure();
616  }
617  }
618  if (options.testAnalysisOnly)
619  return success();
620  if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
621  return failure();
622  return success();
623 }
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:278
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(ModuleOp moduleOp)
Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.