MLIR 23.0.0git
Bufferize.cpp
Go to the documentation of this file.
1//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
18#include "mlir/IR/Diagnostics.h"
19#include "mlir/IR/Operation.h"
23#include "llvm/Support/DebugLog.h"
24#include <optional>
25
26namespace mlir {
27namespace bufferization {
28#define GEN_PASS_DEF_ONESHOTBUFFERIZEPASS
29#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
30} // namespace bufferization
31} // namespace mlir
32
33#define DEBUG_TYPE "bufferize"
34
35using namespace mlir;
36using namespace mlir::bufferization;
37
38namespace {
39
41parseHeuristicOption(const std::string &s) {
42 if (s == "bottom-up")
44 if (s == "top-down")
46 if (s == "bottom-up-from-terminators")
47 return OneShotBufferizationOptions::AnalysisHeuristic::
48 BottomUpFromTerminators;
49 if (s == "fuzzer")
51 llvm_unreachable("invalid analysisheuristic option");
52}
53
54struct OneShotBufferizePass
55 : public bufferization::impl::OneShotBufferizePassBase<
56 OneShotBufferizePass> {
57 using Base::Base;
58
59 void runOnOperation() override {
61 if (!options) {
62 // Make new bufferization options if none were provided when creating the
63 // pass.
64 opt.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
65 opt.allowUnknownOps = allowUnknownOps;
66 opt.analysisFuzzerSeed = analysisFuzzerSeed;
67 opt.analysisHeuristic = parseHeuristicOption(analysisHeuristic);
68 opt.copyBeforeWrite = copyBeforeWrite;
69 opt.dumpAliasSets = dumpAliasSets;
70 opt.setFunctionBoundaryTypeConversion(functionBoundaryTypeConversion);
71
72 if (mustInferMemorySpace && useEncodingForMemorySpace) {
73 emitError(getOperation()->getLoc())
74 << "only one of 'must-infer-memory-space' and "
75 "'use-encoding-for-memory-space' are allowed in "
76 << getArgument();
77 return signalPassFailure();
78 }
79
80 if (mustInferMemorySpace) {
81 opt.defaultMemorySpaceFn =
82 [](TensorType t) -> std::optional<Attribute> {
83 return std::nullopt;
84 };
85 }
86
87 if (useEncodingForMemorySpace) {
88 opt.defaultMemorySpaceFn =
89 [](TensorType t) -> std::optional<Attribute> {
90 if (auto rtt = dyn_cast<RankedTensorType>(t))
91 return rtt.getEncoding();
92 return std::nullopt;
93 };
94 }
95
96 opt.printConflicts = printConflicts;
97 opt.bufferAlignment = bufferAlignment;
98 opt.testAnalysisOnly = testAnalysisOnly;
99 opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
100 opt.checkParallelRegions = checkParallelRegions;
101 opt.noAnalysisFuncFilter = noAnalysisFuncFilter;
102
103 // Configure type converter.
104 LayoutMapOption unknownTypeConversionOption = unknownTypeConversion;
105 if (unknownTypeConversionOption == LayoutMapOption::InferLayoutMap) {
106 emitError(UnknownLoc::get(&getContext()),
107 "Invalid option: 'infer-layout-map' is not a valid value for "
108 "'unknown-type-conversion'");
109 return signalPassFailure();
110 }
111 opt.unknownTypeConverterFn = [=](TensorType tensorType,
112 Attribute memorySpace,
114 if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
115 return bufferization::getMemRefTypeWithStaticIdentityLayout(
116 tensorType, memorySpace);
117 assert(unknownTypeConversionOption ==
118 LayoutMapOption::FullyDynamicLayoutMap &&
119 "invalid layout map option");
120 return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
121 memorySpace);
122 };
123
124 // Configure op filter.
125 OpFilter::Entry::FilterFn filterFn = [&](Operation *op) {
126 // Filter may be specified via options.
127 if (this->dialectFilter.hasValue() && !(*this->dialectFilter).empty())
128 return llvm::is_contained(this->dialectFilter,
129 op->getDialect()->getNamespace());
130 // No filter specified: All other ops are allowed.
131 return true;
132 };
133 opt.opFilter.allowOperation(filterFn);
134 } else {
135 opt = *options;
136 }
137
138 if (opt.copyBeforeWrite && opt.testAnalysisOnly) {
139 // These two flags do not make sense together: "copy-before-write"
140 // indicates that copies should be inserted before every memory write,
141 // but "test-analysis-only" indicates that only the analysis should be
142 // tested. (I.e., no IR is bufferized.)
143 emitError(UnknownLoc::get(&getContext()),
144 "Invalid option: 'copy-before-write' cannot be used with "
145 "'test-analysis-only'");
146 return signalPassFailure();
147 }
148
149 if (opt.printConflicts && !opt.testAnalysisOnly) {
150 emitError(
151 UnknownLoc::get(&getContext()),
152 "Invalid option: 'print-conflicts' requires 'test-analysis-only'");
153 return signalPassFailure();
154 }
155
156 if (opt.dumpAliasSets && !opt.testAnalysisOnly) {
157 emitError(
158 UnknownLoc::get(&getContext()),
159 "Invalid option: 'dump-alias-sets' requires 'test-analysis-only'");
160 return signalPassFailure();
161 }
162
163 BufferizationState state;
164 BufferizationStatistics statistics;
165 ModuleOp moduleOp = getOperation();
166 if (opt.bufferizeFunctionBoundaries) {
167 if (failed(
168 runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
169 signalPassFailure();
170 return;
171 }
172 } else {
173 if (!opt.noAnalysisFuncFilter.empty()) {
174 emitError(UnknownLoc::get(&getContext()),
175 "Invalid option: 'no-analysis-func-filter' requires "
176 "'bufferize-function-boundaries'");
177 return signalPassFailure();
178 }
179 if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
180 signalPassFailure();
181 return;
182 }
183 }
184
185 // Set pass statistics.
186 this->numBufferAlloc = statistics.numBufferAlloc;
187 this->numTensorInPlace = statistics.numTensorInPlace;
188 this->numTensorOutOfPlace = statistics.numTensorOutOfPlace;
189 }
190
191private:
192 std::optional<OneShotBufferizationOptions> options;
193};
194} // namespace
195
196//===----------------------------------------------------------------------===//
197// BufferizableOpInterface-based Bufferization
198//===----------------------------------------------------------------------===//
199
200namespace {
201/// A rewriter that keeps track of extra information during bufferization.
202class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
203public:
204 BufferizationRewriter(MLIRContext *ctx, DenseSet<Operation *> &erasedOps,
205 DenseSet<Operation *> &toBufferOps,
206 SmallVector<Operation *> &worklist,
207 const BufferizationOptions &options,
208 BufferizationStatistics *statistics)
209 : IRRewriter(ctx), erasedOps(erasedOps), toBufferOps(toBufferOps),
210 worklist(worklist), analysisState(options), statistics(statistics) {
211 setListener(this);
212 }
213
214protected:
215 void notifyOperationErased(Operation *op) override {
216 erasedOps.insert(op);
217 // Erase if present.
218 toBufferOps.erase(op);
219 }
220
221 void notifyOperationInserted(Operation *op, InsertPoint previous) override {
222 // We only care about newly created ops.
223 if (previous.isSet())
224 return;
225
226 erasedOps.erase(op);
227
228 // Gather statistics about allocs.
229 if (statistics) {
230 if (auto sideEffectingOp = dyn_cast<MemoryEffectOpInterface>(op))
231 statistics->numBufferAlloc += static_cast<int64_t>(
232 sideEffectingOp.hasEffect<MemoryEffects::Allocate>());
233 }
234
235 // Keep track of to_buffer ops.
236 if (isa<ToBufferOp>(op)) {
237 toBufferOps.insert(op);
238 return;
239 }
240
241 // Skip to_tensor ops.
242 if (isa<ToTensorOp>(op))
243 return;
244
245 // Skip non-tensor ops.
246 if (!hasTensorSemantics(op))
247 return;
248
249 // Skip ops that are not allowed to be bufferized.
250 auto const &options = analysisState.getOptions();
251 if (!options.isOpAllowed(op))
252 return;
253
254 // Add op to worklist.
255 worklist.push_back(op);
256 }
257
258private:
259 /// A set of all erased ops.
260 DenseSet<Operation *> &erasedOps;
261
262 /// A set of all to_buffer ops.
263 DenseSet<Operation *> &toBufferOps;
264
265 /// The worklist of ops to be bufferized.
266 SmallVector<Operation *> &worklist;
267
268 /// The analysis state. Used for debug assertions and access to the
269 /// bufferization options.
270 const AnalysisState analysisState;
271
272 /// Bufferization statistics for debugging.
273 BufferizationStatistics *statistics;
274};
275} // namespace
276
279 BufferizationState &bufferizationState,
280 BufferizationStatistics *statistics) {
281 if (options.copyBeforeWrite) {
282 AnalysisState analysisState(options);
283 if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
284 return failure();
285 }
286
287 // Keep track of to_buffer ops.
288 DenseSet<Operation *> toBufferOps;
289 op->walk([&](ToBufferOp toBufferOp) { toBufferOps.insert(toBufferOp); });
290
291 // Gather all bufferizable ops in top-to-bottom order.
292 //
293 // We should ideally know the exact memref type of all operands when
294 // bufferizing an op. (This is the case when bufferizing top-to-bottom.)
295 // Otherwise, we have to use a memref type with a fully dynamic layout map to
296 // avoid copies. We are currently missing patterns for layout maps to
297 // canonicalize away (or canonicalize to more precise layouts).
299 op->walk<WalkOrder::PostOrder>([&](Operation *op) {
300 if (options.isOpAllowed(op) && hasTensorSemantics(op))
301 worklist.push_back(op);
302 });
303
304 // Keep track of all erased ops.
305 DenseSet<Operation *> erasedOps;
306
307 // Bufferize all ops.
308 BufferizationRewriter rewriter(op->getContext(), erasedOps, toBufferOps,
309 worklist, options, statistics);
310 for (unsigned i = 0; i < worklist.size(); ++i) {
311 Operation *nextOp = worklist[i];
312 // Skip ops that were erased.
313 if (erasedOps.contains(nextOp))
314 continue;
315 // Skip ops that are not bufferizable or not allowed.
316 auto bufferizableOp = options.dynCastBufferizableOp(nextOp);
317 if (!bufferizableOp)
318 continue;
319 // Skip ops that no longer have tensor semantics.
320 if (!hasTensorSemantics(nextOp))
321 continue;
322 // Check for unsupported unstructured control flow.
323 if (!bufferizableOp.supportsUnstructuredControlFlow())
324 for (Region &r : nextOp->getRegions())
325 if (r.getBlocks().size() > 1)
326 return nextOp->emitOpError(
327 "op or BufferizableOpInterface implementation does not support "
328 "unstructured control flow, but at least one region has multiple "
329 "blocks");
330
331 // Bufferize the op.
332 LDBG(3) << "//===-------------------------------------------===//\n"
333 << "IR after bufferizing: " << nextOp->getName();
334 rewriter.setInsertionPoint(nextOp);
335 if (failed(
336 bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
337 LDBG(2) << "failed to bufferize\n"
338 << "//===-------------------------------------------===//";
339 return nextOp->emitError("failed to bufferize op");
340 }
341 LDBG(3) << *op << "\n//===-------------------------------------------===//";
342 }
343
344 // Return early if the top-level op is entirely gone.
345 if (erasedOps.contains(op))
346 return success();
347
348 // Fold all to_buffer(to_tensor(x)) pairs. Snapshot the set first:
349 // `foldToBufferToTensorPair` can erase ops, and the rewriter listener
350 // mutates `toBufferOps` from inside that call, which would invalidate
351 // any DenseSet iterator held across it.
352 SmallVector<Operation *> toBufferOpsSnapshot = llvm::to_vector(toBufferOps);
353 for (Operation *op : toBufferOpsSnapshot) {
354 if (erasedOps.contains(op))
355 continue;
356 rewriter.setInsertionPoint(op);
358 rewriter, cast<ToBufferOp>(op), options);
359 }
360
361 // Remove all dead to_tensor ops.
362 op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
363 if (toTensorOp->getUses().empty()) {
364 rewriter.eraseOp(toTensorOp);
365 return WalkResult::skip();
366 }
367 return WalkResult::advance();
368 });
369
370 /// Check the result of bufferization. Return an error if an op was not
371 /// bufferized, unless partial bufferization is allowed.
372 if (options.allowUnknownOps)
373 return success();
374
375 for (Operation *op : worklist) {
376 // Skip ops that are entirely gone.
377 if (erasedOps.contains(op))
378 continue;
379 // Ops that no longer have tensor semantics (because they were updated
380 // in-place) are allowed.
381 if (!hasTensorSemantics(op))
382 continue;
383 // Continue ops that are not allowed.
384 if (!options.isOpAllowed(op))
385 continue;
386 // Ops without any uses and no side effects will fold away.
387 if (op->getUses().empty() && isMemoryEffectFree(op))
388 continue;
389 // ToTensorOps/ToBufferOps are allowed in the output.
390 if (isa<ToTensorOp, ToBufferOp>(op))
391 continue;
392 return op->emitError("op was not bufferized");
393 }
394
395 return success();
396}
397
398LogicalResult
401 BufferizationState &state) {
402 OpBuilder::InsertionGuard g(rewriter);
403 auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
404 if (!bufferizableOp)
405 return failure();
406
407 // Compute the new signature.
408 SmallVector<Type> newTypes;
409 for (BlockArgument &bbArg : block->getArguments()) {
410 auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
411 if (!tensorType) {
412 newTypes.push_back(bbArg.getType());
413 continue;
414 }
415
416 FailureOr<BufferLikeType> bufferType =
417 bufferization::getBufferType(bbArg, options, state);
418 if (failed(bufferType))
419 return failure();
420 newTypes.push_back(*bufferType);
421 }
422
423 // Change the type of all block arguments.
424 for (auto [bbArg, type] : llvm::zip(block->getArguments(), newTypes)) {
425 if (bbArg.getType() == type)
426 continue;
427
428 // Collect all uses of the bbArg.
429 SmallVector<OpOperand *> bbArgUses;
430 for (OpOperand &use : bbArg.getUses())
431 bbArgUses.push_back(&use);
432
433 Type tensorType = bbArg.getType();
434 // Change the bbArg type to memref.
435 bbArg.setType(type);
436
437 // Replace all uses of the original tensor bbArg.
438 rewriter.setInsertionPointToStart(block);
439 if (!bbArgUses.empty()) {
440 Value toTensorOp = bufferization::ToTensorOp::create(
441 rewriter, bbArg.getLoc(), tensorType, bbArg);
442 for (OpOperand *use : bbArgUses)
443 use->set(toTensorOp);
444 }
445 }
446
447 // Bufferize callers of the block.
448 for (Operation *op : block->getUsers()) {
449 auto branchOp = dyn_cast<BranchOpInterface>(op);
450 if (!branchOp)
451 return op->emitOpError("cannot bufferize ops with block references that "
452 "do not implement BranchOpInterface");
453
454 auto it = llvm::find(op->getSuccessors(), block);
455 assert(it != op->getSuccessors().end() && "could find successor");
456 int64_t successorIdx = std::distance(op->getSuccessors().begin(), it);
457
458 SuccessorOperands operands = branchOp.getSuccessorOperands(successorIdx);
459 SmallVector<Value> newOperands;
460 for (auto [operand, type] :
461 llvm::zip(operands.getForwardedOperands(), newTypes)) {
462 if (operand.getType() == type) {
463 // Not a tensor type. Nothing to do for this operand.
464 newOperands.push_back(operand);
465 continue;
466 }
467 FailureOr<BufferLikeType> operandBufferType =
468 bufferization::getBufferType(operand, options, state);
469 if (failed(operandBufferType))
470 return failure();
471 rewriter.setInsertionPointAfterValue(operand);
472 Value bufferizedOperand = bufferization::ToBufferOp::create(
473 rewriter, operand.getLoc(), *operandBufferType, operand);
474 // A cast is needed if the operand and the block argument have different
475 // bufferized types.
476 if (type != *operandBufferType)
477 bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(),
478 type, bufferizedOperand);
479 newOperands.push_back(bufferizedOperand);
480 }
481 operands.getMutableForwardedOperands().assign(newOperands);
482 }
483
484 return success();
485}
return success()
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
user_range getUsers() const
Returns a range of all users.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
void assign(ValueRange values)
Assign this range to the given values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:423
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
SuccessorRange getSuccessors()
Definition Operation.h:728
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition Operation.h:871
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class models how operands are forwarded to block arguments in control flow.
MutableOperandRange getMutableForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const
Get the range of operands that are simply forwarded to the successor.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
LogicalResult runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Bufferize on the given op: Analysis + Bufferization.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
Bufferization statistics for debugging.
Definition Bufferize.h:35
Options for analysis-enabled bufferization.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
bool dumpAliasSets
Specifies whether the tensor IR should be annotated with alias sets.
bool allowReturnAllocsFromLoops
Specifies whether returning newly allocated memrefs from loops should be allowed.
AnalysisHeuristic analysisHeuristic
The heuristic controls the order in which ops are traversed during the analysis.
llvm::ArrayRef< std::string > noAnalysisFuncFilter
Specify the functions that should not be analyzed.