MLIR  19.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::scf;
27 
28 namespace mlir {
29 namespace scf {
30 namespace {
31 
32 /// Helper function for loop bufferization. Cast the given buffer to the given
33 /// memref type.
34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37  // If the buffer already has the correct type, no cast is needed.
38  if (buffer.getType() == type)
39  return buffer;
40  // TODO: In case `type` has a layout map that is not the fully dynamic
41  // one, we may not be able to cast the buffer. In that case, the loop
42  // iter_arg's layout map must be changed (see uses of `castBuffer`).
43  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44  "scf.while op bufferization: cast incompatible");
45  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46 }
47 
48 /// Helper function for loop bufferization. Return "true" if the given value
49 /// is guaranteed to not alias with an external tensor apart from values in
50 /// `exceptions`. A value is external if it is defined outside of the given
51 /// region or if it is an entry block argument of the region.
52 static bool doesNotAliasExternalValue(Value value, Region *region,
53  ValueRange exceptions,
54  const OneShotAnalysisState &state) {
55  assert(region->getBlocks().size() == 1 &&
56  "expected region with single block");
57  bool result = true;
58  state.applyOnAliases(value, [&](Value alias) {
59  if (llvm::is_contained(exceptions, alias))
60  return;
61  Region *aliasRegion = alias.getParentRegion();
62  if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63  result = false;
64  if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65  result = false;
66  });
67  return result;
68 }
69 
70 /// Bufferization of scf.condition.
71 struct ConditionOpInterface
72  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73  scf::ConditionOp> {
74  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75  const AnalysisState &state) const {
76  return true;
77  }
78 
79  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80  const AnalysisState &state) const {
81  return false;
82  }
83 
84  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85  const AnalysisState &state) const {
86  return {};
87  }
88 
89  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90  const AnalysisState &state) const {
91  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92  // may be generated inside the block. We should not return/yield allocations
93  // when possible.
94  return true;
95  }
96 
97  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98  const BufferizationOptions &options) const {
99  auto conditionOp = cast<scf::ConditionOp>(op);
100  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101 
102  SmallVector<Value> newArgs;
103  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104  Value value = it.value();
105  if (isa<TensorType>(value.getType())) {
106  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
107  if (failed(maybeBuffer))
108  return failure();
110  whileOp.getAfterArguments()[it.index()], options);
111  if (failed(resultType))
112  return failure();
113  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114  newArgs.push_back(buffer);
115  } else {
116  newArgs.push_back(value);
117  }
118  }
119 
120  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121  rewriter, op, conditionOp.getCondition(), newArgs);
122  return success();
123  }
124 };
125 
126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127 /// return an empty op.
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129  scf::YieldOp result;
130  for (Block &block : executeRegionOp.getRegion()) {
131  if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132  if (result)
133  return {};
134  result = yieldOp;
135  }
136  }
137  return result;
138 }
139 
140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141 /// fully implemented at the moment.
142 struct ExecuteRegionOpInterface
144  ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145 
146  static bool supportsUnstructuredControlFlow() { return true; }
147 
148  bool isWritable(Operation *op, Value value,
149  const AnalysisState &state) const {
150  return true;
151  }
152 
153  LogicalResult verifyAnalysis(Operation *op,
154  const AnalysisState &state) const {
155  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156  // TODO: scf.execute_region with multiple yields are not supported.
157  if (!getUniqueYieldOp(executeRegionOp))
158  return op->emitOpError("op without unique scf.yield is not supported");
159  return success();
160  }
161 
163  getAliasingOpOperands(Operation *op, Value value,
164  const AnalysisState &state) const {
165  if (auto bbArg = dyn_cast<BlockArgument>(value))
166  return getAliasingBranchOpOperands(op, bbArg, state);
167 
168  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169  // any SSA value that is in scope. To allow for use-def chain traversal
170  // through ExecuteRegionOps in the analysis, the corresponding yield value
171  // is considered to be aliasing with the result.
172  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173  auto it = llvm::find(op->getOpResults(), value);
174  assert(it != op->getOpResults().end() && "invalid value");
175  size_t resultNum = std::distance(op->getOpResults().begin(), it);
176  auto yieldOp = getUniqueYieldOp(executeRegionOp);
177  // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178  if (!yieldOp)
179  return {};
180  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181  }
182 
183  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184  const BufferizationOptions &options) const {
185  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186  auto yieldOp = getUniqueYieldOp(executeRegionOp);
187  TypeRange newResultTypes(yieldOp.getResults());
188 
189  // Create new op and move over region.
190  auto newOp =
191  rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192  newOp.getRegion().takeBody(executeRegionOp.getRegion());
193 
194  // Bufferize every block.
195  for (Block &block : newOp.getRegion())
196  if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
197  options)))
198  return failure();
199 
200  // Update all uses of the old op.
201  rewriter.setInsertionPointAfter(newOp);
202  SmallVector<Value> newResults;
203  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204  if (isa<TensorType>(it.value())) {
205  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206  executeRegionOp.getLoc(), newOp->getResult(it.index())));
207  } else {
208  newResults.push_back(newOp->getResult(it.index()));
209  }
210  }
211 
212  // Replace old op.
213  rewriter.replaceOp(executeRegionOp, newResults);
214 
215  return success();
216  }
217 };
218 
219 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
220 struct IfOpInterface
221  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
223  getAliasingOpOperands(Operation *op, Value value,
224  const AnalysisState &state) const {
225  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
226  // value that is in scope. To allow for use-def chain traversal through
227  // IfOps in the analysis, both corresponding yield values from the then/else
228  // branches are considered to be aliasing with the result.
229  auto ifOp = cast<scf::IfOp>(op);
230  size_t resultNum = std::distance(op->getOpResults().begin(),
231  llvm::find(op->getOpResults(), value));
232  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
233  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
234  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
235  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
236  }
237 
238  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
239  const BufferizationOptions &options) const {
240  OpBuilder::InsertionGuard g(rewriter);
241  auto ifOp = cast<scf::IfOp>(op);
242 
243  // Compute bufferized result types.
244  SmallVector<Type> newTypes;
245  for (Value result : ifOp.getResults()) {
246  if (!isa<TensorType>(result.getType())) {
247  newTypes.push_back(result.getType());
248  continue;
249  }
250  auto bufferType = bufferization::getBufferType(result, options);
251  if (failed(bufferType))
252  return failure();
253  newTypes.push_back(*bufferType);
254  }
255 
256  // Create new op.
257  rewriter.setInsertionPoint(ifOp);
258  auto newIfOp =
259  rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
260  /*withElseRegion=*/true);
261 
262  // Move over then/else blocks.
263  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
264  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
265 
266  // Replace op results.
267  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
268 
269  return success();
270  }
271 
274  SmallVector<Value> &invocationStack) const {
275  auto ifOp = cast<scf::IfOp>(op);
276  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
277  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
278  assert(value.getDefiningOp() == op && "invalid valid");
279 
280  // Determine buffer types of the true/false branches.
281  auto opResult = cast<OpResult>(value);
282  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
283  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
284  BaseMemRefType thenBufferType, elseBufferType;
285  if (isa<BaseMemRefType>(thenValue.getType())) {
286  // True branch was already bufferized.
287  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
288  } else {
289  auto maybeBufferType =
290  bufferization::getBufferType(thenValue, options, invocationStack);
291  if (failed(maybeBufferType))
292  return failure();
293  thenBufferType = *maybeBufferType;
294  }
295  if (isa<BaseMemRefType>(elseValue.getType())) {
296  // False branch was already bufferized.
297  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
298  } else {
299  auto maybeBufferType =
300  bufferization::getBufferType(elseValue, options, invocationStack);
301  if (failed(maybeBufferType))
302  return failure();
303  elseBufferType = *maybeBufferType;
304  }
305 
306  // Best case: Both branches have the exact same buffer type.
307  if (thenBufferType == elseBufferType)
308  return thenBufferType;
309 
310  // Memory space mismatch.
311  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
312  return op->emitError("inconsistent memory space on then/else branches");
313 
314  // Layout maps are different: Promote to fully dynamic layout map.
316  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
317  }
318 };
319 
320 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
321 /// yields memrefs.
322 struct IndexSwitchOpInterface
323  : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
324  scf::IndexSwitchOp> {
326  getAliasingOpOperands(Operation *op, Value value,
327  const AnalysisState &state) const {
328  // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
329  // any SSA. This is similar to IfOps.
330  auto switchOp = cast<scf::IndexSwitchOp>(op);
331  int64_t resultNum = cast<OpResult>(value).getResultNumber();
332  AliasingOpOperandList result;
333  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
334  auto yieldOp =
335  cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
336  result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
338  /*isDefinite=*/false));
339  }
340  auto defaultYieldOp =
341  cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
342  result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
344  /*isDefinite=*/false));
345  return result;
346  }
347 
348  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
349  const BufferizationOptions &options) const {
350  OpBuilder::InsertionGuard g(rewriter);
351  auto switchOp = cast<scf::IndexSwitchOp>(op);
352 
353  // Compute bufferized result types.
354  SmallVector<Type> newTypes;
355  for (Value result : switchOp.getResults()) {
356  if (!isa<TensorType>(result.getType())) {
357  newTypes.push_back(result.getType());
358  continue;
359  }
360  auto bufferType = bufferization::getBufferType(result, options);
361  if (failed(bufferType))
362  return failure();
363  newTypes.push_back(*bufferType);
364  }
365 
366  // Create new op.
367  rewriter.setInsertionPoint(switchOp);
368  auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
369  switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
370  switchOp.getCases().size());
371 
372  // Move over blocks.
373  for (auto [src, dest] :
374  llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
375  rewriter.inlineRegionBefore(src, dest, dest.begin());
376  rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
377  newSwitchOp.getDefaultRegion(),
378  newSwitchOp.getDefaultRegion().begin());
379 
380  // Replace op results.
381  replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
382 
383  return success();
384  }
385 
388  SmallVector<Value> &invocationStack) const {
389  auto switchOp = cast<scf::IndexSwitchOp>(op);
390  assert(value.getDefiningOp() == op && "invalid value");
391  int64_t resultNum = cast<OpResult>(value).getResultNumber();
392 
393  // Helper function to get buffer type of a case.
394  SmallVector<BaseMemRefType> yieldedTypes;
395  auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
396  auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
397  Value yieldedValue = yieldOp->getOperand(resultNum);
398  if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
399  return bufferType;
400  auto maybeBufferType =
401  bufferization::getBufferType(yieldedValue, options, invocationStack);
402  if (failed(maybeBufferType))
403  return failure();
404  return maybeBufferType;
405  };
406 
407  // Compute buffer type of the default case.
408  auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
409  if (failed(maybeBufferType))
410  return failure();
411  BaseMemRefType bufferType = *maybeBufferType;
412 
413  // Compute buffer types of all other cases.
414  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
415  auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
416  if (failed(yieldedBufferType))
417  return failure();
418 
419  // Best case: Both branches have the exact same buffer type.
420  if (bufferType == *yieldedBufferType)
421  continue;
422 
423  // Memory space mismatch.
424  if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
425  return op->emitError("inconsistent memory space on switch cases");
426 
427  // Layout maps are different: Promote to fully dynamic layout map.
429  cast<TensorType>(value.getType()), bufferType.getMemorySpace());
430  }
431 
432  return bufferType;
433  }
434 };
435 
436 /// Helper function for loop bufferization. Return the indices of all values
437 /// that have a tensor type.
438 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
439  DenseSet<int64_t> result;
440  for (const auto &it : llvm::enumerate(values))
441  if (isa<TensorType>(it.value().getType()))
442  result.insert(it.index());
443  return result;
444 }
445 
446 /// Helper function for loop bufferization. Return the indices of all
447 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
448 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
449  ValueRange yieldedValues,
450  const AnalysisState &state) {
451  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
452  DenseSet<int64_t> result;
453  for (unsigned int i = 0; i < minSize; ++i) {
454  if (!isa<TensorType>(bbArgs[i].getType()) ||
455  !isa<TensorType>(yieldedValues[i].getType()))
456  continue;
457  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
458  result.insert(i);
459  }
460  return result;
461 }
462 
463 /// Helper function for loop bufferization. Return the bufferized values of the
464 /// given OpOperands. If an operand is not a tensor, return the original value.
466 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
467  const BufferizationOptions &options) {
468  SmallVector<Value> result;
469  for (OpOperand &opOperand : operands) {
470  if (isa<TensorType>(opOperand.get().getType())) {
471  FailureOr<Value> resultBuffer =
472  getBuffer(rewriter, opOperand.get(), options);
473  if (failed(resultBuffer))
474  return failure();
475  result.push_back(*resultBuffer);
476  } else {
477  result.push_back(opOperand.get());
478  }
479  }
480  return result;
481 }
482 
483 /// Helper function for loop bufferization. Given a list of bbArgs of the new
484 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
485 /// ToTensorOps, so that the block body can be moved over to the new op.
486 static SmallVector<Value>
487 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
488  const DenseSet<int64_t> &tensorIndices) {
489  SmallVector<Value> result;
490  for (const auto &it : llvm::enumerate(bbArgs)) {
491  size_t idx = it.index();
492  Value val = it.value();
493  if (tensorIndices.contains(idx)) {
494  result.push_back(
495  rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
496  .getResult());
497  } else {
498  result.push_back(val);
499  }
500  }
501  return result;
502 }
503 
504 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
505 /// the bufferized type of the corresponding init_arg and the bufferized type
506 /// of the corresponding yielded value.
507 ///
508 /// This function uses bufferization::getBufferType to compute the bufferized
509 /// type of the init_arg and of the yielded value. (The computation of the
510 /// bufferized yielded value type usually requires computing the bufferized type
511 /// of the iter_arg again; the implementation of getBufferType traces back the
512 /// use-def chain of the given value and computes a buffer type along the way.)
513 /// If both buffer types are equal, no casts are needed the computed buffer type
514 /// can be used directly. Otherwise, the buffer types can only differ in their
515 /// layout map and a cast must be inserted.
516 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
517  Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
518  const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
519  // Determine the buffer type of the init_arg.
520  auto initArgBufferType =
521  bufferization::getBufferType(initArg, options, invocationStack);
522  if (failed(initArgBufferType))
523  return failure();
524 
525  if (llvm::count(invocationStack, iterArg) >= 2) {
526  // If the iter_arg is already twice on the invocation stack, just take the
527  // type of the init_arg. This is to avoid infinite loops when calculating
528  // the buffer type. This will most likely result in computing a memref type
529  // with a fully dynamic layout map.
530 
531  // Note: For more precise layout map computation, a fixpoint iteration could
532  // be done (i.e., re-computing the yielded buffer type until the bufferized
533  // iter_arg type no longer changes). This current implementation immediately
534  // switches to a fully dynamic layout map when a mismatch between bufferized
535  // init_arg type and bufferized yield value type is detected.
536  return *initArgBufferType;
537  }
538 
539  // Compute the buffer type of the yielded value.
540  BaseMemRefType yieldedValueBufferType;
541  if (isa<BaseMemRefType>(yieldedValue.getType())) {
542  // scf.yield was already bufferized.
543  yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
544  } else {
545  // Note: This typically triggers a recursive call for the buffer type of
546  // the iter_arg.
547  auto maybeBufferType =
548  bufferization::getBufferType(yieldedValue, options, invocationStack);
549  if (failed(maybeBufferType))
550  return failure();
551  yieldedValueBufferType = *maybeBufferType;
552  }
553 
554  // If yielded type and init_arg type are the same, use that type directly.
555  if (*initArgBufferType == yieldedValueBufferType)
556  return yieldedValueBufferType;
557 
558  // If there is a mismatch between the yielded buffer type and the init_arg
559  // buffer type, the buffer type must be promoted to a fully dynamic layout
560  // map.
561  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
562  auto iterTensorType = cast<TensorType>(iterArg.getType());
563  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
564  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
565  return loopOp->emitOpError(
566  "init_arg and yielded value bufferize to inconsistent memory spaces");
567 #ifndef NDEBUG
568  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
569  assert(
570  llvm::all_equal({yieldedRankedBufferType.getShape(),
571  cast<MemRefType>(initBufferType).getShape(),
572  cast<RankedTensorType>(iterTensorType).getShape()}) &&
573  "expected same shape");
574  }
575 #endif // NDEBUG
577  iterTensorType, yieldedBufferType.getMemorySpace());
578 }
579 
580 /// Return `true` if the given loop may have 0 iterations.
581 bool mayHaveZeroIterations(scf::ForOp forOp) {
582  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
583  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
584  if (!lb.has_value() || !ub.has_value())
585  return true;
586  return *ub <= *lb;
587 }
588 
589 /// Bufferization of scf.for. Replace with a new scf.for that operates on
590 /// memrefs.
591 struct ForOpInterface
592  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
593  scf::ForOp> {
594  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
595  const AnalysisState &state) const {
596  auto forOp = cast<scf::ForOp>(op);
597 
598  // If the loop has zero iterations, the results of the op are their
599  // corresponding init_args, meaning that the init_args bufferize to a read.
600  if (mayHaveZeroIterations(forOp))
601  return true;
602 
603  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
604  // its matching bbArg may.
605  return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
606  }
607 
608  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
609  const AnalysisState &state) const {
610  // Tensor iter_args of scf::ForOps are always considered as a write.
611  return true;
612  }
613 
614  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
615  const AnalysisState &state) const {
616  auto forOp = cast<scf::ForOp>(op);
617  OpResult opResult = forOp.getTiedLoopResult(&opOperand);
618  BufferRelation relation = bufferRelation(op, opResult, state);
619  return {{opResult, relation,
620  /*isDefinite=*/relation == BufferRelation::Equivalent}};
621  }
622 
623  BufferRelation bufferRelation(Operation *op, OpResult opResult,
624  const AnalysisState &state) const {
625  // ForOp results are equivalent to their corresponding init_args if the
626  // corresponding iter_args and yield values are equivalent.
627  auto forOp = cast<scf::ForOp>(op);
628  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
629  bool equivalentYield = state.areEquivalentBufferizedValues(
630  bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
631  return equivalentYield ? BufferRelation::Equivalent
633  }
634 
635  bool isWritable(Operation *op, Value value,
636  const AnalysisState &state) const {
637  // Interestingly, scf::ForOp's bbArg can **always** be viewed
638  // inplace from the perspective of ops nested under:
639  // 1. Either the matching iter operand is not bufferized inplace and an
640  // alloc + optional copy makes the bbArg itself inplaceable.
641  // 2. Or the matching iter operand is bufferized inplace and bbArg just
642  // bufferizes to that too.
643  return true;
644  }
645 
646  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
647  const AnalysisState &state) const {
648  auto bufferizableOp = cast<BufferizableOpInterface>(op);
649  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
650  return failure();
651 
652  if (!state.getOptions().enforceAliasingInvariants)
653  return success();
654 
655  // According to the `getAliasing...` implementations, a bufferized OpResult
656  // may alias only with the corresponding bufferized init_arg (or with a
657  // newly allocated buffer) and not with other buffers defined outside of the
658  // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
659  // but not with any other OpOperand.
660  auto forOp = cast<scf::ForOp>(op);
661  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
662  OpBuilder::InsertionGuard g(rewriter);
663  rewriter.setInsertionPoint(yieldOp);
664 
665  // Indices of all iter_args that have tensor type. These are the ones that
666  // are bufferized.
667  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
668  // For every yielded value, does it alias with something defined outside of
669  // the loop?
670  SmallVector<Value> yieldValues;
671  for (const auto it : llvm::enumerate(yieldOp.getResults())) {
672  // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
673  // type cannot be used in the signature of `resolveConflicts` because the
674  // op interface is in the "IR" build unit and the `OneShotAnalysisState`
675  // is defined in the "Transforms" build unit.
676  if (!indices.contains(it.index()) ||
677  doesNotAliasExternalValue(
678  it.value(), &forOp.getRegion(),
679  /*exceptions=*/forOp.getRegionIterArg(it.index()),
680  static_cast<const OneShotAnalysisState &>(state))) {
681  yieldValues.push_back(it.value());
682  continue;
683  }
685  rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
686  if (failed(alloc))
687  return failure();
688  yieldValues.push_back(*alloc);
689  }
690 
691  rewriter.modifyOpInPlace(
692  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
693  return success();
694  }
695 
698  SmallVector<Value> &invocationStack) const {
699  auto forOp = cast<scf::ForOp>(op);
700  assert(getOwnerOfValue(value) == op && "invalid value");
701  assert(isa<TensorType>(value.getType()) && "expected tensor type");
702 
703  if (auto opResult = dyn_cast<OpResult>(value)) {
704  // The type of an OpResult must match the corresponding iter_arg type.
705  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
706  return bufferization::getBufferType(bbArg, options, invocationStack);
707  }
708 
709  // Compute result/argument number.
710  BlockArgument bbArg = cast<BlockArgument>(value);
711  unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
712 
713  // Compute the bufferized type.
714  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
715  Value yieldedValue = yieldOp.getOperand(resultNum);
716  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
717  Value initArg = forOp.getInitArgs()[resultNum];
718  return computeLoopRegionIterArgBufferType(
719  op, iterArg, initArg, yieldedValue, options, invocationStack);
720  }
721 
722  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
723  const BufferizationOptions &options) const {
724  auto forOp = cast<scf::ForOp>(op);
725  Block *oldLoopBody = forOp.getBody();
726 
727  // Indices of all iter_args that have tensor type. These are the ones that
728  // are bufferized.
729  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
730 
731  // The new memref init_args of the loop.
732  FailureOr<SmallVector<Value>> maybeInitArgs =
733  getBuffers(rewriter, forOp.getInitArgsMutable(), options);
734  if (failed(maybeInitArgs))
735  return failure();
736  SmallVector<Value> initArgs = *maybeInitArgs;
737 
738  // Cast init_args if necessary.
739  SmallVector<Value> castedInitArgs;
740  for (const auto &it : llvm::enumerate(initArgs)) {
741  Value initArg = it.value();
742  Value result = forOp->getResult(it.index());
743  // If the type is not a tensor, bufferization doesn't need to touch it.
744  if (!isa<TensorType>(result.getType())) {
745  castedInitArgs.push_back(initArg);
746  continue;
747  }
748  auto targetType = bufferization::getBufferType(result, options);
749  if (failed(targetType))
750  return failure();
751  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
752  }
753 
754  // Construct a new scf.for op with memref instead of tensor values.
755  auto newForOp = rewriter.create<scf::ForOp>(
756  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
757  forOp.getStep(), castedInitArgs);
758  newForOp->setAttrs(forOp->getAttrs());
759  Block *loopBody = newForOp.getBody();
760 
761  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
762  // iter_args of the new loop in ToTensorOps.
763  rewriter.setInsertionPointToStart(loopBody);
764  SmallVector<Value> iterArgs =
765  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
766  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
767 
768  // Move loop body to new loop.
769  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
770 
771  // Replace loop results.
772  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
773 
774  return success();
775  }
776 
777  /// Assert that yielded values of an scf.for op are equivalent to their
778  /// corresponding bbArgs. In that case, the buffer relations of the
779  /// corresponding OpResults are "Equivalent".
780  ///
781  /// If this is not the case, an allocs+copies are inserted and yielded from
782  /// the loop. This could be a performance problem, so it must be explicitly
783  /// activated with `alloc-return-allocs`.
784  LogicalResult verifyAnalysis(Operation *op,
785  const AnalysisState &state) const {
786  const auto &options =
787  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
788  if (options.allowReturnAllocsFromLoops)
789  return success();
790 
791  auto forOp = cast<scf::ForOp>(op);
792  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
793  for (OpResult opResult : op->getOpResults()) {
794  if (!isa<TensorType>(opResult.getType()))
795  continue;
796 
797  // Note: This is overly strict. We should check for aliasing bufferized
798  // values. But we don't have a "must-alias" analysis yet.
799  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
800  return yieldOp->emitError()
801  << "Yield operand #" << opResult.getResultNumber()
802  << " is not equivalent to the corresponding iter bbArg";
803  }
804 
805  return success();
806  }
807 };
808 
809 /// Bufferization of scf.while. Replace with a new scf.while that operates on
810 /// memrefs.
811 struct WhileOpInterface
812  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
813  scf::WhileOp> {
814  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
815  const AnalysisState &state) const {
816  // Tensor iter_args of scf::WhileOps are always considered as a read.
817  return true;
818  }
819 
820  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
821  const AnalysisState &state) const {
822  // Tensor iter_args of scf::WhileOps are always considered as a write.
823  return true;
824  }
825 
826  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
827  const AnalysisState &state) const {
828  auto whileOp = cast<scf::WhileOp>(op);
829  unsigned int idx = opOperand.getOperandNumber();
830 
831  // The OpResults and OpOperands may not match. They may not even have the
832  // same type. The number of OpResults and OpOperands can also differ.
833  if (idx >= op->getNumResults() ||
834  opOperand.get().getType() != op->getResult(idx).getType())
835  return {};
836 
837  // The only aliasing OpResult may be the one at the same index.
838  OpResult opResult = whileOp->getResult(idx);
839  BufferRelation relation = bufferRelation(op, opResult, state);
840  return {{opResult, relation,
841  /*isDefinite=*/relation == BufferRelation::Equivalent}};
842  }
843 
844  BufferRelation bufferRelation(Operation *op, OpResult opResult,
845  const AnalysisState &state) const {
846  // WhileOp results are equivalent to their corresponding init_args if the
847  // corresponding iter_args and yield values are equivalent (for both the
848  // "before" and the "after" block).
849  unsigned int resultNumber = opResult.getResultNumber();
850  auto whileOp = cast<scf::WhileOp>(op);
851 
852  // The "before" region bbArgs and the OpResults may not match.
853  if (resultNumber >= whileOp.getBeforeArguments().size())
855  if (opResult.getType() !=
856  whileOp.getBeforeArguments()[resultNumber].getType())
858 
859  auto conditionOp = whileOp.getConditionOp();
860  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
861  Value conditionOperand = conditionOp.getArgs()[resultNumber];
862  bool equivCondition =
863  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
864 
865  auto yieldOp = whileOp.getYieldOp();
866  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
867  Value yieldOperand = yieldOp.getOperand(resultNumber);
868  bool equivYield =
869  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
870 
871  return equivCondition && equivYield ? BufferRelation::Equivalent
873  }
874 
875  bool isWritable(Operation *op, Value value,
876  const AnalysisState &state) const {
877  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
878  // inplace from the perspective of ops nested under:
879  // 1. Either the matching iter operand is not bufferized inplace and an
880  // alloc + optional copy makes the bbArg itself inplaceable.
881  // 2. Or the matching iter operand is bufferized inplace and bbArg just
882  // bufferizes to that too.
883  return true;
884  }
885 
886  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
887  const AnalysisState &state) const {
888  auto bufferizableOp = cast<BufferizableOpInterface>(op);
889  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
890  return failure();
891 
892  if (!state.getOptions().enforceAliasingInvariants)
893  return success();
894 
895  // According to the `getAliasing...` implementations, a bufferized OpResult
896  // may alias only with the corresponding bufferized init_arg and with no
897  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
898  // but not with any other OpOperand. If a corresponding OpResult/init_arg
899  // pair bufferizes to equivalent buffers, this aliasing requirement is
900  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
901  // (New buffer copies do not alias with any buffer.)
902  OpBuilder::InsertionGuard g(rewriter);
903  auto whileOp = cast<scf::WhileOp>(op);
904  auto conditionOp = whileOp.getConditionOp();
905 
906  // For every yielded value, is the value equivalent to its corresponding
907  // bbArg?
908  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
909  whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
910  DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
911  whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
912 
913  // Update "before" region.
914  rewriter.setInsertionPoint(conditionOp);
915  SmallVector<Value> beforeYieldValues;
916  for (int64_t idx = 0;
917  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
918  Value value = conditionOp.getArgs()[idx];
919  if (!isa<TensorType>(value.getType()) ||
920  (equivalentYieldsAfter.contains(idx) &&
921  equivalentYieldsBefore.contains(idx))) {
922  beforeYieldValues.push_back(value);
923  continue;
924  }
926  rewriter, conditionOp.getLoc(), value, state.getOptions());
927  if (failed(alloc))
928  return failure();
929  beforeYieldValues.push_back(*alloc);
930  }
931  rewriter.modifyOpInPlace(conditionOp, [&]() {
932  conditionOp.getArgsMutable().assign(beforeYieldValues);
933  });
934 
935  return success();
936  }
937 
938  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
939  const BufferizationOptions &options) const {
940  auto whileOp = cast<scf::WhileOp>(op);
941 
942  // Indices of all bbArgs that have tensor type. These are the ones that
943  // are bufferized. The "before" and "after" regions may have different args.
944  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
945  DenseSet<int64_t> indicesAfter =
946  getTensorIndices(whileOp.getAfterArguments());
947 
948  // The new memref init_args of the loop.
949  FailureOr<SmallVector<Value>> maybeInitArgs =
950  getBuffers(rewriter, whileOp.getInitsMutable(), options);
951  if (failed(maybeInitArgs))
952  return failure();
953  SmallVector<Value> initArgs = *maybeInitArgs;
954 
955  // Cast init_args if necessary.
956  SmallVector<Value> castedInitArgs;
957  for (const auto &it : llvm::enumerate(initArgs)) {
958  Value initArg = it.value();
959  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
960  // If the type is not a tensor, bufferization doesn't need to touch it.
961  if (!isa<TensorType>(beforeArg.getType())) {
962  castedInitArgs.push_back(initArg);
963  continue;
964  }
965  auto targetType = bufferization::getBufferType(beforeArg, options);
966  if (failed(targetType))
967  return failure();
968  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
969  }
970 
971  // The result types of a WhileOp are the same as the "after" bbArg types.
972  SmallVector<Type> argsTypesAfter = llvm::to_vector(
973  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
974  if (!isa<TensorType>(bbArg.getType()))
975  return bbArg.getType();
976  // TODO: error handling
977  return llvm::cast<Type>(
978  *bufferization::getBufferType(bbArg, options));
979  }));
980 
981  // Construct a new scf.while op with memref instead of tensor values.
982  ValueRange argsRangeBefore(castedInitArgs);
983  TypeRange argsTypesBefore(argsRangeBefore);
984  auto newWhileOp = rewriter.create<scf::WhileOp>(
985  whileOp.getLoc(), argsTypesAfter, castedInitArgs);
986 
987  // Add before/after regions to the new op.
988  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
989  whileOp.getLoc());
990  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
991  whileOp.getLoc());
992  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
993  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
994  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
995  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
996 
997  // Set up new iter_args and move the loop condition block to the new op.
998  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
999  // in ToTensorOps.
1000  rewriter.setInsertionPointToStart(newBeforeBody);
1001  SmallVector<Value> newBeforeArgs = getBbArgReplacements(
1002  rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
1003  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1004 
1005  // Set up new iter_args and move the loop body block to the new op.
1006  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1007  // in ToTensorOps.
1008  rewriter.setInsertionPointToStart(newAfterBody);
1009  SmallVector<Value> newAfterArgs = getBbArgReplacements(
1010  rewriter, newWhileOp.getAfterArguments(), indicesAfter);
1011  rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1012 
1013  // Replace loop results.
1014  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1015 
1016  return success();
1017  }
1018 
1021  SmallVector<Value> &invocationStack) const {
1022  auto whileOp = cast<scf::WhileOp>(op);
1023  assert(getOwnerOfValue(value) == op && "invalid value");
1024  assert(isa<TensorType>(value.getType()) && "expected tensor type");
1025 
1026  // Case 1: Block argument of the "before" region.
1027  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1028  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1029  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1030  auto yieldOp = whileOp.getYieldOp();
1031  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1032  return computeLoopRegionIterArgBufferType(
1033  op, bbArg, initArg, yieldedValue, options, invocationStack);
1034  }
1035  }
1036 
1037  // Case 2: OpResult of the loop or block argument of the "after" region.
1038  // The bufferized "after" bbArg type can be directly computed from the
1039  // bufferized "before" bbArg type.
1040  unsigned resultNum;
1041  if (auto opResult = dyn_cast<OpResult>(value)) {
1042  resultNum = opResult.getResultNumber();
1043  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1044  &whileOp.getAfter()) {
1045  resultNum = cast<BlockArgument>(value).getArgNumber();
1046  } else {
1047  llvm_unreachable("invalid value");
1048  }
1049  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1050  if (!isa<TensorType>(conditionYieldedVal.getType())) {
1051  // scf.condition was already bufferized.
1052  return cast<BaseMemRefType>(conditionYieldedVal.getType());
1053  }
1054  return bufferization::getBufferType(conditionYieldedVal, options,
1055  invocationStack);
1056  }
1057 
1058  /// Assert that yielded values of an scf.while op are equivalent to their
1059  /// corresponding bbArgs. In that case, the buffer relations of the
1060  /// corresponding OpResults are "Equivalent".
1061  ///
1062  /// If this is not the case, allocs+copies are inserted and yielded from
1063  /// the loop. This could be a performance problem, so it must be explicitly
1064  /// activated with `allow-return-allocs`.
1065  ///
1066  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1067  /// equivalence condition must be checked for both.
1068  LogicalResult verifyAnalysis(Operation *op,
1069  const AnalysisState &state) const {
1070  auto whileOp = cast<scf::WhileOp>(op);
1071  const auto &options =
1072  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1073  if (options.allowReturnAllocsFromLoops)
1074  return success();
1075 
1076  auto conditionOp = whileOp.getConditionOp();
1077  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1078  Block *block = conditionOp->getBlock();
1079  if (!isa<TensorType>(it.value().getType()))
1080  continue;
1081  if (it.index() >= block->getNumArguments() ||
1082  !state.areEquivalentBufferizedValues(it.value(),
1083  block->getArgument(it.index())))
1084  return conditionOp->emitError()
1085  << "Condition arg #" << it.index()
1086  << " is not equivalent to the corresponding iter bbArg";
1087  }
1088 
1089  auto yieldOp = whileOp.getYieldOp();
1090  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1091  Block *block = yieldOp->getBlock();
1092  if (!isa<TensorType>(it.value().getType()))
1093  continue;
1094  if (it.index() >= block->getNumArguments() ||
1095  !state.areEquivalentBufferizedValues(it.value(),
1096  block->getArgument(it.index())))
1097  return yieldOp->emitError()
1098  << "Yield operand #" << it.index()
1099  << " is not equivalent to the corresponding iter bbArg";
1100  }
1101 
1102  return success();
1103  }
1104 };
1105 
1106 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1107 /// this is for analysis only.
1108 struct YieldOpInterface
1109  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1110  scf::YieldOp> {
1111  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1112  const AnalysisState &state) const {
1113  return true;
1114  }
1115 
1116  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1117  const AnalysisState &state) const {
1118  return false;
1119  }
1120 
1121  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1122  const AnalysisState &state) const {
1123  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1124  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1125  BufferRelation::Equivalent, /*isDefinite=*/false}};
1126  }
1127  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1128  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1130  return {};
1131  }
1132 
1133  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1134  const AnalysisState &state) const {
1135  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1136  // may be generated inside the block. We should not return/yield allocations
1137  // when possible.
1138  return true;
1139  }
1140 
1141  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1142  const BufferizationOptions &options) const {
1143  auto yieldOp = cast<scf::YieldOp>(op);
1144  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1145  scf::WhileOp>(yieldOp->getParentOp()))
1146  return yieldOp->emitError("unsupported scf::YieldOp parent");
1147 
1148  SmallVector<Value> newResults;
1149  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1150  Value value = it.value();
1151  if (isa<TensorType>(value.getType())) {
1152  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
1153  if (failed(maybeBuffer))
1154  return failure();
1155  Value buffer = *maybeBuffer;
1156  // We may have to cast the value before yielding it.
1157  if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1158  yieldOp->getParentOp())) {
1160  yieldOp->getParentOp()->getResult(it.index()), options);
1161  if (failed(resultType))
1162  return failure();
1163  buffer = castBuffer(rewriter, buffer, *resultType);
1164  } else if (auto whileOp =
1165  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1167  whileOp.getBeforeArguments()[it.index()], options);
1168  if (failed(resultType))
1169  return failure();
1170  buffer = castBuffer(rewriter, buffer, *resultType);
1171  }
1172  newResults.push_back(buffer);
1173  } else {
1174  newResults.push_back(value);
1175  }
1176  }
1177 
1178  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1179  return success();
1180  }
1181 };
1182 
1183 /// Return `true` if the given loop may have 0 iterations.
1184 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1185  for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1186  forallOp.getMixedUpperBound())) {
1187  std::optional<int64_t> lbConst = getConstantIntValue(lb);
1188  std::optional<int64_t> ubConst = getConstantIntValue(ub);
1189  if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1190  return true;
1191  }
1192  return false;
1193 }
1194 
1195 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1196 /// region. There are op interfaces for the terminators (InParallelOp
1197 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1198 /// for bufferization.
1199 struct ForallOpInterface
1200  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1201  ForallOp> {
1202  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1203  const AnalysisState &state) const {
1204  auto forallOp = cast<ForallOp>(op);
1205 
1206  // If the loop has zero iterations, the results of the op are their
1207  // corresponding shared_outs, meaning that the shared_outs bufferize to a
1208  // read.
1209  if (mayHaveZeroIterations(forallOp))
1210  return true;
1211 
1212  // scf::ForallOp alone doesn't bufferize to a memory read, one of the
1213  // uses of its matching bbArg may.
1214  return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1215  }
1216 
1217  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1218  const AnalysisState &state) const {
1219  // Outputs of scf::ForallOps are always considered as a write.
1220  return true;
1221  }
1222 
1223  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1224  const AnalysisState &state) const {
1225  auto forallOp = cast<ForallOp>(op);
1226  return {
1227  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1228  }
1229 
1230  bool isWritable(Operation *op, Value value,
1231  const AnalysisState &state) const {
1232  return true;
1233  }
1234 
1235  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1236  const BufferizationOptions &options) const {
1237  OpBuilder::InsertionGuard guard(rewriter);
1238  auto forallOp = cast<ForallOp>(op);
1239  int64_t rank = forallOp.getRank();
1240 
1241  // Get buffers for all output operands.
1242  SmallVector<Value> buffers;
1243  for (Value out : forallOp.getOutputs()) {
1244  FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1245  if (failed(buffer))
1246  return failure();
1247  buffers.push_back(*buffer);
1248  }
1249 
1250  // Use buffers instead of block arguments.
1251  rewriter.setInsertionPointToStart(forallOp.getBody());
1252  for (const auto &it : llvm::zip(
1253  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1254  BlockArgument bbArg = std::get<0>(it);
1255  Value buffer = std::get<1>(it);
1256  Value bufferAsTensor =
1257  rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
1258  bbArg.replaceAllUsesWith(bufferAsTensor);
1259  }
1260 
1261  // Create new ForallOp without any results and drop the automatically
1262  // introduced terminator.
1263  rewriter.setInsertionPoint(forallOp);
1264  ForallOp newForallOp;
1265  newForallOp = rewriter.create<ForallOp>(
1266  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1267  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1268  /*outputs=*/ValueRange(), forallOp.getMapping());
1269 
1270  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1271 
1272  // Move over block contents of the old op.
1273  SmallVector<Value> replacementBbArgs;
1274  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1275  newForallOp.getBody()->getArguments().end());
1276  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1277  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1278  replacementBbArgs);
1279 
1280  // Remove the old op and replace all of its uses.
1281  replaceOpWithBufferizedValues(rewriter, op, buffers);
1282 
1283  return success();
1284  }
1285 
1288  SmallVector<Value> &invocationStack) const {
1289  auto forallOp = cast<ForallOp>(op);
1290 
1291  if (auto bbArg = dyn_cast<BlockArgument>(value))
1292  // A tensor block argument has the same bufferized type as the
1293  // corresponding output operand.
1295  forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
1296 
1297  // The bufferized result type is the same as the bufferized type of the
1298  // corresponding output operand.
1300  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1301  invocationStack);
1302  }
1303 
1304  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1305  auto forallOp = cast<ForallOp>(op);
1306 
1307  // This op is repetitive if it has 1 or more steps.
1308  // If the control variables are dynamic, it is also considered so.
1309  for (auto [lb, ub, step] :
1310  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1311  forallOp.getMixedStep())) {
1312  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1313  if (!lbConstant)
1314  return true;
1315 
1316  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1317  if (!ubConstant)
1318  return true;
1319 
1320  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1321  if (!stepConstant)
1322  return true;
1323 
1324  if (*lbConstant + *stepConstant < *ubConstant)
1325  return true;
1326  }
1327  return false;
1328  }
1329 
1330  bool isParallelRegion(Operation *op, unsigned index) const {
1331  return isRepetitiveRegion(op, index);
1332  }
1333 };
1334 
1335 /// Nothing to do for InParallelOp.
1336 struct InParallelOpInterface
1337  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1338  InParallelOp> {
1339  LogicalResult bufferize(Operation *op, RewriterBase &b,
1340  const BufferizationOptions &options) const {
1341  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1342  return failure();
1343  }
1344 };
1345 
1346 } // namespace
1347 } // namespace scf
1348 } // namespace mlir
1349 
1351  DialectRegistry &registry) {
1352  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1353  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1354  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1355  ForOp::attachInterface<ForOpInterface>(*ctx);
1356  IfOp::attachInterface<IfOpInterface>(*ctx);
1357  IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1358  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1359  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1360  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1361  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1362  });
1363 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:315
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:324
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:327
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:453
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_range getOpResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition: Region.cpp:50
BlockListType & getBlocks()
Definition: Region.h:45
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:631
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:169
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:554
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...