MLIR  19.0.0git
EmptyTensorToAllocTensor.cpp
Go to the documentation of this file.
1 //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace bufferization {
19 #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSOR
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21 } // namespace bufferization
22 } // namespace mlir
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::tensor;
27 
28 namespace {
29 struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
31 
32  LogicalResult matchAndRewrite(tensor::EmptyOp op,
33  PatternRewriter &rewriter) const override {
34  rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
35  op, op.getType(), op.getDynamicSizes());
36  return success();
37  }
38 };
39 
40 struct EmptyTensorToAllocTensor
41  : public bufferization::impl::EmptyTensorToAllocTensorBase<
42  EmptyTensorToAllocTensor> {
43  EmptyTensorToAllocTensor() = default;
44 
45  void runOnOperation() override;
46 
47  void getDependentDialects(DialectRegistry &registry) const override {
48  registry
49  .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
50  }
51 };
52 } // namespace
53 
55  RewritePatternSet &patterns) {
56  patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext());
57 }
58 
59 void EmptyTensorToAllocTensor::runOnOperation() {
60  Operation *op = getOperation();
61  RewritePatternSet patterns(op->getContext());
63  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
64  signalPassFailure();
65 }
66 
67 std::unique_ptr<Pass>
69  return std::make_unique<EmptyTensorToAllocTensor>();
70 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void populateEmptyTensorToAllocTensorPattern(RewritePatternSet &patterns)
Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor ops.
std::unique_ptr< Pass > createEmptyTensorToAllocTensorPass()
Create a pass that rewrites tensor.empty to bufferization.alloc_tensor.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358