MLIR  19.0.0git
Specialize.cpp
Go to the documentation of this file.
1 //===- Specialize.cpp - linalg generic ops to named ops ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a method to specialize generic operations to named
10 // operations. Conceptually it is the opposite of generalize.cpp.
11 //
12 //===----------------------------------------------------------------------===//
13 
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "linalg-specialization"
21 
22 #define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
23  (rewriter.replaceOpWithNewOp<NEWOP>( \
24  genericOp, \
25  ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
26  genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
27  ValueRange{genericOp.getDpsInits()[0]}))
28 
29 #define REPLACE_UNARY_OP(NEWOP) \
30  (rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
31  ValueRange{genericOp.getDpsInputs()[0]}, \
32  ValueRange{genericOp.getDpsInits()[0]}))
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 
37 // Given a elementwise single binary linalg generic op, checks whether the
38 // binary op accesses operands as swapped. e.g.
39 // this differentiates between a linalg-generic body that contains:
40 // ^bb0(%a: f32, %b: f32, %c : f32):
41 // %0 = arith.subf %a, %b : f32
42 // linalg.yield %0: f32
43 // against:
44 // ^bb0(%a: f32, %b: f32, %c : f32):
45 // %0 = arith.subf %b, %a : f32
46 // linalg.yield %0: f32
47 // Former is linalg.sub(a,b), latter is linalg.sub(b,a).
48 static bool areBinOpsSwapped(GenericOp genericOp) {
49  Block *body = genericOp.getBody();
50  Operation *op = &body->front();
51  bool swapped = false;
52  if (op->getOpOperand(0).get() != body->getArgument(0)) {
53  swapped = true;
54  assert(op->getOpOperand(0).get() == body->getArgument(1) &&
55  op->getOpOperand(1).get() == body->getArgument(0) &&
56  "binary op uses just one block arg");
57  }
58  return swapped;
59 }
60 
62  GenericOp genericOp) {
63  if (isaCopyOpInterface(genericOp)) {
64  LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
65  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
66  return namedOp;
67  }
68 
69  if (isaFillOpInterface(genericOp)) {
70  LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
71  genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
72  return namedOp;
73  }
74 
75  if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
76  Operation *op = &genericOp.getBody()->front();
77  if (isa<math::ExpOp>(op)) {
78  LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
79  return namedOp;
80  }
81  }
82 
83  if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
84  bool swap = areBinOpsSwapped(genericOp);
85  Operation *op = &genericOp.getBody()->front();
86  if (isa<arith::AddFOp>(op)) {
87  LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
88  return namedOp;
89  }
90  if (isa<arith::SubFOp>(op)) {
91  LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
92  return namedOp;
93  }
94  if (isa<arith::MulFOp>(op)) {
95  LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
96  return namedOp;
97  }
98  if (isa<arith::DivFOp>(op)) {
99  LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
100  return namedOp;
101  }
102  }
103  return failure();
104 }
#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP)
Definition: Specialize.cpp:22
static bool areBinOpsSwapped(GenericOp genericOp)
Definition: Specialize.cpp:48
#define REPLACE_UNARY_OP(NEWOP)
Definition: Specialize.cpp:29
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Operation & front()
Definition: Block.h:151
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpOperand & getOpOperand(unsigned idx)
Definition: Operation.h:383
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< LinalgOp > specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp)
Create a namedOp from the given GenericOp and replace the GenericOp.
Definition: Specialize.cpp:61
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62