MLIR  19.0.0git
RewriteAsConstant.cpp
Go to the documentation of this file.
1 //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/IR/PatternMatch.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 namespace {
18 
19 /// Rewrite tensor.generate with arith.constant if the yielded value is a
20 /// constant and the tensor type is static.
21 struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
23 
24  LogicalResult matchAndRewrite(GenerateOp generateOp,
25  PatternRewriter &rewriter) const override {
26  auto tensorType =
27  llvm::cast<RankedTensorType>(generateOp.getResult().getType());
28  if (!tensorType.hasStaticShape())
29  return failure();
30  auto terminatorOp =
31  cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
32  Attribute attr;
33  if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
34  return failure();
35  Operation *constantOp =
36  rewriter.getContext()
37  ->getLoadedDialect<TensorDialect>()
38  ->materializeConstant(rewriter,
39  DenseElementsAttr::get(tensorType, attr),
40  tensorType, generateOp->getLoc());
41  if (!constantOp)
42  return failure();
43  rewriter.replaceOp(generateOp, constantOp->getResults());
44  return success();
45  }
46 };
47 
48 } // namespace
49 
51  RewritePatternSet &patterns) {
52  patterns.add<GenerateToConstant>(patterns.getContext());
53 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358