MLIR  19.0.0git
IndexToLLVM.cpp
Go to the documentation of this file.
1 //===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/Pass/Pass.h"
18 
19 using namespace mlir;
20 using namespace index;
21 
22 namespace {
23 
24 //===----------------------------------------------------------------------===//
25 // ConvertIndexCeilDivS
26 //===----------------------------------------------------------------------===//
27 
28 /// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
29 /// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
30 struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
32 
34  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
35  ConversionPatternRewriter &rewriter) const override {
36  Location loc = op.getLoc();
37  Value n = adaptor.getLhs();
38  Value m = adaptor.getRhs();
39  Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
40  Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
41  Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
42 
43  // Compute `x`.
44  Value mPos =
45  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
46  Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
47 
48  // Compute the positive result.
49  Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
50  Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
51  Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
52 
53  // Compute the negative result.
54  Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
55  Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
56  Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
57 
58  // Pick the positive result if `n` and `m` have the same sign and `n` is
59  // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
60  Value nPos =
61  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
62  Value sameSign =
63  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
64  Value nNonZero =
65  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
66  Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
67  rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
68  return success();
69  }
70 };
71 
72 //===----------------------------------------------------------------------===//
73 // ConvertIndexCeilDivU
74 //===----------------------------------------------------------------------===//
75 
76 /// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
77 struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
79 
81  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
82  ConversionPatternRewriter &rewriter) const override {
83  Location loc = op.getLoc();
84  Value n = adaptor.getLhs();
85  Value m = adaptor.getRhs();
86  Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
87  Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
88 
89  // Compute the non-zero result.
90  Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
91  Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
92  Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
93 
94  // Pick the result.
95  Value cmp =
96  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
97  rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
98  return success();
99  }
100 };
101 
102 //===----------------------------------------------------------------------===//
103 // ConvertIndexFloorDivS
104 //===----------------------------------------------------------------------===//
105 
106 /// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
107 /// `n*m < 0 ? -1 - (x-n)/m : n/m`.
108 struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
110 
112  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
113  ConversionPatternRewriter &rewriter) const override {
114  Location loc = op.getLoc();
115  Value n = adaptor.getLhs();
116  Value m = adaptor.getRhs();
117  Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
118  Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
119  Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
120 
121  // Compute `x`.
122  Value mNeg =
123  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
124  Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
125 
126  // Compute the negative result.
127  Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
128  Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
129  Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
130 
131  // Compute the positive result.
132  Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
133 
134  // Pick the negative result if `n` and `m` have different signs and `n` is
135  // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
136  Value nNeg =
137  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
138  Value diffSign =
139  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
140  Value nNonZero =
141  rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
142  Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
143  rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
144  return success();
145  }
146 };
147 
148 //===----------------------------------------------------------------------===//
149 // CovnertIndexCast
150 //===----------------------------------------------------------------------===//
151 
152 /// Convert a cast op. If the materialized index type is the same as the other
153 /// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
154 /// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
155 /// zero extend when the result bitwidth is larger.
156 template <typename CastOp, typename ExtOp>
157 struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
159 
161  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override {
163  Type in = adaptor.getInput().getType();
164  Type out = this->getTypeConverter()->convertType(op.getType());
165  if (in == out)
166  rewriter.replaceOp(op, adaptor.getInput());
167  else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
168  rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
169  else
170  rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
171  return success();
172  }
173 };
174 
175 using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
176 using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;
177 
178 //===----------------------------------------------------------------------===//
179 // ConvertIndexCmp
180 //===----------------------------------------------------------------------===//
181 
182 /// Assert that the LLVM comparison enum lines up with index's enum.
183 static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
184  IndexCmpPredicate rhs) {
185  return static_cast<int>(lhs) == static_cast<int>(rhs);
186 }
187 
188 static_assert(
189  LLVM::getMaxEnumValForICmpPredicate() ==
190  getMaxEnumValForIndexCmpPredicate() &&
191  checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
192  checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
193  checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
194  checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
195  checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
196  checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
197  checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
198  checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
199  checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
200  checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
201  "LLVM ICmpPredicate mismatches IndexCmpPredicate");
202 
203 struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
205 
207  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  // The LLVM enum has the same values as the index predicate enums.
210  rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
211  op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
212  adaptor.getLhs(), adaptor.getRhs());
213  return success();
214  }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConvertIndexSizeOf
219 //===----------------------------------------------------------------------===//
220 
221 /// Lower `index.sizeof` to a constant with the value of the index bitwidth.
222 struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
224 
226  matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
227  ConversionPatternRewriter &rewriter) const override {
228  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
229  op, getTypeConverter()->getIndexType(),
230  getTypeConverter()->getIndexTypeBitwidth());
231  return success();
232  }
233 };
234 
235 //===----------------------------------------------------------------------===//
236 // ConvertIndexConstant
237 //===----------------------------------------------------------------------===//
238 
239 /// Convert an index constant. Truncate the value as appropriate.
240 struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
242 
244  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
245  ConversionPatternRewriter &rewriter) const override {
246  Type type = getTypeConverter()->getIndexType();
247  APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
248  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
249  op, type, IntegerAttr::get(type, value));
250  return success();
251  }
252 };
253 
254 //===----------------------------------------------------------------------===//
255 // Trivial Conversions
256 //===----------------------------------------------------------------------===//
257 
261 using ConvertIndexDivS =
263 using ConvertIndexDivU =
265 using ConvertIndexRemS =
267 using ConvertIndexRemU =
269 using ConvertIndexMaxS =
271 using ConvertIndexMaxU =
273 using ConvertIndexMinS =
275 using ConvertIndexMinU =
278 using ConvertIndexShrS =
280 using ConvertIndexShrU =
285 using ConvertIndexBoolConstant =
287 
288 } // namespace
289 
290 //===----------------------------------------------------------------------===//
291 // Pattern Population
292 //===----------------------------------------------------------------------===//
293 
295  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
296  patterns.insert<
297  // clang-format off
298  ConvertIndexAdd,
299  ConvertIndexSub,
300  ConvertIndexMul,
301  ConvertIndexDivS,
302  ConvertIndexDivU,
303  ConvertIndexRemS,
304  ConvertIndexRemU,
305  ConvertIndexMaxS,
306  ConvertIndexMaxU,
307  ConvertIndexMinS,
308  ConvertIndexMinU,
309  ConvertIndexShl,
310  ConvertIndexShrS,
311  ConvertIndexShrU,
312  ConvertIndexAnd,
313  ConvertIndexOr,
314  ConvertIndexXor,
315  ConvertIndexCeilDivS,
316  ConvertIndexCeilDivU,
317  ConvertIndexFloorDivS,
318  ConvertIndexCastS,
319  ConvertIndexCastU,
320  ConvertIndexCmp,
321  ConvertIndexSizeOf,
322  ConvertIndexConstant,
323  ConvertIndexBoolConstant
324  // clang-format on
325  >(typeConverter);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // ODS-Generated Definitions
330 //===----------------------------------------------------------------------===//
331 
332 namespace mlir {
333 #define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
334 #include "mlir/Conversion/Passes.h.inc"
335 } // namespace mlir
336 
337 //===----------------------------------------------------------------------===//
338 // Pass Definition
339 //===----------------------------------------------------------------------===//
340 
341 namespace {
342 struct ConvertIndexToLLVMPass
343  : public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
344  using Base::Base;
345 
346  void runOnOperation() override;
347 };
348 } // namespace
349 
350 void ConvertIndexToLLVMPass::runOnOperation() {
351  // Configure dialect conversion.
352  ConversionTarget target(getContext());
353  target.addIllegalDialect<IndexDialect>();
354  target.addLegalDialect<LLVM::LLVMDialect>();
355 
356  // Set LLVM lowering options.
358  if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
359  options.overrideIndexBitwidth(indexBitwidth);
360  LLVMTypeConverter typeConverter(&getContext(), options);
361 
362  // Populate patterns and run the conversion.
363  RewritePatternSet patterns(&getContext());
364  populateIndexToLLVMConversionPatterns(typeConverter, patterns);
365 
366  if (failed(
367  applyPartialConversion(getOperation(), target, std::move(patterns))))
368  return signalPassFailure();
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // ConvertToLLVMPatternInterface implementation
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 /// Implement the interface to convert Index to LLVM.
377 struct IndexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
379  void loadDependentDialects(MLIRContext *context) const final {
380  context->loadDialect<LLVM::LLVMDialect>();
381  }
382 
383  /// Hook for derived dialect interface to provide conversion patterns
384  /// and mark dialect legal for the conversion target.
385  void populateConvertToLLVMConversionPatterns(
386  ConversionTarget &target, LLVMTypeConverter &typeConverter,
387  RewritePatternSet &patterns) const final {
388  populateIndexToLLVMConversionPatterns(typeConverter, patterns);
389  }
390 };
391 } // namespace
392 
394  DialectRegistry &registry) {
395  registry.addExtension(+[](MLIRContext *ctx, index::IndexDialect *dialect) {
396  dialect->addInterfaces<IndexToLLVMDialectInterface>();
397  });
398 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateIndexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
void registerConvertIndexToLLVMInterface(DialectRegistry &registry)
Include the generated interface declarations.
static constexpr unsigned kDeriveIndexBitwidthFromDataLayout
Value to pass as bitwidth for the index type when the converter is expected to derive the bitwidth fr...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26