MLIR  20.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Module Bufferization is an extension of One-Shot Bufferize that
10 // bufferizes function boundaries. It provides `BufferizableOpInterface`
11 // implementations for FuncOp, CallOp and ReturnOp.
12 //
13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14 // This function analyzes the given module and determines the order of analysis
15 // and bufferization: Functions that are called are processed before their
16 // respective callers.
17 //
18 // After analyzing a FuncOp, additional information about its bbArgs is
19 // gathered and stored in `FuncAnalysisState`.
20 //
21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22 // for
23 // each tensor return value (if any).
24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25 // read/written.
26 //
27 // Module Bufferization implements the following calling convention.
28 //
29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30 // be written to in-place.
31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
32 // the CallOp must bufferize out-of-place.
33 //
34 // Example: The tensor.insert op bufferizes in-place because it is allowed to
35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36 // out-of-place because `%t0` is modified by the callee but read by the
37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39 // ```
40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41 // %f = ... : f32
42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43 // return %0 : tensor<?xf32>
44 // }
45 //
46 // func @caller() -> () {
47 // %t0 = ... : tensor<?xf32>
48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49 // %2 = tensor.extract %1[...] : tensor<?xf32>
50 // }
51 // ```
52 //
53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54 // analyze the function body. In such a case, the CallOp analysis conservatively
55 // assumes that each tensor OpOperand is both read and written.
56 //
57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58 // as "not reading" and/or "not writing".
59 
61 
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/Operation.h"
72 
73 using namespace mlir;
74 using namespace mlir::bufferization;
75 using namespace mlir::bufferization::func_ext;
76 
77 /// A mapping of FuncOps to their callers.
79 
80 /// Get or create FuncAnalysisState.
81 static FuncAnalysisState &
83  auto *result = state.getExtension<FuncAnalysisState>();
84  if (result)
85  return *result;
86  return state.addExtension<FuncAnalysisState>();
87 }
88 
89 namespace {
90 
91 /// Annotate IR with the results of the analysis. For testing purposes only.
92 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93  BlockArgument bbArg) {
94  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95  Operation *op = returnVal.getOwner();
96 
97  SmallVector<int64_t> equivBbArgs;
98  if (op->hasAttr(kEquivalentArgsAttr)) {
99  auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
100  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
101  return cast<IntegerAttr>(a).getValue().getSExtValue();
102  }));
103  } else {
104  equivBbArgs.append(op->getNumOperands(), -1);
105  }
106  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107 
108  OpBuilder b(op->getContext());
109  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110 }
111 
112 /// Store function BlockArguments that are equivalent to/aliasing a returned
113 /// value in FuncAnalysisState.
114 static LogicalResult
115 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116  FuncAnalysisState &funcState) {
117  if (funcOp.getBody().empty()) {
118  // No function body available. Conservatively assume that every tensor
119  // return value may alias with any tensor bbArg.
120  FunctionType type = funcOp.getFunctionType();
121  for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
122  if (!isa<TensorType>(inputIt.value()))
123  continue;
124  for (const auto &resultIt : llvm::enumerate(type.getResults())) {
125  if (!isa<TensorType>(resultIt.value()))
126  continue;
127  int64_t returnIdx = resultIt.index();
128  int64_t bbArgIdx = inputIt.index();
129  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
130  }
131  }
132  return success();
133  }
134 
135  // Find all func.return ops.
136  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137  assert(!returnOps.empty() && "expected at least one ReturnOp");
138 
139  // Build alias sets. Merge all aliases from all func.return ops.
140  for (BlockArgument bbArg : funcOp.getArguments()) {
141  if (isa<RankedTensorType>(bbArg.getType())) {
142  int64_t bbArgIdx = bbArg.getArgNumber();
143  // Store aliases in a set, so that we don't add the same alias twice.
144  SetVector<int64_t> aliases;
145  for (func::ReturnOp returnOp : returnOps) {
146  for (OpOperand &returnVal : returnOp->getOpOperands()) {
147  if (isa<RankedTensorType>(returnVal.get().getType())) {
148  int64_t returnIdx = returnVal.getOperandNumber();
149  if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150  aliases.insert(returnIdx);
151  }
152  }
153  }
154  for (int64_t alias : aliases)
155  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156  }
157  }
158 
159  // Build equivalence sets.
160  // Helper function that finds an equivalent block argument index for the
161  // given OpOperand. Return std::nullopt if no equivalent block argument could
162  // be found.
163  auto findEquivalentBlockArgIdx =
164  [&](OpOperand &opOperand) -> std::optional<int64_t> {
165  Value v = opOperand.get();
166  if (!isa<TensorType>(v.getType()))
167  return std::nullopt;
168  for (BlockArgument bbArg : funcOp.getArguments()) {
169  if (isa<RankedTensorType>(bbArg.getType())) {
170  if (state.areEquivalentBufferizedValues(v, bbArg)) {
171  if (state.getOptions().testAnalysisOnly)
172  annotateEquivalentReturnBbArg(opOperand, bbArg);
173  return bbArg.getArgNumber();
174  }
175  }
176  }
177  return std::nullopt;
178  };
179 
180  int64_t numResults = returnOps.front()->getNumOperands();
181  for (int64_t i = 0; i < numResults; ++i) {
182  // Find the equivalent block argument index for the i-th operand of the
183  // first func.return op.
184  std::optional<int64_t> maybeEquiv =
185  findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186  if (!maybeEquiv.has_value())
187  continue;
188  int64_t bbArgIdx = *maybeEquiv;
189  bool allEquiv = true;
190 
191  // Check if all other func.return ops have the same equivalent block
192  // argument for the i-th operand. In contrast to aliasing information,
193  // which is just "merged", equivalence information must match across all
194  // func.return ops.
195  for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196  std::optional<int64_t> maybeEquiv =
197  findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198  if (maybeEquiv != bbArgIdx) {
199  allEquiv = false;
200  break;
201  }
202  }
203 
204  // All func.return ops have the same equivalent block argument for the i-th
205  // operand.
206  if (allEquiv)
207  funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208  }
209 
210  return success();
211 }
212 
213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
214  bool isWritten) {
215  OpBuilder b(funcOp.getContext());
216  Attribute accessType;
217  if (isRead && isWritten) {
218  accessType = b.getStringAttr("read-write");
219  } else if (isRead) {
220  accessType = b.getStringAttr("read");
221  } else if (isWritten) {
222  accessType = b.getStringAttr("write");
223  } else {
224  accessType = b.getStringAttr("none");
225  }
226  funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
227  accessType);
228 }
229 
230 /// Determine which FuncOp bbArgs are read and which are written. When run on a
231 /// function with unknown ops, we conservatively assume that such ops bufferize
232 /// to a read + write.
233 static LogicalResult
234 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235  FuncAnalysisState &funcState) {
236  for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
237  ++idx) {
238  // Skip non-tensor arguments.
239  if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
240  continue;
241  bool isRead;
242  bool isWritten;
243  if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244  idx, BufferizationDialect::kBufferAccessAttrName)) {
245  // Buffer access behavior is specified on the function. Skip the analysis.
246  StringRef str = accessAttr.getValue();
247  isRead = str == "read" || str == "read-write";
248  isWritten = str == "write" || str == "read-write";
249  } else if (funcOp.getBody().empty()) {
250  // If the function has no body, conservatively assume that all args are
251  // read + written.
252  isRead = true;
253  isWritten = true;
254  } else {
255  // Analyze the body of the function.
256  BlockArgument bbArg = funcOp.getArgument(idx);
257  isRead = state.isValueRead(bbArg);
258  isWritten = state.isValueWritten(bbArg);
259  }
260 
261  if (state.getOptions().testAnalysisOnly)
262  annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263  if (isRead)
264  funcState.readBbArgs[funcOp].insert(idx);
265  if (isWritten)
266  funcState.writtenBbArgs[funcOp].insert(idx);
267  }
268 
269  return success();
270 }
271 } // namespace
272 
273 /// Remove bufferization attributes on FuncOp arguments.
275  auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276  funcOp.removeArgAttr(bbArg.getArgNumber(),
277  BufferizationDialect::kBufferLayoutAttrName);
278  funcOp.removeArgAttr(bbArg.getArgNumber(),
279  BufferizationDialect::kWritableAttrName);
280 }
281 
282 /// Return the func::FuncOp called by `callOp`.
283 static func::FuncOp getCalledFunction(func::CallOp callOp) {
284  SymbolRefAttr sym =
285  llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286  if (!sym)
287  return nullptr;
288  return dyn_cast_or_null<func::FuncOp>(
290 }
291 
292 /// Gather equivalence info of CallOps.
293 /// Note: This only adds new equivalence info if the called function was already
294 /// analyzed.
295 // TODO: This does not handle cyclic function call graphs etc.
296 static void equivalenceAnalysis(func::FuncOp funcOp,
297  OneShotAnalysisState &state,
298  FuncAnalysisState &funcState) {
299  funcOp->walk([&](func::CallOp callOp) {
300  func::FuncOp calledFunction = getCalledFunction(callOp);
301  assert(calledFunction && "could not retrieved called func::FuncOp");
302 
303  // No equivalence info available for the called function.
304  if (!funcState.equivalentFuncArgs.count(calledFunction))
305  return WalkResult::skip();
306 
307  for (auto it : funcState.equivalentFuncArgs[calledFunction]) {
308  int64_t returnIdx = it.first;
309  int64_t bbargIdx = it.second;
310  if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
311  continue;
312  Value returnVal = callOp.getResult(returnIdx);
313  Value argVal = callOp->getOperand(bbargIdx);
314  state.unionEquivalenceClasses(returnVal, argVal);
315  }
316 
317  return WalkResult::advance();
318  });
319 }
320 
321 /// Return "true" if the given function signature has tensor semantics.
322 static bool hasTensorSignature(func::FuncOp funcOp) {
323  return llvm::any_of(funcOp.getFunctionType().getInputs(),
324  llvm::IsaPred<TensorType>) ||
325  llvm::any_of(funcOp.getFunctionType().getResults(),
326  llvm::IsaPred<TensorType>);
327 }
328 
329 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
330 /// callee-caller order (i.e., callees without callers first). Store all
331 /// remaining functions (i.e., the ones that call each other recursively) in
332 /// `remainingFuncOps`.
333 ///
334 /// Store the map of FuncOp to all its callers in `callerMap`.
335 ///
336 /// Return `failure()` if we are unable to retrieve the called FuncOp from
337 /// any func::CallOp.
338 static LogicalResult getFuncOpsOrderedByCalls(
339  ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
340  SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
341  // For each FuncOp, the set of functions called by it (i.e. the union of
342  // symbols of all nested func::CallOp).
344  // For each FuncOp, the number of func::CallOp it contains.
345  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
346  WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
347  // Collect function calls and populate the caller map.
348  numberCallOpsContainedInFuncOp[funcOp] = 0;
349  return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
350  func::FuncOp calledFunction = getCalledFunction(callOp);
351  assert(calledFunction && "could not retrieved called func::FuncOp");
352  // If the called function does not have any tensors in its signature, then
353  // it is not necessary to bufferize the callee before the caller.
354  if (!hasTensorSignature(calledFunction))
355  return WalkResult::skip();
356 
357  callerMap[calledFunction].insert(callOp);
358  if (calledBy[calledFunction].insert(funcOp).second) {
359  numberCallOpsContainedInFuncOp[funcOp]++;
360  }
361  return WalkResult::advance();
362  });
363  });
364  if (res.wasInterrupted())
365  return failure();
366 
367  // Iteratively remove function operations that do not call any of the
368  // functions remaining in the callCounter map and add them to ordered list.
369  while (!numberCallOpsContainedInFuncOp.empty()) {
370  auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
371  [](auto entry) { return entry.getSecond() == 0; });
372  if (it == numberCallOpsContainedInFuncOp.end())
373  break;
374  orderedFuncOps.push_back(it->getFirst());
375  for (auto callee : calledBy[it->getFirst()])
376  numberCallOpsContainedInFuncOp[callee]--;
377  numberCallOpsContainedInFuncOp.erase(it);
378  }
379 
380  // Put all other functions in the list of remaining functions. These are
381  // functions that call each other circularly.
382  for (auto it : numberCallOpsContainedInFuncOp)
383  remainingFuncOps.push_back(it.first);
384 
385  return success();
386 }
387 
388 /// Helper function that extracts the source from a memref.cast. If the given
389 /// value is not a memref.cast result, simply returns the given value.
390 static Value unpackCast(Value v) {
391  auto castOp = v.getDefiningOp<memref::CastOp>();
392  if (!castOp)
393  return v;
394  return castOp.getSource();
395 }
396 
397 /// Helper function that returns the return types (skipping casts) of the given
398 /// func.return ops. This function returns as many types as the return ops have
399 /// operands. If the i-th operand is not the same for all func.return ops, then
400 /// the i-th returned type is an "empty" type.
402  assert(!returnOps.empty() && "expected at least one ReturnOp");
403  int numOperands = returnOps.front()->getNumOperands();
404 
405  // Helper function that unpacks memref.cast ops and returns the type.
406  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
407 
408  SmallVector<Type> result;
409  for (int i = 0; i < numOperands; ++i) {
410  // Get the type of the i-th operand of the first func.return ops.
411  Type t = getSourceType(returnOps.front()->getOperand(i));
412 
413  // Check if all other func.return ops have a matching operand type.
414  for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
415  if (getSourceType(returnOps[j]->getOperand(i)) != t)
416  t = Type();
417 
418  result.push_back(t);
419  }
420 
421  return result;
422 }
423 
424 /// Fold return values that are memref casts and update function return types.
425 ///
426 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
427 /// is not known yet. Therefore, the bufferization uses memref types with the
428 /// most generic layout map as function return types. After bufferizing the
429 /// entire function body, a more concise memref type can potentially be used for
430 /// the return type of the function.
431 static void foldMemRefCasts(func::FuncOp funcOp) {
432  // There is nothing to do for bodiless ops.
433  if (funcOp.getBody().empty())
434  return;
435 
436  // Compute the common result types of all return ops.
437  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
438  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
439 
440  // Remove direct casts.
441  for (func::ReturnOp returnOp : returnOps) {
442  for (OpOperand &operand : returnOp->getOpOperands()) {
443  // Bail if no common result type was found.
444  if (resultTypes[operand.getOperandNumber()]) {
445  operand.set(unpackCast(operand.get()));
446  }
447  }
448  }
449 
450  // Fill in the missing result types that were not the same among all
451  // func.return ops.
452  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
453  if (resultTypes[i])
454  continue;
455  resultTypes[i] = funcOp.getFunctionType().getResult(i);
456  }
457 
458  // Update the function type.
459  auto newFuncType = FunctionType::get(
460  funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461  funcOp.setType(newFuncType);
462 }
463 
464 LogicalResult
466  OneShotAnalysisState &state,
467  BufferizationStatistics *statistics) {
468  assert(state.getOptions().bufferizeFunctionBoundaries &&
469  "expected that function boundary bufferization is activated");
471 
472  // A list of non-circular functions in the order in which they are analyzed
473  // and bufferized.
474  SmallVector<func::FuncOp> orderedFuncOps;
475  // A list of all other functions. I.e., functions that call each other
476  // recursively. For these, we analyze the function body but not the function
477  // boundary.
478  SmallVector<func::FuncOp> remainingFuncOps;
479 
480  // A mapping of FuncOps to their callers.
481  FuncCallerMap callerMap;
482 
483  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
484  remainingFuncOps, callerMap)))
485  return failure();
486 
487  // Analyze functions in order. Starting with functions that are not calling
488  // any other functions.
489  for (func::FuncOp funcOp : orderedFuncOps) {
490  if (!state.getOptions().isOpAllowed(funcOp))
491  continue;
492 
493  // Now analyzing function.
494  funcState.startFunctionAnalysis(funcOp);
495 
496  // Gather equivalence info for CallOps.
497  equivalenceAnalysis(funcOp, state, funcState);
498 
499  // Analyze funcOp.
500  if (failed(analyzeOp(funcOp, state, statistics)))
501  return failure();
502 
503  // Run some extra function analyses.
504  if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
505  failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
506  return failure();
507 
508  // Mark op as fully analyzed.
510  }
511 
512  // Analyze all other functions. All function boundary analyses are skipped.
513  for (func::FuncOp funcOp : remainingFuncOps) {
514  if (!state.getOptions().isOpAllowed(funcOp))
515  continue;
516 
517  // Gather equivalence info for CallOps.
518  equivalenceAnalysis(funcOp, state, funcState);
519 
520  // Analyze funcOp.
521  if (failed(analyzeOp(funcOp, state, statistics)))
522  return failure();
523 
524  // TODO: We currently skip all function argument analyses for functions
525  // that call each other circularly. These analyses do not support recursive
526  // calls yet. The `BufferizableOpInterface` implementations of `func`
527  // dialect ops return conservative results in the absence of analysis
528  // information.
529  }
530 
531  return success();
532 }
533 
535  ModuleOp moduleOp) {
536  moduleOp.walk([&](func::FuncOp op) {
537  for (BlockArgument bbArg : op.getArguments())
539  });
540 }
541 
543  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
544  BufferizationStatistics *statistics) {
545  assert(options.bufferizeFunctionBoundaries &&
546  "expected that function boundary bufferization is activated");
547  IRRewriter rewriter(moduleOp.getContext());
548 
549  // A list of non-circular functions in the order in which they are analyzed
550  // and bufferized.
551  SmallVector<func::FuncOp> orderedFuncOps;
552  // A list of all other functions. I.e., functions that call each other
553  // recursively. For these, we analyze the function body but not the function
554  // boundary.
555  SmallVector<func::FuncOp> remainingFuncOps;
556 
557  // A mapping of FuncOps to their callers.
558  FuncCallerMap callerMap;
559 
560  // Try to bufferize functions in calling order. I.e., first bufferize
561  // functions that do not call other functions. This allows us to infer
562  // accurate buffer types for function return values. Functions that call
563  // each other recursively are bufferized in an unspecified order at the end.
564  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
565  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
566  remainingFuncOps, callerMap)))
567  return failure();
568  llvm::append_range(orderedFuncOps, remainingFuncOps);
569 
570  // Bufferize functions.
571  for (func::FuncOp funcOp : orderedFuncOps) {
572  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
573  // would be invalidated.
574 
575  if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
576  // This function was not analyzed and RaW conflicts were not resolved.
577  // Buffer copies must be inserted before every write.
578  OneShotBufferizationOptions updatedOptions = options;
579  updatedOptions.copyBeforeWrite = true;
580  if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
581  return failure();
582  } else {
583  if (failed(bufferizeOp(funcOp, options, statistics)))
584  return failure();
585  }
586 
587  // Change buffer return types to more precise layout maps.
588  if (options.inferFunctionResultLayout)
589  foldMemRefCasts(funcOp);
590  }
591 
592  // Bufferize all other ops.
593  for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
594  // Functions were already bufferized.
595  if (isa<func::FuncOp>(&op))
596  continue;
597  if (failed(bufferizeOp(&op, options, statistics)))
598  return failure();
599  }
600 
601  // Post-pass cleanup of function argument attributes.
603 
604  return success();
605 }
606 
608  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
609  BufferizationStatistics *statistics) {
610  assert(options.bufferizeFunctionBoundaries &&
611  "expected that function boundary bufferization is activated");
612  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
613  "invalid combination of bufferization flags");
614  if (!options.copyBeforeWrite) {
615  if (options.noAnalysisFuncFilter.empty()) {
616  if (failed(insertTensorCopies(moduleOp, options, statistics)))
617  return failure();
618  } else {
619  // FuncOps whose names are specified in options.noAnalysisFuncFilter will
620  // not be analyzed. Ops in these FuncOps will not be analyzed as well.
621  OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
622  auto func = dyn_cast<func::FuncOp>(op);
623  if (!func)
624  func = op->getParentOfType<func::FuncOp>();
625  if (func)
626  return llvm::is_contained(options.noAnalysisFuncFilter,
627  func.getSymName());
628  return false;
629  };
630  OneShotBufferizationOptions updatedOptions(options);
631  updatedOptions.opFilter.denyOperation(analysisFilterFn);
632  if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
633  return failure();
634  }
635  }
636  if (options.testAnalysisOnly)
637  return success();
638  if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
639  return failure();
640  return success();
641 }
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static void equivalenceAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState)
Gather equivalence info of CallOps.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
This class helps build Operations.
Definition: Builders.h:216
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:305
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
void removeBufferizationAttributesInModule(ModuleOp moduleOp)
Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.