MLIR 23.0.0git
MathToLLVM.cpp
Go to the documentation of this file.
1//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19#include "mlir/Pass/Pass.h"
20
21#include "llvm/ADT/FloatingPointMode.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31
32template <typename SourceOp, typename TargetOp>
34
35template <typename SourceOp, typename TargetOp, bool FailOnUnsupportedFP = true>
36using ConvertFMFMathToLLVMPattern =
37 VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
38 FailOnUnsupportedFP>;
39
40/// Lowering pattern that matches only when the source op's rounding mode
41/// presence agrees with `HasRoundingMode`. Mirrors the helper of the same
42/// name in `mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp`. This lets us
43/// register two patterns for one math op: an unconstrained one that lowers
44/// to a regular LLVM op, and a constrained one (rounding mode present) that
45/// lowers to an `llvm.intr.experimental.constrained.*` intrinsic.
46template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
47 template <typename, typename> typename AttrConvert =
49 bool FailOnUnsupportedFP = true>
50struct ConstrainedVectorConvertToLLVMPattern
51 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
52 FailOnUnsupportedFP> {
53 using VectorConvertToLLVMPattern<
54 SourceOp, TargetOp, AttrConvert,
55 FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
56
57 LogicalResult
58 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override {
60 if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
61 return failure();
62 return VectorConvertToLLVMPattern<
63 SourceOp, TargetOp, AttrConvert,
64 FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
65 }
66};
67
68using AbsFOpLowering =
69 ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
70 /*FailOnUnsupportedFP=*/true>;
71using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
72using CopySignOpLowering =
73 ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
74using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
75using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
76using AcosOpLowering = ConvertFMFMathToLLVMPattern<math::AcosOp, LLVM::ACosOp>;
77using CtPopFOpLowering =
78 VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp,
80 /*FailOnUnsupportedFP=*/true>;
81using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
82using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
83using FloorOpLowering =
84 ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
85using FmaOpLowering =
86 ConstrainedVectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp,
87 /*HasRoundingMode=*/false,
88 ConvertFastMath,
89 /*FailOnUnsupportedFP=*/true>;
90using ConstrainedFmaOpLowering = ConstrainedVectorConvertToLLVMPattern<
91 math::FmaOp, LLVM::ConstrainedFMAIntr, /*HasRoundingMode=*/true,
92 arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
93using Log10OpLowering =
94 ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
95using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
96using LogOpLowering = ConvertFMFMathToLLVMPattern<math::LogOp, LLVM::LogOp>;
97using PowFOpLowering = ConvertFMFMathToLLVMPattern<math::PowFOp, LLVM::PowOp>;
98using FPowIOpLowering =
99 ConvertFMFMathToLLVMPattern<math::FPowIOp, LLVM::PowIOp>;
100using RoundEvenOpLowering =
101 ConvertFMFMathToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
102using RoundOpLowering =
103 ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
104using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
105using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
106using ASinOpLowering = ConvertFMFMathToLLVMPattern<math::AsinOp, LLVM::ASinOp>;
107using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
108using FTruncOpLowering =
109 ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
110using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
111using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
112using ATanOpLowering = ConvertFMFMathToLLVMPattern<math::AtanOp, LLVM::ATanOp>;
113using ATan2OpLowering =
114 ConvertFMFMathToLLVMPattern<math::Atan2Op, LLVM::ATan2Op>;
115// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
116// TODO: Result and operand types match for `absi` as opposed to `ct*z`, so it
117// may be better to separate the patterns.
118template <typename MathOp, typename LLVMOp>
119struct IntOpWithFlagLowering
120 : public ConvertOpToLLVMPattern<MathOp, /*FailOnUnsupportedFP=*/true> {
121 using ConvertOpToLLVMPattern<
122 MathOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
123 using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
124
125 LogicalResult
126 matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 const auto &typeConverter = *this->getTypeConverter();
129 auto operandType = adaptor.getOperand().getType();
130 auto llvmOperandType = typeConverter.convertType(operandType);
131 if (!llvmOperandType)
132 return failure();
133
134 auto loc = op.getLoc();
135 auto resultType = op.getResult().getType();
136 auto llvmResultType = typeConverter.convertType(resultType);
137 if (!llvmResultType)
138 return failure();
139
140 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
141 rewriter.replaceOpWithNewOp<LLVMOp>(op, llvmResultType,
142 adaptor.getOperand(), false);
143 return success();
144 }
145
146 if (!isa<VectorType>(resultType))
147 return failure();
148
150 op.getOperation(), adaptor.getOperands(), typeConverter,
151 [&](Type llvm1DVectorTy, ValueRange operands) {
152 return LLVMOp::create(rewriter, loc, llvm1DVectorTy, operands[0],
153 false);
154 },
155 rewriter);
156 }
157};
158
159using CountLeadingZerosOpLowering =
160 IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
161using CountTrailingZerosOpLowering =
162 IntOpWithFlagLowering<math::CountTrailingZerosOp,
163 LLVM::CountTrailingZerosOp>;
164using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
165
166// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
167struct SincosOpLowering
168 : public ConvertOpToLLVMPattern<math::SincosOp,
169 /*FailOnUnsupportedFP=*/true> {
171 math::SincosOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
172
173 LogicalResult
174 matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
175 ConversionPatternRewriter &rewriter) const override {
176 const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
177 mlir::Location loc = op.getLoc();
178 mlir::Type operandType = adaptor.getOperand().getType();
179 mlir::Type llvmOperandType = typeConverter.convertType(operandType);
180 mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
181 mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
182 if (!llvmOperandType || !sinType || !cosType)
183 return failure();
184
185 ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
186
187 auto structType = LLVM::LLVMStructType::getLiteral(
188 rewriter.getContext(), {llvmOperandType, llvmOperandType});
189
190 auto sincosOp = LLVM::SincosOp::create(
191 rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs());
192
193 auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
194 auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
195
196 rewriter.replaceOp(op, {sinValue, cosValue});
197 return success();
198 }
199};
200
201// A `expm1` is converted into `exp - 1`.
202struct ExpM1OpLowering
203 : public ConvertOpToLLVMPattern<math::ExpM1Op,
204 /*FailOnUnsupportedFP=*/true> {
205 using ConvertOpToLLVMPattern<
206 math::ExpM1Op, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
207
208 LogicalResult
209 matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
210 ConversionPatternRewriter &rewriter) const override {
211 const auto &typeConverter = *this->getTypeConverter();
212 auto operandType = adaptor.getOperand().getType();
213 auto llvmOperandType = typeConverter.convertType(operandType);
214 if (!llvmOperandType)
215 return failure();
216
217 auto loc = op.getLoc();
218 auto resultType = op.getResult().getType();
219 auto floatType = cast<FloatType>(
220 typeConverter.convertType(getElementTypeOrSelf(resultType)));
221 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
222 ConvertFastMath<math::ExpM1Op, LLVM::ExpOp> expAttrs(op);
223 ConvertFastMath<math::ExpM1Op, LLVM::FSubOp> subAttrs(op);
224
225 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
226 LLVM::ConstantOp one;
227 if (LLVM::isCompatibleVectorType(llvmOperandType)) {
228 one = LLVM::ConstantOp::create(
229 rewriter, loc, llvmOperandType,
230 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
231 floatOne));
232 } else {
233 one =
234 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
235 }
236 auto exp = LLVM::ExpOp::create(rewriter, loc, adaptor.getOperand(),
237 expAttrs.getAttrs());
238 rewriter.replaceOpWithNewOp<LLVM::FSubOp>(
239 op, llvmOperandType, ValueRange{exp, one}, subAttrs.getAttrs());
240 return success();
241 }
242
243 if (!isa<VectorType>(resultType))
244 return rewriter.notifyMatchFailure(op, "expected vector result type");
245
247 op.getOperation(), adaptor.getOperands(), typeConverter,
248 [&](Type llvm1DVectorTy, ValueRange operands) {
249 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
250 auto splatAttr = SplatElementsAttr::get(
251 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
252 {numElements.isScalable()}),
253 floatOne);
254 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
255 splatAttr);
256 auto exp = LLVM::ExpOp::create(rewriter, loc, llvm1DVectorTy,
257 operands[0], expAttrs.getAttrs());
258 return LLVM::FSubOp::create(rewriter, loc, llvm1DVectorTy,
259 ValueRange{exp, one},
260 subAttrs.getAttrs());
261 },
262 rewriter);
263 }
264};
265
266// A `log1p` is converted into `log(1 + ...)`.
267struct Log1pOpLowering
268 : public ConvertOpToLLVMPattern<math::Log1pOp,
269 /*FailOnUnsupportedFP=*/true> {
270 using ConvertOpToLLVMPattern<
271 math::Log1pOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
272
273 LogicalResult
274 matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 const auto &typeConverter = *this->getTypeConverter();
277 auto operandType = adaptor.getOperand().getType();
278 auto llvmOperandType = typeConverter.convertType(operandType);
279 if (!llvmOperandType)
280 return rewriter.notifyMatchFailure(op, "unsupported operand type");
281
282 auto loc = op.getLoc();
283 auto resultType = op.getResult().getType();
284 auto floatType = cast<FloatType>(
285 typeConverter.convertType(getElementTypeOrSelf(resultType)));
286 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
287 ConvertFastMath<math::Log1pOp, LLVM::FAddOp> addAttrs(op);
288 ConvertFastMath<math::Log1pOp, LLVM::LogOp> logAttrs(op);
289
290 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
291 LLVM::ConstantOp one =
292 isa<VectorType>(llvmOperandType)
293 ? LLVM::ConstantOp::create(
294 rewriter, loc, llvmOperandType,
295 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
296 floatOne))
297 : LLVM::ConstantOp::create(rewriter, loc, llvmOperandType,
298 floatOne);
299
300 auto add = LLVM::FAddOp::create(rewriter, loc, llvmOperandType,
301 ValueRange{one, adaptor.getOperand()},
302 addAttrs.getAttrs());
303 rewriter.replaceOpWithNewOp<LLVM::LogOp>(
304 op, llvmOperandType, ValueRange{add}, logAttrs.getAttrs());
305 return success();
306 }
307
308 if (!isa<VectorType>(resultType))
309 return rewriter.notifyMatchFailure(op, "expected vector result type");
310
312 op.getOperation(), adaptor.getOperands(), typeConverter,
313 [&](Type llvm1DVectorTy, ValueRange operands) {
314 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
315 auto splatAttr = SplatElementsAttr::get(
316 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
317 {numElements.isScalable()}),
318 floatOne);
319 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
320 splatAttr);
321 auto add = LLVM::FAddOp::create(rewriter, loc, llvm1DVectorTy,
322 ValueRange{one, operands[0]},
323 addAttrs.getAttrs());
324 return LLVM::LogOp::create(rewriter, loc, llvm1DVectorTy,
325 ValueRange{add}, logAttrs.getAttrs());
326 },
327 rewriter);
328 }
329};
330
331// A `rsqrt` is converted into `1 / sqrt`.
332struct RsqrtOpLowering
333 : public ConvertOpToLLVMPattern<math::RsqrtOp,
334 /*FailOnUnsupportedFP=*/true> {
335 using ConvertOpToLLVMPattern<
336 math::RsqrtOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
337
338 LogicalResult
339 matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
340 ConversionPatternRewriter &rewriter) const override {
341 const auto &typeConverter = *this->getTypeConverter();
342 auto operandType = adaptor.getOperand().getType();
343 auto llvmOperandType = typeConverter.convertType(operandType);
344 if (!llvmOperandType)
345 return failure();
346
347 auto loc = op.getLoc();
348 auto resultType = op.getResult().getType();
349 auto floatType = cast<FloatType>(
350 typeConverter.convertType(getElementTypeOrSelf(resultType)));
351 auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
352 ConvertFastMath<math::RsqrtOp, LLVM::SqrtOp> sqrtAttrs(op);
353 ConvertFastMath<math::RsqrtOp, LLVM::FDivOp> divAttrs(op);
354
355 if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
356 LLVM::ConstantOp one;
357 if (isa<VectorType>(llvmOperandType)) {
358 one = LLVM::ConstantOp::create(
359 rewriter, loc, llvmOperandType,
360 SplatElementsAttr::get(cast<ShapedType>(llvmOperandType),
361 floatOne));
362 } else {
363 one =
364 LLVM::ConstantOp::create(rewriter, loc, llvmOperandType, floatOne);
365 }
366 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, adaptor.getOperand(),
367 sqrtAttrs.getAttrs());
368 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(
369 op, llvmOperandType, ValueRange{one, sqrt}, divAttrs.getAttrs());
370 return success();
371 }
372
373 if (!isa<VectorType>(resultType))
374 return failure();
375
377 op.getOperation(), adaptor.getOperands(), typeConverter,
378 [&](Type llvm1DVectorTy, ValueRange operands) {
379 auto numElements = LLVM::getVectorNumElements(llvm1DVectorTy);
380 auto splatAttr = SplatElementsAttr::get(
381 mlir::VectorType::get({numElements.getKnownMinValue()}, floatType,
382 {numElements.isScalable()}),
383 floatOne);
384 auto one = LLVM::ConstantOp::create(rewriter, loc, llvm1DVectorTy,
385 splatAttr);
386 auto sqrt = LLVM::SqrtOp::create(rewriter, loc, llvm1DVectorTy,
387 operands[0], sqrtAttrs.getAttrs());
388 return LLVM::FDivOp::create(rewriter, loc, llvm1DVectorTy,
389 ValueRange{one, sqrt},
390 divAttrs.getAttrs());
391 },
392 rewriter);
393 }
394};
395
396struct IsNaNOpLowering
397 : public ConvertOpToLLVMPattern<math::IsNaNOp,
398 /*FailOnUnsupportedFP=*/true> {
399 using ConvertOpToLLVMPattern<
400 math::IsNaNOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
401
402 LogicalResult
403 matchAndRewrite(math::IsNaNOp op, OpAdaptor adaptor,
404 ConversionPatternRewriter &rewriter) const override {
405 const auto &typeConverter = *this->getTypeConverter();
406 auto operandType =
407 typeConverter.convertType(adaptor.getOperand().getType());
408 auto resultType = typeConverter.convertType(op.getResult().getType());
409 if (!operandType || !resultType)
410 return failure();
411
412 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
413 op, resultType, adaptor.getOperand(), llvm::fcNan);
414 return success();
415 }
416};
417
418struct IsFiniteOpLowering
419 : public ConvertOpToLLVMPattern<math::IsFiniteOp,
420 /*FailOnUnsupportedFP=*/true> {
421 using ConvertOpToLLVMPattern<
422 math::IsFiniteOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
423
424 LogicalResult
425 matchAndRewrite(math::IsFiniteOp op, OpAdaptor adaptor,
426 ConversionPatternRewriter &rewriter) const override {
427 const auto &typeConverter = *this->getTypeConverter();
428 auto operandType =
429 typeConverter.convertType(adaptor.getOperand().getType());
430 auto resultType = typeConverter.convertType(op.getResult().getType());
431 if (!operandType || !resultType)
432 return failure();
433
434 rewriter.replaceOpWithNewOp<LLVM::IsFPClass>(
435 op, resultType, adaptor.getOperand(), llvm::fcFinite);
436 return success();
437 }
438};
439
440struct ConvertMathToLLVMPass
441 : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
442 using Base::Base;
443
444 void runOnOperation() override {
445 RewritePatternSet patterns(&getContext());
446 LLVMTypeConverter converter(&getContext());
447 populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
448 LLVMConversionTarget target(getContext());
449 if (failed(applyPartialConversion(getOperation(), target,
450 std::move(patterns))))
451 signalPassFailure();
452 }
453};
454} // namespace
455
457 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
458 bool approximateLog1p, PatternBenefit benefit) {
459 if (approximateLog1p)
460 patterns.add<Log1pOpLowering>(converter, benefit);
461 // clang-format off
462 patterns.add<
463 IsNaNOpLowering,
464 IsFiniteOpLowering,
465 AbsFOpLowering,
466 AbsIOpLowering,
467 CeilOpLowering,
468 CopySignOpLowering,
469 CosOpLowering,
470 CoshOpLowering,
471 AcosOpLowering,
472 CountLeadingZerosOpLowering,
473 CountTrailingZerosOpLowering,
474 CtPopFOpLowering,
475 Exp2OpLowering,
476 ExpM1OpLowering,
477 ExpOpLowering,
478 FPowIOpLowering,
479 FloorOpLowering,
480 FmaOpLowering,
481 ConstrainedFmaOpLowering,
482 Log10OpLowering,
483 Log2OpLowering,
484 LogOpLowering,
485 PowFOpLowering,
486 RoundEvenOpLowering,
487 RoundOpLowering,
488 RsqrtOpLowering,
490 SinOpLowering,
491 SinhOpLowering,
492 ASinOpLowering,
493 SqrtOpLowering,
494 FTruncOpLowering,
495 TanOpLowering,
496 TanhOpLowering,
497 ATanOpLowering,
498 ATan2OpLowering
499 >(converter, benefit);
500 // clang-format on
501}
502
503//===----------------------------------------------------------------------===//
504// ConvertToLLVMPatternInterface implementation
505//===----------------------------------------------------------------------===//
506
507namespace {
508/// Implement the interface to convert Math to LLVM.
509struct MathToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
510 MathToLLVMDialectInterface(Dialect *dialect)
511 : ConvertToLLVMPatternInterface(dialect) {}
512
513 void loadDependentDialects(MLIRContext *context) const final {
514 context->loadDialect<LLVM::LLVMDialect>();
515 }
516
517 /// Hook for derived dialect interface to provide conversion patterns
518 /// and mark dialect legal for the conversion target.
519 void populateConvertToLLVMConversionPatterns(
520 ConversionTarget &target, LLVMTypeConverter &typeConverter,
521 RewritePatternSet &patterns) const final {
522 populateMathToLLVMConversionPatterns(typeConverter, patterns);
523 }
524};
525} // namespace
526
528 registry.addExtension(+[](MLIRContext *ctx, math::MathDialect *dialect) {
529 dialect->addInterfaces<MathToLLVMDialectInterface>();
530 });
531}
return success()
b getContext())
#define add(a, b)
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:29
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
Conversion from types to the LLVM IR dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool approximateLog1p=true, PatternBenefit benefit=1)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void registerConvertMathToLLVMInterface(DialectRegistry &registry)
LogicalResult matchAndRewrite(math::SincosOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override