MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::scf;
27 
28 namespace mlir {
29 namespace scf {
30 namespace {
31 
32 /// Helper function for loop bufferization. Cast the given buffer to the given
33 /// memref type.
34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37  // If the buffer already has the correct type, no cast is needed.
38  if (buffer.getType() == type)
39  return buffer;
40  // TODO: In case `type` has a layout map that is not the fully dynamic
41  // one, we may not be able to cast the buffer. In that case, the loop
42  // iter_arg's layout map must be changed (see uses of `castBuffer`).
43  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44  "scf.while op bufferization: cast incompatible");
45  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46 }
47 
48 /// Helper function for loop bufferization. Return "true" if the given value
49 /// is guaranteed to not alias with an external tensor apart from values in
50 /// `exceptions`. A value is external if it is defined outside of the given
51 /// region or if it is an entry block argument of the region.
52 static bool doesNotAliasExternalValue(Value value, Region *region,
53  ValueRange exceptions,
54  const OneShotAnalysisState &state) {
55  assert(llvm::hasSingleElement(region->getBlocks()) &&
56  "expected region with single block");
57  bool result = true;
58  state.applyOnAliases(value, [&](Value alias) {
59  if (llvm::is_contained(exceptions, alias))
60  return;
61  Region *aliasRegion = alias.getParentRegion();
62  if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63  result = false;
64  if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65  result = false;
66  });
67  return result;
68 }
69 
70 /// Bufferization of scf.condition.
71 struct ConditionOpInterface
72  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73  scf::ConditionOp> {
74  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75  const AnalysisState &state) const {
76  return true;
77  }
78 
79  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80  const AnalysisState &state) const {
81  return false;
82  }
83 
84  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85  const AnalysisState &state) const {
86  return {};
87  }
88 
89  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90  const AnalysisState &state) const {
91  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92  // may be generated inside the block. We should not return/yield allocations
93  // when possible.
94  return true;
95  }
96 
97  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
99  BufferizationState &state) const {
100  auto conditionOp = cast<scf::ConditionOp>(op);
101  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
102 
103  SmallVector<Value> newArgs;
104  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
105  Value value = it.value();
106  if (isa<TensorType>(value.getType())) {
107  FailureOr<Value> maybeBuffer =
108  getBuffer(rewriter, value, options, state);
109  if (failed(maybeBuffer))
110  return failure();
111  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
112  whileOp.getAfterArguments()[it.index()], options, state);
113  if (failed(resultType))
114  return failure();
115  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
116  newArgs.push_back(buffer);
117  } else {
118  newArgs.push_back(value);
119  }
120  }
121 
122  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
123  rewriter, op, conditionOp.getCondition(), newArgs);
124  return success();
125  }
126 };
127 
128 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
129 /// return an empty op.
130 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
131  scf::YieldOp result;
132  for (Block &block : executeRegionOp.getRegion()) {
133  if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
134  if (result)
135  return {};
136  result = yieldOp;
137  }
138  }
139  return result;
140 }
141 
142 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
143 /// fully implemented at the moment.
144 struct ExecuteRegionOpInterface
146  ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
147 
148  static bool supportsUnstructuredControlFlow() { return true; }
149 
150  bool isWritable(Operation *op, Value value,
151  const AnalysisState &state) const {
152  return true;
153  }
154 
155  LogicalResult verifyAnalysis(Operation *op,
156  const AnalysisState &state) const {
157  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
158  // TODO: scf.execute_region with multiple yields are not supported.
159  if (!getUniqueYieldOp(executeRegionOp))
160  return op->emitOpError("op without unique scf.yield is not supported");
161  return success();
162  }
163 
165  getAliasingOpOperands(Operation *op, Value value,
166  const AnalysisState &state) const {
167  if (auto bbArg = dyn_cast<BlockArgument>(value))
168  return getAliasingBranchOpOperands(op, bbArg, state);
169 
170  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
171  // any SSA value that is in scope. To allow for use-def chain traversal
172  // through ExecuteRegionOps in the analysis, the corresponding yield value
173  // is considered to be aliasing with the result.
174  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
175  auto it = llvm::find(op->getOpResults(), value);
176  assert(it != op->getOpResults().end() && "invalid value");
177  size_t resultNum = std::distance(op->getOpResults().begin(), it);
178  auto yieldOp = getUniqueYieldOp(executeRegionOp);
179  // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
180  if (!yieldOp)
181  return {};
182  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
183  }
184 
185  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
187  BufferizationState &state) const {
188  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
189  auto yieldOp = getUniqueYieldOp(executeRegionOp);
190  TypeRange newResultTypes(yieldOp.getResults());
191 
192  // Create new op and move over region.
193  auto newOp =
194  rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
195  newOp.getRegion().takeBody(executeRegionOp.getRegion());
196 
197  // Bufferize every block.
198  for (Block &block : newOp.getRegion())
199  if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
200  options, state)))
201  return failure();
202 
203  // Update all uses of the old op.
204  rewriter.setInsertionPointAfter(newOp);
205  SmallVector<Value> newResults;
206  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
207  if (isa<TensorType>(it.value())) {
208  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
209  executeRegionOp.getLoc(), it.value(),
210  newOp->getResult(it.index())));
211  } else {
212  newResults.push_back(newOp->getResult(it.index()));
213  }
214  }
215 
216  // Replace old op.
217  rewriter.replaceOp(executeRegionOp, newResults);
218 
219  return success();
220  }
221 };
222 
223 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
224 struct IfOpInterface
225  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
227  getAliasingOpOperands(Operation *op, Value value,
228  const AnalysisState &state) const {
229  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
230  // value that is in scope. To allow for use-def chain traversal through
231  // IfOps in the analysis, both corresponding yield values from the then/else
232  // branches are considered to be aliasing with the result.
233  auto ifOp = cast<scf::IfOp>(op);
234  size_t resultNum = std::distance(op->getOpResults().begin(),
235  llvm::find(op->getOpResults(), value));
236  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
237  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
238  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
239  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
240  }
241 
242  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
244  BufferizationState &state) const {
245  OpBuilder::InsertionGuard g(rewriter);
246  auto ifOp = cast<scf::IfOp>(op);
247 
248  // Compute bufferized result types.
249  SmallVector<Type> newTypes;
250  for (Value result : ifOp.getResults()) {
251  if (!isa<TensorType>(result.getType())) {
252  newTypes.push_back(result.getType());
253  continue;
254  }
255  auto bufferType = bufferization::getBufferType(result, options, state);
256  if (failed(bufferType))
257  return failure();
258  newTypes.push_back(*bufferType);
259  }
260 
261  // Create new op.
262  rewriter.setInsertionPoint(ifOp);
263  auto newIfOp =
264  rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
265  /*withElseRegion=*/true);
266 
267  // Move over then/else blocks.
268  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
269  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
270 
271  // Replace op results.
272  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
273 
274  return success();
275  }
276 
277  FailureOr<BaseMemRefType>
279  const BufferizationState &state,
280  SmallVector<Value> &invocationStack) const {
281  auto ifOp = cast<scf::IfOp>(op);
282  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
283  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
284  assert(value.getDefiningOp() == op && "invalid valid");
285 
286  // Determine buffer types of the true/false branches.
287  auto opResult = cast<OpResult>(value);
288  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
289  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
290  BaseMemRefType thenBufferType, elseBufferType;
291  if (isa<BaseMemRefType>(thenValue.getType())) {
292  // True branch was already bufferized.
293  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
294  } else {
295  auto maybeBufferType = bufferization::getBufferType(
296  thenValue, options, state, invocationStack);
297  if (failed(maybeBufferType))
298  return failure();
299  thenBufferType = *maybeBufferType;
300  }
301  if (isa<BaseMemRefType>(elseValue.getType())) {
302  // False branch was already bufferized.
303  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
304  } else {
305  auto maybeBufferType = bufferization::getBufferType(
306  elseValue, options, state, invocationStack);
307  if (failed(maybeBufferType))
308  return failure();
309  elseBufferType = *maybeBufferType;
310  }
311 
312  // Best case: Both branches have the exact same buffer type.
313  if (thenBufferType == elseBufferType)
314  return thenBufferType;
315 
316  // Memory space mismatch.
317  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
318  return op->emitError("inconsistent memory space on then/else branches");
319 
320  // Layout maps are different: Promote to fully dynamic layout map.
322  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
323  }
324 };
325 
326 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
327 /// yields memrefs.
328 struct IndexSwitchOpInterface
329  : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330  scf::IndexSwitchOp> {
332  getAliasingOpOperands(Operation *op, Value value,
333  const AnalysisState &state) const {
334  // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
335  // any SSA. This is similar to IfOps.
336  auto switchOp = cast<scf::IndexSwitchOp>(op);
337  int64_t resultNum = cast<OpResult>(value).getResultNumber();
338  AliasingOpOperandList result;
339  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
340  auto yieldOp =
341  cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
342  result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
344  /*isDefinite=*/false));
345  }
346  auto defaultYieldOp =
347  cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
348  result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
350  /*isDefinite=*/false));
351  return result;
352  }
353 
354  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
356  BufferizationState &state) const {
357  OpBuilder::InsertionGuard g(rewriter);
358  auto switchOp = cast<scf::IndexSwitchOp>(op);
359 
360  // Compute bufferized result types.
361  SmallVector<Type> newTypes;
362  for (Value result : switchOp.getResults()) {
363  if (!isa<TensorType>(result.getType())) {
364  newTypes.push_back(result.getType());
365  continue;
366  }
367  auto bufferType = bufferization::getBufferType(result, options, state);
368  if (failed(bufferType))
369  return failure();
370  newTypes.push_back(*bufferType);
371  }
372 
373  // Create new op.
374  rewriter.setInsertionPoint(switchOp);
375  auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
376  switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
377  switchOp.getCases().size());
378 
379  // Move over blocks.
380  for (auto [src, dest] :
381  llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
382  rewriter.inlineRegionBefore(src, dest, dest.begin());
383  rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
384  newSwitchOp.getDefaultRegion(),
385  newSwitchOp.getDefaultRegion().begin());
386 
387  // Replace op results.
388  replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
389 
390  return success();
391  }
392 
393  FailureOr<BaseMemRefType>
395  const BufferizationState &state,
396  SmallVector<Value> &invocationStack) const {
397  auto switchOp = cast<scf::IndexSwitchOp>(op);
398  assert(value.getDefiningOp() == op && "invalid value");
399  int64_t resultNum = cast<OpResult>(value).getResultNumber();
400 
401  // Helper function to get buffer type of a case.
402  auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
403  auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404  Value yieldedValue = yieldOp->getOperand(resultNum);
405  if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
406  return bufferType;
407  auto maybeBufferType = bufferization::getBufferType(
408  yieldedValue, options, state, invocationStack);
409  if (failed(maybeBufferType))
410  return failure();
411  return maybeBufferType;
412  };
413 
414  // Compute buffer type of the default case.
415  auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
416  if (failed(maybeBufferType))
417  return failure();
418  BaseMemRefType bufferType = *maybeBufferType;
419 
420  // Compute buffer types of all other cases.
421  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
422  auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
423  if (failed(yieldedBufferType))
424  return failure();
425 
426  // Best case: Both branches have the exact same buffer type.
427  if (bufferType == *yieldedBufferType)
428  continue;
429 
430  // Memory space mismatch.
431  if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
432  return op->emitError("inconsistent memory space on switch cases");
433 
434  // Layout maps are different: Promote to fully dynamic layout map.
436  cast<TensorType>(value.getType()), bufferType.getMemorySpace());
437  }
438 
439  return bufferType;
440  }
441 };
442 
443 /// Helper function for loop bufferization. Return the indices of all values
444 /// that have a tensor type.
445 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
446  DenseSet<int64_t> result;
447  for (const auto &it : llvm::enumerate(values))
448  if (isa<TensorType>(it.value().getType()))
449  result.insert(it.index());
450  return result;
451 }
452 
453 /// Helper function for loop bufferization. Return the indices of all
454 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
455 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
456  ValueRange yieldedValues,
457  const AnalysisState &state) {
458  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
459  DenseSet<int64_t> result;
460  for (unsigned int i = 0; i < minSize; ++i) {
461  if (!isa<TensorType>(bbArgs[i].getType()) ||
462  !isa<TensorType>(yieldedValues[i].getType()))
463  continue;
464  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
465  result.insert(i);
466  }
467  return result;
468 }
469 
470 /// Helper function for loop bufferization. Return the bufferized values of the
471 /// given OpOperands. If an operand is not a tensor, return the original value.
472 static FailureOr<SmallVector<Value>>
473 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
475  SmallVector<Value> result;
476  for (OpOperand &opOperand : operands) {
477  if (isa<TensorType>(opOperand.get().getType())) {
478  FailureOr<Value> resultBuffer =
479  getBuffer(rewriter, opOperand.get(), options, state);
480  if (failed(resultBuffer))
481  return failure();
482  result.push_back(*resultBuffer);
483  } else {
484  result.push_back(opOperand.get());
485  }
486  }
487  return result;
488 }
489 
490 /// Helper function for loop bufferization. Given a list of bbArgs of the new
491 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
492 /// ToTensorOps, so that the block body can be moved over to the new op.
493 static SmallVector<Value>
494 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
495  Block::BlockArgListType oldBbArgs,
496  const DenseSet<int64_t> &tensorIndices) {
497  SmallVector<Value> result;
498  for (const auto &it : llvm::enumerate(bbArgs)) {
499  size_t idx = it.index();
500  Value val = it.value();
501  if (tensorIndices.contains(idx)) {
502  result.push_back(rewriter
503  .create<bufferization::ToTensorOp>(
504  val.getLoc(), oldBbArgs[idx].getType(), val)
505  .getResult());
506  } else {
507  result.push_back(val);
508  }
509  }
510  return result;
511 }
512 
513 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
514 /// the bufferized type of the corresponding init_arg and the bufferized type
515 /// of the corresponding yielded value.
516 ///
517 /// This function uses bufferization::getBufferType to compute the bufferized
518 /// type of the init_arg and of the yielded value. (The computation of the
519 /// bufferized yielded value type usually requires computing the bufferized type
520 /// of the iter_arg again; the implementation of getBufferType traces back the
521 /// use-def chain of the given value and computes a buffer type along the way.)
522 /// If both buffer types are equal, no casts are needed the computed buffer type
523 /// can be used directly. Otherwise, the buffer types can only differ in their
524 /// layout map and a cast must be inserted.
525 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
526  Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
527  const BufferizationOptions &options, const BufferizationState &state,
528  SmallVector<Value> &invocationStack) {
529  // Determine the buffer type of the init_arg.
530  auto initArgBufferType =
531  bufferization::getBufferType(initArg, options, state, invocationStack);
532  if (failed(initArgBufferType))
533  return failure();
534 
535  if (llvm::count(invocationStack, iterArg) >= 2) {
536  // If the iter_arg is already twice on the invocation stack, just take the
537  // type of the init_arg. This is to avoid infinite loops when calculating
538  // the buffer type. This will most likely result in computing a memref type
539  // with a fully dynamic layout map.
540 
541  // Note: For more precise layout map computation, a fixpoint iteration could
542  // be done (i.e., re-computing the yielded buffer type until the bufferized
543  // iter_arg type no longer changes). This current implementation immediately
544  // switches to a fully dynamic layout map when a mismatch between bufferized
545  // init_arg type and bufferized yield value type is detected.
546  return *initArgBufferType;
547  }
548 
549  // Compute the buffer type of the yielded value.
550  BaseMemRefType yieldedValueBufferType;
551  if (isa<BaseMemRefType>(yieldedValue.getType())) {
552  // scf.yield was already bufferized.
553  yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
554  } else {
555  // Note: This typically triggers a recursive call for the buffer type of
556  // the iter_arg.
557  auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
558  state, invocationStack);
559  if (failed(maybeBufferType))
560  return failure();
561  yieldedValueBufferType = *maybeBufferType;
562  }
563 
564  // If yielded type and init_arg type are the same, use that type directly.
565  if (*initArgBufferType == yieldedValueBufferType)
566  return yieldedValueBufferType;
567 
568  // If there is a mismatch between the yielded buffer type and the init_arg
569  // buffer type, the buffer type must be promoted to a fully dynamic layout
570  // map.
571  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
572  auto iterTensorType = cast<TensorType>(iterArg.getType());
573  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
574  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
575  return loopOp->emitOpError(
576  "init_arg and yielded value bufferize to inconsistent memory spaces");
577 #ifndef NDEBUG
578  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
579  assert(
580  llvm::all_equal({yieldedRankedBufferType.getShape(),
581  cast<MemRefType>(initBufferType).getShape(),
582  cast<RankedTensorType>(iterTensorType).getShape()}) &&
583  "expected same shape");
584  }
585 #endif // NDEBUG
587  iterTensorType, yieldedBufferType.getMemorySpace());
588 }
589 
590 /// Return `true` if the given loop may have 0 iterations.
591 bool mayHaveZeroIterations(scf::ForOp forOp) {
592  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
593  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
594  if (!lb.has_value() || !ub.has_value())
595  return true;
596  return *ub <= *lb;
597 }
598 
599 /// Bufferization of scf.for. Replace with a new scf.for that operates on
600 /// memrefs.
601 struct ForOpInterface
602  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
603  scf::ForOp> {
604  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
605  const AnalysisState &state) const {
606  auto forOp = cast<scf::ForOp>(op);
607 
608  // If the loop has zero iterations, the results of the op are their
609  // corresponding init_args, meaning that the init_args bufferize to a read.
610  if (mayHaveZeroIterations(forOp))
611  return true;
612 
613  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
614  // its matching bbArg may.
615  return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
616  }
617 
618  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
619  const AnalysisState &state) const {
620  // Tensor iter_args of scf::ForOps are always considered as a write.
621  return true;
622  }
623 
624  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
625  const AnalysisState &state) const {
626  auto forOp = cast<scf::ForOp>(op);
627  OpResult opResult = forOp.getTiedLoopResult(&opOperand);
628  BufferRelation relation = bufferRelation(op, opResult, state);
629  return {{opResult, relation,
630  /*isDefinite=*/relation == BufferRelation::Equivalent}};
631  }
632 
633  BufferRelation bufferRelation(Operation *op, OpResult opResult,
634  const AnalysisState &state) const {
635  // ForOp results are equivalent to their corresponding init_args if the
636  // corresponding iter_args and yield values are equivalent.
637  auto forOp = cast<scf::ForOp>(op);
638  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
639  bool equivalentYield = state.areEquivalentBufferizedValues(
640  bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
641  return equivalentYield ? BufferRelation::Equivalent
643  }
644 
645  bool isWritable(Operation *op, Value value,
646  const AnalysisState &state) const {
647  // Interestingly, scf::ForOp's bbArg can **always** be viewed
648  // inplace from the perspective of ops nested under:
649  // 1. Either the matching iter operand is not bufferized inplace and an
650  // alloc + optional copy makes the bbArg itself inplaceable.
651  // 2. Or the matching iter operand is bufferized inplace and bbArg just
652  // bufferizes to that too.
653  return true;
654  }
655 
656  LogicalResult
657  resolveConflicts(Operation *op, RewriterBase &rewriter,
658  const AnalysisState &analysisState,
659  const BufferizationState &bufferizationState) const {
660  auto bufferizableOp = cast<BufferizableOpInterface>(op);
661  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
662  rewriter, analysisState, bufferizationState)))
663  return failure();
664 
665  if (analysisState.getOptions().copyBeforeWrite)
666  return success();
667 
668  // According to the `getAliasing...` implementations, a bufferized OpResult
669  // may alias only with the corresponding bufferized init_arg (or with a
670  // newly allocated buffer) and not with other buffers defined outside of the
671  // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
672  // but not with any other OpOperand.
673  auto forOp = cast<scf::ForOp>(op);
674  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
675  OpBuilder::InsertionGuard g(rewriter);
676  rewriter.setInsertionPoint(yieldOp);
677 
678  // Indices of all iter_args that have tensor type. These are the ones that
679  // are bufferized.
680  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
681  // For every yielded value, does it alias with something defined outside of
682  // the loop?
683  SmallVector<Value> yieldValues;
684  for (const auto it : llvm::enumerate(yieldOp.getResults())) {
685  // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
686  // type cannot be used in the signature of `resolveConflicts` because the
687  // op interface is in the "IR" build unit and the `OneShotAnalysisState`
688  // is defined in the "Transforms" build unit.
689  if (!indices.contains(it.index()) ||
690  doesNotAliasExternalValue(
691  it.value(), &forOp.getRegion(),
692  /*exceptions=*/forOp.getRegionIterArg(it.index()),
693  static_cast<const OneShotAnalysisState &>(analysisState))) {
694  yieldValues.push_back(it.value());
695  continue;
696  }
697  FailureOr<Value> alloc = allocateTensorForShapedValue(
698  rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
699  bufferizationState);
700  if (failed(alloc))
701  return failure();
702  yieldValues.push_back(*alloc);
703  }
704 
705  rewriter.modifyOpInPlace(
706  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
707  return success();
708  }
709 
710  FailureOr<BaseMemRefType>
712  const BufferizationState &state,
713  SmallVector<Value> &invocationStack) const {
714  auto forOp = cast<scf::ForOp>(op);
715  assert(getOwnerOfValue(value) == op && "invalid value");
716  assert(isa<TensorType>(value.getType()) && "expected tensor type");
717 
718  if (auto opResult = dyn_cast<OpResult>(value)) {
719  // The type of an OpResult must match the corresponding iter_arg type.
720  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
721  return bufferization::getBufferType(bbArg, options, state,
722  invocationStack);
723  }
724 
725  // Compute result/argument number.
726  BlockArgument bbArg = cast<BlockArgument>(value);
727  unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
728 
729  // Compute the bufferized type.
730  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
731  Value yieldedValue = yieldOp.getOperand(resultNum);
732  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
733  Value initArg = forOp.getInitArgs()[resultNum];
734  return computeLoopRegionIterArgBufferType(
735  op, iterArg, initArg, yieldedValue, options, state, invocationStack);
736  }
737 
738  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
740  BufferizationState &state) const {
741  auto forOp = cast<scf::ForOp>(op);
742  Block *oldLoopBody = forOp.getBody();
743 
744  // Indices of all iter_args that have tensor type. These are the ones that
745  // are bufferized.
746  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
747 
748  // The new memref init_args of the loop.
749  FailureOr<SmallVector<Value>> maybeInitArgs =
750  getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
751  if (failed(maybeInitArgs))
752  return failure();
753  SmallVector<Value> initArgs = *maybeInitArgs;
754 
755  // Cast init_args if necessary.
756  SmallVector<Value> castedInitArgs;
757  for (const auto &it : llvm::enumerate(initArgs)) {
758  Value initArg = it.value();
759  Value result = forOp->getResult(it.index());
760  // If the type is not a tensor, bufferization doesn't need to touch it.
761  if (!isa<TensorType>(result.getType())) {
762  castedInitArgs.push_back(initArg);
763  continue;
764  }
765  auto targetType = bufferization::getBufferType(result, options, state);
766  if (failed(targetType))
767  return failure();
768  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
769  }
770 
771  // Construct a new scf.for op with memref instead of tensor values.
772  auto newForOp = rewriter.create<scf::ForOp>(
773  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
774  forOp.getStep(), castedInitArgs);
775  newForOp->setAttrs(forOp->getAttrs());
776  Block *loopBody = newForOp.getBody();
777 
778  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
779  // iter_args of the new loop in ToTensorOps.
780  rewriter.setInsertionPointToStart(loopBody);
781  SmallVector<Value> iterArgs =
782  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
783  forOp.getRegionIterArgs(), indices);
784  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
785 
786  // Move loop body to new loop.
787  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
788 
789  // Replace loop results.
790  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
791 
792  return success();
793  }
794 
795  /// Assert that yielded values of an scf.for op are equivalent to their
796  /// corresponding bbArgs. In that case, the buffer relations of the
797  /// corresponding OpResults are "Equivalent".
798  ///
799  /// If this is not the case, an allocs+copies are inserted and yielded from
800  /// the loop. This could be a performance problem, so it must be explicitly
801  /// activated with `alloc-return-allocs`.
802  LogicalResult verifyAnalysis(Operation *op,
803  const AnalysisState &state) const {
804  const auto &options =
805  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
806  if (options.allowReturnAllocsFromLoops)
807  return success();
808 
809  auto forOp = cast<scf::ForOp>(op);
810  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
811  for (OpResult opResult : op->getOpResults()) {
812  if (!isa<TensorType>(opResult.getType()))
813  continue;
814 
815  // Note: This is overly strict. We should check for aliasing bufferized
816  // values. But we don't have a "must-alias" analysis yet.
817  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
818  return yieldOp->emitError()
819  << "Yield operand #" << opResult.getResultNumber()
820  << " is not equivalent to the corresponding iter bbArg";
821  }
822 
823  return success();
824  }
825 };
826 
827 /// Bufferization of scf.while. Replace with a new scf.while that operates on
828 /// memrefs.
829 struct WhileOpInterface
830  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
831  scf::WhileOp> {
832  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
833  const AnalysisState &state) const {
834  // Tensor iter_args of scf::WhileOps are always considered as a read.
835  return true;
836  }
837 
838  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
839  const AnalysisState &state) const {
840  // Tensor iter_args of scf::WhileOps are always considered as a write.
841  return true;
842  }
843 
844  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
845  const AnalysisState &state) const {
846  auto whileOp = cast<scf::WhileOp>(op);
847  unsigned int idx = opOperand.getOperandNumber();
848 
849  // The OpResults and OpOperands may not match. They may not even have the
850  // same type. The number of OpResults and OpOperands can also differ.
851  if (idx >= op->getNumResults() ||
852  opOperand.get().getType() != op->getResult(idx).getType())
853  return {};
854 
855  // The only aliasing OpResult may be the one at the same index.
856  OpResult opResult = whileOp->getResult(idx);
857  BufferRelation relation = bufferRelation(op, opResult, state);
858  return {{opResult, relation,
859  /*isDefinite=*/relation == BufferRelation::Equivalent}};
860  }
861 
862  BufferRelation bufferRelation(Operation *op, OpResult opResult,
863  const AnalysisState &state) const {
864  // WhileOp results are equivalent to their corresponding init_args if the
865  // corresponding iter_args and yield values are equivalent (for both the
866  // "before" and the "after" block).
867  unsigned int resultNumber = opResult.getResultNumber();
868  auto whileOp = cast<scf::WhileOp>(op);
869 
870  // The "before" region bbArgs and the OpResults may not match.
871  if (resultNumber >= whileOp.getBeforeArguments().size())
873  if (opResult.getType() !=
874  whileOp.getBeforeArguments()[resultNumber].getType())
876 
877  auto conditionOp = whileOp.getConditionOp();
878  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
879  Value conditionOperand = conditionOp.getArgs()[resultNumber];
880  bool equivCondition =
881  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
882 
883  auto yieldOp = whileOp.getYieldOp();
884  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
885  Value yieldOperand = yieldOp.getOperand(resultNumber);
886  bool equivYield =
887  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
888 
889  return equivCondition && equivYield ? BufferRelation::Equivalent
891  }
892 
893  bool isWritable(Operation *op, Value value,
894  const AnalysisState &state) const {
895  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
896  // inplace from the perspective of ops nested under:
897  // 1. Either the matching iter operand is not bufferized inplace and an
898  // alloc + optional copy makes the bbArg itself inplaceable.
899  // 2. Or the matching iter operand is bufferized inplace and bbArg just
900  // bufferizes to that too.
901  return true;
902  }
903 
904  LogicalResult
905  resolveConflicts(Operation *op, RewriterBase &rewriter,
906  const AnalysisState &analysisState,
907  const BufferizationState &bufferizationState) const {
908  auto bufferizableOp = cast<BufferizableOpInterface>(op);
909  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
910  rewriter, analysisState, bufferizationState)))
911  return failure();
912 
913  if (analysisState.getOptions().copyBeforeWrite)
914  return success();
915 
916  // According to the `getAliasing...` implementations, a bufferized OpResult
917  // may alias only with the corresponding bufferized init_arg and with no
918  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
919  // but not with any other OpOperand. If a corresponding OpResult/init_arg
920  // pair bufferizes to equivalent buffers, this aliasing requirement is
921  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
922  // (New buffer copies do not alias with any buffer.)
923  OpBuilder::InsertionGuard g(rewriter);
924  auto whileOp = cast<scf::WhileOp>(op);
925  auto conditionOp = whileOp.getConditionOp();
926 
927  // For every yielded value, is the value equivalent to its corresponding
928  // bbArg?
929  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
930  whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
931  DenseSet<int64_t> equivalentYieldsAfter =
932  getEquivalentBuffers(whileOp.getAfterArguments(),
933  whileOp.getYieldOp().getResults(), analysisState);
934 
935  // Update "before" region.
936  rewriter.setInsertionPoint(conditionOp);
937  SmallVector<Value> beforeYieldValues;
938  for (int64_t idx = 0;
939  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
940  Value value = conditionOp.getArgs()[idx];
941  if (!isa<TensorType>(value.getType()) ||
942  (equivalentYieldsAfter.contains(idx) &&
943  equivalentYieldsBefore.contains(idx))) {
944  beforeYieldValues.push_back(value);
945  continue;
946  }
947  FailureOr<Value> alloc = allocateTensorForShapedValue(
948  rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
949  bufferizationState);
950  if (failed(alloc))
951  return failure();
952  beforeYieldValues.push_back(*alloc);
953  }
954  rewriter.modifyOpInPlace(conditionOp, [&]() {
955  conditionOp.getArgsMutable().assign(beforeYieldValues);
956  });
957 
958  return success();
959  }
960 
961  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
963  BufferizationState &state) const {
964  auto whileOp = cast<scf::WhileOp>(op);
965 
966  // Indices of all bbArgs that have tensor type. These are the ones that
967  // are bufferized. The "before" and "after" regions may have different args.
968  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
969  DenseSet<int64_t> indicesAfter =
970  getTensorIndices(whileOp.getAfterArguments());
971 
972  // The new memref init_args of the loop.
973  FailureOr<SmallVector<Value>> maybeInitArgs =
974  getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
975  if (failed(maybeInitArgs))
976  return failure();
977  SmallVector<Value> initArgs = *maybeInitArgs;
978 
979  // Cast init_args if necessary.
980  SmallVector<Value> castedInitArgs;
981  for (const auto &it : llvm::enumerate(initArgs)) {
982  Value initArg = it.value();
983  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
984  // If the type is not a tensor, bufferization doesn't need to touch it.
985  if (!isa<TensorType>(beforeArg.getType())) {
986  castedInitArgs.push_back(initArg);
987  continue;
988  }
989  auto targetType = bufferization::getBufferType(beforeArg, options, state);
990  if (failed(targetType))
991  return failure();
992  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
993  }
994 
995  // The result types of a WhileOp are the same as the "after" bbArg types.
996  SmallVector<Type> argsTypesAfter = llvm::to_vector(
997  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
998  if (!isa<TensorType>(bbArg.getType()))
999  return bbArg.getType();
1000  // TODO: error handling
1001  return llvm::cast<Type>(
1002  *bufferization::getBufferType(bbArg, options, state));
1003  }));
1004 
1005  // Construct a new scf.while op with memref instead of tensor values.
1006  ValueRange argsRangeBefore(castedInitArgs);
1007  TypeRange argsTypesBefore(argsRangeBefore);
1008  auto newWhileOp = rewriter.create<scf::WhileOp>(
1009  whileOp.getLoc(), argsTypesAfter, castedInitArgs);
1010 
1011  // Add before/after regions to the new op.
1012  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1013  whileOp.getLoc());
1014  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1015  whileOp.getLoc());
1016  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1017  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1018  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1019  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1020 
1021  // Set up new iter_args and move the loop condition block to the new op.
1022  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1023  // in ToTensorOps.
1024  rewriter.setInsertionPointToStart(newBeforeBody);
1025  SmallVector<Value> newBeforeArgs =
1026  getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1027  whileOp.getBeforeArguments(), indicesBefore);
1028  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1029 
1030  // Set up new iter_args and move the loop body block to the new op.
1031  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1032  // in ToTensorOps.
1033  rewriter.setInsertionPointToStart(newAfterBody);
1034  SmallVector<Value> newAfterArgs =
1035  getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1036  whileOp.getAfterArguments(), indicesAfter);
1037  rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1038 
1039  // Replace loop results.
1040  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1041 
1042  return success();
1043  }
1044 
1045  FailureOr<BaseMemRefType>
1047  const BufferizationState &state,
1048  SmallVector<Value> &invocationStack) const {
1049  auto whileOp = cast<scf::WhileOp>(op);
1050  assert(getOwnerOfValue(value) == op && "invalid value");
1051  assert(isa<TensorType>(value.getType()) && "expected tensor type");
1052 
1053  // Case 1: Block argument of the "before" region.
1054  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1055  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1056  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1057  auto yieldOp = whileOp.getYieldOp();
1058  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1059  return computeLoopRegionIterArgBufferType(
1060  op, bbArg, initArg, yieldedValue, options, state, invocationStack);
1061  }
1062  }
1063 
1064  // Case 2: OpResult of the loop or block argument of the "after" region.
1065  // The bufferized "after" bbArg type can be directly computed from the
1066  // bufferized "before" bbArg type.
1067  unsigned resultNum;
1068  if (auto opResult = dyn_cast<OpResult>(value)) {
1069  resultNum = opResult.getResultNumber();
1070  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1071  &whileOp.getAfter()) {
1072  resultNum = cast<BlockArgument>(value).getArgNumber();
1073  } else {
1074  llvm_unreachable("invalid value");
1075  }
1076  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1077  if (!isa<TensorType>(conditionYieldedVal.getType())) {
1078  // scf.condition was already bufferized.
1079  return cast<BaseMemRefType>(conditionYieldedVal.getType());
1080  }
1081  return bufferization::getBufferType(conditionYieldedVal, options, state,
1082  invocationStack);
1083  }
1084 
1085  /// Assert that yielded values of an scf.while op are equivalent to their
1086  /// corresponding bbArgs. In that case, the buffer relations of the
1087  /// corresponding OpResults are "Equivalent".
1088  ///
1089  /// If this is not the case, allocs+copies are inserted and yielded from
1090  /// the loop. This could be a performance problem, so it must be explicitly
1091  /// activated with `allow-return-allocs`.
1092  ///
1093  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1094  /// equivalence condition must be checked for both.
1095  LogicalResult verifyAnalysis(Operation *op,
1096  const AnalysisState &state) const {
1097  auto whileOp = cast<scf::WhileOp>(op);
1098  const auto &options =
1099  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1100  if (options.allowReturnAllocsFromLoops)
1101  return success();
1102 
1103  auto conditionOp = whileOp.getConditionOp();
1104  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1105  Block *block = conditionOp->getBlock();
1106  if (!isa<TensorType>(it.value().getType()))
1107  continue;
1108  if (it.index() >= block->getNumArguments() ||
1109  !state.areEquivalentBufferizedValues(it.value(),
1110  block->getArgument(it.index())))
1111  return conditionOp->emitError()
1112  << "Condition arg #" << it.index()
1113  << " is not equivalent to the corresponding iter bbArg";
1114  }
1115 
1116  auto yieldOp = whileOp.getYieldOp();
1117  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1118  Block *block = yieldOp->getBlock();
1119  if (!isa<TensorType>(it.value().getType()))
1120  continue;
1121  if (it.index() >= block->getNumArguments() ||
1122  !state.areEquivalentBufferizedValues(it.value(),
1123  block->getArgument(it.index())))
1124  return yieldOp->emitError()
1125  << "Yield operand #" << it.index()
1126  << " is not equivalent to the corresponding iter bbArg";
1127  }
1128 
1129  return success();
1130  }
1131 };
1132 
1133 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1134 /// this is for analysis only.
1135 struct YieldOpInterface
1136  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1137  scf::YieldOp> {
1138  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1139  const AnalysisState &state) const {
1140  return true;
1141  }
1142 
1143  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1144  const AnalysisState &state) const {
1145  return false;
1146  }
1147 
1148  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1149  const AnalysisState &state) const {
1150  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1151  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1152  BufferRelation::Equivalent, /*isDefinite=*/false}};
1153  }
1154  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1155  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1157  return {};
1158  }
1159 
1160  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1161  const AnalysisState &state) const {
1162  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1163  // may be generated inside the block. We should not return/yield allocations
1164  // when possible.
1165  return true;
1166  }
1167 
1168  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1170  BufferizationState &state) const {
1171  auto yieldOp = cast<scf::YieldOp>(op);
1172  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1173  scf::WhileOp>(yieldOp->getParentOp()))
1174  return yieldOp->emitError("unsupported scf::YieldOp parent");
1175 
1176  SmallVector<Value> newResults;
1177  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1178  Value value = it.value();
1179  if (isa<TensorType>(value.getType())) {
1180  FailureOr<Value> maybeBuffer =
1181  getBuffer(rewriter, value, options, state);
1182  if (failed(maybeBuffer))
1183  return failure();
1184  Value buffer = *maybeBuffer;
1185  // We may have to cast the value before yielding it.
1186  if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1187  yieldOp->getParentOp())) {
1188  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1189  yieldOp->getParentOp()->getResult(it.index()), options, state);
1190  if (failed(resultType))
1191  return failure();
1192  buffer = castBuffer(rewriter, buffer, *resultType);
1193  } else if (auto whileOp =
1194  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1195  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1196  whileOp.getBeforeArguments()[it.index()], options, state);
1197  if (failed(resultType))
1198  return failure();
1199  buffer = castBuffer(rewriter, buffer, *resultType);
1200  }
1201  newResults.push_back(buffer);
1202  } else {
1203  newResults.push_back(value);
1204  }
1205  }
1206 
1207  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1208  return success();
1209  }
1210 };
1211 
1212 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1213 /// region. There are op interfaces for the terminators (InParallelOp
1214 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1215 /// for bufferization.
1216 struct ForallOpInterface
1217  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1218  ForallOp> {
1219  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1220  const AnalysisState &state) const {
1221  // All tensor operands to `scf.forall` are `shared_outs` and all
1222  // shared outs are assumed to be read by the loop. This does not
1223  // account for the case where the entire value is over-written,
1224  // but being conservative here.
1225  return true;
1226  }
1227 
1228  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1229  const AnalysisState &state) const {
1230  // Outputs of scf::ForallOps are always considered as a write.
1231  return true;
1232  }
1233 
1234  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1235  const AnalysisState &state) const {
1236  auto forallOp = cast<ForallOp>(op);
1237  return {
1238  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1239  }
1240 
1241  bool isWritable(Operation *op, Value value,
1242  const AnalysisState &state) const {
1243  return true;
1244  }
1245 
1246  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1248  BufferizationState &state) const {
1249  OpBuilder::InsertionGuard guard(rewriter);
1250  auto forallOp = cast<ForallOp>(op);
1251  int64_t rank = forallOp.getRank();
1252 
1253  // Get buffers for all output operands.
1254  SmallVector<Value> buffers;
1255  for (Value out : forallOp.getOutputs()) {
1256  FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1257  if (failed(buffer))
1258  return failure();
1259  buffers.push_back(*buffer);
1260  }
1261 
1262  // Use buffers instead of block arguments.
1263  rewriter.setInsertionPointToStart(forallOp.getBody());
1264  for (const auto &it : llvm::zip(
1265  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1266  BlockArgument bbArg = std::get<0>(it);
1267  Value buffer = std::get<1>(it);
1268  Value bufferAsTensor = rewriter.create<ToTensorOp>(
1269  forallOp.getLoc(), bbArg.getType(), buffer);
1270  bbArg.replaceAllUsesWith(bufferAsTensor);
1271  }
1272 
1273  // Create new ForallOp without any results and drop the automatically
1274  // introduced terminator.
1275  rewriter.setInsertionPoint(forallOp);
1276  ForallOp newForallOp;
1277  newForallOp = rewriter.create<ForallOp>(
1278  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1279  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1280  /*outputs=*/ValueRange(), forallOp.getMapping());
1281 
1282  // Keep discardable attributes from the original op.
1284 
1285  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1286 
1287  // Move over block contents of the old op.
1288  SmallVector<Value> replacementBbArgs;
1289  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1290  newForallOp.getBody()->getArguments().end());
1291  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1292  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1293  replacementBbArgs);
1294 
1295  // Remove the old op and replace all of its uses.
1296  replaceOpWithBufferizedValues(rewriter, op, buffers);
1297 
1298  return success();
1299  }
1300 
1301  FailureOr<BaseMemRefType>
1303  const BufferizationState &state,
1304  SmallVector<Value> &invocationStack) const {
1305  auto forallOp = cast<ForallOp>(op);
1306 
1307  if (auto bbArg = dyn_cast<BlockArgument>(value))
1308  // A tensor block argument has the same bufferized type as the
1309  // corresponding output operand.
1311  forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1312  invocationStack);
1313 
1314  // The bufferized result type is the same as the bufferized type of the
1315  // corresponding output operand.
1317  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1318  state, invocationStack);
1319  }
1320 
1321  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1322  auto forallOp = cast<ForallOp>(op);
1323 
1324  // This op is repetitive if it has 1 or more steps.
1325  // If the control variables are dynamic, it is also considered so.
1326  for (auto [lb, ub, step] :
1327  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1328  forallOp.getMixedStep())) {
1329  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1330  if (!lbConstant)
1331  return true;
1332 
1333  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1334  if (!ubConstant)
1335  return true;
1336 
1337  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1338  if (!stepConstant)
1339  return true;
1340 
1341  if (*lbConstant + *stepConstant < *ubConstant)
1342  return true;
1343  }
1344  return false;
1345  }
1346 
1347  bool isParallelRegion(Operation *op, unsigned index) const {
1348  return isRepetitiveRegion(op, index);
1349  }
1350 };
1351 
1352 /// Nothing to do for InParallelOp.
1353 struct InParallelOpInterface
1354  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1355  InParallelOp> {
1356  LogicalResult bufferize(Operation *op, RewriterBase &b,
1358  BufferizationState &state) const {
1359  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1360  return failure();
1361  }
1362 };
1363 
1364 } // namespace
1365 } // namespace scf
1366 } // namespace mlir
1367 
1369  DialectRegistry &registry) {
1370  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1371  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1372  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1373  ForOp::attachInterface<ForOpInterface>(*ctx);
1374  IfOp::attachInterface<IfOpInterface>(*ctx);
1375  IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1376  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1377  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1378  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1379  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1380  });
1381 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
result_range getOpResults()
Definition: Operation.h:420
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition: Region.cpp:50
BlockListType & getBlocks()
Definition: Region.h:45
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:398
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...