MLIR  22.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1 //===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
2 //----===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Module Bufferization is an extension of One-Shot Bufferize that
11 // bufferizes function boundaries. It provides `BufferizableOpInterface`
12 // implementations for FuncOp, CallOp and ReturnOp. Although it is named
13 // Module Bufferization, it may operate on any SymbolTable.
14 //
15 // Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
16 // ...)`. This function analyzes the given op and determines the order of
17 // analysis and bufferization: Functions that are called are processed before
18 // their respective callers.
19 //
20 // After analyzing a FuncOp, additional information about its bbArgs is
21 // gathered and stored in `FuncAnalysisState`.
22 //
23 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
24 // for
25 // each tensor return value (if any).
26 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
27 // read/written.
28 //
29 // Module Bufferization implements the following calling convention.
30 //
31 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
32 // be written to in-place.
33 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
34 // the CallOp must bufferize out-of-place.
35 //
36 // Example: The tensor.insert op bufferizes in-place because it is allowed to
37 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
38 // out-of-place because `%t0` is modified by the callee but read by the
39 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
40 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
41 // ```
42 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
43 // %f = ... : f32
44 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
45 // return %0 : tensor<?xf32>
46 // }
47 //
48 // func @caller() -> () {
49 // %t0 = ... : tensor<?xf32>
50 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
51 // %2 = tensor.extract %1[...] : tensor<?xf32>
52 // }
53 // ```
54 //
55 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
56 // analyze the function body. In such a case, the CallOp analysis conservatively
57 // assumes that each tensor OpOperand is both read and written.
58 //
59 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
60 // as "not reading" and/or "not writing".
61 
63 
72 #include "mlir/IR/BuiltinTypes.h"
73 #include "mlir/IR/Operation.h"
74 
75 using namespace mlir;
76 using namespace mlir::bufferization;
77 using namespace mlir::bufferization::func_ext;
78 
79 /// A mapping of FuncOps to their callers.
81 
82 /// Get or create FuncAnalysisState.
83 static FuncAnalysisState &
85  auto *result = state.getExtension<FuncAnalysisState>();
86  if (result)
87  return *result;
88  return state.addExtension<FuncAnalysisState>();
89 }
90 
91 namespace {
92 
93 /// Annotate IR with the results of the analysis. For testing purposes only.
94 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
95  BlockArgument bbArg) {
96  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
97  Operation *op = returnVal.getOwner();
98 
99  SmallVector<int64_t> equivBbArgs;
100  if (op->hasAttr(kEquivalentArgsAttr)) {
101  auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
102  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
103  return cast<IntegerAttr>(a).getValue().getSExtValue();
104  }));
105  } else {
106  equivBbArgs.append(op->getNumOperands(), -1);
107  }
108  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
109 
110  OpBuilder b(op->getContext());
111  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
112 }
113 
114 /// Store function BlockArguments that are equivalent to/aliasing a returned
115 /// value in FuncAnalysisState.
116 static LogicalResult
117 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
118  FuncAnalysisState &funcState) {
119  if (funcOp.getBody().empty()) {
120  // No function body available. Conservatively assume that every tensor
121  // return value may alias with any tensor bbArg.
122  FunctionType type = funcOp.getFunctionType();
123  for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
124  if (!isa<TensorType>(inputIt.value()))
125  continue;
126  for (const auto &resultIt : llvm::enumerate(type.getResults())) {
127  if (!isa<TensorType>(resultIt.value()))
128  continue;
129  int64_t returnIdx = resultIt.index();
130  int64_t bbArgIdx = inputIt.index();
131  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
132  }
133  }
134  return success();
135  }
136 
137  // Find all func.return ops.
138  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
139  assert(!returnOps.empty() && "expected at least one ReturnOp");
140 
141  // Build alias sets. Merge all aliases from all func.return ops.
142  for (BlockArgument bbArg : funcOp.getArguments()) {
143  if (isa<RankedTensorType>(bbArg.getType())) {
144  int64_t bbArgIdx = bbArg.getArgNumber();
145  // Store aliases in a set, so that we don't add the same alias twice.
146  SetVector<int64_t> aliases;
147  for (func::ReturnOp returnOp : returnOps) {
148  for (OpOperand &returnVal : returnOp->getOpOperands()) {
149  if (isa<RankedTensorType>(returnVal.get().getType())) {
150  int64_t returnIdx = returnVal.getOperandNumber();
151  if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
152  aliases.insert(returnIdx);
153  }
154  }
155  }
156  for (int64_t alias : aliases)
157  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
158  }
159  }
160 
161  // Build equivalence sets.
162  // Helper function that finds an equivalent block argument index for the
163  // given OpOperand. Return std::nullopt if no equivalent block argument could
164  // be found.
165  auto findEquivalentBlockArgIdx =
166  [&](OpOperand &opOperand) -> std::optional<int64_t> {
167  Value v = opOperand.get();
168  if (!isa<TensorType>(v.getType()))
169  return std::nullopt;
170  for (BlockArgument bbArg : funcOp.getArguments()) {
171  if (isa<RankedTensorType>(bbArg.getType())) {
172  if (state.areEquivalentBufferizedValues(v, bbArg)) {
173  if (state.getOptions().testAnalysisOnly)
174  annotateEquivalentReturnBbArg(opOperand, bbArg);
175  return bbArg.getArgNumber();
176  }
177  }
178  }
179  return std::nullopt;
180  };
181 
182  int64_t numResults = returnOps.front()->getNumOperands();
183  for (int64_t i = 0; i < numResults; ++i) {
184  // Find the equivalent block argument index for the i-th operand of the
185  // first func.return op.
186  std::optional<int64_t> maybeEquiv =
187  findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
188  if (!maybeEquiv.has_value())
189  continue;
190  int64_t bbArgIdx = *maybeEquiv;
191  bool allEquiv = true;
192 
193  // Check if all other func.return ops have the same equivalent block
194  // argument for the i-th operand. In contrast to aliasing information,
195  // which is just "merged", equivalence information must match across all
196  // func.return ops.
197  for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
198  std::optional<int64_t> maybeEquiv =
199  findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
200  if (maybeEquiv != bbArgIdx) {
201  allEquiv = false;
202  break;
203  }
204  }
205 
206  // All func.return ops have the same equivalent block argument for the i-th
207  // operand.
208  if (allEquiv)
209  funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
210  }
211 
212  return success();
213 }
214 
215 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
216  bool isWritten) {
217  OpBuilder b(funcOp.getContext());
218  Attribute accessType;
219  if (isRead && isWritten) {
220  accessType = b.getStringAttr("read-write");
221  } else if (isRead) {
222  accessType = b.getStringAttr("read");
223  } else if (isWritten) {
224  accessType = b.getStringAttr("write");
225  } else {
226  accessType = b.getStringAttr("none");
227  }
228  funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
229  accessType);
230 }
231 
232 /// Determine which FuncOp bbArgs are read and which are written. When run on a
233 /// function with unknown ops, we conservatively assume that such ops bufferize
234 /// to a read + write.
235 static LogicalResult
236 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
237  FuncAnalysisState &funcState) {
238  for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
239  ++idx) {
240  // Skip non-tensor arguments.
241  if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
242  continue;
243  bool isRead;
244  bool isWritten;
245  if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
246  idx, BufferizationDialect::kBufferAccessAttrName)) {
247  // Buffer access behavior is specified on the function. Skip the analysis.
248  StringRef str = accessAttr.getValue();
249  isRead = str == "read" || str == "read-write";
250  isWritten = str == "write" || str == "read-write";
251  } else if (funcOp.getBody().empty()) {
252  // If the function has no body, conservatively assume that all args are
253  // read + written.
254  isRead = true;
255  isWritten = true;
256  } else {
257  // Analyze the body of the function.
258  BlockArgument bbArg = funcOp.getArgument(idx);
259  isRead = state.isValueRead(bbArg);
260  isWritten = state.isValueWritten(bbArg);
261  }
262 
263  if (state.getOptions().testAnalysisOnly)
264  annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
265  if (isRead)
266  funcState.readBbArgs[funcOp].insert(idx);
267  if (isWritten)
268  funcState.writtenBbArgs[funcOp].insert(idx);
269  }
270 
271  return success();
272 }
273 } // namespace
274 
275 /// Remove bufferization attributes on FuncOp arguments.
277  auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
278  funcOp.removeArgAttr(bbArg.getArgNumber(),
279  BufferizationDialect::kBufferLayoutAttrName);
280  funcOp.removeArgAttr(bbArg.getArgNumber(),
281  BufferizationDialect::kWritableAttrName);
282 }
283 
284 /// Return the func::FuncOp called by `callOp`.
285 static func::FuncOp
286 getCalledFunction(func::CallOp callOp,
287  mlir::SymbolTableCollection &symbolTable) {
288  return dyn_cast_or_null<func::FuncOp>(
289  callOp.resolveCallableInTable(&symbolTable));
290 }
291 
292 /// Return "true" if the given function signature has tensor semantics.
293 static bool hasTensorSignature(func::FuncOp funcOp) {
294  return llvm::any_of(funcOp.getFunctionType().getInputs(),
295  llvm::IsaPred<TensorType>) ||
296  llvm::any_of(funcOp.getFunctionType().getResults(),
297  llvm::IsaPred<TensorType>);
298 }
299 
300 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
301 /// callee-caller order (i.e., callees without callers first). Store all
302 /// remaining functions (i.e., the ones that call each other recursively) in
303 /// `remainingFuncOps`. Does not traverse nested symbol tables.
304 ///
305 /// Store the map of FuncOp to all its callers in `callerMap`.
306 ///
307 /// Return `failure()` if we are unable to retrieve the called FuncOp from
308 /// any func::CallOp.
309 static LogicalResult getFuncOpsOrderedByCalls(
310  Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
311  SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
312  SymbolTableCollection &symbolTables) {
313  // For each FuncOp, the set of functions called by it (i.e. the union of
314  // symbols of all nested func::CallOp).
316  // For each FuncOp, the number of func::CallOp it contains.
317  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
318  for (mlir::Region &region : moduleOp->getRegions()) {
319  for (mlir::Block &block : region.getBlocks()) {
320  for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
321  // Collect function calls and populate the caller map.
322  numberCallOpsContainedInFuncOp[funcOp] = 0;
323  WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
324  func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
325  assert(calledFunction && "could not retrieved called func::FuncOp");
326  // If the called function does not have any tensors in its signature,
327  // then it is not necessary to bufferize the callee before the caller.
328  if (!hasTensorSignature(calledFunction))
329  return WalkResult::skip();
330 
331  callerMap[calledFunction].insert(callOp);
332  if (calledBy[calledFunction].insert(funcOp).second) {
333  numberCallOpsContainedInFuncOp[funcOp]++;
334  }
335  return WalkResult::advance();
336  });
337  if (res.wasInterrupted())
338  return failure();
339  }
340  }
341  }
342 
343  // Iteratively remove function operations that do not call any of the
344  // functions remaining in the callCounter map and add them to ordered list.
345  SmallVector<func::FuncOp> worklist;
346 
347  for (const auto &entry : numberCallOpsContainedInFuncOp) {
348  if (entry.second == 0)
349  worklist.push_back(entry.first);
350  }
351 
352  while (!worklist.empty()) {
353  func::FuncOp func = worklist.pop_back_val();
354  orderedFuncOps.push_back(func);
355 
356  for (func::FuncOp caller : calledBy[func]) {
357  auto &count = numberCallOpsContainedInFuncOp[caller];
358 
359  if (--count == 0)
360  worklist.push_back(caller);
361  }
362 
363  numberCallOpsContainedInFuncOp.erase(func);
364  }
365 
366  // Put all other functions in the list of remaining functions. These are
367  // functions that call each other circularly.
368  for (auto it : numberCallOpsContainedInFuncOp)
369  remainingFuncOps.push_back(it.first);
370 
371  return success();
372 }
373 
374 /// Helper function that extracts the source from a memref.cast. If the given
375 /// value is not a memref.cast result, simply returns the given value.
376 static Value unpackCast(Value v) {
377  auto castOp = v.getDefiningOp<memref::CastOp>();
378  if (!castOp)
379  return v;
380  return castOp.getSource();
381 }
382 
383 /// Helper function that returns the return types (skipping casts) of the given
384 /// func.return ops. This function returns as many types as the return ops have
385 /// operands. If the i-th operand is not the same for all func.return ops, then
386 /// the i-th returned type is an "empty" type.
388  assert(!returnOps.empty() && "expected at least one ReturnOp");
389  int numOperands = returnOps.front()->getNumOperands();
390 
391  // Helper function that unpacks memref.cast ops and returns the type.
392  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
393 
394  SmallVector<Type> result;
395  for (int i = 0; i < numOperands; ++i) {
396  // Get the type of the i-th operand of the first func.return ops.
397  Type t = getSourceType(returnOps.front()->getOperand(i));
398 
399  // Check if all other func.return ops have a matching operand type.
400  for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
401  if (getSourceType(returnOps[j]->getOperand(i)) != t)
402  t = Type();
403 
404  result.push_back(t);
405  }
406 
407  return result;
408 }
409 
410 /// Fold return values that are memref casts and update function return types.
411 ///
412 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
413 /// is not known yet. Therefore, the bufferization uses memref types with the
414 /// most generic layout map as function return types. After bufferizing the
415 /// entire function body, a more concise memref type can potentially be used for
416 /// the return type of the function.
417 static void foldMemRefCasts(func::FuncOp funcOp) {
418  // There is nothing to do for bodiless ops.
419  if (funcOp.getBody().empty())
420  return;
421 
422  // Compute the common result types of all return ops.
423  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
424  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
425 
426  // Remove direct casts.
427  for (func::ReturnOp returnOp : returnOps) {
428  for (OpOperand &operand : returnOp->getOpOperands()) {
429  // Bail if no common result type was found.
430  if (resultTypes[operand.getOperandNumber()]) {
431  operand.set(unpackCast(operand.get()));
432  }
433  }
434  }
435 
436  // Fill in the missing result types that were not the same among all
437  // func.return ops.
438  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
439  if (resultTypes[i])
440  continue;
441  resultTypes[i] = funcOp.getFunctionType().getResult(i);
442  }
443 
444  // Update the function type.
445  auto newFuncType = FunctionType::get(
446  funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
447  funcOp.setType(newFuncType);
448 }
449 
450 LogicalResult
452  OneShotAnalysisState &state,
453  BufferizationStatistics *statistics) {
454  assert(state.getOptions().bufferizeFunctionBoundaries &&
455  "expected that function boundary bufferization is activated");
457 
458  // A list of non-circular functions in the order in which they are analyzed
459  // and bufferized.
460  SmallVector<func::FuncOp> orderedFuncOps;
461  // A list of all other functions. I.e., functions that call each other
462  // recursively. For these, we analyze the function body but not the function
463  // boundary.
464  SmallVector<func::FuncOp> remainingFuncOps;
465 
466  // A mapping of FuncOps to their callers.
467  FuncCallerMap callerMap;
468 
469  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
470  remainingFuncOps, callerMap,
471  funcState.symbolTables)))
472  return failure();
473 
474  // Analyze functions in order. Starting with functions that are not calling
475  // any other functions.
476  for (func::FuncOp funcOp : orderedFuncOps) {
477  if (!state.getOptions().isOpAllowed(funcOp))
478  continue;
479 
480  // Now analyzing function.
481  funcState.startFunctionAnalysis(funcOp);
482 
483  // Analyze funcOp.
484  if (failed(analyzeOp(funcOp, state, statistics)))
485  return failure();
486 
487  // Run some extra function analyses.
488  if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
489  failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
490  return failure();
491 
492  // Mark op as fully analyzed.
494  }
495 
496  // Analyze all other functions. All function boundary analyses are skipped.
497  for (func::FuncOp funcOp : remainingFuncOps) {
498  if (!state.getOptions().isOpAllowed(funcOp))
499  continue;
500 
501  // Analyze funcOp.
502  if (failed(analyzeOp(funcOp, state, statistics)))
503  return failure();
504 
505  // TODO: We currently skip all function argument analyses for functions
506  // that call each other circularly. These analyses do not support recursive
507  // calls yet. The `BufferizableOpInterface` implementations of `func`
508  // dialect ops return conservative results in the absence of analysis
509  // information.
510  }
511 
512  return success();
513 }
514 
516  Operation *moduleOp) {
517  for (mlir::Region &region : moduleOp->getRegions()) {
518  for (mlir::Block &block : region.getBlocks()) {
519  for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
520  for (BlockArgument bbArg : funcOp.getArguments())
522  }
523  }
524  }
525 }
526 
529  BufferizationState &state, BufferizationStatistics *statistics) {
530  assert(options.bufferizeFunctionBoundaries &&
531  "expected that function boundary bufferization is activated");
532  IRRewriter rewriter(moduleOp->getContext());
533 
534  // A list of non-circular functions in the order in which they are analyzed
535  // and bufferized.
536  SmallVector<func::FuncOp> orderedFuncOps;
537  // A list of all other functions. I.e., functions that call each other
538  // recursively. For these, we analyze the function body but not the function
539  // boundary.
540  SmallVector<func::FuncOp> remainingFuncOps;
541 
542  // A mapping of FuncOps to their callers.
543  FuncCallerMap callerMap;
544 
545  // Try to bufferize functions in calling order. I.e., first bufferize
546  // functions that do not call other functions. This allows us to infer
547  // accurate buffer types for function return values. Functions that call
548  // each other recursively are bufferized in an unspecified order at the end.
549  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
550  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
551  remainingFuncOps, callerMap,
552  state.getSymbolTables())))
553  return failure();
554  llvm::append_range(orderedFuncOps, remainingFuncOps);
555 
556  // Bufferize functions.
557  for (func::FuncOp funcOp : orderedFuncOps) {
558  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
559  // would be invalidated.
560 
561  if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
562  // This function was not analyzed and RaW conflicts were not resolved.
563  // Buffer copies must be inserted before every write.
564  OneShotBufferizationOptions updatedOptions = options;
565  updatedOptions.copyBeforeWrite = true;
566  if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
567  return failure();
568  } else {
569  if (failed(bufferizeOp(funcOp, options, state, statistics)))
570  return failure();
571  }
572 
573  // Change buffer return types to more precise layout maps.
574  if (options.inferFunctionResultLayout)
575  foldMemRefCasts(funcOp);
576  }
577 
578  // Bufferize all other ops.
579  for (mlir::Region &region : moduleOp->getRegions()) {
580  for (mlir::Block &block : region.getBlocks()) {
581  for (mlir::Operation &op :
582  llvm::make_early_inc_range(block.getOperations())) {
583  // Functions were already bufferized.
584  if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
585  continue;
586  if (failed(bufferizeOp(&op, options, state, statistics)))
587  return failure();
588  }
589  }
590  }
591 
592  // Post-pass cleanup of function argument attributes.
594 
595  return success();
596 }
597 
600  BufferizationState &state, BufferizationStatistics *statistics) {
601  assert(options.bufferizeFunctionBoundaries &&
602  "expected that function boundary bufferization is activated");
603  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
604  "invalid combination of bufferization flags");
605  if (!options.copyBeforeWrite) {
606  if (options.noAnalysisFuncFilter.empty()) {
607  if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
608  return failure();
609  } else {
610  // FuncOps whose names are specified in options.noAnalysisFuncFilter will
611  // not be analyzed. Ops in these FuncOps will not be analyzed as well.
612  OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
613  auto func = dyn_cast<func::FuncOp>(op);
614  if (!func)
615  func = op->getParentOfType<func::FuncOp>();
616  if (func)
617  return llvm::is_contained(options.noAnalysisFuncFilter,
618  func.getSymName());
619  return false;
620  };
621  OneShotBufferizationOptions updatedOptions(options);
622  updatedOptions.opFilter.denyOperation(analysisFilterFn);
623  if (failed(
624  insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
625  return failure();
626  }
627  }
628  if (options.testAnalysisOnly)
629  return success();
630  if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
631  return failure();
632  return success();
633 }
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult skip()
Definition: WalkResult.h:48
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:277
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Definition: Bufferize.h:35
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.