MLIR  21.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include <type_traits>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Operations whose conversion will depend on whether they are passed a
33 /// rounding mode attribute or not.
34 ///
35 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37 /// attribute.
38 template <typename SourceOp, typename TargetOp, bool Constrained,
39  template <typename, typename> typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
42  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
43  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
44  AttrConvert>::VectorConvertToLLVMPattern;
45 
46  LogicalResult
47  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
48  ConversionPatternRewriter &rewriter) const override {
49  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
50  return failure();
51  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
52  AttrConvert>::matchAndRewrite(op, adaptor,
53  rewriter);
54  }
55 };
56 
57 /// No-op bitcast. Propagate type input arg if converted source and dest types
58 /// are the same.
59 struct IdentityBitcastLowering final
60  : public OpConversionPattern<arith::BitcastOp> {
62 
63  LogicalResult
64  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65  ConversionPatternRewriter &rewriter) const final {
66  Value src = adaptor.getIn();
67  Type resultType = getTypeConverter()->convertType(op.getType());
68  if (src.getType() != resultType)
69  return rewriter.notifyMatchFailure(op, "Types are different");
70 
71  rewriter.replaceOp(op, src);
72  return success();
73  }
74 };
75 
76 //===----------------------------------------------------------------------===//
77 // Straightforward Op Lowerings
78 //===----------------------------------------------------------------------===//
79 
80 using AddFOpLowering =
81  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
82  arith::AttrConvertFastMathToLLVM>;
83 using AddIOpLowering =
84  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
85  arith::AttrConvertOverflowToLLVM>;
87 using BitcastOpLowering =
89 using DivFOpLowering =
90  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
91  arith::AttrConvertFastMathToLLVM>;
92 using DivSIOpLowering =
94 using DivUIOpLowering =
97 using ExtSIOpLowering =
99 using ExtUIOpLowering =
101 using FPToSIOpLowering =
103 using FPToUIOpLowering =
105 using MaximumFOpLowering =
106  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
107  arith::AttrConvertFastMathToLLVM>;
108 using MaxNumFOpLowering =
109  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
110  arith::AttrConvertFastMathToLLVM>;
111 using MaxSIOpLowering =
113 using MaxUIOpLowering =
115 using MinimumFOpLowering =
116  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
117  arith::AttrConvertFastMathToLLVM>;
118 using MinNumFOpLowering =
119  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
120  arith::AttrConvertFastMathToLLVM>;
121 using MinSIOpLowering =
123 using MinUIOpLowering =
125 using MulFOpLowering =
126  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
127  arith::AttrConvertFastMathToLLVM>;
128 using MulIOpLowering =
129  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
130  arith::AttrConvertOverflowToLLVM>;
131 using NegFOpLowering =
132  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
133  arith::AttrConvertFastMathToLLVM>;
135 using RemFOpLowering =
136  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
137  arith::AttrConvertFastMathToLLVM>;
138 using RemSIOpLowering =
140 using RemUIOpLowering =
142 using SelectOpLowering =
144 using ShLIOpLowering =
145  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
146  arith::AttrConvertOverflowToLLVM>;
147 using ShRSIOpLowering =
149 using ShRUIOpLowering =
151 using SIToFPOpLowering =
153 using SubFOpLowering =
154  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
155  arith::AttrConvertFastMathToLLVM>;
156 using SubIOpLowering =
157  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
158  arith::AttrConvertOverflowToLLVM>;
159 using TruncFOpLowering =
160  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
161  false>;
162 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
164  arith::AttrConverterConstrainedFPToLLVM>;
165 using TruncIOpLowering =
167 using UIToFPOpLowering =
170 
171 //===----------------------------------------------------------------------===//
172 // Op Lowering Patterns
173 //===----------------------------------------------------------------------===//
174 
175 /// Directly lower to LLVM op.
176 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
178 
179  LogicalResult
180  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// The lowering of index_cast becomes an integer conversion since index
185 /// becomes an integer. If the bit width of the source and target integer
186 /// types is the same, just erase the cast. If the target type is wider,
187 /// sign-extend the value, otherwise truncate it.
188 template <typename OpTy, typename ExtCastTy>
189 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
191 
192  LogicalResult
193  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
194  ConversionPatternRewriter &rewriter) const override;
195 };
196 
197 using IndexCastOpSILowering =
198  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
199 using IndexCastOpUILowering =
200  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
201 
202 struct AddUIExtendedOpLowering
203  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
205 
206  LogicalResult
207  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override;
209 };
210 
211 template <typename ArithMulOp, bool IsSigned>
212 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
214 
215  LogicalResult
216  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override;
218 };
219 
220 using MulSIExtendedOpLowering =
221  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
222 using MulUIExtendedOpLowering =
223  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
224 
225 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
227 
228  LogicalResult
229  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const override;
231 };
232 
233 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
235 
236  LogicalResult
237  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const override;
239 };
240 
241 } // namespace
242 
243 //===----------------------------------------------------------------------===//
244 // ConstantOpLowering
245 //===----------------------------------------------------------------------===//
246 
247 LogicalResult
248 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
249  ConversionPatternRewriter &rewriter) const {
250  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
251  adaptor.getOperands(), op->getAttrs(),
252  *getTypeConverter(), rewriter);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // IndexCastOpLowering
257 //===----------------------------------------------------------------------===//
258 
259 template <typename OpTy, typename ExtCastTy>
260 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
261  OpTy op, typename OpTy::Adaptor adaptor,
262  ConversionPatternRewriter &rewriter) const {
263  Type resultType = op.getResult().getType();
264  Type targetElementType =
265  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
266  Type sourceElementType =
267  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
268  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
269  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
270 
271  if (targetBits == sourceBits) {
272  rewriter.replaceOp(op, adaptor.getIn());
273  return success();
274  }
275 
276  // Handle the scalar and 1D vector cases.
277  Type operandType = adaptor.getIn().getType();
278  if (!isa<LLVM::LLVMArrayType>(operandType)) {
279  Type targetType = this->typeConverter->convertType(resultType);
280  if (targetBits < sourceBits)
281  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
282  adaptor.getIn());
283  else
284  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
285  return success();
286  }
287 
288  if (!isa<VectorType>(resultType))
289  return rewriter.notifyMatchFailure(op, "expected vector result type");
290 
292  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
293  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
294  typename OpTy::Adaptor adaptor(operands);
295  if (targetBits < sourceBits) {
296  return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
297  adaptor.getIn());
298  }
299  return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
300  adaptor.getIn());
301  },
302  rewriter);
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // AddUIExtendedOpLowering
307 //===----------------------------------------------------------------------===//
308 
309 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
310  arith::AddUIExtendedOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const {
312  Type operandType = adaptor.getLhs().getType();
313  Type sumResultType = op.getSum().getType();
314  Type overflowResultType = op.getOverflow().getType();
315 
316  if (!LLVM::isCompatibleType(operandType))
317  return failure();
318 
319  MLIRContext *ctx = rewriter.getContext();
320  Location loc = op.getLoc();
321 
322  // Handle the scalar and 1D vector cases.
323  if (!isa<LLVM::LLVMArrayType>(operandType)) {
324  Type newOverflowType = typeConverter->convertType(overflowResultType);
325  Type structType =
326  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
327  Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
328  loc, structType, adaptor.getLhs(), adaptor.getRhs());
329  Value sumExtracted =
330  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
331  Value overflowExtracted =
332  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
333  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
334  return success();
335  }
336 
337  if (!isa<VectorType>(sumResultType))
338  return rewriter.notifyMatchFailure(loc, "expected vector result types");
339 
340  return rewriter.notifyMatchFailure(loc,
341  "ND vector types are not supported yet");
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // MulIExtendedOpLowering
346 //===----------------------------------------------------------------------===//
347 
348 template <typename ArithMulOp, bool IsSigned>
349 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
350  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
351  ConversionPatternRewriter &rewriter) const {
352  Type resultType = adaptor.getLhs().getType();
353 
354  if (!LLVM::isCompatibleType(resultType))
355  return failure();
356 
357  Location loc = op.getLoc();
358 
359  // Handle the scalar and 1D vector cases. Because LLVM does not have a
360  // matching extended multiplication intrinsic, perform regular multiplication
361  // on operands zero-extended to i(2*N) bits, and truncate the results back to
362  // iN types.
363  if (!isa<LLVM::LLVMArrayType>(resultType)) {
364  // Shift amount necessary to extract the high bits from widened result.
365  TypedAttr shiftValAttr;
366 
367  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
368  unsigned resultBitwidth = intTy.getWidth();
369  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
370  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
371  } else {
372  auto vecTy = cast<VectorType>(resultType);
373  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
374  auto attrTy = VectorType::get(
375  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
376  shiftValAttr = SplatElementsAttr::get(
377  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
378  }
379  Type wideType = shiftValAttr.getType();
380  assert(LLVM::isCompatibleType(wideType) &&
381  "LLVM dialect should support all signless integer types");
382 
383  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
384  Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
385  Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
386  Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
387 
388  // Split the 2*N-bit wide result into two N-bit values.
389  Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
390  Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
391  Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
392  Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
393 
394  rewriter.replaceOp(op, {low, high});
395  return success();
396  }
397 
398  if (!isa<VectorType>(resultType))
399  return rewriter.notifyMatchFailure(op, "expected vector result type");
400 
401  return rewriter.notifyMatchFailure(op,
402  "ND vector types are not supported yet");
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // CmpIOpLowering
407 //===----------------------------------------------------------------------===//
408 
409 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
410 // share numerical values so just cast.
411 template <typename LLVMPredType, typename PredType>
412 static LLVMPredType convertCmpPredicate(PredType pred) {
413  return static_cast<LLVMPredType>(pred);
414 }
415 
416 LogicalResult
417 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
418  ConversionPatternRewriter &rewriter) const {
419  Type operandType = adaptor.getLhs().getType();
420  Type resultType = op.getResult().getType();
421 
422  // Handle the scalar and 1D vector cases.
423  if (!isa<LLVM::LLVMArrayType>(operandType)) {
424  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
425  op, typeConverter->convertType(resultType),
426  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
427  adaptor.getLhs(), adaptor.getRhs());
428  return success();
429  }
430 
431  if (!isa<VectorType>(resultType))
432  return rewriter.notifyMatchFailure(op, "expected vector result type");
433 
435  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
436  [&](Type llvm1DVectorTy, ValueRange operands) {
437  OpAdaptor adaptor(operands);
438  return rewriter.create<LLVM::ICmpOp>(
439  op.getLoc(), llvm1DVectorTy,
440  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
441  adaptor.getLhs(), adaptor.getRhs());
442  },
443  rewriter);
444 }
445 
446 //===----------------------------------------------------------------------===//
447 // CmpFOpLowering
448 //===----------------------------------------------------------------------===//
449 
450 LogicalResult
451 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
452  ConversionPatternRewriter &rewriter) const {
453  Type operandType = adaptor.getLhs().getType();
454  Type resultType = op.getResult().getType();
455  LLVM::FastmathFlags fmf =
456  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
457 
458  // Handle the scalar and 1D vector cases.
459  if (!isa<LLVM::LLVMArrayType>(operandType)) {
460  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
461  op, typeConverter->convertType(resultType),
462  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
463  adaptor.getLhs(), adaptor.getRhs(), fmf);
464  return success();
465  }
466 
467  if (!isa<VectorType>(resultType))
468  return rewriter.notifyMatchFailure(op, "expected vector result type");
469 
471  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
472  [&](Type llvm1DVectorTy, ValueRange operands) {
473  OpAdaptor adaptor(operands);
474  return rewriter.create<LLVM::FCmpOp>(
475  op.getLoc(), llvm1DVectorTy,
476  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
477  adaptor.getLhs(), adaptor.getRhs(), fmf);
478  },
479  rewriter);
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // Pass Definition
484 //===----------------------------------------------------------------------===//
485 
486 namespace {
487 struct ArithToLLVMConversionPass
488  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
489  using Base::Base;
490 
491  void runOnOperation() override {
494 
496  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
497  options.overrideIndexBitwidth(indexBitwidth);
498 
499  LLVMTypeConverter converter(&getContext(), options);
502 
503  if (failed(applyPartialConversion(getOperation(), target,
504  std::move(patterns))))
505  signalPassFailure();
506  }
507 };
508 } // namespace
509 
510 //===----------------------------------------------------------------------===//
511 // ConvertToLLVMPatternInterface implementation
512 //===----------------------------------------------------------------------===//
513 
514 namespace {
515 /// Implement the interface to convert MemRef to LLVM.
516 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
518  void loadDependentDialects(MLIRContext *context) const final {
519  context->loadDialect<LLVM::LLVMDialect>();
520  }
521 
522  /// Hook for derived dialect interface to provide conversion patterns
523  /// and mark dialect legal for the conversion target.
525  ConversionTarget &target, LLVMTypeConverter &typeConverter,
526  RewritePatternSet &patterns) const final {
529  }
530 };
531 } // namespace
532 
534  DialectRegistry &registry) {
535  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
536  dialect->addInterfaces<ArithToLLVMDialectInterface>();
537  });
538 }
539 
540 //===----------------------------------------------------------------------===//
541 // Pattern Population
542 //===----------------------------------------------------------------------===//
543 
545  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
546 
547  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
548  // before BitcastOpLowering.
549  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
550  /*patternBenefit*/ 10);
551 
552  // clang-format off
553  patterns.add<
554  AddFOpLowering,
555  AddIOpLowering,
556  AndIOpLowering,
557  AddUIExtendedOpLowering,
558  BitcastOpLowering,
559  ConstantOpLowering,
560  CmpFOpLowering,
561  CmpIOpLowering,
562  DivFOpLowering,
563  DivSIOpLowering,
564  DivUIOpLowering,
565  ExtFOpLowering,
566  ExtSIOpLowering,
567  ExtUIOpLowering,
568  FPToSIOpLowering,
569  FPToUIOpLowering,
570  IndexCastOpSILowering,
571  IndexCastOpUILowering,
572  MaximumFOpLowering,
573  MaxNumFOpLowering,
574  MaxSIOpLowering,
575  MaxUIOpLowering,
576  MinimumFOpLowering,
577  MinNumFOpLowering,
578  MinSIOpLowering,
579  MinUIOpLowering,
580  MulFOpLowering,
581  MulIOpLowering,
582  MulSIExtendedOpLowering,
583  MulUIExtendedOpLowering,
584  NegFOpLowering,
585  OrIOpLowering,
586  RemFOpLowering,
587  RemSIOpLowering,
588  RemUIOpLowering,
589  SelectOpLowering,
590  ShLIOpLowering,
591  ShRSIOpLowering,
592  ShRUIOpLowering,
593  SIToFPOpLowering,
594  SubFOpLowering,
595  SubIOpLowering,
596  TruncFOpLowering,
597  ConstrainedTruncFOpLowering,
598  TruncIOpLowering,
599  UIToFPOpLowering,
600  XOrIOpLowering
601  >(converter);
602  // clang-format on
603 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:345
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:380
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.