MLIR  20.0.0git
EmptyTensorToAllocTensor.cpp
Go to the documentation of this file.
1 //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace bufferization {
19 #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSOR
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21 } // namespace bufferization
22 } // namespace mlir
23 
24 using namespace mlir;
25 using namespace mlir::bufferization;
26 using namespace mlir::tensor;
27 
28 namespace {
29 struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
31 
32  LogicalResult matchAndRewrite(tensor::EmptyOp op,
33  PatternRewriter &rewriter) const override {
34  rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
35  op, op.getType(), op.getDynamicSizes());
36  return success();
37  }
38 };
39 
40 struct EmptyTensorToAllocTensor
41  : public bufferization::impl::EmptyTensorToAllocTensorBase<
42  EmptyTensorToAllocTensor> {
43  EmptyTensorToAllocTensor() = default;
44 
45  void runOnOperation() override;
46 
47  void getDependentDialects(DialectRegistry &registry) const override {
48  registry
49  .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
50  }
51 };
52 } // namespace
53 
55  RewritePatternSet &patterns) {
56  patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext());
57 }
58 
59 void EmptyTensorToAllocTensor::runOnOperation() {
60  Operation *op = getOperation();
61  RewritePatternSet patterns(op->getContext());
63  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
64  signalPassFailure();
65 }
66 
67 std::unique_ptr<Pass>
69  return std::make_unique<EmptyTensorToAllocTensor>();
70 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
MLIRContext * getContext() const
Definition: PatternMatch.h:823
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void populateEmptyTensorToAllocTensorPattern(RewritePatternSet &patterns)
Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor ops.
std::unique_ptr< Pass > createEmptyTensorToAllocTensorPass()
Create a pass that rewrites tensor.empty to bufferization.alloc_tensor.
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358