MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/Operation.h"
22
23using namespace mlir;
24using namespace mlir::bufferization;
25using namespace mlir::scf;
26
27namespace mlir {
28namespace scf {
29namespace {
30
31/// Helper function for loop bufferization. Cast the given buffer to the given
32/// memref type.
33static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
34 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
35 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
36 // If the buffer already has the correct type, no cast is needed.
37 if (buffer.getType() == type)
38 return buffer;
39 // TODO: In case `type` has a layout map that is not the fully dynamic
40 // one, we may not be able to cast the buffer. In that case, the loop
41 // iter_arg's layout map must be changed (see uses of `castBuffer`).
42 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
43 "scf.while op bufferization: cast incompatible");
44 return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
45}
46
47/// Helper function for loop bufferization. Return "true" if the given value
48/// is guaranteed to not alias with an external tensor apart from values in
49/// `exceptions`. A value is external if it is defined outside of the given
50/// region or if it is an entry block argument of the region.
51static bool doesNotAliasExternalValue(Value value, Region *region,
52 ValueRange exceptions,
53 const OneShotAnalysisState &state) {
54 assert(region->hasOneBlock() && "expected region with single block");
55 bool result = true;
56 state.applyOnAliases(value, [&](Value alias) {
57 if (llvm::is_contained(exceptions, alias))
58 return;
59 Region *aliasRegion = alias.getParentRegion();
60 if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
61 result = false;
62 if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
63 result = false;
64 });
65 return result;
66}
67
68/// Bufferization of scf.condition.
69struct ConditionOpInterface
70 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
71 scf::ConditionOp> {
72 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
73 const AnalysisState &state) const {
74 return true;
75 }
76
77 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
78 const AnalysisState &state) const {
79 return false;
80 }
81
82 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
83 const AnalysisState &state) const {
84 return {};
85 }
86
87 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
88 const AnalysisState &state) const {
89 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
90 // may be generated inside the block. We should not return/yield allocations
91 // when possible.
92 return true;
93 }
94
95 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
96 const BufferizationOptions &options,
97 BufferizationState &state) const {
98 auto conditionOp = cast<scf::ConditionOp>(op);
99 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
100
101 SmallVector<Value> newArgs;
102 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
103 Value value = it.value();
104 if (isa<TensorType>(value.getType())) {
105 FailureOr<Value> maybeBuffer =
106 getBuffer(rewriter, value, options, state);
107 if (failed(maybeBuffer))
108 return failure();
109 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
110 whileOp.getAfterArguments()[it.index()], options, state);
111 if (failed(resultType))
112 return failure();
113 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
114 newArgs.push_back(buffer);
115 } else {
116 newArgs.push_back(value);
117 }
118 }
119
120 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
121 rewriter, op, conditionOp.getCondition(), newArgs);
122 return success();
123 }
124};
125
126/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
127/// return an empty op.
128static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
129 scf::YieldOp result;
130 for (Block &block : executeRegionOp.getRegion()) {
131 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
132 if (result)
133 return {};
134 result = yieldOp;
135 }
136 }
137 return result;
138}
139
140/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
141/// fully implemented at the moment.
142struct ExecuteRegionOpInterface
143 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
144 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
145
146 static bool supportsUnstructuredControlFlow() { return true; }
147
148 bool isWritable(Operation *op, Value value,
149 const AnalysisState &state) const {
150 return true;
151 }
152
153 LogicalResult verifyAnalysis(Operation *op,
154 const AnalysisState &state) const {
155 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
156 // TODO: scf.execute_region with multiple yields are not supported.
157 if (!getUniqueYieldOp(executeRegionOp))
158 return op->emitOpError("op without unique scf.yield is not supported");
159 return success();
160 }
161
162 AliasingOpOperandList
163 getAliasingOpOperands(Operation *op, Value value,
164 const AnalysisState &state) const {
165 if (auto bbArg = dyn_cast<BlockArgument>(value))
166 return getAliasingBranchOpOperands(op, bbArg, state);
167
168 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
169 // any SSA value that is in scope. To allow for use-def chain traversal
170 // through ExecuteRegionOps in the analysis, the corresponding yield value
171 // is considered to be aliasing with the result.
172 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
173 auto it = llvm::find(op->getOpResults(), value);
174 assert(it != op->getOpResults().end() && "invalid value");
175 size_t resultNum = std::distance(op->getOpResults().begin(), it);
176 auto yieldOp = getUniqueYieldOp(executeRegionOp);
177 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
178 if (!yieldOp)
179 return {};
180 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
181 }
182
183 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
184 const BufferizationOptions &options,
185 BufferizationState &state) const {
186 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
187 auto yieldOp = getUniqueYieldOp(executeRegionOp);
188 TypeRange newResultTypes(yieldOp.getResults());
189
190 // Create new op and move over region.
191 auto newOp = scf::ExecuteRegionOp::create(
192 rewriter, op->getLoc(), newResultTypes, executeRegionOp.getNoInline());
193 newOp.getRegion().takeBody(executeRegionOp.getRegion());
194
195 // Bufferize every block.
196 for (Block &block : newOp.getRegion())
198 options, state)))
199 return failure();
200
201 // Update all uses of the old op.
202 rewriter.setInsertionPointAfter(newOp);
203 SmallVector<Value> newResults;
204 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
205 if (isa<TensorType>(it.value())) {
206 newResults.push_back(bufferization::ToTensorOp::create(
207 rewriter, executeRegionOp.getLoc(), it.value(),
208 newOp->getResult(it.index())));
209 } else {
210 newResults.push_back(newOp->getResult(it.index()));
211 }
212 }
213
214 // Replace old op.
215 rewriter.replaceOp(executeRegionOp, newResults);
216
217 return success();
218 }
219};
220
221/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
222struct IfOpInterface
223 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
224 AliasingOpOperandList
225 getAliasingOpOperands(Operation *op, Value value,
226 const AnalysisState &state) const {
227 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
228 // value that is in scope. To allow for use-def chain traversal through
229 // IfOps in the analysis, both corresponding yield values from the then/else
230 // branches are considered to be aliasing with the result.
231 auto ifOp = cast<scf::IfOp>(op);
232 size_t resultNum = std::distance(op->getOpResults().begin(),
233 llvm::find(op->getOpResults(), value));
234 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
235 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
236 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
237 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
238 }
239
240 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
241 const BufferizationOptions &options,
242 BufferizationState &state) const {
243 OpBuilder::InsertionGuard g(rewriter);
244 auto ifOp = cast<scf::IfOp>(op);
245
246 // Compute bufferized result types.
247 SmallVector<Type> newTypes;
248 for (Value result : ifOp.getResults()) {
249 if (!isa<TensorType>(result.getType())) {
250 newTypes.push_back(result.getType());
251 continue;
252 }
253 auto bufferType = bufferization::getBufferType(result, options, state);
254 if (failed(bufferType))
255 return failure();
256 newTypes.push_back(*bufferType);
257 }
258
259 // Create new op.
260 rewriter.setInsertionPoint(ifOp);
261 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
262 ifOp.getCondition(),
263 /*withElseRegion=*/true);
264
265 // Move over then/else blocks.
266 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
267 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
268
269 // Replace op results.
270 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
271
272 return success();
273 }
274
275 FailureOr<BufferLikeType>
276 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
277 const BufferizationState &state,
278 SmallVector<Value> &invocationStack) const {
279 auto ifOp = cast<scf::IfOp>(op);
280 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
281 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
282 assert(value.getDefiningOp() == op && "invalid valid");
283
284 // Determine buffer types of the true/false branches.
285 auto opResult = cast<OpResult>(value);
286 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
287 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
288 BaseMemRefType thenBufferType, elseBufferType;
289 if (isa<BaseMemRefType>(thenValue.getType())) {
290 // True branch was already bufferized.
291 thenBufferType = cast<BaseMemRefType>(thenValue.getType());
292 } else {
293 auto maybeBufferType =
294 bufferization::detail::asMemRefType(bufferization::getBufferType(
295 thenValue, options, state, invocationStack));
296 if (failed(maybeBufferType))
297 return failure();
298 thenBufferType = *maybeBufferType;
299 }
300 if (isa<BaseMemRefType>(elseValue.getType())) {
301 // False branch was already bufferized.
302 elseBufferType = cast<BaseMemRefType>(elseValue.getType());
303 } else {
304 auto maybeBufferType =
305 bufferization::detail::asMemRefType(bufferization::getBufferType(
306 elseValue, options, state, invocationStack));
307 if (failed(maybeBufferType))
308 return failure();
309 elseBufferType = *maybeBufferType;
310 }
311
312 // Best case: Both branches have the exact same buffer type.
313 if (thenBufferType == elseBufferType)
314 return cast<BufferLikeType>(thenBufferType);
315
316 // Memory space mismatch.
317 if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
318 return op->emitError("inconsistent memory space on then/else branches");
319
320 // Layout maps are different: Promote to fully dynamic layout map.
321 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
322 cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
323 }
324};
325
326/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
327/// yields memrefs.
328struct IndexSwitchOpInterface
329 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
330 scf::IndexSwitchOp> {
331 AliasingOpOperandList
332 getAliasingOpOperands(Operation *op, Value value,
333 const AnalysisState &state) const {
334 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
335 // any SSA. This is similar to IfOps.
336 auto switchOp = cast<scf::IndexSwitchOp>(op);
337 int64_t resultNum = cast<OpResult>(value).getResultNumber();
338 AliasingOpOperandList result;
339 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
340 auto yieldOp =
341 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
342 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
343 BufferRelation::Equivalent,
344 /*isDefinite=*/false));
345 }
346 auto defaultYieldOp =
347 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
348 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
349 BufferRelation::Equivalent,
350 /*isDefinite=*/false));
351 return result;
352 }
353
354 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
355 const BufferizationOptions &options,
356 BufferizationState &state) const {
357 OpBuilder::InsertionGuard g(rewriter);
358 auto switchOp = cast<scf::IndexSwitchOp>(op);
359
360 // Compute bufferized result types.
361 SmallVector<Type> newTypes;
362 for (Value result : switchOp.getResults()) {
363 if (!isa<TensorType>(result.getType())) {
364 newTypes.push_back(result.getType());
365 continue;
366 }
367 auto bufferType = bufferization::getBufferType(result, options, state);
368 if (failed(bufferType))
369 return failure();
370 newTypes.push_back(*bufferType);
371 }
372
373 // Create new op.
374 rewriter.setInsertionPoint(switchOp);
375 auto newSwitchOp = scf::IndexSwitchOp::create(
376 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
377 switchOp.getCases(), switchOp.getCases().size());
378
379 // Move over blocks.
380 for (auto [src, dest] :
381 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
382 rewriter.inlineRegionBefore(src, dest, dest.begin());
383 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
384 newSwitchOp.getDefaultRegion(),
385 newSwitchOp.getDefaultRegion().begin());
386
387 // Replace op results.
388 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
389
390 return success();
391 }
392
393 FailureOr<BufferLikeType>
394 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
395 const BufferizationState &state,
396 SmallVector<Value> &invocationStack) const {
397 auto switchOp = cast<scf::IndexSwitchOp>(op);
398 assert(value.getDefiningOp() == op && "invalid value");
399 int64_t resultNum = cast<OpResult>(value).getResultNumber();
400
401 // Helper function to get buffer type of a case.
402 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
403 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
404 Value yieldedValue = yieldOp->getOperand(resultNum);
405 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
406 return bufferType;
407 auto maybeBufferType = bufferization::getBufferType(
408 yieldedValue, options, state, invocationStack);
409 return bufferization::detail::asMemRefType(maybeBufferType);
410 };
411
412 // Compute buffer type of the default case.
413 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
414 if (failed(maybeBufferType))
415 return failure();
416 BaseMemRefType bufferType = *maybeBufferType;
417
418 // Compute buffer types of all other cases.
419 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
420 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
421 if (failed(yieldedBufferType))
422 return failure();
423
424 // Best case: Both branches have the exact same buffer type.
425 if (bufferType == *yieldedBufferType)
426 continue;
427
428 // Memory space mismatch.
429 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
430 return op->emitError("inconsistent memory space on switch cases");
431
432 // Layout maps are different: Promote to fully dynamic layout map.
433 bufferType = getMemRefTypeWithFullyDynamicLayout(
434 cast<TensorType>(value.getType()), bufferType.getMemorySpace());
435 }
436
437 return cast<BufferLikeType>(bufferType);
438 }
439};
440
441/// Helper function for loop bufferization. Return the indices of all values
442/// that have a tensor type.
443static DenseSet<int64_t> getTensorIndices(ValueRange values) {
445 for (const auto &it : llvm::enumerate(values))
446 if (isa<TensorType>(it.value().getType()))
447 result.insert(it.index());
448 return result;
449}
450
451/// Helper function for loop bufferization. Return the indices of all
452/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
453DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
454 ValueRange yieldedValues,
455 const AnalysisState &state) {
456 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
458 for (unsigned int i = 0; i < minSize; ++i) {
459 if (!isa<TensorType>(bbArgs[i].getType()) ||
460 !isa<TensorType>(yieldedValues[i].getType()))
461 continue;
462 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
463 result.insert(i);
464 }
465 return result;
466}
467
468/// Helper function for loop bufferization. Return the bufferized values of the
469/// given OpOperands. If an operand is not a tensor, return the original value.
470static FailureOr<SmallVector<Value>>
471getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
472 const BufferizationOptions &options, BufferizationState &state) {
473 SmallVector<Value> result;
474 for (OpOperand &opOperand : operands) {
475 if (isa<TensorType>(opOperand.get().getType())) {
476 FailureOr<Value> resultBuffer =
477 getBuffer(rewriter, opOperand.get(), options, state);
478 if (failed(resultBuffer))
479 return failure();
480 result.push_back(*resultBuffer);
481 } else {
482 result.push_back(opOperand.get());
483 }
484 }
485 return result;
486}
487
488/// Helper function for loop bufferization. Given a list of bbArgs of the new
489/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
490/// ToTensorOps, so that the block body can be moved over to the new op.
491static SmallVector<Value>
492getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
493 Block::BlockArgListType oldBbArgs,
494 const DenseSet<int64_t> &tensorIndices) {
495 SmallVector<Value> result;
496 for (const auto &it : llvm::enumerate(bbArgs)) {
497 size_t idx = it.index();
498 Value val = it.value();
499 if (tensorIndices.contains(idx)) {
500 result.push_back(
501 bufferization::ToTensorOp::create(rewriter, val.getLoc(),
502 oldBbArgs[idx].getType(), val)
503 .getResult());
504 } else {
505 result.push_back(val);
506 }
507 }
508 return result;
509}
510
511/// Compute the bufferized type of a loop iter_arg. This type must be equal to
512/// the bufferized type of the corresponding init_arg and the bufferized type
513/// of the corresponding yielded value.
514///
515/// This function uses bufferization::getBufferType to compute the bufferized
516/// type of the init_arg and of the yielded value. (The computation of the
517/// bufferized yielded value type usually requires computing the bufferized type
518/// of the iter_arg again; the implementation of getBufferType traces back the
519/// use-def chain of the given value and computes a buffer type along the way.)
520/// If both buffer types are equal, no casts are needed the computed buffer type
521/// can be used directly. Otherwise, the buffer types can only differ in their
522/// layout map and a cast must be inserted.
523static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
524 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
525 const BufferizationOptions &options, const BufferizationState &state,
526 SmallVector<Value> &invocationStack) {
527 // Determine the buffer type of the init_arg.
528 auto initArgBufferType =
529 bufferization::getBufferType(initArg, options, state, invocationStack);
530 if (failed(initArgBufferType))
531 return failure();
532
533 if (llvm::count(invocationStack, iterArg) >= 2) {
534 // If the iter_arg is already twice on the invocation stack, just take the
535 // type of the init_arg. This is to avoid infinite loops when calculating
536 // the buffer type. This will most likely result in computing a memref type
537 // with a fully dynamic layout map.
538
539 // Note: For more precise layout map computation, a fixpoint iteration could
540 // be done (i.e., re-computing the yielded buffer type until the bufferized
541 // iter_arg type no longer changes). This current implementation immediately
542 // switches to a fully dynamic layout map when a mismatch between bufferized
543 // init_arg type and bufferized yield value type is detected.
544 return *initArgBufferType;
545 }
546
547 // Compute the buffer type of the yielded value.
548 BufferLikeType yieldedValueBufferType;
549 if (isa<BaseMemRefType>(yieldedValue.getType())) {
550 // scf.yield was already bufferized.
551 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
552 } else {
553 // Note: This typically triggers a recursive call for the buffer type of
554 // the iter_arg.
555 auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
556 state, invocationStack);
557 if (failed(maybeBufferType))
558 return failure();
559 yieldedValueBufferType = *maybeBufferType;
560 }
561
562 // If yielded type and init_arg type are the same, use that type directly.
563 if (*initArgBufferType == yieldedValueBufferType)
564 return yieldedValueBufferType;
565
566 // If there is a mismatch between the yielded buffer type and the init_arg
567 // buffer type, the buffer type must be promoted to a fully dynamic layout
568 // map.
569 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
570 auto iterTensorType = cast<TensorType>(iterArg.getType());
571 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
572 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
573 return loopOp->emitOpError(
574 "init_arg and yielded value bufferize to inconsistent memory spaces");
575#ifndef NDEBUG
576 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
577 assert(
578 llvm::all_equal({yieldedRankedBufferType.getShape(),
579 cast<MemRefType>(initBufferType).getShape(),
580 cast<RankedTensorType>(iterTensorType).getShape()}) &&
581 "expected same shape");
582 }
583#endif // NDEBUG
584 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
585 iterTensorType, yieldedBufferType.getMemorySpace()));
586}
587
588/// Return `true` if the given loop may have 0 iterations.
589bool mayHaveZeroIterations(scf::ForOp forOp) {
590 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
591 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
592 if (!lb.has_value() || !ub.has_value())
593 return true;
594 return *ub <= *lb;
595}
596
597/// Bufferization of scf.for. Replace with a new scf.for that operates on
598/// memrefs.
599struct ForOpInterface
600 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
601 scf::ForOp> {
602 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
603 const AnalysisState &state) const {
604 auto forOp = cast<scf::ForOp>(op);
605
606 // If the loop has zero iterations, the results of the op are their
607 // corresponding init_args, meaning that the init_args bufferize to a read.
608 if (mayHaveZeroIterations(forOp))
609 return true;
610
611 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
612 // its matching bbArg may.
613 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
614 }
615
616 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
617 const AnalysisState &state) const {
618 // Tensor iter_args of scf::ForOps are always considered as a write.
619 return true;
620 }
621
622 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
623 const AnalysisState &state) const {
624 auto forOp = cast<scf::ForOp>(op);
625 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
626 BufferRelation relation = bufferRelation(op, opResult, state);
627 return {{opResult, relation,
628 /*isDefinite=*/relation == BufferRelation::Equivalent}};
629 }
630
631 BufferRelation bufferRelation(Operation *op, OpResult opResult,
632 const AnalysisState &state) const {
633 // ForOp results are equivalent to their corresponding init_args if the
634 // corresponding iter_args and yield values are equivalent.
635 auto forOp = cast<scf::ForOp>(op);
636 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
637 bool equivalentYield = state.areEquivalentBufferizedValues(
638 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
639 return equivalentYield ? BufferRelation::Equivalent
640 : BufferRelation::Unknown;
641 }
642
643 bool isWritable(Operation *op, Value value,
644 const AnalysisState &state) const {
645 // Interestingly, scf::ForOp's bbArg can **always** be viewed
646 // inplace from the perspective of ops nested under:
647 // 1. Either the matching iter operand is not bufferized inplace and an
648 // alloc + optional copy makes the bbArg itself inplaceable.
649 // 2. Or the matching iter operand is bufferized inplace and bbArg just
650 // bufferizes to that too.
651 return true;
652 }
653
654 LogicalResult
655 resolveConflicts(Operation *op, RewriterBase &rewriter,
656 const AnalysisState &analysisState,
657 const BufferizationState &bufferizationState) const {
658 auto bufferizableOp = cast<BufferizableOpInterface>(op);
659 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
660 rewriter, analysisState, bufferizationState)))
661 return failure();
662
663 if (analysisState.getOptions().copyBeforeWrite)
664 return success();
665
666 // According to the `getAliasing...` implementations, a bufferized OpResult
667 // may alias only with the corresponding bufferized init_arg (or with a
668 // newly allocated buffer) and not with other buffers defined outside of the
669 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
670 // but not with any other OpOperand.
671 auto forOp = cast<scf::ForOp>(op);
672 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
673 OpBuilder::InsertionGuard g(rewriter);
674 rewriter.setInsertionPoint(yieldOp);
675
676 // Indices of all iter_args that have tensor type. These are the ones that
677 // are bufferized.
678 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
679 // For every yielded value, does it alias with something defined outside of
680 // the loop?
681 SmallVector<Value> yieldValues;
682 for (const auto it : llvm::enumerate(yieldOp.getResults())) {
683 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
684 // type cannot be used in the signature of `resolveConflicts` because the
685 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
686 // is defined in the "Transforms" build unit.
687 if (!indices.contains(it.index()) ||
688 doesNotAliasExternalValue(
689 it.value(), &forOp.getRegion(),
690 /*exceptions=*/forOp.getRegionIterArg(it.index()),
691 static_cast<const OneShotAnalysisState &>(analysisState))) {
692 yieldValues.push_back(it.value());
693 continue;
694 }
695 FailureOr<Value> alloc = allocateTensorForShapedValue(
696 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
697 bufferizationState);
698 if (failed(alloc))
699 return failure();
700 yieldValues.push_back(*alloc);
701 }
702
703 rewriter.modifyOpInPlace(
704 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
705 return success();
706 }
707
708 FailureOr<BufferLikeType>
709 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
710 const BufferizationState &state,
711 SmallVector<Value> &invocationStack) const {
712 auto forOp = cast<scf::ForOp>(op);
713 assert(getOwnerOfValue(value) == op && "invalid value");
714 assert(isa<TensorType>(value.getType()) && "expected tensor type");
715
716 if (auto opResult = dyn_cast<OpResult>(value)) {
717 // The type of an OpResult must match the corresponding iter_arg type.
718 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
719 return bufferization::getBufferType(bbArg, options, state,
720 invocationStack);
721 }
722
723 // Compute result/argument number.
724 BlockArgument bbArg = cast<BlockArgument>(value);
725 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
726
727 // Compute the bufferized type.
728 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
729 Value yieldedValue = yieldOp.getOperand(resultNum);
730 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
731 Value initArg = forOp.getInitArgs()[resultNum];
732 return computeLoopRegionIterArgBufferType(
733 op, iterArg, initArg, yieldedValue, options, state, invocationStack);
734 }
735
736 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
737 const BufferizationOptions &options,
738 BufferizationState &state) const {
739 auto forOp = cast<scf::ForOp>(op);
740 Block *oldLoopBody = forOp.getBody();
741
742 // Indices of all iter_args that have tensor type. These are the ones that
743 // are bufferized.
744 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
745
746 // The new memref init_args of the loop.
747 FailureOr<SmallVector<Value>> maybeInitArgs =
748 getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
749 if (failed(maybeInitArgs))
750 return failure();
751 SmallVector<Value> initArgs = *maybeInitArgs;
752
753 // Cast init_args if necessary.
754 SmallVector<Value> castedInitArgs;
755 for (const auto &it : llvm::enumerate(initArgs)) {
756 Value initArg = it.value();
757 Value result = forOp->getResult(it.index());
758 // If the type is not a tensor, bufferization doesn't need to touch it.
759 if (!isa<TensorType>(result.getType())) {
760 castedInitArgs.push_back(initArg);
761 continue;
762 }
763 auto targetType = bufferization::getBufferType(result, options, state);
764 if (failed(targetType))
765 return failure();
766 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
767 }
768
769 // Construct a new scf.for op with memref instead of tensor values.
770 auto newForOp = scf::ForOp::create(
771 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772 forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
773 forOp.getUnsignedCmp());
774 newForOp->setAttrs(forOp->getAttrs());
775 Block *loopBody = newForOp.getBody();
776
777 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
778 // iter_args of the new loop in ToTensorOps.
779 rewriter.setInsertionPointToStart(loopBody);
780 SmallVector<Value> iterArgs =
781 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
782 forOp.getRegionIterArgs(), indices);
783 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
784
785 // Move loop body to new loop.
786 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
787
788 // Replace loop results.
789 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
790
791 return success();
792 }
793
794 /// Assert that yielded values of an scf.for op are equivalent to their
795 /// corresponding bbArgs. In that case, the buffer relations of the
796 /// corresponding OpResults are "Equivalent".
797 ///
798 /// If this is not the case, an allocs+copies are inserted and yielded from
799 /// the loop. This could be a performance problem, so it must be explicitly
800 /// activated with `alloc-return-allocs`.
801 LogicalResult verifyAnalysis(Operation *op,
802 const AnalysisState &state) const {
803 const auto &options =
804 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
805 if (options.allowReturnAllocsFromLoops)
806 return success();
807
808 auto forOp = cast<scf::ForOp>(op);
809 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
810 for (OpResult opResult : op->getOpResults()) {
811 if (!isa<TensorType>(opResult.getType()))
812 continue;
813
814 // Note: This is overly strict. We should check for aliasing bufferized
815 // values. But we don't have a "must-alias" analysis yet.
816 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
817 return yieldOp->emitError()
818 << "Yield operand #" << opResult.getResultNumber()
819 << " is not equivalent to the corresponding iter bbArg";
820 }
821
822 return success();
823 }
824};
825
826/// Bufferization of scf.while. Replace with a new scf.while that operates on
827/// memrefs.
828struct WhileOpInterface
829 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
830 scf::WhileOp> {
831 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
832 const AnalysisState &state) const {
833 // Tensor iter_args of scf::WhileOps are always considered as a read.
834 return true;
835 }
836
837 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
838 const AnalysisState &state) const {
839 // Tensor iter_args of scf::WhileOps are always considered as a write.
840 return true;
841 }
842
843 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
844 const AnalysisState &state) const {
845 auto whileOp = cast<scf::WhileOp>(op);
846 unsigned int idx = opOperand.getOperandNumber();
847
848 // The OpResults and OpOperands may not match. They may not even have the
849 // same type. The number of OpResults and OpOperands can also differ.
850 if (idx >= op->getNumResults() ||
851 opOperand.get().getType() != op->getResult(idx).getType())
852 return {};
853
854 // The only aliasing OpResult may be the one at the same index.
855 OpResult opResult = whileOp->getResult(idx);
856 BufferRelation relation = bufferRelation(op, opResult, state);
857 return {{opResult, relation,
858 /*isDefinite=*/relation == BufferRelation::Equivalent}};
859 }
860
861 BufferRelation bufferRelation(Operation *op, OpResult opResult,
862 const AnalysisState &state) const {
863 // WhileOp results are equivalent to their corresponding init_args if the
864 // corresponding iter_args and yield values are equivalent (for both the
865 // "before" and the "after" block).
866 unsigned int resultNumber = opResult.getResultNumber();
867 auto whileOp = cast<scf::WhileOp>(op);
868
869 // The "before" region bbArgs and the OpResults may not match.
870 if (resultNumber >= whileOp.getBeforeArguments().size())
871 return BufferRelation::Unknown;
872 if (opResult.getType() !=
873 whileOp.getBeforeArguments()[resultNumber].getType())
874 return BufferRelation::Unknown;
875
876 auto conditionOp = whileOp.getConditionOp();
877 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
878 Value conditionOperand = conditionOp.getArgs()[resultNumber];
879 bool equivCondition =
880 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
881
882 auto yieldOp = whileOp.getYieldOp();
883 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
884 Value yieldOperand = yieldOp.getOperand(resultNumber);
885 bool equivYield =
886 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
887
888 return equivCondition && equivYield ? BufferRelation::Equivalent
889 : BufferRelation::Unknown;
890 }
891
892 bool isWritable(Operation *op, Value value,
893 const AnalysisState &state) const {
894 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
895 // inplace from the perspective of ops nested under:
896 // 1. Either the matching iter operand is not bufferized inplace and an
897 // alloc + optional copy makes the bbArg itself inplaceable.
898 // 2. Or the matching iter operand is bufferized inplace and bbArg just
899 // bufferizes to that too.
900 return true;
901 }
902
903 LogicalResult
904 resolveConflicts(Operation *op, RewriterBase &rewriter,
905 const AnalysisState &analysisState,
906 const BufferizationState &bufferizationState) const {
907 auto bufferizableOp = cast<BufferizableOpInterface>(op);
908 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
909 rewriter, analysisState, bufferizationState)))
910 return failure();
911
912 if (analysisState.getOptions().copyBeforeWrite)
913 return success();
914
915 // According to the `getAliasing...` implementations, a bufferized OpResult
916 // may alias only with the corresponding bufferized init_arg and with no
917 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
918 // but not with any other OpOperand. If a corresponding OpResult/init_arg
919 // pair bufferizes to equivalent buffers, this aliasing requirement is
920 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
921 // (New buffer copies do not alias with any buffer.)
922 OpBuilder::InsertionGuard g(rewriter);
923 auto whileOp = cast<scf::WhileOp>(op);
924 auto conditionOp = whileOp.getConditionOp();
925
926 // For every yielded value, is the value equivalent to its corresponding
927 // bbArg?
928 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
929 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
930 DenseSet<int64_t> equivalentYieldsAfter =
931 getEquivalentBuffers(whileOp.getAfterArguments(),
932 whileOp.getYieldOp().getResults(), analysisState);
933
934 // Update "before" region.
935 rewriter.setInsertionPoint(conditionOp);
936 SmallVector<Value> beforeYieldValues;
937 for (int64_t idx = 0;
938 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
939 Value value = conditionOp.getArgs()[idx];
940 if (!isa<TensorType>(value.getType()) ||
941 (equivalentYieldsAfter.contains(idx) &&
942 equivalentYieldsBefore.contains(idx))) {
943 beforeYieldValues.push_back(value);
944 continue;
945 }
946 FailureOr<Value> alloc = allocateTensorForShapedValue(
947 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
948 bufferizationState);
949 if (failed(alloc))
950 return failure();
951 beforeYieldValues.push_back(*alloc);
952 }
953 rewriter.modifyOpInPlace(conditionOp, [&]() {
954 conditionOp.getArgsMutable().assign(beforeYieldValues);
955 });
956
957 return success();
958 }
959
960 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
961 const BufferizationOptions &options,
962 BufferizationState &state) const {
963 auto whileOp = cast<scf::WhileOp>(op);
964
965 // Indices of all bbArgs that have tensor type. These are the ones that
966 // are bufferized. The "before" and "after" regions may have different args.
967 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
968 DenseSet<int64_t> indicesAfter =
969 getTensorIndices(whileOp.getAfterArguments());
970
971 // The new memref init_args of the loop.
972 FailureOr<SmallVector<Value>> maybeInitArgs =
973 getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
974 if (failed(maybeInitArgs))
975 return failure();
976 SmallVector<Value> initArgs = *maybeInitArgs;
977
978 // Cast init_args if necessary.
979 SmallVector<Value> castedInitArgs;
980 for (const auto &it : llvm::enumerate(initArgs)) {
981 Value initArg = it.value();
982 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
983 // If the type is not a tensor, bufferization doesn't need to touch it.
984 if (!isa<TensorType>(beforeArg.getType())) {
985 castedInitArgs.push_back(initArg);
986 continue;
987 }
988 auto targetType = bufferization::getBufferType(beforeArg, options, state);
989 if (failed(targetType))
990 return failure();
991 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
992 }
993
994 // The result types of a WhileOp are the same as the "after" bbArg types.
995 SmallVector<Type> argsTypesAfter = llvm::to_vector(
996 llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
997 if (!isa<TensorType>(bbArg.getType()))
998 return bbArg.getType();
999 // TODO: error handling
1000 return llvm::cast<Type>(
1001 *bufferization::getBufferType(bbArg, options, state));
1002 }));
1003
1004 // Construct a new scf.while op with memref instead of tensor values.
1005 ValueRange argsRangeBefore(castedInitArgs);
1006 TypeRange argsTypesBefore(argsRangeBefore);
1007 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1008 argsTypesAfter, castedInitArgs);
1009
1010 // Add before/after regions to the new op.
1011 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1012 whileOp.getLoc());
1013 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1014 whileOp.getLoc());
1015 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1016 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1017 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1018 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1019
1020 // Set up new iter_args and move the loop condition block to the new op.
1021 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1022 // in ToTensorOps.
1023 rewriter.setInsertionPointToStart(newBeforeBody);
1024 SmallVector<Value> newBeforeArgs =
1025 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1026 whileOp.getBeforeArguments(), indicesBefore);
1027 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1028
1029 // Set up new iter_args and move the loop body block to the new op.
1030 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1031 // in ToTensorOps.
1032 rewriter.setInsertionPointToStart(newAfterBody);
1033 SmallVector<Value> newAfterArgs =
1034 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1035 whileOp.getAfterArguments(), indicesAfter);
1036 rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1037
1038 // Replace loop results.
1039 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1040
1041 return success();
1042 }
1043
1044 FailureOr<BufferLikeType>
1045 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1046 const BufferizationState &state,
1047 SmallVector<Value> &invocationStack) const {
1048 auto whileOp = cast<scf::WhileOp>(op);
1049 assert(getOwnerOfValue(value) == op && "invalid value");
1050 assert(isa<TensorType>(value.getType()) && "expected tensor type");
1051
1052 // Case 1: Block argument of the "before" region.
1053 if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1054 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1055 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1056 auto yieldOp = whileOp.getYieldOp();
1057 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1058 return computeLoopRegionIterArgBufferType(
1059 op, bbArg, initArg, yieldedValue, options, state, invocationStack);
1060 }
1061 }
1062
1063 // Case 2: OpResult of the loop or block argument of the "after" region.
1064 // The bufferized "after" bbArg type can be directly computed from the
1065 // bufferized "before" bbArg type.
1066 unsigned resultNum;
1067 if (auto opResult = dyn_cast<OpResult>(value)) {
1068 resultNum = opResult.getResultNumber();
1069 } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1070 &whileOp.getAfter()) {
1071 resultNum = cast<BlockArgument>(value).getArgNumber();
1072 } else {
1073 llvm_unreachable("invalid value");
1074 }
1075 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1076 if (!isa<TensorType>(conditionYieldedVal.getType())) {
1077 // scf.condition was already bufferized.
1078 return cast<BufferLikeType>(conditionYieldedVal.getType());
1079 }
1080 return bufferization::getBufferType(conditionYieldedVal, options, state,
1081 invocationStack);
1082 }
1083
1084 /// Assert that yielded values of an scf.while op are equivalent to their
1085 /// corresponding bbArgs. In that case, the buffer relations of the
1086 /// corresponding OpResults are "Equivalent".
1087 ///
1088 /// If this is not the case, allocs+copies are inserted and yielded from
1089 /// the loop. This could be a performance problem, so it must be explicitly
1090 /// activated with `allow-return-allocs`.
1091 ///
1092 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1093 /// equivalence condition must be checked for both.
1094 LogicalResult verifyAnalysis(Operation *op,
1095 const AnalysisState &state) const {
1096 auto whileOp = cast<scf::WhileOp>(op);
1097 const auto &options =
1098 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1099 if (options.allowReturnAllocsFromLoops)
1100 return success();
1101
1102 auto conditionOp = whileOp.getConditionOp();
1103 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1104 Block *block = conditionOp->getBlock();
1105 if (!isa<TensorType>(it.value().getType()))
1106 continue;
1107 if (it.index() >= block->getNumArguments() ||
1108 !state.areEquivalentBufferizedValues(it.value(),
1109 block->getArgument(it.index())))
1110 return conditionOp->emitError()
1111 << "Condition arg #" << it.index()
1112 << " is not equivalent to the corresponding iter bbArg";
1113 }
1114
1115 auto yieldOp = whileOp.getYieldOp();
1116 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1117 Block *block = yieldOp->getBlock();
1118 if (!isa<TensorType>(it.value().getType()))
1119 continue;
1120 if (it.index() >= block->getNumArguments() ||
1121 !state.areEquivalentBufferizedValues(it.value(),
1122 block->getArgument(it.index())))
1123 return yieldOp->emitError()
1124 << "Yield operand #" << it.index()
1125 << " is not equivalent to the corresponding iter bbArg";
1126 }
1127
1128 return success();
1129 }
1130};
1131
1132/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1133/// this is for analysis only.
1134struct YieldOpInterface
1135 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1136 scf::YieldOp> {
1137 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1138 const AnalysisState &state) const {
1139 return true;
1140 }
1141
1142 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1143 const AnalysisState &state) const {
1144 return false;
1145 }
1146
1147 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1148 const AnalysisState &state) const {
1149 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1150 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1151 BufferRelation::Equivalent, /*isDefinite=*/false}};
1152 }
1153 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1154 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1155 BufferRelation::Equivalent}};
1156 return {};
1157 }
1158
1159 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1160 const AnalysisState &state) const {
1161 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1162 // may be generated inside the block. We should not return/yield allocations
1163 // when possible.
1164 return true;
1165 }
1166
1167 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1168 const BufferizationOptions &options,
1169 BufferizationState &state) const {
1170 auto yieldOp = cast<scf::YieldOp>(op);
1171 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1172 scf::WhileOp>(yieldOp->getParentOp()))
1173 return yieldOp->emitError("unsupported scf::YieldOp parent");
1174
1175 SmallVector<Value> newResults;
1176 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1177 Value value = it.value();
1178 if (isa<TensorType>(value.getType())) {
1179 FailureOr<Value> maybeBuffer =
1180 getBuffer(rewriter, value, options, state);
1181 if (failed(maybeBuffer))
1182 return failure();
1183 Value buffer = *maybeBuffer;
1184 // We may have to cast the value before yielding it.
1185 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1186 yieldOp->getParentOp())) {
1187 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1188 yieldOp->getParentOp()->getResult(it.index()), options, state);
1189 if (failed(resultType))
1190 return failure();
1191 buffer = castBuffer(rewriter, buffer, *resultType);
1192 } else if (auto whileOp =
1193 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1194 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1195 whileOp.getBeforeArguments()[it.index()], options, state);
1196 if (failed(resultType))
1197 return failure();
1198 buffer = castBuffer(rewriter, buffer, *resultType);
1199 }
1200 newResults.push_back(buffer);
1201 } else {
1202 newResults.push_back(value);
1203 }
1204 }
1205
1206 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1207 return success();
1208 }
1209};
1210
1211/// Bufferization of ForallOp. This also bufferizes the terminator of the
1212/// region. There are op interfaces for the terminators (InParallelOp
1213/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1214/// for bufferization.
1215struct ForallOpInterface
1216 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1217 ForallOp> {
1218 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1219 const AnalysisState &state) const {
1220 // All tensor operands to `scf.forall` are `shared_outs` and all
1221 // shared outs are assumed to be read by the loop. This does not
1222 // account for the case where the entire value is over-written,
1223 // but being conservative here.
1224 return true;
1225 }
1226
1227 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1228 const AnalysisState &state) const {
1229 // Outputs of scf::ForallOps are always considered as a write.
1230 return true;
1231 }
1232
1233 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1234 const AnalysisState &state) const {
1235 auto forallOp = cast<ForallOp>(op);
1236 return {
1237 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1238 }
1239
1240 bool isWritable(Operation *op, Value value,
1241 const AnalysisState &state) const {
1242 return true;
1243 }
1244
1245 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1246 const BufferizationOptions &options,
1247 BufferizationState &state) const {
1248 OpBuilder::InsertionGuard guard(rewriter);
1249 auto forallOp = cast<ForallOp>(op);
1250 int64_t rank = forallOp.getRank();
1251
1252 // Get buffers for all output operands.
1253 SmallVector<Value> buffers;
1254 for (Value out : forallOp.getOutputs()) {
1255 FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1256 if (failed(buffer))
1257 return failure();
1258 buffers.push_back(*buffer);
1259 }
1260
1261 // Use buffers instead of block arguments.
1262 rewriter.setInsertionPointToStart(forallOp.getBody());
1263 for (const auto &it : llvm::zip(
1264 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1265 BlockArgument bbArg = std::get<0>(it);
1266 Value buffer = std::get<1>(it);
1267 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1268 bbArg.getType(), buffer);
1269 bbArg.replaceAllUsesWith(bufferAsTensor);
1270 }
1271
1272 // Create new ForallOp without any results and drop the automatically
1273 // introduced terminator.
1274 rewriter.setInsertionPoint(forallOp);
1275 ForallOp newForallOp;
1276 newForallOp = ForallOp::create(
1277 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1278 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1279 /*outputs=*/ValueRange(), forallOp.getMapping());
1280
1281 // Keep discardable attributes from the original op.
1282 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1283
1284 rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1285
1286 // Move over block contents of the old op.
1287 SmallVector<Value> replacementBbArgs;
1288 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1289 newForallOp.getBody()->getArguments().end());
1290 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1291 rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1292 replacementBbArgs);
1293
1294 // Remove the old op and replace all of its uses.
1295 replaceOpWithBufferizedValues(rewriter, op, buffers);
1296
1297 return success();
1298 }
1299
1300 FailureOr<BufferLikeType>
1301 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1302 const BufferizationState &state,
1303 SmallVector<Value> &invocationStack) const {
1304 auto forallOp = cast<ForallOp>(op);
1305
1306 if (auto bbArg = dyn_cast<BlockArgument>(value))
1307 // A tensor block argument has the same bufferized type as the
1308 // corresponding output operand.
1309 return bufferization::getBufferType(
1310 forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1311 invocationStack);
1312
1313 // The bufferized result type is the same as the bufferized type of the
1314 // corresponding output operand.
1315 return bufferization::getBufferType(
1316 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1317 state, invocationStack);
1318 }
1319
1320 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1321 auto forallOp = cast<ForallOp>(op);
1322
1323 // This op is repetitive if it has 1 or more steps.
1324 // If the control variables are dynamic, it is also considered so.
1325 for (auto [lb, ub, step] :
1326 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1327 forallOp.getMixedStep())) {
1328 std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1329 if (!lbConstant)
1330 return true;
1331
1332 std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1333 if (!ubConstant)
1334 return true;
1335
1336 std::optional<int64_t> stepConstant = getConstantIntValue(step);
1337 if (!stepConstant)
1338 return true;
1339
1340 if (*lbConstant + *stepConstant < *ubConstant)
1341 return true;
1342 }
1343 return false;
1344 }
1345
1346 bool isParallelRegion(Operation *op, unsigned index) const {
1347 return isRepetitiveRegion(op, index);
1348 }
1349};
1350
1351/// Nothing to do for InParallelOp.
1352struct InParallelOpInterface
1353 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1354 InParallelOp> {
1355 LogicalResult bufferize(Operation *op, RewriterBase &b,
1356 const BufferizationOptions &options,
1357 BufferizationState &state) const {
1358 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1359 return failure();
1360 }
1361};
1362
1363} // namespace
1364} // namespace scf
1365} // namespace mlir
1366
1368 DialectRegistry &registry) {
1369 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1370 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1371 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1372 ForOp::attachInterface<ForOpInterface>(*ctx);
1373 IfOp::attachInterface<IfOpInterface>(*ctx);
1374 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1375 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1376 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1377 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1378 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1379 });
1380}
return success()
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:85
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:501
result_range getOpResults()
Definition Operation.h:420
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:222
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition Region.cpp:50
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128