MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::scf;
27 
28 namespace mlir {
29 namespace scf {
30 namespace {
31 
32 /// Helper function for loop bufferization. Cast the given buffer to the given
33 /// memref type.
34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37  // If the buffer already has the correct type, no cast is needed.
38  if (buffer.getType() == type)
39  return buffer;
40  // TODO: In case `type` has a layout map that is not the fully dynamic
41  // one, we may not be able to cast the buffer. In that case, the loop
42  // iter_arg's layout map must be changed (see uses of `castBuffer`).
43  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44  "scf.while op bufferization: cast incompatible");
45  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46 }
47 
48 /// Helper function for loop bufferization. Return "true" if the given value
49 /// is guaranteed to not alias with an external tensor apart from values in
50 /// `exceptions`. A value is external if it is defined outside of the given
51 /// region or if it is an entry block argument of the region.
52 static bool doesNotAliasExternalValue(Value value, Region *region,
53  ValueRange exceptions,
54  const OneShotAnalysisState &state) {
55  assert(llvm::hasSingleElement(region->getBlocks()) &&
56  "expected region with single block");
57  bool result = true;
58  state.applyOnAliases(value, [&](Value alias) {
59  if (llvm::is_contained(exceptions, alias))
60  return;
61  Region *aliasRegion = alias.getParentRegion();
62  if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63  result = false;
64  if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65  result = false;
66  });
67  return result;
68 }
69 
70 /// Bufferization of scf.condition.
71 struct ConditionOpInterface
72  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73  scf::ConditionOp> {
74  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75  const AnalysisState &state) const {
76  return true;
77  }
78 
79  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80  const AnalysisState &state) const {
81  return false;
82  }
83 
84  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85  const AnalysisState &state) const {
86  return {};
87  }
88 
89  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90  const AnalysisState &state) const {
91  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92  // may be generated inside the block. We should not return/yield allocations
93  // when possible.
94  return true;
95  }
96 
97  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98  const BufferizationOptions &options) const {
99  auto conditionOp = cast<scf::ConditionOp>(op);
100  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101 
102  SmallVector<Value> newArgs;
103  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104  Value value = it.value();
105  if (isa<TensorType>(value.getType())) {
106  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
107  if (failed(maybeBuffer))
108  return failure();
109  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
110  whileOp.getAfterArguments()[it.index()], options);
111  if (failed(resultType))
112  return failure();
113  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114  newArgs.push_back(buffer);
115  } else {
116  newArgs.push_back(value);
117  }
118  }
119 
120  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121  rewriter, op, conditionOp.getCondition(), newArgs);
122  return success();
123  }
124 };
125 
126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127 /// return an empty op.
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129  scf::YieldOp result;
130  for (Block &block : executeRegionOp.getRegion()) {
131  if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132  if (result)
133  return {};
134  result = yieldOp;
135  }
136  }
137  return result;
138 }
139 
140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141 /// fully implemented at the moment.
142 struct ExecuteRegionOpInterface
144  ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145 
146  static bool supportsUnstructuredControlFlow() { return true; }
147 
148  bool isWritable(Operation *op, Value value,
149  const AnalysisState &state) const {
150  return true;
151  }
152 
153  LogicalResult verifyAnalysis(Operation *op,
154  const AnalysisState &state) const {
155  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156  // TODO: scf.execute_region with multiple yields are not supported.
157  if (!getUniqueYieldOp(executeRegionOp))
158  return op->emitOpError("op without unique scf.yield is not supported");
159  return success();
160  }
161 
163  getAliasingOpOperands(Operation *op, Value value,
164  const AnalysisState &state) const {
165  if (auto bbArg = dyn_cast<BlockArgument>(value))
166  return getAliasingBranchOpOperands(op, bbArg, state);
167 
168  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169  // any SSA value that is in scope. To allow for use-def chain traversal
170  // through ExecuteRegionOps in the analysis, the corresponding yield value
171  // is considered to be aliasing with the result.
172  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173  auto it = llvm::find(op->getOpResults(), value);
174  assert(it != op->getOpResults().end() && "invalid value");
175  size_t resultNum = std::distance(op->getOpResults().begin(), it);
176  auto yieldOp = getUniqueYieldOp(executeRegionOp);
177  // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178  if (!yieldOp)
179  return {};
180  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181  }
182 
183  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184  const BufferizationOptions &options) const {
185  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186  auto yieldOp = getUniqueYieldOp(executeRegionOp);
187  TypeRange newResultTypes(yieldOp.getResults());
188 
189  // Create new op and move over region.
190  auto newOp =
191  rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192  newOp.getRegion().takeBody(executeRegionOp.getRegion());
193 
194  // Bufferize every block.
195  for (Block &block : newOp.getRegion())
196  if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
197  options)))
198  return failure();
199 
200  // Update all uses of the old op.
201  rewriter.setInsertionPointAfter(newOp);
202  SmallVector<Value> newResults;
203  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204  if (isa<TensorType>(it.value())) {
205  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206  executeRegionOp.getLoc(), it.value(),
207  newOp->getResult(it.index())));
208  } else {
209  newResults.push_back(newOp->getResult(it.index()));
210  }
211  }
212 
213  // Replace old op.
214  rewriter.replaceOp(executeRegionOp, newResults);
215 
216  return success();
217  }
218 };
219 
220 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
221 struct IfOpInterface
222  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
224  getAliasingOpOperands(Operation *op, Value value,
225  const AnalysisState &state) const {
226  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
227  // value that is in scope. To allow for use-def chain traversal through
228  // IfOps in the analysis, both corresponding yield values from the then/else
229  // branches are considered to be aliasing with the result.
230  auto ifOp = cast<scf::IfOp>(op);
231  size_t resultNum = std::distance(op->getOpResults().begin(),
232  llvm::find(op->getOpResults(), value));
233  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
234  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
235  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
236  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
237  }
238 
239  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
240  const BufferizationOptions &options) const {
241  OpBuilder::InsertionGuard g(rewriter);
242  auto ifOp = cast<scf::IfOp>(op);
243 
244  // Compute bufferized result types.
245  SmallVector<Type> newTypes;
246  for (Value result : ifOp.getResults()) {
247  if (!isa<TensorType>(result.getType())) {
248  newTypes.push_back(result.getType());
249  continue;
250  }
251  auto bufferType = bufferization::getBufferType(result, options);
252  if (failed(bufferType))
253  return failure();
254  newTypes.push_back(*bufferType);
255  }
256 
257  // Create new op.
258  rewriter.setInsertionPoint(ifOp);
259  auto newIfOp =
260  rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
261  /*withElseRegion=*/true);
262 
263  // Move over then/else blocks.
264  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
265  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
266 
267  // Replace op results.
268  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
269 
270  return success();
271  }
272 
273  FailureOr<BaseMemRefType>
275  SmallVector<Value> &invocationStack) const {
276  auto ifOp = cast<scf::IfOp>(op);
277  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
278  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
279  assert(value.getDefiningOp() == op && "invalid valid");
280 
281  // Determine buffer types of the true/false branches.
282  auto opResult = cast<OpResult>(value);
283  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
284  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
285  BaseMemRefType thenBufferType, elseBufferType;
286  if (isa<BaseMemRefType>(thenValue.getType())) {
287  // True branch was already bufferized.
288  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
289  } else {
290  auto maybeBufferType =
291  bufferization::getBufferType(thenValue, options, invocationStack);
292  if (failed(maybeBufferType))
293  return failure();
294  thenBufferType = *maybeBufferType;
295  }
296  if (isa<BaseMemRefType>(elseValue.getType())) {
297  // False branch was already bufferized.
298  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
299  } else {
300  auto maybeBufferType =
301  bufferization::getBufferType(elseValue, options, invocationStack);
302  if (failed(maybeBufferType))
303  return failure();
304  elseBufferType = *maybeBufferType;
305  }
306 
307  // Best case: Both branches have the exact same buffer type.
308  if (thenBufferType == elseBufferType)
309  return thenBufferType;
310 
311  // Memory space mismatch.
312  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
313  return op->emitError("inconsistent memory space on then/else branches");
314 
315  // Layout maps are different: Promote to fully dynamic layout map.
317  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
318  }
319 };
320 
321 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
322 /// yields memrefs.
323 struct IndexSwitchOpInterface
324  : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
325  scf::IndexSwitchOp> {
327  getAliasingOpOperands(Operation *op, Value value,
328  const AnalysisState &state) const {
329  // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
330  // any SSA. This is similar to IfOps.
331  auto switchOp = cast<scf::IndexSwitchOp>(op);
332  int64_t resultNum = cast<OpResult>(value).getResultNumber();
333  AliasingOpOperandList result;
334  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
335  auto yieldOp =
336  cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
337  result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
339  /*isDefinite=*/false));
340  }
341  auto defaultYieldOp =
342  cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
343  result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
345  /*isDefinite=*/false));
346  return result;
347  }
348 
349  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
350  const BufferizationOptions &options) const {
351  OpBuilder::InsertionGuard g(rewriter);
352  auto switchOp = cast<scf::IndexSwitchOp>(op);
353 
354  // Compute bufferized result types.
355  SmallVector<Type> newTypes;
356  for (Value result : switchOp.getResults()) {
357  if (!isa<TensorType>(result.getType())) {
358  newTypes.push_back(result.getType());
359  continue;
360  }
361  auto bufferType = bufferization::getBufferType(result, options);
362  if (failed(bufferType))
363  return failure();
364  newTypes.push_back(*bufferType);
365  }
366 
367  // Create new op.
368  rewriter.setInsertionPoint(switchOp);
369  auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
370  switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
371  switchOp.getCases().size());
372 
373  // Move over blocks.
374  for (auto [src, dest] :
375  llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
376  rewriter.inlineRegionBefore(src, dest, dest.begin());
377  rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
378  newSwitchOp.getDefaultRegion(),
379  newSwitchOp.getDefaultRegion().begin());
380 
381  // Replace op results.
382  replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
383 
384  return success();
385  }
386 
387  FailureOr<BaseMemRefType>
389  SmallVector<Value> &invocationStack) const {
390  auto switchOp = cast<scf::IndexSwitchOp>(op);
391  assert(value.getDefiningOp() == op && "invalid value");
392  int64_t resultNum = cast<OpResult>(value).getResultNumber();
393 
394  // Helper function to get buffer type of a case.
395  SmallVector<BaseMemRefType> yieldedTypes;
396  auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
397  auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
398  Value yieldedValue = yieldOp->getOperand(resultNum);
399  if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
400  return bufferType;
401  auto maybeBufferType =
402  bufferization::getBufferType(yieldedValue, options, invocationStack);
403  if (failed(maybeBufferType))
404  return failure();
405  return maybeBufferType;
406  };
407 
408  // Compute buffer type of the default case.
409  auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
410  if (failed(maybeBufferType))
411  return failure();
412  BaseMemRefType bufferType = *maybeBufferType;
413 
414  // Compute buffer types of all other cases.
415  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
416  auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
417  if (failed(yieldedBufferType))
418  return failure();
419 
420  // Best case: Both branches have the exact same buffer type.
421  if (bufferType == *yieldedBufferType)
422  continue;
423 
424  // Memory space mismatch.
425  if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
426  return op->emitError("inconsistent memory space on switch cases");
427 
428  // Layout maps are different: Promote to fully dynamic layout map.
430  cast<TensorType>(value.getType()), bufferType.getMemorySpace());
431  }
432 
433  return bufferType;
434  }
435 };
436 
437 /// Helper function for loop bufferization. Return the indices of all values
438 /// that have a tensor type.
439 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
440  DenseSet<int64_t> result;
441  for (const auto &it : llvm::enumerate(values))
442  if (isa<TensorType>(it.value().getType()))
443  result.insert(it.index());
444  return result;
445 }
446 
447 /// Helper function for loop bufferization. Return the indices of all
448 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
449 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
450  ValueRange yieldedValues,
451  const AnalysisState &state) {
452  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
453  DenseSet<int64_t> result;
454  for (unsigned int i = 0; i < minSize; ++i) {
455  if (!isa<TensorType>(bbArgs[i].getType()) ||
456  !isa<TensorType>(yieldedValues[i].getType()))
457  continue;
458  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
459  result.insert(i);
460  }
461  return result;
462 }
463 
464 /// Helper function for loop bufferization. Return the bufferized values of the
465 /// given OpOperands. If an operand is not a tensor, return the original value.
466 static FailureOr<SmallVector<Value>>
467 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
468  const BufferizationOptions &options) {
469  SmallVector<Value> result;
470  for (OpOperand &opOperand : operands) {
471  if (isa<TensorType>(opOperand.get().getType())) {
472  FailureOr<Value> resultBuffer =
473  getBuffer(rewriter, opOperand.get(), options);
474  if (failed(resultBuffer))
475  return failure();
476  result.push_back(*resultBuffer);
477  } else {
478  result.push_back(opOperand.get());
479  }
480  }
481  return result;
482 }
483 
484 /// Helper function for loop bufferization. Given a list of bbArgs of the new
485 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
486 /// ToTensorOps, so that the block body can be moved over to the new op.
487 static SmallVector<Value>
488 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
489  Block::BlockArgListType oldBbArgs,
490  const DenseSet<int64_t> &tensorIndices) {
491  SmallVector<Value> result;
492  for (const auto &it : llvm::enumerate(bbArgs)) {
493  size_t idx = it.index();
494  Value val = it.value();
495  if (tensorIndices.contains(idx)) {
496  result.push_back(rewriter
497  .create<bufferization::ToTensorOp>(
498  val.getLoc(), oldBbArgs[idx].getType(), val)
499  .getResult());
500  } else {
501  result.push_back(val);
502  }
503  }
504  return result;
505 }
506 
507 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
508 /// the bufferized type of the corresponding init_arg and the bufferized type
509 /// of the corresponding yielded value.
510 ///
511 /// This function uses bufferization::getBufferType to compute the bufferized
512 /// type of the init_arg and of the yielded value. (The computation of the
513 /// bufferized yielded value type usually requires computing the bufferized type
514 /// of the iter_arg again; the implementation of getBufferType traces back the
515 /// use-def chain of the given value and computes a buffer type along the way.)
516 /// If both buffer types are equal, no casts are needed the computed buffer type
517 /// can be used directly. Otherwise, the buffer types can only differ in their
518 /// layout map and a cast must be inserted.
519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
520  Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
521  const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
522  // Determine the buffer type of the init_arg.
523  auto initArgBufferType =
524  bufferization::getBufferType(initArg, options, invocationStack);
525  if (failed(initArgBufferType))
526  return failure();
527 
528  if (llvm::count(invocationStack, iterArg) >= 2) {
529  // If the iter_arg is already twice on the invocation stack, just take the
530  // type of the init_arg. This is to avoid infinite loops when calculating
531  // the buffer type. This will most likely result in computing a memref type
532  // with a fully dynamic layout map.
533 
534  // Note: For more precise layout map computation, a fixpoint iteration could
535  // be done (i.e., re-computing the yielded buffer type until the bufferized
536  // iter_arg type no longer changes). This current implementation immediately
537  // switches to a fully dynamic layout map when a mismatch between bufferized
538  // init_arg type and bufferized yield value type is detected.
539  return *initArgBufferType;
540  }
541 
542  // Compute the buffer type of the yielded value.
543  BaseMemRefType yieldedValueBufferType;
544  if (isa<BaseMemRefType>(yieldedValue.getType())) {
545  // scf.yield was already bufferized.
546  yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
547  } else {
548  // Note: This typically triggers a recursive call for the buffer type of
549  // the iter_arg.
550  auto maybeBufferType =
551  bufferization::getBufferType(yieldedValue, options, invocationStack);
552  if (failed(maybeBufferType))
553  return failure();
554  yieldedValueBufferType = *maybeBufferType;
555  }
556 
557  // If yielded type and init_arg type are the same, use that type directly.
558  if (*initArgBufferType == yieldedValueBufferType)
559  return yieldedValueBufferType;
560 
561  // If there is a mismatch between the yielded buffer type and the init_arg
562  // buffer type, the buffer type must be promoted to a fully dynamic layout
563  // map.
564  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
565  auto iterTensorType = cast<TensorType>(iterArg.getType());
566  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
567  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
568  return loopOp->emitOpError(
569  "init_arg and yielded value bufferize to inconsistent memory spaces");
570 #ifndef NDEBUG
571  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
572  assert(
573  llvm::all_equal({yieldedRankedBufferType.getShape(),
574  cast<MemRefType>(initBufferType).getShape(),
575  cast<RankedTensorType>(iterTensorType).getShape()}) &&
576  "expected same shape");
577  }
578 #endif // NDEBUG
580  iterTensorType, yieldedBufferType.getMemorySpace());
581 }
582 
583 /// Return `true` if the given loop may have 0 iterations.
584 bool mayHaveZeroIterations(scf::ForOp forOp) {
585  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
586  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
587  if (!lb.has_value() || !ub.has_value())
588  return true;
589  return *ub <= *lb;
590 }
591 
592 /// Bufferization of scf.for. Replace with a new scf.for that operates on
593 /// memrefs.
594 struct ForOpInterface
595  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
596  scf::ForOp> {
597  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
598  const AnalysisState &state) const {
599  auto forOp = cast<scf::ForOp>(op);
600 
601  // If the loop has zero iterations, the results of the op are their
602  // corresponding init_args, meaning that the init_args bufferize to a read.
603  if (mayHaveZeroIterations(forOp))
604  return true;
605 
606  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
607  // its matching bbArg may.
608  return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
609  }
610 
611  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
612  const AnalysisState &state) const {
613  // Tensor iter_args of scf::ForOps are always considered as a write.
614  return true;
615  }
616 
617  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
618  const AnalysisState &state) const {
619  auto forOp = cast<scf::ForOp>(op);
620  OpResult opResult = forOp.getTiedLoopResult(&opOperand);
621  BufferRelation relation = bufferRelation(op, opResult, state);
622  return {{opResult, relation,
623  /*isDefinite=*/relation == BufferRelation::Equivalent}};
624  }
625 
626  BufferRelation bufferRelation(Operation *op, OpResult opResult,
627  const AnalysisState &state) const {
628  // ForOp results are equivalent to their corresponding init_args if the
629  // corresponding iter_args and yield values are equivalent.
630  auto forOp = cast<scf::ForOp>(op);
631  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
632  bool equivalentYield = state.areEquivalentBufferizedValues(
633  bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
634  return equivalentYield ? BufferRelation::Equivalent
636  }
637 
638  bool isWritable(Operation *op, Value value,
639  const AnalysisState &state) const {
640  // Interestingly, scf::ForOp's bbArg can **always** be viewed
641  // inplace from the perspective of ops nested under:
642  // 1. Either the matching iter operand is not bufferized inplace and an
643  // alloc + optional copy makes the bbArg itself inplaceable.
644  // 2. Or the matching iter operand is bufferized inplace and bbArg just
645  // bufferizes to that too.
646  return true;
647  }
648 
649  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
650  const AnalysisState &state) const {
651  auto bufferizableOp = cast<BufferizableOpInterface>(op);
652  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
653  return failure();
654 
655  if (state.getOptions().copyBeforeWrite)
656  return success();
657 
658  // According to the `getAliasing...` implementations, a bufferized OpResult
659  // may alias only with the corresponding bufferized init_arg (or with a
660  // newly allocated buffer) and not with other buffers defined outside of the
661  // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
662  // but not with any other OpOperand.
663  auto forOp = cast<scf::ForOp>(op);
664  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
665  OpBuilder::InsertionGuard g(rewriter);
666  rewriter.setInsertionPoint(yieldOp);
667 
668  // Indices of all iter_args that have tensor type. These are the ones that
669  // are bufferized.
670  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
671  // For every yielded value, does it alias with something defined outside of
672  // the loop?
673  SmallVector<Value> yieldValues;
674  for (const auto it : llvm::enumerate(yieldOp.getResults())) {
675  // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
676  // type cannot be used in the signature of `resolveConflicts` because the
677  // op interface is in the "IR" build unit and the `OneShotAnalysisState`
678  // is defined in the "Transforms" build unit.
679  if (!indices.contains(it.index()) ||
680  doesNotAliasExternalValue(
681  it.value(), &forOp.getRegion(),
682  /*exceptions=*/forOp.getRegionIterArg(it.index()),
683  static_cast<const OneShotAnalysisState &>(state))) {
684  yieldValues.push_back(it.value());
685  continue;
686  }
687  FailureOr<Value> alloc = allocateTensorForShapedValue(
688  rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
689  if (failed(alloc))
690  return failure();
691  yieldValues.push_back(*alloc);
692  }
693 
694  rewriter.modifyOpInPlace(
695  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
696  return success();
697  }
698 
699  FailureOr<BaseMemRefType>
701  SmallVector<Value> &invocationStack) const {
702  auto forOp = cast<scf::ForOp>(op);
703  assert(getOwnerOfValue(value) == op && "invalid value");
704  assert(isa<TensorType>(value.getType()) && "expected tensor type");
705 
706  if (auto opResult = dyn_cast<OpResult>(value)) {
707  // The type of an OpResult must match the corresponding iter_arg type.
708  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
709  return bufferization::getBufferType(bbArg, options, invocationStack);
710  }
711 
712  // Compute result/argument number.
713  BlockArgument bbArg = cast<BlockArgument>(value);
714  unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
715 
716  // Compute the bufferized type.
717  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
718  Value yieldedValue = yieldOp.getOperand(resultNum);
719  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
720  Value initArg = forOp.getInitArgs()[resultNum];
721  return computeLoopRegionIterArgBufferType(
722  op, iterArg, initArg, yieldedValue, options, invocationStack);
723  }
724 
725  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
726  const BufferizationOptions &options) const {
727  auto forOp = cast<scf::ForOp>(op);
728  Block *oldLoopBody = forOp.getBody();
729 
730  // Indices of all iter_args that have tensor type. These are the ones that
731  // are bufferized.
732  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
733 
734  // The new memref init_args of the loop.
735  FailureOr<SmallVector<Value>> maybeInitArgs =
736  getBuffers(rewriter, forOp.getInitArgsMutable(), options);
737  if (failed(maybeInitArgs))
738  return failure();
739  SmallVector<Value> initArgs = *maybeInitArgs;
740 
741  // Cast init_args if necessary.
742  SmallVector<Value> castedInitArgs;
743  for (const auto &it : llvm::enumerate(initArgs)) {
744  Value initArg = it.value();
745  Value result = forOp->getResult(it.index());
746  // If the type is not a tensor, bufferization doesn't need to touch it.
747  if (!isa<TensorType>(result.getType())) {
748  castedInitArgs.push_back(initArg);
749  continue;
750  }
751  auto targetType = bufferization::getBufferType(result, options);
752  if (failed(targetType))
753  return failure();
754  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
755  }
756 
757  // Construct a new scf.for op with memref instead of tensor values.
758  auto newForOp = rewriter.create<scf::ForOp>(
759  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
760  forOp.getStep(), castedInitArgs);
761  newForOp->setAttrs(forOp->getAttrs());
762  Block *loopBody = newForOp.getBody();
763 
764  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
765  // iter_args of the new loop in ToTensorOps.
766  rewriter.setInsertionPointToStart(loopBody);
767  SmallVector<Value> iterArgs =
768  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
769  forOp.getRegionIterArgs(), indices);
770  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
771 
772  // Move loop body to new loop.
773  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
774 
775  // Replace loop results.
776  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
777 
778  return success();
779  }
780 
781  /// Assert that yielded values of an scf.for op are equivalent to their
782  /// corresponding bbArgs. In that case, the buffer relations of the
783  /// corresponding OpResults are "Equivalent".
784  ///
785  /// If this is not the case, an allocs+copies are inserted and yielded from
786  /// the loop. This could be a performance problem, so it must be explicitly
787  /// activated with `alloc-return-allocs`.
788  LogicalResult verifyAnalysis(Operation *op,
789  const AnalysisState &state) const {
790  const auto &options =
791  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
792  if (options.allowReturnAllocsFromLoops)
793  return success();
794 
795  auto forOp = cast<scf::ForOp>(op);
796  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
797  for (OpResult opResult : op->getOpResults()) {
798  if (!isa<TensorType>(opResult.getType()))
799  continue;
800 
801  // Note: This is overly strict. We should check for aliasing bufferized
802  // values. But we don't have a "must-alias" analysis yet.
803  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
804  return yieldOp->emitError()
805  << "Yield operand #" << opResult.getResultNumber()
806  << " is not equivalent to the corresponding iter bbArg";
807  }
808 
809  return success();
810  }
811 };
812 
813 /// Bufferization of scf.while. Replace with a new scf.while that operates on
814 /// memrefs.
815 struct WhileOpInterface
816  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
817  scf::WhileOp> {
818  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
819  const AnalysisState &state) const {
820  // Tensor iter_args of scf::WhileOps are always considered as a read.
821  return true;
822  }
823 
824  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
825  const AnalysisState &state) const {
826  // Tensor iter_args of scf::WhileOps are always considered as a write.
827  return true;
828  }
829 
830  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
831  const AnalysisState &state) const {
832  auto whileOp = cast<scf::WhileOp>(op);
833  unsigned int idx = opOperand.getOperandNumber();
834 
835  // The OpResults and OpOperands may not match. They may not even have the
836  // same type. The number of OpResults and OpOperands can also differ.
837  if (idx >= op->getNumResults() ||
838  opOperand.get().getType() != op->getResult(idx).getType())
839  return {};
840 
841  // The only aliasing OpResult may be the one at the same index.
842  OpResult opResult = whileOp->getResult(idx);
843  BufferRelation relation = bufferRelation(op, opResult, state);
844  return {{opResult, relation,
845  /*isDefinite=*/relation == BufferRelation::Equivalent}};
846  }
847 
848  BufferRelation bufferRelation(Operation *op, OpResult opResult,
849  const AnalysisState &state) const {
850  // WhileOp results are equivalent to their corresponding init_args if the
851  // corresponding iter_args and yield values are equivalent (for both the
852  // "before" and the "after" block).
853  unsigned int resultNumber = opResult.getResultNumber();
854  auto whileOp = cast<scf::WhileOp>(op);
855 
856  // The "before" region bbArgs and the OpResults may not match.
857  if (resultNumber >= whileOp.getBeforeArguments().size())
859  if (opResult.getType() !=
860  whileOp.getBeforeArguments()[resultNumber].getType())
862 
863  auto conditionOp = whileOp.getConditionOp();
864  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
865  Value conditionOperand = conditionOp.getArgs()[resultNumber];
866  bool equivCondition =
867  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
868 
869  auto yieldOp = whileOp.getYieldOp();
870  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
871  Value yieldOperand = yieldOp.getOperand(resultNumber);
872  bool equivYield =
873  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
874 
875  return equivCondition && equivYield ? BufferRelation::Equivalent
877  }
878 
879  bool isWritable(Operation *op, Value value,
880  const AnalysisState &state) const {
881  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
882  // inplace from the perspective of ops nested under:
883  // 1. Either the matching iter operand is not bufferized inplace and an
884  // alloc + optional copy makes the bbArg itself inplaceable.
885  // 2. Or the matching iter operand is bufferized inplace and bbArg just
886  // bufferizes to that too.
887  return true;
888  }
889 
890  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
891  const AnalysisState &state) const {
892  auto bufferizableOp = cast<BufferizableOpInterface>(op);
893  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
894  return failure();
895 
896  if (state.getOptions().copyBeforeWrite)
897  return success();
898 
899  // According to the `getAliasing...` implementations, a bufferized OpResult
900  // may alias only with the corresponding bufferized init_arg and with no
901  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
902  // but not with any other OpOperand. If a corresponding OpResult/init_arg
903  // pair bufferizes to equivalent buffers, this aliasing requirement is
904  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
905  // (New buffer copies do not alias with any buffer.)
906  OpBuilder::InsertionGuard g(rewriter);
907  auto whileOp = cast<scf::WhileOp>(op);
908  auto conditionOp = whileOp.getConditionOp();
909 
910  // For every yielded value, is the value equivalent to its corresponding
911  // bbArg?
912  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
913  whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
914  DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
915  whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
916 
917  // Update "before" region.
918  rewriter.setInsertionPoint(conditionOp);
919  SmallVector<Value> beforeYieldValues;
920  for (int64_t idx = 0;
921  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
922  Value value = conditionOp.getArgs()[idx];
923  if (!isa<TensorType>(value.getType()) ||
924  (equivalentYieldsAfter.contains(idx) &&
925  equivalentYieldsBefore.contains(idx))) {
926  beforeYieldValues.push_back(value);
927  continue;
928  }
929  FailureOr<Value> alloc = allocateTensorForShapedValue(
930  rewriter, conditionOp.getLoc(), value, state.getOptions());
931  if (failed(alloc))
932  return failure();
933  beforeYieldValues.push_back(*alloc);
934  }
935  rewriter.modifyOpInPlace(conditionOp, [&]() {
936  conditionOp.getArgsMutable().assign(beforeYieldValues);
937  });
938 
939  return success();
940  }
941 
942  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
943  const BufferizationOptions &options) const {
944  auto whileOp = cast<scf::WhileOp>(op);
945 
946  // Indices of all bbArgs that have tensor type. These are the ones that
947  // are bufferized. The "before" and "after" regions may have different args.
948  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
949  DenseSet<int64_t> indicesAfter =
950  getTensorIndices(whileOp.getAfterArguments());
951 
952  // The new memref init_args of the loop.
953  FailureOr<SmallVector<Value>> maybeInitArgs =
954  getBuffers(rewriter, whileOp.getInitsMutable(), options);
955  if (failed(maybeInitArgs))
956  return failure();
957  SmallVector<Value> initArgs = *maybeInitArgs;
958 
959  // Cast init_args if necessary.
960  SmallVector<Value> castedInitArgs;
961  for (const auto &it : llvm::enumerate(initArgs)) {
962  Value initArg = it.value();
963  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
964  // If the type is not a tensor, bufferization doesn't need to touch it.
965  if (!isa<TensorType>(beforeArg.getType())) {
966  castedInitArgs.push_back(initArg);
967  continue;
968  }
969  auto targetType = bufferization::getBufferType(beforeArg, options);
970  if (failed(targetType))
971  return failure();
972  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
973  }
974 
975  // The result types of a WhileOp are the same as the "after" bbArg types.
976  SmallVector<Type> argsTypesAfter = llvm::to_vector(
977  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
978  if (!isa<TensorType>(bbArg.getType()))
979  return bbArg.getType();
980  // TODO: error handling
981  return llvm::cast<Type>(
982  *bufferization::getBufferType(bbArg, options));
983  }));
984 
985  // Construct a new scf.while op with memref instead of tensor values.
986  ValueRange argsRangeBefore(castedInitArgs);
987  TypeRange argsTypesBefore(argsRangeBefore);
988  auto newWhileOp = rewriter.create<scf::WhileOp>(
989  whileOp.getLoc(), argsTypesAfter, castedInitArgs);
990 
991  // Add before/after regions to the new op.
992  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
993  whileOp.getLoc());
994  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
995  whileOp.getLoc());
996  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
997  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
998  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
999  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1000 
1001  // Set up new iter_args and move the loop condition block to the new op.
1002  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1003  // in ToTensorOps.
1004  rewriter.setInsertionPointToStart(newBeforeBody);
1005  SmallVector<Value> newBeforeArgs =
1006  getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1007  whileOp.getBeforeArguments(), indicesBefore);
1008  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1009 
1010  // Set up new iter_args and move the loop body block to the new op.
1011  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1012  // in ToTensorOps.
1013  rewriter.setInsertionPointToStart(newAfterBody);
1014  SmallVector<Value> newAfterArgs =
1015  getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1016  whileOp.getAfterArguments(), indicesAfter);
1017  rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1018 
1019  // Replace loop results.
1020  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1021 
1022  return success();
1023  }
1024 
1025  FailureOr<BaseMemRefType>
1027  SmallVector<Value> &invocationStack) const {
1028  auto whileOp = cast<scf::WhileOp>(op);
1029  assert(getOwnerOfValue(value) == op && "invalid value");
1030  assert(isa<TensorType>(value.getType()) && "expected tensor type");
1031 
1032  // Case 1: Block argument of the "before" region.
1033  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1034  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1035  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1036  auto yieldOp = whileOp.getYieldOp();
1037  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1038  return computeLoopRegionIterArgBufferType(
1039  op, bbArg, initArg, yieldedValue, options, invocationStack);
1040  }
1041  }
1042 
1043  // Case 2: OpResult of the loop or block argument of the "after" region.
1044  // The bufferized "after" bbArg type can be directly computed from the
1045  // bufferized "before" bbArg type.
1046  unsigned resultNum;
1047  if (auto opResult = dyn_cast<OpResult>(value)) {
1048  resultNum = opResult.getResultNumber();
1049  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1050  &whileOp.getAfter()) {
1051  resultNum = cast<BlockArgument>(value).getArgNumber();
1052  } else {
1053  llvm_unreachable("invalid value");
1054  }
1055  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1056  if (!isa<TensorType>(conditionYieldedVal.getType())) {
1057  // scf.condition was already bufferized.
1058  return cast<BaseMemRefType>(conditionYieldedVal.getType());
1059  }
1060  return bufferization::getBufferType(conditionYieldedVal, options,
1061  invocationStack);
1062  }
1063 
1064  /// Assert that yielded values of an scf.while op are equivalent to their
1065  /// corresponding bbArgs. In that case, the buffer relations of the
1066  /// corresponding OpResults are "Equivalent".
1067  ///
1068  /// If this is not the case, allocs+copies are inserted and yielded from
1069  /// the loop. This could be a performance problem, so it must be explicitly
1070  /// activated with `allow-return-allocs`.
1071  ///
1072  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1073  /// equivalence condition must be checked for both.
1074  LogicalResult verifyAnalysis(Operation *op,
1075  const AnalysisState &state) const {
1076  auto whileOp = cast<scf::WhileOp>(op);
1077  const auto &options =
1078  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1079  if (options.allowReturnAllocsFromLoops)
1080  return success();
1081 
1082  auto conditionOp = whileOp.getConditionOp();
1083  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1084  Block *block = conditionOp->getBlock();
1085  if (!isa<TensorType>(it.value().getType()))
1086  continue;
1087  if (it.index() >= block->getNumArguments() ||
1088  !state.areEquivalentBufferizedValues(it.value(),
1089  block->getArgument(it.index())))
1090  return conditionOp->emitError()
1091  << "Condition arg #" << it.index()
1092  << " is not equivalent to the corresponding iter bbArg";
1093  }
1094 
1095  auto yieldOp = whileOp.getYieldOp();
1096  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1097  Block *block = yieldOp->getBlock();
1098  if (!isa<TensorType>(it.value().getType()))
1099  continue;
1100  if (it.index() >= block->getNumArguments() ||
1101  !state.areEquivalentBufferizedValues(it.value(),
1102  block->getArgument(it.index())))
1103  return yieldOp->emitError()
1104  << "Yield operand #" << it.index()
1105  << " is not equivalent to the corresponding iter bbArg";
1106  }
1107 
1108  return success();
1109  }
1110 };
1111 
1112 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1113 /// this is for analysis only.
1114 struct YieldOpInterface
1115  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1116  scf::YieldOp> {
1117  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1118  const AnalysisState &state) const {
1119  return true;
1120  }
1121 
1122  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1123  const AnalysisState &state) const {
1124  return false;
1125  }
1126 
1127  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1128  const AnalysisState &state) const {
1129  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1130  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1131  BufferRelation::Equivalent, /*isDefinite=*/false}};
1132  }
1133  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1134  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1136  return {};
1137  }
1138 
1139  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1140  const AnalysisState &state) const {
1141  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1142  // may be generated inside the block. We should not return/yield allocations
1143  // when possible.
1144  return true;
1145  }
1146 
1147  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1148  const BufferizationOptions &options) const {
1149  auto yieldOp = cast<scf::YieldOp>(op);
1150  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1151  scf::WhileOp>(yieldOp->getParentOp()))
1152  return yieldOp->emitError("unsupported scf::YieldOp parent");
1153 
1154  SmallVector<Value> newResults;
1155  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1156  Value value = it.value();
1157  if (isa<TensorType>(value.getType())) {
1158  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
1159  if (failed(maybeBuffer))
1160  return failure();
1161  Value buffer = *maybeBuffer;
1162  // We may have to cast the value before yielding it.
1163  if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1164  yieldOp->getParentOp())) {
1165  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1166  yieldOp->getParentOp()->getResult(it.index()), options);
1167  if (failed(resultType))
1168  return failure();
1169  buffer = castBuffer(rewriter, buffer, *resultType);
1170  } else if (auto whileOp =
1171  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1172  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1173  whileOp.getBeforeArguments()[it.index()], options);
1174  if (failed(resultType))
1175  return failure();
1176  buffer = castBuffer(rewriter, buffer, *resultType);
1177  }
1178  newResults.push_back(buffer);
1179  } else {
1180  newResults.push_back(value);
1181  }
1182  }
1183 
1184  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1185  return success();
1186  }
1187 };
1188 
1189 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1190 /// region. There are op interfaces for the terminators (InParallelOp
1191 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1192 /// for bufferization.
1193 struct ForallOpInterface
1194  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1195  ForallOp> {
1196  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1197  const AnalysisState &state) const {
1198  // All tensor operands to `scf.forall` are `shared_outs` and all
1199  // shared outs are assumed to be read by the loop. This does not
1200  // account for the case where the entire value is over-written,
1201  // but being conservative here.
1202  return true;
1203  }
1204 
1205  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1206  const AnalysisState &state) const {
1207  // Outputs of scf::ForallOps are always considered as a write.
1208  return true;
1209  }
1210 
1211  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1212  const AnalysisState &state) const {
1213  auto forallOp = cast<ForallOp>(op);
1214  return {
1215  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1216  }
1217 
1218  bool isWritable(Operation *op, Value value,
1219  const AnalysisState &state) const {
1220  return true;
1221  }
1222 
1223  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1224  const BufferizationOptions &options) const {
1225  OpBuilder::InsertionGuard guard(rewriter);
1226  auto forallOp = cast<ForallOp>(op);
1227  int64_t rank = forallOp.getRank();
1228 
1229  // Get buffers for all output operands.
1230  SmallVector<Value> buffers;
1231  for (Value out : forallOp.getOutputs()) {
1232  FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1233  if (failed(buffer))
1234  return failure();
1235  buffers.push_back(*buffer);
1236  }
1237 
1238  // Use buffers instead of block arguments.
1239  rewriter.setInsertionPointToStart(forallOp.getBody());
1240  for (const auto &it : llvm::zip(
1241  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1242  BlockArgument bbArg = std::get<0>(it);
1243  Value buffer = std::get<1>(it);
1244  Value bufferAsTensor = rewriter.create<ToTensorOp>(
1245  forallOp.getLoc(), bbArg.getType(), buffer);
1246  bbArg.replaceAllUsesWith(bufferAsTensor);
1247  }
1248 
1249  // Create new ForallOp without any results and drop the automatically
1250  // introduced terminator.
1251  rewriter.setInsertionPoint(forallOp);
1252  ForallOp newForallOp;
1253  newForallOp = rewriter.create<ForallOp>(
1254  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1255  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1256  /*outputs=*/ValueRange(), forallOp.getMapping());
1257 
1258  // Keep discardable attributes from the original op.
1260 
1261  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1262 
1263  // Move over block contents of the old op.
1264  SmallVector<Value> replacementBbArgs;
1265  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1266  newForallOp.getBody()->getArguments().end());
1267  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1268  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1269  replacementBbArgs);
1270 
1271  // Remove the old op and replace all of its uses.
1272  replaceOpWithBufferizedValues(rewriter, op, buffers);
1273 
1274  return success();
1275  }
1276 
1277  FailureOr<BaseMemRefType>
1279  SmallVector<Value> &invocationStack) const {
1280  auto forallOp = cast<ForallOp>(op);
1281 
1282  if (auto bbArg = dyn_cast<BlockArgument>(value))
1283  // A tensor block argument has the same bufferized type as the
1284  // corresponding output operand.
1286  forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
1287 
1288  // The bufferized result type is the same as the bufferized type of the
1289  // corresponding output operand.
1291  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1292  invocationStack);
1293  }
1294 
1295  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1296  auto forallOp = cast<ForallOp>(op);
1297 
1298  // This op is repetitive if it has 1 or more steps.
1299  // If the control variables are dynamic, it is also considered so.
1300  for (auto [lb, ub, step] :
1301  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1302  forallOp.getMixedStep())) {
1303  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1304  if (!lbConstant)
1305  return true;
1306 
1307  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1308  if (!ubConstant)
1309  return true;
1310 
1311  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1312  if (!stepConstant)
1313  return true;
1314 
1315  if (*lbConstant + *stepConstant < *ubConstant)
1316  return true;
1317  }
1318  return false;
1319  }
1320 
1321  bool isParallelRegion(Operation *op, unsigned index) const {
1322  return isRepetitiveRegion(op, index);
1323  }
1324 };
1325 
1326 /// Nothing to do for InParallelOp.
1327 struct InParallelOpInterface
1328  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1329  InParallelOp> {
1330  LogicalResult bufferize(Operation *op, RewriterBase &b,
1331  const BufferizationOptions &options) const {
1332  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1333  return failure();
1334  }
1335 };
1336 
1337 } // namespace
1338 } // namespace scf
1339 } // namespace mlir
1340 
1342  DialectRegistry &registry) {
1343  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1344  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1345  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1346  ForOp::attachInterface<ForOpInterface>(*ctx);
1347  IfOp::attachInterface<IfOpInterface>(*ctx);
1348  IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1349  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1350  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1351  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1352  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1353  });
1354 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:102
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:295
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:304
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:307
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
result_range getOpResults()
Definition: Operation.h:420
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition: Region.cpp:50
BlockListType & getBlocks()
Definition: Region.h:45
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:394
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...