MLIR  22.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Operations whose conversion will depend on whether they are passed a
32 /// rounding mode attribute or not.
33 ///
34 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
35 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
36 /// attribute.
37 template <typename SourceOp, typename TargetOp, bool Constrained,
38  template <typename, typename> typename AttrConvert =
40 struct ConstrainedVectorConvertToLLVMPattern
41  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43  AttrConvert>::VectorConvertToLLVMPattern;
44 
45  LogicalResult
46  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override {
48  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
49  return failure();
50  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51  AttrConvert>::matchAndRewrite(op, adaptor,
52  rewriter);
53  }
54 };
55 
56 /// No-op bitcast. Propagate type input arg if converted source and dest types
57 /// are the same.
58 struct IdentityBitcastLowering final
59  : public OpConversionPattern<arith::BitcastOp> {
61 
62  LogicalResult
63  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
64  ConversionPatternRewriter &rewriter) const final {
65  Value src = adaptor.getIn();
66  Type resultType = getTypeConverter()->convertType(op.getType());
67  if (src.getType() != resultType)
68  return rewriter.notifyMatchFailure(op, "Types are different");
69 
70  rewriter.replaceOp(op, src);
71  return success();
72  }
73 };
74 
75 //===----------------------------------------------------------------------===//
76 // Straightforward Op Lowerings
77 //===----------------------------------------------------------------------===//
78 
79 using AddFOpLowering =
80  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
81  arith::AttrConvertFastMathToLLVM>;
82 using AddIOpLowering =
83  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
84  arith::AttrConvertOverflowToLLVM>;
86 using BitcastOpLowering =
88 using DivFOpLowering =
89  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
90  arith::AttrConvertFastMathToLLVM>;
91 using DivSIOpLowering =
93 using DivUIOpLowering =
96 using ExtSIOpLowering =
98 using ExtUIOpLowering =
100 using FPToSIOpLowering =
102 using FPToUIOpLowering =
104 using MaximumFOpLowering =
105  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
106  arith::AttrConvertFastMathToLLVM>;
107 using MaxNumFOpLowering =
108  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
109  arith::AttrConvertFastMathToLLVM>;
110 using MaxSIOpLowering =
112 using MaxUIOpLowering =
114 using MinimumFOpLowering =
115  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
116  arith::AttrConvertFastMathToLLVM>;
117 using MinNumFOpLowering =
118  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
119  arith::AttrConvertFastMathToLLVM>;
120 using MinSIOpLowering =
122 using MinUIOpLowering =
124 using MulFOpLowering =
125  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
126  arith::AttrConvertFastMathToLLVM>;
127 using MulIOpLowering =
128  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
129  arith::AttrConvertOverflowToLLVM>;
130 using NegFOpLowering =
131  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
132  arith::AttrConvertFastMathToLLVM>;
134 using RemFOpLowering =
135  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
136  arith::AttrConvertFastMathToLLVM>;
137 using RemSIOpLowering =
139 using RemUIOpLowering =
141 using SelectOpLowering =
143 using ShLIOpLowering =
144  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
145  arith::AttrConvertOverflowToLLVM>;
146 using ShRSIOpLowering =
148 using ShRUIOpLowering =
150 using SIToFPOpLowering =
152 using SubFOpLowering =
153  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
154  arith::AttrConvertFastMathToLLVM>;
155 using SubIOpLowering =
156  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
157  arith::AttrConvertOverflowToLLVM>;
158 using TruncFOpLowering =
159  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
160  false>;
161 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
162  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
163  arith::AttrConverterConstrainedFPToLLVM>;
164 using TruncIOpLowering =
165  VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
166  arith::AttrConvertOverflowToLLVM>;
167 using UIToFPOpLowering =
170 
171 //===----------------------------------------------------------------------===//
172 // Op Lowering Patterns
173 //===----------------------------------------------------------------------===//
174 
175 /// Directly lower to LLVM op.
176 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
178 
179  LogicalResult
180  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// The lowering of index_cast becomes an integer conversion since index
185 /// becomes an integer. If the bit width of the source and target integer
186 /// types is the same, just erase the cast. If the target type is wider,
187 /// sign-extend the value, otherwise truncate it.
188 template <typename OpTy, typename ExtCastTy>
189 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
191 
192  LogicalResult
193  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override;
195 };
196 
197 using IndexCastOpSILowering =
198  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199 using IndexCastOpUILowering =
200  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
201 
202 struct AddUIExtendedOpLowering
203  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
205 
206  LogicalResult
207  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override;
209 };
210 
211 template <typename ArithMulOp, bool IsSigned>
212 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
214 
215  LogicalResult
216  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override;
218 };
219 
220 using MulSIExtendedOpLowering =
221  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222 using MulUIExtendedOpLowering =
223  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
224 
225 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
227 
228  LogicalResult
229  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const override;
231 };
232 
233 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
235 
236  LogicalResult
237  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override;
239 };
240 
241 } // namespace
242 
243 //===----------------------------------------------------------------------===//
244 // ConstantOpLowering
245 //===----------------------------------------------------------------------===//
246 
247 LogicalResult
248 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
249  ConversionPatternRewriter &rewriter) const {
250  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
251  adaptor.getOperands(), op->getAttrs(),
252  *getTypeConverter(), rewriter);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // IndexCastOpLowering
257 //===----------------------------------------------------------------------===//
258 
259 template <typename OpTy, typename ExtCastTy>
260 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
261  OpTy op, typename OpTy::Adaptor adaptor,
262  ConversionPatternRewriter &rewriter) const {
263  Type resultType = op.getResult().getType();
264  Type targetElementType =
265  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
266  Type sourceElementType =
267  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
268  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
269  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
270 
271  if (targetBits == sourceBits) {
272  rewriter.replaceOp(op, adaptor.getIn());
273  return success();
274  }
275 
276  // Handle the scalar and 1D vector cases.
277  Type operandType = adaptor.getIn().getType();
278  if (!isa<LLVM::LLVMArrayType>(operandType)) {
279  Type targetType = this->typeConverter->convertType(resultType);
280  if (targetBits < sourceBits)
281  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
282  adaptor.getIn());
283  else
284  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
285  return success();
286  }
287 
288  if (!isa<VectorType>(resultType))
289  return rewriter.notifyMatchFailure(op, "expected vector result type");
290 
292  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
293  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
294  typename OpTy::Adaptor adaptor(operands);
295  if (targetBits < sourceBits) {
296  return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy,
297  adaptor.getIn());
298  }
299  return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy,
300  adaptor.getIn());
301  },
302  rewriter);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AddUIExtendedOpLowering
307 //===----------------------------------------------------------------------===//
308 
309 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
310  arith::AddUIExtendedOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const {
312  Type operandType = adaptor.getLhs().getType();
313  Type sumResultType = op.getSum().getType();
314  Type overflowResultType = op.getOverflow().getType();
315 
316  if (!LLVM::isCompatibleType(operandType))
317  return failure();
318 
319  MLIRContext *ctx = rewriter.getContext();
320  Location loc = op.getLoc();
321 
322  // Handle the scalar and 1D vector cases.
323  if (!isa<LLVM::LLVMArrayType>(operandType)) {
324  Type newOverflowType = typeConverter->convertType(overflowResultType);
325  Type structType =
326  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
327  Value addOverflow = LLVM::UAddWithOverflowOp::create(
328  rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs());
329  Value sumExtracted =
330  LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0);
331  Value overflowExtracted =
332  LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1);
333  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
334  return success();
335  }
336 
337  if (!isa<VectorType>(sumResultType))
338  return rewriter.notifyMatchFailure(loc, "expected vector result types");
339 
340  return rewriter.notifyMatchFailure(loc,
341  "ND vector types are not supported yet");
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // MulIExtendedOpLowering
346 //===----------------------------------------------------------------------===//
347 
348 template <typename ArithMulOp, bool IsSigned>
349 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
350  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
351  ConversionPatternRewriter &rewriter) const {
352  Type resultType = adaptor.getLhs().getType();
353 
354  if (!LLVM::isCompatibleType(resultType))
355  return failure();
356 
357  Location loc = op.getLoc();
358 
359  // Handle the scalar and 1D vector cases. Because LLVM does not have a
360  // matching extended multiplication intrinsic, perform regular multiplication
361  // on operands zero-extended to i(2*N) bits, and truncate the results back to
362  // iN types.
363  if (!isa<LLVM::LLVMArrayType>(resultType)) {
364  // Shift amount necessary to extract the high bits from widened result.
365  TypedAttr shiftValAttr;
366 
367  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
368  unsigned resultBitwidth = intTy.getWidth();
369  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
370  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
371  } else {
372  auto vecTy = cast<VectorType>(resultType);
373  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
374  auto attrTy = VectorType::get(
375  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
376  shiftValAttr = SplatElementsAttr::get(
377  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
378  }
379  Type wideType = shiftValAttr.getType();
380  assert(LLVM::isCompatibleType(wideType) &&
381  "LLVM dialect should support all signless integer types");
382 
383  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
384  Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs());
385  Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs());
386  Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt);
387 
388  // Split the 2*N-bit wide result into two N-bit values.
389  Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt);
390  Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr);
391  Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal);
392  Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt);
393 
394  rewriter.replaceOp(op, {low, high});
395  return success();
396  }
397 
398  if (!isa<VectorType>(resultType))
399  return rewriter.notifyMatchFailure(op, "expected vector result type");
400 
401  return rewriter.notifyMatchFailure(op,
402  "ND vector types are not supported yet");
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // CmpIOpLowering
407 //===----------------------------------------------------------------------===//
408 
409 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
410 // share numerical values so just cast.
411 template <typename LLVMPredType, typename PredType>
412 static LLVMPredType convertCmpPredicate(PredType pred) {
413  return static_cast<LLVMPredType>(pred);
414 }
415 
416 LogicalResult
417 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
418  ConversionPatternRewriter &rewriter) const {
419  Type operandType = adaptor.getLhs().getType();
420  Type resultType = op.getResult().getType();
421 
422  // Handle the scalar and 1D vector cases.
423  if (!isa<LLVM::LLVMArrayType>(operandType)) {
424  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
425  op, typeConverter->convertType(resultType),
426  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
427  adaptor.getLhs(), adaptor.getRhs());
428  return success();
429  }
430 
431  if (!isa<VectorType>(resultType))
432  return rewriter.notifyMatchFailure(op, "expected vector result type");
433 
435  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
436  [&](Type llvm1DVectorTy, ValueRange operands) {
437  OpAdaptor adaptor(operands);
438  return LLVM::ICmpOp::create(
439  rewriter, op.getLoc(), llvm1DVectorTy,
440  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
441  adaptor.getLhs(), adaptor.getRhs());
442  },
443  rewriter);
444 }
445 
446 //===----------------------------------------------------------------------===//
447 // CmpFOpLowering
448 //===----------------------------------------------------------------------===//
449 
450 LogicalResult
451 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
452  ConversionPatternRewriter &rewriter) const {
453  Type operandType = adaptor.getLhs().getType();
454  Type resultType = op.getResult().getType();
455  LLVM::FastmathFlags fmf =
456  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
457 
458  // Handle the scalar and 1D vector cases.
459  if (!isa<LLVM::LLVMArrayType>(operandType)) {
460  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
461  op, typeConverter->convertType(resultType),
462  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
463  adaptor.getLhs(), adaptor.getRhs(), fmf);
464  return success();
465  }
466 
467  if (!isa<VectorType>(resultType))
468  return rewriter.notifyMatchFailure(op, "expected vector result type");
469 
471  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
472  [&](Type llvm1DVectorTy, ValueRange operands) {
473  OpAdaptor adaptor(operands);
474  return LLVM::FCmpOp::create(
475  rewriter, op.getLoc(), llvm1DVectorTy,
476  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
477  adaptor.getLhs(), adaptor.getRhs(), fmf);
478  },
479  rewriter);
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // Pass Definition
484 //===----------------------------------------------------------------------===//
485 
486 namespace {
487 struct ArithToLLVMConversionPass
488  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
489  using Base::Base;
490 
491  void runOnOperation() override {
494 
496  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
497  options.overrideIndexBitwidth(indexBitwidth);
498 
499  LLVMTypeConverter converter(&getContext(), options);
502 
503  if (failed(applyPartialConversion(getOperation(), target,
504  std::move(patterns))))
505  signalPassFailure();
506  }
507 };
508 } // namespace
509 
510 //===----------------------------------------------------------------------===//
511 // ConvertToLLVMPatternInterface implementation
512 //===----------------------------------------------------------------------===//
513 
514 namespace {
515 /// Implement the interface to convert MemRef to LLVM.
516 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
518  void loadDependentDialects(MLIRContext *context) const final {
519  context->loadDialect<LLVM::LLVMDialect>();
520  }
521 
522  /// Hook for derived dialect interface to provide conversion patterns
523  /// and mark dialect legal for the conversion target.
525  ConversionTarget &target, LLVMTypeConverter &typeConverter,
526  RewritePatternSet &patterns) const final {
529  }
530 };
531 } // namespace
532 
534  DialectRegistry &registry) {
535  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
536  dialect->addInterfaces<ArithToLLVMDialectInterface>();
537  });
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // Pattern Population
542 //===----------------------------------------------------------------------===//
543 
545  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
546 
547  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
548  // before BitcastOpLowering.
549  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
550  /*patternBenefit*/ 10);
551 
552  // clang-format off
553  patterns.add<
554  AddFOpLowering,
555  AddIOpLowering,
556  AndIOpLowering,
557  AddUIExtendedOpLowering,
558  BitcastOpLowering,
559  ConstantOpLowering,
560  CmpFOpLowering,
561  CmpIOpLowering,
562  DivFOpLowering,
563  DivSIOpLowering,
564  DivUIOpLowering,
565  ExtFOpLowering,
566  ExtSIOpLowering,
567  ExtUIOpLowering,
568  FPToSIOpLowering,
569  FPToUIOpLowering,
570  IndexCastOpSILowering,
571  IndexCastOpUILowering,
572  MaximumFOpLowering,
573  MaxNumFOpLowering,
574  MaxSIOpLowering,
575  MaxUIOpLowering,
576  MinimumFOpLowering,
577  MinNumFOpLowering,
578  MinSIOpLowering,
579  MinUIOpLowering,
580  MulFOpLowering,
581  MulIOpLowering,
582  MulSIExtendedOpLowering,
583  MulUIExtendedOpLowering,
584  NegFOpLowering,
585  OrIOpLowering,
586  RemFOpLowering,
587  RemSIOpLowering,
588  RemUIOpLowering,
589  SelectOpLowering,
590  ShLIOpLowering,
591  ShRSIOpLowering,
592  ShRUIOpLowering,
593  SIToFPOpLowering,
594  SubFOpLowering,
595  SubIOpLowering,
596  TruncFOpLowering,
597  ConstrainedTruncFOpLowering,
598  TruncIOpLowering,
599  UIToFPOpLowering,
600  XOrIOpLowering
601  >(converter);
602  // clang-format on
603 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:319
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:795
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:796
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.