MLIR  20.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include <type_traits>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Operations whose conversion will depend on whether they are passed a
33 /// rounding mode attribute or not.
34 ///
35 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37 /// attribute.
38 template <typename SourceOp, typename TargetOp, bool Constrained,
39  template <typename, typename> typename AttrConvert =
41 struct ConstrainedVectorConvertToLLVMPattern
42  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
43  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
44  AttrConvert>::VectorConvertToLLVMPattern;
45 
46  LogicalResult
47  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
48  ConversionPatternRewriter &rewriter) const override {
49  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
50  return failure();
51  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
52  AttrConvert>::matchAndRewrite(op, adaptor,
53  rewriter);
54  }
55 };
56 
57 //===----------------------------------------------------------------------===//
58 // Straightforward Op Lowerings
59 //===----------------------------------------------------------------------===//
60 
61 using AddFOpLowering =
62  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
63  arith::AttrConvertFastMathToLLVM>;
64 using AddIOpLowering =
65  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
66  arith::AttrConvertOverflowToLLVM>;
68 using BitcastOpLowering =
70 using DivFOpLowering =
71  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
72  arith::AttrConvertFastMathToLLVM>;
73 using DivSIOpLowering =
75 using DivUIOpLowering =
78 using ExtSIOpLowering =
80 using ExtUIOpLowering =
82 using FPToSIOpLowering =
84 using FPToUIOpLowering =
86 using MaximumFOpLowering =
87  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
88  arith::AttrConvertFastMathToLLVM>;
89 using MaxNumFOpLowering =
90  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
91  arith::AttrConvertFastMathToLLVM>;
92 using MaxSIOpLowering =
94 using MaxUIOpLowering =
96 using MinimumFOpLowering =
97  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
98  arith::AttrConvertFastMathToLLVM>;
99 using MinNumFOpLowering =
100  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
101  arith::AttrConvertFastMathToLLVM>;
102 using MinSIOpLowering =
104 using MinUIOpLowering =
106 using MulFOpLowering =
107  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
108  arith::AttrConvertFastMathToLLVM>;
109 using MulIOpLowering =
110  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
111  arith::AttrConvertOverflowToLLVM>;
112 using NegFOpLowering =
113  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
114  arith::AttrConvertFastMathToLLVM>;
116 using RemFOpLowering =
117  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
118  arith::AttrConvertFastMathToLLVM>;
119 using RemSIOpLowering =
121 using RemUIOpLowering =
123 using SelectOpLowering =
125 using ShLIOpLowering =
126  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
127  arith::AttrConvertOverflowToLLVM>;
128 using ShRSIOpLowering =
130 using ShRUIOpLowering =
132 using SIToFPOpLowering =
134 using SubFOpLowering =
135  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
136  arith::AttrConvertFastMathToLLVM>;
137 using SubIOpLowering =
138  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
139  arith::AttrConvertOverflowToLLVM>;
140 using TruncFOpLowering =
141  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
142  false>;
143 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
144  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
145  arith::AttrConverterConstrainedFPToLLVM>;
146 using TruncIOpLowering =
148 using UIToFPOpLowering =
151 
152 //===----------------------------------------------------------------------===//
153 // Op Lowering Patterns
154 //===----------------------------------------------------------------------===//
155 
156 /// Directly lower to LLVM op.
157 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
159 
160  LogicalResult
161  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override;
163 };
164 
165 /// The lowering of index_cast becomes an integer conversion since index
166 /// becomes an integer. If the bit width of the source and target integer
167 /// types is the same, just erase the cast. If the target type is wider,
168 /// sign-extend the value, otherwise truncate it.
169 template <typename OpTy, typename ExtCastTy>
170 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
172 
173  LogicalResult
174  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override;
176 };
177 
178 using IndexCastOpSILowering =
179  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
180 using IndexCastOpUILowering =
181  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
182 
183 struct AddUIExtendedOpLowering
184  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
186 
187  LogicalResult
188  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
189  ConversionPatternRewriter &rewriter) const override;
190 };
191 
192 template <typename ArithMulOp, bool IsSigned>
193 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
195 
196  LogicalResult
197  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override;
199 };
200 
201 using MulSIExtendedOpLowering =
202  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
203 using MulUIExtendedOpLowering =
204  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
205 
206 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
208 
209  LogicalResult
210  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
211  ConversionPatternRewriter &rewriter) const override;
212 };
213 
214 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
216 
217  LogicalResult
218  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const override;
220 };
221 
222 } // namespace
223 
224 //===----------------------------------------------------------------------===//
225 // ConstantOpLowering
226 //===----------------------------------------------------------------------===//
227 
228 LogicalResult
229 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
230  ConversionPatternRewriter &rewriter) const {
231  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
232  adaptor.getOperands(), op->getAttrs(),
233  *getTypeConverter(), rewriter);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // IndexCastOpLowering
238 //===----------------------------------------------------------------------===//
239 
240 template <typename OpTy, typename ExtCastTy>
241 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
242  OpTy op, typename OpTy::Adaptor adaptor,
243  ConversionPatternRewriter &rewriter) const {
244  Type resultType = op.getResult().getType();
245  Type targetElementType =
246  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
247  Type sourceElementType =
248  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
249  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
250  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
251 
252  if (targetBits == sourceBits) {
253  rewriter.replaceOp(op, adaptor.getIn());
254  return success();
255  }
256 
257  // Handle the scalar and 1D vector cases.
258  Type operandType = adaptor.getIn().getType();
259  if (!isa<LLVM::LLVMArrayType>(operandType)) {
260  Type targetType = this->typeConverter->convertType(resultType);
261  if (targetBits < sourceBits)
262  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
263  adaptor.getIn());
264  else
265  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
266  return success();
267  }
268 
269  if (!isa<VectorType>(resultType))
270  return rewriter.notifyMatchFailure(op, "expected vector result type");
271 
273  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
274  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
275  typename OpTy::Adaptor adaptor(operands);
276  if (targetBits < sourceBits) {
277  return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
278  adaptor.getIn());
279  }
280  return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
281  adaptor.getIn());
282  },
283  rewriter);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // AddUIExtendedOpLowering
288 //===----------------------------------------------------------------------===//
289 
290 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
291  arith::AddUIExtendedOp op, OpAdaptor adaptor,
292  ConversionPatternRewriter &rewriter) const {
293  Type operandType = adaptor.getLhs().getType();
294  Type sumResultType = op.getSum().getType();
295  Type overflowResultType = op.getOverflow().getType();
296 
297  if (!LLVM::isCompatibleType(operandType))
298  return failure();
299 
300  MLIRContext *ctx = rewriter.getContext();
301  Location loc = op.getLoc();
302 
303  // Handle the scalar and 1D vector cases.
304  if (!isa<LLVM::LLVMArrayType>(operandType)) {
305  Type newOverflowType = typeConverter->convertType(overflowResultType);
306  Type structType =
307  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
308  Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
309  loc, structType, adaptor.getLhs(), adaptor.getRhs());
310  Value sumExtracted =
311  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
312  Value overflowExtracted =
313  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
314  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
315  return success();
316  }
317 
318  if (!isa<VectorType>(sumResultType))
319  return rewriter.notifyMatchFailure(loc, "expected vector result types");
320 
321  return rewriter.notifyMatchFailure(loc,
322  "ND vector types are not supported yet");
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // MulIExtendedOpLowering
327 //===----------------------------------------------------------------------===//
328 
329 template <typename ArithMulOp, bool IsSigned>
330 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
331  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
332  ConversionPatternRewriter &rewriter) const {
333  Type resultType = adaptor.getLhs().getType();
334 
335  if (!LLVM::isCompatibleType(resultType))
336  return failure();
337 
338  Location loc = op.getLoc();
339 
340  // Handle the scalar and 1D vector cases. Because LLVM does not have a
341  // matching extended multiplication intrinsic, perform regular multiplication
342  // on operands zero-extended to i(2*N) bits, and truncate the results back to
343  // iN types.
344  if (!isa<LLVM::LLVMArrayType>(resultType)) {
345  // Shift amount necessary to extract the high bits from widened result.
346  TypedAttr shiftValAttr;
347 
348  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
349  unsigned resultBitwidth = intTy.getWidth();
350  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
351  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
352  } else {
353  auto vecTy = cast<VectorType>(resultType);
354  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
355  auto attrTy = VectorType::get(
356  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
357  shiftValAttr = SplatElementsAttr::get(
358  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
359  }
360  Type wideType = shiftValAttr.getType();
361  assert(LLVM::isCompatibleType(wideType) &&
362  "LLVM dialect should support all signless integer types");
363 
364  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
365  Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
366  Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
367  Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
368 
369  // Split the 2*N-bit wide result into two N-bit values.
370  Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
371  Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
372  Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
373  Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
374 
375  rewriter.replaceOp(op, {low, high});
376  return success();
377  }
378 
379  if (!isa<VectorType>(resultType))
380  return rewriter.notifyMatchFailure(op, "expected vector result type");
381 
382  return rewriter.notifyMatchFailure(op,
383  "ND vector types are not supported yet");
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // CmpIOpLowering
388 //===----------------------------------------------------------------------===//
389 
390 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
391 // share numerical values so just cast.
392 template <typename LLVMPredType, typename PredType>
393 static LLVMPredType convertCmpPredicate(PredType pred) {
394  return static_cast<LLVMPredType>(pred);
395 }
396 
397 LogicalResult
398 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
399  ConversionPatternRewriter &rewriter) const {
400  Type operandType = adaptor.getLhs().getType();
401  Type resultType = op.getResult().getType();
402 
403  // Handle the scalar and 1D vector cases.
404  if (!isa<LLVM::LLVMArrayType>(operandType)) {
405  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
406  op, typeConverter->convertType(resultType),
407  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
408  adaptor.getLhs(), adaptor.getRhs());
409  return success();
410  }
411 
412  if (!isa<VectorType>(resultType))
413  return rewriter.notifyMatchFailure(op, "expected vector result type");
414 
416  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
417  [&](Type llvm1DVectorTy, ValueRange operands) {
418  OpAdaptor adaptor(operands);
419  return rewriter.create<LLVM::ICmpOp>(
420  op.getLoc(), llvm1DVectorTy,
421  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
422  adaptor.getLhs(), adaptor.getRhs());
423  },
424  rewriter);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // CmpFOpLowering
429 //===----------------------------------------------------------------------===//
430 
431 LogicalResult
432 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
433  ConversionPatternRewriter &rewriter) const {
434  Type operandType = adaptor.getLhs().getType();
435  Type resultType = op.getResult().getType();
436  LLVM::FastmathFlags fmf =
437  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
438 
439  // Handle the scalar and 1D vector cases.
440  if (!isa<LLVM::LLVMArrayType>(operandType)) {
441  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
442  op, typeConverter->convertType(resultType),
443  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
444  adaptor.getLhs(), adaptor.getRhs(), fmf);
445  return success();
446  }
447 
448  if (!isa<VectorType>(resultType))
449  return rewriter.notifyMatchFailure(op, "expected vector result type");
450 
452  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
453  [&](Type llvm1DVectorTy, ValueRange operands) {
454  OpAdaptor adaptor(operands);
455  return rewriter.create<LLVM::FCmpOp>(
456  op.getLoc(), llvm1DVectorTy,
457  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
458  adaptor.getLhs(), adaptor.getRhs(), fmf);
459  },
460  rewriter);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // Pass Definition
465 //===----------------------------------------------------------------------===//
466 
467 namespace {
468 struct ArithToLLVMConversionPass
469  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
470  using Base::Base;
471 
472  void runOnOperation() override {
475 
477  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
478  options.overrideIndexBitwidth(indexBitwidth);
479 
480  LLVMTypeConverter converter(&getContext(), options);
483 
484  if (failed(applyPartialConversion(getOperation(), target,
485  std::move(patterns))))
486  signalPassFailure();
487  }
488 };
489 } // namespace
490 
491 //===----------------------------------------------------------------------===//
492 // ConvertToLLVMPatternInterface implementation
493 //===----------------------------------------------------------------------===//
494 
495 namespace {
496 /// Implement the interface to convert MemRef to LLVM.
497 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
499  void loadDependentDialects(MLIRContext *context) const final {
500  context->loadDialect<LLVM::LLVMDialect>();
501  }
502 
503  /// Hook for derived dialect interface to provide conversion patterns
504  /// and mark dialect legal for the conversion target.
506  ConversionTarget &target, LLVMTypeConverter &typeConverter,
507  RewritePatternSet &patterns) const final {
510  }
511 };
512 } // namespace
513 
515  DialectRegistry &registry) {
516  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
517  dialect->addInterfaces<ArithToLLVMDialectInterface>();
518  });
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // Pattern Population
523 //===----------------------------------------------------------------------===//
524 
526  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
527  // clang-format off
528  patterns.add<
529  AddFOpLowering,
530  AddIOpLowering,
531  AndIOpLowering,
532  AddUIExtendedOpLowering,
533  BitcastOpLowering,
534  ConstantOpLowering,
535  CmpFOpLowering,
536  CmpIOpLowering,
537  DivFOpLowering,
538  DivSIOpLowering,
539  DivUIOpLowering,
540  ExtFOpLowering,
541  ExtSIOpLowering,
542  ExtUIOpLowering,
543  FPToSIOpLowering,
544  FPToUIOpLowering,
545  IndexCastOpSILowering,
546  IndexCastOpUILowering,
547  MaximumFOpLowering,
548  MaxNumFOpLowering,
549  MaxSIOpLowering,
550  MaxUIOpLowering,
551  MinimumFOpLowering,
552  MinNumFOpLowering,
553  MinSIOpLowering,
554  MinUIOpLowering,
555  MulFOpLowering,
556  MulIOpLowering,
557  MulSIExtendedOpLowering,
558  MulUIExtendedOpLowering,
559  NegFOpLowering,
560  OrIOpLowering,
561  RemFOpLowering,
562  RemSIOpLowering,
563  RemUIOpLowering,
564  SelectOpLowering,
565  ShLIOpLowering,
566  ShRSIOpLowering,
567  ShRUIOpLowering,
568  SIToFPOpLowering,
569  SubFOpLowering,
570  SubIOpLowering,
571  TruncFOpLowering,
572  ConstrainedTruncFOpLowering,
573  TruncIOpLowering,
574  UIToFPOpLowering,
575  XOrIOpLowering
576  >(converter);
577  // clang-format on
578 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:149
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:338
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:390
void populateArithToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.