MLIR  22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/PatternMatch.h"
22 
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 using namespace mlir::scf;
26 
27 namespace mlir {
28 namespace scf {
29 namespace {
30 
31 /// Helper function for loop bufferization. Cast the given buffer to the given
32 /// memref type.
33 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
34  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
35  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
36  // If the buffer already has the correct type, no cast is needed.
37  if (buffer.getType() == type)
38  return buffer;
39  // TODO: In case `type` has a layout map that is not the fully dynamic
40  // one, we may not be able to cast the buffer. In that case, the loop
41  // iter_arg's layout map must be changed (see uses of `castBuffer`).
42  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
43  "scf.while op bufferization: cast incompatible");
44  return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
45 }
46 
47 /// Helper function for loop bufferization. Return "true" if the given value
48 /// is guaranteed to not alias with an external tensor apart from values in
49 /// `exceptions`. A value is external if it is defined outside of the given
50 /// region or if it is an entry block argument of the region.
51 static bool doesNotAliasExternalValue(Value value, Region *region,
52  ValueRange exceptions,
53  const OneShotAnalysisState &state) {
54  assert(region->hasOneBlock() && "expected region with single block");
55  bool result = true;
56  state.applyOnAliases(value, [&](Value alias) {
57  if (llvm::is_contained(exceptions, alias))
58  return;
59  Region *aliasRegion = alias.getParentRegion();
60  if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
61  result = false;
62  if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
63  result = false;
64  });
65  return result;
66 }
67 
68 /// Bufferization of scf.condition.
69 struct ConditionOpInterface
70  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
71  scf::ConditionOp> {
72  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
73  const AnalysisState &state) const {
74  return true;
75  }
76 
77  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
78  const AnalysisState &state) const {
79  return false;
80  }
81 
82  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
83  const AnalysisState &state) const {
84  return {};
85  }
86 
87  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
88  const AnalysisState &state) const {
89  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
90  // may be generated inside the block. We should not return/yield allocations
91  // when possible.
92  return true;
93  }
94 
95  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97  BufferizationState &state) const {
98  auto conditionOp = cast<scf::ConditionOp>(op);
99  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
100 
101  SmallVector<Value> newArgs;
102  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
103  Value value = it.value();
104  if (isa<TensorType>(value.getType())) {
105  FailureOr<Value> maybeBuffer =
106  getBuffer(rewriter, value, options, state);
107  if (failed(maybeBuffer))
108  return failure();
109  FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
110  whileOp.getAfterArguments()[it.index()], options, state);
111  if (failed(resultType))
112  return failure();
113  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114  newArgs.push_back(buffer);
115  } else {
116  newArgs.push_back(value);
117  }
118  }
119 
120  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121  rewriter, op, conditionOp.getCondition(), newArgs);
122  return success();
123  }
124 };
125 
126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127 /// return an empty op.
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129  scf::YieldOp result;
130  for (Block &block : executeRegionOp.getRegion()) {
131  if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132  if (result)
133  return {};
134  result = yieldOp;
135  }
136  }
137  return result;
138 }
139 
140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141 /// fully implemented at the moment.
142 struct ExecuteRegionOpInterface
144  ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145 
146  static bool supportsUnstructuredControlFlow() { return true; }
147 
148  bool isWritable(Operation *op, Value value,
149  const AnalysisState &state) const {
150  return true;
151  }
152 
153  LogicalResult verifyAnalysis(Operation *op,
154  const AnalysisState &state) const {
155  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156  // TODO: scf.execute_region with multiple yields are not supported.
157  if (!getUniqueYieldOp(executeRegionOp))
158  return op->emitOpError("op without unique scf.yield is not supported");
159  return success();
160  }
161 
163  getAliasingOpOperands(Operation *op, Value value,
164  const AnalysisState &state) const {
165  if (auto bbArg = dyn_cast<BlockArgument>(value))
166  return getAliasingBranchOpOperands(op, bbArg, state);
167 
168  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169  // any SSA value that is in scope. To allow for use-def chain traversal
170  // through ExecuteRegionOps in the analysis, the corresponding yield value
171  // is considered to be aliasing with the result.
172  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173  auto it = llvm::find(op->getOpResults(), value);
174  assert(it != op->getOpResults().end() && "invalid value");
175  size_t resultNum = std::distance(op->getOpResults().begin(), it);
176  auto yieldOp = getUniqueYieldOp(executeRegionOp);
177  // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178  if (!yieldOp)
179  return {};
180  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181  }
182 
183  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
185  BufferizationState &state) const {
186  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
187  auto yieldOp = getUniqueYieldOp(executeRegionOp);
188  TypeRange newResultTypes(yieldOp.getResults());
189 
190  // Create new op and move over region.
191  auto newOp =
192  scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes);
193  newOp.getRegion().takeBody(executeRegionOp.getRegion());
194 
195  // Bufferize every block.
196  for (Block &block : newOp.getRegion())
197  if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
198  options, state)))
199  return failure();
200 
201  // Update all uses of the old op.
202  rewriter.setInsertionPointAfter(newOp);
203  SmallVector<Value> newResults;
204  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
205  if (isa<TensorType>(it.value())) {
206  newResults.push_back(bufferization::ToTensorOp::create(
207  rewriter, executeRegionOp.getLoc(), it.value(),
208  newOp->getResult(it.index())));
209  } else {
210  newResults.push_back(newOp->getResult(it.index()));
211  }
212  }
213 
214  // Replace old op.
215  rewriter.replaceOp(executeRegionOp, newResults);
216 
217  return success();
218  }
219 };
220 
221 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
222 struct IfOpInterface
223  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
225  getAliasingOpOperands(Operation *op, Value value,
226  const AnalysisState &state) const {
227  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
228  // value that is in scope. To allow for use-def chain traversal through
229  // IfOps in the analysis, both corresponding yield values from the then/else
230  // branches are considered to be aliasing with the result.
231  auto ifOp = cast<scf::IfOp>(op);
232  size_t resultNum = std::distance(op->getOpResults().begin(),
233  llvm::find(op->getOpResults(), value));
234  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
235  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
236  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
237  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
238  }
239 
240  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242  BufferizationState &state) const {
243  OpBuilder::InsertionGuard g(rewriter);
244  auto ifOp = cast<scf::IfOp>(op);
245 
246  // Compute bufferized result types.
247  SmallVector<Type> newTypes;
248  for (Value result : ifOp.getResults()) {
249  if (!isa<TensorType>(result.getType())) {
250  newTypes.push_back(result.getType());
251  continue;
252  }
253  auto bufferType = bufferization::getBufferType(result, options, state);
254  if (failed(bufferType))
255  return failure();
256  newTypes.push_back(*bufferType);
257  }
258 
259  // Create new op.
260  rewriter.setInsertionPoint(ifOp);
261  auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
262  ifOp.getCondition(),
263  /*withElseRegion=*/true);
264 
265  // Move over then/else blocks.
266  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
267  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
268 
269  // Replace op results.
270  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
271 
272  return success();
273  }
274 
275  FailureOr<BufferLikeType>
277  const BufferizationState &state,
278  SmallVector<Value> &invocationStack) const {
279  auto ifOp = cast<scf::IfOp>(op);
280  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
281  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
282  assert(value.getDefiningOp() == op && "invalid valid");
283 
284  // Determine buffer types of the true/false branches.
285  auto opResult = cast<OpResult>(value);
286  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
287  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
288  BaseMemRefType thenBufferType, elseBufferType;
289  if (isa<BaseMemRefType>(thenValue.getType())) {
290  // True branch was already bufferized.
291  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
292  } else {
293  auto maybeBufferType =
295  thenValue, options, state, invocationStack));
296  if (failed(maybeBufferType))
297  return failure();
298  thenBufferType = *maybeBufferType;
299  }
300  if (isa<BaseMemRefType>(elseValue.getType())) {
301  // False branch was already bufferized.
302  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
303  } else {
304  auto maybeBufferType =
306  elseValue, options, state, invocationStack));
307  if (failed(maybeBufferType))
308  return failure();
309  elseBufferType = *maybeBufferType;
310  }
311 
312  // Best case: Both branches have the exact same buffer type.
313  if (thenBufferType == elseBufferType)
314  return cast<BufferLikeType>(thenBufferType);
315 
316  // Memory space mismatch.
317  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
318  return op->emitError("inconsistent memory space on then/else branches");
319 
320  // Layout maps are different: Promote to fully dynamic layout map.
321  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
322  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
323  }
324 };
325 
326 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
327 /// yields memrefs.
328 struct IndexSwitchOpInterface
329  : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330  scf::IndexSwitchOp> {
332  getAliasingOpOperands(Operation *op, Value value,
333  const AnalysisState &state) const {
334  // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
335  // any SSA. This is similar to IfOps.
336  auto switchOp = cast<scf::IndexSwitchOp>(op);
337  int64_t resultNum = cast<OpResult>(value).getResultNumber();
338  AliasingOpOperandList result;
339  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
340  auto yieldOp =
341  cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
342  result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
344  /*isDefinite=*/false));
345  }
346  auto defaultYieldOp =
347  cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
348  result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
350  /*isDefinite=*/false));
351  return result;
352  }
353 
354  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
356  BufferizationState &state) const {
357  OpBuilder::InsertionGuard g(rewriter);
358  auto switchOp = cast<scf::IndexSwitchOp>(op);
359 
360  // Compute bufferized result types.
361  SmallVector<Type> newTypes;
362  for (Value result : switchOp.getResults()) {
363  if (!isa<TensorType>(result.getType())) {
364  newTypes.push_back(result.getType());
365  continue;
366  }
367  auto bufferType = bufferization::getBufferType(result, options, state);
368  if (failed(bufferType))
369  return failure();
370  newTypes.push_back(*bufferType);
371  }
372 
373  // Create new op.
374  rewriter.setInsertionPoint(switchOp);
375  auto newSwitchOp = scf::IndexSwitchOp::create(
376  rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
377  switchOp.getCases(), switchOp.getCases().size());
378 
379  // Move over blocks.
380  for (auto [src, dest] :
381  llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
382  rewriter.inlineRegionBefore(src, dest, dest.begin());
383  rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
384  newSwitchOp.getDefaultRegion(),
385  newSwitchOp.getDefaultRegion().begin());
386 
387  // Replace op results.
388  replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
389 
390  return success();
391  }
392 
393  FailureOr<BufferLikeType>
395  const BufferizationState &state,
396  SmallVector<Value> &invocationStack) const {
397  auto switchOp = cast<scf::IndexSwitchOp>(op);
398  assert(value.getDefiningOp() == op && "invalid value");
399  int64_t resultNum = cast<OpResult>(value).getResultNumber();
400 
401  // Helper function to get buffer type of a case.
402  auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
403  auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404  Value yieldedValue = yieldOp->getOperand(resultNum);
405  if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
406  return bufferType;
407  auto maybeBufferType = bufferization::getBufferType(
408  yieldedValue, options, state, invocationStack);
409  return bufferization::detail::asMemRefType(maybeBufferType);
410  };
411 
412  // Compute buffer type of the default case.
413  auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
414  if (failed(maybeBufferType))
415  return failure();
416  BaseMemRefType bufferType = *maybeBufferType;
417 
418  // Compute buffer types of all other cases.
419  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
420  auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
421  if (failed(yieldedBufferType))
422  return failure();
423 
424  // Best case: Both branches have the exact same buffer type.
425  if (bufferType == *yieldedBufferType)
426  continue;
427 
428  // Memory space mismatch.
429  if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
430  return op->emitError("inconsistent memory space on switch cases");
431 
432  // Layout maps are different: Promote to fully dynamic layout map.
434  cast<TensorType>(value.getType()), bufferType.getMemorySpace());
435  }
436 
437  return cast<BufferLikeType>(bufferType);
438  }
439 };
440 
441 /// Helper function for loop bufferization. Return the indices of all values
442 /// that have a tensor type.
443 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
444  DenseSet<int64_t> result;
445  for (const auto &it : llvm::enumerate(values))
446  if (isa<TensorType>(it.value().getType()))
447  result.insert(it.index());
448  return result;
449 }
450 
451 /// Helper function for loop bufferization. Return the indices of all
452 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
453 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
454  ValueRange yieldedValues,
455  const AnalysisState &state) {
456  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
457  DenseSet<int64_t> result;
458  for (unsigned int i = 0; i < minSize; ++i) {
459  if (!isa<TensorType>(bbArgs[i].getType()) ||
460  !isa<TensorType>(yieldedValues[i].getType()))
461  continue;
462  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
463  result.insert(i);
464  }
465  return result;
466 }
467 
468 /// Helper function for loop bufferization. Return the bufferized values of the
469 /// given OpOperands. If an operand is not a tensor, return the original value.
470 static FailureOr<SmallVector<Value>>
471 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
473  SmallVector<Value> result;
474  for (OpOperand &opOperand : operands) {
475  if (isa<TensorType>(opOperand.get().getType())) {
476  FailureOr<Value> resultBuffer =
477  getBuffer(rewriter, opOperand.get(), options, state);
478  if (failed(resultBuffer))
479  return failure();
480  result.push_back(*resultBuffer);
481  } else {
482  result.push_back(opOperand.get());
483  }
484  }
485  return result;
486 }
487 
488 /// Helper function for loop bufferization. Given a list of bbArgs of the new
489 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
490 /// ToTensorOps, so that the block body can be moved over to the new op.
491 static SmallVector<Value>
492 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
493  Block::BlockArgListType oldBbArgs,
494  const DenseSet<int64_t> &tensorIndices) {
495  SmallVector<Value> result;
496  for (const auto &it : llvm::enumerate(bbArgs)) {
497  size_t idx = it.index();
498  Value val = it.value();
499  if (tensorIndices.contains(idx)) {
500  result.push_back(
501  bufferization::ToTensorOp::create(rewriter, val.getLoc(),
502  oldBbArgs[idx].getType(), val)
503  .getResult());
504  } else {
505  result.push_back(val);
506  }
507  }
508  return result;
509 }
510 
511 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
512 /// the bufferized type of the corresponding init_arg and the bufferized type
513 /// of the corresponding yielded value.
514 ///
515 /// This function uses bufferization::getBufferType to compute the bufferized
516 /// type of the init_arg and of the yielded value. (The computation of the
517 /// bufferized yielded value type usually requires computing the bufferized type
518 /// of the iter_arg again; the implementation of getBufferType traces back the
519 /// use-def chain of the given value and computes a buffer type along the way.)
520 /// If both buffer types are equal, no casts are needed the computed buffer type
521 /// can be used directly. Otherwise, the buffer types can only differ in their
522 /// layout map and a cast must be inserted.
523 static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
524  Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
525  const BufferizationOptions &options, const BufferizationState &state,
526  SmallVector<Value> &invocationStack) {
527  // Determine the buffer type of the init_arg.
528  auto initArgBufferType =
529  bufferization::getBufferType(initArg, options, state, invocationStack);
530  if (failed(initArgBufferType))
531  return failure();
532 
533  if (llvm::count(invocationStack, iterArg) >= 2) {
534  // If the iter_arg is already twice on the invocation stack, just take the
535  // type of the init_arg. This is to avoid infinite loops when calculating
536  // the buffer type. This will most likely result in computing a memref type
537  // with a fully dynamic layout map.
538 
539  // Note: For more precise layout map computation, a fixpoint iteration could
540  // be done (i.e., re-computing the yielded buffer type until the bufferized
541  // iter_arg type no longer changes). This current implementation immediately
542  // switches to a fully dynamic layout map when a mismatch between bufferized
543  // init_arg type and bufferized yield value type is detected.
544  return *initArgBufferType;
545  }
546 
547  // Compute the buffer type of the yielded value.
548  BufferLikeType yieldedValueBufferType;
549  if (isa<BaseMemRefType>(yieldedValue.getType())) {
550  // scf.yield was already bufferized.
551  yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
552  } else {
553  // Note: This typically triggers a recursive call for the buffer type of
554  // the iter_arg.
555  auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
556  state, invocationStack);
557  if (failed(maybeBufferType))
558  return failure();
559  yieldedValueBufferType = *maybeBufferType;
560  }
561 
562  // If yielded type and init_arg type are the same, use that type directly.
563  if (*initArgBufferType == yieldedValueBufferType)
564  return yieldedValueBufferType;
565 
566  // If there is a mismatch between the yielded buffer type and the init_arg
567  // buffer type, the buffer type must be promoted to a fully dynamic layout
568  // map.
569  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
570  auto iterTensorType = cast<TensorType>(iterArg.getType());
571  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
572  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
573  return loopOp->emitOpError(
574  "init_arg and yielded value bufferize to inconsistent memory spaces");
575 #ifndef NDEBUG
576  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
577  assert(
578  llvm::all_equal({yieldedRankedBufferType.getShape(),
579  cast<MemRefType>(initBufferType).getShape(),
580  cast<RankedTensorType>(iterTensorType).getShape()}) &&
581  "expected same shape");
582  }
583 #endif // NDEBUG
584  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
585  iterTensorType, yieldedBufferType.getMemorySpace()));
586 }
587 
588 /// Return `true` if the given loop may have 0 iterations.
589 bool mayHaveZeroIterations(scf::ForOp forOp) {
590  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
591  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
592  if (!lb.has_value() || !ub.has_value())
593  return true;
594  return *ub <= *lb;
595 }
596 
597 /// Bufferization of scf.for. Replace with a new scf.for that operates on
598 /// memrefs.
599 struct ForOpInterface
600  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
601  scf::ForOp> {
602  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
603  const AnalysisState &state) const {
604  auto forOp = cast<scf::ForOp>(op);
605 
606  // If the loop has zero iterations, the results of the op are their
607  // corresponding init_args, meaning that the init_args bufferize to a read.
608  if (mayHaveZeroIterations(forOp))
609  return true;
610 
611  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
612  // its matching bbArg may.
613  return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
614  }
615 
616  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
617  const AnalysisState &state) const {
618  // Tensor iter_args of scf::ForOps are always considered as a write.
619  return true;
620  }
621 
622  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
623  const AnalysisState &state) const {
624  auto forOp = cast<scf::ForOp>(op);
625  OpResult opResult = forOp.getTiedLoopResult(&opOperand);
626  BufferRelation relation = bufferRelation(op, opResult, state);
627  return {{opResult, relation,
628  /*isDefinite=*/relation == BufferRelation::Equivalent}};
629  }
630 
631  BufferRelation bufferRelation(Operation *op, OpResult opResult,
632  const AnalysisState &state) const {
633  // ForOp results are equivalent to their corresponding init_args if the
634  // corresponding iter_args and yield values are equivalent.
635  auto forOp = cast<scf::ForOp>(op);
636  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
637  bool equivalentYield = state.areEquivalentBufferizedValues(
638  bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
639  return equivalentYield ? BufferRelation::Equivalent
641  }
642 
643  bool isWritable(Operation *op, Value value,
644  const AnalysisState &state) const {
645  // Interestingly, scf::ForOp's bbArg can **always** be viewed
646  // inplace from the perspective of ops nested under:
647  // 1. Either the matching iter operand is not bufferized inplace and an
648  // alloc + optional copy makes the bbArg itself inplaceable.
649  // 2. Or the matching iter operand is bufferized inplace and bbArg just
650  // bufferizes to that too.
651  return true;
652  }
653 
654  LogicalResult
655  resolveConflicts(Operation *op, RewriterBase &rewriter,
656  const AnalysisState &analysisState,
657  const BufferizationState &bufferizationState) const {
658  auto bufferizableOp = cast<BufferizableOpInterface>(op);
659  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
660  rewriter, analysisState, bufferizationState)))
661  return failure();
662 
663  if (analysisState.getOptions().copyBeforeWrite)
664  return success();
665 
666  // According to the `getAliasing...` implementations, a bufferized OpResult
667  // may alias only with the corresponding bufferized init_arg (or with a
668  // newly allocated buffer) and not with other buffers defined outside of the
669  // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
670  // but not with any other OpOperand.
671  auto forOp = cast<scf::ForOp>(op);
672  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
673  OpBuilder::InsertionGuard g(rewriter);
674  rewriter.setInsertionPoint(yieldOp);
675 
676  // Indices of all iter_args that have tensor type. These are the ones that
677  // are bufferized.
678  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
679  // For every yielded value, does it alias with something defined outside of
680  // the loop?
681  SmallVector<Value> yieldValues;
682  for (const auto it : llvm::enumerate(yieldOp.getResults())) {
683  // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
684  // type cannot be used in the signature of `resolveConflicts` because the
685  // op interface is in the "IR" build unit and the `OneShotAnalysisState`
686  // is defined in the "Transforms" build unit.
687  if (!indices.contains(it.index()) ||
688  doesNotAliasExternalValue(
689  it.value(), &forOp.getRegion(),
690  /*exceptions=*/forOp.getRegionIterArg(it.index()),
691  static_cast<const OneShotAnalysisState &>(analysisState))) {
692  yieldValues.push_back(it.value());
693  continue;
694  }
695  FailureOr<Value> alloc = allocateTensorForShapedValue(
696  rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
697  bufferizationState);
698  if (failed(alloc))
699  return failure();
700  yieldValues.push_back(*alloc);
701  }
702 
703  rewriter.modifyOpInPlace(
704  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
705  return success();
706  }
707 
708  FailureOr<BufferLikeType>
710  const BufferizationState &state,
711  SmallVector<Value> &invocationStack) const {
712  auto forOp = cast<scf::ForOp>(op);
713  assert(getOwnerOfValue(value) == op && "invalid value");
714  assert(isa<TensorType>(value.getType()) && "expected tensor type");
715 
716  if (auto opResult = dyn_cast<OpResult>(value)) {
717  // The type of an OpResult must match the corresponding iter_arg type.
718  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
719  return bufferization::getBufferType(bbArg, options, state,
720  invocationStack);
721  }
722 
723  // Compute result/argument number.
724  BlockArgument bbArg = cast<BlockArgument>(value);
725  unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
726 
727  // Compute the bufferized type.
728  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
729  Value yieldedValue = yieldOp.getOperand(resultNum);
730  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
731  Value initArg = forOp.getInitArgs()[resultNum];
732  return computeLoopRegionIterArgBufferType(
733  op, iterArg, initArg, yieldedValue, options, state, invocationStack);
734  }
735 
736  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
738  BufferizationState &state) const {
739  auto forOp = cast<scf::ForOp>(op);
740  Block *oldLoopBody = forOp.getBody();
741 
742  // Indices of all iter_args that have tensor type. These are the ones that
743  // are bufferized.
744  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
745 
746  // The new memref init_args of the loop.
747  FailureOr<SmallVector<Value>> maybeInitArgs =
748  getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
749  if (failed(maybeInitArgs))
750  return failure();
751  SmallVector<Value> initArgs = *maybeInitArgs;
752 
753  // Cast init_args if necessary.
754  SmallVector<Value> castedInitArgs;
755  for (const auto &it : llvm::enumerate(initArgs)) {
756  Value initArg = it.value();
757  Value result = forOp->getResult(it.index());
758  // If the type is not a tensor, bufferization doesn't need to touch it.
759  if (!isa<TensorType>(result.getType())) {
760  castedInitArgs.push_back(initArg);
761  continue;
762  }
763  auto targetType = bufferization::getBufferType(result, options, state);
764  if (failed(targetType))
765  return failure();
766  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
767  }
768 
769  // Construct a new scf.for op with memref instead of tensor values.
770  auto newForOp = scf::ForOp::create(
771  rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772  forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
773  forOp.getUnsignedCmp());
774  newForOp->setAttrs(forOp->getAttrs());
775  Block *loopBody = newForOp.getBody();
776 
777  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
778  // iter_args of the new loop in ToTensorOps.
779  rewriter.setInsertionPointToStart(loopBody);
780  SmallVector<Value> iterArgs =
781  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
782  forOp.getRegionIterArgs(), indices);
783  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
784 
785  // Move loop body to new loop.
786  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
787 
788  // Replace loop results.
789  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
790 
791  return success();
792  }
793 
794  /// Assert that yielded values of an scf.for op are equivalent to their
795  /// corresponding bbArgs. In that case, the buffer relations of the
796  /// corresponding OpResults are "Equivalent".
797  ///
798  /// If this is not the case, an allocs+copies are inserted and yielded from
799  /// the loop. This could be a performance problem, so it must be explicitly
800  /// activated with `alloc-return-allocs`.
801  LogicalResult verifyAnalysis(Operation *op,
802  const AnalysisState &state) const {
803  const auto &options =
804  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
805  if (options.allowReturnAllocsFromLoops)
806  return success();
807 
808  auto forOp = cast<scf::ForOp>(op);
809  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
810  for (OpResult opResult : op->getOpResults()) {
811  if (!isa<TensorType>(opResult.getType()))
812  continue;
813 
814  // Note: This is overly strict. We should check for aliasing bufferized
815  // values. But we don't have a "must-alias" analysis yet.
816  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
817  return yieldOp->emitError()
818  << "Yield operand #" << opResult.getResultNumber()
819  << " is not equivalent to the corresponding iter bbArg";
820  }
821 
822  return success();
823  }
824 };
825 
826 /// Bufferization of scf.while. Replace with a new scf.while that operates on
827 /// memrefs.
828 struct WhileOpInterface
829  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
830  scf::WhileOp> {
831  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832  const AnalysisState &state) const {
833  // Tensor iter_args of scf::WhileOps are always considered as a read.
834  return true;
835  }
836 
837  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
838  const AnalysisState &state) const {
839  // Tensor iter_args of scf::WhileOps are always considered as a write.
840  return true;
841  }
842 
843  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
844  const AnalysisState &state) const {
845  auto whileOp = cast<scf::WhileOp>(op);
846  unsigned int idx = opOperand.getOperandNumber();
847 
848  // The OpResults and OpOperands may not match. They may not even have the
849  // same type. The number of OpResults and OpOperands can also differ.
850  if (idx >= op->getNumResults() ||
851  opOperand.get().getType() != op->getResult(idx).getType())
852  return {};
853 
854  // The only aliasing OpResult may be the one at the same index.
855  OpResult opResult = whileOp->getResult(idx);
856  BufferRelation relation = bufferRelation(op, opResult, state);
857  return {{opResult, relation,
858  /*isDefinite=*/relation == BufferRelation::Equivalent}};
859  }
860 
861  BufferRelation bufferRelation(Operation *op, OpResult opResult,
862  const AnalysisState &state) const {
863  // WhileOp results are equivalent to their corresponding init_args if the
864  // corresponding iter_args and yield values are equivalent (for both the
865  // "before" and the "after" block).
866  unsigned int resultNumber = opResult.getResultNumber();
867  auto whileOp = cast<scf::WhileOp>(op);
868 
869  // The "before" region bbArgs and the OpResults may not match.
870  if (resultNumber >= whileOp.getBeforeArguments().size())
872  if (opResult.getType() !=
873  whileOp.getBeforeArguments()[resultNumber].getType())
875 
876  auto conditionOp = whileOp.getConditionOp();
877  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
878  Value conditionOperand = conditionOp.getArgs()[resultNumber];
879  bool equivCondition =
880  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
881 
882  auto yieldOp = whileOp.getYieldOp();
883  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
884  Value yieldOperand = yieldOp.getOperand(resultNumber);
885  bool equivYield =
886  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
887 
888  return equivCondition && equivYield ? BufferRelation::Equivalent
890  }
891 
892  bool isWritable(Operation *op, Value value,
893  const AnalysisState &state) const {
894  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
895  // inplace from the perspective of ops nested under:
896  // 1. Either the matching iter operand is not bufferized inplace and an
897  // alloc + optional copy makes the bbArg itself inplaceable.
898  // 2. Or the matching iter operand is bufferized inplace and bbArg just
899  // bufferizes to that too.
900  return true;
901  }
902 
903  LogicalResult
904  resolveConflicts(Operation *op, RewriterBase &rewriter,
905  const AnalysisState &analysisState,
906  const BufferizationState &bufferizationState) const {
907  auto bufferizableOp = cast<BufferizableOpInterface>(op);
908  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
909  rewriter, analysisState, bufferizationState)))
910  return failure();
911 
912  if (analysisState.getOptions().copyBeforeWrite)
913  return success();
914 
915  // According to the `getAliasing...` implementations, a bufferized OpResult
916  // may alias only with the corresponding bufferized init_arg and with no
917  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
918  // but not with any other OpOperand. If a corresponding OpResult/init_arg
919  // pair bufferizes to equivalent buffers, this aliasing requirement is
920  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
921  // (New buffer copies do not alias with any buffer.)
922  OpBuilder::InsertionGuard g(rewriter);
923  auto whileOp = cast<scf::WhileOp>(op);
924  auto conditionOp = whileOp.getConditionOp();
925 
926  // For every yielded value, is the value equivalent to its corresponding
927  // bbArg?
928  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
929  whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
930  DenseSet<int64_t> equivalentYieldsAfter =
931  getEquivalentBuffers(whileOp.getAfterArguments(),
932  whileOp.getYieldOp().getResults(), analysisState);
933 
934  // Update "before" region.
935  rewriter.setInsertionPoint(conditionOp);
936  SmallVector<Value> beforeYieldValues;
937  for (int64_t idx = 0;
938  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
939  Value value = conditionOp.getArgs()[idx];
940  if (!isa<TensorType>(value.getType()) ||
941  (equivalentYieldsAfter.contains(idx) &&
942  equivalentYieldsBefore.contains(idx))) {
943  beforeYieldValues.push_back(value);
944  continue;
945  }
946  FailureOr<Value> alloc = allocateTensorForShapedValue(
947  rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
948  bufferizationState);
949  if (failed(alloc))
950  return failure();
951  beforeYieldValues.push_back(*alloc);
952  }
953  rewriter.modifyOpInPlace(conditionOp, [&]() {
954  conditionOp.getArgsMutable().assign(beforeYieldValues);
955  });
956 
957  return success();
958  }
959 
960  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
962  BufferizationState &state) const {
963  auto whileOp = cast<scf::WhileOp>(op);
964 
965  // Indices of all bbArgs that have tensor type. These are the ones that
966  // are bufferized. The "before" and "after" regions may have different args.
967  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
968  DenseSet<int64_t> indicesAfter =
969  getTensorIndices(whileOp.getAfterArguments());
970 
971  // The new memref init_args of the loop.
972  FailureOr<SmallVector<Value>> maybeInitArgs =
973  getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
974  if (failed(maybeInitArgs))
975  return failure();
976  SmallVector<Value> initArgs = *maybeInitArgs;
977 
978  // Cast init_args if necessary.
979  SmallVector<Value> castedInitArgs;
980  for (const auto &it : llvm::enumerate(initArgs)) {
981  Value initArg = it.value();
982  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
983  // If the type is not a tensor, bufferization doesn't need to touch it.
984  if (!isa<TensorType>(beforeArg.getType())) {
985  castedInitArgs.push_back(initArg);
986  continue;
987  }
988  auto targetType = bufferization::getBufferType(beforeArg, options, state);
989  if (failed(targetType))
990  return failure();
991  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
992  }
993 
994  // The result types of a WhileOp are the same as the "after" bbArg types.
995  SmallVector<Type> argsTypesAfter = llvm::to_vector(
996  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
997  if (!isa<TensorType>(bbArg.getType()))
998  return bbArg.getType();
999  // TODO: error handling
1000  return llvm::cast<Type>(
1001  *bufferization::getBufferType(bbArg, options, state));
1002  }));
1003 
1004  // Construct a new scf.while op with memref instead of tensor values.
1005  ValueRange argsRangeBefore(castedInitArgs);
1006  TypeRange argsTypesBefore(argsRangeBefore);
1007  auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1008  argsTypesAfter, castedInitArgs);
1009 
1010  // Add before/after regions to the new op.
1011  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1012  whileOp.getLoc());
1013  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1014  whileOp.getLoc());
1015  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1016  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1017  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1018  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1019 
1020  // Set up new iter_args and move the loop condition block to the new op.
1021  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1022  // in ToTensorOps.
1023  rewriter.setInsertionPointToStart(newBeforeBody);
1024  SmallVector<Value> newBeforeArgs =
1025  getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1026  whileOp.getBeforeArguments(), indicesBefore);
1027  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1028 
1029  // Set up new iter_args and move the loop body block to the new op.
1030  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1031  // in ToTensorOps.
1032  rewriter.setInsertionPointToStart(newAfterBody);
1033  SmallVector<Value> newAfterArgs =
1034  getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1035  whileOp.getAfterArguments(), indicesAfter);
1036  rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1037 
1038  // Replace loop results.
1039  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1040 
1041  return success();
1042  }
1043 
1044  FailureOr<BufferLikeType>
1046  const BufferizationState &state,
1047  SmallVector<Value> &invocationStack) const {
1048  auto whileOp = cast<scf::WhileOp>(op);
1049  assert(getOwnerOfValue(value) == op && "invalid value");
1050  assert(isa<TensorType>(value.getType()) && "expected tensor type");
1051 
1052  // Case 1: Block argument of the "before" region.
1053  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1054  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1055  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1056  auto yieldOp = whileOp.getYieldOp();
1057  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1058  return computeLoopRegionIterArgBufferType(
1059  op, bbArg, initArg, yieldedValue, options, state, invocationStack);
1060  }
1061  }
1062 
1063  // Case 2: OpResult of the loop or block argument of the "after" region.
1064  // The bufferized "after" bbArg type can be directly computed from the
1065  // bufferized "before" bbArg type.
1066  unsigned resultNum;
1067  if (auto opResult = dyn_cast<OpResult>(value)) {
1068  resultNum = opResult.getResultNumber();
1069  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1070  &whileOp.getAfter()) {
1071  resultNum = cast<BlockArgument>(value).getArgNumber();
1072  } else {
1073  llvm_unreachable("invalid value");
1074  }
1075  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1076  if (!isa<TensorType>(conditionYieldedVal.getType())) {
1077  // scf.condition was already bufferized.
1078  return cast<BufferLikeType>(conditionYieldedVal.getType());
1079  }
1080  return bufferization::getBufferType(conditionYieldedVal, options, state,
1081  invocationStack);
1082  }
1083 
1084  /// Assert that yielded values of an scf.while op are equivalent to their
1085  /// corresponding bbArgs. In that case, the buffer relations of the
1086  /// corresponding OpResults are "Equivalent".
1087  ///
1088  /// If this is not the case, allocs+copies are inserted and yielded from
1089  /// the loop. This could be a performance problem, so it must be explicitly
1090  /// activated with `allow-return-allocs`.
1091  ///
1092  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1093  /// equivalence condition must be checked for both.
1094  LogicalResult verifyAnalysis(Operation *op,
1095  const AnalysisState &state) const {
1096  auto whileOp = cast<scf::WhileOp>(op);
1097  const auto &options =
1098  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1099  if (options.allowReturnAllocsFromLoops)
1100  return success();
1101 
1102  auto conditionOp = whileOp.getConditionOp();
1103  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1104  Block *block = conditionOp->getBlock();
1105  if (!isa<TensorType>(it.value().getType()))
1106  continue;
1107  if (it.index() >= block->getNumArguments() ||
1108  !state.areEquivalentBufferizedValues(it.value(),
1109  block->getArgument(it.index())))
1110  return conditionOp->emitError()
1111  << "Condition arg #" << it.index()
1112  << " is not equivalent to the corresponding iter bbArg";
1113  }
1114 
1115  auto yieldOp = whileOp.getYieldOp();
1116  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1117  Block *block = yieldOp->getBlock();
1118  if (!isa<TensorType>(it.value().getType()))
1119  continue;
1120  if (it.index() >= block->getNumArguments() ||
1121  !state.areEquivalentBufferizedValues(it.value(),
1122  block->getArgument(it.index())))
1123  return yieldOp->emitError()
1124  << "Yield operand #" << it.index()
1125  << " is not equivalent to the corresponding iter bbArg";
1126  }
1127 
1128  return success();
1129  }
1130 };
1131 
1132 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1133 /// this is for analysis only.
1134 struct YieldOpInterface
1135  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1136  scf::YieldOp> {
1137  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1138  const AnalysisState &state) const {
1139  return true;
1140  }
1141 
1142  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1143  const AnalysisState &state) const {
1144  return false;
1145  }
1146 
1147  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1148  const AnalysisState &state) const {
1149  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1150  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1151  BufferRelation::Equivalent, /*isDefinite=*/false}};
1152  }
1153  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1154  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1156  return {};
1157  }
1158 
1159  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1160  const AnalysisState &state) const {
1161  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1162  // may be generated inside the block. We should not return/yield allocations
1163  // when possible.
1164  return true;
1165  }
1166 
1167  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1169  BufferizationState &state) const {
1170  auto yieldOp = cast<scf::YieldOp>(op);
1171  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1172  scf::WhileOp>(yieldOp->getParentOp()))
1173  return yieldOp->emitError("unsupported scf::YieldOp parent");
1174 
1175  SmallVector<Value> newResults;
1176  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1177  Value value = it.value();
1178  if (isa<TensorType>(value.getType())) {
1179  FailureOr<Value> maybeBuffer =
1180  getBuffer(rewriter, value, options, state);
1181  if (failed(maybeBuffer))
1182  return failure();
1183  Value buffer = *maybeBuffer;
1184  // We may have to cast the value before yielding it.
1185  if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1186  yieldOp->getParentOp())) {
1187  FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1188  yieldOp->getParentOp()->getResult(it.index()), options, state);
1189  if (failed(resultType))
1190  return failure();
1191  buffer = castBuffer(rewriter, buffer, *resultType);
1192  } else if (auto whileOp =
1193  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1194  FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1195  whileOp.getBeforeArguments()[it.index()], options, state);
1196  if (failed(resultType))
1197  return failure();
1198  buffer = castBuffer(rewriter, buffer, *resultType);
1199  }
1200  newResults.push_back(buffer);
1201  } else {
1202  newResults.push_back(value);
1203  }
1204  }
1205 
1206  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1207  return success();
1208  }
1209 };
1210 
1211 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1212 /// region. There are op interfaces for the terminators (InParallelOp
1213 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1214 /// for bufferization.
1215 struct ForallOpInterface
1216  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1217  ForallOp> {
1218  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1219  const AnalysisState &state) const {
1220  // All tensor operands to `scf.forall` are `shared_outs` and all
1221  // shared outs are assumed to be read by the loop. This does not
1222  // account for the case where the entire value is over-written,
1223  // but being conservative here.
1224  return true;
1225  }
1226 
1227  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1228  const AnalysisState &state) const {
1229  // Outputs of scf::ForallOps are always considered as a write.
1230  return true;
1231  }
1232 
1233  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1234  const AnalysisState &state) const {
1235  auto forallOp = cast<ForallOp>(op);
1236  return {
1237  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1238  }
1239 
1240  bool isWritable(Operation *op, Value value,
1241  const AnalysisState &state) const {
1242  return true;
1243  }
1244 
1245  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1247  BufferizationState &state) const {
1248  OpBuilder::InsertionGuard guard(rewriter);
1249  auto forallOp = cast<ForallOp>(op);
1250  int64_t rank = forallOp.getRank();
1251 
1252  // Get buffers for all output operands.
1253  SmallVector<Value> buffers;
1254  for (Value out : forallOp.getOutputs()) {
1255  FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1256  if (failed(buffer))
1257  return failure();
1258  buffers.push_back(*buffer);
1259  }
1260 
1261  // Use buffers instead of block arguments.
1262  rewriter.setInsertionPointToStart(forallOp.getBody());
1263  for (const auto &it : llvm::zip(
1264  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1265  BlockArgument bbArg = std::get<0>(it);
1266  Value buffer = std::get<1>(it);
1267  Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1268  bbArg.getType(), buffer);
1269  bbArg.replaceAllUsesWith(bufferAsTensor);
1270  }
1271 
1272  // Create new ForallOp without any results and drop the automatically
1273  // introduced terminator.
1274  rewriter.setInsertionPoint(forallOp);
1275  ForallOp newForallOp;
1276  newForallOp = ForallOp::create(
1277  rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1278  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1279  /*outputs=*/ValueRange(), forallOp.getMapping());
1280 
1281  // Keep discardable attributes from the original op.
1282  newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1283 
1284  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1285 
1286  // Move over block contents of the old op.
1287  SmallVector<Value> replacementBbArgs;
1288  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1289  newForallOp.getBody()->getArguments().end());
1290  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1291  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1292  replacementBbArgs);
1293 
1294  // Remove the old op and replace all of its uses.
1295  replaceOpWithBufferizedValues(rewriter, op, buffers);
1296 
1297  return success();
1298  }
1299 
1300  FailureOr<BufferLikeType>
1302  const BufferizationState &state,
1303  SmallVector<Value> &invocationStack) const {
1304  auto forallOp = cast<ForallOp>(op);
1305 
1306  if (auto bbArg = dyn_cast<BlockArgument>(value))
1307  // A tensor block argument has the same bufferized type as the
1308  // corresponding output operand.
1310  forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1311  invocationStack);
1312 
1313  // The bufferized result type is the same as the bufferized type of the
1314  // corresponding output operand.
1316  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1317  state, invocationStack);
1318  }
1319 
1320  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1321  auto forallOp = cast<ForallOp>(op);
1322 
1323  // This op is repetitive if it has 1 or more steps.
1324  // If the control variables are dynamic, it is also considered so.
1325  for (auto [lb, ub, step] :
1326  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1327  forallOp.getMixedStep())) {
1328  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1329  if (!lbConstant)
1330  return true;
1331 
1332  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1333  if (!ubConstant)
1334  return true;
1335 
1336  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1337  if (!stepConstant)
1338  return true;
1339 
1340  if (*lbConstant + *stepConstant < *ubConstant)
1341  return true;
1342  }
1343  return false;
1344  }
1345 
1346  bool isParallelRegion(Operation *op, unsigned index) const {
1347  return isRepetitiveRegion(op, index);
1348  }
1349 };
1350 
1351 /// Nothing to do for InParallelOp.
1352 struct InParallelOpInterface
1353  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1354  InParallelOp> {
1355  LogicalResult bufferize(Operation *op, RewriterBase &b,
1357  BufferizationState &state) const {
1358  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1359  return failure();
1360  }
1361 };
1362 
1363 } // namespace
1364 } // namespace scf
1365 } // namespace mlir
1366 
1368  DialectRegistry &registry) {
1369  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1370  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1371  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1372  ForOp::attachInterface<ForOpInterface>(*ctx);
1373  IfOp::attachInterface<IfOpInterface>(*ctx);
1374  IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1375  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1376  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1377  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1378  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1379  });
1380 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:309
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:318
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:321
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
result_range getOpResults()
Definition: Operation.h:420
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition: Region.cpp:50
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:393
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...