MLIR 23.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/Operation.h"
22#include "llvm/ADT/SmallVectorExtras.h"
23
24using namespace mlir;
25using namespace mlir::bufferization;
26using namespace mlir::scf;
27
28namespace mlir {
29namespace scf {
30namespace {
31
32/// Helper function for loop bufferization. Cast the given buffer to the given
33/// memref type.
34static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
36 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
37 // If the buffer already has the correct type, no cast is needed.
38 if (buffer.getType() == type)
39 return buffer;
40 // TODO: In case `type` has a layout map that is not the fully dynamic
41 // one, we may not be able to cast the buffer. In that case, the loop
42 // iter_arg's layout map must be changed (see uses of `castBuffer`).
43 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
44 "scf.while op bufferization: cast incompatible");
45 return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
46}
47
48/// Helper function for loop bufferization. Return "true" if the given value
49/// is guaranteed to not alias with an external tensor apart from values in
50/// `exceptions`. A value is external if it is defined outside of the given
51/// region or if it is an entry block argument of the region.
52static bool doesNotAliasExternalValue(Value value, Region *region,
53 ValueRange exceptions,
54 const OneShotAnalysisState &state) {
55 assert(region->hasOneBlock() && "expected region with single block");
56 bool result = true;
57 state.applyOnAliases(value, [&](Value alias) {
58 if (llvm::is_contained(exceptions, alias))
59 return;
60 Region *aliasRegion = alias.getParentRegion();
61 if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
62 result = false;
63 if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
64 result = false;
65 });
66 return result;
67}
68
69/// Bufferization of scf.condition.
70struct ConditionOpInterface
71 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
72 scf::ConditionOp> {
73 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
74 const AnalysisState &state) const {
75 return true;
76 }
77
78 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
79 const AnalysisState &state) const {
80 return false;
81 }
82
83 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
84 const AnalysisState &state) const {
85 return {};
86 }
87
88 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
89 const AnalysisState &state) const {
90 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
91 // may be generated inside the block. We should not return/yield allocations
92 // when possible.
93 return true;
94 }
95
96 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
97 const BufferizationOptions &options,
98 BufferizationState &state) const {
99 auto conditionOp = cast<scf::ConditionOp>(op);
100 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
101
102 SmallVector<Value> newArgs;
103 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
104 Value value = it.value();
105 if (isa<TensorType>(value.getType())) {
106 FailureOr<Value> maybeBuffer =
107 getBuffer(rewriter, value, options, state);
108 if (failed(maybeBuffer))
109 return failure();
110 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
111 whileOp.getAfterArguments()[it.index()], options, state);
112 if (failed(resultType))
113 return failure();
114 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
115 newArgs.push_back(buffer);
116 } else {
117 newArgs.push_back(value);
118 }
119 }
120
121 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
122 rewriter, op, conditionOp.getCondition(), newArgs);
123 return success();
124 }
125};
126
127/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
128/// return an empty op.
129static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
130 scf::YieldOp result;
131 for (Block &block : executeRegionOp.getRegion()) {
132 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
133 if (result)
134 return {};
135 result = yieldOp;
136 }
137 }
138 return result;
139}
140
141/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
142/// fully implemented at the moment.
143struct ExecuteRegionOpInterface
144 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
145 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
146
147 static bool supportsUnstructuredControlFlow() { return true; }
148
149 bool isWritable(Operation *op, Value value,
150 const AnalysisState &state) const {
151 return true;
152 }
153
154 LogicalResult verifyAnalysis(Operation *op,
155 const AnalysisState &state) const {
156 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
157 // TODO: scf.execute_region with multiple yields are not supported.
158 if (!getUniqueYieldOp(executeRegionOp))
159 return op->emitOpError("op without unique scf.yield is not supported");
160 return success();
161 }
162
163 AliasingOpOperandList
164 getAliasingOpOperands(Operation *op, Value value,
165 const AnalysisState &state) const {
166 if (auto bbArg = dyn_cast<BlockArgument>(value))
167 return getAliasingBranchOpOperands(op, bbArg, state);
168
169 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
170 // any SSA value that is in scope. To allow for use-def chain traversal
171 // through ExecuteRegionOps in the analysis, the corresponding yield value
172 // is considered to be aliasing with the result.
173 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
174 auto it = llvm::find(op->getOpResults(), value);
175 assert(it != op->getOpResults().end() && "invalid value");
176 size_t resultNum = std::distance(op->getOpResults().begin(), it);
177 auto yieldOp = getUniqueYieldOp(executeRegionOp);
178 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
179 if (!yieldOp)
180 return {};
181 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
182 }
183
184 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
185 const BufferizationOptions &options,
186 BufferizationState &state) const {
187 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
188 auto yieldOp = getUniqueYieldOp(executeRegionOp);
189 TypeRange newResultTypes(yieldOp.getResults());
190
191 // Create new op and move over region.
192 auto newOp = scf::ExecuteRegionOp::create(
193 rewriter, op->getLoc(), newResultTypes, executeRegionOp.getNoInline());
194 newOp.getRegion().takeBody(executeRegionOp.getRegion());
195
196 // Bufferize every block.
197 for (Block &block : newOp.getRegion())
199 options, state)))
200 return failure();
201
202 // Update all uses of the old op.
203 rewriter.setInsertionPointAfter(newOp);
204 SmallVector<Value> newResults;
205 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
206 if (isa<TensorType>(it.value())) {
207 newResults.push_back(bufferization::ToTensorOp::create(
208 rewriter, executeRegionOp.getLoc(), it.value(),
209 newOp->getResult(it.index())));
210 } else {
211 newResults.push_back(newOp->getResult(it.index()));
212 }
213 }
214
215 // Replace old op.
216 rewriter.replaceOp(executeRegionOp, newResults);
217
218 return success();
219 }
220};
221
222/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
223struct IfOpInterface
224 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
225 AliasingOpOperandList
226 getAliasingOpOperands(Operation *op, Value value,
227 const AnalysisState &state) const {
228 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
229 // value that is in scope. To allow for use-def chain traversal through
230 // IfOps in the analysis, both corresponding yield values from the then/else
231 // branches are considered to be aliasing with the result.
232 auto ifOp = cast<scf::IfOp>(op);
233 size_t resultNum = std::distance(op->getOpResults().begin(),
234 llvm::find(op->getOpResults(), value));
235 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
236 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
237 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
238 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
239 }
240
241 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242 const BufferizationOptions &options,
243 BufferizationState &state) const {
244 OpBuilder::InsertionGuard g(rewriter);
245 auto ifOp = cast<scf::IfOp>(op);
246
247 // Compute bufferized result types.
248 SmallVector<Type> newTypes;
249 for (Value result : ifOp.getResults()) {
250 if (!isa<TensorType>(result.getType())) {
251 newTypes.push_back(result.getType());
252 continue;
253 }
254 auto bufferType = bufferization::getBufferType(result, options, state);
255 if (failed(bufferType))
256 return failure();
257 newTypes.push_back(*bufferType);
258 }
259
260 // Create new op.
261 rewriter.setInsertionPoint(ifOp);
262 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
263 ifOp.getCondition(),
264 /*withElseRegion=*/true);
265
266 // Move over then/else blocks.
267 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
268 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
269
270 // Replace op results.
271 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
272
273 return success();
274 }
275
276 FailureOr<BufferLikeType>
277 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
278 const BufferizationState &state,
279 SmallVector<Value> &invocationStack) const {
280 auto ifOp = cast<scf::IfOp>(op);
281 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
282 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
283 assert(value.getDefiningOp() == op && "invalid valid");
284
285 // Determine buffer types of the true/false branches.
286 auto opResult = cast<OpResult>(value);
287 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
288 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
289 BaseMemRefType thenBufferType, elseBufferType;
290 if (isa<BaseMemRefType>(thenValue.getType())) {
291 // True branch was already bufferized.
292 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
293 } else {
294 auto maybeBufferType =
295 bufferization::detail::asMemRefType(bufferization::getBufferType(
296 thenValue, options, state, invocationStack));
297 if (failed(maybeBufferType))
298 return failure();
299 thenBufferType = *maybeBufferType;
300 }
301 if (isa<BaseMemRefType>(elseValue.getType())) {
302 // False branch was already bufferized.
303 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
304 } else {
305 auto maybeBufferType =
306 bufferization::detail::asMemRefType(bufferization::getBufferType(
307 elseValue, options, state, invocationStack));
308 if (failed(maybeBufferType))
309 return failure();
310 elseBufferType = *maybeBufferType;
311 }
312
313 // Best case: Both branches have the exact same buffer type.
314 if (thenBufferType == elseBufferType)
315 return cast<BufferLikeType>(thenBufferType);
316
317 // Memory space mismatch.
318 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
319 return op->emitError("inconsistent memory space on then/else branches");
320
321 // Layout maps are different: Promote to fully dynamic layout map.
322 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
323 cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
324 }
325};
326
327/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
328/// yields memrefs.
329struct IndexSwitchOpInterface
330 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
331 scf::IndexSwitchOp> {
332 AliasingOpOperandList
333 getAliasingOpOperands(Operation *op, Value value,
334 const AnalysisState &state) const {
335 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
336 // any SSA. This is similar to IfOps.
337 auto switchOp = cast<scf::IndexSwitchOp>(op);
338 int64_t resultNum = cast<OpResult>(value).getResultNumber();
339 AliasingOpOperandList result;
340 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
341 auto yieldOp =
342 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
343 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
344 BufferRelation::Equivalent,
345 /*isDefinite=*/false));
346 }
347 auto defaultYieldOp =
348 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
349 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
350 BufferRelation::Equivalent,
351 /*isDefinite=*/false));
352 return result;
353 }
354
355 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
356 const BufferizationOptions &options,
357 BufferizationState &state) const {
358 OpBuilder::InsertionGuard g(rewriter);
359 auto switchOp = cast<scf::IndexSwitchOp>(op);
360
361 // Compute bufferized result types.
362 SmallVector<Type> newTypes;
363 for (Value result : switchOp.getResults()) {
364 if (!isa<TensorType>(result.getType())) {
365 newTypes.push_back(result.getType());
366 continue;
367 }
368 auto bufferType = bufferization::getBufferType(result, options, state);
369 if (failed(bufferType))
370 return failure();
371 newTypes.push_back(*bufferType);
372 }
373
374 // Create new op.
375 rewriter.setInsertionPoint(switchOp);
376 auto newSwitchOp = scf::IndexSwitchOp::create(
377 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
378 switchOp.getCases(), switchOp.getCases().size());
379
380 // Move over blocks.
381 for (auto [src, dest] :
382 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
383 rewriter.inlineRegionBefore(src, dest, dest.begin());
384 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion(),
386 newSwitchOp.getDefaultRegion().begin());
387
388 // Replace op results.
389 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
390
391 return success();
392 }
393
394 FailureOr<BufferLikeType>
395 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
396 const BufferizationState &state,
397 SmallVector<Value> &invocationStack) const {
398 auto switchOp = cast<scf::IndexSwitchOp>(op);
399 assert(value.getDefiningOp() == op && "invalid value");
400 int64_t resultNum = cast<OpResult>(value).getResultNumber();
401
402 // Helper function to get buffer type of a case.
403 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
404 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
405 Value yieldedValue = yieldOp->getOperand(resultNum);
406 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
407 return bufferType;
408 auto maybeBufferType = bufferization::getBufferType(
409 yieldedValue, options, state, invocationStack);
410 return bufferization::detail::asMemRefType(maybeBufferType);
411 };
412
413 // Compute buffer type of the default case.
414 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
415 if (failed(maybeBufferType))
416 return failure();
417 BaseMemRefType bufferType = *maybeBufferType;
418
419 // Compute buffer types of all other cases.
420 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
421 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
422 if (failed(yieldedBufferType))
423 return failure();
424
425 // Best case: Both branches have the exact same buffer type.
426 if (bufferType == *yieldedBufferType)
427 continue;
428
429 // Memory space mismatch.
430 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
431 return op->emitError("inconsistent memory space on switch cases");
432
433 // Layout maps are different: Promote to fully dynamic layout map.
434 bufferType = getMemRefTypeWithFullyDynamicLayout(
435 cast<TensorType>(value.getType()), bufferType.getMemorySpace());
436 }
437
438 return cast<BufferLikeType>(bufferType);
439 }
440};
441
442/// Helper function for loop bufferization. Return the indices of all values
443/// that have a tensor type.
444static DenseSet<int64_t> getTensorIndices(ValueRange values) {
446 for (const auto &it : llvm::enumerate(values))
447 if (isa<TensorType>(it.value().getType()))
448 result.insert(it.index());
449 return result;
450}
451
452/// Helper function for loop bufferization. Return the indices of all
453/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
454DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
455 ValueRange yieldedValues,
456 const AnalysisState &state) {
457 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
459 for (unsigned int i = 0; i < minSize; ++i) {
460 if (!isa<TensorType>(bbArgs[i].getType()) ||
461 !isa<TensorType>(yieldedValues[i].getType()))
462 continue;
463 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
464 result.insert(i);
465 }
466 return result;
467}
468
469/// Helper function for loop bufferization. Return the bufferized values of the
470/// given OpOperands. If an operand is not a tensor, return the original value.
471static FailureOr<SmallVector<Value>>
472getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
473 const BufferizationOptions &options, BufferizationState &state) {
474 SmallVector<Value> result;
475 for (OpOperand &opOperand : operands) {
476 if (isa<TensorType>(opOperand.get().getType())) {
477 FailureOr<Value> resultBuffer =
478 getBuffer(rewriter, opOperand.get(), options, state);
479 if (failed(resultBuffer))
480 return failure();
481 result.push_back(*resultBuffer);
482 } else {
483 result.push_back(opOperand.get());
484 }
485 }
486 return result;
487}
488
489/// Helper function for loop bufferization. Given a list of bbArgs of the new
490/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
491/// ToTensorOps, so that the block body can be moved over to the new op.
492static SmallVector<Value>
493getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
494 Block::BlockArgListType oldBbArgs,
495 const DenseSet<int64_t> &tensorIndices) {
496 SmallVector<Value> result;
497 for (const auto &it : llvm::enumerate(bbArgs)) {
498 size_t idx = it.index();
499 Value val = it.value();
500 if (tensorIndices.contains(idx)) {
501 result.push_back(
502 bufferization::ToTensorOp::create(rewriter, val.getLoc(),
503 oldBbArgs[idx].getType(), val)
504 .getResult());
505 } else {
506 result.push_back(val);
507 }
508 }
509 return result;
510}
511
512/// Compute the bufferized type of a loop iter_arg. This type must be equal to
513/// the bufferized type of the corresponding init_arg and the bufferized type
514/// of the corresponding yielded value.
515///
516/// This function uses bufferization::getBufferType to compute the bufferized
517/// type of the init_arg and of the yielded value. (The computation of the
518/// bufferized yielded value type usually requires computing the bufferized type
519/// of the iter_arg again; the implementation of getBufferType traces back the
520/// use-def chain of the given value and computes a buffer type along the way.)
521/// If both buffer types are equal, no casts are needed the computed buffer type
522/// can be used directly. Otherwise, the buffer types can only differ in their
523/// layout map and a cast must be inserted.
524static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
525 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
526 const BufferizationOptions &options, const BufferizationState &state,
527 SmallVector<Value> &invocationStack) {
528 // Determine the buffer type of the init_arg.
529 auto initArgBufferType =
530 bufferization::getBufferType(initArg, options, state, invocationStack);
531 if (failed(initArgBufferType))
532 return failure();
533
534 if (llvm::count(invocationStack, iterArg) >= 2) {
535 // If the iter_arg is already twice on the invocation stack, just take the
536 // type of the init_arg. This is to avoid infinite loops when calculating
537 // the buffer type. This will most likely result in computing a memref type
538 // with a fully dynamic layout map.
539
540 // Note: For more precise layout map computation, a fixpoint iteration could
541 // be done (i.e., re-computing the yielded buffer type until the bufferized
542 // iter_arg type no longer changes). This current implementation immediately
543 // switches to a fully dynamic layout map when a mismatch between bufferized
544 // init_arg type and bufferized yield value type is detected.
545 return *initArgBufferType;
546 }
547
548 // Compute the buffer type of the yielded value.
549 BufferLikeType yieldedValueBufferType;
550 if (isa<BaseMemRefType>(yieldedValue.getType())) {
551 // scf.yield was already bufferized.
552 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
553 } else {
554 // Note: This typically triggers a recursive call for the buffer type of
555 // the iter_arg.
556 auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
557 state, invocationStack);
558 if (failed(maybeBufferType))
559 return failure();
560 yieldedValueBufferType = *maybeBufferType;
561 }
562
563 // If yielded type and init_arg type are the same, use that type directly.
564 if (*initArgBufferType == yieldedValueBufferType)
565 return yieldedValueBufferType;
566
567 // If there is a mismatch between the yielded buffer type and the init_arg
568 // buffer type, the buffer type must be promoted to a fully dynamic layout
569 // map.
570 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
571 auto iterTensorType = cast<TensorType>(iterArg.getType());
572 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
573 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
574 return loopOp->emitOpError(
575 "init_arg and yielded value bufferize to inconsistent memory spaces");
576#ifndef NDEBUG
577 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
578 assert(
579 llvm::all_equal({yieldedRankedBufferType.getShape(),
580 cast<MemRefType>(initBufferType).getShape(),
581 cast<RankedTensorType>(iterTensorType).getShape()}) &&
582 "expected same shape");
583 }
584#endif // NDEBUG
585 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
586 iterTensorType, yieldedBufferType.getMemorySpace()));
587}
588
589/// Return `true` if the given loop may have 0 iterations.
590bool mayHaveZeroIterations(scf::ForOp forOp) {
591 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
592 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
593 if (!lb.has_value() || !ub.has_value())
594 return true;
595 return *ub <= *lb;
596}
597
598/// Bufferization of scf.for. Replace with a new scf.for that operates on
599/// memrefs.
600struct ForOpInterface
601 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
602 scf::ForOp> {
603 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
604 const AnalysisState &state) const {
605 auto forOp = cast<scf::ForOp>(op);
606
607 // If the loop has zero iterations, the results of the op are their
608 // corresponding init_args, meaning that the init_args bufferize to a read.
609 if (mayHaveZeroIterations(forOp))
610 return true;
611
612 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
613 // its matching bbArg may.
614 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
615 }
616
617 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
618 const AnalysisState &state) const {
619 // Tensor iter_args of scf::ForOps are always considered as a write.
620 return true;
621 }
622
623 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
624 const AnalysisState &state) const {
625 auto forOp = cast<scf::ForOp>(op);
626 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
627 BufferRelation relation = bufferRelation(op, opResult, state);
628 return {{opResult, relation,
629 /*isDefinite=*/relation == BufferRelation::Equivalent}};
630 }
631
632 BufferRelation bufferRelation(Operation *op, OpResult opResult,
633 const AnalysisState &state) const {
634 // ForOp results are equivalent to their corresponding init_args if the
635 // corresponding iter_args and yield values are equivalent.
636 auto forOp = cast<scf::ForOp>(op);
637 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
638 bool equivalentYield = state.areEquivalentBufferizedValues(
639 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
640 return equivalentYield ? BufferRelation::Equivalent
641 : BufferRelation::Unknown;
642 }
643
644 bool isWritable(Operation *op, Value value,
645 const AnalysisState &state) const {
646 // Interestingly, scf::ForOp's bbArg can **always** be viewed
647 // inplace from the perspective of ops nested under:
648 // 1. Either the matching iter operand is not bufferized inplace and an
649 // alloc + optional copy makes the bbArg itself inplaceable.
650 // 2. Or the matching iter operand is bufferized inplace and bbArg just
651 // bufferizes to that too.
652 return true;
653 }
654
655 LogicalResult
656 resolveConflicts(Operation *op, RewriterBase &rewriter,
657 const AnalysisState &analysisState,
658 const BufferizationState &bufferizationState) const {
659 auto bufferizableOp = cast<BufferizableOpInterface>(op);
660 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
661 rewriter, analysisState, bufferizationState)))
662 return failure();
663
664 if (analysisState.getOptions().copyBeforeWrite)
665 return success();
666
667 // According to the `getAliasing...` implementations, a bufferized OpResult
668 // may alias only with the corresponding bufferized init_arg (or with a
669 // newly allocated buffer) and not with other buffers defined outside of the
670 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
671 // but not with any other OpOperand.
672 auto forOp = cast<scf::ForOp>(op);
673 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
674 OpBuilder::InsertionGuard g(rewriter);
675 rewriter.setInsertionPoint(yieldOp);
676
677 // Indices of all iter_args that have tensor type. These are the ones that
678 // are bufferized.
679 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
680 // For every yielded value, does it alias with something defined outside of
681 // the loop?
682 SmallVector<Value> yieldValues;
683 for (const auto it : llvm::enumerate(yieldOp.getResults())) {
684 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
685 // type cannot be used in the signature of `resolveConflicts` because the
686 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
687 // is defined in the "Transforms" build unit.
688 if (!indices.contains(it.index()) ||
689 doesNotAliasExternalValue(
690 it.value(), &forOp.getRegion(),
691 /*exceptions=*/forOp.getRegionIterArg(it.index()),
692 static_cast<const OneShotAnalysisState &>(analysisState))) {
693 yieldValues.push_back(it.value());
694 continue;
695 }
696 FailureOr<Value> alloc = allocateTensorForShapedValue(
697 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
698 bufferizationState);
699 if (failed(alloc))
700 return failure();
701 yieldValues.push_back(*alloc);
702 }
703
704 rewriter.modifyOpInPlace(
705 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
706 return success();
707 }
708
709 FailureOr<BufferLikeType>
710 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
711 const BufferizationState &state,
712 SmallVector<Value> &invocationStack) const {
713 auto forOp = cast<scf::ForOp>(op);
714 assert(getOwnerOfValue(value) == op && "invalid value");
715 assert(isa<TensorType>(value.getType()) && "expected tensor type");
716
717 if (auto opResult = dyn_cast<OpResult>(value)) {
718 // The type of an OpResult must match the corresponding iter_arg type.
719 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
720 return bufferization::getBufferType(bbArg, options, state,
721 invocationStack);
722 }
723
724 // Compute result/argument number.
725 BlockArgument bbArg = cast<BlockArgument>(value);
726 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
727
728 // Compute the bufferized type.
729 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
730 Value yieldedValue = yieldOp.getOperand(resultNum);
731 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
732 Value initArg = forOp.getInitArgs()[resultNum];
733 return computeLoopRegionIterArgBufferType(
734 op, iterArg, initArg, yieldedValue, options, state, invocationStack);
735 }
736
737 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
738 const BufferizationOptions &options,
739 BufferizationState &state) const {
740 auto forOp = cast<scf::ForOp>(op);
741 Block *oldLoopBody = forOp.getBody();
742
743 // Indices of all iter_args that have tensor type. These are the ones that
744 // are bufferized.
745 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
746
747 // The new memref init_args of the loop.
748 FailureOr<SmallVector<Value>> maybeInitArgs =
749 getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
750 if (failed(maybeInitArgs))
751 return failure();
752 SmallVector<Value> initArgs = *maybeInitArgs;
753
754 // Cast init_args if necessary.
755 SmallVector<Value> castedInitArgs;
756 for (const auto &it : llvm::enumerate(initArgs)) {
757 Value initArg = it.value();
758 Value result = forOp->getResult(it.index());
759 // If the type is not a tensor, bufferization doesn't need to touch it.
760 if (!isa<TensorType>(result.getType())) {
761 castedInitArgs.push_back(initArg);
762 continue;
763 }
764 auto targetType = bufferization::getBufferType(result, options, state);
765 if (failed(targetType))
766 return failure();
767 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
768 }
769
770 // Construct a new scf.for op with memref instead of tensor values.
771 auto newForOp = scf::ForOp::create(
772 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
773 forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
774 forOp.getUnsignedCmp());
775 newForOp->setAttrs(forOp->getAttrs());
776 Block *loopBody = newForOp.getBody();
777
778 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
779 // iter_args of the new loop in ToTensorOps.
780 rewriter.setInsertionPointToStart(loopBody);
781 SmallVector<Value> iterArgs =
782 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
783 forOp.getRegionIterArgs(), indices);
784 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
785
786 // Move loop body to new loop.
787 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
788
789 // Replace loop results.
790 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
791
792 return success();
793 }
794
795 /// Assert that yielded values of an scf.for op are equivalent to their
796 /// corresponding bbArgs. In that case, the buffer relations of the
797 /// corresponding OpResults are "Equivalent".
798 ///
799 /// If this is not the case, an allocs+copies are inserted and yielded from
800 /// the loop. This could be a performance problem, so it must be explicitly
801 /// activated with `alloc-return-allocs`.
802 LogicalResult verifyAnalysis(Operation *op,
803 const AnalysisState &state) const {
804 const auto &options =
805 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
806 if (options.allowReturnAllocsFromLoops)
807 return success();
808
809 auto forOp = cast<scf::ForOp>(op);
810 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
811 for (OpResult opResult : op->getOpResults()) {
812 if (!isa<TensorType>(opResult.getType()))
813 continue;
814
815 // Note: This is overly strict. We should check for aliasing bufferized
816 // values. But we don't have a "must-alias" analysis yet.
817 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
818 return yieldOp->emitError()
819 << "Yield operand #" << opResult.getResultNumber()
820 << " is not equivalent to the corresponding iter bbArg";
821 }
822
823 return success();
824 }
825};
826
827/// Bufferization of scf.while. Replace with a new scf.while that operates on
828/// memrefs.
829struct WhileOpInterface
830 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
831 scf::WhileOp> {
832 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
833 const AnalysisState &state) const {
834 // Tensor iter_args of scf::WhileOps are always considered as a read.
835 return true;
836 }
837
838 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
839 const AnalysisState &state) const {
840 // Tensor iter_args of scf::WhileOps are always considered as a write.
841 return true;
842 }
843
844 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
845 const AnalysisState &state) const {
846 auto whileOp = cast<scf::WhileOp>(op);
847 unsigned int idx = opOperand.getOperandNumber();
848
849 // The OpResults and OpOperands may not match. They may not even have the
850 // same type. The number of OpResults and OpOperands can also differ.
851 if (idx >= op->getNumResults() ||
852 opOperand.get().getType() != op->getResult(idx).getType())
853 return {};
854
855 // The only aliasing OpResult may be the one at the same index.
856 OpResult opResult = whileOp->getResult(idx);
857 BufferRelation relation = bufferRelation(op, opResult, state);
858 return {{opResult, relation,
859 /*isDefinite=*/relation == BufferRelation::Equivalent}};
860 }
861
862 BufferRelation bufferRelation(Operation *op, OpResult opResult,
863 const AnalysisState &state) const {
864 // WhileOp results are equivalent to their corresponding init_args if the
865 // corresponding iter_args and yield values are equivalent (for both the
866 // "before" and the "after" block).
867 unsigned int resultNumber = opResult.getResultNumber();
868 auto whileOp = cast<scf::WhileOp>(op);
869
870 // The "before" region bbArgs and the OpResults may not match.
871 if (resultNumber >= whileOp.getBeforeArguments().size())
872 return BufferRelation::Unknown;
873 if (opResult.getType() !=
874 whileOp.getBeforeArguments()[resultNumber].getType())
875 return BufferRelation::Unknown;
876
877 auto conditionOp = whileOp.getConditionOp();
878 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
879 Value conditionOperand = conditionOp.getArgs()[resultNumber];
880 bool equivCondition =
881 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
882
883 auto yieldOp = whileOp.getYieldOp();
884 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
885 Value yieldOperand = yieldOp.getOperand(resultNumber);
886 bool equivYield =
887 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
888
889 return equivCondition && equivYield ? BufferRelation::Equivalent
890 : BufferRelation::Unknown;
891 }
892
893 bool isWritable(Operation *op, Value value,
894 const AnalysisState &state) const {
895 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
896 // inplace from the perspective of ops nested under:
897 // 1. Either the matching iter operand is not bufferized inplace and an
898 // alloc + optional copy makes the bbArg itself inplaceable.
899 // 2. Or the matching iter operand is bufferized inplace and bbArg just
900 // bufferizes to that too.
901 return true;
902 }
903
904 LogicalResult
905 resolveConflicts(Operation *op, RewriterBase &rewriter,
906 const AnalysisState &analysisState,
907 const BufferizationState &bufferizationState) const {
908 auto bufferizableOp = cast<BufferizableOpInterface>(op);
909 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
910 rewriter, analysisState, bufferizationState)))
911 return failure();
912
913 if (analysisState.getOptions().copyBeforeWrite)
914 return success();
915
916 // According to the `getAliasing...` implementations, a bufferized OpResult
917 // may alias only with the corresponding bufferized init_arg and with no
918 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
919 // but not with any other OpOperand. If a corresponding OpResult/init_arg
920 // pair bufferizes to equivalent buffers, this aliasing requirement is
921 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
922 // (New buffer copies do not alias with any buffer.)
923 OpBuilder::InsertionGuard g(rewriter);
924 auto whileOp = cast<scf::WhileOp>(op);
925 auto conditionOp = whileOp.getConditionOp();
926
927 // For every yielded value, is the value equivalent to its corresponding
928 // bbArg?
929 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
930 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
931 DenseSet<int64_t> equivalentYieldsAfter =
932 getEquivalentBuffers(whileOp.getAfterArguments(),
933 whileOp.getYieldOp().getResults(), analysisState);
934
935 // Update "before" region.
936 rewriter.setInsertionPoint(conditionOp);
937 SmallVector<Value> beforeYieldValues;
938 for (int64_t idx = 0;
939 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
940 Value value = conditionOp.getArgs()[idx];
941 if (!isa<TensorType>(value.getType()) ||
942 (equivalentYieldsAfter.contains(idx) &&
943 equivalentYieldsBefore.contains(idx))) {
944 beforeYieldValues.push_back(value);
945 continue;
946 }
947 FailureOr<Value> alloc = allocateTensorForShapedValue(
948 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
949 bufferizationState);
950 if (failed(alloc))
951 return failure();
952 beforeYieldValues.push_back(*alloc);
953 }
954 rewriter.modifyOpInPlace(conditionOp, [&]() {
955 conditionOp.getArgsMutable().assign(beforeYieldValues);
956 });
957
958 return success();
959 }
960
961 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
962 const BufferizationOptions &options,
963 BufferizationState &state) const {
964 auto whileOp = cast<scf::WhileOp>(op);
965
966 // Indices of all bbArgs that have tensor type. These are the ones that
967 // are bufferized. The "before" and "after" regions may have different args.
968 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
969 DenseSet<int64_t> indicesAfter =
970 getTensorIndices(whileOp.getAfterArguments());
971
972 // The new memref init_args of the loop.
973 FailureOr<SmallVector<Value>> maybeInitArgs =
974 getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
975 if (failed(maybeInitArgs))
976 return failure();
977 SmallVector<Value> initArgs = *maybeInitArgs;
978
979 // Cast init_args if necessary.
980 SmallVector<Value> castedInitArgs;
981 for (const auto &it : llvm::enumerate(initArgs)) {
982 Value initArg = it.value();
983 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
984 // If the type is not a tensor, bufferization doesn't need to touch it.
985 if (!isa<TensorType>(beforeArg.getType())) {
986 castedInitArgs.push_back(initArg);
987 continue;
988 }
989 auto targetType = bufferization::getBufferType(beforeArg, options, state);
990 if (failed(targetType))
991 return failure();
992 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
993 }
994
995 // The result types of a WhileOp are the same as the "after" bbArg types.
996 SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
997 whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
998 if (!isa<TensorType>(bbArg.getType()))
999 return bbArg.getType();
1000 // TODO: error handling
1001 return llvm::cast<Type>(
1002 *bufferization::getBufferType(bbArg, options, state));
1003 });
1004
1005 // Construct a new scf.while op with memref instead of tensor values.
1006 ValueRange argsRangeBefore(castedInitArgs);
1007 TypeRange argsTypesBefore(argsRangeBefore);
1008 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1009 argsTypesAfter, castedInitArgs);
1010
1011 // Add before/after regions to the new op.
1012 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1013 whileOp.getLoc());
1014 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1015 whileOp.getLoc());
1016 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1017 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1018 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1019 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1020
1021 // Set up new iter_args and move the loop condition block to the new op.
1022 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1023 // in ToTensorOps.
1024 rewriter.setInsertionPointToStart(newBeforeBody);
1025 SmallVector<Value> newBeforeArgs =
1026 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1027 whileOp.getBeforeArguments(), indicesBefore);
1028 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1029
1030 // Set up new iter_args and move the loop body block to the new op.
1031 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1032 // in ToTensorOps.
1033 rewriter.setInsertionPointToStart(newAfterBody);
1034 SmallVector<Value> newAfterArgs =
1035 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1036 whileOp.getAfterArguments(), indicesAfter);
1037 rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1038
1039 // Replace loop results.
1040 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1041
1042 return success();
1043 }
1044
1045 FailureOr<BufferLikeType>
1046 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1047 const BufferizationState &state,
1048 SmallVector<Value> &invocationStack) const {
1049 auto whileOp = cast<scf::WhileOp>(op);
1050 assert(getOwnerOfValue(value) == op && "invalid value");
1051 assert(isa<TensorType>(value.getType()) && "expected tensor type");
1052
1053 // Case 1: Block argument of the "before" region.
1054 if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1055 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1056 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1057 auto yieldOp = whileOp.getYieldOp();
1058 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1059 return computeLoopRegionIterArgBufferType(
1060 op, bbArg, initArg, yieldedValue, options, state, invocationStack);
1061 }
1062 }
1063
1064 // Case 2: OpResult of the loop or block argument of the "after" region.
1065 // The bufferized "after" bbArg type can be directly computed from the
1066 // bufferized "before" bbArg type.
1067 unsigned resultNum;
1068 if (auto opResult = dyn_cast<OpResult>(value)) {
1069 resultNum = opResult.getResultNumber();
1070 } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1071 &whileOp.getAfter()) {
1072 resultNum = cast<BlockArgument>(value).getArgNumber();
1073 } else {
1074 llvm_unreachable("invalid value");
1075 }
1076 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1077 if (!isa<TensorType>(conditionYieldedVal.getType())) {
1078 // scf.condition was already bufferized.
1079 return cast<BufferLikeType>(conditionYieldedVal.getType());
1080 }
1081 return bufferization::getBufferType(conditionYieldedVal, options, state,
1082 invocationStack);
1083 }
1084
1085 /// Assert that yielded values of an scf.while op are equivalent to their
1086 /// corresponding bbArgs. In that case, the buffer relations of the
1087 /// corresponding OpResults are "Equivalent".
1088 ///
1089 /// If this is not the case, allocs+copies are inserted and yielded from
1090 /// the loop. This could be a performance problem, so it must be explicitly
1091 /// activated with `allow-return-allocs`.
1092 ///
1093 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1094 /// equivalence condition must be checked for both.
1095 LogicalResult verifyAnalysis(Operation *op,
1096 const AnalysisState &state) const {
1097 auto whileOp = cast<scf::WhileOp>(op);
1098 const auto &options =
1099 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1100 if (options.allowReturnAllocsFromLoops)
1101 return success();
1102
1103 auto conditionOp = whileOp.getConditionOp();
1104 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1105 Block *block = conditionOp->getBlock();
1106 if (!isa<TensorType>(it.value().getType()))
1107 continue;
1108 if (it.index() >= block->getNumArguments() ||
1109 !state.areEquivalentBufferizedValues(it.value(),
1110 block->getArgument(it.index())))
1111 return conditionOp->emitError()
1112 << "Condition arg #" << it.index()
1113 << " is not equivalent to the corresponding iter bbArg";
1114 }
1115
1116 auto yieldOp = whileOp.getYieldOp();
1117 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1118 Block *block = yieldOp->getBlock();
1119 if (!isa<TensorType>(it.value().getType()))
1120 continue;
1121 if (it.index() >= block->getNumArguments() ||
1122 !state.areEquivalentBufferizedValues(it.value(),
1123 block->getArgument(it.index())))
1124 return yieldOp->emitError()
1125 << "Yield operand #" << it.index()
1126 << " is not equivalent to the corresponding iter bbArg";
1127 }
1128
1129 return success();
1130 }
1131};
1132
1133/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1134/// this is for analysis only.
1135struct YieldOpInterface
1136 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1137 scf::YieldOp> {
1138 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1139 const AnalysisState &state) const {
1140 return true;
1141 }
1142
1143 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1144 const AnalysisState &state) const {
1145 return false;
1146 }
1147
1148 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1149 const AnalysisState &state) const {
1150 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1151 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1152 BufferRelation::Equivalent, /*isDefinite=*/false}};
1153 }
1154 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1155 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1156 BufferRelation::Equivalent}};
1157 return {};
1158 }
1159
1160 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1161 const AnalysisState &state) const {
1162 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1163 // may be generated inside the block. We should not return/yield allocations
1164 // when possible.
1165 return true;
1166 }
1167
1168 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1169 const BufferizationOptions &options,
1170 BufferizationState &state) const {
1171 auto yieldOp = cast<scf::YieldOp>(op);
1172 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1173 scf::WhileOp>(yieldOp->getParentOp()))
1174 return yieldOp->emitError("unsupported scf::YieldOp parent");
1175
1176 SmallVector<Value> newResults;
1177 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1178 Value value = it.value();
1179 if (isa<TensorType>(value.getType())) {
1180 FailureOr<Value> maybeBuffer =
1181 getBuffer(rewriter, value, options, state);
1182 if (failed(maybeBuffer))
1183 return failure();
1184 Value buffer = *maybeBuffer;
1185 // We may have to cast the value before yielding it.
1186 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1187 yieldOp->getParentOp())) {
1188 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1189 yieldOp->getParentOp()->getResult(it.index()), options, state);
1190 if (failed(resultType))
1191 return failure();
1192 buffer = castBuffer(rewriter, buffer, *resultType);
1193 } else if (auto whileOp =
1194 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1195 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1196 whileOp.getBeforeArguments()[it.index()], options, state);
1197 if (failed(resultType))
1198 return failure();
1199 buffer = castBuffer(rewriter, buffer, *resultType);
1200 }
1201 newResults.push_back(buffer);
1202 } else {
1203 newResults.push_back(value);
1204 }
1205 }
1206
1207 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1208 return success();
1209 }
1210};
1211
1212/// Bufferization of ForallOp. This also bufferizes the terminator of the
1213/// region. There are op interfaces for the terminators (InParallelOp
1214/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1215/// for bufferization.
1216struct ForallOpInterface
1217 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1218 ForallOp> {
1219 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1220 const AnalysisState &state) const {
1221 // All tensor operands to `scf.forall` are `shared_outs` and all
1222 // shared outs are assumed to be read by the loop. This does not
1223 // account for the case where the entire value is over-written,
1224 // but being conservative here.
1225 return true;
1226 }
1227
1228 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1229 const AnalysisState &state) const {
1230 // Outputs of scf::ForallOps are always considered as a write.
1231 return true;
1232 }
1233
1234 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1235 const AnalysisState &state) const {
1236 auto forallOp = cast<ForallOp>(op);
1237 return {
1238 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1239 }
1240
1241 bool isWritable(Operation *op, Value value,
1242 const AnalysisState &state) const {
1243 return true;
1244 }
1245
1246 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1247 const BufferizationOptions &options,
1248 BufferizationState &state) const {
1249 OpBuilder::InsertionGuard guard(rewriter);
1250 auto forallOp = cast<ForallOp>(op);
1251 int64_t rank = forallOp.getRank();
1252
1253 // Get buffers for all output operands.
1254 SmallVector<Value> buffers;
1255 for (Value out : forallOp.getOutputs()) {
1256 FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1257 if (failed(buffer))
1258 return failure();
1259 buffers.push_back(*buffer);
1260 }
1261
1262 // Use buffers instead of block arguments.
1263 rewriter.setInsertionPointToStart(forallOp.getBody());
1264 for (const auto &it : llvm::zip(
1265 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1266 BlockArgument bbArg = std::get<0>(it);
1267 Value buffer = std::get<1>(it);
1268 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1269 bbArg.getType(), buffer);
1270 bbArg.replaceAllUsesWith(bufferAsTensor);
1271 }
1272
1273 // Create new ForallOp without any results and drop the automatically
1274 // introduced terminator.
1275 rewriter.setInsertionPoint(forallOp);
1276 ForallOp newForallOp;
1277 newForallOp = ForallOp::create(
1278 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1279 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1280 /*outputs=*/ValueRange(), forallOp.getMapping());
1281
1282 // Keep discardable attributes from the original op.
1283 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1284
1285 rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1286
1287 // Move over block contents of the old op.
1288 SmallVector<Value> replacementBbArgs;
1289 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1290 newForallOp.getBody()->getArguments().end());
1291 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1292 rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1293 replacementBbArgs);
1294
1295 // Remove the old op and replace all of its uses.
1296 replaceOpWithBufferizedValues(rewriter, op, buffers);
1297
1298 return success();
1299 }
1300
1301 FailureOr<BufferLikeType>
1302 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1303 const BufferizationState &state,
1304 SmallVector<Value> &invocationStack) const {
1305 auto forallOp = cast<ForallOp>(op);
1306
1307 if (auto bbArg = dyn_cast<BlockArgument>(value))
1308 // A tensor block argument has the same bufferized type as the
1309 // corresponding output operand.
1310 return bufferization::getBufferType(
1311 forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1312 invocationStack);
1313
1314 // The bufferized result type is the same as the bufferized type of the
1315 // corresponding output operand.
1316 return bufferization::getBufferType(
1317 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1318 state, invocationStack);
1319 }
1320
1321 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1322 auto forallOp = cast<ForallOp>(op);
1323
1324 // This op is repetitive if it has 1 or more steps.
1325 // If the control variables are dynamic, it is also considered so.
1326 for (auto [lb, ub, step] :
1327 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1328 forallOp.getMixedStep())) {
1329 std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1330 if (!lbConstant)
1331 return true;
1332
1333 std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1334 if (!ubConstant)
1335 return true;
1336
1337 std::optional<int64_t> stepConstant = getConstantIntValue(step);
1338 if (!stepConstant)
1339 return true;
1340
1341 if (*lbConstant + *stepConstant < *ubConstant)
1342 return true;
1343 }
1344 return false;
1345 }
1346
1347 bool isParallelRegion(Operation *op, unsigned index) const {
1348 return isRepetitiveRegion(op, index);
1349 }
1350};
1351
1352/// Nothing to do for InParallelOp.
1353struct InParallelOpInterface
1354 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1355 InParallelOp> {
1356 LogicalResult bufferize(Operation *op, RewriterBase &b,
1357 const BufferizationOptions &options,
1358 BufferizationState &state) const {
1359 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1360 return failure();
1361 }
1362};
1363
1364} // namespace
1365} // namespace scf
1366} // namespace mlir
1367
1369 DialectRegistry &registry) {
1370 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1371 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1372 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1373 ForOp::attachInterface<ForOpInterface>(*ctx);
1374 IfOp::attachInterface<IfOpInterface>(*ctx);
1375 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1376 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1377 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1378 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1379 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1380 });
1381}
return success()
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:501
result_range getOpResults()
Definition Operation.h:420
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition Region.cpp:50
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120