1 //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
11 #include "mlir/IR/Matchers.h"
12 #include "mlir/IR/PatternMatch.h"
14 using namespace mlir;
15 using namespace mlir::tensor;
17 namespace {
19 /// Rewrite tensor.generate with arith.constant if the yielded value is a
20 /// constant and the tensor type is static.
21 struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
24  LogicalResult matchAndRewrite(GenerateOp generateOp,
25  PatternRewriter &rewriter) const override {
26  auto tensorType =
27  llvm::cast<RankedTensorType>(generateOp.getResult().getType());
28  if (!tensorType.hasStaticShape())
29  return failure();
30  auto terminatorOp =
31  cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
32  Attribute attr;
33  if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
34  return failure();
35  Operation *constantOp =
36  rewriter.getContext()
37  ->getLoadedDialect<TensorDialect>()
38  ->materializeConstant(rewriter,
39  DenseElementsAttr::get(tensorType, attr),
40  tensorType, generateOp->getLoc());
41  if (!constantOp)
42  return failure();
43  rewriter.replaceOp(generateOp, constantOp->getResults());
44  return success();
45  }
46 };
48 } // namespace
51  RewritePatternSet &patterns) {
52  patterns.add<GenerateToConstant>(patterns.getContext());
53 }
