MLIR  21.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include <type_traits>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Operations whose conversion will depend on whether they are passed a
33 /// rounding mode attribute or not.
34 ///
35 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37 /// attribute.
38 template <typename SourceOp, typename TargetOp, bool Constrained,
39  template <typename, typename> typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
42  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
43  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
44  AttrConvert>::VectorConvertToLLVMPattern;
45 
46  LogicalResult
47  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
48  ConversionPatternRewriter &rewriter) const override {
49  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
50  return failure();
51  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
52  AttrConvert>::matchAndRewrite(op, adaptor,
53  rewriter);
54  }
55 };
56 
57 /// No-op bitcast. Propagate type input arg if converted source and dest types
58 /// are the same.
59 struct IdentityBitcastLowering final
60  : public OpConversionPattern<arith::BitcastOp> {
62 
63  LogicalResult
64  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
65  ConversionPatternRewriter &rewriter) const final {
66  Value src = adaptor.getIn();
67  Type resultType = getTypeConverter()->convertType(op.getType());
68  if (src.getType() != resultType)
69  return rewriter.notifyMatchFailure(op, "Types are different");
70 
71  rewriter.replaceOp(op, src);
72  return success();
73  }
74 };
75 
76 //===----------------------------------------------------------------------===//
77 // Straightforward Op Lowerings
78 //===----------------------------------------------------------------------===//
79 
80 using AddFOpLowering =
81  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
82  arith::AttrConvertFastMathToLLVM>;
83 using AddIOpLowering =
84  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
85  arith::AttrConvertOverflowToLLVM>;
87 using BitcastOpLowering =
89 using DivFOpLowering =
90  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
91  arith::AttrConvertFastMathToLLVM>;
92 using DivSIOpLowering =
94 using DivUIOpLowering =
97 using ExtSIOpLowering =
99 using ExtUIOpLowering =
101 using FPToSIOpLowering =
103 using FPToUIOpLowering =
105 using MaximumFOpLowering =
106  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
107  arith::AttrConvertFastMathToLLVM>;
108 using MaxNumFOpLowering =
109  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
110  arith::AttrConvertFastMathToLLVM>;
111 using MaxSIOpLowering =
113 using MaxUIOpLowering =
115 using MinimumFOpLowering =
116  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
117  arith::AttrConvertFastMathToLLVM>;
118 using MinNumFOpLowering =
119  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
120  arith::AttrConvertFastMathToLLVM>;
121 using MinSIOpLowering =
123 using MinUIOpLowering =
125 using MulFOpLowering =
126  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
127  arith::AttrConvertFastMathToLLVM>;
128 using MulIOpLowering =
129  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
130  arith::AttrConvertOverflowToLLVM>;
131 using NegFOpLowering =
132  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
133  arith::AttrConvertFastMathToLLVM>;
135 using RemFOpLowering =
136  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
137  arith::AttrConvertFastMathToLLVM>;
138 using RemSIOpLowering =
140 using RemUIOpLowering =
142 using SelectOpLowering =
144 using ShLIOpLowering =
145  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
146  arith::AttrConvertOverflowToLLVM>;
147 using ShRSIOpLowering =
149 using ShRUIOpLowering =
151 using SIToFPOpLowering =
153 using SubFOpLowering =
154  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
155  arith::AttrConvertFastMathToLLVM>;
156 using SubIOpLowering =
157  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
158  arith::AttrConvertOverflowToLLVM>;
159 using TruncFOpLowering =
160  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
161  false>;
162 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
164  arith::AttrConverterConstrainedFPToLLVM>;
165 using TruncIOpLowering =
166  VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
167  arith::AttrConvertOverflowToLLVM>;
168 using UIToFPOpLowering =
171 
172 //===----------------------------------------------------------------------===//
173 // Op Lowering Patterns
174 //===----------------------------------------------------------------------===//
175 
176 /// Directly lower to LLVM op.
177 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
179 
180  LogicalResult
181  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
182  ConversionPatternRewriter &rewriter) const override;
183 };
184 
185 /// The lowering of index_cast becomes an integer conversion since index
186 /// becomes an integer. If the bit width of the source and target integer
187 /// types is the same, just erase the cast. If the target type is wider,
188 /// sign-extend the value, otherwise truncate it.
189 template <typename OpTy, typename ExtCastTy>
190 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
192 
193  LogicalResult
194  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
195  ConversionPatternRewriter &rewriter) const override;
196 };
197 
198 using IndexCastOpSILowering =
199  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
200 using IndexCastOpUILowering =
201  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
202 
203 struct AddUIExtendedOpLowering
204  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
206 
207  LogicalResult
208  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
209  ConversionPatternRewriter &rewriter) const override;
210 };
211 
212 template <typename ArithMulOp, bool IsSigned>
213 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
215 
216  LogicalResult
217  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
218  ConversionPatternRewriter &rewriter) const override;
219 };
220 
221 using MulSIExtendedOpLowering =
222  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
223 using MulUIExtendedOpLowering =
224  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
225 
226 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
228 
229  LogicalResult
230  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
231  ConversionPatternRewriter &rewriter) const override;
232 };
233 
234 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
236 
237  LogicalResult
238  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const override;
240 };
241 
242 } // namespace
243 
244 //===----------------------------------------------------------------------===//
245 // ConstantOpLowering
246 //===----------------------------------------------------------------------===//
247 
248 LogicalResult
249 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
250  ConversionPatternRewriter &rewriter) const {
251  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
252  adaptor.getOperands(), op->getAttrs(),
253  *getTypeConverter(), rewriter);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // IndexCastOpLowering
258 //===----------------------------------------------------------------------===//
259 
260 template <typename OpTy, typename ExtCastTy>
261 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
262  OpTy op, typename OpTy::Adaptor adaptor,
263  ConversionPatternRewriter &rewriter) const {
264  Type resultType = op.getResult().getType();
265  Type targetElementType =
266  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
267  Type sourceElementType =
268  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
269  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
270  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
271 
272  if (targetBits == sourceBits) {
273  rewriter.replaceOp(op, adaptor.getIn());
274  return success();
275  }
276 
277  // Handle the scalar and 1D vector cases.
278  Type operandType = adaptor.getIn().getType();
279  if (!isa<LLVM::LLVMArrayType>(operandType)) {
280  Type targetType = this->typeConverter->convertType(resultType);
281  if (targetBits < sourceBits)
282  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
283  adaptor.getIn());
284  else
285  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
286  return success();
287  }
288 
289  if (!isa<VectorType>(resultType))
290  return rewriter.notifyMatchFailure(op, "expected vector result type");
291 
293  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
294  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
295  typename OpTy::Adaptor adaptor(operands);
296  if (targetBits < sourceBits) {
297  return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
298  adaptor.getIn());
299  }
300  return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
301  adaptor.getIn());
302  },
303  rewriter);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // AddUIExtendedOpLowering
308 //===----------------------------------------------------------------------===//
309 
310 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
311  arith::AddUIExtendedOp op, OpAdaptor adaptor,
312  ConversionPatternRewriter &rewriter) const {
313  Type operandType = adaptor.getLhs().getType();
314  Type sumResultType = op.getSum().getType();
315  Type overflowResultType = op.getOverflow().getType();
316 
317  if (!LLVM::isCompatibleType(operandType))
318  return failure();
319 
320  MLIRContext *ctx = rewriter.getContext();
321  Location loc = op.getLoc();
322 
323  // Handle the scalar and 1D vector cases.
324  if (!isa<LLVM::LLVMArrayType>(operandType)) {
325  Type newOverflowType = typeConverter->convertType(overflowResultType);
326  Type structType =
327  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
328  Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
329  loc, structType, adaptor.getLhs(), adaptor.getRhs());
330  Value sumExtracted =
331  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
332  Value overflowExtracted =
333  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
334  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
335  return success();
336  }
337 
338  if (!isa<VectorType>(sumResultType))
339  return rewriter.notifyMatchFailure(loc, "expected vector result types");
340 
341  return rewriter.notifyMatchFailure(loc,
342  "ND vector types are not supported yet");
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // MulIExtendedOpLowering
347 //===----------------------------------------------------------------------===//
348 
349 template <typename ArithMulOp, bool IsSigned>
350 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
351  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
352  ConversionPatternRewriter &rewriter) const {
353  Type resultType = adaptor.getLhs().getType();
354 
355  if (!LLVM::isCompatibleType(resultType))
356  return failure();
357 
358  Location loc = op.getLoc();
359 
360  // Handle the scalar and 1D vector cases. Because LLVM does not have a
361  // matching extended multiplication intrinsic, perform regular multiplication
362  // on operands zero-extended to i(2*N) bits, and truncate the results back to
363  // iN types.
364  if (!isa<LLVM::LLVMArrayType>(resultType)) {
365  // Shift amount necessary to extract the high bits from widened result.
366  TypedAttr shiftValAttr;
367 
368  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
369  unsigned resultBitwidth = intTy.getWidth();
370  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
371  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
372  } else {
373  auto vecTy = cast<VectorType>(resultType);
374  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
375  auto attrTy = VectorType::get(
376  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
377  shiftValAttr = SplatElementsAttr::get(
378  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
379  }
380  Type wideType = shiftValAttr.getType();
381  assert(LLVM::isCompatibleType(wideType) &&
382  "LLVM dialect should support all signless integer types");
383 
384  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
385  Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
386  Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
387  Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
388 
389  // Split the 2*N-bit wide result into two N-bit values.
390  Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
391  Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
392  Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
393  Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
394 
395  rewriter.replaceOp(op, {low, high});
396  return success();
397  }
398 
399  if (!isa<VectorType>(resultType))
400  return rewriter.notifyMatchFailure(op, "expected vector result type");
401 
402  return rewriter.notifyMatchFailure(op,
403  "ND vector types are not supported yet");
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // CmpIOpLowering
408 //===----------------------------------------------------------------------===//
409 
410 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
411 // share numerical values so just cast.
412 template <typename LLVMPredType, typename PredType>
413 static LLVMPredType convertCmpPredicate(PredType pred) {
414  return static_cast<LLVMPredType>(pred);
415 }
416 
417 LogicalResult
418 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
419  ConversionPatternRewriter &rewriter) const {
420  Type operandType = adaptor.getLhs().getType();
421  Type resultType = op.getResult().getType();
422 
423  // Handle the scalar and 1D vector cases.
424  if (!isa<LLVM::LLVMArrayType>(operandType)) {
425  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
426  op, typeConverter->convertType(resultType),
427  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
428  adaptor.getLhs(), adaptor.getRhs());
429  return success();
430  }
431 
432  if (!isa<VectorType>(resultType))
433  return rewriter.notifyMatchFailure(op, "expected vector result type");
434 
436  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
437  [&](Type llvm1DVectorTy, ValueRange operands) {
438  OpAdaptor adaptor(operands);
439  return rewriter.create<LLVM::ICmpOp>(
440  op.getLoc(), llvm1DVectorTy,
441  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
442  adaptor.getLhs(), adaptor.getRhs());
443  },
444  rewriter);
445 }
446 
447 //===----------------------------------------------------------------------===//
448 // CmpFOpLowering
449 //===----------------------------------------------------------------------===//
450 
451 LogicalResult
452 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
453  ConversionPatternRewriter &rewriter) const {
454  Type operandType = adaptor.getLhs().getType();
455  Type resultType = op.getResult().getType();
456  LLVM::FastmathFlags fmf =
457  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
458 
459  // Handle the scalar and 1D vector cases.
460  if (!isa<LLVM::LLVMArrayType>(operandType)) {
461  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
462  op, typeConverter->convertType(resultType),
463  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
464  adaptor.getLhs(), adaptor.getRhs(), fmf);
465  return success();
466  }
467 
468  if (!isa<VectorType>(resultType))
469  return rewriter.notifyMatchFailure(op, "expected vector result type");
470 
472  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
473  [&](Type llvm1DVectorTy, ValueRange operands) {
474  OpAdaptor adaptor(operands);
475  return rewriter.create<LLVM::FCmpOp>(
476  op.getLoc(), llvm1DVectorTy,
477  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
478  adaptor.getLhs(), adaptor.getRhs(), fmf);
479  },
480  rewriter);
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Pass Definition
485 //===----------------------------------------------------------------------===//
486 
487 namespace {
488 struct ArithToLLVMConversionPass
489  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
490  using Base::Base;
491 
492  void runOnOperation() override {
495 
497  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
498  options.overrideIndexBitwidth(indexBitwidth);
499 
500  LLVMTypeConverter converter(&getContext(), options);
503 
504  if (failed(applyPartialConversion(getOperation(), target,
505  std::move(patterns))))
506  signalPassFailure();
507  }
508 };
509 } // namespace
510 
511 //===----------------------------------------------------------------------===//
512 // ConvertToLLVMPatternInterface implementation
513 //===----------------------------------------------------------------------===//
514 
515 namespace {
516 /// Implement the interface to convert MemRef to LLVM.
517 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
519  void loadDependentDialects(MLIRContext *context) const final {
520  context->loadDialect<LLVM::LLVMDialect>();
521  }
522 
523  /// Hook for derived dialect interface to provide conversion patterns
524  /// and mark dialect legal for the conversion target.
526  ConversionTarget &target, LLVMTypeConverter &typeConverter,
527  RewritePatternSet &patterns) const final {
530  }
531 };
532 } // namespace
533 
535  DialectRegistry &registry) {
536  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
537  dialect->addInterfaces<ArithToLLVMDialectInterface>();
538  });
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // Pattern Population
543 //===----------------------------------------------------------------------===//
544 
546  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
547 
548  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
549  // before BitcastOpLowering.
550  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
551  /*patternBenefit*/ 10);
552 
553  // clang-format off
554  patterns.add<
555  AddFOpLowering,
556  AddIOpLowering,
557  AndIOpLowering,
558  AddUIExtendedOpLowering,
559  BitcastOpLowering,
560  ConstantOpLowering,
561  CmpFOpLowering,
562  CmpIOpLowering,
563  DivFOpLowering,
564  DivSIOpLowering,
565  DivUIOpLowering,
566  ExtFOpLowering,
567  ExtSIOpLowering,
568  ExtUIOpLowering,
569  FPToSIOpLowering,
570  FPToUIOpLowering,
571  IndexCastOpSILowering,
572  IndexCastOpUILowering,
573  MaximumFOpLowering,
574  MaxNumFOpLowering,
575  MaxSIOpLowering,
576  MaxUIOpLowering,
577  MinimumFOpLowering,
578  MinNumFOpLowering,
579  MinSIOpLowering,
580  MinUIOpLowering,
581  MulFOpLowering,
582  MulIOpLowering,
583  MulSIExtendedOpLowering,
584  MulUIExtendedOpLowering,
585  NegFOpLowering,
586  OrIOpLowering,
587  RemFOpLowering,
588  RemSIOpLowering,
589  RemUIOpLowering,
590  SelectOpLowering,
591  ShLIOpLowering,
592  ShRSIOpLowering,
593  ShRUIOpLowering,
594  SIToFPOpLowering,
595  SubFOpLowering,
596  SubIOpLowering,
597  TruncFOpLowering,
598  ConstrainedTruncFOpLowering,
599  TruncIOpLowering,
600  UIToFPOpLowering,
601  XOrIOpLowering
602  >(converter);
603  // clang-format on
604 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:199
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:205
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:319
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:797
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.