MLIR  15.0.0git
ArithmeticToLLVM.cpp
Go to the documentation of this file.
1 //===- ArithmeticToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 //===----------------------------------------------------------------------===//
22 // Straightforward Op Lowerings
23 //===----------------------------------------------------------------------===//
24 
28 using DivUIOpLowering =
30 using DivSIOpLowering =
32 using RemUIOpLowering =
34 using RemSIOpLowering =
40 using ShRUIOpLowering =
42 using ShRSIOpLowering =
50 using ExtUIOpLowering =
52 using ExtSIOpLowering =
55 using TruncIOpLowering =
57 using TruncFOpLowering =
59 using UIToFPOpLowering =
61 using SIToFPOpLowering =
63 using FPToUIOpLowering =
65 using FPToSIOpLowering =
67 using BitcastOpLowering =
69 using SelectOpLowering =
71 
72 //===----------------------------------------------------------------------===//
73 // Op Lowering Patterns
74 //===----------------------------------------------------------------------===//
75 
76 /// Directly lower to LLVM op.
77 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
79 
81  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
82  ConversionPatternRewriter &rewriter) const override;
83 };
84 
85 /// The lowering of index_cast becomes an integer conversion since index
86 /// becomes an integer. If the bit width of the source and target integer
87 /// types is the same, just erase the cast. If the target type is wider,
88 /// sign-extend the value, otherwise truncate it.
89 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<arith::IndexCastOp> {
91 
93  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
94  ConversionPatternRewriter &rewriter) const override;
95 };
96 
97 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
99 
101  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override;
103 };
104 
105 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
107 
109  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override;
111 };
112 
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // ConstantOpLowering
117 //===----------------------------------------------------------------------===//
118 
120 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
121  ConversionPatternRewriter &rewriter) const {
122  return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
123  adaptor.getOperands(),
124  *getTypeConverter(), rewriter);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // IndexCastOpLowering
129 //===----------------------------------------------------------------------===//
130 
131 LogicalResult IndexCastOpLowering::matchAndRewrite(
132  arith::IndexCastOp op, OpAdaptor adaptor,
133  ConversionPatternRewriter &rewriter) const {
134  auto targetType = typeConverter->convertType(op.getResult().getType());
135  auto targetElementType =
136  typeConverter->convertType(getElementTypeOrSelf(op.getResult()))
137  .cast<IntegerType>();
138  auto sourceElementType =
139  getElementTypeOrSelf(adaptor.getIn()).cast<IntegerType>();
140  unsigned targetBits = targetElementType.getWidth();
141  unsigned sourceBits = sourceElementType.getWidth();
142 
143  if (targetBits == sourceBits)
144  rewriter.replaceOp(op, adaptor.getIn());
145  else if (targetBits < sourceBits)
146  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, adaptor.getIn());
147  else
148  rewriter.replaceOpWithNewOp<LLVM::SExtOp>(op, targetType, adaptor.getIn());
149  return success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // CmpIOpLowering
154 //===----------------------------------------------------------------------===//
155 
156 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
157 // share numerical values so just cast.
158 template <typename LLVMPredType, typename PredType>
159 static LLVMPredType convertCmpPredicate(PredType pred) {
160  return static_cast<LLVMPredType>(pred);
161 }
162 
164 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
165  ConversionPatternRewriter &rewriter) const {
166  auto operandType = adaptor.getLhs().getType();
167  auto resultType = op.getResult().getType();
168 
169  // Handle the scalar and 1D vector cases.
170  if (!operandType.isa<LLVM::LLVMArrayType>()) {
171  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
172  op, typeConverter->convertType(resultType),
173  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
174  adaptor.getLhs(), adaptor.getRhs());
175  return success();
176  }
177 
178  auto vectorType = resultType.dyn_cast<VectorType>();
179  if (!vectorType)
180  return rewriter.notifyMatchFailure(op, "expected vector result type");
181 
183  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
184  [&](Type llvm1DVectorTy, ValueRange operands) {
185  OpAdaptor adaptor(operands);
186  return rewriter.create<LLVM::ICmpOp>(
187  op.getLoc(), llvm1DVectorTy,
188  convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
189  adaptor.getLhs(), adaptor.getRhs());
190  },
191  rewriter);
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // CmpFOpLowering
196 //===----------------------------------------------------------------------===//
197 
199 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
200  ConversionPatternRewriter &rewriter) const {
201  auto operandType = adaptor.getLhs().getType();
202  auto resultType = op.getResult().getType();
203 
204  // Handle the scalar and 1D vector cases.
205  if (!operandType.isa<LLVM::LLVMArrayType>()) {
206  rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
207  op, typeConverter->convertType(resultType),
208  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
209  adaptor.getLhs(), adaptor.getRhs());
210  return success();
211  }
212 
213  auto vectorType = resultType.dyn_cast<VectorType>();
214  if (!vectorType)
215  return rewriter.notifyMatchFailure(op, "expected vector result type");
216 
218  op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
219  [&](Type llvm1DVectorTy, ValueRange operands) {
220  OpAdaptor adaptor(operands);
221  return rewriter.create<LLVM::FCmpOp>(
222  op.getLoc(), llvm1DVectorTy,
223  convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
224  adaptor.getLhs(), adaptor.getRhs());
225  },
226  rewriter);
227 }
228 
229 //===----------------------------------------------------------------------===//
230 // Pass Definition
231 //===----------------------------------------------------------------------===//
232 
233 namespace {
234 struct ConvertArithmeticToLLVMPass
235  : public ConvertArithmeticToLLVMBase<ConvertArithmeticToLLVMPass> {
236  ConvertArithmeticToLLVMPass() = default;
237 
238  void runOnOperation() override {
239  LLVMConversionTarget target(getContext());
240  RewritePatternSet patterns(&getContext());
241 
242  LowerToLLVMOptions options(&getContext());
243  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
244  options.overrideIndexBitwidth(indexBitwidth);
245 
246  LLVMTypeConverter converter(&getContext(), options);
248  patterns);
249 
250  if (failed(applyPartialConversion(getOperation(), target,
251  std::move(patterns))))
252  signalPassFailure();
253  }
254 };
255 } // namespace
256 
257 //===----------------------------------------------------------------------===//
258 // Pattern Population
259 //===----------------------------------------------------------------------===//
260 
262  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
263  // clang-format off
264  patterns.add<
265  ConstantOpLowering,
266  AddIOpLowering,
267  SubIOpLowering,
268  MulIOpLowering,
269  DivUIOpLowering,
270  DivSIOpLowering,
271  RemUIOpLowering,
272  RemSIOpLowering,
273  AndIOpLowering,
274  OrIOpLowering,
275  XOrIOpLowering,
276  ShLIOpLowering,
277  ShRUIOpLowering,
278  ShRSIOpLowering,
279  NegFOpLowering,
280  AddFOpLowering,
281  SubFOpLowering,
282  MulFOpLowering,
283  DivFOpLowering,
284  RemFOpLowering,
285  ExtUIOpLowering,
286  ExtSIOpLowering,
287  ExtFOpLowering,
288  TruncIOpLowering,
289  TruncFOpLowering,
290  UIToFPOpLowering,
291  SIToFPOpLowering,
292  FPToUIOpLowering,
293  FPToSIOpLowering,
294  IndexCastOpLowering,
295  BitcastOpLowering,
296  CmpIOpLowering,
297  CmpFOpLowering,
298  SelectOpLowering
299  >(converter);
300  // clang-format on
301 }
302 
304  return std::make_unique<ConvertArithmeticToLLVMPass>();
305 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
Definition: VectorPattern.h:67
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
Derived class that automatically populates legalization information for different LLVM ops...
void populateArithmeticToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createConvertArithmeticToLLVMPass()
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void overrideIndexBitwidth(unsigned bitwidth)
Set the index bitwidth to the given value.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LLVM dialect array type.
Definition: LLVMTypes.h:74
static llvm::ManagedStatic< PassManagerOptions > options
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Options to control the LLVM lowering.
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:310
This class provides an abstraction over the different types of ranges over Values.
static LLVMPredType convertCmpPredicate(PredType pred)