MLIR  17.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/Dialect.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/PatternMatch.h"
21 
22 using namespace mlir;
23 using namespace mlir::bufferization;
24 using namespace mlir::scf;
25 
26 namespace mlir {
27 namespace scf {
28 namespace {
29 
30 /// Helper function for loop bufferization. Cast the given buffer to the given
31 /// memref type.
32 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
33  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
34  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
35  // If the buffer already has the correct type, no cast is needed.
36  if (buffer.getType() == type)
37  return buffer;
38  // TODO: In case `type` has a layout map that is not the fully dynamic
39  // one, we may not be able to cast the buffer. In that case, the loop
40  // iter_arg's layout map must be changed (see uses of `castBuffer`).
41  assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
42  "scf.while op bufferization: cast incompatible");
43  return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
44 }
45 
46 /// Bufferization of scf.condition.
47 struct ConditionOpInterface
48  : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
49  scf::ConditionOp> {
50  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
51  const AnalysisState &state) const {
52  return true;
53  }
54 
55  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
56  const AnalysisState &state) const {
57  return false;
58  }
59 
60  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
61  const AnalysisState &state) const {
62  return {};
63  }
64 
65  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
66  const AnalysisState &state) const {
67  // Condition operands always bufferize inplace. Otherwise, an alloc + copy
68  // may be generated inside the block. We should not return/yield allocations
69  // when possible.
70  return true;
71  }
72 
73  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
74  const BufferizationOptions &options) const {
75  auto conditionOp = cast<scf::ConditionOp>(op);
76  auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
77 
78  SmallVector<Value> newArgs;
79  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
80  Value value = it.value();
81  if (isa<TensorType>(value.getType())) {
82  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
83  if (failed(maybeBuffer))
84  return failure();
86  whileOp.getAfterArguments()[it.index()], options);
87  if (failed(resultType))
88  return failure();
89  Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
90  newArgs.push_back(buffer);
91  } else {
92  newArgs.push_back(value);
93  }
94  }
95 
96  replaceOpWithNewBufferizedOp<scf::ConditionOp>(
97  rewriter, op, conditionOp.getCondition(), newArgs);
98  return success();
99  }
100 };
101 
102 /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
103 /// fully implemented at the moment.
104 struct ExecuteRegionOpInterface
105  : public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
106  scf::ExecuteRegionOp> {
108  getAliasingOpOperands(Operation *op, OpResult opResult,
109  const AnalysisState &state) const {
110  // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
111  // any SSA value that is in scope. To allow for use-def chain traversal
112  // through ExecuteRegionOps in the analysis, the corresponding yield value
113  // is considered to be aliasing with the result.
114  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
115  size_t resultNum = std::distance(op->getOpResults().begin(),
116  llvm::find(op->getOpResults(), opResult));
117  // TODO: Support multiple blocks.
118  assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
119  "expected exactly 1 block");
120  auto yieldOp = dyn_cast<scf::YieldOp>(
121  executeRegionOp.getRegion().front().getTerminator());
122  assert(yieldOp && "expected scf.yield terminator in scf.execute_region");
123  return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
124  }
125 
126  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
127  const BufferizationOptions &options) const {
128  auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
129  assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
130  "only 1 block supported");
131  auto yieldOp =
132  cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
133  TypeRange newResultTypes(yieldOp.getResults());
134 
135  // Create new op and move over region.
136  auto newOp =
137  rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
138  newOp.getRegion().takeBody(executeRegionOp.getRegion());
139 
140  // Update all uses of the old op.
141  rewriter.setInsertionPointAfter(newOp);
142  SmallVector<Value> newResults;
143  for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
144  if (isa<TensorType>(it.value())) {
145  newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
146  executeRegionOp.getLoc(), newOp->getResult(it.index())));
147  } else {
148  newResults.push_back(newOp->getResult(it.index()));
149  }
150  }
151 
152  // Replace old op.
153  rewriter.replaceOp(executeRegionOp, newResults);
154 
155  return success();
156  }
157 };
158 
159 /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
160 struct IfOpInterface
161  : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
163  getAliasingOpOperands(Operation *op, OpResult opResult,
164  const AnalysisState &state) const {
165  // IfOps do not have tensor OpOperands. The yielded value can be any SSA
166  // value that is in scope. To allow for use-def chain traversal through
167  // IfOps in the analysis, both corresponding yield values from the then/else
168  // branches are considered to be aliasing with the result.
169  auto ifOp = cast<scf::IfOp>(op);
170  size_t resultNum = std::distance(op->getOpResults().begin(),
171  llvm::find(op->getOpResults(), opResult));
172  OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
173  OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
174  return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
175  {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
176  }
177 
178  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
179  const BufferizationOptions &options) const {
180  OpBuilder::InsertionGuard g(rewriter);
181  auto ifOp = cast<scf::IfOp>(op);
182 
183  // Compute bufferized result types.
184  SmallVector<Type> newTypes;
185  for (Value result : ifOp.getResults()) {
186  if (!isa<TensorType>(result.getType())) {
187  newTypes.push_back(result.getType());
188  continue;
189  }
190  auto bufferType = bufferization::getBufferType(result, options);
191  if (failed(bufferType))
192  return failure();
193  newTypes.push_back(*bufferType);
194  }
195 
196  // Create new op.
197  rewriter.setInsertionPoint(ifOp);
198  auto newIfOp =
199  rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
200  /*withElseRegion=*/true);
201 
202  // Move over then/else blocks.
203  rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
204  rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
205 
206  // Replace op results.
207  replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
208 
209  return success();
210  }
211 
214  const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
215  auto ifOp = cast<scf::IfOp>(op);
216  auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
217  auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
218  assert(value.getDefiningOp() == op && "invalid valid");
219 
220  // Determine buffer types of the true/false branches.
221  auto opResult = cast<OpResult>(value);
222  auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
223  auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
224  BaseMemRefType thenBufferType, elseBufferType;
225  if (isa<BaseMemRefType>(thenValue.getType())) {
226  // True branch was already bufferized.
227  thenBufferType = cast<BaseMemRefType>(thenValue.getType());
228  } else {
229  auto maybeBufferType =
230  bufferization::getBufferType(thenValue, options, fixedTypes);
231  if (failed(maybeBufferType))
232  return failure();
233  thenBufferType = *maybeBufferType;
234  }
235  if (isa<BaseMemRefType>(elseValue.getType())) {
236  // False branch was already bufferized.
237  elseBufferType = cast<BaseMemRefType>(elseValue.getType());
238  } else {
239  auto maybeBufferType =
240  bufferization::getBufferType(elseValue, options, fixedTypes);
241  if (failed(maybeBufferType))
242  return failure();
243  elseBufferType = *maybeBufferType;
244  }
245 
246  // Best case: Both branches have the exact same buffer type.
247  if (thenBufferType == elseBufferType)
248  return thenBufferType;
249 
250  // Memory space mismatch.
251  if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
252  return op->emitError("inconsistent memory space on then/else branches");
253 
254  // Layout maps are different: Promote to fully dynamic layout map.
256  cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
257  }
258 };
259 
260 /// Helper function for loop bufferization. Return the indices of all values
261 /// that have a tensor type.
262 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
263  DenseSet<int64_t> result;
264  for (const auto &it : llvm::enumerate(values))
265  if (isa<TensorType>(it.value().getType()))
266  result.insert(it.index());
267  return result;
268 }
269 
270 /// Helper function for loop bufferization. Return the indices of all
271 /// bbArg/yielded value pairs who's buffer relation is "Equivalent".
272 DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
273  ValueRange yieldedValues,
274  const AnalysisState &state) {
275  unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
276  DenseSet<int64_t> result;
277  for (unsigned int i = 0; i < minSize; ++i) {
278  if (!isa<TensorType>(bbArgs[i].getType()) ||
279  !isa<TensorType>(yieldedValues[i].getType()))
280  continue;
281  if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
282  result.insert(i);
283  }
284  return result;
285 }
286 
287 /// Helper function for loop bufferization. Return the bufferized values of the
288 /// given OpOperands. If an operand is not a tensor, return the original value.
290 getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
291  const BufferizationOptions &options) {
292  SmallVector<Value> result;
293  for (OpOperand &opOperand : operands) {
294  if (isa<TensorType>(opOperand.get().getType())) {
295  FailureOr<Value> resultBuffer =
296  getBuffer(rewriter, opOperand.get(), options);
297  if (failed(resultBuffer))
298  return failure();
299  result.push_back(*resultBuffer);
300  } else {
301  result.push_back(opOperand.get());
302  }
303  }
304  return result;
305 }
306 
307 /// Helper function for loop bufferization. Given a list of bbArgs of the new
308 /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
309 /// ToTensorOps, so that the block body can be moved over to the new op.
310 static SmallVector<Value>
311 getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
312  const DenseSet<int64_t> &tensorIndices) {
313  SmallVector<Value> result;
314  for (const auto &it : llvm::enumerate(bbArgs)) {
315  size_t idx = it.index();
316  Value val = it.value();
317  if (tensorIndices.contains(idx)) {
318  result.push_back(
319  rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
320  .getResult());
321  } else {
322  result.push_back(val);
323  }
324  }
325  return result;
326 }
327 
328 /// Compute the bufferized type of a loop iter_arg. This type must be equal to
329 /// the bufferized type of the corresponding init_arg and the bufferized type
330 /// of the corresponding yielded value.
331 ///
332 /// This function uses bufferization::getBufferType to compute the bufferized
333 /// type of the init_arg and of the yielded value. (The computation of the
334 /// usually requires computing the bufferized type of the corresponding
335 /// iter_arg; the implementation of getBufferType traces back the use-def chain
336 /// of the given value and computes a buffer type along the way.) If both buffer
337 /// types are equal, no casts are needed the computed buffer type can be used
338 /// directly. Otherwise, the buffer types can only differ in their layout map
339 /// and a cast must be inserted.
340 static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
341  BlockArgument iterArg, Value initArg, Value yieldedValue,
343  const DenseMap<Value, BaseMemRefType> &fixedTypes) {
344  // Determine the buffer type of the init_arg.
345  auto initArgBufferType =
346  bufferization::getBufferType(initArg, options, fixedTypes);
347  if (failed(initArgBufferType))
348  return failure();
349 
350  // Fix the iter_arg type, so that recursive lookups return the buffer type
351  // of the init_arg. This is to avoid infinite loops when calculating the
352  // buffer type of the yielded value.
353  //
354  // Note: For more precise layout map computation, a fixpoint iteration could
355  // be done (i.e., re-computing the yielded buffer type until the bufferized
356  // iter_arg type no longer changes). This current implementation immediately
357  // switches to a fully dynamic layout map when a mismatch between bufferized
358  // init_arg type and bufferized yield value type is detected.
359  DenseMap<Value, BaseMemRefType> newFixedTypes(fixedTypes);
360  newFixedTypes[iterArg] = *initArgBufferType;
361 
362  // Compute the buffer type of the yielded value.
363  BaseMemRefType yieldedValueBufferType;
364  if (isa<BaseMemRefType>(yieldedValue.getType())) {
365  // scf.yield was already bufferized.
366  yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
367  } else {
368  auto maybeBufferType =
369  bufferization::getBufferType(yieldedValue, options, newFixedTypes);
370  if (failed(maybeBufferType))
371  return failure();
372  yieldedValueBufferType = *maybeBufferType;
373  }
374 
375  // If yielded type and init_arg type are the same, use that type directly.
376  if (*initArgBufferType == yieldedValueBufferType)
377  return yieldedValueBufferType;
378 
379  // If there is a mismatch between the yielded buffer type and the iter_arg
380  // buffer type, the buffer type must be promoted to a fully dynamic layout
381  // map.
382  auto yieldedRanked = cast<MemRefType>(yieldedValueBufferType);
383 #ifndef NDEBUG
384  auto iterRanked = llvm::cast<MemRefType>(*initArgBufferType);
385  assert(llvm::equal(yieldedRanked.getShape(), iterRanked.getShape()) &&
386  "expected same shape");
387  assert(yieldedRanked.getMemorySpace() == iterRanked.getMemorySpace() &&
388  "expected same memory space");
389 #endif // NDEBUG
391  cast<RankedTensorType>(iterArg.getType()),
392  yieldedRanked.getMemorySpace());
393 }
394 
395 /// Return `true` if the given loop may have 0 iterations.
396 bool mayHaveZeroIterations(scf::ForOp forOp) {
397  std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
398  std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
399  if (!lb.has_value() || !ub.has_value())
400  return true;
401  return *ub <= *lb;
402 }
403 
404 /// Bufferization of scf.for. Replace with a new scf.for that operates on
405 /// memrefs.
406 struct ForOpInterface
407  : public BufferizableOpInterface::ExternalModel<ForOpInterface,
408  scf::ForOp> {
409  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
410  const AnalysisState &state) const {
411  auto forOp = cast<scf::ForOp>(op);
412 
413  // If the loop has zero iterations, the results of the op are their
414  // corresponding init_args, meaning that the init_args bufferize to a read.
415  if (mayHaveZeroIterations(forOp))
416  return true;
417 
418  // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
419  // its matching bbArg may.
420  return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
421  }
422 
423  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
424  const AnalysisState &state) const {
425  // Tensor iter_args of scf::ForOps are always considered as a write.
426  return true;
427  }
428 
429  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
430  const AnalysisState &state) const {
431  auto forOp = cast<scf::ForOp>(op);
432  OpResult opResult = forOp.getResultForOpOperand(opOperand);
433  BufferRelation relation = bufferRelation(op, opResult, state);
434  return {{opResult, relation,
435  /*isDefinite=*/relation == BufferRelation::Equivalent}};
436  }
437 
438  BufferRelation bufferRelation(Operation *op, OpResult opResult,
439  const AnalysisState &state) const {
440  // ForOp results are equivalent to their corresponding init_args if the
441  // corresponding iter_args and yield values are equivalent.
442  auto forOp = cast<scf::ForOp>(op);
443  OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
444  auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
445  auto yieldOp =
446  cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
447  bool equivalentYield = state.areEquivalentBufferizedValues(
448  bbArg, yieldOp->getOperand(opResult.getResultNumber()));
449  return equivalentYield ? BufferRelation::Equivalent
451  }
452 
453  bool isWritable(Operation *op, Value value,
454  const AnalysisState &state) const {
455  // Interestingly, scf::ForOp's bbArg can **always** be viewed
456  // inplace from the perspective of ops nested under:
457  // 1. Either the matching iter operand is not bufferized inplace and an
458  // alloc + optional copy makes the bbArg itself inplaceable.
459  // 2. Or the matching iter operand is bufferized inplace and bbArg just
460  // bufferizes to that too.
461  return true;
462  }
463 
464  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
465  const AnalysisState &state) const {
466  auto bufferizableOp = cast<BufferizableOpInterface>(op);
467  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
468  return failure();
469 
470  if (!state.getOptions().enforceAliasingInvariants)
471  return success();
472 
473  // According to the `getAliasing...` implementations, a bufferized OpResult
474  // may alias only with the corresponding bufferized init_arg and with no
475  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
476  // but not with any other OpOperand. If a corresponding OpResult/init_arg
477  // pair bufferizes to equivalent buffers, this aliasing requirement is
478  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
479  // (New buffer copies do not alias with any buffer.)
480  auto forOp = cast<scf::ForOp>(op);
481  auto yieldOp =
482  cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
483  OpBuilder::InsertionGuard g(rewriter);
484  rewriter.setInsertionPoint(yieldOp);
485 
486  // Indices of all iter_args that have tensor type. These are the ones that
487  // are bufferized.
488  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
489  // For every yielded value, is the value equivalent to its corresponding
490  // bbArg?
491  DenseSet<int64_t> equivalentYields = getEquivalentBuffers(
492  forOp.getRegionIterArgs(), yieldOp.getResults(), state);
493  SmallVector<Value> yieldValues;
494  for (int64_t idx = 0;
495  idx < static_cast<int64_t>(yieldOp.getResults().size()); ++idx) {
496  Value value = yieldOp.getResults()[idx];
497  if (!indices.contains(idx) || equivalentYields.contains(idx)) {
498  yieldValues.push_back(value);
499  continue;
500  }
501  FailureOr<Value> alloc =
502  allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
503  /*escape=*/true, state.getOptions());
504  if (failed(alloc))
505  return failure();
506  yieldValues.push_back(*alloc);
507  }
508 
509  rewriter.updateRootInPlace(
510  yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
511  return success();
512  }
513 
516  const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
517  auto forOp = cast<scf::ForOp>(op);
518  assert(getOwnerOfValue(value) == op && "invalid value");
519  assert(isa<TensorType>(value.getType()) && "expected tensor type");
520 
521  // Get result/argument number.
522  unsigned resultNum;
523  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
524  resultNum =
525  forOp.getResultForOpOperand(forOp.getOpOperandForRegionIterArg(bbArg))
526  .getResultNumber();
527  } else {
528  resultNum = cast<OpResult>(value).getResultNumber();
529  }
530 
531  // Compute the bufferized type.
532  auto yieldOp =
533  cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
534  Value yieldedValue = yieldOp.getOperand(resultNum);
535  BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
536  Value initArg = forOp.getInitArgs()[resultNum];
537  return computeLoopRegionIterArgBufferType(iterArg, initArg, yieldedValue,
538  options, fixedTypes);
539  }
540 
541  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
542  const BufferizationOptions &options) const {
543  auto forOp = cast<scf::ForOp>(op);
544  Block *oldLoopBody = &forOp.getLoopBody().front();
545 
546  // Indices of all iter_args that have tensor type. These are the ones that
547  // are bufferized.
548  DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
549 
550  // The new memref init_args of the loop.
551  FailureOr<SmallVector<Value>> maybeInitArgs =
552  getBuffers(rewriter, forOp.getIterOpOperands(), options);
553  if (failed(maybeInitArgs))
554  return failure();
555  SmallVector<Value> initArgs = *maybeInitArgs;
556 
557  // Cast init_args if necessary.
558  SmallVector<Value> castedInitArgs;
559  for (const auto &it : llvm::enumerate(initArgs)) {
560  Value initArg = it.value();
561  Value result = forOp->getResult(it.index());
562  // If the type is not a tensor, bufferization doesn't need to touch it.
563  if (!isa<TensorType>(result.getType())) {
564  castedInitArgs.push_back(initArg);
565  continue;
566  }
567  auto targetType = bufferization::getBufferType(result, options);
568  if (failed(targetType))
569  return failure();
570  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
571  }
572 
573  // Construct a new scf.for op with memref instead of tensor values.
574  auto newForOp = rewriter.create<scf::ForOp>(
575  forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
576  forOp.getStep(), castedInitArgs);
577  newForOp->setAttrs(forOp->getAttrs());
578  Block *loopBody = &newForOp.getLoopBody().front();
579 
580  // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
581  // iter_args of the new loop in ToTensorOps.
582  rewriter.setInsertionPointToStart(loopBody);
583  SmallVector<Value> iterArgs =
584  getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
585  iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
586 
587  // Move loop body to new loop.
588  rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
589 
590  // Replace loop results.
591  replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
592 
593  return success();
594  }
595 
596  /// Assert that yielded values of an scf.for op are equivalent to their
597  /// corresponding bbArgs. In that case, the buffer relations of the
598  /// corresponding OpResults are "Equivalent".
599  ///
600  /// If this is not the case, an allocs+copies are inserted and yielded from
601  /// the loop. This could be a performance problem, so it must be explicitly
602  /// activated with `alloc-return-allocs`.
603  LogicalResult verifyAnalysis(Operation *op,
604  const AnalysisState &state) const {
605  const auto &options =
606  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
607  if (options.allowReturnAllocs)
608  return success();
609 
610  auto forOp = cast<scf::ForOp>(op);
611  auto yieldOp =
612  cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
613  for (OpResult opResult : op->getOpResults()) {
614  if (!isa<TensorType>(opResult.getType()))
615  continue;
616 
617  // Note: This is overly strict. We should check for aliasing bufferized
618  // values. But we don't have a "must-alias" analysis yet.
619  if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
620  return yieldOp->emitError()
621  << "Yield operand #" << opResult.getResultNumber()
622  << " is not equivalent to the corresponding iter bbArg";
623  }
624 
625  return success();
626  }
627 };
628 
629 /// Bufferization of scf.while. Replace with a new scf.while that operates on
630 /// memrefs.
631 struct WhileOpInterface
632  : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
633  scf::WhileOp> {
634  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
635  const AnalysisState &state) const {
636  // Tensor iter_args of scf::WhileOps are always considered as a read.
637  return true;
638  }
639 
640  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
641  const AnalysisState &state) const {
642  // Tensor iter_args of scf::WhileOps are always considered as a write.
643  return true;
644  }
645 
646  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
647  const AnalysisState &state) const {
648  auto whileOp = cast<scf::WhileOp>(op);
649  unsigned int idx = opOperand.getOperandNumber();
650 
651  // The OpResults and OpOperands may not match. They may not even have the
652  // same type. The number of OpResults and OpOperands can also differ.
653  if (idx >= op->getNumResults() ||
654  opOperand.get().getType() != op->getResult(idx).getType())
655  return {};
656 
657  // The only aliasing OpResult may be the one at the same index.
658  OpResult opResult = whileOp->getResult(idx);
659  BufferRelation relation = bufferRelation(op, opResult, state);
660  return {{opResult, relation,
661  /*isDefinite=*/relation == BufferRelation::Equivalent}};
662  }
663 
664  BufferRelation bufferRelation(Operation *op, OpResult opResult,
665  const AnalysisState &state) const {
666  // WhileOp results are equivalent to their corresponding init_args if the
667  // corresponding iter_args and yield values are equivalent (for both the
668  // "before" and the "after" block).
669  unsigned int resultNumber = opResult.getResultNumber();
670  auto whileOp = cast<scf::WhileOp>(op);
671 
672  // The "before" region bbArgs and the OpResults may not match.
673  if (resultNumber >= whileOp.getBeforeArguments().size())
675  if (opResult.getType() !=
676  whileOp.getBeforeArguments()[resultNumber].getType())
678 
679  auto conditionOp = whileOp.getConditionOp();
680  BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
681  Value conditionOperand = conditionOp.getArgs()[resultNumber];
682  bool equivCondition =
683  state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
684 
685  auto yieldOp = whileOp.getYieldOp();
686  BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
687  Value yieldOperand = yieldOp.getOperand(resultNumber);
688  bool equivYield =
689  state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
690 
691  return equivCondition && equivYield ? BufferRelation::Equivalent
693  }
694 
695  bool isWritable(Operation *op, Value value,
696  const AnalysisState &state) const {
697  // Interestingly, scf::WhileOp's bbArg can **always** be viewed
698  // inplace from the perspective of ops nested under:
699  // 1. Either the matching iter operand is not bufferized inplace and an
700  // alloc + optional copy makes the bbArg itself inplaceable.
701  // 2. Or the matching iter operand is bufferized inplace and bbArg just
702  // bufferizes to that too.
703  return true;
704  }
705 
706  LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
707  const AnalysisState &state) const {
708  auto bufferizableOp = cast<BufferizableOpInterface>(op);
709  if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
710  return failure();
711 
712  if (!state.getOptions().enforceAliasingInvariants)
713  return success();
714 
715  // According to the `getAliasing...` implementations, a bufferized OpResult
716  // may alias only with the corresponding bufferized init_arg and with no
717  // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
718  // but not with any other OpOperand. If a corresponding OpResult/init_arg
719  // pair bufferizes to equivalent buffers, this aliasing requirement is
720  // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
721  // (New buffer copies do not alias with any buffer.)
722  OpBuilder::InsertionGuard g(rewriter);
723  auto whileOp = cast<scf::WhileOp>(op);
724  auto conditionOp = whileOp.getConditionOp();
725 
726  // For every yielded value, is the value equivalent to its corresponding
727  // bbArg?
728  DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
729  whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
730  DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
731  whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
732 
733  // Update "before" region.
734  rewriter.setInsertionPoint(conditionOp);
735  SmallVector<Value> beforeYieldValues;
736  for (int64_t idx = 0;
737  idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
738  Value value = conditionOp.getArgs()[idx];
739  if (!isa<TensorType>(value.getType()) ||
740  (equivalentYieldsAfter.contains(idx) &&
741  equivalentYieldsBefore.contains(idx))) {
742  beforeYieldValues.push_back(value);
743  continue;
744  }
745  FailureOr<Value> alloc =
746  allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
747  /*escape=*/true, state.getOptions());
748  if (failed(alloc))
749  return failure();
750  beforeYieldValues.push_back(*alloc);
751  }
752  rewriter.updateRootInPlace(conditionOp, [&]() {
753  conditionOp.getArgsMutable().assign(beforeYieldValues);
754  });
755 
756  return success();
757  }
758 
759  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
760  const BufferizationOptions &options) const {
761  auto whileOp = cast<scf::WhileOp>(op);
762 
763  assert(whileOp.getBefore().getBlocks().size() == 1 &&
764  "regions with multiple blocks not supported");
765  Block *beforeBody = &whileOp.getBefore().front();
766  assert(whileOp.getAfter().getBlocks().size() == 1 &&
767  "regions with multiple blocks not supported");
768  Block *afterBody = &whileOp.getAfter().front();
769 
770  // Indices of all bbArgs that have tensor type. These are the ones that
771  // are bufferized. The "before" and "after" regions may have different args.
772  DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
773  DenseSet<int64_t> indicesAfter =
774  getTensorIndices(whileOp.getAfterArguments());
775 
776  // The new memref init_args of the loop.
777  FailureOr<SmallVector<Value>> maybeInitArgs =
778  getBuffers(rewriter, whileOp->getOpOperands(), options);
779  if (failed(maybeInitArgs))
780  return failure();
781  SmallVector<Value> initArgs = *maybeInitArgs;
782 
783  // Cast init_args if necessary.
784  SmallVector<Value> castedInitArgs;
785  for (const auto &it : llvm::enumerate(initArgs)) {
786  Value initArg = it.value();
787  Value beforeArg = whileOp.getBeforeArguments()[it.index()];
788  // If the type is not a tensor, bufferization doesn't need to touch it.
789  if (!isa<TensorType>(beforeArg.getType())) {
790  castedInitArgs.push_back(initArg);
791  continue;
792  }
793  auto targetType = bufferization::getBufferType(beforeArg, options);
794  if (failed(targetType))
795  return failure();
796  castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
797  }
798 
799  // The result types of a WhileOp are the same as the "after" bbArg types.
800  SmallVector<Type> argsTypesAfter = llvm::to_vector(
801  llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
802  if (!isa<TensorType>(bbArg.getType()))
803  return bbArg.getType();
804  // TODO: error handling
805  return llvm::cast<Type>(*bufferization::getBufferType(bbArg, options));
806  }));
807 
808  // Construct a new scf.while op with memref instead of tensor values.
809  ValueRange argsRangeBefore(castedInitArgs);
810  TypeRange argsTypesBefore(argsRangeBefore);
811  auto newWhileOp = rewriter.create<scf::WhileOp>(
812  whileOp.getLoc(), argsTypesAfter, castedInitArgs);
813 
814  // Add before/after regions to the new op.
815  SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
816  whileOp.getLoc());
817  SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
818  whileOp.getLoc());
819  Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
820  newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
821  Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
822  newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
823 
824  // Set up new iter_args and move the loop condition block to the new op.
825  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
826  // in ToTensorOps.
827  rewriter.setInsertionPointToStart(newBeforeBody);
828  SmallVector<Value> newBeforeArgs = getBbArgReplacements(
829  rewriter, newWhileOp.getBeforeArguments(), indicesBefore);
830  rewriter.mergeBlocks(beforeBody, newBeforeBody, newBeforeArgs);
831 
832  // Set up new iter_args and move the loop body block to the new op.
833  // The old block uses tensors, so wrap the (memref) bbArgs of the new block
834  // in ToTensorOps.
835  rewriter.setInsertionPointToStart(newAfterBody);
836  SmallVector<Value> newAfterArgs = getBbArgReplacements(
837  rewriter, newWhileOp.getAfterArguments(), indicesAfter);
838  rewriter.mergeBlocks(afterBody, newAfterBody, newAfterArgs);
839 
840  // Replace loop results.
841  replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
842 
843  return success();
844  }
845 
848  const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
849  auto whileOp = cast<scf::WhileOp>(op);
850  assert(getOwnerOfValue(value) == op && "invalid value");
851  assert(isa<TensorType>(value.getType()) && "expected tensor type");
852 
853  // Case 1: Block argument of the "before" region.
854  if (auto bbArg = dyn_cast<BlockArgument>(value)) {
855  if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
856  Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
857  auto yieldOp = whileOp.getYieldOp();
858  Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
859  return computeLoopRegionIterArgBufferType(bbArg, initArg, yieldedValue,
860  options, fixedTypes);
861  }
862  }
863 
864  // Case 2: OpResult of the loop or block argument of the "after" region.
865  // The bufferized "after" bbArg type can be directly computed from the
866  // bufferized "before" bbArg type.
867  unsigned resultNum;
868  if (auto opResult = dyn_cast<OpResult>(value)) {
869  resultNum = opResult.getResultNumber();
870  } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
871  &whileOp.getAfter()) {
872  resultNum = cast<BlockArgument>(value).getArgNumber();
873  } else {
874  llvm_unreachable("invalid value");
875  }
876  Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
877  if (!isa<TensorType>(conditionYieldedVal.getType())) {
878  // scf.condition was already bufferized.
879  return cast<BaseMemRefType>(conditionYieldedVal.getType());
880  }
881  return bufferization::getBufferType(conditionYieldedVal, options,
882  fixedTypes);
883  }
884 
885  /// Assert that yielded values of an scf.while op are equivalent to their
886  /// corresponding bbArgs. In that case, the buffer relations of the
887  /// corresponding OpResults are "Equivalent".
888  ///
889  /// If this is not the case, allocs+copies are inserted and yielded from
890  /// the loop. This could be a performance problem, so it must be explicitly
891  /// activated with `allow-return-allocs`.
892  ///
893  /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
894  /// equivalence condition must be checked for both.
895  LogicalResult verifyAnalysis(Operation *op,
896  const AnalysisState &state) const {
897  auto whileOp = cast<scf::WhileOp>(op);
898  const auto &options =
899  static_cast<const OneShotBufferizationOptions &>(state.getOptions());
900  if (options.allowReturnAllocs)
901  return success();
902 
903  auto conditionOp = whileOp.getConditionOp();
904  for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
905  Block *block = conditionOp->getBlock();
906  if (!isa<TensorType>(it.value().getType()))
907  continue;
908  if (it.index() >= block->getNumArguments() ||
909  !state.areEquivalentBufferizedValues(it.value(),
910  block->getArgument(it.index())))
911  return conditionOp->emitError()
912  << "Condition arg #" << it.index()
913  << " is not equivalent to the corresponding iter bbArg";
914  }
915 
916  auto yieldOp = whileOp.getYieldOp();
917  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
918  Block *block = yieldOp->getBlock();
919  if (!isa<TensorType>(it.value().getType()))
920  continue;
921  if (it.index() >= block->getNumArguments() ||
922  !state.areEquivalentBufferizedValues(it.value(),
923  block->getArgument(it.index())))
924  return yieldOp->emitError()
925  << "Yield operand #" << it.index()
926  << " is not equivalent to the corresponding iter bbArg";
927  }
928 
929  return success();
930  }
931 };
932 
933 /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
934 /// this is for analysis only.
935 struct YieldOpInterface
936  : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
937  scf::YieldOp> {
938  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
939  const AnalysisState &state) const {
940  return true;
941  }
942 
943  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
944  const AnalysisState &state) const {
945  return false;
946  }
947 
948  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
949  const AnalysisState &state) const {
950  if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
951  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
952  BufferRelation::Equivalent, /*isDefinite=*/false}};
953  }
954  if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
955  return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
957  return {};
958  }
959 
960  bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
961  const AnalysisState &state) const {
962  // Yield operands always bufferize inplace. Otherwise, an alloc + copy
963  // may be generated inside the block. We should not return/yield allocations
964  // when possible.
965  return true;
966  }
967 
968  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
969  const BufferizationOptions &options) const {
970  auto yieldOp = cast<scf::YieldOp>(op);
971  if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
972  yieldOp->getParentOp()))
973  return yieldOp->emitError("unsupported scf::YieldOp parent");
974 
975  SmallVector<Value> newResults;
976  for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
977  Value value = it.value();
978  if (isa<TensorType>(value.getType())) {
979  FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
980  if (failed(maybeBuffer))
981  return failure();
982  Value buffer = *maybeBuffer;
983  // We may have to cast the value before yielding it.
984  if (isa<scf::ForOp, scf::IfOp>(yieldOp->getParentOp())) {
986  yieldOp->getParentOp()->getResult(it.index()), options);
987  if (failed(resultType))
988  return failure();
989  buffer = castBuffer(rewriter, buffer, *resultType);
990  } else if (auto whileOp =
991  dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
993  whileOp.getBeforeArguments()[it.index()], options);
994  if (failed(resultType))
995  return failure();
996  buffer = castBuffer(rewriter, buffer, *resultType);
997  }
998  newResults.push_back(buffer);
999  } else {
1000  newResults.push_back(value);
1001  }
1002  }
1003 
1004  replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1005  return success();
1006  }
1007 };
1008 
1009 /// Return `true` if the given loop may have 0 iterations.
1010 bool mayHaveZeroIterations(scf::ForallOp forallOp) {
1011  for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(),
1012  forallOp.getMixedUpperBound())) {
1013  std::optional<int64_t> lbConst = getConstantIntValue(lb);
1014  std::optional<int64_t> ubConst = getConstantIntValue(ub);
1015  if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst)
1016  return true;
1017  }
1018  return false;
1019 }
1020 
1021 /// Bufferization of ForallOp. This also bufferizes the terminator of the
1022 /// region. There are op interfaces for the terminators (InParallelOp
1023 /// and ParallelInsertSliceOp), but these are only used during analysis. Not
1024 /// for bufferization.
1025 struct ForallOpInterface
1026  : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1027  ForallOp> {
1028  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1029  const AnalysisState &state) const {
1030  auto forallOp = cast<ForallOp>(op);
1031 
1032  // If the loop has zero iterations, the results of the op are their
1033  // corresponding shared_outs, meaning that the shared_outs bufferize to a
1034  // read.
1035  if (mayHaveZeroIterations(forallOp))
1036  return true;
1037 
1038  // scf::ForallOp alone doesn't bufferize to a memory read, one of the
1039  // uses of its matching bbArg may.
1040  return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand));
1041  }
1042 
1043  bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1044  const AnalysisState &state) const {
1045  // Outputs of scf::ForallOps are always considered as a write.
1046  return true;
1047  }
1048 
1049  AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand,
1050  const AnalysisState &state) const {
1051  auto forallOp = cast<ForallOp>(op);
1052  return {
1053  {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1054  }
1055 
1056  bool isWritable(Operation *op, Value value,
1057  const AnalysisState &state) const {
1058  return true;
1059  }
1060 
1061  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1062  const BufferizationOptions &options) const {
1063  OpBuilder::InsertionGuard guard(rewriter);
1064  auto forallOp = cast<ForallOp>(op);
1065  int64_t rank = forallOp.getRank();
1066 
1067  // Get buffers for all output operands.
1068  SmallVector<Value> buffers;
1069  for (Value out : forallOp.getOutputs()) {
1070  FailureOr<Value> buffer = getBuffer(rewriter, out, options);
1071  if (failed(buffer))
1072  return failure();
1073  buffers.push_back(*buffer);
1074  }
1075 
1076  // Use buffers instead of block arguments.
1077  rewriter.setInsertionPointToStart(forallOp.getBody());
1078  for (const auto &it : llvm::zip(
1079  forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1080  BlockArgument bbArg = std::get<0>(it);
1081  Value buffer = std::get<1>(it);
1082  Value bufferAsTensor =
1083  rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer);
1084  bbArg.replaceAllUsesWith(bufferAsTensor);
1085  }
1086 
1087  // Create new ForallOp without any results and drop the automatically
1088  // introduced terminator.
1089  rewriter.setInsertionPoint(forallOp);
1090  ForallOp newForallOp;
1091  newForallOp = rewriter.create<ForallOp>(
1092  forallOp.getLoc(), forallOp.getMixedLowerBound(),
1093  forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1094  /*outputs=*/ValueRange(), forallOp.getMapping());
1095 
1096  rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1097 
1098  // Move over block contents of the old op.
1099  SmallVector<Value> replacementBbArgs;
1100  replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1101  newForallOp.getBody()->getArguments().end());
1102  replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1103  rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1104  replacementBbArgs);
1105 
1106  // Remove the old op and replace all of its uses.
1107  replaceOpWithBufferizedValues(rewriter, op, buffers);
1108 
1109  return success();
1110  }
1111 
1114  const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
1115  auto forallOp = cast<ForallOp>(op);
1116 
1117  if (auto bbArg = dyn_cast<BlockArgument>(value))
1118  // A tensor block argument has the same bufferized type as the
1119  // corresponding output operand.
1121  forallOp.getTiedOpOperand(bbArg)->get(), options, fixedTypes);
1122 
1123  // The bufferized result type is the same as the bufferized type of the
1124  // corresponding output operand.
1126  forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1127  fixedTypes);
1128  }
1129 
1130  bool isRepetitiveRegion(Operation *op, unsigned index) const {
1131  auto forallOp = cast<ForallOp>(op);
1132 
1133  // This op is repetitive if it has 1 or more steps.
1134  // If the control variables are dynamic, it is also considered so.
1135  for (auto [lb, ub, step] :
1136  llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1137  forallOp.getMixedStep())) {
1138  std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1139  if (!lbConstant)
1140  return true;
1141 
1142  std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1143  if (!ubConstant)
1144  return true;
1145 
1146  std::optional<int64_t> stepConstant = getConstantIntValue(step);
1147  if (!stepConstant)
1148  return true;
1149 
1150  if (*lbConstant + *stepConstant < *ubConstant)
1151  return true;
1152  }
1153  return false;
1154  }
1155 };
1156 
1157 /// Nothing to do for InParallelOp.
1158 struct InParallelOpInterface
1159  : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1160  InParallelOp> {
1161  LogicalResult bufferize(Operation *op, RewriterBase &b,
1162  const BufferizationOptions &options) const {
1163  llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1164  return failure();
1165  }
1166 };
1167 
1168 } // namespace
1169 } // namespace scf
1170 } // namespace mlir
1171 
1173  DialectRegistry &registry) {
1174  registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1175  ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1176  ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1177  ForOp::attachInterface<ForOpInterface>(*ctx);
1178  IfOp::attachInterface<IfOpInterface>(*ctx);
1179  ForallOp::attachInterface<ForallOpInterface>(*ctx);
1180  InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1181  WhileOp::attachInterface<WhileOpInterface>(*ctx);
1182  YieldOp::attachInterface<YieldOpInterface>(*ctx);
1183  });
1184 }
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static llvm::ManagedStatic< PassManagerOptions > options
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base class for generic analysis states.
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:129
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
This class represents an argument of a Block.
Definition: Value.h:310
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:148
Operation & front()
Definition: Block.h:142
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:152
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:393
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
This is a value defined by a result of an operation.
Definition: Value.h:448
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:460
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:303
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:266
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:648
result_range getOpResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:581
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:166
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options for BufferizableOpInterface-based bufferization.
Options for analysis-enabled bufferization.