MLIR  16.0.0git
ComplexToLibm.cpp
Go to the documentation of this file.
1 //===-- ComplexToLibm.cpp - conversion from Complex to libm calls ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLIBM
18 #include "mlir/Conversion/Passes.h.inc"
19 } // namespace mlir
20 
21 using namespace mlir;
22 
23 namespace {
24 // Functor to resolve the function name corresponding to the given complex
25 // result type.
26 struct ComplexTypeResolver {
27  llvm::Optional<bool> operator()(Type type) const {
28  auto complexType = type.cast<ComplexType>();
29  auto elementType = complexType.getElementType();
30  if (!elementType.isa<Float32Type, Float64Type>())
31  return {};
32 
33  return elementType.getIntOrFloatBitWidth() == 64;
34  }
35 };
36 
37 // Functor to resolve the function name corresponding to the given float result
38 // type.
39 struct FloatTypeResolver {
40  llvm::Optional<bool> operator()(Type type) const {
41  auto elementType = type.cast<FloatType>();
42  if (!elementType.isa<Float32Type, Float64Type>())
43  return {};
44 
45  return elementType.getIntOrFloatBitWidth() == 64;
46  }
47 };
48 
49 // Pattern to convert scalar complex operations to calls to libm functions.
50 // Additionally the libm function signatures are declared.
51 // TypeResolver is a functor returning the libm function name according to the
52 // expected type double or float.
53 template <typename Op, typename TypeResolver = ComplexTypeResolver>
54 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
55 public:
57  ScalarOpToLibmCall<Op, TypeResolver>(MLIRContext *context,
58  StringRef floatFunc,
59  StringRef doubleFunc,
60  PatternBenefit benefit)
61  : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
62  doubleFunc(doubleFunc){};
63 
64  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
65 
66 private:
67  std::string floatFunc, doubleFunc;
68 };
69 } // namespace
70 
71 template <typename Op, typename TypeResolver>
72 LogicalResult ScalarOpToLibmCall<Op, TypeResolver>::matchAndRewrite(
73  Op op, PatternRewriter &rewriter) const {
74  auto module = SymbolTable::getNearestSymbolTable(op);
75  auto isDouble = TypeResolver()(op.getType());
76  if (!isDouble.has_value())
77  return failure();
78 
79  auto name = isDouble.value() ? doubleFunc : floatFunc;
80 
81  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
82  SymbolTable::lookupSymbolIn(module, name));
83  // Forward declare function if it hasn't already been
84  if (!opFunc) {
85  OpBuilder::InsertionGuard guard(rewriter);
86  rewriter.setInsertionPointToStart(&module->getRegion(0).front());
87  auto opFunctionTy = FunctionType::get(
88  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
89  opFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(), name,
90  opFunctionTy);
91  opFunc.setPrivate();
92  }
93  assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
94 
95  rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
96  op->getOperands());
97 
98  return success();
99 }
100 
102  PatternBenefit benefit) {
103  patterns.add<ScalarOpToLibmCall<complex::PowOp>>(patterns.getContext(),
104  "cpowf", "cpow", benefit);
105  patterns.add<ScalarOpToLibmCall<complex::SqrtOp>>(patterns.getContext(),
106  "csqrtf", "csqrt", benefit);
107  patterns.add<ScalarOpToLibmCall<complex::TanhOp>>(patterns.getContext(),
108  "ctanhf", "ctanh", benefit);
109  patterns.add<ScalarOpToLibmCall<complex::CosOp>>(patterns.getContext(),
110  "ccosf", "ccos", benefit);
111  patterns.add<ScalarOpToLibmCall<complex::SinOp>>(patterns.getContext(),
112  "csinf", "csin", benefit);
113  patterns.add<ScalarOpToLibmCall<complex::ConjOp>>(patterns.getContext(),
114  "conjf", "conj", benefit);
115  patterns.add<ScalarOpToLibmCall<complex::LogOp>>(patterns.getContext(),
116  "clogf", "clog", benefit);
117  patterns.add<ScalarOpToLibmCall<complex::AbsOp, FloatTypeResolver>>(
118  patterns.getContext(), "cabsf", "cabs", benefit);
119  patterns.add<ScalarOpToLibmCall<complex::AngleOp, FloatTypeResolver>>(
120  patterns.getContext(), "cargf", "carg", benefit);
121 }
122 
123 namespace {
124 struct ConvertComplexToLibmPass
125  : public impl::ConvertComplexToLibmBase<ConvertComplexToLibmPass> {
126  void runOnOperation() override;
127 };
128 } // namespace
129 
130 void ConvertComplexToLibmPass::runOnOperation() {
131  auto module = getOperation();
132 
133  RewritePatternSet patterns(&getContext());
134  populateComplexToLibmConversionPatterns(patterns, /*benefit=*/1);
135 
136  ConversionTarget target(getContext());
137  target.addLegalDialect<func::FuncDialect>();
138  target.addIllegalOp<complex::PowOp, complex::SqrtOp, complex::TanhOp,
139  complex::CosOp, complex::SinOp, complex::ConjOp,
140  complex::LogOp, complex::AbsOp, complex::AngleOp>();
141  if (failed(applyPartialConversion(module, target, std::move(patterns))))
142  signalPassFailure();
143 }
144 
145 std::unique_ptr<OperationPass<ModuleOp>>
147  return std::make_unique<ConvertComplexToLibmPass>();
148 }
Location getUnknownLoc()
Definition: Builders.cpp:26
Include the generated interface declarations.
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of &#39;symbolTableOp&#39;.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
std::unique_ptr< OperationPass< ModuleOp > > createConvertComplexToLibmPass()
Create a pass to convert Complex operations to libm calls.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:382
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:299
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
void populateComplexToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit)
Populate the given list with patterns that convert from Complex to Libm calls.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:92
This provides public APIs that all operations should have.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class describes a specific conversion target.
MLIRContext * getContext() const
U cast() const
Definition: Types.h:279