MLIR  14.0.0git
ModuleBufferization.cpp
Go to the documentation of this file.
1 //===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Module Bufferization is an extension of Comprehensive Bufferize that
10 // bufferizes function boundaries. It provides `BufferizableOpInterface`
11 // implementations for FuncOp, CallOp and ReturnOp.
12 //
13 // Module Bufferization is run via `runComprehensiveBufferize(ModuleOp, ...)`.
14 // This function analyzed the given module and determines the order of
15 // analysis and bufferization: Functions that are called are processed before
16 // their respective callers.
17 //
18 // After analyzing a FuncOp, additional information about its bbArgs is
19 // gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
20 //
21 // * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
22 // tensor return value (if any).
23 // * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
24 // read/written.
25 //
26 // Only tensors that are equivalent to some FuncOp bbArg may be returned.
27 // Bufferization currently fails if other tensors (in particular tensors that
28 // bufferize out-of-place and result in a new buffer allocation) are returned.
29 // In the future, such allocations could be hoisted to the caller.
30 //
31 // Example: `foo` fails bufferization because %0 is not equivalent to any bbArg.
32 // ```
33 // func @foo() -> tensor<?xf32> {
34 // %0 = linalg.init_tensor [...] : tensor<?xf32>
35 // return %0 : tensor<?xf32>
36 // }
37 // ```
38 //
39 // Module Bufferization implements the following calling convention.
40 //
41 // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
42 // be written to in-place.
43 // * If a tensor operand of a CallOp is read after the CallOp, the operand of
44 // the CallOp must bufferize out-of-place.
45 //
46 // Example: The tensor.insert op bufferizes in-place because it is allowed to
47 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
48 // out-of-place because `%t0` is modified by the callee but read by the
49 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
50 // bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
51 // ```
52 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
53 // %f = ... : f32
54 // %0 = tensor.insert %f into %t1[...] : tensor<?xf32>
55 // return %0 : tensor<?xf32>
56 // }
57 //
58 // func @caller() -> () {
59 // %t0 = ... : tensor<?xf32>
60 // %1 = call @callee(%t0) : (tensor<?xf32>) -> (tensor<?xf32>)
61 // %2 = tensor.extract %1[...] : tensor<?xf32>
62 // }
63 // ```
64 //
65 // Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
66 // analyze the function body. In such a case, the CallOp analysis conservatively
67 // assumes that each tensor OpOperand is both read and written.
68 //
69 // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
70 // as "not reading" and/or "not writing".
71 
73 
80 #include "mlir/IR/Operation.h"
81 
82 using namespace mlir;
83 using namespace linalg;
84 using namespace tensor;
85 using namespace comprehensive_bufferize;
86 using namespace mlir::bufferization;
87 
88 namespace {
89 /// The state of analysis of a FuncOp.
90 enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
91 
92 /// Extra bufferization state that is required for bufferization of function
93 /// boundaries.
94 struct ModuleBufferizationState : public DialectBufferizationState {
95  /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
96  /// indices.
98 
99  /// A set of all read BlockArguments of FuncOps.
100  // Note: BlockArgument knows about its owner, so we do not need to store
101  // FuncOps here.
102  DenseSet<BlockArgument> readBbArgs;
103 
104  /// A set of all written-to BlockArguments of FuncOps.
105  DenseSet<BlockArgument> writtenBbArgs;
106 
107  /// Keep track of which FuncOps are fully analyzed or currently being
108  /// analyzed.
110 
111  // A list of functions in the order in which they are analyzed + bufferized.
112  SmallVector<FuncOp> orderedFuncOps;
113 
114  // A mapping of FuncOps to their callers.
116 };
117 } // namespace
118 
119 /// Get ModuleBufferizationState.
120 static const ModuleBufferizationState &
123  state.getDialectState<ModuleBufferizationState>(
124  StandardOpsDialect::getDialectNamespace());
125  assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
126  return **maybeState;
127 }
128 
129 /// Get or create ModuleBufferizationState.
130 static ModuleBufferizationState &
132  return state.getOrCreateDialectState<ModuleBufferizationState>(
133  StandardOpsDialect::getDialectNamespace());
134 }
135 
136 /// Return the state (phase) of analysis of the FuncOp.
137 static FuncOpAnalysisState
138 getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
139  const ModuleBufferizationState &moduleState =
141  auto it = moduleState.analyzedFuncOps.find(funcOp);
142  if (it == moduleState.analyzedFuncOps.end())
143  return FuncOpAnalysisState::NotAnalyzed;
144  return it->second;
145 }
146 
147 /// Return the unique ReturnOp that terminates `funcOp`.
148 /// Return nullptr if there is no such unique ReturnOp.
149 static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
150  ReturnOp returnOp;
151  for (Block &b : funcOp.body()) {
152  if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
153  if (returnOp)
154  return nullptr;
155  returnOp = candidateOp;
156  }
157  }
158  return returnOp;
159 }
160 
161 namespace {
162 /// Store function BlockArguments that are equivalent to a returned value in
163 /// ModuleBufferizationState.
164 struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
165  /// Annotate IR with the results of the analysis. For testing purposes only.
166  static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
167  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
168  Operation *op = returnVal.getOwner();
169 
170  SmallVector<int64_t> equivBbArgs;
171  if (op->hasAttr(kEquivalentArgsAttr)) {
172  auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
173  equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
174  return a.cast<IntegerAttr>().getValue().getSExtValue();
175  }));
176  } else {
177  equivBbArgs.append(op->getNumOperands(), -1);
178  }
179  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
180 
181  OpBuilder b(op->getContext());
182  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
183  }
184 
186  BufferizationAliasInfo &aliasInfo,
187  SmallVector<Operation *> &newOps) override {
188  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
189 
190  // Support only single return-terminated block in the function.
191  auto funcOp = cast<FuncOp>(op);
192  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
193  assert(returnOp && "expected func with single return op");
194 
195  for (OpOperand &returnVal : returnOp->getOpOperands())
196  if (returnVal.get().getType().isa<RankedTensorType>())
197  for (BlockArgument bbArg : funcOp.getArguments())
198  if (bbArg.getType().isa<RankedTensorType>())
199  if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
200  bbArg)) {
201  moduleState
202  .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
203  bbArg.getArgNumber();
204  if (state.getOptions().testAnalysisOnly)
205  annotateReturnOp(returnVal, bbArg);
206  }
207 
208  return success();
209  }
210 };
211 
212 /// Return true if the buffer of the given tensor value is written to. Must not
213 /// be called for values inside not yet analyzed functions. (Post-analysis
214 /// steps do not have to be run yet, i.e., "in progress" is also OK.)
215 static bool isValueWritten(Value value, const BufferizationState &state,
216  const BufferizationAliasInfo &aliasInfo) {
217 #ifndef NDEBUG
218  assert(value.getType().isa<TensorType>() && "expected TensorType");
219  FuncOp funcOp;
220  if (auto bbArg = value.dyn_cast<BlockArgument>()) {
221  Operation *owner = bbArg.getOwner()->getParentOp();
222  funcOp = isa<FuncOp>(owner) ? cast<FuncOp>(owner)
223  : owner->getParentOfType<FuncOp>();
224  } else {
225  funcOp = value.getDefiningOp()->getParentOfType<FuncOp>();
226  }
227  assert(getFuncOpAnalysisState(state, funcOp) !=
228  FuncOpAnalysisState::NotAnalyzed &&
229  "FuncOp must be fully analyzed or analysis in progress");
230 #endif // NDEBUG
231 
232  bool isWritten = false;
233  aliasInfo.applyOnAliases(value, [&](Value val) {
234  for (OpOperand &use : val.getUses())
235  if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use))
236  isWritten = true;
237  });
238  return isWritten;
239 }
240 
241 /// Determine which FuncOp bbArgs are read and which are written. If this
242 /// PostAnalysisStep is run on a function with unknown ops, it will
243 /// conservatively assume that such ops bufferize to a read + write.
244 struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
246  BufferizationAliasInfo &aliasInfo,
247  SmallVector<Operation *> &newOps) override {
248  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
249  auto funcOp = cast<FuncOp>(op);
250 
251  // If the function has no body, conservatively assume that all args are
252  // read + written.
253  if (funcOp.getBody().empty()) {
254  for (BlockArgument bbArg : funcOp.getArguments()) {
255  moduleState.readBbArgs.insert(bbArg);
256  moduleState.writtenBbArgs.insert(bbArg);
257  }
258 
259  return success();
260  }
261 
262  for (BlockArgument bbArg : funcOp.getArguments()) {
263  if (!bbArg.getType().isa<TensorType>())
264  continue;
265  if (state.isValueRead(bbArg))
266  moduleState.readBbArgs.insert(bbArg);
267  if (isValueWritten(bbArg, state, aliasInfo))
268  moduleState.writtenBbArgs.insert(bbArg);
269  }
270 
271  return success();
272  }
273 };
274 } // namespace
275 
276 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
277 
278 /// If `value` is a memref::CastOp, return its source. Otherwise, return
279 /// `value` directly.
281  while (auto castOp = value.getDefiningOp<memref::CastOp>())
282  value = castOp.source();
283  return value;
284 }
285 
286 /// Remove the attribute that triggers inplace bufferization on a FuncOp
287 /// argument `bbArg`.
289  auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
290  funcOp.removeArgAttr(bbArg.getArgNumber(),
291  BufferizableOpInterface::kBufferLayoutAttrName);
292  funcOp.removeArgAttr(bbArg.getArgNumber(),
293  BufferizableOpInterface::kInplaceableAttrName);
294 }
295 
296 /// Return the FuncOp called by `callOp`.
297 static FuncOp getCalledFunction(CallOpInterface callOp) {
298  SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
299  if (!sym)
300  return nullptr;
301  return dyn_cast_or_null<FuncOp>(
303 }
304 
305 /// Return the FunctionType with `argumentTypes` and `resultTypes` where each
306 /// tensor is replaced by the corresponding buffer type.
307 /// In order for all the callers to agree, this *must* bufferize to the most
308 /// dynamic buffer type supported.
309 /// A later pass across all CallOps in the module can decide whether to simplify
310 /// the types of to version according to some cost model.
311 static FunctionType
313  TypeRange resultTypes,
314  const BufferizationOptions &options) {
315  auto rewrite = [&](Type t) -> Type {
316  // TODO: non-zero address space.
317  // TODO: layout information if relevant.
318  if (auto tensorType = t.dyn_cast<TensorType>())
319  return getMemRefType(tensorType, options);
320  return t;
321  };
322  auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
323  auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
324  return FunctionType::get(ctx, argTypes, retTypes);
325 }
326 
327 /// Gather equivalence info of CallOps.
328 /// Note: This only adds new equivalence info if `funcOp` was already analyzed.
329 // TODO: This does not handle cyclic function call graphs etc.
330 static void equivalenceAnalysis(FuncOp funcOp,
331  BufferizationAliasInfo &aliasInfo,
332  ModuleBufferizationState &moduleState) {
333  funcOp->walk([&](CallOp callOp) {
334  FuncOp calledFunction = getCalledFunction(callOp);
335  assert(calledFunction && "could not retrieved called FuncOp");
336 
337  // No equivalence info available for the called function.
338  if (!moduleState.equivalentFuncArgs.count(calledFunction))
339  return WalkResult::skip();
340 
341  for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
342  int64_t returnIdx = it.first;
343  int64_t bbargIdx = it.second;
344  Value returnVal = callOp.getResult(returnIdx);
345  Value argVal = callOp->getOperand(bbargIdx);
346  aliasInfo.unionEquivalenceClasses(returnVal, argVal);
347  }
348 
349  return WalkResult::advance();
350  });
351 }
352 
353 /// Rewrite the `funcOp` arguments analysis return values and terminator into
354 /// buffer form (using the canonical memref layout for now), according to the
355 /// inPlace-bufferizable information of the function arguments.
356 ///
357 /// This relies on a buffer equivalence analysis of each return operand. When a
358 /// result buffer is equivalent to a BlockArgument of `funcOp`, it can be
359 /// dropped from the return values and becomes inplaceable at all callers. This
360 /// assumes all CallOp perform the necessary work to clone operands so as to
361 /// make them inplaceable. Reliance on this logic will need to be relaxed in the
362 /// future.
363 ///
364 /// Note: Returning a memref currently fails bufferization. If such memrefs
365 /// originate from an op with an Alloc effect, they could be hoisted in the
366 /// future.
368  RewriterBase &rewriter,
369  BufferizationState &state) {
370  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
371 
372  // If nothing to do then we are done.
373  if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
374  !llvm::any_of(funcOp.getType().getResults(), isaTensor))
375  return success();
376 
377  // Get the bufferized FunctionType for funcOp or construct it if not yet
378  // available.
379  // TODO: Atm we have 3 cases:
380  // 1. if a function is called from within the Module, it must have bufferized
381  // to inplaceable tensor results.
382  // 2. if it is bodiless, it must have bufferized and is not allowed to have
383  // result tensors.
384  // 3. if it is not called internally, it still must bufferize to inplaceable
385  // tensor results and we construct it now (e.g. top-level function called
386  // externally).
387  // -> Figure out a better layering.
388  TypeRange resultTypes;
389 
390  // Corner case: Bodiless FuncOp
391  // ============================
392  // The body of such functions is assumed opaque and we can't know the
393  // bufferization contract they want to enforce atm.
394  // As a consequence, only support functions that don't return any tensor atm.
395  if (funcOp.getBody().empty()) {
396  if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
397  return funcOp->emitError() << "cannot bufferize bodiless function that "
398  << "returns a tensor";
399  FunctionType bufferizedFuncType = getBufferizedFunctionType(
400  funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{},
401  state.getOptions());
402  funcOp.setType(bufferizedFuncType);
403  return success();
404  }
405 
406  // Support only single return-terminated block in the function.
407  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
408  assert(returnOp && "expected func with single return op");
409 
410  // 1. For each FuncOp result, keep track of which inplace argument it reuses.
411  SmallVector<Value> returnValues;
412  for (OpOperand &returnOperand : returnOp->getOpOperands()) {
413  Value returnVal = returnOperand.get();
414 
415  // If not a renturn tensor type just forward it.
416  if (!returnVal.getType().isa<RankedTensorType>()) {
417  returnValues.push_back(returnVal);
418  continue;
419  }
420 
421  // If return operand is equivalent to some bbArg, no need to return it.
422  if (moduleState.equivalentFuncArgs[funcOp].count(
423  returnOperand.getOperandNumber()))
424  continue;
425 
426  // Cast values at the call site if necessary.
427  returnValues.push_back(
428  getNonCastedValue(*state.getBuffer(rewriter, returnOperand)));
429  }
430 
431  // 2. Rewrite the terminator without the inPlace bufferizable values.
432  ValueRange retValues{returnValues};
433  FunctionType bufferizedFuncType = getBufferizedFunctionType(
434  funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes(),
435  state.getOptions());
436  OpBuilder b(returnOp);
437  b.create<ReturnOp>(returnOp.getLoc(), returnValues);
438  returnOp->erase();
439 
440  // 3. Rewrite the bbArgs.
441  // Iterate on the original `numArgs` and replace them in order.
442  // This guarantees the argument order still matches after the rewrite.
443  Block &frontBlock = funcOp.body().front();
444  unsigned numArgs = frontBlock.getNumArguments();
445  for (unsigned idx = 0; idx < numArgs; ++idx) {
446  auto bbArg = frontBlock.getArgument(0);
447  auto tensorType = bbArg.getType().dyn_cast<TensorType>();
448  // Non-tensor types are just forwarded.
449  if (!tensorType) {
450  frontBlock.addArgument(bbArg.getType(), bbArg.getLoc());
451  bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
452  frontBlock.eraseArgument(0);
453  continue;
454  }
455 
456  // Get the buffer type from the bufferized function type.
457  Type memrefType = bufferizedFuncType.getInput(idx);
458  Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc());
459  OpBuilder b(funcOp->getContext());
460  b.setInsertionPointToStart(&frontBlock);
461  // Replace all uses of bbArg through a ToMemRefOp.
462  for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
463  if (auto toMemrefOp =
464  dyn_cast<bufferization::ToMemrefOp>(use.getOwner())) {
465  if (memref.getType() != toMemrefOp.memref().getType()) {
466  // Type has changed, insert a cast.
467  assert(memref::CastOp::areCastCompatible(
468  memref.getType(), toMemrefOp.memref().getType()) &&
469  "bufferizeFuncOpBoundary: cast incompatible");
470  auto castOp = b.create<memref::CastOp>(
471  funcOp.getLoc(), toMemrefOp.memref().getType(), memref);
472  toMemrefOp.memref().replaceAllUsesWith(castOp);
473  } else {
474  // Type did not change, replace directly.
475  toMemrefOp.memref().replaceAllUsesWith(memref);
476  }
477  }
478  }
479  // Replace all remaining uses by a to_tensor.
480  if (!bbArg.use_empty()) {
481  auto toTensorOp =
482  b.create<bufferization::ToTensorOp>(funcOp.getLoc(), memref);
483  bbArg.replaceAllUsesWith(toTensorOp);
484  }
485  frontBlock.eraseArgument(0);
486  // TODO: add support to erase aliasInfo entries if deemed necessary.
487  }
488 
489  // 4. Rewrite the FuncOp type to buffer form.
490  funcOp.setType(bufferizedFuncType);
491 
492  return success();
493 }
494 
495 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
496 /// callee-caller order (i.e. callees without callers first).
497 /// Store the map of FuncOp to all its callers in `callerMap`.
498 /// Return `failure()` if a cycle of calls is detected or if we are unable to
499 /// retrieve the called FuncOp from any CallOpInterface.
500 static LogicalResult
501 getFuncOpsOrderedByCalls(ModuleOp moduleOp,
502  SmallVectorImpl<FuncOp> &orderedFuncOps,
503  DenseMap<FuncOp, DenseSet<Operation *>> &callerMap) {
504  // For each FuncOp, the set of functions called by it (i.e. the union of
505  // symbols of all nested CallOpInterfaceOp).
507  // For each FuncOp, the number of CallOpInterface it contains.
508  DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
509  WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult {
510  if (!funcOp.body().empty()) {
511  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
512  if (!returnOp)
513  return funcOp->emitError()
514  << "cannot bufferize a FuncOp with tensors and "
515  "without a unique ReturnOp";
516  }
517 
518  numberCallOpsContainedInFuncOp[funcOp] = 0;
519  return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
520  // Only support CallOp for now.
521  if (!isa<CallOp>(callOp.getOperation()))
522  return callOp->emitError() << "expected a CallOp";
523  FuncOp calledFunction = getCalledFunction(callOp);
524  assert(calledFunction && "could not retrieved called FuncOp");
525  auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
526  it.first->getSecond().insert(callOp);
527  if (calledBy[calledFunction].count(funcOp) == 0) {
528  calledBy[calledFunction].insert(funcOp);
529  numberCallOpsContainedInFuncOp[funcOp]++;
530  }
531  return WalkResult::advance();
532  });
533  });
534  if (res.wasInterrupted())
535  return failure();
536  // Iteratively remove function operation that do not call any of the
537  // functions remaining in the callCounter map and add them to the worklist.
538  while (!numberCallOpsContainedInFuncOp.empty()) {
539  auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
540  [](auto entry) { return entry.getSecond() == 0; });
541  if (it == numberCallOpsContainedInFuncOp.end())
542  return moduleOp.emitOpError(
543  "expected callgraph to be free of circular dependencies.");
544  orderedFuncOps.push_back(it->getFirst());
545  for (auto callee : calledBy[it->getFirst()])
546  numberCallOpsContainedInFuncOp[callee]--;
547  numberCallOpsContainedInFuncOp.erase(it);
548  }
549  return success();
550 }
551 
552 static void
553 foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
554  FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
555  auto itCallers = callerMap.find(callee);
556  if (itCallers == callerMap.end())
557  return;
558  for (Operation *caller : itCallers->second)
559  doit(caller);
560 }
561 
562 /// Postprocess the linalg.buffer_layout annotation across function boundaries.
563 /// This is a purely mechanical process that may later become part of a
564 /// separate pass with its own layout assignment heuristic.
565 static void layoutPostProcessing(ModuleOp moduleOp) {
566  SmallVector<FuncOp> orderedFuncOps;
568  auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
569  (void)res;
570  assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
571 
572  for (FuncOp funcOp : orderedFuncOps) {
573  DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
574  foreachCaller(callerMap, funcOp, [&](Operation *caller) {
575  operandsPerCaller.try_emplace(caller, SmallVector<Value>());
576  });
577 
578  SmallVector<Type> argumentTypes;
579  // Iterate on each function argument and check it it was marked with a
580  // desired layout.
581  for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) {
582  int argNumber = it.index();
583  Type inputType = it.value();
584  auto memrefType = inputType.dyn_cast<MemRefType>();
585  auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
586  argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
587  AffineMap desiredLayoutMap =
588  layoutAttr ? layoutAttr.getValue() : AffineMap();
589  AffineMap currentLayoutMap =
590  memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
591  if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
592  argumentTypes.push_back(inputType);
593  foreachCaller(callerMap, funcOp, [&](Operation *caller) {
594  operandsPerCaller.find(caller)->getSecond().push_back(
595  caller->getOperand(argNumber));
596  });
597  continue;
598  }
599 
600  // Compute the buffer type with desired layout and add to input argument
601  // types.
602  MemRefType desiredMemrefType = MemRefType::get(
603  memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
604  argumentTypes.push_back(desiredMemrefType);
605 
606  // If funcOp's body is not empty, change the bbArg type and propagate.
607  if (!funcOp.body().empty()) {
608  BlockArgument bbArg = funcOp.getArgument(argNumber);
609  bbArg.setType(desiredMemrefType);
610  OpBuilder b(bbArg.getContext());
612  assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
613  "layoutPostProcessing: cast incompatible");
614  // Cast back to the original memrefType and let it canonicalize.
615  Value cast =
616  b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
617  bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
618  }
619 
620  // Cast to desired buffer type on all callers to `funcOp`.
621  // TODO: on the callee side, this may even have to trigger a copy to
622  // change the layout. For now let the memref::CastOp fail to verify in
623  // such cases.
624  auto castArg = [&](Operation *caller) {
625  OpBuilder b(caller);
626  assert(
627  memref::CastOp::areCastCompatible(
628  caller->getOperand(argNumber).getType(), desiredMemrefType) &&
629  "layoutPostProcessing.2: cast incompatible");
630  Value newOperand = b.create<memref::CastOp>(
631  funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
632  operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
633  };
634  foreachCaller(callerMap, funcOp, castArg);
635  }
636 
637  // Set operands with cast buffer on all callers to `funcOp`.
638  foreachCaller(callerMap, funcOp, [&](Operation *caller) {
639  caller->setOperands(operandsPerCaller.lookup(caller));
640  });
641 
642  // Finally set the funcOp type to update the arguments.
643  auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
644  funcOp.getType().getResults());
645  funcOp.setType(newFuncType);
646  }
647 }
648 
649 namespace mlir {
650 namespace linalg {
651 namespace comprehensive_bufferize {
652 namespace std_ext {
653 
654 /// Return the index of the bbArg in the given FuncOp that is equivalent to the
655 /// specified return value (if any).
656 static Optional<int64_t>
657 getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
658  int64_t returnValIdx) {
659  auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
660  if (funcOpIt == state.equivalentFuncArgs.end())
661  // No equivalence info stores for funcOp.
662  return None;
663 
664  auto retValIt = funcOpIt->getSecond().find(returnValIdx);
665  if (retValIt == funcOpIt->getSecond().end())
666  // Return value has no equivalent bbArg.
667  return None;
668 
669  return retValIt->getSecond();
670 }
671 
673  : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
675  const BufferizationState &state) const {
676  CallOp callOp = cast<CallOp>(op);
677  FuncOp funcOp = getCalledFunction(callOp);
678  assert(funcOp && "expected CallOp to a FuncOp");
679 
680  const ModuleBufferizationState &moduleState =
682  if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
683  // FuncOp not analyzed yet. Assume that OpOperand is read.
684  return true;
685 
686  return moduleState.readBbArgs.contains(
687  funcOp.getArgument(opOperand.getOperandNumber()));
688  }
689 
691  const BufferizationState &state) const {
692  CallOp callOp = cast<CallOp>(op);
693  FuncOp funcOp = getCalledFunction(callOp);
694  assert(funcOp && "expected CallOp to a FuncOp");
695 
696  const ModuleBufferizationState &moduleState =
698  if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
699  // FuncOp not analyzed yet. Assume that OpOperand is written.
700  return true;
701 
702  return moduleState.writtenBbArgs.contains(
703  funcOp.getArgument(opOperand.getOperandNumber()));
704  }
705 
707  const BufferizationState &state) const {
708  CallOp callOp = cast<CallOp>(op);
709  FuncOp funcOp = getCalledFunction(callOp);
710  assert(funcOp && "expected CallOp to a FuncOp");
711  const ModuleBufferizationState &moduleState =
713 
714  for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
715  ++resultIdx)
716  if (Optional<int64_t> maybeArgNumber =
717  getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
718  if (*maybeArgNumber == opOperand.getOperandNumber())
719  return callOp->getOpResult(resultIdx);
720 
721  // Note: Returning a non-equivalent tensor from a FuncOp is currently not
722  // supported an will fail bufferization. (Even if allow-return-memref, it
723  // will fail when the function is called.)
724  return OpResult();
725  }
726 
729  const BufferizationState &state) const {
730  CallOp callOp = cast<CallOp>(op);
731  FuncOp funcOp = getCalledFunction(callOp);
732  assert(funcOp && "expected CallOp to a FuncOp");
733  const ModuleBufferizationState &moduleState =
735 
736  // TODO: We should be looking for aliasing block arguments here. The current
737  // condition is actually stronger than neccesary. Once we check for aliasing
738  // block arguments, we may be multiple.
739  if (Optional<int64_t> maybeArgNumber = getEquivalentFuncArgIdx(
740  funcOp, moduleState, opResult.getResultNumber()))
741  return {&op->getOpOperand(*maybeArgNumber)};
742 
743  // Note: Returning a non-equivalent tensor from a FuncOp is currently not
744  // supported an will fail bufferization.
745  return {};
746  }
747 
749  const BufferizationState &state) const {
750  return BufferRelation::Equivalent;
751  }
752 
753  /// In a first approximation, all the function arguments of a FuncOp are
754  /// marked inplaceable. For now, it is the responsibility of the `callOp`
755  /// bufferization to allow FuncOp that are inplaceable to write inPlace.
757  const BufferizationState &state) const {
758  CallOp callOp = cast<CallOp>(op);
759  unsigned numResults = callOp.getNumResults();
760  unsigned numOperands = callOp->getNumOperands();
761  FuncOp funcOp = getCalledFunction(callOp);
762  assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
763  "expected CallOp to a FuncOp");
764  const ModuleBufferizationState &moduleState =
766 
767  // Result types of the bufferized CallOp.
768  SmallVector<Type> resultTypes;
769  // Replacement values for the existing CallOp. These are usually the results
770  // of the bufferized CallOp, unless a tensor result folds onto an operand.
771  SmallVector<Value> replacementValues(numResults, Value());
772  // For non-tensor results: A mapping from return val indices of the old
773  // CallOp to return val indices of the bufferized CallOp.
774  SmallVector<Optional<unsigned>> retValMapping(numResults, None);
775  // Operands of the bufferized CallOp.
776  SmallVector<Value> newOperands(numOperands, Value());
777 
778  // Based on previously gathered equivalence information, we know if a
779  // tensor result folds onto an operand. These are the only tensor value
780  // results that are supported at the moment.
781  //
782  // For tensors return values that do not fold onto an operand, additional
783  // work is needed (TODO) to either:
784  // * hoist a result into an inplaceable operand or
785  // * devise a better representation to truly return a buffer.
786  //
787  // Note: If a function has no body, no equivalence information is
788  // available. Consequently, a tensor return value cannot be proven to fold
789  // onto a FuncOp bbArg, so calls to such functions are not bufferizable at
790  // the moment.
791 
792  // 1. Compute the result types of the new CallOp. Tensor results that are
793  // equivalent to a FuncOp bbArg are no longer returned.
794  for (const auto &it : llvm::enumerate(callOp.getResultTypes())) {
795  unsigned returnValIdx = it.index();
796  Type returnType = it.value();
797  if (!isaTensor(returnType)) {
798  // Non-tensor values are returned.
799  retValMapping[returnValIdx] = resultTypes.size();
800  resultTypes.push_back(returnType);
801  continue;
802  }
803 
804  if (Optional<int64_t> bbArgIdx =
805  getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
806  // Return operands that are equivalent to some bbArg, are not
807  // returned.
808  FailureOr<Value> bufferOrFailure =
809  state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
810  if (failed(bufferOrFailure))
811  return failure();
812  replacementValues[returnValIdx] = *bufferOrFailure;
813  newOperands[*bbArgIdx] = *bufferOrFailure;
814  continue;
815  }
816 
817  return callOp->emitError(
818  "call to FuncOp that returns non-equivalent tensors not supported");
819  }
820 
821  // 2. Compute bufferized FunctionType.
822  SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
823  // Get the bufferized FunctionType for funcOp or construct it if not yet
824  // available.
825  FunctionType bufferizedFuncType = getBufferizedFunctionType(
826  funcOp.getContext(), argumentTypes, resultTypes, state.getOptions());
827 
828  // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
829  for (OpOperand &opOperand : callOp->getOpOperands()) {
830  unsigned idx = opOperand.getOperandNumber();
831  Value tensorOperand = opOperand.get();
832 
833  // Non-tensor operands are just copied.
834  if (!tensorOperand.getType().isa<TensorType>()) {
835  newOperands[idx] = tensorOperand;
836  continue;
837  }
838 
839  // Retrieve buffers for tensor operands. Tensor operand buffers, who's
840  // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
841  // already stored in `newOperands` during Step 1.
842  Value buffer = newOperands[idx];
843  if (!buffer) {
844  FailureOr<Value> bufferOrFailure = state.getBuffer(rewriter, opOperand);
845  if (failed(bufferOrFailure))
846  return failure();
847  buffer = *bufferOrFailure;
848  }
849 
850  // Caller / callee type mismatch is handled with a CastOp.
851  auto memRefType = bufferizedFuncType.getInput(idx);
852  // Since we don't yet have a clear layout story, to_memref may
853  // conservatively turn tensors into more dynamic memref than necessary.
854  // If the memref type of the callee fails, introduce an extra memref.cast
855  // that will either canonicalize away or fail compilation until we can do
856  // something better.
857  if (buffer.getType() != memRefType) {
858  assert(
859  memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
860  "CallOp::bufferize: cast incompatible");
861  Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
862  memRefType, buffer);
863  buffer = castBuffer;
864  }
865  newOperands[idx] = buffer;
866  }
867 
868  // 4. Create the new CallOp.
869  Operation *newCallOp = rewriter.create<CallOp>(
870  callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
871  newCallOp->setAttrs(callOp->getAttrs());
872  // Get replacement values for non-tensor / non-equivalent results.
873  for (unsigned i = 0; i < replacementValues.size(); ++i) {
874  if (replacementValues[i])
875  continue;
876  replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
877  }
878 
879  // 5. Replace the old op with the new op.
880  replaceOpWithBufferizedValues(rewriter, callOp, replacementValues);
881 
882  return success();
883  }
884 };
885 
887  : public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
888  ReturnOp> {
890  const BufferizationState &state) const {
891  return true;
892  }
893 
895  const BufferizationState &state) const {
896  return false;
897  }
898 
900  const BufferizationState &state) const {
901  return OpResult();
902  }
903 
905  const BufferizationState &state) const {
906 #ifndef NDEBUG
907  auto returnOp = cast<ReturnOp>(op);
908  assert(isa<FuncOp>(returnOp->getParentOp()) &&
909  "only support FuncOp parent for ReturnOp");
910 #endif // NDEBUG
911  return failure();
912  }
913 };
914 
916  : public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
918  const BufferizationState &state) const {
919  return failure();
920  }
921 
922  /// Return `true` if the given function argument is writable.
924  const BufferizationState &state) const {
925  auto funcOp = cast<FuncOp>(op);
926  BlockArgument bbArg = value.dyn_cast<BlockArgument>();
927  assert(bbArg && "expected BlockArgument");
928 
929  // "linalg.inplaceable" overrides other writability decisions. This is
930  // currently used for testing only.
931  if (BoolAttr inplaceAttr = funcOp.getArgAttrOfType<BoolAttr>(
932  bbArg.getArgNumber(),
933  BufferizableOpInterface::kInplaceableAttrName))
934  return inplaceAttr.getValue();
935 
936  // All function arguments are writable by default.
937  return true;
938  }
939 
940  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
941 };
942 
943 } // namespace std_ext
944 } // namespace comprehensive_bufferize
945 } // namespace linalg
946 } // namespace mlir
947 
950  registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
951  registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
952  registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
953 }
954 
955 /// Set the attribute that triggers inplace bufferization on a FuncOp argument
956 /// `bbArg`.
957 static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
958  auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
959  funcOp.setArgAttr(bbArg.getArgNumber(),
960  BufferizableOpInterface::kInplaceableAttrName,
961  BoolAttr::get(bbArg.getContext(), inPlace));
962 }
963 
964 /// Annotate the IR with the result of the analysis. For testing/debugging only.
965 static void
967  const BufferizationState &state) {
968  auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
969  for (BlockArgument bbArg : funcOp.getArguments())
970  if (bbArg.getType().isa<TensorType>())
971  setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state));
972 }
973 
975  ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
976  IRRewriter rewriter(moduleOp.getContext());
977  AnalysisBufferizationState state(moduleOp, *options);
978  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
979  BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
980 
981  if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
982  moduleState.callerMap)))
983  return failure();
984 
985  // Collect bbArg/return value information after the analysis.
986  options->postAnalysisSteps.emplace_back(
987  std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
988  options->postAnalysisSteps.emplace_back(
989  std::make_unique<FuncOpBbArgReadWriteAnalysis>());
990 
991  // Analyze ops.
992  for (FuncOp funcOp : moduleState.orderedFuncOps) {
993  // No body => no analysis.
994  if (funcOp.body().empty())
995  continue;
996 
997  // Now analyzing function.
998  moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
999 
1000  // Analyze funcOp.
1001  if (failed(analyzeOp(funcOp, state)))
1002  return failure();
1003 
1004  // Gather equivalence info for CallOps.
1005  // TODO: Make this a post-analysis step.
1006  equivalenceAnalysis(funcOp, aliasInfo, moduleState);
1007 
1008  // Mark op as fully analyzed.
1009  moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;
1010 
1011  // Add annotations to function arguments.
1012  if (options->testAnalysisOnly)
1013  annotateOpsWithBufferizationMarkers(funcOp, state);
1014  }
1015 
1016  if (options->testAnalysisOnly)
1017  return success();
1018 
1019  // Bufferize function bodies.
1020  for (FuncOp funcOp : moduleState.orderedFuncOps) {
1021  // No body => no analysis.
1022  if (funcOp.body().empty())
1023  continue;
1024 
1025  if (failed(bufferizeOp(funcOp, state)))
1026  return failure();
1027  }
1028 
1029  // Bufferize function boundaries.
1030  for (FuncOp funcOp : moduleState.orderedFuncOps) {
1031  // Note: It would be good to apply cleanups here but we cannot as aliasInfo
1032  // would be invalidated.
1033  if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
1034  return failure();
1035 
1036  if (!options->allowReturnMemref &&
1037  llvm::any_of(funcOp.getType().getResults(), [](Type t) {
1038  return t.isa<MemRefType, UnrankedMemRefType>();
1039  })) {
1040  funcOp->emitError("memref return type is unsupported");
1041  return failure();
1042  }
1043  }
1044 
1045  // Perform a post-processing pass of layout modification at function boundary
1046  // according to the kBufferLayoutAttrName.
1047  layoutPostProcessing(moduleOp);
1048 
1049  // Post-pass cleanup of inplaceable and buffer_layout attributes.
1050  moduleOp.walk([&](FuncOp op) {
1051  for (BlockArgument bbArg : op.getArguments())
1053  });
1054 
1055  return success();
1056 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
static bool isaTensor(Type t)
U cast() const
Definition: Attributes.h:123
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Dialect-specific bufferization state.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This is a value defined by a result of an operation.
Definition: Value.h:423
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in &#39;operands&#39;.
Definition: Operation.cpp:200
AffineMap getStridedLinearLayoutMap(MemRefType t)
Return the layout map in strided linear layout AffineMap form.
Block represents an ordered list of Operations.
Definition: Block.h:29
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
Value getOperand(unsigned idx)
Definition: Operation.h:219
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
unsigned getNumOperands()
Definition: Operation.h:215
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
static FunctionType getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes, TypeRange resultTypes, const BufferizationOptions &options)
Return the FunctionType with argumentTypes and resultTypes where each tensor is replaced by the corre...
bool areEquivalentBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 bufferize to equivalent buffers.
bool isValueRead(Value value) const
Return true if the given value is read by an op that bufferizes to a memory read. ...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
SmallVector< OpOperand * > getAliasingOpOperand(Operation *op, OpResult opResult, const BufferizationState &state) const
Operation & front()
Definition: Block.h:144
static void equivalenceAnalysis(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, ModuleBufferizationState &moduleState)
Gather equivalence info of CallOps.
void unionEquivalenceClasses(Value v1, Value v2)
Union the equivalence classes of v1 and v2.
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:310
The BufferizationAliasInfo class maintains a list of buffer aliases and equivalence classes to suppor...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:120
void replaceAllUsesWith(Value newValue) const
Replace all uses of &#39;this&#39; value with the new value, updating anything in the IR that uses &#39;this&#39; to ...
Definition: Value.h:161
static constexpr const bool value
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:99
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:307
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, std::unique_ptr< bufferization::AnalysisBufferizationOptions > options)
Run Module Bufferization on the given module.
static void layoutPostProcessing(ModuleOp moduleOp)
Postprocess the linalg.buffer_layout annotation across function boundaries.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:337
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
void addOpInterface()
Add an external op interface model for an op that belongs to a dialect, both provided as template par...
Definition: Dialect.h:382
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
PostAnalysisSteps can be registered with BufferizationOptions and are executed after the analysis...
U dyn_cast() const
Definition: Types.h:244
unsigned getNumArguments()
Definition: Block.h:119
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
U dyn_cast() const
Definition: Value.h:99
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:117
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:435
FuncOpAnalysisState
The state of analysis of a FuncOp.
void eraseArgument(unsigned index)
Erase the argument at &#39;index&#39; and remove it from the argument list.
Definition: Block.cpp:181
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< FuncOp > &orderedFuncOps, DenseMap< FuncOp, DenseSet< Operation *>> &callerMap)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i...
FailureOr< Value > getBuffer(RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace=false, Optional< Operation *> customCopyInsertionPoint=None) const
Return the buffer (memref) for a given OpOperand (tensor).
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, RewriterBase &rewriter, BufferizationState &state)
Rewrite the funcOp arguments analysis return values and terminator into buffer form (using the canoni...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
static WalkResult advance()
Definition: Visitors.h:51
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:133
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const
In a first approximation, all the function arguments of a FuncOp are marked inplaceable.
static void rewrite(SCCPAnalysis &analysis, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:195
void registerModuleBufferizationExternalModels(DialectRegistry &registry)
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:38
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const
BlockArgListType getArguments()
Definition: Block.h:76
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
This class represents an argument of a Block.
Definition: Value.h:298
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
static void foreachCaller(const DenseMap< FuncOp, DenseSet< Operation *>> &callerMap, FuncOp callee, llvm::function_ref< void(Operation *)> doit)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
Options for ComprehensiveBufferize.
static FuncOpAnalysisState getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
virtual bool isInPlace(OpOperand &opOperand) const =0
Return true if the given OpResult has been decided to bufferize inplace.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of...
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:865
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Optional< const StateT * > getDialectState(StringRef name) const
Return dialect-specific bufferization state.
bool isWritable(Operation *op, Value value, const BufferizationState &state) const
Return true if the given function argument is writable.
static llvm::ManagedStatic< PassManagerOptions > options
static WalkResult skip()
Definition: Visitors.h:52
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state)
Analyze op and its nested ops.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
BufferizationState provides a variety of helper functions for dealing with tensor values and memref b...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:347
static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace)
Set the attribute that triggers inplace bufferization on a FuncOp argument bbArg. ...
Type getType() const
Return the type of this value.
Definition: Value.h:117
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Definition: Dialect.h:282
static Optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:257
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:37
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static const ModuleBufferizationState & getModuleBufferizationState(const BufferizationState &state)
Get ModuleBufferizationState.
This class represents an operand of an operation.
Definition: Value.h:249
static Value getNonCastedValue(Value value)
If value is a memref::CastOp, return its source.
StateT & getOrCreateDialectState(StringRef name)
Return dialect-specific bufferization state or create one if none exists.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, const BufferizationState &state)
Annotate the IR with the result of the analysis. For testing/debugging only.
LogicalResult bufferizeOp(Operation *op, const BufferizationState &state)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:205
static BoolAttr get(MLIRContext *context, bool value)
State for analysis-enabled bufferization.
bool isa() const
Definition: Types.h:234
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace={})
Return a MemRefType to which the tensorType can be bufferized in a composable fashion.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationState &state) const
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
Definition: Operation.h:314
This class helps build Operations.
Definition: Builders.h:177
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
This class provides an abstraction over the different types of ranges over Values.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:196
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:323
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition: Value.h:128
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:120
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:688
static void removeBufferizationFuncArguments(BlockArgument bbArg)
Remove the attribute that triggers inplace bufferization on a FuncOp argument bbArg.
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation *> &exceptions) const
Replace all uses of &#39;this&#39; value with &#39;newValue&#39;, updating anything in the IR that uses &#39;this&#39; to use...
Definition: Value.cpp:61