MLIR  22.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1 //===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2 //----===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Module Bufferization is an extension of One-Shot Bufferize that
11 // bufferizes function boundaries. It provides `BufferizableOpInterface`
12 // implementations for FuncOp, CallOp and ReturnOp. Although it is named
13 // Module Bufferization, it may operate on any SymbolTable.
14 //
15 // Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16 // ...)`. This function analyzes the given op and determines the order of
17 // analysis and bufferization: Functions that are called are processed before
18 // their respective callers.
19 //
20 // After analyzing a FuncOp, additional information about its bbArgs is
21 // gathered and stored in `FuncAnalysisState`.
22 //
23 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
24 // for
25 // each tensor return value (if any).
26 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
27 // read/written.
28 //
29 // Module Bufferization implements the following calling convention.
30 //
31 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
32 // be written to in-place.
33 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
34 // the CallOp must bufferize out-of-place.
35 //
36 // Example: The tensor.insert op bufferizes in-place because it is allowed to
37 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
38 // out-of-place because `%t0` is modified by the callee but read by the
39 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
40 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
41 // ```
42 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
43 // %f = ... : f32
44 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
45 // return %0 : tensor<?xf32>
46 // }
47 //
48 // func @caller() -> () {
49 // %t0 = ... : tensor<?xf32>
50 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
51 // %2 = tensor.extract %1[...] : tensor<?xf32>
52 // }
53 // ```
54 //
55 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
56 // analyze the function body. In such a case, the CallOp analysis conservatively
57 // assumes that each tensor OpOperand is both read and written.
58 //
59 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
60 // as "not reading" and/or "not writing".
61 
63 
72 #include "mlir/IR/BuiltinTypes.h"
73 #include "mlir/IR/Operation.h"
74 
75 using namespace mlir;
76 using namespace mlir::bufferization;
77 using namespace mlir::bufferization::func_ext;
78 
79 /// A mapping of FuncOps to their callers.
81 
82 /// Get or create FuncAnalysisState.
83 static FuncAnalysisState &
85  auto *result = state.getExtension<FuncAnalysisState>();
86  if (result)
87  return *result;
88  return state.addExtension<FuncAnalysisState>();
89 }
90 
91 namespace {
92 
93 /// Annotate IR with the results of the analysis. For testing purposes only.
94 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
95  BlockArgument bbArg) {
96  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
97  Operation *op = returnVal.getOwner();
98 
99  SmallVector<int64_t> equivBbArgs;
100  if (op->hasAttr(kEquivalentArgsAttr)) {
101  auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
102  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
103  return cast<IntegerAttr>(a).getValue().getSExtValue();
104  }));
105  } else {
106  equivBbArgs.append(op->getNumOperands(), -1);
107  }
108  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
109 
110  OpBuilder b(op->getContext());
111  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
112 }
113 
114 /// Store function BlockArguments that are equivalent to/aliasing a returned
115 /// value in FuncAnalysisState.
116 static LogicalResult
117 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
118  FuncAnalysisState &funcState) {
119  if (funcOp.getBody().empty()) {
120  // No function body available. Conservatively assume that every tensor
121  // return value may alias with any tensor bbArg.
122  FunctionType type = funcOp.getFunctionType();
123  for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
124  if (!isa<TensorType>(inputIt.value()))
125  continue;
126  for (const auto &resultIt : llvm::enumerate(type.getResults())) {
127  if (!isa<TensorType>(resultIt.value()))
128  continue;
129  int64_t returnIdx = resultIt.index();
130  int64_t bbArgIdx = inputIt.index();
131  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
132  }
133  }
134  return success();
135  }
136 
137  // Find all func.return ops.
138  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
139  assert(!returnOps.empty() && "expected at least one ReturnOp");
140 
141  // Build alias sets. Merge all aliases from all func.return ops.
142  for (BlockArgument bbArg : funcOp.getArguments()) {
143  if (isa<RankedTensorType>(bbArg.getType())) {
144  int64_t bbArgIdx = bbArg.getArgNumber();
145  // Store aliases in a set, so that we don't add the same alias twice.
146  SetVector<int64_t> aliases;
147  for (func::ReturnOp returnOp : returnOps) {
148  for (OpOperand &returnVal : returnOp->getOpOperands()) {
149  if (isa<RankedTensorType>(returnVal.get().getType())) {
150  int64_t returnIdx = returnVal.getOperandNumber();
151  if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
152  aliases.insert(returnIdx);
153  }
154  }
155  }
156  for (int64_t alias : aliases)
157  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
158  }
159  }
160 
161  // Build equivalence sets.
162  // Helper function that finds an equivalent block argument index for the
163  // given OpOperand. Return std::nullopt if no equivalent block argument could
164  // be found.
165  auto findEquivalentBlockArgIdx =
166  [&](OpOperand &opOperand) -> std::optional<int64_t> {
167  Value v = opOperand.get();
168  if (!isa<TensorType>(v.getType()))
169  return std::nullopt;
170  for (BlockArgument bbArg : funcOp.getArguments()) {
171  if (isa<RankedTensorType>(bbArg.getType())) {
172  if (state.areEquivalentBufferizedValues(v, bbArg)) {
173  if (state.getOptions().testAnalysisOnly)
174  annotateEquivalentReturnBbArg(opOperand, bbArg);
175  return bbArg.getArgNumber();
176  }
177  }
178  }
179  return std::nullopt;
180  };
181 
182  int64_t numResults = returnOps.front()->getNumOperands();
183  for (int64_t i = 0; i < numResults; ++i) {
184  // Find the equivalent block argument index for the i-th operand of the
185  // first func.return op.
186  std::optional<int64_t> maybeEquiv =
187  findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
188  if (!maybeEquiv.has_value())
189  continue;
190  int64_t bbArgIdx = *maybeEquiv;
191  bool allEquiv = true;
192 
193  // Check if all other func.return ops have the same equivalent block
194  // argument for the i-th operand. In contrast to aliasing information,
195  // which is just "merged", equivalence information must match across all
196  // func.return ops.
197  for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
198  std::optional<int64_t> maybeEquiv =
199  findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
200  if (maybeEquiv != bbArgIdx) {
201  allEquiv = false;
202  break;
203  }
204  }
205 
206  // All func.return ops have the same equivalent block argument for the i-th
207  // operand.
208  if (allEquiv)
209  funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
210  }
211 
212  return success();
213 }
214 
215 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
216  bool isWritten) {
217  OpBuilder b(funcOp.getContext());
218  Attribute accessType;
219  if (isRead && isWritten) {
220  accessType = b.getStringAttr("read-write");
221  } else if (isRead) {
222  accessType = b.getStringAttr("read");
223  } else if (isWritten) {
224  accessType = b.getStringAttr("write");
225  } else {
226  accessType = b.getStringAttr("none");
227  }
228  funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
229  accessType);
230 }
231 
232 /// Determine which FuncOp bbArgs are read and which are written. When run on a
233 /// function with unknown ops, we conservatively assume that such ops bufferize
234 /// to a read + write.
235 static LogicalResult
236 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
237  FuncAnalysisState &funcState) {
238  for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
239  ++idx) {
240  // Skip non-tensor arguments.
241  if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
242  continue;
243  bool isRead;
244  bool isWritten;
245  if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
246  idx, BufferizationDialect::kBufferAccessAttrName)) {
247  // Buffer access behavior is specified on the function. Skip the analysis.
248  StringRef str = accessAttr.getValue();
249  isRead = str == "read" || str == "read-write";
250  isWritten = str == "write" || str == "read-write";
251  } else if (funcOp.getBody().empty()) {
252  // If the function has no body, conservatively assume that all args are
253  // read + written.
254  isRead = true;
255  isWritten = true;
256  } else {
257  // Analyze the body of the function.
258  BlockArgument bbArg = funcOp.getArgument(idx);
259  isRead = state.isValueRead(bbArg);
260  isWritten = state.isValueWritten(bbArg);
261  }
262 
263  if (state.getOptions().testAnalysisOnly)
264  annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
265  if (isRead)
266  funcState.readBbArgs[funcOp].insert(idx);
267  if (isWritten)
268  funcState.writtenBbArgs[funcOp].insert(idx);
269  }
270 
271  return success();
272 }
273 } // namespace
274 
275 /// Remove bufferization attributes on FuncOp arguments.
277  auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
278  funcOp.removeArgAttr(bbArg.getArgNumber(),
279  BufferizationDialect::kBufferLayoutAttrName);
280  funcOp.removeArgAttr(bbArg.getArgNumber(),
281  BufferizationDialect::kWritableAttrName);
282 }
283 
284 /// Return the func::FuncOp called by `callOp`.
285 static func::FuncOp
286 getCalledFunction(func::CallOp callOp,
287  mlir::SymbolTableCollection &symbolTable) {
288  SymbolRefAttr sym =
289  llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
290  if (!sym)
291  return nullptr;
292  return dyn_cast_or_null<func::FuncOp>(
293  symbolTable.lookupNearestSymbolFrom(callOp, sym));
294 }
295 
296 /// Return "true" if the given function signature has tensor semantics.
297 static bool hasTensorSignature(func::FuncOp funcOp) {
298  return llvm::any_of(funcOp.getFunctionType().getInputs(),
299  llvm::IsaPred<TensorType>) ||
300  llvm::any_of(funcOp.getFunctionType().getResults(),
301  llvm::IsaPred<TensorType>);
302 }
303 
304 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
305 /// callee-caller order (i.e., callees without callers first). Store all
306 /// remaining functions (i.e., the ones that call each other recursively) in
307 /// `remainingFuncOps`. Does not traverse nested symbol tables.
308 ///
309 /// Store the map of FuncOp to all its callers in `callerMap`.
310 ///
311 /// Return `failure()` if we are unable to retrieve the called FuncOp from
312 /// any func::CallOp.
313 static LogicalResult getFuncOpsOrderedByCalls(
314  Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
315  SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
316  SymbolTableCollection &symbolTables) {
317  // For each FuncOp, the set of functions called by it (i.e. the union of
318  // symbols of all nested func::CallOp).
320  // For each FuncOp, the number of func::CallOp it contains.
321  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
322  for (mlir::Region &region : moduleOp->getRegions()) {
323  for (mlir::Block &block : region.getBlocks()) {
324  for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
325  // Collect function calls and populate the caller map.
326  numberCallOpsContainedInFuncOp[funcOp] = 0;
327  WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
328  func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
329  assert(calledFunction && "could not retrieved called func::FuncOp");
330  // If the called function does not have any tensors in its signature,
331  // then it is not necessary to bufferize the callee before the caller.
332  if (!hasTensorSignature(calledFunction))
333  return WalkResult::skip();
334 
335  callerMap[calledFunction].insert(callOp);
336  if (calledBy[calledFunction].insert(funcOp).second) {
337  numberCallOpsContainedInFuncOp[funcOp]++;
338  }
339  return WalkResult::advance();
340  });
341  if (res.wasInterrupted())
342  return failure();
343  }
344  }
345  }
346 
347  // Iteratively remove function operations that do not call any of the
348  // functions remaining in the callCounter map and add them to ordered list.
349  SmallVector<func::FuncOp> worklist;
350 
351  for (const auto &entry : numberCallOpsContainedInFuncOp) {
352  if (entry.second == 0)
353  worklist.push_back(entry.first);
354  }
355 
356  while (!worklist.empty()) {
357  func::FuncOp func = worklist.pop_back_val();
358  orderedFuncOps.push_back(func);
359 
360  for (func::FuncOp caller : calledBy[func]) {
361  auto &count = numberCallOpsContainedInFuncOp[caller];
362 
363  if (--count == 0)
364  worklist.push_back(caller);
365  }
366 
367  numberCallOpsContainedInFuncOp.erase(func);
368  }
369 
370  // Put all other functions in the list of remaining functions. These are
371  // functions that call each other circularly.
372  for (auto it : numberCallOpsContainedInFuncOp)
373  remainingFuncOps.push_back(it.first);
374 
375  return success();
376 }
377 
378 /// Helper function that extracts the source from a memref.cast. If the given
379 /// value is not a memref.cast result, simply returns the given value.
380 static Value unpackCast(Value v) {
381  auto castOp = v.getDefiningOp<memref::CastOp>();
382  if (!castOp)
383  return v;
384  return castOp.getSource();
385 }
386 
387 /// Helper function that returns the return types (skipping casts) of the given
388 /// func.return ops. This function returns as many types as the return ops have
389 /// operands. If the i-th operand is not the same for all func.return ops, then
390 /// the i-th returned type is an "empty" type.
392  assert(!returnOps.empty() && "expected at least one ReturnOp");
393  int numOperands = returnOps.front()->getNumOperands();
394 
395  // Helper function that unpacks memref.cast ops and returns the type.
396  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
397 
398  SmallVector<Type> result;
399  for (int i = 0; i < numOperands; ++i) {
400  // Get the type of the i-th operand of the first func.return ops.
401  Type t = getSourceType(returnOps.front()->getOperand(i));
402 
403  // Check if all other func.return ops have a matching operand type.
404  for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
405  if (getSourceType(returnOps[j]->getOperand(i)) != t)
406  t = Type();
407 
408  result.push_back(t);
409  }
410 
411  return result;
412 }
413 
414 /// Fold return values that are memref casts and update function return types.
415 ///
416 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
417 /// is not known yet. Therefore, the bufferization uses memref types with the
418 /// most generic layout map as function return types. After bufferizing the
419 /// entire function body, a more concise memref type can potentially be used for
420 /// the return type of the function.
421 static void foldMemRefCasts(func::FuncOp funcOp) {
422  // There is nothing to do for bodiless ops.
423  if (funcOp.getBody().empty())
424  return;
425 
426  // Compute the common result types of all return ops.
427  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
428  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
429 
430  // Remove direct casts.
431  for (func::ReturnOp returnOp : returnOps) {
432  for (OpOperand &operand : returnOp->getOpOperands()) {
433  // Bail if no common result type was found.
434  if (resultTypes[operand.getOperandNumber()]) {
435  operand.set(unpackCast(operand.get()));
436  }
437  }
438  }
439 
440  // Fill in the missing result types that were not the same among all
441  // func.return ops.
442  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
443  if (resultTypes[i])
444  continue;
445  resultTypes[i] = funcOp.getFunctionType().getResult(i);
446  }
447 
448  // Update the function type.
449  auto newFuncType = FunctionType::get(
450  funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
451  funcOp.setType(newFuncType);
452 }
453 
454 LogicalResult
456  OneShotAnalysisState &state,
457  BufferizationStatistics *statistics) {
458  assert(state.getOptions().bufferizeFunctionBoundaries &&
459  "expected that function boundary bufferization is activated");
461 
462  // A list of non-circular functions in the order in which they are analyzed
463  // and bufferized.
464  SmallVector<func::FuncOp> orderedFuncOps;
465  // A list of all other functions. I.e., functions that call each other
466  // recursively. For these, we analyze the function body but not the function
467  // boundary.
468  SmallVector<func::FuncOp> remainingFuncOps;
469 
470  // A mapping of FuncOps to their callers.
471  FuncCallerMap callerMap;
472 
473  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
474  remainingFuncOps, callerMap,
475  funcState.symbolTables)))
476  return failure();
477 
478  // Analyze functions in order. Starting with functions that are not calling
479  // any other functions.
480  for (func::FuncOp funcOp : orderedFuncOps) {
481  if (!state.getOptions().isOpAllowed(funcOp))
482  continue;
483 
484  // Now analyzing function.
485  funcState.startFunctionAnalysis(funcOp);
486 
487  // Analyze funcOp.
488  if (failed(analyzeOp(funcOp, state, statistics)))
489  return failure();
490 
491  // Run some extra function analyses.
492  if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
493  failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
494  return failure();
495 
496  // Mark op as fully analyzed.
498  }
499 
500  // Analyze all other functions. All function boundary analyses are skipped.
501  for (func::FuncOp funcOp : remainingFuncOps) {
502  if (!state.getOptions().isOpAllowed(funcOp))
503  continue;
504 
505  // Analyze funcOp.
506  if (failed(analyzeOp(funcOp, state, statistics)))
507  return failure();
508 
509  // TODO: We currently skip all function argument analyses for functions
510  // that call each other circularly. These analyses do not support recursive
511  // calls yet. The `BufferizableOpInterface` implementations of `func`
512  // dialect ops return conservative results in the absence of analysis
513  // information.
514  }
515 
516  return success();
517 }
518 
520  Operation *moduleOp) {
521  for (mlir::Region &region : moduleOp->getRegions()) {
522  for (mlir::Block &block : region.getBlocks()) {
523  for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
524  for (BlockArgument bbArg : funcOp.getArguments())
526  }
527  }
528  }
529 }
530 
533  BufferizationState &state, BufferizationStatistics *statistics) {
534  assert(options.bufferizeFunctionBoundaries &&
535  "expected that function boundary bufferization is activated");
536  IRRewriter rewriter(moduleOp->getContext());
537 
538  // A list of non-circular functions in the order in which they are analyzed
539  // and bufferized.
540  SmallVector<func::FuncOp> orderedFuncOps;
541  // A list of all other functions. I.e., functions that call each other
542  // recursively. For these, we analyze the function body but not the function
543  // boundary.
544  SmallVector<func::FuncOp> remainingFuncOps;
545 
546  // A mapping of FuncOps to their callers.
547  FuncCallerMap callerMap;
548 
549  // Try to bufferize functions in calling order. I.e., first bufferize
550  // functions that do not call other functions. This allows us to infer
551  // accurate buffer types for function return values. Functions that call
552  // each other recursively are bufferized in an unspecified order at the end.
553  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
554  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
555  remainingFuncOps, callerMap,
556  state.getSymbolTables())))
557  return failure();
558  llvm::append_range(orderedFuncOps, remainingFuncOps);
559 
560  // Bufferize functions.
561  for (func::FuncOp funcOp : orderedFuncOps) {
562  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
563  // would be invalidated.
564 
565  if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
566  // This function was not analyzed and RaW conflicts were not resolved.
567  // Buffer copies must be inserted before every write.
568  OneShotBufferizationOptions updatedOptions = options;
569  updatedOptions.copyBeforeWrite = true;
570  if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
571  return failure();
572  } else {
573  if (failed(bufferizeOp(funcOp, options, state, statistics)))
574  return failure();
575  }
576 
577  // Change buffer return types to more precise layout maps.
578  if (options.inferFunctionResultLayout)
579  foldMemRefCasts(funcOp);
580  }
581 
582  // Bufferize all other ops.
583  for (mlir::Region &region : moduleOp->getRegions()) {
584  for (mlir::Block &block : region.getBlocks()) {
585  for (mlir::Operation &op :
586  llvm::make_early_inc_range(block.getOperations())) {
587  // Functions were already bufferized.
588  if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
589  continue;
590  if (failed(bufferizeOp(&op, options, state, statistics)))
591  return failure();
592  }
593  }
594  }
595 
596  // Post-pass cleanup of function argument attributes.
598 
599  return success();
600 }
601 
604  BufferizationState &state, BufferizationStatistics *statistics) {
605  assert(options.bufferizeFunctionBoundaries &&
606  "expected that function boundary bufferization is activated");
607  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
608  "invalid combination of bufferization flags");
609  if (!options.copyBeforeWrite) {
610  if (options.noAnalysisFuncFilter.empty()) {
611  if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
612  return failure();
613  } else {
614  // FuncOps whose names are specified in options.noAnalysisFuncFilter will
615  // not be analyzed. Ops in these FuncOps will not be analyzed as well.
616  OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
617  auto func = dyn_cast<func::FuncOp>(op);
618  if (!func)
619  func = op->getParentOfType<func::FuncOp>();
620  if (func)
621  return llvm::is_contained(options.noAnalysisFuncFilter,
622  func.getSymName());
623  return false;
624  };
625  OneShotBufferizationOptions updatedOptions(options);
626  updatedOptions.opFilter.denyOperation(analysisFilterFn);
627  if (failed(
628  insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
629  return failure();
630  }
631  }
632  if (options.testAnalysisOnly)
633  return success();
634  if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
635  return failure();
636  return success();
637 }
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:276
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.