MLIR  22.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Operations whose conversion will depend on whether they are passed a
32 /// rounding mode attribute or not.
33 ///
34 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
35 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
36 /// attribute.
37 template <typename SourceOp, typename TargetOp, bool Constrained,
38  template <typename, typename> typename AttrConvert =
40 struct ConstrainedVectorConvertToLLVMPattern
41  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43  AttrConvert>::VectorConvertToLLVMPattern;
44 
45  LogicalResult
46  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override {
48  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
49  return failure();
50  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51  AttrConvert>::matchAndRewrite(op, adaptor,
52  rewriter);
53  }
54 };
55 
56 /// No-op bitcast. Propagate type input arg if converted source and dest types
57 /// are the same.
58 struct IdentityBitcastLowering final
59  : public OpConversionPattern<arith::BitcastOp> {
61 
62  LogicalResult
63  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const final {
65  Value src = adaptor.getIn();
66  Type resultType = getTypeConverter()->convertType(op.getType());
67  if (src.getType() != resultType)
68  return rewriter.notifyMatchFailure(op, "Types are different");
69 
70  rewriter.replaceOp(op, src);
71  return success();
72  }
73 };
74 
75 //===----------------------------------------------------------------------===//
76 // Straightforward Op Lowerings
77 //===----------------------------------------------------------------------===//
78 
79 using AddFOpLowering =
80  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
81  arith::AttrConvertFastMathToLLVM>;
82 using AddIOpLowering =
83  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
84  arith::AttrConvertOverflowToLLVM>;
86 using BitcastOpLowering =
88 using DivFOpLowering =
89  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
90  arith::AttrConvertFastMathToLLVM>;
91 using DivSIOpLowering =
93 using DivUIOpLowering =
96 using ExtSIOpLowering =
98 using ExtUIOpLowering =
100 using FPToSIOpLowering =
102 using FPToUIOpLowering =
104 using MaximumFOpLowering =
105  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
106  arith::AttrConvertFastMathToLLVM>;
107 using MaxNumFOpLowering =
108  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
109  arith::AttrConvertFastMathToLLVM>;
110 using MaxSIOpLowering =
112 using MaxUIOpLowering =
114 using MinimumFOpLowering =
115  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
116  arith::AttrConvertFastMathToLLVM>;
117 using MinNumFOpLowering =
118  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
119  arith::AttrConvertFastMathToLLVM>;
120 using MinSIOpLowering =
122 using MinUIOpLowering =
124 using MulFOpLowering =
125  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
126  arith::AttrConvertFastMathToLLVM>;
127 using MulIOpLowering =
128  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
129  arith::AttrConvertOverflowToLLVM>;
130 using NegFOpLowering =
131  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
132  arith::AttrConvertFastMathToLLVM>;
134 using RemFOpLowering =
135  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
136  arith::AttrConvertFastMathToLLVM>;
137 using RemSIOpLowering =
139 using RemUIOpLowering =
141 using SelectOpLowering =
143 using ShLIOpLowering =
144  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
145  arith::AttrConvertOverflowToLLVM>;
146 using ShRSIOpLowering =
148 using ShRUIOpLowering =
150 using SIToFPOpLowering =
152 using SubFOpLowering =
153  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
154  arith::AttrConvertFastMathToLLVM>;
155 using SubIOpLowering =
156  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
157  arith::AttrConvertOverflowToLLVM>;
158 using TruncFOpLowering =
159  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
160  false>;
161 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
163  arith::AttrConverterConstrainedFPToLLVM>;
164 using TruncIOpLowering =
165  VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
166  arith::AttrConvertOverflowToLLVM>;
167 using UIToFPOpLowering =
170 
171 //===----------------------------------------------------------------------===//
172 // Op Lowering Patterns
173 //===----------------------------------------------------------------------===//
174 
175 /// Directly lower to LLVM op.
176 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
178 
179  LogicalResult
180  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// The lowering of index_cast becomes an integer conversion since index
185 /// becomes an integer. If the bit width of the source and target integer
186 /// types is the same, just erase the cast. If the target type is wider,
187 /// sign-extend the value, otherwise truncate it.
188 template <typename OpTy, typename ExtCastTy>
189 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
191 
192  LogicalResult
193  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override;
195 };
196 
197 using IndexCastOpSILowering =
198  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199 using IndexCastOpUILowering =
200  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
201 
202 struct AddUIExtendedOpLowering
203  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
205 
206  LogicalResult
207  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override;
209 };
210 
211 template <typename ArithMulOp, bool IsSigned>
212 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
214 
215  LogicalResult
216  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override;
218 };
219 
220 using MulSIExtendedOpLowering =
221  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222 using MulUIExtendedOpLowering =
223  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
224 
225 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
227 
228  LogicalResult
229  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const override;
231 };
232 
233 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
235 
236  LogicalResult
237  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override;
239 };
240 
241 struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
243  using Adaptor =
245 
246  LogicalResult
247  matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
248  ConversionPatternRewriter &rewriter) const override;
249 };
250 
251 } // namespace
252 
253 //===----------------------------------------------------------------------===//
254 // ConstantOpLowering
255 //===----------------------------------------------------------------------===//
256 
257 LogicalResult
258 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
259  ConversionPatternRewriter &rewriter) const {
260  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
261  adaptor.getOperands(), op->getAttrs(),
262  *getTypeConverter(), rewriter);
263 }
264 
265 //===----------------------------------------------------------------------===//
266 // IndexCastOpLowering
267 //===----------------------------------------------------------------------===//
268 
269 template <typename OpTy, typename ExtCastTy>
270 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
271  OpTy op, typename OpTy::Adaptor adaptor,
272  ConversionPatternRewriter &rewriter) const {
273  Type resultType = op.getResult().getType();
274  Type targetElementType =
275  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
276  Type sourceElementType =
277  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
278  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
279  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
280 
281  if (targetBits == sourceBits) {
282  rewriter.replaceOp(op, adaptor.getIn());
283  return success();
284  }
285 
286  // Handle the scalar and 1D vector cases.
287  Type operandType = adaptor.getIn().getType();
288  if (!isa<LLVM::LLVMArrayType>(operandType)) {
289  Type targetType = this->typeConverter->convertType(resultType);
290  if (targetBits < sourceBits)
291  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
292  adaptor.getIn());
293  else
294  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
295  return success();
296  }
297 
298  if (!isa<VectorType>(resultType))
299  return rewriter.notifyMatchFailure(op, "expected vector result type");
300 
302  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
303  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
304  typename OpTy::Adaptor adaptor(operands);
305  if (targetBits < sourceBits) {
306  return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
307  adaptor.getIn());
308  }
309  return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
310  adaptor.getIn());
311  },
312  rewriter);
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // AddUIExtendedOpLowering
317 //===----------------------------------------------------------------------===//
318 
319 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
320  arith::AddUIExtendedOp op, OpAdaptor adaptor,
321  ConversionPatternRewriter &rewriter) const {
322  Type operandType = adaptor.getLhs().getType();
323  Type sumResultType = op.getSum().getType();
324  Type overflowResultType = op.getOverflow().getType();
325 
326  if (!LLVM::isCompatibleType(operandType))
327  return failure();
328 
329  MLIRContext *ctx = rewriter.getContext();
330  Location loc = op.getLoc();
331 
332  // Handle the scalar and 1D vector cases.
333  if (!isa<LLVM::LLVMArrayType>(operandType)) {
334  Type newOverflowType = typeConverter->convertType(overflowResultType);
335  Type structType =
336  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
337  Value addOverflow = LLVM::UAddWithOverflowOp::create(
338  rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
339  Value sumExtracted =
340  LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
341  Value overflowExtracted =
342  LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
343  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
344  return success();
345  }
346 
347  if (!isa<VectorType>(sumResultType))
348  return rewriter.notifyMatchFailure(loc, "expected vector result types");
349 
350  return rewriter.notifyMatchFailure(loc,
351  "ND vector types are not supported yet");
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // MulIExtendedOpLowering
356 //===----------------------------------------------------------------------===//
357 
358 template <typename ArithMulOp, bool IsSigned>
359 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
360  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
361  ConversionPatternRewriter &rewriter) const {
362  Type resultType = adaptor.getLhs().getType();
363 
364  if (!LLVM::isCompatibleType(resultType))
365  return failure();
366 
367  Location loc = op.getLoc();
368 
369  // Handle the scalar and 1D vector cases. Because LLVM does not have a
370  // matching extended multiplication intrinsic, perform regular multiplication
371  // on operands zero-extended to i(2*N) bits, and truncate the results back to
372  // iN types.
373  if (!isa<LLVM::LLVMArrayType>(resultType)) {
374  // Shift amount necessary to extract the high bits from widened result.
375  TypedAttr shiftValAttr;
376 
377  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
378  unsigned resultBitwidth = intTy.getWidth();
379  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
380  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
381  } else {
382  auto vecTy = cast<VectorType>(resultType);
383  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
384  auto attrTy = VectorType::get(
385  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
386  shiftValAttr = SplatElementsAttr::get(
387  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
388  }
389  Type wideType = shiftValAttr.getType();
390  assert(LLVM::isCompatibleType(wideType) &&
391  "LLVM dialect should support all signless integer types");
392 
393  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
394  Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
395  Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
396  Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
397 
398  // Split the 2*N-bit wide result into two N-bit values.
399  Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
400  Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
401  Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
402  Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
403 
404  rewriter.replaceOp(op, {low, high});
405  return success();
406  }
407 
408  if (!isa<VectorType>(resultType))
409  return rewriter.notifyMatchFailure(op, "expected vector result type");
410 
411  return rewriter.notifyMatchFailure(op,
412  "ND vector types are not supported yet");
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // CmpIOpLowering
417 //===----------------------------------------------------------------------===//
418 
419 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
420 // share numerical values so just cast.
421 template <typename LLVMPredType, typename PredType>
422 static LLVMPredType convertCmpPredicate(PredType pred) {
423  return static_cast<LLVMPredType>(pred);
424 }
425 
426 LogicalResult
427 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
428  ConversionPatternRewriter &rewriter) const {
429  Type operandType = adaptor.getLhs().getType();
430  Type resultType = op.getResult().getType();
431 
432  // Handle the scalar and 1D vector cases.
433  if (!isa<LLVM::LLVMArrayType>(operandType)) {
434  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
435  op, typeConverter->convertType(resultType),
436  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
437  adaptor.getLhs(), adaptor.getRhs());
438  return success();
439  }
440 
441  if (!isa<VectorType>(resultType))
442  return rewriter.notifyMatchFailure(op, "expected vector result type");
443 
445  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
446  [&](Type llvm1DVectorTy, ValueRange operands) {
447  OpAdaptor adaptor(operands);
448  return LLVM::ICmpOp::create(
449  rewriter, op.getLoc(), llvm1DVectorTy,
450  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
451  adaptor.getLhs(), adaptor.getRhs());
452  },
453  rewriter);
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // CmpFOpLowering
458 //===----------------------------------------------------------------------===//
459 
460 LogicalResult
461 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
462  ConversionPatternRewriter &rewriter) const {
463  Type operandType = adaptor.getLhs().getType();
464  Type resultType = op.getResult().getType();
465  LLVM::FastmathFlags fmf =
466  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
467 
468  // Handle the scalar and 1D vector cases.
469  if (!isa<LLVM::LLVMArrayType>(operandType)) {
470  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
471  op, typeConverter->convertType(resultType),
472  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
473  adaptor.getLhs(), adaptor.getRhs(), fmf);
474  return success();
475  }
476 
477  if (!isa<VectorType>(resultType))
478  return rewriter.notifyMatchFailure(op, "expected vector result type");
479 
481  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
482  [&](Type llvm1DVectorTy, ValueRange operands) {
483  OpAdaptor adaptor(operands);
484  return LLVM::FCmpOp::create(
485  rewriter, op.getLoc(), llvm1DVectorTy,
486  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
487  adaptor.getLhs(), adaptor.getRhs(), fmf);
488  },
489  rewriter);
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // SelectOpOneToNLowering
494 //===----------------------------------------------------------------------===//
495 
496 /// Pattern for arith.select where the true/false values lower to multiple
497 /// SSA values (1:N conversion). This pattern generates multiple arith.select
498 /// than can be lowered by the 1:1 arith.select pattern.
499 LogicalResult SelectOpOneToNLowering::matchAndRewrite(
500  arith::SelectOp op, Adaptor adaptor,
501  ConversionPatternRewriter &rewriter) const {
502  // In case of a 1:1 conversion, the 1:1 pattern will match.
503  if (llvm::hasSingleElement(adaptor.getTrueValue()))
504  return rewriter.notifyMatchFailure(
505  op, "not a 1:N conversion, 1:1 pattern will match");
506  if (!op.getCondition().getType().isInteger(1))
507  return rewriter.notifyMatchFailure(op,
508  "non-i1 conditions are not supported");
509  SmallVector<Value> results;
510  for (auto [trueValue, falseValue] :
511  llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
512  results.push_back(arith::SelectOp::create(
513  rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
514  rewriter.replaceOpWithMultiple(op, {results});
515  return success();
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Pass Definition
520 //===----------------------------------------------------------------------===//
521 
522 namespace {
523 struct ArithToLLVMConversionPass
524  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
525  using Base::Base;
526 
527  void runOnOperation() override {
530 
532  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
533  options.overrideIndexBitwidth(indexBitwidth);
534 
535  LLVMTypeConverter converter(&getContext(), options);
538 
539  if (failed(applyPartialConversion(getOperation(), target,
540  std::move(patterns))))
541  signalPassFailure();
542  }
543 };
544 } // namespace
545 
546 //===----------------------------------------------------------------------===//
547 // ConvertToLLVMPatternInterface implementation
548 //===----------------------------------------------------------------------===//
549 
550 namespace {
551 /// Implement the interface to convert MemRef to LLVM.
552 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
554  void loadDependentDialects(MLIRContext *context) const final {
555  context->loadDialect<LLVM::LLVMDialect>();
556  }
557 
558  /// Hook for derived dialect interface to provide conversion patterns
559  /// and mark dialect legal for the conversion target.
561  ConversionTarget &target, LLVMTypeConverter &typeConverter,
562  RewritePatternSet &patterns) const final {
565  }
566 };
567 } // namespace
568 
570  DialectRegistry &registry) {
571  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
572  dialect->addInterfaces<ArithToLLVMDialectInterface>();
573  });
574 }
575 
576 //===----------------------------------------------------------------------===//
577 // Pattern Population
578 //===----------------------------------------------------------------------===//
579 
581  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
582 
583  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
584  // before BitcastOpLowering.
585  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
586  /*patternBenefit*/ 10);
587 
588  // clang-format off
589  patterns.add<
590  AddFOpLowering,
591  AddIOpLowering,
592  AndIOpLowering,
593  AddUIExtendedOpLowering,
594  BitcastOpLowering,
595  ConstantOpLowering,
596  CmpFOpLowering,
597  CmpIOpLowering,
598  DivFOpLowering,
599  DivSIOpLowering,
600  DivUIOpLowering,
601  ExtFOpLowering,
602  ExtSIOpLowering,
603  ExtUIOpLowering,
604  FPToSIOpLowering,
605  FPToUIOpLowering,
606  IndexCastOpSILowering,
607  IndexCastOpUILowering,
608  MaximumFOpLowering,
609  MaxNumFOpLowering,
610  MaxSIOpLowering,
611  MaxUIOpLowering,
612  MinimumFOpLowering,
613  MinNumFOpLowering,
614  MinSIOpLowering,
615  MinUIOpLowering,
616  MulFOpLowering,
617  MulIOpLowering,
618  MulSIExtendedOpLowering,
619  MulUIExtendedOpLowering,
620  NegFOpLowering,
621  OrIOpLowering,
622  RemFOpLowering,
623  RemSIOpLowering,
624  RemUIOpLowering,
625  SelectOpLowering,
626  SelectOpOneToNLowering,
627  ShLIOpLowering,
628  ShRSIOpLowering,
629  ShRUIOpLowering,
630  SIToFPOpLowering,
631  SubFOpLowering,
632  SubIOpLowering,
633  TruncFOpLowering,
634  ConstrainedTruncFOpLowering,
635  TruncIOpLowering,
636  UIToFPOpLowering,
637  XOrIOpLowering
638  >(converter);
639  // clang-format on
640 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void replaceOpWithMultiple(Operation *op, SmallVector< SmallVector< Value >> &&newValues)
Replace the given operation with the new value ranges.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
Definition: Pattern.h:213
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:307
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:796
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.