MLIR 23.0.0git
Bufferize.cpp
Go to the documentation of this file.
1//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Operation.h"
23#include "llvm/Support/DebugLog.h"
24#include <optional>
25
26namespace mlir {
27namespace bufferization {
28#define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
29#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
30} // namespace bufferization
31} // namespace mlir
32
33#define DEBUG_TYPE "bufferize"
34
35using namespace mlir;
36using namespace mlir::bufferization;
37
38namespace {
39
41parseHeuristicOption(const std::string &s) {
42 if (s == "bottom-up")
44 if (s == "top-down")
46 if (s == "bottom-up-from-terminators")
47 return OneShotBufferizationOptions::AnalysisHeuristic::
48 BottomUpFromTerminators;
49 if (s == "fuzzer")
51 llvm_unreachable("invalid analysisheuristic option");
52}
53
54struct OneShotBufferizePass
55 : public bufferization::impl::OneShotBufferizePassBase<
56 OneShotBufferizePass> {
57 using Base::Base;
58
59 void runOnOperation() override {
61 if (!options) {
62 // Make new bufferization options if none were provided when creating the
63 // pass.
64 opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
65 opt.allowUnknownOps = allowUnknownOps;
66 opt.analysisFuzzerSeed = analysisFuzzerSeed;
67 opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
68 opt.copyBeforeWrite = copyBeforeWrite;
69 opt.dumpAliasSets = dumpAliasSets;
70 opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
71
72 if (mustInferMemorySpace && useEncodingForMemorySpace) {
73 emitError(getOperation()->getLoc())
74 << "only one of 'must-infer-memory-space' and "
75 "'use-encoding-for-memory-space' are allowed in "
76 << getArgument();
77 return signalPassFailure();
78 }
79
80 if (mustInferMemorySpace) {
81 opt.defaultMemorySpaceFn =
82 [](TensorLikeType t) -> std::optional<Attribute> {
83 return std::nullopt;
84 };
85 }
86
87 if (useEncodingForMemorySpace) {
88 opt.defaultMemorySpaceFn =
89 [](TensorLikeType t) -> std::optional<Attribute> {
90 if (auto rtt = dyn_cast<RankedTensorType>(t))
91 return rtt.getEncoding();
92 return std::nullopt;
93 };
94 }
95
96 opt.printConflicts = printConflicts;
97 opt.bufferAlignment = bufferAlignment;
98 opt.testAnalysisOnly = testAnalysisOnly;
99 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
100 opt.checkParallelRegions = checkParallelRegions;
101 opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
102
103 // Configure type converter.
104 LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
105 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
106 emitError(UnknownLoc::get(&getContext()),
107 "Invalid option: 'infer-layout-map' is not a valid value for "
108 "'unknown-type-conversion'");
109 return signalPassFailure();
110 }
111 opt.unknownTypeConverterFn = [=](TensorLikeType type,
112 Attribute memorySpace,
114 const auto tensorType = cast<TensorType>(type);
115 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116 return cast<bufferization::BufferLikeType>(
117 bufferization::getMemRefTypeWithStaticIdentityLayout(
118 tensorType, memorySpace));
119 assert(unknownTypeConversionOption ==
120 LayoutMapOption::FullyDynamicLayoutMap &&
121 "invalid layout map option");
122 return cast<bufferization::BufferLikeType>(
123 bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
124 memorySpace));
125 };
126
127 // Configure op filter.
128 OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
129 // Filter may be specified via options.
130 if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
131 return llvm::is_contained(this->dialectFilter,
132 op->getDialect()->getNamespace());
133 // No filter specified: All other ops are allowed.
134 return true;
135 };
136 opt.opFilter.allowOperation(filterFn);
137 } else {
138 opt = *options;
139 }
140
141 if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
142 // These two flags do not make sense together: "copy-before-write"
143 // indicates that copies should be inserted before every memory write,
144 // but "test-analysis-only" indicates that only the analysis should be
145 // tested. (I.e., no IR is bufferized.)
146 emitError(UnknownLoc::get(&getContext()),
147 "Invalid option: 'copy-before-write' cannot be used with "
148 "'test-analysis-only'");
149 return signalPassFailure();
150 }
151
152 if (opt.printConflicts && !opt.testAnalysisOnly) {
153 emitError(
154 UnknownLoc::get(&getContext()),
155 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
156 return signalPassFailure();
157 }
158
159 if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
160 emitError(
161 UnknownLoc::get(&getContext()),
162 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
163 return signalPassFailure();
164 }
165
166 BufferizationState state;
167 BufferizationStatistics statistics;
168 ModuleOp moduleOp = getOperation();
169 if (opt.bufferizeFunctionBoundaries) {
170 if (failed(
171 runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
172 signalPassFailure();
173 return;
174 }
175 } else {
176 if (!opt.noAnalysisFuncFilter.empty()) {
177 emitError(UnknownLoc::get(&getContext()),
178 "Invalid option: 'no-analysis-func-filter' requires "
179 "'bufferize-function-boundaries'");
180 return signalPassFailure();
181 }
182 if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
183 signalPassFailure();
184 return;
185 }
186 }
187
188 // Set pass statistics.
189 this->numBufferAlloc = statistics.numBufferAlloc;
190 this->numTensorInPlace = statistics.numTensorInPlace;
191 this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
192 }
193
194private:
195 std::optional<OneShotBufferizationOptions> options;
196};
197} // namespace
198
199//===----------------------------------------------------------------------===//
200// BufferizableOpInterface-based Bufferization
201//===----------------------------------------------------------------------===//
202
203namespace {
204/// A rewriter that keeps track of extra information during bufferization.
205class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
206public:
207 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
208 DenseSet<Operation *> &toBufferOps,
209 SmallVector<Operation *> &worklist,
210 const BufferizationOptions &options,
211 BufferizationStatistics *statistics)
212 : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
213 worklist(worklist), analysisState(options), statistics(statistics) {
214 setListener(this);
215 }
216
217protected:
218 void notifyOperationErased(Operation *op) override {
219 erasedOps.insert(op);
220 // Erase if present.
221 toBufferOps.erase(op);
222 }
223
224 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
225 // We only care about newly created ops.
226 if (previous.isSet())
227 return;
228
229 erasedOps.erase(op);
230
231 // Gather statistics about allocs.
232 if (statistics) {
233 if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
234 statistics->numBufferAlloc += static_cast<int64_t>(
235 sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
236 }
237
238 // Keep track of to_buffer ops.
239 if (isa<ToBufferOp>(op)) {
240 toBufferOps.insert(op);
241 return;
242 }
243
244 // Skip to_tensor ops.
245 if (isa<ToTensorOp>(op))
246 return;
247
248 // Skip non-tensor ops.
249 if (!hasTensorSemantics(op))
250 return;
251
252 // Skip ops that are not allowed to be bufferized.
253 auto const &options = analysisState.getOptions();
254 if (!options.isOpAllowed(op))
255 return;
256
257 // Add op to worklist.
258 worklist.push_back(op);
259 }
260
261private:
262 /// A set of all erased ops.
263 DenseSet<Operation *> &erasedOps;
264
265 /// A set of all to_buffer ops.
266 DenseSet<Operation *> &toBufferOps;
267
268 /// The worklist of ops to be bufferized.
269 SmallVector<Operation *> &worklist;
270
271 /// The analysis state. Used for debug assertions and access to the
272 /// bufferization options.
273 const AnalysisState analysisState;
274
275 /// Bufferization statistics for debugging.
276 BufferizationStatistics *statistics;
277};
278} // namespace
279
282 BufferizationState &bufferizationState,
283 BufferizationStatistics *statistics) {
284 if (options.copyBeforeWrite) {
285 AnalysisState analysisState(options);
286 if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
287 return failure();
288 }
289
290 // Keep track of to_buffer ops.
291 DenseSet<Operation *> toBufferOps;
292 op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
293
294 // Gather all bufferizable ops in top-to-bottom order.
295 //
296 // We should ideally know the exact memref type of all operands when
297 // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
298 // Otherwise, we have to use a memref type with a fully dynamic layout map to
299 // avoid copies. We are currently missing patterns for layout maps to
300 // canonicalize away (or canonicalize to more precise layouts).
302 op->walk<WalkOrder::PostOrder>([&](Operation *op) {
303 if (options.isOpAllowed(op) && hasTensorSemantics(op))
304 worklist.push_back(op);
305 });
306
307 // Keep track of all erased ops.
308 DenseSet<Operation *> erasedOps;
309
310 // Bufferize all ops.
311 BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
312 worklist, options, statistics);
313 for (unsigned i = 0; i < worklist.size(); ++i) {
314 Operation *nextOp = worklist[i];
315 // Skip ops that were erased.
316 if (erasedOps.contains(nextOp))
317 continue;
318 // Skip ops that are not bufferizable or not allowed.
319 auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
320 if (!bufferizableOp)
321 continue;
322 // Skip ops that no longer have tensor semantics.
323 if (!hasTensorSemantics(nextOp))
324 continue;
325 // Check for unsupported unstructured control flow.
326 if (!bufferizableOp.supportsUnstructuredControlFlow())
327 for (Region &r : nextOp->getRegions())
328 if (r.getBlocks().size() > 1)
329 return nextOp->emitOpError(
330 "op or BufferizableOpInterface implementation does not support "
331 "unstructured control flow, but at least one region has multiple "
332 "blocks");
333
334 // Bufferize the op.
335 LDBG(3) << "//===-------------------------------------------===//\n"
336 << "IR after bufferizing: " << nextOp->getName();
337 rewriter.setInsertionPoint(nextOp);
338 if (failed(
339 bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
340 LDBG(2) << "failed to bufferize\n"
341 << "//===-------------------------------------------===//";
342 return nextOp->emitError("failed to bufferize op");
343 }
344 LDBG(3) << *op << "\n//===-------------------------------------------===//";
345 }
346
347 // Return early if the top-level op is entirely gone.
348 if (erasedOps.contains(op))
349 return success();
350
351 // Fold all to_buffer(to_tensor(x)) pairs. Snapshot the set first:
352 // `foldToBufferToTensorPair` can erase ops, and the rewriter listener
353 // mutates `toBufferOps` from inside that call, which would invalidate
354 // any DenseSet iterator held across it.
355 SmallVector<Operation *> toBufferOpsSnapshot = llvm::to_vector(toBufferOps);
356 for (Operation *op : toBufferOpsSnapshot) {
357 if (erasedOps.contains(op))
358 continue;
359 rewriter.setInsertionPoint(op);
361 rewriter, cast<ToBufferOp>(op), options);
362 }
363
364 // Remove all dead to_tensor ops.
365 op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
366 if (toTensorOp->getUses().empty()) {
367 rewriter.eraseOp(toTensorOp);
368 return WalkResult::skip();
369 }
370 return WalkResult::advance();
371 });
372
373 /// Check the result of bufferization. Return an error if an op was not
374 /// bufferized, unless partial bufferization is allowed.
375 if (options.allowUnknownOps)
376 return success();
377
378 for (Operation *op : worklist) {
379 // Skip ops that are entirely gone.
380 if (erasedOps.contains(op))
381 continue;
382 // Ops that no longer have tensor semantics (because they were updated
383 // in-place) are allowed.
384 if (!hasTensorSemantics(op))
385 continue;
386 // Continue ops that are not allowed.
387 if (!options.isOpAllowed(op))
388 continue;
389 // Ops without any uses and no side effects will fold away.
390 if (op->getUses().empty() && isMemoryEffectFree(op))
391 continue;
392 // ToTensorOps/ToBufferOps are allowed in the output.
393 if (isa<ToTensorOp, ToBufferOp>(op))
394 continue;
395 return op->emitError("op was not bufferized");
396 }
397
398 return success();
399}
400
401LogicalResult
404 BufferizationState &state) {
405 OpBuilder::InsertionGuard g(rewriter);
406 auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
407 if (!bufferizableOp)
408 return failure();
409
410 // Compute the new signature.
411 SmallVector<Type> newTypes;
412 for (BlockArgument &bbArg : block->getArguments()) {
413 auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
414 if (!tensorType) {
415 newTypes.push_back(bbArg.getType());
416 continue;
417 }
418
419 FailureOr<BufferLikeType> bufferType =
420 bufferization::getBufferType(bbArg, options, state);
421 if (failed(bufferType))
422 return failure();
423 newTypes.push_back(*bufferType);
424 }
425
426 // Change the type of all block arguments.
427 for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
428 if (bbArg.getType() == type)
429 continue;
430
431 // Collect all uses of the bbArg.
432 SmallVector<OpOperand *> bbArgUses;
433 for (OpOperand &use : bbArg.getUses())
434 bbArgUses.push_back(&use);
435
436 Type tensorType = bbArg.getType();
437 // Change the bbArg type to memref.
438 bbArg.setType(type);
439
440 // Replace all uses of the original tensor bbArg.
441 rewriter.setInsertionPointToStart(block);
442 if (!bbArgUses.empty()) {
443 Value toTensorOp = bufferization::ToTensorOp::create(
444 rewriter, bbArg.getLoc(), tensorType, bbArg);
445 for (OpOperand *use : bbArgUses)
446 use->set(toTensorOp);
447 }
448 }
449
450 // Bufferize callers of the block.
451 for (Operation *op : block->getUsers()) {
452 auto branchOp = dyn_cast<BranchOpInterface>(op);
453 if (!branchOp)
454 return op->emitOpError("cannot bufferize ops with block references that "
455 "do not implement BranchOpInterface");
456
457 auto it = llvm::find(op->getSuccessors(), block);
458 assert(it != op->getSuccessors().end() && "could find successor");
459 int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
460
461 SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
462 SmallVector<Value> newOperands;
463 for (auto [operand, type] :
464 llvm::zip(operands.getForwardedOperands(), newTypes)) {
465 if (operand.getType() == type) {
466 // Not a tensor type. Nothing to do for this operand.
467 newOperands.push_back(operand);
468 continue;
469 }
470 FailureOr<BufferLikeType> operandBufferType =
471 bufferization::getBufferType(operand, options, state);
472 if (failed(operandBufferType))
473 return failure();
474 rewriter.setInsertionPointAfterValue(operand);
475 Value bufferizedOperand = bufferization::ToBufferOp::create(
476 rewriter, operand.getLoc(), *operandBufferType, operand);
477 // A cast is needed if the operand and the block argument have different
478 // bufferized types.
479 if (type != *operandBufferType)
480 bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
481 type, bufferizedOperand);
482 newOperands.push_back(bufferizedOperand);
483 }
484 operands.getMutableForwardedOperands().assign(newOperands);
485 }
486
487 return success();
488}
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
SuccessorRange getSuccessors()
Definition Operation.h:728
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:871
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.