MLIR 22.0.0git
EmptyTensorToAllocTensor.cpp
Go to the documentation of this file.
1//===- InitTensorToAllocTensor.cpp - Lower tensor.empty to alloc_tensor ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15
16namespace mlir {
17namespace bufferization {
18#define GEN_PASS_DEF_EMPTYTENSORTOALLOCTENSORPASS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20} // namespace bufferization
21} // namespace mlir
22
23using namespace mlir;
24using namespace mlir::bufferization;
25using namespace mlir::tensor;
26
27namespace {
28struct EmptyTensorLoweringPattern : public OpRewritePattern<tensor::EmptyOp> {
29 using OpRewritePattern<tensor::EmptyOp>::OpRewritePattern;
30
31 LogicalResult matchAndRewrite(tensor::EmptyOp op,
32 PatternRewriter &rewriter) const override {
33 rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
34 op, op.getType(), op.getDynamicSizes());
35 return success();
36 }
37};
38
39struct EmptyTensorToAllocTensor
41 EmptyTensorToAllocTensor> {
42 void runOnOperation() override;
43
44 void getDependentDialects(DialectRegistry &registry) const override {
45 registry
46 .insert<tensor::TensorDialect, bufferization::BufferizationDialect>();
47 }
48};
49} // namespace
50
53 patterns.insert<EmptyTensorLoweringPattern>(patterns.getContext());
54}
55
56void EmptyTensorToAllocTensor::runOnOperation() {
57 Operation *op = getOperation();
60 if (failed(applyPatternsGreedily(op, std::move(patterns))))
61 signalPassFailure();
62}
return success()
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateEmptyTensorToAllocTensorPattern(RewritePatternSet &patterns)
Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...