MLIR 23.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/IR/Dialect.h"
20#include "mlir/IR/Operation.h"
22#include "llvm/ADT/SmallVectorExtras.h"
23
24using namespace mlir;
25using namespace mlir::bufferization;
26using namespace mlir::scf;
27
28namespace mlir {
29namespace scf {
30namespace {
31
32/// Helper function for loop bufferization. Cast the given buffer to the given
33/// memref type.
34static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
35 // If the buffer already has the correct type, no cast is needed.
36 if (buffer.getType() == type)
37 return buffer;
38
39 // TODO: Properly support with options, for now it is hardcoded MemRef type
40 // based approach
41 assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
42 assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
43 // TODO: In case `type` has a layout map that is not the fully dynamic
44 // one, we may not be able to cast the buffer. In that case, the loop
45 // iter_arg's layout map must be changed (see uses of `castBuffer`).
46 assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
47 "scf.while op bufferization: cast incompatible");
48 return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
49}
50
51/// Helper function for loop bufferization. Return "true" if the given value
52/// is guaranteed to not alias with an external tensor apart from values in
53/// `exceptions`. A value is external if it is defined outside of the given
54/// region or if it is an entry block argument of the region.
55static bool doesNotAliasExternalValue(Value value, Region *region,
56 ValueRange exceptions,
57 const OneShotAnalysisState &state) {
58 assert(region->hasOneBlock() && "expected region with single block");
59 bool result = true;
60 state.applyOnAliases(value, [&](Value alias) {
61 if (llvm::is_contained(exceptions, alias))
62 return;
63 Region *aliasRegion = alias.getParentRegion();
64 if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion))
65 result = false;
66 if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion))
67 result = false;
68 });
69 return result;
70}
71
72/// Bufferization of scf.condition.
73struct ConditionOpInterface
74 : public BufferizableOpInterface::ExternalModel<ConditionOpInterface,
75 scf::ConditionOp> {
76 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
77 const AnalysisState &state) const {
78 return true;
79 }
80
81 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
82 const AnalysisState &state) const {
83 return false;
84 }
85
86 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
87 const AnalysisState &state) const {
88 return {};
89 }
90
91 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
92 const AnalysisState &state) const {
93 // Condition operands always bufferize inplace. Otherwise, an alloc + copy
94 // may be generated inside the block. We should not return/yield allocations
95 // when possible.
96 return true;
97 }
98
99 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
100 const BufferizationOptions &options,
101 BufferizationState &state) const {
102 auto conditionOp = cast<scf::ConditionOp>(op);
103 auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
104
105 SmallVector<Value> newArgs;
106 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
107 Value value = it.value();
108 if (isa<TensorLikeType>(value.getType())) {
109 FailureOr<Value> maybeBuffer =
110 getBuffer(rewriter, value, options, state);
111 if (failed(maybeBuffer))
112 return failure();
113 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
114 whileOp.getAfterArguments()[it.index()], options, state);
115 if (failed(resultType))
116 return failure();
117 Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
118 newArgs.push_back(buffer);
119 } else {
120 newArgs.push_back(value);
121 }
122 }
123
124 replaceOpWithNewBufferizedOp<scf::ConditionOp>(
125 rewriter, op, conditionOp.getCondition(), newArgs);
126 return success();
127 }
128};
129
130/// Return the unique scf.yield op. If there are multiple or no scf.yield ops,
131/// return an empty op.
132static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) {
133 scf::YieldOp result;
134 for (Block &block : executeRegionOp.getRegion()) {
135 if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) {
136 if (result)
137 return {};
138 result = yieldOp;
139 }
140 }
141 return result;
142}
143
144/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
145/// fully implemented at the moment.
146struct ExecuteRegionOpInterface
147 : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel<
148 ExecuteRegionOpInterface, scf::ExecuteRegionOp> {
149
150 static bool supportsUnstructuredControlFlow() { return true; }
151
152 bool isWritable(Operation *op, Value value,
153 const AnalysisState &state) const {
154 return true;
155 }
156
157 LogicalResult verifyAnalysis(Operation *op,
158 const AnalysisState &state) const {
159 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
160 // TODO: scf.execute_region with multiple yields are not supported.
161 if (!getUniqueYieldOp(executeRegionOp))
162 return op->emitOpError("op without unique scf.yield is not supported");
163 return success();
164 }
165
166 AliasingOpOperandList
167 getAliasingOpOperands(Operation *op, Value value,
168 const AnalysisState &state) const {
169 if (auto bbArg = dyn_cast<BlockArgument>(value))
170 return getAliasingBranchOpOperands(op, bbArg, state);
171
172 // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
173 // any SSA value that is in scope. To allow for use-def chain traversal
174 // through ExecuteRegionOps in the analysis, the corresponding yield value
175 // is considered to be aliasing with the result.
176 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
177 auto it = llvm::find(op->getOpResults(), value);
178 assert(it != op->getOpResults().end() && "invalid value");
179 size_t resultNum = std::distance(op->getOpResults().begin(), it);
180 auto yieldOp = getUniqueYieldOp(executeRegionOp);
181 // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail.
182 if (!yieldOp)
183 return {};
184 return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
185 }
186
187 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
188 const BufferizationOptions &options,
189 BufferizationState &state) const {
190 auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
191 auto yieldOp = getUniqueYieldOp(executeRegionOp);
192 TypeRange newResultTypes(yieldOp.getResults());
193
194 // Create new op and move over region.
195 auto newOp = scf::ExecuteRegionOp::create(
196 rewriter, op->getLoc(), newResultTypes, executeRegionOp.getNoInline());
197 newOp.getRegion().takeBody(executeRegionOp.getRegion());
198
199 // Bufferize every block.
200 for (Block &block : newOp.getRegion())
202 options, state)))
203 return failure();
204
205 // Update all uses of the old op.
206 rewriter.setInsertionPointAfter(newOp);
207 SmallVector<Value> newResults;
208 for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
209 if (isa<TensorType>(it.value())) {
210 newResults.push_back(bufferization::ToTensorOp::create(
211 rewriter, executeRegionOp.getLoc(), it.value(),
212 newOp->getResult(it.index())));
213 } else {
214 newResults.push_back(newOp->getResult(it.index()));
215 }
216 }
217
218 // Replace old op.
219 rewriter.replaceOp(executeRegionOp, newResults);
220
221 return success();
222 }
223};
224
225/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
226struct IfOpInterface
227 : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
228 AliasingOpOperandList
229 getAliasingOpOperands(Operation *op, Value value,
230 const AnalysisState &state) const {
231 // IfOps do not have tensor OpOperands. The yielded value can be any SSA
232 // value that is in scope. To allow for use-def chain traversal through
233 // IfOps in the analysis, both corresponding yield values from the then/else
234 // branches are considered to be aliasing with the result.
235 auto ifOp = cast<scf::IfOp>(op);
236 size_t resultNum = std::distance(op->getOpResults().begin(),
237 llvm::find(op->getOpResults(), value));
238 OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
239 OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
240 return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
241 {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
242 }
243
244 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
245 const BufferizationOptions &options,
246 BufferizationState &state) const {
247 OpBuilder::InsertionGuard g(rewriter);
248 auto ifOp = cast<scf::IfOp>(op);
249
250 // Compute bufferized result types.
251 SmallVector<Type> newTypes;
252 for (Value result : ifOp.getResults()) {
253 if (!isa<TensorLikeType>(result.getType())) {
254 newTypes.push_back(result.getType());
255 continue;
256 }
257 auto bufferType = bufferization::getBufferType(result, options, state);
258 if (failed(bufferType))
259 return failure();
260 newTypes.push_back(*bufferType);
261 }
262
263 // Create new op.
264 rewriter.setInsertionPoint(ifOp);
265 auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
266 ifOp.getCondition(),
267 /*withElseRegion=*/true);
268
269 // Move over then/else blocks.
270 rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
271 rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
272
273 // Replace op results.
274 replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
275
276 return success();
277 }
278
279 FailureOr<BufferLikeType>
280 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
281 const BufferizationState &state,
282 SmallVector<Value> &invocationStack) const {
283 auto ifOp = cast<scf::IfOp>(op);
284 auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
285 auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
286 assert(value.getDefiningOp() == op && "invalid valid");
287
288 // Determine buffer types of the true/false branches.
289 auto opResult = cast<OpResult>(value);
290 auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
291 auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
292 BufferLikeType thenBufferType, elseBufferType;
293 if (isa<BufferLikeType>(thenValue.getType())) {
294 // True branch was already bufferized.
295 thenBufferType = cast<BufferLikeType>(thenValue.getType());
296 } else {
297 auto maybeBufferType = bufferization::getBufferType(
298 thenValue, options, state, invocationStack);
299 if (failed(maybeBufferType))
300 return failure();
301 thenBufferType = *maybeBufferType;
302 }
303 if (isa<BufferLikeType>(elseValue.getType())) {
304 // False branch was already bufferized.
305 elseBufferType = cast<BufferLikeType>(elseValue.getType());
306 } else {
307 auto maybeBufferType = bufferization::getBufferType(
308 elseValue, options, state, invocationStack);
309 if (failed(maybeBufferType))
310 return failure();
311 elseBufferType = *maybeBufferType;
312 }
313
314 // Best case: Both branches have the exact same buffer type.
315 if (thenBufferType == elseBufferType)
316 return cast<BufferLikeType>(thenBufferType);
317
318 // Memory space mismatch.
319 auto thenBaseMemRefType = dyn_cast<BaseMemRefType>(thenBufferType);
320 auto elseBaseMemRefType = dyn_cast<BaseMemRefType>(elseBufferType);
321 if (thenBaseMemRefType && elseBaseMemRefType &&
322 thenBaseMemRefType.getMemorySpace() !=
323 elseBaseMemRefType.getMemorySpace())
324 return op->emitError("inconsistent memory space on then/else branches");
325
326 // TODO: Properly support with options, for now it is hardcoded MemRef type
327 // based approach Layout maps are different: Promote to fully dynamic layout
328 // map.
329 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
330 cast<TensorType>(opResult.getType()),
331 thenBaseMemRefType.getMemorySpace()));
332 }
333};
334
335/// Bufferization of scf.index_switch. Replace with a new scf.index_switch that
336/// yields memrefs.
337struct IndexSwitchOpInterface
338 : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface,
339 scf::IndexSwitchOp> {
340 AliasingOpOperandList
341 getAliasingOpOperands(Operation *op, Value value,
342 const AnalysisState &state) const {
343 // IndexSwitchOps do not have tensor OpOperands. The yielded value can be
344 // any SSA. This is similar to IfOps.
345 auto switchOp = cast<scf::IndexSwitchOp>(op);
346 int64_t resultNum = cast<OpResult>(value).getResultNumber();
347 AliasingOpOperandList result;
348 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
349 auto yieldOp =
350 cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator());
351 result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum),
352 BufferRelation::Equivalent,
353 /*isDefinite=*/false));
354 }
355 auto defaultYieldOp =
356 cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator());
357 result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum),
358 BufferRelation::Equivalent,
359 /*isDefinite=*/false));
360 return result;
361 }
362
363 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
364 const BufferizationOptions &options,
365 BufferizationState &state) const {
366 OpBuilder::InsertionGuard g(rewriter);
367 auto switchOp = cast<scf::IndexSwitchOp>(op);
368
369 // Compute bufferized result types.
370 SmallVector<Type> newTypes;
371 for (Value result : switchOp.getResults()) {
372 if (!isa<TensorType>(result.getType())) {
373 newTypes.push_back(result.getType());
374 continue;
375 }
376 auto bufferType = bufferization::getBufferType(result, options, state);
377 if (failed(bufferType))
378 return failure();
379 newTypes.push_back(*bufferType);
380 }
381
382 // Create new op.
383 rewriter.setInsertionPoint(switchOp);
384 auto newSwitchOp = scf::IndexSwitchOp::create(
385 rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
386 switchOp.getCases(), switchOp.getCases().size());
387
388 // Move over blocks.
389 for (auto [src, dest] :
390 llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions()))
391 rewriter.inlineRegionBefore(src, dest, dest.begin());
392 rewriter.inlineRegionBefore(switchOp.getDefaultRegion(),
393 newSwitchOp.getDefaultRegion(),
394 newSwitchOp.getDefaultRegion().begin());
395
396 // Replace op results.
397 replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults());
398
399 return success();
400 }
401
402 FailureOr<BufferLikeType>
403 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
404 const BufferizationState &state,
405 SmallVector<Value> &invocationStack) const {
406 auto switchOp = cast<scf::IndexSwitchOp>(op);
407 assert(value.getDefiningOp() == op && "invalid value");
408 int64_t resultNum = cast<OpResult>(value).getResultNumber();
409
410 // TODO: Properly support with options, for now it is hardcoded MemRef type
411 // based approach Helper function to get buffer type of a case.
412 auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> {
413 auto yieldOp = cast<scf::YieldOp>(b.getTerminator());
414 Value yieldedValue = yieldOp->getOperand(resultNum);
415 if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
416 return bufferType;
417 auto maybeBufferType = bufferization::getBufferType(
418 yieldedValue, options, state, invocationStack);
419 return bufferization::detail::asMemRefType(maybeBufferType);
420 };
421
422 // Compute buffer type of the default case.
423 auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock());
424 if (failed(maybeBufferType))
425 return failure();
426 BaseMemRefType bufferType = *maybeBufferType;
427
428 // Compute buffer types of all other cases.
429 for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) {
430 auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i));
431 if (failed(yieldedBufferType))
432 return failure();
433
434 // Best case: Both branches have the exact same buffer type.
435 if (bufferType == *yieldedBufferType)
436 continue;
437
438 // Memory space mismatch.
439 if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace())
440 return op->emitError("inconsistent memory space on switch cases");
441
442 // TODO: Properly support with options, for now it is hardcoded MemRef
443 // type based approach Layout maps are different: Promote to fully dynamic
444 // layout map.
445 bufferType = getMemRefTypeWithFullyDynamicLayout(
446 cast<TensorType>(value.getType()), bufferType.getMemorySpace());
447 }
448
449 return cast<BufferLikeType>(bufferType);
450 }
451};
452
453/// Helper function for loop bufferization. Return the indices of all values
454/// that have a tensor type.
455static DenseSet<int64_t> getTensorIndices(ValueRange values) {
457 for (const auto &it : llvm::enumerate(values))
458 if (isa<TensorLikeType>(it.value().getType()))
459 result.insert(it.index());
460 return result;
461}
462
463/// Helper function for loop bufferization. Return the indices of all
464/// bbArg/yielded value pairs who's buffer relation is "Equivalent".
465DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
466 ValueRange yieldedValues,
467 const AnalysisState &state) {
468 unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
470 for (unsigned int i = 0; i < minSize; ++i) {
471 if (!isa<TensorLikeType>(bbArgs[i].getType()) ||
472 !isa<TensorLikeType>(yieldedValues[i].getType()))
473 continue;
474 if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
475 result.insert(i);
476 }
477 return result;
478}
479
480/// Helper function for loop bufferization. Return the bufferized values of the
481/// given OpOperands. If an operand is not a tensor, return the original value.
482static FailureOr<SmallVector<Value>>
483getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
484 const BufferizationOptions &options, BufferizationState &state) {
485 SmallVector<Value> result;
486 for (OpOperand &opOperand : operands) {
487 if (isa<TensorLikeType>(opOperand.get().getType())) {
488 FailureOr<Value> resultBuffer =
489 getBuffer(rewriter, opOperand.get(), options, state);
490 if (failed(resultBuffer))
491 return failure();
492 result.push_back(*resultBuffer);
493 } else {
494 result.push_back(opOperand.get());
495 }
496 }
497 return result;
498}
499
500/// Helper function for loop bufferization. Given a list of bbArgs of the new
501/// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into
502/// ToTensorOps, so that the block body can be moved over to the new op.
503static SmallVector<Value>
504getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
505 Block::BlockArgListType oldBbArgs,
506 const DenseSet<int64_t> &tensorIndices) {
507 SmallVector<Value> result;
508 for (const auto &it : llvm::enumerate(bbArgs)) {
509 size_t idx = it.index();
510 Value val = it.value();
511 if (tensorIndices.contains(idx)) {
512 result.push_back(
513 bufferization::ToTensorOp::create(rewriter, val.getLoc(),
514 oldBbArgs[idx].getType(), val)
515 .getResult());
516 } else {
517 result.push_back(val);
518 }
519 }
520 return result;
521}
522
523/// Compute the bufferized type of a loop iter_arg. This type must be equal to
524/// the bufferized type of the corresponding init_arg and the bufferized type
525/// of the corresponding yielded value.
526///
527/// This function uses bufferization::getBufferType to compute the bufferized
528/// type of the init_arg and of the yielded value. (The computation of the
529/// bufferized yielded value type usually requires computing the bufferized type
530/// of the iter_arg again; the implementation of getBufferType traces back the
531/// use-def chain of the given value and computes a buffer type along the way.)
532/// If both buffer types are equal, no casts are needed the computed buffer type
533/// can be used directly. Otherwise, the buffer types can only differ in their
534/// layout map and a cast must be inserted.
535static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
536 Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
537 const BufferizationOptions &options, const BufferizationState &state,
538 SmallVector<Value> &invocationStack) {
539 // Determine the buffer type of the init_arg.
540 auto initArgBufferType =
541 bufferization::getBufferType(initArg, options, state, invocationStack);
542 if (failed(initArgBufferType))
543 return failure();
544
545 if (llvm::count(invocationStack, iterArg) >= 2) {
546 // If the iter_arg is already twice on the invocation stack, just take the
547 // type of the init_arg. This is to avoid infinite loops when calculating
548 // the buffer type. This will most likely result in computing a memref type
549 // with a fully dynamic layout map.
550
551 // Note: For more precise layout map computation, a fixpoint iteration could
552 // be done (i.e., re-computing the yielded buffer type until the bufferized
553 // iter_arg type no longer changes). This current implementation immediately
554 // switches to a fully dynamic layout map when a mismatch between bufferized
555 // init_arg type and bufferized yield value type is detected.
556 return *initArgBufferType;
557 }
558
559 // Compute the buffer type of the yielded value.
560 BufferLikeType yieldedValueBufferType;
561 if (isa<BufferLikeType>(yieldedValue.getType())) {
562 // scf.yield was already bufferized.
563 yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
564 } else {
565 // Note: This typically triggers a recursive call for the buffer type of
566 // the iter_arg.
567 auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
568 state, invocationStack);
569 if (failed(maybeBufferType))
570 return failure();
571 yieldedValueBufferType = *maybeBufferType;
572 }
573
574 // If yielded type and init_arg type are the same, use that type directly.
575 if (*initArgBufferType == yieldedValueBufferType)
576 return yieldedValueBufferType;
577
578 // If there is a mismatch between the yielded buffer type and the init_arg
579 // buffer type, the buffer type must be promoted to a fully dynamic layout
580 // map.
581 auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType);
582 auto iterTensorType = cast<TensorType>(iterArg.getType());
583 auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType);
584 if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace())
585 return loopOp->emitOpError(
586 "init_arg and yielded value bufferize to inconsistent memory spaces");
587#ifndef NDEBUG
588 if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) {
589 assert(
590 llvm::all_equal({yieldedRankedBufferType.getShape(),
591 cast<MemRefType>(initBufferType).getShape(),
592 cast<RankedTensorType>(iterTensorType).getShape()}) &&
593 "expected same shape");
594 }
595#endif // NDEBUG
596 // TODO: Properly support with options, for now it is hardcoded MemRef type
597 // based approach
598 return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
599 iterTensorType, yieldedBufferType.getMemorySpace()));
600}
601
602/// Return `true` if the given loop may have 0 iterations.
603bool mayHaveZeroIterations(scf::ForOp forOp) {
604 std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound());
605 std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound());
606 if (!lb.has_value() || !ub.has_value())
607 return true;
608 return *ub <= *lb;
609}
610
611/// Bufferization of scf.for. Replace with a new scf.for that operates on
612/// memrefs.
613struct ForOpInterface
614 : public BufferizableOpInterface::ExternalModel<ForOpInterface,
615 scf::ForOp> {
616 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
617 const AnalysisState &state) const {
618 auto forOp = cast<scf::ForOp>(op);
619
620 // If the loop has zero iterations, the results of the op are their
621 // corresponding init_args, meaning that the init_args bufferize to a read.
622 if (mayHaveZeroIterations(forOp))
623 return true;
624
625 // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
626 // its matching bbArg may.
627 return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand));
628 }
629
630 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
631 const AnalysisState &state) const {
632 // Tensor iter_args of scf::ForOps are always considered as a write.
633 return true;
634 }
635
636 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
637 const AnalysisState &state) const {
638 auto forOp = cast<scf::ForOp>(op);
639 OpResult opResult = forOp.getTiedLoopResult(&opOperand);
640 BufferRelation relation = bufferRelation(op, opResult, state);
641 return {{opResult, relation,
642 /*isDefinite=*/relation == BufferRelation::Equivalent}};
643 }
644
645 BufferRelation bufferRelation(Operation *op, OpResult opResult,
646 const AnalysisState &state) const {
647 // ForOp results are equivalent to their corresponding init_args if the
648 // corresponding iter_args and yield values are equivalent.
649 auto forOp = cast<scf::ForOp>(op);
650 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
651 bool equivalentYield = state.areEquivalentBufferizedValues(
652 bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
653 return equivalentYield ? BufferRelation::Equivalent
654 : BufferRelation::Unknown;
655 }
656
657 bool isWritable(Operation *op, Value value,
658 const AnalysisState &state) const {
659 // Interestingly, scf::ForOp's bbArg can **always** be viewed
660 // inplace from the perspective of ops nested under:
661 // 1. Either the matching iter operand is not bufferized inplace and an
662 // alloc + optional copy makes the bbArg itself inplaceable.
663 // 2. Or the matching iter operand is bufferized inplace and bbArg just
664 // bufferizes to that too.
665 return true;
666 }
667
668 LogicalResult
669 resolveConflicts(Operation *op, RewriterBase &rewriter,
670 const AnalysisState &analysisState,
671 const BufferizationState &bufferizationState) const {
672 auto bufferizableOp = cast<BufferizableOpInterface>(op);
673 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
674 rewriter, analysisState, bufferizationState)))
675 return failure();
676
677 if (analysisState.getOptions().copyBeforeWrite)
678 return success();
679
680 // According to the `getAliasing...` implementations, a bufferized OpResult
681 // may alias only with the corresponding bufferized init_arg (or with a
682 // newly allocated buffer) and not with other buffers defined outside of the
683 // loop. I.e., the i-th OpResult may alias with the i-th init_arg;
684 // but not with any other OpOperand.
685 auto forOp = cast<scf::ForOp>(op);
686 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
687 OpBuilder::InsertionGuard g(rewriter);
688 rewriter.setInsertionPoint(yieldOp);
689
690 // Indices of all iter_args that have tensor type. These are the ones that
691 // are bufferized.
692 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
693 // For every yielded value, does it alias with something defined outside of
694 // the loop?
695 SmallVector<Value> yieldValues;
696 for (const auto it : llvm::enumerate(yieldOp.getResults())) {
697 // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this
698 // type cannot be used in the signature of `resolveConflicts` because the
699 // op interface is in the "IR" build unit and the `OneShotAnalysisState`
700 // is defined in the "Transforms" build unit.
701 if (!indices.contains(it.index()) ||
702 doesNotAliasExternalValue(
703 it.value(), &forOp.getRegion(),
704 /*exceptions=*/forOp.getRegionIterArg(it.index()),
705 static_cast<const OneShotAnalysisState &>(analysisState))) {
706 yieldValues.push_back(it.value());
707 continue;
708 }
709 FailureOr<Value> alloc = allocateTensorForShapedValue(
710 rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
711 bufferizationState);
712 if (failed(alloc))
713 return failure();
714 yieldValues.push_back(*alloc);
715 }
716
717 rewriter.modifyOpInPlace(
718 yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
719 return success();
720 }
721
722 FailureOr<BufferLikeType>
723 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
724 const BufferizationState &state,
725 SmallVector<Value> &invocationStack) const {
726 auto forOp = cast<scf::ForOp>(op);
727 assert(getOwnerOfValue(value) == op && "invalid value");
728 assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
729
730 if (auto opResult = dyn_cast<OpResult>(value)) {
731 // The type of an OpResult must match the corresponding iter_arg type.
732 BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
733 return bufferization::getBufferType(bbArg, options, state,
734 invocationStack);
735 }
736
737 // Compute result/argument number.
738 BlockArgument bbArg = cast<BlockArgument>(value);
739 unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
740
741 // Compute the bufferized type.
742 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
743 Value yieldedValue = yieldOp.getOperand(resultNum);
744 BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
745 Value initArg = forOp.getInitArgs()[resultNum];
746 return computeLoopRegionIterArgBufferType(
747 op, iterArg, initArg, yieldedValue, options, state, invocationStack);
748 }
749
750 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
751 const BufferizationOptions &options,
752 BufferizationState &state) const {
753 auto forOp = cast<scf::ForOp>(op);
754 Block *oldLoopBody = forOp.getBody();
755
756 // Indices of all iter_args that have tensor type. These are the ones that
757 // are bufferized.
758 DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());
759
760 // The new memref init_args of the loop.
761 FailureOr<SmallVector<Value>> maybeInitArgs =
762 getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
763 if (failed(maybeInitArgs))
764 return failure();
765 SmallVector<Value> initArgs = *maybeInitArgs;
766
767 // Cast init_args if necessary.
768 SmallVector<Value> castedInitArgs;
769 for (const auto &it : llvm::enumerate(initArgs)) {
770 Value initArg = it.value();
771 Value result = forOp->getResult(it.index());
772 // If the type is not a tensor, bufferization doesn't need to touch it.
773 if (!isa<TensorLikeType>(result.getType())) {
774 castedInitArgs.push_back(initArg);
775 continue;
776 }
777 auto targetType = bufferization::getBufferType(result, options, state);
778 if (failed(targetType))
779 return failure();
780 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
781 }
782
783 // Construct a new scf.for op with memref instead of tensor values.
784 auto newForOp = scf::ForOp::create(
785 rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
786 forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
787 forOp.getUnsignedCmp());
788 newForOp->setAttrs(forOp->getAttrs());
789 Block *loopBody = newForOp.getBody();
790
791 // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
792 // iter_args of the new loop in ToTensorOps.
793 rewriter.setInsertionPointToStart(loopBody);
794 SmallVector<Value> iterArgs =
795 getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(),
796 forOp.getRegionIterArgs(), indices);
797 iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
798
799 // Move loop body to new loop.
800 rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
801
802 // Replace loop results.
803 replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
804
805 return success();
806 }
807
808 /// Assert that yielded values of an scf.for op are equivalent to their
809 /// corresponding bbArgs. In that case, the buffer relations of the
810 /// corresponding OpResults are "Equivalent".
811 ///
812 /// If this is not the case, an allocs+copies are inserted and yielded from
813 /// the loop. This could be a performance problem, so it must be explicitly
814 /// activated with `alloc-return-allocs`.
815 LogicalResult verifyAnalysis(Operation *op,
816 const AnalysisState &state) const {
817 const auto &options =
818 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
819 if (options.allowReturnAllocsFromLoops)
820 return success();
821
822 auto forOp = cast<scf::ForOp>(op);
823 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
824 for (OpResult opResult : op->getOpResults()) {
825 if (!isa<TensorLikeType>(opResult.getType()))
826 continue;
827 // Note: This is overly strict. We should check for aliasing bufferized
828 // values. But we don't have a "must-alias" analysis yet.
829 if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
830 return yieldOp->emitError()
831 << "Yield operand #" << opResult.getResultNumber()
832 << " is not equivalent to the corresponding iter bbArg";
833 }
834
835 return success();
836 }
837};
838
839/// Bufferization of scf.while. Replace with a new scf.while that operates on
840/// memrefs.
841struct WhileOpInterface
842 : public BufferizableOpInterface::ExternalModel<WhileOpInterface,
843 scf::WhileOp> {
844 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
845 const AnalysisState &state) const {
846 // Tensor iter_args of scf::WhileOps are always considered as a read.
847 return true;
848 }
849
850 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
851 const AnalysisState &state) const {
852 // Tensor iter_args of scf::WhileOps are always considered as a write.
853 return true;
854 }
855
856 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
857 const AnalysisState &state) const {
858 auto whileOp = cast<scf::WhileOp>(op);
859 unsigned int idx = opOperand.getOperandNumber();
860
861 // The OpResults and OpOperands may not match. They may not even have the
862 // same type. The number of OpResults and OpOperands can also differ.
863 if (idx >= op->getNumResults() ||
864 opOperand.get().getType() != op->getResult(idx).getType())
865 return {};
866
867 // The only aliasing OpResult may be the one at the same index.
868 OpResult opResult = whileOp->getResult(idx);
869 BufferRelation relation = bufferRelation(op, opResult, state);
870 return {{opResult, relation,
871 /*isDefinite=*/relation == BufferRelation::Equivalent}};
872 }
873
874 BufferRelation bufferRelation(Operation *op, OpResult opResult,
875 const AnalysisState &state) const {
876 // WhileOp results are equivalent to their corresponding init_args if the
877 // corresponding iter_args and yield values are equivalent (for both the
878 // "before" and the "after" block).
879 unsigned int resultNumber = opResult.getResultNumber();
880 auto whileOp = cast<scf::WhileOp>(op);
881
882 // The "before" region bbArgs and the OpResults may not match.
883 if (resultNumber >= whileOp.getBeforeArguments().size())
884 return BufferRelation::Unknown;
885 if (opResult.getType() !=
886 whileOp.getBeforeArguments()[resultNumber].getType())
887 return BufferRelation::Unknown;
888
889 auto conditionOp = whileOp.getConditionOp();
890 BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber];
891 Value conditionOperand = conditionOp.getArgs()[resultNumber];
892 bool equivCondition =
893 state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand);
894
895 auto yieldOp = whileOp.getYieldOp();
896 BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber];
897 Value yieldOperand = yieldOp.getOperand(resultNumber);
898 bool equivYield =
899 state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand);
900
901 return equivCondition && equivYield ? BufferRelation::Equivalent
902 : BufferRelation::Unknown;
903 }
904
905 bool isWritable(Operation *op, Value value,
906 const AnalysisState &state) const {
907 // Interestingly, scf::WhileOp's bbArg can **always** be viewed
908 // inplace from the perspective of ops nested under:
909 // 1. Either the matching iter operand is not bufferized inplace and an
910 // alloc + optional copy makes the bbArg itself inplaceable.
911 // 2. Or the matching iter operand is bufferized inplace and bbArg just
912 // bufferizes to that too.
913 return true;
914 }
915
916 LogicalResult
917 resolveConflicts(Operation *op, RewriterBase &rewriter,
918 const AnalysisState &analysisState,
919 const BufferizationState &bufferizationState) const {
920 auto bufferizableOp = cast<BufferizableOpInterface>(op);
921 if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
922 rewriter, analysisState, bufferizationState)))
923 return failure();
924
925 if (analysisState.getOptions().copyBeforeWrite)
926 return success();
927
928 // According to the `getAliasing...` implementations, a bufferized OpResult
929 // may alias only with the corresponding bufferized init_arg and with no
930 // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg;
931 // but not with any other OpOperand. If a corresponding OpResult/init_arg
932 // pair bufferizes to equivalent buffers, this aliasing requirement is
933 // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
934 // (New buffer copies do not alias with any buffer.)
935 OpBuilder::InsertionGuard g(rewriter);
936 auto whileOp = cast<scf::WhileOp>(op);
937 auto conditionOp = whileOp.getConditionOp();
938
939 // For every yielded value, is the value equivalent to its corresponding
940 // bbArg?
941 DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
942 whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
943 DenseSet<int64_t> equivalentYieldsAfter =
944 getEquivalentBuffers(whileOp.getAfterArguments(),
945 whileOp.getYieldOp().getResults(), analysisState);
946
947 // Update "before" region.
948 rewriter.setInsertionPoint(conditionOp);
949 SmallVector<Value> beforeYieldValues;
950 for (int64_t idx = 0;
951 idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
952 Value value = conditionOp.getArgs()[idx];
953 if (!isa<TensorLikeType>(value.getType()) ||
954 (equivalentYieldsAfter.contains(idx) &&
955 equivalentYieldsBefore.contains(idx))) {
956 beforeYieldValues.push_back(value);
957 continue;
958 }
959 FailureOr<Value> alloc = allocateTensorForShapedValue(
960 rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
961 bufferizationState);
962 if (failed(alloc))
963 return failure();
964 beforeYieldValues.push_back(*alloc);
965 }
966 rewriter.modifyOpInPlace(conditionOp, [&]() {
967 conditionOp.getArgsMutable().assign(beforeYieldValues);
968 });
969
970 return success();
971 }
972
973 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
974 const BufferizationOptions &options,
975 BufferizationState &state) const {
976 auto whileOp = cast<scf::WhileOp>(op);
977
978 // Indices of all bbArgs that have tensor type. These are the ones that
979 // are bufferized. The "before" and "after" regions may have different args.
980 DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits());
981 DenseSet<int64_t> indicesAfter =
982 getTensorIndices(whileOp.getAfterArguments());
983
984 // The new memref init_args of the loop.
985 FailureOr<SmallVector<Value>> maybeInitArgs =
986 getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
987 if (failed(maybeInitArgs))
988 return failure();
989 SmallVector<Value> initArgs = *maybeInitArgs;
990
991 // Cast init_args if necessary.
992 SmallVector<Value> castedInitArgs;
993 for (const auto &it : llvm::enumerate(initArgs)) {
994 Value initArg = it.value();
995 Value beforeArg = whileOp.getBeforeArguments()[it.index()];
996 // If the type is not a tensor, bufferization doesn't need to touch it.
997 if (!isa<TensorLikeType>(beforeArg.getType())) {
998 castedInitArgs.push_back(initArg);
999 continue;
1000 }
1001 auto targetType = bufferization::getBufferType(beforeArg, options, state);
1002 if (failed(targetType))
1003 return failure();
1004 castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
1005 }
1006
1007 // The result types of a WhileOp are the same as the "after" bbArg types.
1008 SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
1009 whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
1010 if (!isa<TensorLikeType>(bbArg.getType()))
1011 return bbArg.getType();
1012 // TODO: error handling
1013 return llvm::cast<Type>(
1014 *bufferization::getBufferType(bbArg, options, state));
1015 });
1016
1017 // Construct a new scf.while op with memref instead of tensor values.
1018 ValueRange argsRangeBefore(castedInitArgs);
1019 TypeRange argsTypesBefore(argsRangeBefore);
1020 auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1021 argsTypesAfter, castedInitArgs);
1022
1023 // Add before/after regions to the new op.
1024 SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
1025 whileOp.getLoc());
1026 SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(),
1027 whileOp.getLoc());
1028 Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock();
1029 newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore);
1030 Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock();
1031 newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter);
1032
1033 // Set up new iter_args and move the loop condition block to the new op.
1034 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1035 // in ToTensorOps.
1036 rewriter.setInsertionPointToStart(newBeforeBody);
1037 SmallVector<Value> newBeforeArgs =
1038 getBbArgReplacements(rewriter, newWhileOp.getBeforeArguments(),
1039 whileOp.getBeforeArguments(), indicesBefore);
1040 rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs);
1041
1042 // Set up new iter_args and move the loop body block to the new op.
1043 // The old block uses tensors, so wrap the (memref) bbArgs of the new block
1044 // in ToTensorOps.
1045 rewriter.setInsertionPointToStart(newAfterBody);
1046 SmallVector<Value> newAfterArgs =
1047 getBbArgReplacements(rewriter, newWhileOp.getAfterArguments(),
1048 whileOp.getAfterArguments(), indicesAfter);
1049 rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs);
1050
1051 // Replace loop results.
1052 replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
1053
1054 return success();
1055 }
1056
1057 FailureOr<BufferLikeType>
1058 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1059 const BufferizationState &state,
1060 SmallVector<Value> &invocationStack) const {
1061 auto whileOp = cast<scf::WhileOp>(op);
1062 assert(getOwnerOfValue(value) == op && "invalid value");
1063 assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
1064
1065 // Case 1: Block argument of the "before" region.
1066 if (auto bbArg = dyn_cast<BlockArgument>(value)) {
1067 if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) {
1068 Value initArg = whileOp.getInits()[bbArg.getArgNumber()];
1069 auto yieldOp = whileOp.getYieldOp();
1070 Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
1071 return computeLoopRegionIterArgBufferType(
1072 op, bbArg, initArg, yieldedValue, options, state, invocationStack);
1073 }
1074 }
1075
1076 // Case 2: OpResult of the loop or block argument of the "after" region.
1077 // The bufferized "after" bbArg type can be directly computed from the
1078 // bufferized "before" bbArg type.
1079 unsigned resultNum;
1080 if (auto opResult = dyn_cast<OpResult>(value)) {
1081 resultNum = opResult.getResultNumber();
1082 } else if (cast<BlockArgument>(value).getOwner()->getParent() ==
1083 &whileOp.getAfter()) {
1084 resultNum = cast<BlockArgument>(value).getArgNumber();
1085 } else {
1086 llvm_unreachable("invalid value");
1087 }
1088 Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
1089 if (!isa<TensorLikeType>(conditionYieldedVal.getType())) {
1090 // scf.condition was already bufferized.
1091 return cast<BufferLikeType>(conditionYieldedVal.getType());
1092 }
1093 return bufferization::getBufferType(conditionYieldedVal, options, state,
1094 invocationStack);
1095 }
1096
1097 /// Assert that yielded values of an scf.while op are equivalent to their
1098 /// corresponding bbArgs. In that case, the buffer relations of the
1099 /// corresponding OpResults are "Equivalent".
1100 ///
1101 /// If this is not the case, allocs+copies are inserted and yielded from
1102 /// the loop. This could be a performance problem, so it must be explicitly
1103 /// activated with `allow-return-allocs`.
1104 ///
1105 /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the
1106 /// equivalence condition must be checked for both.
1107 LogicalResult verifyAnalysis(Operation *op,
1108 const AnalysisState &state) const {
1109 auto whileOp = cast<scf::WhileOp>(op);
1110 const auto &options =
1111 static_cast<const OneShotBufferizationOptions &>(state.getOptions());
1112 if (options.allowReturnAllocsFromLoops)
1113 return success();
1114
1115 auto conditionOp = whileOp.getConditionOp();
1116 for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
1117 Block *block = conditionOp->getBlock();
1118 if (!isa<TensorLikeType>(it.value().getType()))
1119 continue;
1120 if (it.index() >= block->getNumArguments() ||
1121 !state.areEquivalentBufferizedValues(it.value(),
1122 block->getArgument(it.index())))
1123 return conditionOp->emitError()
1124 << "Condition arg #" << it.index()
1125 << " is not equivalent to the corresponding iter bbArg";
1126 }
1127
1128 auto yieldOp = whileOp.getYieldOp();
1129 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1130 Block *block = yieldOp->getBlock();
1131 if (!isa<TensorLikeType>(it.value().getType()))
1132 continue;
1133 if (it.index() >= block->getNumArguments() ||
1134 !state.areEquivalentBufferizedValues(it.value(),
1135 block->getArgument(it.index())))
1136 return yieldOp->emitError()
1137 << "Yield operand #" << it.index()
1138 << " is not equivalent to the corresponding iter bbArg";
1139 }
1140
1141 return success();
1142 }
1143};
1144
1145/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
1146/// this is for analysis only.
1147struct YieldOpInterface
1148 : public BufferizableOpInterface::ExternalModel<YieldOpInterface,
1149 scf::YieldOp> {
1150 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1151 const AnalysisState &state) const {
1152 return true;
1153 }
1154
1155 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1156 const AnalysisState &state) const {
1157 return false;
1158 }
1159
1160 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1161 const AnalysisState &state) const {
1162 if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) {
1163 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1164 BufferRelation::Equivalent, /*isDefinite=*/false}};
1165 }
1166 if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
1167 return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
1168 BufferRelation::Equivalent}};
1169 return {};
1170 }
1171
1172 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
1173 const AnalysisState &state) const {
1174 // Yield operands always bufferize inplace. Otherwise, an alloc + copy
1175 // may be generated inside the block. We should not return/yield allocations
1176 // when possible.
1177 return true;
1178 }
1179
1180 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1181 const BufferizationOptions &options,
1182 BufferizationState &state) const {
1183 auto yieldOp = cast<scf::YieldOp>(op);
1184 if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
1185 scf::WhileOp>(yieldOp->getParentOp()))
1186 return yieldOp->emitError("unsupported scf::YieldOp parent");
1187
1188 SmallVector<Value> newResults;
1189 for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
1190 Value value = it.value();
1191 if (isa<TensorLikeType>(value.getType())) {
1192 FailureOr<Value> maybeBuffer =
1193 getBuffer(rewriter, value, options, state);
1194 if (failed(maybeBuffer))
1195 return failure();
1196 Value buffer = *maybeBuffer;
1197 // We may have to cast the value before yielding it.
1198 if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
1199 yieldOp->getParentOp())) {
1200 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1201 yieldOp->getParentOp()->getResult(it.index()), options, state);
1202 if (failed(resultType))
1203 return failure();
1204 buffer = castBuffer(rewriter, buffer, *resultType);
1205 } else if (auto whileOp =
1206 dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
1207 FailureOr<BufferLikeType> resultType = bufferization::getBufferType(
1208 whileOp.getBeforeArguments()[it.index()], options, state);
1209 if (failed(resultType))
1210 return failure();
1211 buffer = castBuffer(rewriter, buffer, *resultType);
1212 }
1213 newResults.push_back(buffer);
1214 } else {
1215 newResults.push_back(value);
1216 }
1217 }
1218
1219 replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
1220 return success();
1221 }
1222};
1223
1224/// Bufferization of ForallOp. This also bufferizes the terminator of the
1225/// region. There are op interfaces for the terminators (InParallelOp
1226/// and ParallelInsertSliceOp), but these are only used during analysis. Not
1227/// for bufferization.
1228struct ForallOpInterface
1229 : public BufferizableOpInterface::ExternalModel<ForallOpInterface,
1230 ForallOp> {
1231 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
1232 const AnalysisState &state) const {
1233 // All tensor operands to `scf.forall` are `shared_outs` and all
1234 // shared outs are assumed to be read by the loop. This does not
1235 // account for the case where the entire value is over-written,
1236 // but being conservative here.
1237 return true;
1238 }
1239
1240 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
1241 const AnalysisState &state) const {
1242 // Outputs of scf::ForallOps are always considered as a write.
1243 return true;
1244 }
1245
1246 AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
1247 const AnalysisState &state) const {
1248 auto forallOp = cast<ForallOp>(op);
1249 return {
1250 {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}};
1251 }
1252
1253 bool isWritable(Operation *op, Value value,
1254 const AnalysisState &state) const {
1255 return true;
1256 }
1257
1258 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
1259 const BufferizationOptions &options,
1260 BufferizationState &state) const {
1261 OpBuilder::InsertionGuard guard(rewriter);
1262 auto forallOp = cast<ForallOp>(op);
1263 int64_t rank = forallOp.getRank();
1264
1265 // Get buffers for all output operands.
1266 SmallVector<Value> buffers;
1267 for (Value out : forallOp.getOutputs()) {
1268 FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
1269 if (failed(buffer))
1270 return failure();
1271 buffers.push_back(*buffer);
1272 }
1273
1274 // Use buffers instead of block arguments.
1275 rewriter.setInsertionPointToStart(forallOp.getBody());
1276 for (const auto &it : llvm::zip(
1277 forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
1278 BlockArgument bbArg = std::get<0>(it);
1279 Value buffer = std::get<1>(it);
1280 Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1281 bbArg.getType(), buffer);
1282 bbArg.replaceAllUsesWith(bufferAsTensor);
1283 }
1284
1285 // Create new ForallOp without any results and drop the automatically
1286 // introduced terminator.
1287 rewriter.setInsertionPoint(forallOp);
1288 ForallOp newForallOp;
1289 newForallOp = ForallOp::create(
1290 rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
1291 forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
1292 /*outputs=*/ValueRange(), forallOp.getMapping());
1293
1294 // Keep discardable attributes from the original op.
1295 newForallOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
1296
1297 rewriter.eraseOp(newForallOp.getBody()->getTerminator());
1298
1299 // Move over block contents of the old op.
1300 SmallVector<Value> replacementBbArgs;
1301 replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(),
1302 newForallOp.getBody()->getArguments().end());
1303 replacementBbArgs.append(forallOp.getOutputs().size(), Value());
1304 rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
1305 replacementBbArgs);
1306
1307 // Remove the old op and replace all of its uses.
1308 replaceOpWithBufferizedValues(rewriter, op, buffers);
1309
1310 return success();
1311 }
1312
1313 FailureOr<BufferLikeType>
1314 getBufferType(Operation *op, Value value, const BufferizationOptions &options,
1315 const BufferizationState &state,
1316 SmallVector<Value> &invocationStack) const {
1317 auto forallOp = cast<ForallOp>(op);
1318
1319 if (auto bbArg = dyn_cast<BlockArgument>(value))
1320 // A tensor block argument has the same bufferized type as the
1321 // corresponding output operand.
1322 return bufferization::getBufferType(
1323 forallOp.getTiedOpOperand(bbArg)->get(), options, state,
1324 invocationStack);
1325
1326 // The bufferized result type is the same as the bufferized type of the
1327 // corresponding output operand.
1328 return bufferization::getBufferType(
1329 forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
1330 state, invocationStack);
1331 }
1332
1333 bool isRepetitiveRegion(Operation *op, unsigned index) const {
1334 auto forallOp = cast<ForallOp>(op);
1335
1336 // This op is repetitive if it has 1 or more steps.
1337 // If the control variables are dynamic, it is also considered so.
1338 for (auto [lb, ub, step] :
1339 llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1340 forallOp.getMixedStep())) {
1341 std::optional<int64_t> lbConstant = getConstantIntValue(lb);
1342 if (!lbConstant)
1343 return true;
1344
1345 std::optional<int64_t> ubConstant = getConstantIntValue(ub);
1346 if (!ubConstant)
1347 return true;
1348
1349 std::optional<int64_t> stepConstant = getConstantIntValue(step);
1350 if (!stepConstant)
1351 return true;
1352
1353 if (*lbConstant + *stepConstant < *ubConstant)
1354 return true;
1355 }
1356 return false;
1357 }
1358
1359 bool isParallelRegion(Operation *op, unsigned index) const {
1360 return isRepetitiveRegion(op, index);
1361 }
1362};
1363
1364/// Nothing to do for InParallelOp.
1365struct InParallelOpInterface
1366 : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
1367 InParallelOp> {
1368 LogicalResult bufferize(Operation *op, RewriterBase &b,
1369 const BufferizationOptions &options,
1370 BufferizationState &state) const {
1371 llvm_unreachable("op does not have any tensor OpOperands / OpResults");
1372 return failure();
1373 }
1374};
1375
1376} // namespace
1377} // namespace scf
1378} // namespace mlir
1379
1381 DialectRegistry &registry) {
1382 registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
1383 ConditionOp::attachInterface<ConditionOpInterface>(*ctx);
1384 ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
1385 ForOp::attachInterface<ForOpInterface>(*ctx);
1386 IfOp::attachInterface<IfOpInterface>(*ctx);
1387 IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx);
1388 ForallOp::attachInterface<ForallOpInterface>(*ctx);
1389 InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
1390 WhileOp::attachInterface<WhileOpInterface>(*ctx);
1391 YieldOp::attachInterface<YieldOpInterface>(*ctx);
1392 });
1393}
return success()
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
static Operation * getOwnerOfValue(Value value)
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:318
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:315
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:466
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
Definition Operation.h:527
result_range getOpResults()
Definition Operation.h:446
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
bool isProperAncestor(Region *other)
Return true if this region is a proper ancestor of the other region.
Definition Region.cpp:50
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
void applyOnAliases(Value v, function_ref< void(Value)> fun) const
Apply fun to all aliases of v.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122