MLIR  20.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Dialect.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/PatternMatch.h"
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::scf;
27 
28 namespace mlir {
29 namespace scf {
30 namespace {
31 
32 /// Helper function for loop bufferization. Cast the given buffer to the given
33 /// memref type.
34 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37  // If the buffer already has the correct type, no cast is needed.
38  if (buffer.getType() == type)
39  return buffer;
40  // TODO: In case `type` has a layout map that is not the fully dynamic
41  // one, we may not be able to cast the buffer. In that case, the loop
42  // iter_arg's layout map must be changed (see uses of `castBuffer`).
43  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44  "scf.while op bufferization: cast incompatible");
45  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
46 }
47 
48 /// Helper function for loop bufferization. Return "true" if the given value
49 /// is guaranteed to not alias with an external tensor apart from values in
50 /// `exceptions`. A value is external if it is defined outside of the given
51 /// region or if it is an entry block argument of the region.
52 static bool doesNotAliasExternalValue(Value value, Region *region,
53  ValueRange exceptions,
54  const OneShotAnalysisState &state) {
55  assert(region->getBlocks().size() == 1 &&
56  "expected region with single block");
57  bool result = true;
58  state.applyOnAliases(value, [&](Value alias) {
59  if (llvm::is_contained(exceptions, alias))
60  return;
61  Region *aliasRegion = alias.getParentRegion();
62  if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
63  result = false;
64  if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
65  result = false;
66  });
67  return result;
68 }
69 
70 /// Bufferization of scf.condition.
71 struct ConditionOpInterface
72  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
73  scf::ConditionOp> {
74  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
75  const AnalysisState &state) const {
76  return true;
77  }
78 
79  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
80  const AnalysisState &state) const {
81  return false;
82  }
83 
84  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
85  const AnalysisState &state) const {
86  return {};
87  }
88 
89  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
90  const AnalysisState &state) const {
91  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
92  // may be generated inside the block. We should not return/yield allocations
93  // when possible.
94  return true;
95  }
96 
97  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
98  const BufferizationOptions &options) const {
99  auto conditionOp = cast<scf::ConditionOp>(op);
100  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101 
102  SmallVector<Value> newArgs;
103  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104  Value value = it.value();
105  if (isa<TensorType>(value.getType())) {
106  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
107  if (failed(maybeBuffer))
108  return failure();
109  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
110  whileOp.getAfterArguments()[it.index()], options);
111  if (failed(resultType))
112  return failure();
113  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114  newArgs.push_back(buffer);
115  } else {
116  newArgs.push_back(value);
117  }
118  }
119 
120  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121  rewriter, op, conditionOp.getCondition(), newArgs);
122  return success();
123  }
124 };
125 
126 /// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127 /// return an empty op.
128 static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129  scf::YieldOp result;
130  for (Block &block : executeRegionOp.getRegion()) {
131  if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132  if (result)
133  return {};
134  result = yieldOp;
135  }
136  }
137  return result;
138 }
139 
140 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141 /// fully implemented at the moment.
142 struct ExecuteRegionOpInterface
144  ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145 
146  static bool supportsUnstructuredControlFlow() { return true; }
147 
148  bool isWritable(Operation *op, Value value,
149  const AnalysisState &state) const {
150  return true;
151  }
152 
153  LogicalResult verifyAnalysis(Operation *op,
154  const AnalysisState &state) const {
155  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156  // TODO: scf.execute_region with multiple yields are not supported.
157  if (!getUniqueYieldOp(executeRegionOp))
158  return op->emitOpError("op without unique scf.yield is not supported");
159  return success();
160  }
161 
163  getAliasingOpOperands(Operation *op, Value value,
164  const AnalysisState &state) const {
165  if (auto bbArg = dyn_cast<BlockArgument>(value))
166  return getAliasingBranchOpOperands(op, bbArg, state);
167 
168  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169  // any SSA value that is in scope. To allow for use-def chain traversal
170  // through ExecuteRegionOps in the analysis, the corresponding yield value
171  // is considered to be aliasing with the result.
172  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173  auto it = llvm::find(op->getOpResults(), value);
174  assert(it != op->getOpResults().end() && "invalid value");
175  size_t resultNum = std::distance(op->getOpResults().begin(), it);
176  auto yieldOp = getUniqueYieldOp(executeRegionOp);
177  // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178  if (!yieldOp)
179  return {};
180  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181  }
182 
183  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184  const BufferizationOptions &options) const {
185  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
186  auto yieldOp = getUniqueYieldOp(executeRegionOp);
187  TypeRange newResultTypes(yieldOp.getResults());
188 
189  // Create new op and move over region.
190  auto newOp =
191  rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192  newOp.getRegion().takeBody(executeRegionOp.getRegion());
193 
194  // Bufferize every block.
195  for (Block &block : newOp.getRegion())
196  if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
197  options)))
198  return failure();
199 
200  // Update all uses of the old op.
201  rewriter.setInsertionPointAfter(newOp);
202  SmallVector<Value> newResults;
203  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
204  if (isa<TensorType>(it.value())) {
205  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
206  executeRegionOp.getLoc(), it.value(),
207  newOp->getResult(it.index())));
208  } else {
209  newResults.push_back(newOp->getResult(it.index()));
210  }
211  }
212 
213  // Replace old op.
214  rewriter.replaceOp(executeRegionOp, newResults);
215 
216  return success();
217  }
218 };
219 
220 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
221 struct IfOpInterface
222  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
224  getAliasingOpOperands(Operation *op, Value value,
225  const AnalysisState &state) const {
226  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
227  // value that is in scope. To allow for use-def chain traversal through
228  // IfOps in the analysis, both corresponding yield values from the then/else
229  // branches are considered to be aliasing with the result.
230  auto ifOp = cast<scf::IfOp>(op);
231  size_t resultNum = std::distance(op->getOpResults().begin(),
232  llvm::find(op->getOpResults(), value));
233  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
234  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
235  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
236  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
237  }
238 
239  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
240  const BufferizationOptions &options) const {
241  OpBuilder::InsertionGuard g(rewriter);
242  auto ifOp = cast<scf::IfOp>(op);
243 
244  // Compute bufferized result types.
245  SmallVector<Type> newTypes;
246  for (Value result : ifOp.getResults()) {
247  if (!isa<TensorType>(result.getType())) {
248  newTypes.push_back(result.getType());
249  continue;
250  }
251  auto bufferType = bufferization::getBufferType(result, options);
252  if (failed(bufferType))
253  return failure();
254  newTypes.push_back(*bufferType);
255  }
256 
257  // Create new op.
258  rewriter.setInsertionPoint(ifOp);
259  auto newIfOp =
260  rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
261  /*withElseRegion=*/true);
262 
263  // Move over then/else blocks.
264  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
265  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
266 
267  // Replace op results.
268  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
269 
270  return success();
271  }
272 
273  FailureOr<BaseMemRefType>
275  SmallVector<Value> &invocationStack) const {
276  auto ifOp = cast<scf::IfOp>(op);
277  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
278  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
279  assert(value.getDefiningOp() == op && "invalid valid");
280 
281  // Determine buffer types of the true/false branches.
282  auto opResult = cast<OpResult>(value);
283  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
284  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
285  BaseMemRefType thenBufferType, elseBufferType;
286  if (isa<BaseMemRefType>(thenValue.getType())) {
287  // True branch was already bufferized.
288  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
289  } else {
290  auto maybeBufferType =
291  bufferization::getBufferType(thenValue, options, invocationStack);
292  if (failed(maybeBufferType))
293  return failure();
294  thenBufferType = *maybeBufferType;
295  }
296  if (isa<BaseMemRefType>(elseValue.getType())) {
297  // False branch was already bufferized.
298  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
299  } else {
300  auto maybeBufferType =
301  bufferization::getBufferType(elseValue, options, invocationStack);
302  if (failed(maybeBufferType))
303  return failure();
304  elseBufferType = *maybeBufferType;
305  }
306 
307  // Best case: Both branches have the exact same buffer type.
308  if (thenBufferType == elseBufferType)
309  return thenBufferType;
310 
311  // Memory space mismatch.
312  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
313  return op->emitError("inconsistent memory space on then/else branches");
314 
315  // Layout maps are different: Promote to fully dynamic layout map.
317  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
318  }
319 };
320 
321 /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
322 /// yields memrefs.
323 struct IndexSwitchOpInterface
324  : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
325  scf::IndexSwitchOp> {
327  getAliasingOpOperands(Operation *op, Value value,
328  const AnalysisState &state) const {
329  // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
330  // any SSA. This is similar to IfOps.
331  auto switchOp = cast<scf::IndexSwitchOp>(op);
332  int64_t resultNum = cast<OpResult>(value).getResultNumber();
333  AliasingOpOperandList result;
334  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
335  auto yieldOp =
336  cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
337  result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
339  /*isDefinite=*/false));
340  }
341  auto defaultYieldOp =
342  cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
343  result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
345  /*isDefinite=*/false));
346  return result;
347  }
348 
349  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
350  const BufferizationOptions &options) const {
351  OpBuilder::InsertionGuard g(rewriter);
352  auto switchOp = cast<scf::IndexSwitchOp>(op);
353 
354  // Compute bufferized result types.
355  SmallVector<Type> newTypes;
356  for (Value result : switchOp.getResults()) {
357  if (!isa<TensorType>(result.getType())) {
358  newTypes.push_back(result.getType());
359  continue;
360  }
361  auto bufferType = bufferization::getBufferType(result, options);
362  if (failed(bufferType))
363  return failure();
364  newTypes.push_back(*bufferType);
365  }
366 
367  // Create new op.
368  rewriter.setInsertionPoint(switchOp);
369  auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
370  switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
371  switchOp.getCases().size());
372 
373  // Move over blocks.
374  for (auto [src, dest] :
375  llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
376  rewriter.inlineRegionBefore(src, dest, dest.begin());
377  rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
378  newSwitchOp.getDefaultRegion(),
379  newSwitchOp.getDefaultRegion().begin());
380 
381  // Replace op results.
382  replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
383 
384  return success();
385  }
386 
387  FailureOr<BaseMemRefType>
389  SmallVector<Value> &invocationStack) const {
390  auto switchOp = cast<scf::IndexSwitchOp>(op);
391  assert(value.getDefiningOp() == op && "invalid value");
392  int64_t resultNum = cast<OpResult>(value).getResultNumber();
393 
394  // Helper function to get buffer type of a case.
395  SmallVector<BaseMemRefType> yieldedTypes;
396  auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
397  auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
398  Value yieldedValue = yieldOp->getOperand(resultNum);
399  if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
400  return bufferType;
401  auto maybeBufferType =
402  bufferization::getBufferType(yieldedValue, options, invocationStack);
403  if (failed(maybeBufferType))
404  return failure();
405  return maybeBufferType;
406  };
407 
408  // Compute buffer type of the default case.
409  auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
410  if (failed(maybeBufferType))
411  return failure();
412  BaseMemRefType bufferType = *maybeBufferType;
413 
414  // Compute buffer types of all other cases.
415  for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
416  auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
417  if (failed(yieldedBufferType))
418  return failure();
419 
420  // Best case: Both branches have the exact same buffer type.
421  if (bufferType == *yieldedBufferType)
422  continue;
423 
424  // Memory space mismatch.
425  if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
426  return op->emitError("inconsistent memory space on switch cases");
427 
428  // Layout maps are different: Promote to fully dynamic layout map.
430  cast<TensorType>(value.getType()), bufferType.getMemorySpace());
431  }
432 
433  return bufferType;
434  }
435 };
436 
437 /// Helper function for loop bufferization. Return the indices of all values
438 /// that have a tensor type.
439 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
440  DenseSet<int64_t> result;
441  for (const auto &it : llvm::enumerate(values))
442  if (isa<TensorType>(it.value().getType()))
443  result.insert(it.index());
444  return result;
445 }
446 
447 /// Helper function for loop bufferization. Return the indices of all
448 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
449 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
450  ValueRange yieldedValues,
451  const AnalysisState &state) {
452  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
453  DenseSet<int64_t> result;
454  for (unsigned int i = 0; i < minSize; ++i) {
455  if (!isa<TensorType>(bbArgs[i].getType()) ||
456  !isa<TensorType>(yieldedValues[i].getType()))
457  continue;
458  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
459  result.insert(i);
460  }
461  return result;
462 }
463 
464 /// Helper function for loop bufferization. Return the bufferized values of the
465 /// given OpOperands. If an operand is not a tensor, return the original value.
466 static FailureOr<SmallVector<Value>>
467 getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
468  const BufferizationOptions &options) {
469  SmallVector<Value> result;
470  for (OpOperand &opOperand : operands) {
471  if (isa<TensorType>(opOperand.get().getType())) {
472  FailureOr<Value> resultBuffer =
473  getBuffer(rewriter, opOperand.get(), options);
474  if (failed(resultBuffer))
475  return failure();
476  result.push_back(*resultBuffer);
477  } else {
478  result.push_back(opOperand.get());
479  }
480  }
481  return result;
482 }
483 
484 /// Helper function for loop bufferization. Given a list of bbArgs of the new
485 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
486 /// ToTensorOps, so that the block body can be moved over to the new op.
487 static SmallVector<Value>
488 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
489  Block::BlockArgListType oldBbArgs,
490  const DenseSet<int64_t> &tensorIndices) {
491  SmallVector<Value> result;
492  for (const auto &it : llvm::enumerate(bbArgs)) {
493  size_t idx = it.index();
494  Value val = it.value();
495  if (tensorIndices.contains(idx)) {
496  result.push_back(rewriter
497  .create<bufferization::ToTensorOp>(
498  val.getLoc(), oldBbArgs[idx].getType(), val)
499  .getResult());
500  } else {
501  result.push_back(val);
502  }
503  }
504  return result;
505 }
506 
507 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
508 /// the bufferized type of the corresponding init_arg and the bufferized type
509 /// of the corresponding yielded value.
510 ///
511 /// This function uses bufferization::getBufferType to compute the bufferized
512 /// type of the init_arg and of the yielded value. (The computation of the
513 /// bufferized yielded value type usually requires computing the bufferized type
514 /// of the iter_arg again; the implementation of getBufferType traces back the
515 /// use-def chain of the given value and computes a buffer type along the way.)
516 /// If both buffer types are equal, no casts are needed the computed buffer type
517 /// can be used directly. Otherwise, the buffer types can only differ in their
518 /// layout map and a cast must be inserted.
519 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
520  Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
521  const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
522  // Determine the buffer type of the init_arg.
523  auto initArgBufferType =
524  bufferization::getBufferType(initArg, options, invocationStack);
525  if (failed(initArgBufferType))
526  return failure();
527 
528  if (llvm::count(invocationStack, iterArg) >= 2) {
529  // If the iter_arg is already twice on the invocation stack, just take the
530  // type of the init_arg. This is to avoid infinite loops when calculating
531  // the buffer type. This will most likely result in computing a memref type
532  // with a fully dynamic layout map.
533 
534  // Note: For more precise layout map computation, a fixpoint iteration could
535  // be done (i.e., re-computing the yielded buffer type until the bufferized
536  // iter_arg type no longer changes). This current implementation immediately
537  // switches to a fully dynamic layout map when a mismatch between bufferized
538  // init_arg type and bufferized yield value type is detected.
539  return *initArgBufferType;
540  }
541 
542  // Compute the buffer type of the yielded value.
543  BaseMemRefType yieldedValueBufferType;
544  if (isa<BaseMemRefType>(yieldedValue.getType())) {
545  // scf.yield was already bufferized.
546  yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
547  } else {
548  // Note: This typically triggers a recursive call for the buffer type of
549  // the iter_arg.
550  auto maybeBufferType =
551  bufferization::getBufferType(yieldedValue, options, invocationStack);
552  if (failed(maybeBufferType))
553  return failure();
554  yieldedValueBufferType = *maybeBufferType;
555  }
556 
557  // If yielded type and init_arg type are the same, use that type directly.
558  if (*initArgBufferType == yieldedValueBufferType)
559  return yieldedValueBufferType;
560 
561  // If there is a mismatch between the yielded buffer type and the init_arg
562  // buffer type, the buffer type must be promoted to a fully dynamic layout
563  // map.
564  auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
565  auto iterTensorType = cast<TensorType>(iterArg.getType());
566  auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
567  if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
568  return loopOp->emitOpError(
569  "init_arg and yielded value bufferize to inconsistent memory spaces");
570 #ifndef NDEBUG
571  if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
572  assert(
573  llvm::all_equal({yieldedRankedBufferType.getShape(),
574  cast<MemRefType>(initBufferType).getShape(),
575  cast<RankedTensorType>(iterTensorType).getShape()}) &&
576  "expected same shape");
577  }
578 #endif // NDEBUG
580  iterTensorType, yieldedBufferType.getMemorySpace());
581 }
582 
583 /// Return `true` if the given loop may have 0 iterations.
584 bool mayHaveZeroIterations(scf::ForOp forOp) {
585  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
586  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
587  if (!lb.has_value() || !ub.has_value())
588  return true;
589  return *ub <= *lb;
590 }
591 
592 /// Bufferization of scf.for. Replace with a new scf.for that operates on
593 /// memrefs.
594 struct ForOpInterface
595  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
596  scf::ForOp> {
597  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
598  const AnalysisState &state) const {
599  auto forOp = cast<scf::ForOp>(op);
600 
601  // If the loop has zero iterations, the results of the op are their
602  // corresponding init_args, meaning that the init_args bufferize to a read.
603  if (mayHaveZeroIterations(forOp))
604  return true;
605 
606  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
607  // its matching bbArg may.
608  return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
609  }
610 
611  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
612  const AnalysisState &state) const {
613  // Tensor iter_args of scf::ForOps are always considered as a write.
614  return true;
615  }
616 
617  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
618  const AnalysisState &state) const {
619  auto forOp = cast<scf::ForOp>(op);
620  OpResult opResult = forOp.getTiedLoopResult(&opOperand);
621  BufferRelation relation = bufferRelation(op, opResult, state);
622  return {{opResult, relation,
623  /*isDefinite=*/relation == BufferRelation::Equivalent}};
624  }
625 
626  BufferRelation bufferRelation(Operation *op, OpResult opResult,
627  const AnalysisState &state) const {
628  // ForOp results are equivalent to their corresponding init_args if the
629  // corresponding iter_args and yield values are equivalent.
630  auto forOp = cast<scf::ForOp>(op);
631  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
632  bool equivalentYield = state.areEquivalentBufferizedValues(
633  bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
634  return equivalentYield ? BufferRelation::Equivalent
636  }
637 
638  bool isWritable(Operation *op, Value value,
639  const AnalysisState &state) const {
640  // Interestingly, scf::ForOp's bbArg can **always** be viewed
641  // inplace from the perspective of ops nested under:
642  // 1. Either the matching iter operand is not bufferized inplace and an
643  // alloc + optional copy makes the bbArg itself inplaceable.
644  // 2. Or the matching iter operand is bufferized inplace and bbArg just
645  // bufferizes to that too.
646  return true;
647  }
648 
649  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
650  const AnalysisState &state) const {
651  auto bufferizableOp = cast<BufferizableOpInterface>(op);
652  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
653  return failure();
654 
655  if (!state.getOptions().enforceAliasingInvariants ||
656  state.getOptions().copyBeforeWrite)
657  return success();
658 
659  // According to the `getAliasing...` implementations, a bufferized OpResult
660  // may alias only with the corresponding bufferized init_arg (or with a
661  // newly allocated buffer) and not with other buffers defined outside of the
662  // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
663  // but not with any other OpOperand.
664  auto forOp = cast<scf::ForOp>(op);
665  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
666  OpBuilder::InsertionGuard g(rewriter);
667  rewriter.setInsertionPoint(yieldOp);
668 
669  // Indices of all iter_args that have tensor type. These are the ones that
670  // are bufferized.
671  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
672  // For every yielded value, does it alias with something defined outside of
673  // the loop?
674  SmallVector<Value> yieldValues;
675  for (const auto it : llvm::enumerate(yieldOp.getResults())) {
676  // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
677  // type cannot be used in the signature of `resolveConflicts` because the
678  // op interface is in the "IR" build unit and the `OneShotAnalysisState`
679  // is defined in the "Transforms" build unit.
680  if (!indices.contains(it.index()) ||
681  doesNotAliasExternalValue(
682  it.value(), &forOp.getRegion(),
683  /*exceptions=*/forOp.getRegionIterArg(it.index()),
684  static_cast<const OneShotAnalysisState &>(state))) {
685  yieldValues.push_back(it.value());
686  continue;
687  }
688  FailureOr<Value> alloc = allocateTensorForShapedValue(
689  rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
690  if (failed(alloc))
691  return failure();
692  yieldValues.push_back(*alloc);
693  }
694 
695  rewriter.modifyOpInPlace(
696  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
697  return success();
698  }
699 
700  FailureOr<BaseMemRefType>
702  SmallVector<Value> &invocationStack) const {
703  auto forOp = cast<scf::ForOp>(op);
704  assert(getOwnerOfValue(value) == op && "invalid value");
705  assert(isa<TensorType>(value.getType()) && "expected tensor type");
706 
707  if (auto opResult = dyn_cast<OpResult>(value)) {
708  // The type of an OpResult must match the corresponding iter_arg type.
709  BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
710  return bufferization::getBufferType(bbArg, options, invocationStack);
711  }
712 
713  // Compute result/argument number.
714  BlockArgument bbArg = cast<BlockArgument>(value);
715  unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
716 
717  // Compute the bufferized type.
718  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
719  Value yieldedValue = yieldOp.getOperand(resultNum);
720  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
721  Value initArg = forOp.getInitArgs()[resultNum];
722  return computeLoopRegionIterArgBufferType(
723  op, iterArg, initArg, yieldedValue, options, invocationStack);
724  }
725 
726  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
727  const BufferizationOptions &options) const {
728  auto forOp = cast<scf::ForOp>(op);
729  Block *oldLoopBody = forOp.getBody();
730 
731  // Indices of all iter_args that have tensor type. These are the ones that
732  // are bufferized.
733  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
734 
735  // The new memref init_args of the loop.
736  FailureOr<SmallVector<Value>> maybeInitArgs =
737  getBuffers(rewriter, forOp.getInitArgsMutable(), options);
738  if (failed(maybeInitArgs))
739  return failure();
740  SmallVector<Value> initArgs = *maybeInitArgs;
741 
742  // Cast init_args if necessary.
743  SmallVector<Value> castedInitArgs;
744  for (const auto &it : llvm::enumerate(initArgs)) {
745  Value initArg = it.value();
746  Value result = forOp->getResult(it.index());
747  // If the type is not a tensor, bufferization doesn't need to touch it.
748  if (!isa<TensorType>(result.getType())) {
749  castedInitArgs.push_back(initArg);
750  continue;
751  }
752  auto targetType = bufferization::getBufferType(result, options);
753  if (failed(targetType))
754  return failure();
755  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
756  }
757 
758  // Construct a new scf.for op with memref instead of tensor values.
759  auto newForOp = rewriter.create<scf::ForOp>(
760  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
761  forOp.getStep(), castedInitArgs);
762  newForOp->setAttrs(forOp->getAttrs());
763  Block *loopBody = newForOp.getBody();
764 
765  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
766  // iter_args of the new loop in ToTensorOps.
767  rewriter.setInsertionPointToStart(loopBody);
768  SmallVector<Value> iterArgs =
769  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
770  forOp.getRegionIterArgs(), indices);
771  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
772 
773  // Move loop body to new loop.
774  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
775 
776  // Replace loop results.
777  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
778 
779  return success();
780  }
781 
782  /// Assert that yielded values of an scf.for op are equivalent to their
783  /// corresponding bbArgs. In that case, the buffer relations of the
784  /// corresponding OpResults are "Equivalent".
785  ///
786  /// If this is not the case, an allocs+copies are inserted and yielded from
787  /// the loop. This could be a performance problem, so it must be explicitly
788  /// activated with `alloc-return-allocs`.
789  LogicalResult verifyAnalysis(Operation *op,
790  const AnalysisState &state) const {
791  const auto &options =
792  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
793  if (options.allowReturnAllocsFromLoops)
794  return success();
795 
796  auto forOp = cast<scf::ForOp>(op);
797  auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
798  for (OpResult opResult : op->getOpResults()) {
799  if (!isa<TensorType>(opResult.getType()))
800  continue;
801 
802  // Note: This is overly strict. We should check for aliasing bufferized
803  // values. But we don't have a "must-alias" analysis yet.
804  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
805  return yieldOp->emitError()
806  << "Yield operand #" << opResult.getResultNumber()
807  << " is not equivalent to the corresponding iter bbArg";
808  }
809 
810  return success();
811  }
812 };
813 
814 /// Bufferization of scf.while. Replace with a new scf.while that operates on
815 /// memrefs.
816 struct WhileOpInterface
817  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
818  scf::WhileOp> {
819  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
820  const AnalysisState &state) const {
821  // Tensor iter_args of scf::WhileOps are always considered as a read.
822  return true;
823  }
824 
825  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
826  const AnalysisState &state) const {
827  // Tensor iter_args of scf::WhileOps are always considered as a write.
828  return true;
829  }
830 
831  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
832  const AnalysisState &state) const {
833  auto whileOp = cast<scf::WhileOp>(op);
834  unsigned int idx = opOperand.getOperandNumber();
835 
836  // The OpResults and OpOperands may not match. They may not even have the
837  // same type. The number of OpResults and OpOperands can also differ.
838  if (idx >= op->getNumResults() ||
839  opOperand.get().getType() != op->getResult(idx).getType())
840  return {};
841 
842  // The only aliasing OpResult may be the one at the same index.
843  OpResult opResult = whileOp->getResult(idx);
844  BufferRelation relation = bufferRelation(op, opResult, state);
845  return {{opResult, relation,
846  /*isDefinite=*/relation == BufferRelation::Equivalent}};
847  }
848 
849  BufferRelation bufferRelation(Operation *op, OpResult opResult,
850  const AnalysisState &state) const {
851  // WhileOp results are equivalent to their corresponding init_args if the
852  // corresponding iter_args and yield values are equivalent (for both the
853  // "before" and the "after" block).
854  unsigned int resultNumber = opResult.getResultNumber();
855  auto whileOp = cast<scf::WhileOp>(op);
856 
857  // The "before" region bbArgs and the OpResults may not match.
858  if (resultNumber >= whileOp.getBeforeArguments().size())
860  if (opResult.getType() !=
861  whileOp.getBeforeArguments()[resultNumber].getType())
863 
864  auto conditionOp = whileOp.getConditionOp();
865  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
866  Value conditionOperand = conditionOp.getArgs()[resultNumber];
867  bool equivCondition =
868  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
869 
870  auto yieldOp = whileOp.getYieldOp();
871  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
872  Value yieldOperand = yieldOp.getOperand(resultNumber);
873  bool equivYield =
874  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
875 
876  return equivCondition && equivYield ? BufferRelation::Equivalent
878  }
879 
880  bool isWritable(Operation *op, Value value,
881  const AnalysisState &state) const {
882  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
883  // inplace from the perspective of ops nested under:
884  // 1. Either the matching iter operand is not bufferized inplace and an
885  // alloc + optional copy makes the bbArg itself inplaceable.
886  // 2. Or the matching iter operand is bufferized inplace and bbArg just
887  // bufferizes to that too.
888  return true;
889  }
890 
891  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
892  const AnalysisState &state) const {
893  auto bufferizableOp = cast<BufferizableOpInterface>(op);
894  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
895  return failure();
896 
897  if (!state.getOptions().enforceAliasingInvariants ||
898  state.getOptions().copyBeforeWrite)
899  return success();
900 
901  // According to the `getAliasing...` implementations, a bufferized OpResult
902  // may alias only with the corresponding bufferized init_arg and with no
903  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
904  // but not with any other OpOperand. If a corresponding OpResult/init_arg
905  // pair bufferizes to equivalent buffers, this aliasing requirement is
906  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
907  // (New buffer copies do not alias with any buffer.)
908  OpBuilder::InsertionGuard g(rewriter);
909  auto whileOp = cast<scf::WhileOp>(op);
910  auto conditionOp = whileOp.getConditionOp();
911 
912  // For every yielded value, is the value equivalent to its corresponding
913  // bbArg?
914  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
915  whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
916  DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
917  whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
918 
919  // Update "before" region.
920  rewriter.setInsertionPoint(conditionOp);
921  SmallVector<Value> beforeYieldValues;
922  for (int64_t idx = 0;
923  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
924  Value value = conditionOp.getArgs()[idx];
925  if (!isa<TensorType>(value.getType()) ||
926  (equivalentYieldsAfter.contains(idx) &&
927  equivalentYieldsBefore.contains(idx))) {
928  beforeYieldValues.push_back(value);
929  continue;
930  }
931  FailureOr<Value> alloc = allocateTensorForShapedValue(
932  rewriter, conditionOp.getLoc(), value, state.getOptions());
933  if (failed(alloc))
934  return failure();
935  beforeYieldValues.push_back(*alloc);
936  }
937  rewriter.modifyOpInPlace(conditionOp, [&]() {
938  conditionOp.getArgsMutable().assign(beforeYieldValues);
939  });
940 
941  return success();
942  }
943 
944  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
945  const BufferizationOptions &options) const {
946  auto whileOp = cast<scf::WhileOp>(op);
947 
948  // Indices of all bbArgs that have tensor type. These are the ones that
949  // are bufferized. The "before" and "after" regions may have different args.
950  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
951  DenseSet<int64_t> indicesAfter =
952  getTensorIndices(whileOp.getAfterArguments());
953 
954  // The new memref init_args of the loop.
955  FailureOr<SmallVector<Value>> maybeInitArgs =
956  getBuffers(rewriter, whileOp.getInitsMutable(), options);
957  if (failed(maybeInitArgs))
958  return failure();
959  SmallVector<Value> initArgs = *maybeInitArgs;
960 
961  // Cast init_args if necessary.
962  SmallVector<Value> castedInitArgs;
963  for (const auto &it : llvm::enumerate(initArgs)) {
964  Value initArg = it.value();
965  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
966  // If the type is not a tensor, bufferization doesn't need to touch it.
967  if (!isa<TensorType>(beforeArg.getType())) {
968  castedInitArgs.push_back(initArg);
969  continue;
970  }
971  auto targetType = bufferization::getBufferType(beforeArg, options);
972  if (failed(targetType))
973  return failure();
974  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
975  }
976 
977  // The result types of a WhileOp are the same as the "after" bbArg types.
978  SmallVector<Type> argsTypesAfter = llvm::to_vector(
979  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
980  if (!isa<TensorType>(bbArg.getType()))
981  return bbArg.getType();
982  // TODO: error handling
983  return llvm::cast<Type>(
984  *bufferization::getBufferType(bbArg, options));
985  }));
986 
987  // Construct a new scf.while op with memref instead of tensor values.
988  ValueRange argsRangeBefore(castedInitArgs);
989  TypeRange argsTypesBefore(argsRangeBefore);
990  auto newWhileOp = rewriter.create<scf::WhileOp>(
991  whileOp.getLoc(), argsTypesAfter, castedInitArgs);
992 
993  // Add before/after regions to the new op.
994  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
995  whileOp.getLoc());
996  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
997  whileOp.getLoc());
998  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
999  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1000  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1001  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1002 
1003  // Set up new iter_args and move the loop condition block to the new op.
1004  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1005  // in ToTensorOps.
1006  rewriter.setInsertionPointToStart(newBeforeBody);
1007  SmallVector<Value> newBeforeArgs =
1008  getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1009  whileOp.getBeforeArguments(), indicesBefore);
1010  rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1011 
1012  // Set up new iter_args and move the loop body block to the new op.
1013  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1014  // in ToTensorOps.
1015  rewriter.setInsertionPointToStart(newAfterBody);
1016  SmallVector<Value> newAfterArgs =
1017  getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1018  whileOp.getAfterArguments(), indicesAfter);
1019  rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1020 
1021  // Replace loop results.
1022  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1023 
1024  return success();
1025  }
1026 
1027  FailureOr<BaseMemRefType>
1029  SmallVector<Value> &invocationStack) const {
1030  auto whileOp = cast<scf::WhileOp>(op);
1031  assert(getOwnerOfValue(value) == op && "invalid value");
1032  assert(isa<TensorType>(value.getType()) && "expected tensor type");
1033 
1034  // Case 1: Block argument of the "before" region.
1035  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1036  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1037  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1038  auto yieldOp = whileOp.getYieldOp();
1039  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1040  return computeLoopRegionIterArgBufferType(
1041  op, bbArg, initArg, yieldedValue, options, invocationStack);
1042  }
1043  }
1044 
1045  // Case 2: OpResult of the loop or block argument of the "after" region.
1046  // The bufferized "after" bbArg type can be directly computed from the
1047  // bufferized "before" bbArg type.
1048  unsigned resultNum;
1049  if (auto opResult = dyn_cast<OpResult>(value)) {
1050  resultNum = opResult.getResultNumber();
1051  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1052  &whileOp.getAfter()) {
1053  resultNum = cast<BlockArgument>(value).getArgNumber();
1054  } else {
1055  llvm_unreachable("invalid value");
1056  }
1057  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1058  if (!isa<TensorType>(conditionYieldedVal.getType())) {
1059  // scf.condition was already bufferized.
1060  return cast<BaseMemRefType>(conditionYieldedVal.getType());
1061  }
1062  return bufferization::getBufferType(conditionYieldedVal, options,
1063  invocationStack);
1064  }
1065 
1066  /// Assert that yielded values of an scf.while op are equivalent to their
1067  /// corresponding bbArgs. In that case, the buffer relations of the
1068  /// corresponding OpResults are "Equivalent".
1069  ///
1070  /// If this is not the case, allocs+copies are inserted and yielded from
1071  /// the loop. This could be a performance problem, so it must be explicitly
1072  /// activated with `allow-return-allocs`.
1073  ///
1074  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1075  /// equivalence condition must be checked for both.
1076  LogicalResult verifyAnalysis(Operation *op,
1077  const AnalysisState &state) const {
1078  auto whileOp = cast<scf::WhileOp>(op);
1079  const auto &options =
1080  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1081  if (options.allowReturnAllocsFromLoops)
1082  return success();
1083 
1084  auto conditionOp = whileOp.getConditionOp();
1085  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1086  Block *block = conditionOp->getBlock();
1087  if (!isa<TensorType>(it.value().getType()))
1088  continue;
1089  if (it.index() >= block->getNumArguments() ||
1090  !state.areEquivalentBufferizedValues(it.value(),
1091  block->getArgument(it.index())))
1092  return conditionOp->emitError()
1093  << "Condition arg #" << it.index()
1094  << " is not equivalent to the corresponding iter bbArg";
1095  }
1096 
1097  auto yieldOp = whileOp.getYieldOp();
1098  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1099  Block *block = yieldOp->getBlock();
1100  if (!isa<TensorType>(it.value().getType()))
1101  continue;
1102  if (it.index() >= block->getNumArguments() ||
1103  !state.areEquivalentBufferizedValues(it.value(),
1104  block->getArgument(it.index())))
1105  return yieldOp->emitError()
1106  << "Yield operand #" << it.index()
1107  << " is not equivalent to the corresponding iter bbArg";
1108  }
1109 
1110  return success();
1111  }
1112 };
1113 
1114 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1115 /// this is for analysis only.
1116 struct YieldOpInterface
1117  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1118  scf::YieldOp> {
1119  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1120  const AnalysisState &state) const {
1121  return true;
1122  }
1123 
1124  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1125  const AnalysisState &state) const {
1126  return false;
1127  }
1128 
1129  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1130  const AnalysisState &state) const {
1131  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1132  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1133  BufferRelation::Equivalent, /*isDefinite=*/false}};
1134  }
1135  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1136  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1138  return {};
1139  }
1140 
1141  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1142  const AnalysisState &state) const {
1143  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1144  // may be generated inside the block. We should not return/yield allocations
1145  // when possible.
1146  return true;
1147  }
1148 
1149  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1150  const BufferizationOptions &options) const {
1151  auto yieldOp = cast<scf::YieldOp>(op);
1152  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1153  scf::WhileOp>(yieldOp->getParentOp()))
1154  return yieldOp->emitError("unsupported scf::YieldOp parent");
1155 
1156  SmallVector<Value> newResults;
1157  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1158  Value value = it.value();
1159  if (isa<TensorType>(value.getType())) {
1160  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
1161  if (failed(maybeBuffer))
1162  return failure();
1163  Value buffer = *maybeBuffer;
1164  // We may have to cast the value before yielding it.
1165  if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1166  yieldOp->getParentOp())) {
1167  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1168  yieldOp->getParentOp()->getResult(it.index()), options);
1169  if (failed(resultType))
1170  return failure();
1171  buffer = castBuffer(rewriter, buffer, *resultType);
1172  } else if (auto whileOp =
1173  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1174  FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
1175  whileOp.getBeforeArguments()[it.index()], options);
1176  if (failed(resultType))
1177  return failure();
1178  buffer = castBuffer(rewriter, buffer, *resultType);
1179  }
1180  newResults.push_back(buffer);
1181  } else {
1182  newResults.push_back(value);
1183  }
1184  }
1185 
1186  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1187  return success();
1188  }
1189 };
1190 
1191 /// Return `true` if the given loop may have 0 iterations.
1192 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1193  for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1194  forallOp.getMixedUpperBound())) {
1195  std::optional<int64_t> lbConst = getConstantIntValue(lb);
1196  std::optional<int64_t> ubConst = getConstantIntValue(ub);
1197  if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1198  return true;
1199  }
1200  return false;
1201 }
1202 
1203 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1204 /// region. There are op interfaces for the terminators (InParallelOp
1205 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1206 /// for bufferization.
1207 struct ForallOpInterface
1208  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1209  ForallOp> {
1210  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1211  const AnalysisState &state) const {
1212  auto forallOp = cast<ForallOp>(op);
1213 
1214  // If the loop has zero iterations, the results of the op are their
1215  // corresponding shared_outs, meaning that the shared_outs bufferize to a
1216  // read.
1217  if (mayHaveZeroIterations(forallOp))
1218  return true;
1219 
1220  // scf::ForallOp alone doesn't bufferize to a memory read, one of the
1221  // uses of its matching bbArg may.
1222  return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1223  }
1224 
1225  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1226  const AnalysisState &state) const {
1227  // Outputs of scf::ForallOps are always considered as a write.
1228  return true;
1229  }
1230 
1231  AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1232  const AnalysisState &state) const {
1233  auto forallOp = cast<ForallOp>(op);
1234  return {
1235  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1236  }
1237 
1238  bool isWritable(Operation *op, Value value,
1239  const AnalysisState &state) const {
1240  return true;
1241  }
1242 
1243  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1244  const BufferizationOptions &options) const {
1245  OpBuilder::InsertionGuard guard(rewriter);
1246  auto forallOp = cast<ForallOp>(op);
1247  int64_t rank = forallOp.getRank();
1248 
1249  // Get buffers for all output operands.
1250  SmallVector<Value> buffers;
1251  for (Value out : forallOp.getOutputs()) {
1252  FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1253  if (failed(buffer))
1254  return failure();
1255  buffers.push_back(*buffer);
1256  }
1257 
1258  // Use buffers instead of block arguments.
1259  rewriter.setInsertionPointToStart(forallOp.getBody());
1260  for (const auto &it : llvm::zip(
1261  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1262  BlockArgument bbArg = std::get<0>(it);
1263  Value buffer = std::get<1>(it);
1264  Value bufferAsTensor = rewriter.create<ToTensorOp>(
1265  forallOp.getLoc(), bbArg.getType(), buffer);
1266  bbArg.replaceAllUsesWith(bufferAsTensor);
1267  }
1268 
1269  // Create new ForallOp without any results and drop the automatically
1270  // introduced terminator.
1271  rewriter.setInsertionPoint(forallOp);
1272  ForallOp newForallOp;
1273  newForallOp = rewriter.create<ForallOp>(
1274  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1275  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1276  /*outputs=*/ValueRange(), forallOp.getMapping());
1277 
1278  // Keep discardable attributes from the original op.
1280 
1281  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1282 
1283  // Move over block contents of the old op.
1284  SmallVector<Value> replacementBbArgs;
1285  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1286  newForallOp.getBody()->getArguments().end());
1287  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1288  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1289  replacementBbArgs);
1290 
1291  // Remove the old op and replace all of its uses.
1292  replaceOpWithBufferizedValues(rewriter, op, buffers);
1293 
1294  return success();
1295  }
1296 
1297  FailureOr<BaseMemRefType>
1299  SmallVector<Value> &invocationStack) const {
1300  auto forallOp = cast<ForallOp>(op);
1301 
1302  if (auto bbArg = dyn_cast<BlockArgument>(value))
1303  // A tensor block argument has the same bufferized type as the
1304  // corresponding output operand.
1306  forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
1307 
1308  // The bufferized result type is the same as the bufferized type of the
1309  // corresponding output operand.
1311  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1312  invocationStack);
1313  }
1314 
1315  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1316  auto forallOp = cast<ForallOp>(op);
1317 
1318  // This op is repetitive if it has 1 or more steps.
1319  // If the control variables are dynamic, it is also considered so.
1320  for (auto [lb, ub, step] :
1321  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1322  forallOp.getMixedStep())) {
1323  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1324  if (!lbConstant)
1325  return true;
1326 
1327  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1328  if (!ubConstant)
1329  return true;
1330 
1331  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1332  if (!stepConstant)
1333  return true;
1334 
1335  if (*lbConstant + *stepConstant < *ubConstant)
1336  return true;
1337  }
1338  return false;
1339  }
1340 
1341  bool isParallelRegion(Operation *op, unsigned index) const {
1342  return isRepetitiveRegion(op, index);
1343  }
1344 };
1345 
1346 /// Nothing to do for InParallelOp.
1347 struct InParallelOpInterface
1348  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1349  InParallelOp> {
1350  LogicalResult bufferize(Operation *op, RewriterBase &b,
1351  const BufferizationOptions &options) const {
1352  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1353  return failure();
1354  }
1355 };
1356 
1357 } // namespace
1358 } // namespace scf
1359 } // namespace mlir
1360 
1362  DialectRegistry &registry) {
1363  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1364  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1365  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1366  ForOp::attachInterface<ForOpInterface>(*ctx);
1367  IfOp::attachInterface<IfOpInterface>(*ctx);
1368  IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1369  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1370  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1371  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1372  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1373  });
1374 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:149
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:319
Block * getOwner() const
Returns the block that owns this argument.
Definition: Value.h:328
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition: Operation.h:501
result_range getOpResults()
Definition: Operation.h:420
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition: Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition: Region.cpp:50
BlockListType & getBlocks()
Definition: Region.h:45
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
State for analysis-enabled bufferization.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Definition: Bufferize.cpp:423
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...