MLIR  21.0.0git
OneShotModuleBufferize.cpp
Go to the documentation of this file.
1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Module Bufferization is an extension of One-Shot Bufferize that
10 // bufferizes function boundaries. It provides `BufferizableOpInterface`
11 // implementations for FuncOp, CallOp and ReturnOp.
12 //
13 // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14 // This function analyzes the given module and determines the order of analysis
15 // and bufferization: Functions that are called are processed before their
16 // respective callers.
17 //
18 // After analyzing a FuncOp, additional information about its bbArgs is
19 // gathered and stored in `FuncAnalysisState`.
20 //
21 // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs
22 // for
23 // each tensor return value (if any).
24 // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25 // read/written.
26 //
27 // Module Bufferization implements the following calling convention.
28 //
29 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30 // be written to in-place.
31 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
32 // the CallOp must bufferize out-of-place.
33 //
34 // Example: The tensor.insert op bufferizes in-place because it is allowed to
35 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
36 // out-of-place because `%t0` is modified by the callee but read by the
37 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
38 // bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
39 // ```
40 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
41 // %f = ... : f32
42 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
43 // return %0 : tensor<?xf32>
44 // }
45 //
46 // func @caller() -> () {
47 // %t0 = ... : tensor<?xf32>
48 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
49 // %2 = tensor.extract %1[...] : tensor<?xf32>
50 // }
51 // ```
52 //
53 // Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
54 // analyze the function body. In such a case, the CallOp analysis conservatively
55 // assumes that each tensor OpOperand is both read and written.
56 //
57 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58 // as "not reading" and/or "not writing".
59 
61 
70 #include "mlir/IR/BuiltinTypes.h"
71 #include "mlir/IR/Operation.h"
72 
73 using namespace mlir;
74 using namespace mlir::bufferization;
75 using namespace mlir::bufferization::func_ext;
76 
77 /// A mapping of FuncOps to their callers.
79 
80 /// Get or create FuncAnalysisState.
81 static FuncAnalysisState &
83  auto *result = state.getExtension<FuncAnalysisState>();
84  if (result)
85  return *result;
86  return state.addExtension<FuncAnalysisState>();
87 }
88 
89 namespace {
90 
91 /// Annotate IR with the results of the analysis. For testing purposes only.
92 static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
93  BlockArgument bbArg) {
94  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
95  Operation *op = returnVal.getOwner();
96 
97  SmallVector<int64_t> equivBbArgs;
98  if (op->hasAttr(kEquivalentArgsAttr)) {
99  auto attr = cast<ArrayAttr>(op->getAttr(kEquivalentArgsAttr));
100  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
101  return cast<IntegerAttr>(a).getValue().getSExtValue();
102  }));
103  } else {
104  equivBbArgs.append(op->getNumOperands(), -1);
105  }
106  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
107 
108  OpBuilder b(op->getContext());
109  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
110 }
111 
112 /// Store function BlockArguments that are equivalent to/aliasing a returned
113 /// value in FuncAnalysisState.
114 static LogicalResult
115 aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
116  FuncAnalysisState &funcState) {
117  if (funcOp.getBody().empty()) {
118  // No function body available. Conservatively assume that every tensor
119  // return value may alias with any tensor bbArg.
120  FunctionType type = funcOp.getFunctionType();
121  for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
122  if (!isa<TensorType>(inputIt.value()))
123  continue;
124  for (const auto &resultIt : llvm::enumerate(type.getResults())) {
125  if (!isa<TensorType>(resultIt.value()))
126  continue;
127  int64_t returnIdx = resultIt.index();
128  int64_t bbArgIdx = inputIt.index();
129  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx);
130  }
131  }
132  return success();
133  }
134 
135  // Find all func.return ops.
136  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
137  assert(!returnOps.empty() && "expected at least one ReturnOp");
138 
139  // Build alias sets. Merge all aliases from all func.return ops.
140  for (BlockArgument bbArg : funcOp.getArguments()) {
141  if (isa<RankedTensorType>(bbArg.getType())) {
142  int64_t bbArgIdx = bbArg.getArgNumber();
143  // Store aliases in a set, so that we don't add the same alias twice.
144  SetVector<int64_t> aliases;
145  for (func::ReturnOp returnOp : returnOps) {
146  for (OpOperand &returnVal : returnOp->getOpOperands()) {
147  if (isa<RankedTensorType>(returnVal.get().getType())) {
148  int64_t returnIdx = returnVal.getOperandNumber();
149  if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
150  aliases.insert(returnIdx);
151  }
152  }
153  }
154  for (int64_t alias : aliases)
155  funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(alias);
156  }
157  }
158 
159  // Build equivalence sets.
160  // Helper function that finds an equivalent block argument index for the
161  // given OpOperand. Return std::nullopt if no equivalent block argument could
162  // be found.
163  auto findEquivalentBlockArgIdx =
164  [&](OpOperand &opOperand) -> std::optional<int64_t> {
165  Value v = opOperand.get();
166  if (!isa<TensorType>(v.getType()))
167  return std::nullopt;
168  for (BlockArgument bbArg : funcOp.getArguments()) {
169  if (isa<RankedTensorType>(bbArg.getType())) {
170  if (state.areEquivalentBufferizedValues(v, bbArg)) {
171  if (state.getOptions().testAnalysisOnly)
172  annotateEquivalentReturnBbArg(opOperand, bbArg);
173  return bbArg.getArgNumber();
174  }
175  }
176  }
177  return std::nullopt;
178  };
179 
180  int64_t numResults = returnOps.front()->getNumOperands();
181  for (int64_t i = 0; i < numResults; ++i) {
182  // Find the equivalent block argument index for the i-th operand of the
183  // first func.return op.
184  std::optional<int64_t> maybeEquiv =
185  findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186  if (!maybeEquiv.has_value())
187  continue;
188  int64_t bbArgIdx = *maybeEquiv;
189  bool allEquiv = true;
190 
191  // Check if all other func.return ops have the same equivalent block
192  // argument for the i-th operand. In contrast to aliasing information,
193  // which is just "merged", equivalence information must match across all
194  // func.return ops.
195  for (func::ReturnOp returnOp : ArrayRef(returnOps).drop_front()) {
196  std::optional<int64_t> maybeEquiv =
197  findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198  if (maybeEquiv != bbArgIdx) {
199  allEquiv = false;
200  break;
201  }
202  }
203 
204  // All func.return ops have the same equivalent block argument for the i-th
205  // operand.
206  if (allEquiv)
207  funcState.equivalentFuncArgs[funcOp][i] = bbArgIdx;
208  }
209 
210  return success();
211 }
212 
213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx, bool isRead,
214  bool isWritten) {
215  OpBuilder b(funcOp.getContext());
216  Attribute accessType;
217  if (isRead && isWritten) {
218  accessType = b.getStringAttr("read-write");
219  } else if (isRead) {
220  accessType = b.getStringAttr("read");
221  } else if (isWritten) {
222  accessType = b.getStringAttr("write");
223  } else {
224  accessType = b.getStringAttr("none");
225  }
226  funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
227  accessType);
228 }
229 
230 /// Determine which FuncOp bbArgs are read and which are written. When run on a
231 /// function with unknown ops, we conservatively assume that such ops bufferize
232 /// to a read + write.
233 static LogicalResult
234 funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
235  FuncAnalysisState &funcState) {
236  for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
237  ++idx) {
238  // Skip non-tensor arguments.
239  if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
240  continue;
241  bool isRead;
242  bool isWritten;
243  if (auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244  idx, BufferizationDialect::kBufferAccessAttrName)) {
245  // Buffer access behavior is specified on the function. Skip the analysis.
246  StringRef str = accessAttr.getValue();
247  isRead = str == "read" || str == "read-write";
248  isWritten = str == "write" || str == "read-write";
249  } else if (funcOp.getBody().empty()) {
250  // If the function has no body, conservatively assume that all args are
251  // read + written.
252  isRead = true;
253  isWritten = true;
254  } else {
255  // Analyze the body of the function.
256  BlockArgument bbArg = funcOp.getArgument(idx);
257  isRead = state.isValueRead(bbArg);
258  isWritten = state.isValueWritten(bbArg);
259  }
260 
261  if (state.getOptions().testAnalysisOnly)
262  annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
263  if (isRead)
264  funcState.readBbArgs[funcOp].insert(idx);
265  if (isWritten)
266  funcState.writtenBbArgs[funcOp].insert(idx);
267  }
268 
269  return success();
270 }
271 } // namespace
272 
273 /// Remove bufferization attributes on FuncOp arguments.
275  auto funcOp = cast<func::FuncOp>(bbArg.getOwner()->getParentOp());
276  funcOp.removeArgAttr(bbArg.getArgNumber(),
277  BufferizationDialect::kBufferLayoutAttrName);
278  funcOp.removeArgAttr(bbArg.getArgNumber(),
279  BufferizationDialect::kWritableAttrName);
280 }
281 
282 /// Return the func::FuncOp called by `callOp`.
283 static func::FuncOp getCalledFunction(func::CallOp callOp) {
284  SymbolRefAttr sym =
285  llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
286  if (!sym)
287  return nullptr;
288  return dyn_cast_or_null<func::FuncOp>(
290 }
291 
292 /// Return "true" if the given function signature has tensor semantics.
293 static bool hasTensorSignature(func::FuncOp funcOp) {
294  return llvm::any_of(funcOp.getFunctionType().getInputs(),
295  llvm::IsaPred<TensorType>) ||
296  llvm::any_of(funcOp.getFunctionType().getResults(),
297  llvm::IsaPred<TensorType>);
298 }
299 
300 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
301 /// callee-caller order (i.e., callees without callers first). Store all
302 /// remaining functions (i.e., the ones that call each other recursively) in
303 /// `remainingFuncOps`. Does not traverse nested symbol tables.
304 ///
305 /// Store the map of FuncOp to all its callers in `callerMap`.
306 ///
307 /// Return `failure()` if we are unable to retrieve the called FuncOp from
308 /// any func::CallOp.
309 static LogicalResult getFuncOpsOrderedByCalls(
310  ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
311  SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap) {
312  // For each FuncOp, the set of functions called by it (i.e. the union of
313  // symbols of all nested func::CallOp).
315  // For each FuncOp, the number of func::CallOp it contains.
316  DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
317  for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
318  // Collect function calls and populate the caller map.
319  numberCallOpsContainedInFuncOp[funcOp] = 0;
320  WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
321  func::FuncOp calledFunction = getCalledFunction(callOp);
322  assert(calledFunction && "could not retrieved called func::FuncOp");
323  // If the called function does not have any tensors in its signature, then
324  // it is not necessary to bufferize the callee before the caller.
325  if (!hasTensorSignature(calledFunction))
326  return WalkResult::skip();
327 
328  callerMap[calledFunction].insert(callOp);
329  if (calledBy[calledFunction].insert(funcOp).second) {
330  numberCallOpsContainedInFuncOp[funcOp]++;
331  }
332  return WalkResult::advance();
333  });
334  if (res.wasInterrupted())
335  return failure();
336  }
337 
338  // Iteratively remove function operations that do not call any of the
339  // functions remaining in the callCounter map and add them to ordered list.
340  while (!numberCallOpsContainedInFuncOp.empty()) {
341  auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
342  [](auto entry) { return entry.getSecond() == 0; });
343  if (it == numberCallOpsContainedInFuncOp.end())
344  break;
345  orderedFuncOps.push_back(it->getFirst());
346  for (auto callee : calledBy[it->getFirst()])
347  numberCallOpsContainedInFuncOp[callee]--;
348  numberCallOpsContainedInFuncOp.erase(it);
349  }
350 
351  // Put all other functions in the list of remaining functions. These are
352  // functions that call each other circularly.
353  for (auto it : numberCallOpsContainedInFuncOp)
354  remainingFuncOps.push_back(it.first);
355 
356  return success();
357 }
358 
359 /// Helper function that extracts the source from a memref.cast. If the given
360 /// value is not a memref.cast result, simply returns the given value.
361 static Value unpackCast(Value v) {
362  auto castOp = v.getDefiningOp<memref::CastOp>();
363  if (!castOp)
364  return v;
365  return castOp.getSource();
366 }
367 
368 /// Helper function that returns the return types (skipping casts) of the given
369 /// func.return ops. This function returns as many types as the return ops have
370 /// operands. If the i-th operand is not the same for all func.return ops, then
371 /// the i-th returned type is an "empty" type.
373  assert(!returnOps.empty() && "expected at least one ReturnOp");
374  int numOperands = returnOps.front()->getNumOperands();
375 
376  // Helper function that unpacks memref.cast ops and returns the type.
377  auto getSourceType = [&](Value v) { return unpackCast(v).getType(); };
378 
379  SmallVector<Type> result;
380  for (int i = 0; i < numOperands; ++i) {
381  // Get the type of the i-th operand of the first func.return ops.
382  Type t = getSourceType(returnOps.front()->getOperand(i));
383 
384  // Check if all other func.return ops have a matching operand type.
385  for (int j = 1; j < static_cast<int>(returnOps.size()); ++j)
386  if (getSourceType(returnOps[j]->getOperand(i)) != t)
387  t = Type();
388 
389  result.push_back(t);
390  }
391 
392  return result;
393 }
394 
395 /// Fold return values that are memref casts and update function return types.
396 ///
397 /// During FuncOp bufferization, the exact type of the returned memrefs (if any)
398 /// is not known yet. Therefore, the bufferization uses memref types with the
399 /// most generic layout map as function return types. After bufferizing the
400 /// entire function body, a more concise memref type can potentially be used for
401 /// the return type of the function.
402 static void foldMemRefCasts(func::FuncOp funcOp) {
403  // There is nothing to do for bodiless ops.
404  if (funcOp.getBody().empty())
405  return;
406 
407  // Compute the common result types of all return ops.
408  SmallVector<func::ReturnOp> returnOps = getReturnOps(funcOp);
409  SmallVector<Type> resultTypes = getReturnTypes(returnOps);
410 
411  // Remove direct casts.
412  for (func::ReturnOp returnOp : returnOps) {
413  for (OpOperand &operand : returnOp->getOpOperands()) {
414  // Bail if no common result type was found.
415  if (resultTypes[operand.getOperandNumber()]) {
416  operand.set(unpackCast(operand.get()));
417  }
418  }
419  }
420 
421  // Fill in the missing result types that were not the same among all
422  // func.return ops.
423  for (int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
424  if (resultTypes[i])
425  continue;
426  resultTypes[i] = funcOp.getFunctionType().getResult(i);
427  }
428 
429  // Update the function type.
430  auto newFuncType = FunctionType::get(
431  funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
432  funcOp.setType(newFuncType);
433 }
434 
435 LogicalResult
437  OneShotAnalysisState &state,
438  BufferizationStatistics *statistics) {
439  assert(state.getOptions().bufferizeFunctionBoundaries &&
440  "expected that function boundary bufferization is activated");
442 
443  // A list of non-circular functions in the order in which they are analyzed
444  // and bufferized.
445  SmallVector<func::FuncOp> orderedFuncOps;
446  // A list of all other functions. I.e., functions that call each other
447  // recursively. For these, we analyze the function body but not the function
448  // boundary.
449  SmallVector<func::FuncOp> remainingFuncOps;
450 
451  // A mapping of FuncOps to their callers.
452  FuncCallerMap callerMap;
453 
454  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
455  remainingFuncOps, callerMap)))
456  return failure();
457 
458  // Analyze functions in order. Starting with functions that are not calling
459  // any other functions.
460  for (func::FuncOp funcOp : orderedFuncOps) {
461  if (!state.getOptions().isOpAllowed(funcOp))
462  continue;
463 
464  // Now analyzing function.
465  funcState.startFunctionAnalysis(funcOp);
466 
467  // Analyze funcOp.
468  if (failed(analyzeOp(funcOp, state, statistics)))
469  return failure();
470 
471  // Run some extra function analyses.
472  if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
473  failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
474  return failure();
475 
476  // Mark op as fully analyzed.
478  }
479 
480  // Analyze all other functions. All function boundary analyses are skipped.
481  for (func::FuncOp funcOp : remainingFuncOps) {
482  if (!state.getOptions().isOpAllowed(funcOp))
483  continue;
484 
485  // Analyze funcOp.
486  if (failed(analyzeOp(funcOp, state, statistics)))
487  return failure();
488 
489  // TODO: We currently skip all function argument analyses for functions
490  // that call each other circularly. These analyses do not support recursive
491  // calls yet. The `BufferizableOpInterface` implementations of `func`
492  // dialect ops return conservative results in the absence of analysis
493  // information.
494  }
495 
496  return success();
497 }
498 
500  ModuleOp moduleOp) {
501  for (auto op : moduleOp.getOps<func::FuncOp>()) {
502  for (BlockArgument bbArg : op.getArguments())
504  }
505 }
506 
508  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
509  BufferizationStatistics *statistics) {
510  assert(options.bufferizeFunctionBoundaries &&
511  "expected that function boundary bufferization is activated");
512  IRRewriter rewriter(moduleOp.getContext());
513 
514  // A list of non-circular functions in the order in which they are analyzed
515  // and bufferized.
516  SmallVector<func::FuncOp> orderedFuncOps;
517  // A list of all other functions. I.e., functions that call each other
518  // recursively. For these, we analyze the function body but not the function
519  // boundary.
520  SmallVector<func::FuncOp> remainingFuncOps;
521 
522  // A mapping of FuncOps to their callers.
523  FuncCallerMap callerMap;
524 
525  // Try to bufferize functions in calling order. I.e., first bufferize
526  // functions that do not call other functions. This allows us to infer
527  // accurate buffer types for function return values. Functions that call
528  // each other recursively are bufferized in an unspecified order at the end.
529  // We may use unnecessarily "complex" (in terms of layout map) buffer types.
530  if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
531  remainingFuncOps, callerMap)))
532  return failure();
533  llvm::append_range(orderedFuncOps, remainingFuncOps);
534 
535  // Bufferize functions.
536  for (func::FuncOp funcOp : orderedFuncOps) {
537  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
538  // would be invalidated.
539 
540  if (llvm::is_contained(options.noAnalysisFuncFilter, funcOp.getSymName())) {
541  // This function was not analyzed and RaW conflicts were not resolved.
542  // Buffer copies must be inserted before every write.
543  OneShotBufferizationOptions updatedOptions = options;
544  updatedOptions.copyBeforeWrite = true;
545  if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
546  return failure();
547  } else {
548  if (failed(bufferizeOp(funcOp, options, statistics)))
549  return failure();
550  }
551 
552  // Change buffer return types to more precise layout maps.
553  if (options.inferFunctionResultLayout)
554  foldMemRefCasts(funcOp);
555  }
556 
557  // Bufferize all other ops.
558  for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
559  // Functions were already bufferized.
560  if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
561  continue;
562  if (failed(bufferizeOp(&op, options, statistics)))
563  return failure();
564  }
565 
566  // Post-pass cleanup of function argument attributes.
568 
569  return success();
570 }
571 
573  ModuleOp moduleOp, const OneShotBufferizationOptions &options,
574  BufferizationStatistics *statistics) {
575  assert(options.bufferizeFunctionBoundaries &&
576  "expected that function boundary bufferization is activated");
577  assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
578  "invalid combination of bufferization flags");
579  if (!options.copyBeforeWrite) {
580  if (options.noAnalysisFuncFilter.empty()) {
581  if (failed(insertTensorCopies(moduleOp, options, statistics)))
582  return failure();
583  } else {
584  // FuncOps whose names are specified in options.noAnalysisFuncFilter will
585  // not be analyzed. Ops in these FuncOps will not be analyzed as well.
586  OpFilter::Entry::FilterFn analysisFilterFn = [=](Operation *op) {
587  auto func = dyn_cast<func::FuncOp>(op);
588  if (!func)
589  func = op->getParentOfType<func::FuncOp>();
590  if (func)
591  return llvm::is_contained(options.noAnalysisFuncFilter,
592  func.getSymName());
593  return false;
594  };
595  OneShotBufferizationOptions updatedOptions(options);
596  updatedOptions.opFilter.denyOperation(analysisFilterFn);
597  if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
598  return failure();
599  }
600  }
601  if (options.testAnalysisOnly)
602  return success();
603  if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
604  return failure();
605  return success();
606 }
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:784
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:346
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:281
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
void removeBufferizationAttributesInModule(ModuleOp moduleOp)
Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Definition: Bufferize.h:34
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.