MLIR  16.0.0git
DenseBufferizationPass.cpp
Go to the documentation of this file.
1 //===- DenseBufferizationPass.cpp - Dense bufferization pass --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 using namespace mlir;
18 using namespace mlir::func;
19 
20 namespace mlir {
21 namespace sparse_tensor {
22 
23 /// Return `true` if one of the given types is a sparse tensor type.
24 static bool containsSparseTensor(TypeRange types) {
25  for (Type t : types)
27  return true;
28  return false;
29 }
30 
31 /// A pass that bufferizes only dense tensor ops and ignores all sparse tensor
32 /// ops. No buffer copies are inserted. All tensor OpOperands must be
33 /// inplacable.
35  : public PassWrapper<BufferizeDenseOpsPass, OperationPass<ModuleOp>> {
36 public:
39  : options(options) {}
40 
41  void runOnOperation() override {
42  // Disallow all sparse tensor ops, so that only dense tensor ops are
43  // bufferized.
44  bufferization::OpFilter opFilter;
45  opFilter.allowOperation([&](Operation *op) {
48  return false;
49  if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
50  FunctionType funcType = funcOp.getFunctionType();
51  if (containsSparseTensor(funcType.getInputs()) ||
52  containsSparseTensor(funcType.getResults()))
53  return false;
54  }
55  return true;
56  });
57 
58  if (failed(bufferization::bufferizeOp(getOperation(), options,
59  /*copyBeforeWrite=*/false,
60  &opFilter)))
61  signalPassFailure();
62  }
63 
64 private:
66 };
67 } // namespace sparse_tensor
68 } // namespace mlir
69 
70 std::unique_ptr<Pass> mlir::createDenseBufferizationPass(
72  return std::make_unique<mlir::sparse_tensor::BufferizeDenseOpsPass>(options);
73 }
Include the generated interface declarations.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static bool containsSparseTensor(TypeRange types)
Return true if one of the given types is a sparse tensor type.
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
void allowOperation()
Allow the given ops.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
BufferizeDenseOpsPass(const bufferization::OneShotBufferizationOptions &options)
A pass that bufferizes only dense tensor ops and ignores all sparse tensor ops.
std::unique_ptr< Pass > createDenseBufferizationPass(const bufferization::OneShotBufferizationOptions &options)
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
static llvm::ManagedStatic< PassManagerOptions > options
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, bool copyBeforeWrite=true, const OpFilter *opFilter=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
Definition: Bufferize.cpp:388
result_range getResults()
Definition: Operation.h:332
Options for analysis-enabled bufferization.
This class provides a CRTP wrapper around a base pass class to define several necessary utility metho...
Definition: Pass.h:440