MLIR 23.0.0git
CategoryToNamedOp.cpp
Go to the documentation of this file.
1//===- CategoryToNamedOp.cpp - convert category ops to linalg named ops ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewriting of linalg category ops (e.g.
10// `linalg.elementwise`) to their equivalent named ops (e.g. `linalg.add`,
11// `linalg.exp`). This is the reverse of NamedToElementwise.cpp.
12//
13//===----------------------------------------------------------------------===//
14
18
19using namespace mlir;
20using namespace mlir::linalg;
21
22#define DEBUG_TYPE "linalg-category-to-named"
23
24namespace {
25struct ElementwiseToNamedPattern : public OpRewritePattern<ElementwiseOp> {
26 using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
27
28 LogicalResult matchAndRewrite(ElementwiseOp op,
29 PatternRewriter &rewriter) const override {
30 // Named elementwise ops only support identity indexing maps.
31 if (!op.getIndexingMapsArray().empty() &&
32 !llvm::all_of(op.getIndexingMapsArray(),
33 [](AffineMap map) { return map.isIdentity(); }))
34 return failure();
35
36 auto inputs = op.getDpsInputs();
37 auto inits = op.getDpsInits();
38 auto loc = op.getLoc();
39
40 // Helper to create a named op and replace the elementwise op.
41 auto replaceWith = [&](auto namedOp) {
42 using OpTy = decltype(namedOp);
43 rewriter.replaceOp(op, OpTy::create(rewriter, loc, inputs, inits,
45 return success();
46 };
47
48 switch (op.getKind()) {
49 case ElementwiseKind::exp:
50 return replaceWith(ExpOp{});
51 case ElementwiseKind::log:
52 return replaceWith(LogOp{});
53 case ElementwiseKind::abs:
54 return replaceWith(AbsOp{});
55 case ElementwiseKind::ceil:
56 return replaceWith(CeilOp{});
57 case ElementwiseKind::floor:
58 return replaceWith(FloorOp{});
59 case ElementwiseKind::negf:
60 return replaceWith(NegFOp{});
61 case ElementwiseKind::reciprocal:
62 return replaceWith(ReciprocalOp{});
63 case ElementwiseKind::round:
64 return replaceWith(RoundOp{});
65 case ElementwiseKind::sqrt:
66 return replaceWith(SqrtOp{});
67 case ElementwiseKind::rsqrt:
68 return replaceWith(RsqrtOp{});
69 case ElementwiseKind::square:
70 return replaceWith(SquareOp{});
71 case ElementwiseKind::tanh:
72 return replaceWith(TanhOp{});
73 case ElementwiseKind::erf:
74 return replaceWith(ErfOp{});
75 case ElementwiseKind::add:
76 return replaceWith(AddOp{});
77 case ElementwiseKind::sub:
78 return replaceWith(SubOp{});
79 case ElementwiseKind::mul:
80 return replaceWith(MulOp{});
81 case ElementwiseKind::div:
82 return replaceWith(DivOp{});
83 case ElementwiseKind::div_unsigned:
84 return replaceWith(DivUnsignedOp{});
85 case ElementwiseKind::max_signed:
86 return replaceWith(MaxOp{});
87 case ElementwiseKind::min_signed:
88 return replaceWith(MinOp{});
89 case ElementwiseKind::powf:
90 return replaceWith(PowFOp{});
91 case ElementwiseKind::select:
92 return replaceWith(SelectOp{});
93 default:
94 return failure();
95 }
96 }
97};
98} // namespace
99
101 RewritePatternSet &patterns) {
102 patterns.add<ElementwiseToNamedPattern>(patterns.getContext());
103}
return success()
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void populateLinalgCategoryToNamedPatterns(RewritePatternSet &patterns)
Populates patterns that convert linalg category ops (e.g.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...