MLIR  22.0.0git
NamedToElementwise.cpp
Go to the documentation of this file.
1 //===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting those linalg named ops that are essentially
10 // elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
11 // optimization on `linalg.elementwise` such as folding transpose, broadcast.
12 //
13 //===----------------------------------------------------------------------===//
14 
18 #include "mlir/IR/PatternMatch.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 #define DEBUG_TYPE "linalg-named-to-elementwise"
27 
28 namespace {
29 ElementwiseKind getKind(Operation *op) {
31  .Case([](SelectOp) { return ElementwiseKind::select; })
32  .Case([](AddOp) { return ElementwiseKind::add; })
33  .Case([](SubOp) { return ElementwiseKind::sub; })
34  .Case([](MulOp) { return ElementwiseKind::mul; })
35  .Case([](DivOp) { return ElementwiseKind::div; })
36  .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
37  .Case([](PowFOp) { return ElementwiseKind::powf; })
38  .Case([](ExpOp) { return ElementwiseKind::exp; })
39  .Case([](LogOp) { return ElementwiseKind::log; })
40  .Case([](AbsOp) { return ElementwiseKind::abs; })
41  .Case([](CeilOp) { return ElementwiseKind::ceil; })
42  .Case([](FloorOp) { return ElementwiseKind::floor; })
43  .Case([](NegFOp) { return ElementwiseKind::negf; })
44  .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
45  .Case([](RoundOp) { return ElementwiseKind::round; })
46  .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
47  .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
48  .Case([](SquareOp) { return ElementwiseKind::square; })
49  .Case([](TanhOp) { return ElementwiseKind::tanh; })
50  .Case([](ErfOp) { return ElementwiseKind::erf; })
51  .Default([&](Operation *op) {
52  llvm_unreachable("unhandled case in named to elementwise");
53  return ElementwiseKind::sub;
54  });
55 }
56 
57 template <typename NamedOpTy>
58 struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
60 
61  LogicalResult matchAndRewrite(NamedOpTy op,
62  PatternRewriter &rewriter) const override {
64  auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
65  attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
66  attrs.push_back(
67  rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
68 
69  rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
70  op.getDpsInits(), attrs);
71  return success();
72  }
73 };
74 } // namespace
75 
78  patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
79  patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
80  patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
81  patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
82  patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
83  patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
84  patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
85  patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
86  patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
87  patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
88  patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
89  patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
90  patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
91  patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
92  patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
93  patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
94  patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
95  patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
96  patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
97  patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
98 }
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:93
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns)
Populates patterns that convert linalg named ops e.g.
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314