MLIR  20.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Pass/Pass.h"
20 #include <type_traits>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Operations whose conversion will depend on whether they are passed a
32 /// rounding mode attribute or not.
33 ///
34 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
35 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
36 /// attribute.
37 template <typename SourceOp, typename TargetOp, bool Constrained,
38  template <typename, typename> typename AttrConvert =
40 struct ConstrainedVectorConvertToLLVMPattern
41  : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
42  using VectorConvertToLLVMPattern<SourceOp, TargetOp,
43  AttrConvert>::VectorConvertToLLVMPattern;
44 
45  LogicalResult
46  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override {
48  if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
49  return failure();
50  return VectorConvertToLLVMPattern<SourceOp, TargetOp,
51  AttrConvert>::matchAndRewrite(op, adaptor,
52  rewriter);
53  }
54 };
55 
56 //===----------------------------------------------------------------------===//
57 // Straightforward Op Lowerings
58 //===----------------------------------------------------------------------===//
59 
60 using AddFOpLowering =
61  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
62  arith::AttrConvertFastMathToLLVM>;
63 using AddIOpLowering =
64  VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
65  arith::AttrConvertOverflowToLLVM>;
67 using BitcastOpLowering =
69 using DivFOpLowering =
70  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
71  arith::AttrConvertFastMathToLLVM>;
72 using DivSIOpLowering =
74 using DivUIOpLowering =
77 using ExtSIOpLowering =
79 using ExtUIOpLowering =
81 using FPToSIOpLowering =
83 using FPToUIOpLowering =
85 using MaximumFOpLowering =
86  VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
87  arith::AttrConvertFastMathToLLVM>;
88 using MaxNumFOpLowering =
89  VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
90  arith::AttrConvertFastMathToLLVM>;
91 using MaxSIOpLowering =
93 using MaxUIOpLowering =
95 using MinimumFOpLowering =
96  VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
97  arith::AttrConvertFastMathToLLVM>;
98 using MinNumFOpLowering =
99  VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
100  arith::AttrConvertFastMathToLLVM>;
101 using MinSIOpLowering =
103 using MinUIOpLowering =
105 using MulFOpLowering =
106  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
107  arith::AttrConvertFastMathToLLVM>;
108 using MulIOpLowering =
109  VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
110  arith::AttrConvertOverflowToLLVM>;
111 using NegFOpLowering =
112  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
113  arith::AttrConvertFastMathToLLVM>;
115 using RemFOpLowering =
116  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
117  arith::AttrConvertFastMathToLLVM>;
118 using RemSIOpLowering =
120 using RemUIOpLowering =
122 using SelectOpLowering =
124 using ShLIOpLowering =
125  VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
126  arith::AttrConvertOverflowToLLVM>;
127 using ShRSIOpLowering =
129 using ShRUIOpLowering =
131 using SIToFPOpLowering =
133 using SubFOpLowering =
134  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
135  arith::AttrConvertFastMathToLLVM>;
136 using SubIOpLowering =
137  VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
138  arith::AttrConvertOverflowToLLVM>;
139 using TruncFOpLowering =
140  ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
141  false>;
142 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
143  arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
144  arith::AttrConverterConstrainedFPToLLVM>;
145 using TruncIOpLowering =
147 using UIToFPOpLowering =
150 
151 //===----------------------------------------------------------------------===//
152 // Op Lowering Patterns
153 //===----------------------------------------------------------------------===//
154 
155 /// Directly lower to LLVM op.
156 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
158 
159  LogicalResult
160  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
161  ConversionPatternRewriter &rewriter) const override;
162 };
163 
164 /// The lowering of index_cast becomes an integer conversion since index
165 /// becomes an integer. If the bit width of the source and target integer
166 /// types is the same, just erase the cast. If the target type is wider,
167 /// sign-extend the value, otherwise truncate it.
168 template <typename OpTy, typename ExtCastTy>
169 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
171 
172  LogicalResult
173  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
174  ConversionPatternRewriter &rewriter) const override;
175 };
176 
177 using IndexCastOpSILowering =
178  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
179 using IndexCastOpUILowering =
180  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
181 
182 struct AddUIExtendedOpLowering
183  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
185 
186  LogicalResult
187  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
188  ConversionPatternRewriter &rewriter) const override;
189 };
190 
191 template <typename ArithMulOp, bool IsSigned>
192 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
194 
195  LogicalResult
196  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override;
198 };
199 
200 using MulSIExtendedOpLowering =
201  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
202 using MulUIExtendedOpLowering =
203  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
204 
205 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
207 
208  LogicalResult
209  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
210  ConversionPatternRewriter &rewriter) const override;
211 };
212 
213 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
215 
216  LogicalResult
217  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
218  ConversionPatternRewriter &rewriter) const override;
219 };
220 
221 } // namespace
222 
223 //===----------------------------------------------------------------------===//
224 // ConstantOpLowering
225 //===----------------------------------------------------------------------===//
226 
227 LogicalResult
228 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
229  ConversionPatternRewriter &rewriter) const {
230  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
231  adaptor.getOperands(), op->getAttrs(),
232  *getTypeConverter(), rewriter);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // IndexCastOpLowering
237 //===----------------------------------------------------------------------===//
238 
239 template <typename OpTy, typename ExtCastTy>
240 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
241  OpTy op, typename OpTy::Adaptor adaptor,
242  ConversionPatternRewriter &rewriter) const {
243  Type resultType = op.getResult().getType();
244  Type targetElementType =
245  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
246  Type sourceElementType =
247  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
248  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
249  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
250 
251  if (targetBits == sourceBits) {
252  rewriter.replaceOp(op, adaptor.getIn());
253  return success();
254  }
255 
256  // Handle the scalar and 1D vector cases.
257  Type operandType = adaptor.getIn().getType();
258  if (!isa<LLVM::LLVMArrayType>(operandType)) {
259  Type targetType = this->typeConverter->convertType(resultType);
260  if (targetBits < sourceBits)
261  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
262  adaptor.getIn());
263  else
264  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
265  return success();
266  }
267 
268  if (!isa<VectorType>(resultType))
269  return rewriter.notifyMatchFailure(op, "expected vector result type");
270 
272  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
273  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
274  typename OpTy::Adaptor adaptor(operands);
275  if (targetBits < sourceBits) {
276  return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
277  adaptor.getIn());
278  }
279  return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
280  adaptor.getIn());
281  },
282  rewriter);
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // AddUIExtendedOpLowering
287 //===----------------------------------------------------------------------===//
288 
289 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
290  arith::AddUIExtendedOp op, OpAdaptor adaptor,
291  ConversionPatternRewriter &rewriter) const {
292  Type operandType = adaptor.getLhs().getType();
293  Type sumResultType = op.getSum().getType();
294  Type overflowResultType = op.getOverflow().getType();
295 
296  if (!LLVM::isCompatibleType(operandType))
297  return failure();
298 
299  MLIRContext *ctx = rewriter.getContext();
300  Location loc = op.getLoc();
301 
302  // Handle the scalar and 1D vector cases.
303  if (!isa<LLVM::LLVMArrayType>(operandType)) {
304  Type newOverflowType = typeConverter->convertType(overflowResultType);
305  Type structType =
306  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
307  Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
308  loc, structType, adaptor.getLhs(), adaptor.getRhs());
309  Value sumExtracted =
310  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
311  Value overflowExtracted =
312  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
313  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
314  return success();
315  }
316 
317  if (!isa<VectorType>(sumResultType))
318  return rewriter.notifyMatchFailure(loc, "expected vector result types");
319 
320  return rewriter.notifyMatchFailure(loc,
321  "ND vector types are not supported yet");
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // MulIExtendedOpLowering
326 //===----------------------------------------------------------------------===//
327 
328 template <typename ArithMulOp, bool IsSigned>
329 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
330  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
331  ConversionPatternRewriter &rewriter) const {
332  Type resultType = adaptor.getLhs().getType();
333 
334  if (!LLVM::isCompatibleType(resultType))
335  return failure();
336 
337  Location loc = op.getLoc();
338 
339  // Handle the scalar and 1D vector cases. Because LLVM does not have a
340  // matching extended multiplication intrinsic, perform regular multiplication
341  // on operands zero-extended to i(2*N) bits, and truncate the results back to
342  // iN types.
343  if (!isa<LLVM::LLVMArrayType>(resultType)) {
344  // Shift amount necessary to extract the high bits from widened result.
345  TypedAttr shiftValAttr;
346 
347  if (auto intTy = dyn_cast<IntegerType>(resultType)) {
348  unsigned resultBitwidth = intTy.getWidth();
349  auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
350  shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
351  } else {
352  auto vecTy = cast<VectorType>(resultType);
353  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
354  auto attrTy = VectorType::get(
355  vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
356  shiftValAttr = SplatElementsAttr::get(
357  attrTy, APInt(resultBitwidth * 2, resultBitwidth));
358  }
359  Type wideType = shiftValAttr.getType();
360  assert(LLVM::isCompatibleType(wideType) &&
361  "LLVM dialect should support all signless integer types");
362 
363  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
364  Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
365  Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
366  Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
367 
368  // Split the 2*N-bit wide result into two N-bit values.
369  Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
370  Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
371  Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
372  Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
373 
374  rewriter.replaceOp(op, {low, high});
375  return success();
376  }
377 
378  if (!isa<VectorType>(resultType))
379  return rewriter.notifyMatchFailure(op, "expected vector result type");
380 
381  return rewriter.notifyMatchFailure(op,
382  "ND vector types are not supported yet");
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // CmpIOpLowering
387 //===----------------------------------------------------------------------===//
388 
389 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
390 // share numerical values so just cast.
391 template <typename LLVMPredType, typename PredType>
392 static LLVMPredType convertCmpPredicate(PredType pred) {
393  return static_cast<LLVMPredType>(pred);
394 }
395 
396 LogicalResult
397 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
398  ConversionPatternRewriter &rewriter) const {
399  Type operandType = adaptor.getLhs().getType();
400  Type resultType = op.getResult().getType();
401 
402  // Handle the scalar and 1D vector cases.
403  if (!isa<LLVM::LLVMArrayType>(operandType)) {
404  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
405  op, typeConverter->convertType(resultType),
406  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
407  adaptor.getLhs(), adaptor.getRhs());
408  return success();
409  }
410 
411  if (!isa<VectorType>(resultType))
412  return rewriter.notifyMatchFailure(op, "expected vector result type");
413 
415  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
416  [&](Type llvm1DVectorTy, ValueRange operands) {
417  OpAdaptor adaptor(operands);
418  return rewriter.create<LLVM::ICmpOp>(
419  op.getLoc(), llvm1DVectorTy,
420  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
421  adaptor.getLhs(), adaptor.getRhs());
422  },
423  rewriter);
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // CmpFOpLowering
428 //===----------------------------------------------------------------------===//
429 
430 LogicalResult
431 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const {
433  Type operandType = adaptor.getLhs().getType();
434  Type resultType = op.getResult().getType();
435  LLVM::FastmathFlags fmf =
436  arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
437 
438  // Handle the scalar and 1D vector cases.
439  if (!isa<LLVM::LLVMArrayType>(operandType)) {
440  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
441  op, typeConverter->convertType(resultType),
442  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
443  adaptor.getLhs(), adaptor.getRhs(), fmf);
444  return success();
445  }
446 
447  if (!isa<VectorType>(resultType))
448  return rewriter.notifyMatchFailure(op, "expected vector result type");
449 
451  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
452  [&](Type llvm1DVectorTy, ValueRange operands) {
453  OpAdaptor adaptor(operands);
454  return rewriter.create<LLVM::FCmpOp>(
455  op.getLoc(), llvm1DVectorTy,
456  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
457  adaptor.getLhs(), adaptor.getRhs(), fmf);
458  },
459  rewriter);
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // Pass Definition
464 //===----------------------------------------------------------------------===//
465 
466 namespace {
467 struct ArithToLLVMConversionPass
468  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
469  using Base::Base;
470 
471  void runOnOperation() override {
473  RewritePatternSet patterns(&getContext());
474 
476  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
477  options.overrideIndexBitwidth(indexBitwidth);
478 
479  LLVMTypeConverter converter(&getContext(), options);
481 
482  if (failed(applyPartialConversion(getOperation(), target,
483  std::move(patterns))))
484  signalPassFailure();
485  }
486 };
487 } // namespace
488 
489 //===----------------------------------------------------------------------===//
490 // ConvertToLLVMPatternInterface implementation
491 //===----------------------------------------------------------------------===//
492 
493 namespace {
494 /// Implement the interface to convert MemRef to LLVM.
495 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
497  void loadDependentDialects(MLIRContext *context) const final {
498  context->loadDialect<LLVM::LLVMDialect>();
499  }
500 
501  /// Hook for derived dialect interface to provide conversion patterns
502  /// and mark dialect legal for the conversion target.
504  ConversionTarget &target, LLVMTypeConverter &typeConverter,
505  RewritePatternSet &patterns) const final {
506  arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
507  }
508 };
509 } // namespace
510 
512  DialectRegistry &registry) {
513  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
514  dialect->addInterfaces<ArithToLLVMDialectInterface>();
515  });
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Pattern Population
520 //===----------------------------------------------------------------------===//
521 
523  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
524  // clang-format off
525  patterns.add<
526  AddFOpLowering,
527  AddIOpLowering,
528  AndIOpLowering,
529  AddUIExtendedOpLowering,
530  BitcastOpLowering,
531  ConstantOpLowering,
532  CmpFOpLowering,
533  CmpIOpLowering,
534  DivFOpLowering,
535  DivSIOpLowering,
536  DivUIOpLowering,
537  ExtFOpLowering,
538  ExtSIOpLowering,
539  ExtUIOpLowering,
540  FPToSIOpLowering,
541  FPToUIOpLowering,
542  IndexCastOpSILowering,
543  IndexCastOpUILowering,
544  MaximumFOpLowering,
545  MaxNumFOpLowering,
546  MaxSIOpLowering,
547  MaxUIOpLowering,
548  MinimumFOpLowering,
549  MinNumFOpLowering,
550  MinSIOpLowering,
551  MinUIOpLowering,
552  MulFOpLowering,
553  MulIOpLowering,
554  MulSIExtendedOpLowering,
555  MulUIExtendedOpLowering,
556  NegFOpLowering,
557  OrIOpLowering,
558  RemFOpLowering,
559  RemSIOpLowering,
560  RemUIOpLowering,
561  SelectOpLowering,
562  ShLIOpLowering,
563  ShRSIOpLowering,
564  ShRUIOpLowering,
565  SIToFPOpLowering,
566  SubFOpLowering,
567  SubIOpLowering,
568  TruncFOpLowering,
569  ConstrainedTruncFOpLowering,
570  TruncIOpLowering,
571  UIToFPOpLowering,
572  XOrIOpLowering
573  >(converter);
574  // clang-format on
575 }
static LLVMPredType convertCmpPredicate(PredType pred)
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:250
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:99
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
virtual void loadDependentDialects(MLIRContext *context) const
Hook for derived dialect interface to load the dialects they target.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:90
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:340
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:856
void registerConvertArithToLLVMInterface(DialectRegistry &registry)
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.