MLIR  22.0.0git
ComplexToROCDLLibraryCalls.cpp
Go to the documentation of this file.
1 //=== ComplexToROCDLLibraryCalls.cpp - convert from Complex to ROCDL calls ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 #define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
17 #include "mlir/Conversion/Passes.h.inc"
18 } // namespace mlir
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 template <typename Op, typename FloatTy>
25 // Pattern to convert Complex ops to ROCDL function calls.
26 struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
28  ComplexOpToROCDLLibraryCalls(MLIRContext *context, StringRef funcName,
29  PatternBenefit benefit = 1)
30  : OpRewritePattern<Op>(context, benefit), funcName(funcName) {}
31 
32  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final {
34  Type resType = op.getType();
35  if (auto complexType = dyn_cast<ComplexType>(resType))
36  resType = complexType.getElementType();
37  if (!isa<FloatTy>(resType))
38  return failure();
39 
40  auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
41  SymbolTable::lookupSymbolIn(symTable, funcName));
42  if (!opFunc) {
43  OpBuilder::InsertionGuard guard(rewriter);
44  rewriter.setInsertionPointToStart(&symTable->getRegion(0).front());
45  auto funcTy = FunctionType::get(
46  rewriter.getContext(), op->getOperandTypes(), op->getResultTypes());
47  opFunc = func::FuncOp::create(rewriter, rewriter.getUnknownLoc(),
48  funcName, funcTy);
49  opFunc.setPrivate();
50  }
51  rewriter.replaceOpWithNewOp<func::CallOp>(op, funcName, op.getType(),
52  op->getOperands());
53  return success();
54  }
55 
56 private:
57  std::string funcName;
58 };
59 
60 // Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
61 struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
63 
64  LogicalResult matchAndRewrite(complex::PowOp op,
65  PatternRewriter &rewriter) const final {
66  Location loc = op.getLoc();
67  Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
68  Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
69  Value exp = complex::ExpOp::create(rewriter, loc, mul);
70  rewriter.replaceOp(op, exp);
71  return success();
72  }
73 };
74 } // namespace
75 
78  patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
79  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
80  patterns.getContext(), "__ocml_cabs_f32");
81  patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
82  patterns.getContext(), "__ocml_cabs_f64");
83  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
84  patterns.getContext(), "__ocml_ccos_f32");
85  patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
86  patterns.getContext(), "__ocml_ccos_f64");
87  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
88  patterns.getContext(), "__ocml_cexp_f32");
89  patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
90  patterns.getContext(), "__ocml_cexp_f64");
91  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
92  patterns.getContext(), "__ocml_clog_f32");
93  patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
94  patterns.getContext(), "__ocml_clog_f64");
95  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
96  patterns.getContext(), "__ocml_csin_f32");
97  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
98  patterns.getContext(), "__ocml_csin_f64");
99  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
100  patterns.getContext(), "__ocml_csqrt_f32");
101  patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
102  patterns.getContext(), "__ocml_csqrt_f64");
103  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
104  patterns.getContext(), "__ocml_ctan_f32");
105  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
106  patterns.getContext(), "__ocml_ctan_f64");
107  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
108  patterns.getContext(), "__ocml_ctanh_f32");
109  patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
110  patterns.getContext(), "__ocml_ctanh_f64");
111 }
112 
113 namespace {
114 struct ConvertComplexToROCDLLibraryCallsPass
115  : public impl::ConvertComplexToROCDLLibraryCallsBase<
116  ConvertComplexToROCDLLibraryCallsPass> {
117  void runOnOperation() override;
118 };
119 } // namespace
120 
121 void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
122  Operation *op = getOperation();
123 
126 
127  ConversionTarget target(getContext());
128  target.addLegalDialect<func::FuncDialect>();
129  target.addLegalOp<complex::MulOp>();
130  target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
131  complex::LogOp, complex::PowOp, complex::SinOp,
132  complex::SqrtOp, complex::TanOp, complex::TanhOp>();
133  if (failed(applyPartialConversion(op, target, std::move(patterns))))
134  signalPassFailure();
135 }
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:24
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
Block & front()
Definition: Region.h:65
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateComplexToROCDLLibraryCallsConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to ROCDL calls.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314