MLIR 22.0.0git
NamedToElementwise.cpp
Go to the documentation of this file.
1//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting those linalg named ops that are essentially
10// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
11// optimization on `linalg.elementwise` such as folding transpose, broadcast.
12//
13//===----------------------------------------------------------------------===//
14
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23using namespace mlir;
24using namespace mlir::linalg;
25
26#define DEBUG_TYPE "linalg-named-to-elementwise"
27
28namespace {
29ElementwiseKind getKind(Operation *op) {
31 .Case([](SelectOp) { return ElementwiseKind::select; })
32 .Case([](AddOp) { return ElementwiseKind::add; })
33 .Case([](SubOp) { return ElementwiseKind::sub; })
34 .Case([](MulOp) { return ElementwiseKind::mul; })
35 .Case([](DivOp) { return ElementwiseKind::div; })
36 .Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
37 .Case([](PowFOp) { return ElementwiseKind::powf; })
38 .Case([](ExpOp) { return ElementwiseKind::exp; })
39 .Case([](LogOp) { return ElementwiseKind::log; })
40 .Case([](AbsOp) { return ElementwiseKind::abs; })
41 .Case([](CeilOp) { return ElementwiseKind::ceil; })
42 .Case([](FloorOp) { return ElementwiseKind::floor; })
43 .Case([](NegFOp) { return ElementwiseKind::negf; })
44 .Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
45 .Case([](RoundOp) { return ElementwiseKind::round; })
46 .Case([](SqrtOp) { return ElementwiseKind::sqrt; })
47 .Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
48 .Case([](SquareOp) { return ElementwiseKind::square; })
49 .Case([](TanhOp) { return ElementwiseKind::tanh; })
50 .Case([](ErfOp) { return ElementwiseKind::erf; })
51 .DefaultUnreachable("unhandled case in named to elementwise");
52}
53
54template <typename NamedOpTy>
55struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
57
58 LogicalResult matchAndRewrite(NamedOpTy op,
59 PatternRewriter &rewriter) const override {
61 auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
62 attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
63 attrs.push_back(
64 rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
65
66 rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
67 op.getDpsInits(), attrs);
68 return success();
69 }
70};
71} // namespace
72
75 patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
76 patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
77 patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
78 patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
79 patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
80 patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
81 patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
82 patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
83 patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
84 patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
85 patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
86 patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
87 patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
88 patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
89 patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
90 patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
91 patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
92 patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
93 patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
94 patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
95}
return success()
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns)
Populates patterns that convert linalg named ops e.g.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...