MLIR  22.0.0git
EmptyTensorToAllocTensor.cpp
Go to the documentation of this file.
1 //===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSORPASS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::bufferization;
25 using namespace mlir::tensor;
26 
27 namespace {
28 struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
30 
31  LogicalResult matchAndRewrite(tensor::EmptyOp op,
32  PatternRewriter &rewriter) const override {
33  rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
34  op, op.getType(), op.getDynamicSizes());
35  return success();
36  }
37 };
38 
39 struct EmptyTensorToAllocTensor
40  : public bufferization::impl::EmptyTensorToAllocTensorPassBase<
41  EmptyTensorToAllocTensor> {
42  void runOnOperation() override;
43 
44  void getDependentDialects(DialectRegistry &registry) const override {
45  registry
46  .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
47  }
48 };
49 } // namespace
50 
53  patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext());
54 }
55 
56 void EmptyTensorToAllocTensor::runOnOperation() {
57  Operation *op = getOperation();
60  if (failed(applyPatternsGreedily(op, std::move(patterns))))
61  signalPassFailure();
62 }
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:767
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
void populateEmptyTensorToAllocTensorPattern(RewritePatternSet &patterns)
Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314