MLIR  17.0.0git
ArithToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Pass/Pass.h"
18 #include <type_traits>
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 
29 //===----------------------------------------------------------------------===//
30 // Straightforward Op Lowerings
31 //===----------------------------------------------------------------------===//
32 
33 using AddFOpLowering =
34  VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
35  arith::AttrConvertFastMathToLLVM>;
38 using BitcastOpLowering =
40 using DivFOpLowering =
41  VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
42  arith::AttrConvertFastMathToLLVM>;
43 using DivSIOpLowering =
45 using DivUIOpLowering =
48 using ExtSIOpLowering =
50 using ExtUIOpLowering =
52 using FPToSIOpLowering =
54 using FPToUIOpLowering =
56 using MaxFOpLowering =
57  VectorConvertToLLVMPattern<arith::MaxFOp, LLVM::MaxNumOp,
58  arith::AttrConvertFastMathToLLVM>;
59 using MaxSIOpLowering =
61 using MaxUIOpLowering =
63 using MinFOpLowering =
64  VectorConvertToLLVMPattern<arith::MinFOp, LLVM::MinNumOp,
65  arith::AttrConvertFastMathToLLVM>;
66 using MinSIOpLowering =
68 using MinUIOpLowering =
70 using MulFOpLowering =
71  VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
72  arith::AttrConvertFastMathToLLVM>;
74 using NegFOpLowering =
75  VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
76  arith::AttrConvertFastMathToLLVM>;
78 using RemFOpLowering =
79  VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
80  arith::AttrConvertFastMathToLLVM>;
81 using RemSIOpLowering =
83 using RemUIOpLowering =
85 using SelectOpLowering =
88 using ShRSIOpLowering =
90 using ShRUIOpLowering =
92 using SIToFPOpLowering =
94 using SubFOpLowering =
95  VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
96  arith::AttrConvertFastMathToLLVM>;
98 using TruncFOpLowering =
100 using TruncIOpLowering =
102 using UIToFPOpLowering =
105 
106 //===----------------------------------------------------------------------===//
107 // Op Lowering Patterns
108 //===----------------------------------------------------------------------===//
109 
110 /// Directly lower to LLVM op.
111 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
113 
115  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
116  ConversionPatternRewriter &rewriter) const override;
117 };
118 
119 /// The lowering of index_cast becomes an integer conversion since index
120 /// becomes an integer. If the bit width of the source and target integer
121 /// types is the same, just erase the cast. If the target type is wider,
122 /// sign-extend the value, otherwise truncate it.
123 template <typename OpTy, typename ExtCastTy>
124 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
126 
128  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
129  ConversionPatternRewriter &rewriter) const override;
130 };
131 
132 using IndexCastOpSILowering =
133  IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
134 using IndexCastOpUILowering =
135  IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
136 
137 struct AddUIExtendedOpLowering
138  : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
140 
142  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
143  ConversionPatternRewriter &rewriter) const override;
144 };
145 
146 template <typename ArithMulOp, bool IsSigned>
147 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
149 
151  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
152  ConversionPatternRewriter &rewriter) const override;
153 };
154 
155 using MulSIExtendedOpLowering =
156  MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
157 using MulUIExtendedOpLowering =
158  MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
159 
160 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
162 
164  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165  ConversionPatternRewriter &rewriter) const override;
166 };
167 
168 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
170 
172  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const override;
174 };
175 
176 } // namespace
177 
178 //===----------------------------------------------------------------------===//
179 // ConstantOpLowering
180 //===----------------------------------------------------------------------===//
181 
183 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const {
185  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
186  adaptor.getOperands(), op->getAttrs(),
187  *getTypeConverter(), rewriter);
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // IndexCastOpLowering
192 //===----------------------------------------------------------------------===//
193 
194 template <typename OpTy, typename ExtCastTy>
195 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
196  OpTy op, typename OpTy::Adaptor adaptor,
197  ConversionPatternRewriter &rewriter) const {
198  Type resultType = op.getResult().getType();
199  Type targetElementType =
200  this->typeConverter->convertType(getElementTypeOrSelf(resultType));
201  Type sourceElementType =
202  this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
203  unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
204  unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
205 
206  if (targetBits == sourceBits) {
207  rewriter.replaceOp(op, adaptor.getIn());
208  return success();
209  }
210 
211  // Handle the scalar and 1D vector cases.
212  Type operandType = adaptor.getIn().getType();
213  if (!operandType.isa<LLVM::LLVMArrayType>()) {
214  Type targetType = this->typeConverter->convertType(resultType);
215  if (targetBits < sourceBits)
216  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
217  adaptor.getIn());
218  else
219  rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
220  return success();
221  }
222 
223  if (!resultType.isa<VectorType>())
224  return rewriter.notifyMatchFailure(op, "expected vector result type");
225 
227  op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
228  [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
229  typename OpTy::Adaptor adaptor(operands);
230  if (targetBits < sourceBits) {
231  return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
232  adaptor.getIn());
233  }
234  return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
235  adaptor.getIn());
236  },
237  rewriter);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AddUIExtendedOpLowering
242 //===----------------------------------------------------------------------===//
243 
244 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
245  arith::AddUIExtendedOp op, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const {
247  Type operandType = adaptor.getLhs().getType();
248  Type sumResultType = op.getSum().getType();
249  Type overflowResultType = op.getOverflow().getType();
250 
251  if (!LLVM::isCompatibleType(operandType))
252  return failure();
253 
254  MLIRContext *ctx = rewriter.getContext();
255  Location loc = op.getLoc();
256 
257  // Handle the scalar and 1D vector cases.
258  if (!operandType.isa<LLVM::LLVMArrayType>()) {
259  Type newOverflowType = typeConverter->convertType(overflowResultType);
260  Type structType =
261  LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
262  Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
263  loc, structType, adaptor.getLhs(), adaptor.getRhs());
264  Value sumExtracted =
265  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
266  Value overflowExtracted =
267  rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
268  rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
269  return success();
270  }
271 
272  if (!sumResultType.isa<VectorType>())
273  return rewriter.notifyMatchFailure(loc, "expected vector result types");
274 
275  return rewriter.notifyMatchFailure(loc,
276  "ND vector types are not supported yet");
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // MulIExtendedOpLowering
281 //===----------------------------------------------------------------------===//
282 
283 template <typename ArithMulOp, bool IsSigned>
284 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
285  ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
286  ConversionPatternRewriter &rewriter) const {
287  Type resultType = adaptor.getLhs().getType();
288 
289  if (!LLVM::isCompatibleType(resultType))
290  return failure();
291 
292  Location loc = op.getLoc();
293 
294  // Handle the scalar and 1D vector cases. Because LLVM does not have a
295  // matching extended multiplication intrinsic, perform regular multiplication
296  // on operands zero-extended to i(2*N) bits, and truncate the results back to
297  // iN types.
298  if (!resultType.isa<LLVM::LLVMArrayType>()) {
299  Type wideType;
300  // Shift amount necessary to extract the high bits from widened result.
301  Attribute shiftValAttr;
302 
303  if (auto intTy = resultType.dyn_cast<IntegerType>()) {
304  unsigned resultBitwidth = intTy.getWidth();
305  wideType = rewriter.getIntegerType(resultBitwidth * 2);
306  shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
307  } else {
308  auto vecTy = resultType.cast<VectorType>();
309  unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
310  wideType = VectorType::get(vecTy.getShape(),
311  rewriter.getIntegerType(resultBitwidth * 2));
312  shiftValAttr = SplatElementsAttr::get(
313  wideType, APInt(resultBitwidth * 2, resultBitwidth));
314  }
315  assert(LLVM::isCompatibleType(wideType) &&
316  "LLVM dialect should support all signless integer types");
317 
318  using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
319  Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
320  Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
321  Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
322 
323  // Split the 2*N-bit wide result into two N-bit values.
324  Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
325  Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
326  Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
327  Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
328 
329  rewriter.replaceOp(op, {low, high});
330  return success();
331  }
332 
333  if (!resultType.isa<VectorType>())
334  return rewriter.notifyMatchFailure(op, "expected vector result type");
335 
336  return rewriter.notifyMatchFailure(op,
337  "ND vector types are not supported yet");
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // CmpIOpLowering
342 //===----------------------------------------------------------------------===//
343 
344 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
345 // share numerical values so just cast.
346 template <typename LLVMPredType, typename PredType>
347 static LLVMPredType convertCmpPredicate(PredType pred) {
348  return static_cast<LLVMPredType>(pred);
349 }
350 
352 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const {
354  Type operandType = adaptor.getLhs().getType();
355  Type resultType = op.getResult().getType();
356 
357  // Handle the scalar and 1D vector cases.
358  if (!operandType.isa<LLVM::LLVMArrayType>()) {
359  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
360  op, typeConverter->convertType(resultType),
361  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
362  adaptor.getLhs(), adaptor.getRhs());
363  return success();
364  }
365 
366  if (!resultType.isa<VectorType>())
367  return rewriter.notifyMatchFailure(op, "expected vector result type");
368 
370  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
371  [&](Type llvm1DVectorTy, ValueRange operands) {
372  OpAdaptor adaptor(operands);
373  return rewriter.create<LLVM::ICmpOp>(
374  op.getLoc(), llvm1DVectorTy,
375  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
376  adaptor.getLhs(), adaptor.getRhs());
377  },
378  rewriter);
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // CmpFOpLowering
383 //===----------------------------------------------------------------------===//
384 
386 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const {
388  Type operandType = adaptor.getLhs().getType();
389  Type resultType = op.getResult().getType();
390 
391  // Handle the scalar and 1D vector cases.
392  if (!operandType.isa<LLVM::LLVMArrayType>()) {
393  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
394  op, typeConverter->convertType(resultType),
395  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
396  adaptor.getLhs(), adaptor.getRhs());
397  return success();
398  }
399 
400  if (!resultType.isa<VectorType>())
401  return rewriter.notifyMatchFailure(op, "expected vector result type");
402 
404  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
405  [&](Type llvm1DVectorTy, ValueRange operands) {
406  OpAdaptor adaptor(operands);
407  return rewriter.create<LLVM::FCmpOp>(
408  op.getLoc(), llvm1DVectorTy,
409  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
410  adaptor.getLhs(), adaptor.getRhs());
411  },
412  rewriter);
413 }
414 
415 //===----------------------------------------------------------------------===//
416 // Pass Definition
417 //===----------------------------------------------------------------------===//
418 
419 namespace {
420 struct ArithToLLVMConversionPass
421  : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
422  using Base::Base;
423 
424  void runOnOperation() override {
425  LLVMConversionTarget target(getContext());
426  RewritePatternSet patterns(&getContext());
427 
428  LowerToLLVMOptions options(&getContext());
429  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
430  options.overrideIndexBitwidth(indexBitwidth);
431 
432  LLVMTypeConverter converter(&getContext(), options);
434 
435  if (failed(applyPartialConversion(getOperation(), target,
436  std::move(patterns))))
437  signalPassFailure();
438  }
439 };
440 } // namespace
441 
442 //===----------------------------------------------------------------------===//
443 // Pattern Population
444 //===----------------------------------------------------------------------===//
445 
447  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
448  // clang-format off
449  patterns.add<
450  AddFOpLowering,
451  AddIOpLowering,
452  AndIOpLowering,
453  AddUIExtendedOpLowering,
454  BitcastOpLowering,
455  ConstantOpLowering,
456  CmpFOpLowering,
457  CmpIOpLowering,
458  DivFOpLowering,
459  DivSIOpLowering,
460  DivUIOpLowering,
461  ExtFOpLowering,
462  ExtSIOpLowering,
463  ExtUIOpLowering,
464  FPToSIOpLowering,
465  FPToUIOpLowering,
466  IndexCastOpSILowering,
467  IndexCastOpUILowering,
468  MaxFOpLowering,
469  MaxSIOpLowering,
470  MaxUIOpLowering,
471  MinFOpLowering,
472  MinSIOpLowering,
473  MinUIOpLowering,
474  MulFOpLowering,
475  MulIOpLowering,
476  MulSIExtendedOpLowering,
477  MulUIExtendedOpLowering,
478  NegFOpLowering,
479  OrIOpLowering,
480  RemFOpLowering,
481  RemSIOpLowering,
482  RemUIOpLowering,
483  SelectOpLowering,
484  ShLIOpLowering,
485  ShRSIOpLowering,
486  ShRUIOpLowering,
487  SIToFPOpLowering,
488  SubFOpLowering,
489  SubIOpLowering,
490  TruncFOpLowering,
491  TruncIOpLowering,
492  UIToFPOpLowering,
493  XOrIOpLowering
494  >(converter);
495  // clang-format on
496 }
static LLVMPredType convertCmpPredicate(PredType pred)
static llvm::ManagedStatic< PassManagerOptions > options
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:84
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:135
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:139
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:87
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:323
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:824
void populateArithToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
LLVM_ATTRIBUTE_ALWAYS_INLINE bool addOverflow(int64_t x, int64_t y, int64_t &result)
If builtin intrinsics for overflow-checked arithmetic are available, use them.
Definition: MPInt.h:45
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26