MLIR  20.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/MathExtras.h"
23 #include <cassert>
24 
25 namespace mlir::arith {
26 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28 } // namespace mlir::arith
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Common Helper Functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
37 /// Treats `value` as a 2*N bits-wide integer.
38 /// The bottom bits are returned in the first pair element, while the top bits
39 /// in the second one.
40 static std::pair<APInt, APInt> getHalves(const APInt &value,
41  unsigned newBitWidth) {
42  APInt low = value.extractBits(newBitWidth, 0);
43  APInt high = value.extractBits(newBitWidth, newBitWidth);
44  return {std::move(low), std::move(high)};
45 }
46 
47 /// Returns the type with the last (innermost) dimension reduced to x1.
48 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
49 /// e.g.:
50 /// - vector<3x2xi16> --> vector<3x1xi16>
51 /// - vector<2xi16> --> i16
52 static Type reduceInnermostDim(VectorType type) {
53  if (type.getShape().size() == 1)
54  return type.getElementType();
55 
56  auto newShape = to_vector(type.getShape());
57  newShape.back() = 1;
58  return VectorType::get(newShape, type.getElementType());
59 }
60 
61 /// Extracts the `input` vector slice with elements at the last dimension offset
62 /// by `lastOffset`. Returns a value of vector type with the last dimension
63 /// reduced to x1 or fully scalarized, e.g.:
64 /// - vector<3x2xi16> --> vector<3x1xi16>
65 /// - vector<2xi16> --> i16
67  Location loc, Value input,
68  int64_t lastOffset) {
69  ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
70  assert(lastOffset < shape.back() && "Offset out of bounds");
71 
72  // Scalarize the result in case of 1D vectors.
73  if (shape.size() == 1)
74  return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
75 
76  SmallVector<int64_t> offsets(shape.size(), 0);
77  offsets.back() = lastOffset;
78  auto sizes = llvm::to_vector(shape);
79  sizes.back() = 1;
80  SmallVector<int64_t> strides(shape.size(), 1);
81 
82  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
83  sizes, strides);
84 }
85 
86 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
87 /// with the first element at offset 0 and the second element at offset 1.
88 static std::pair<Value, Value>
90  Value input) {
91  return {extractLastDimSlice(rewriter, loc, input, 0),
92  extractLastDimSlice(rewriter, loc, input, 1)};
93 }
94 
95 // Performs a vector shape cast to drop the trailing x1 dimension. If the
96 // `input` is a scalar, this is a noop.
98  Location loc, Value input) {
99  auto vecTy = dyn_cast<VectorType>(input.getType());
100  if (!vecTy)
101  return input;
102 
103  // Shape cast to drop the last x1 dimension.
104  ArrayRef<int64_t> shape = vecTy.getShape();
105  assert(shape.size() >= 2 && "Expected vector with at list two dims");
106  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
107 
108  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
109  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
110 }
111 
112 /// Performs a vector shape cast to append an x1 dimension. If the
113 /// `input` is a scalar, this is a noop.
115  Value input) {
116  auto vecTy = dyn_cast<VectorType>(input.getType());
117  if (!vecTy)
118  return input;
119 
120  // Add a trailing x1 dim.
121  auto newShape = llvm::to_vector(vecTy.getShape());
122  newShape.push_back(1);
123  auto newTy = VectorType::get(newShape, vecTy.getElementType());
124  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
125 }
126 
127 /// Inserts the `source` vector slice into the `dest` vector at offset
128 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
129 /// a 1D vector.
131  Location loc, Value source, Value dest,
132  int64_t lastOffset) {
133  ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
134  assert(lastOffset < shape.back() && "Offset out of bounds");
135 
136  // Handle scalar source.
137  if (isa<IntegerType>(source.getType()))
138  return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
139 
140  SmallVector<int64_t> offsets(shape.size(), 0);
141  offsets.back() = lastOffset;
142  SmallVector<int64_t> strides(shape.size(), 1);
143  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
144  offsets, strides);
145 }
146 
147 /// Constructs a new vector of type `resultType` by creating a series of
148 /// insertions of `resultComponents`, each at the next offset of the last vector
149 /// dimension.
150 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
151 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
152 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
154  Location loc, VectorType resultType,
155  ValueRange resultComponents) {
156  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
157  (void)resultShape;
158  assert(!resultShape.empty() && "Result expected to have dimensions");
159  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
160  "Wrong number of result components");
161 
162  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
163  for (auto [i, component] : llvm::enumerate(resultComponents))
164  resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
165 
166  return resultVec;
167 }
168 
169 namespace {
170 //===----------------------------------------------------------------------===//
171 // ConvertConstant
172 //===----------------------------------------------------------------------===//
173 
174 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
176 
177  LogicalResult
178  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  Type oldType = op.getType();
181  auto newType = getTypeConverter()->convertType<VectorType>(oldType);
182  if (!newType)
183  return rewriter.notifyMatchFailure(
184  op, llvm::formatv("unsupported type: {0}", op.getType()));
185 
186  unsigned newBitWidth = newType.getElementTypeBitWidth();
187  Attribute oldValue = op.getValueAttr();
188 
189  if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
190  auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
191  auto newAttr = DenseElementsAttr::get(newType, {low, high});
192  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
193  return success();
194  }
195 
196  if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
197  auto [low, high] =
198  getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
199  int64_t numSplatElems = splatAttr.getNumElements();
200  SmallVector<APInt> values;
201  values.reserve(numSplatElems * 2);
202  for (int64_t i = 0; i < numSplatElems; ++i) {
203  values.push_back(low);
204  values.push_back(high);
205  }
206 
207  auto attr = DenseElementsAttr::get(newType, values);
208  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
209  return success();
210  }
211 
212  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
213  int64_t numElems = elemsAttr.getNumElements();
214  SmallVector<APInt> values;
215  values.reserve(numElems * 2);
216  for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
217  auto [low, high] = getHalves(origVal, newBitWidth);
218  values.push_back(std::move(low));
219  values.push_back(std::move(high));
220  }
221 
222  auto attr = DenseElementsAttr::get(newType, values);
223  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
224  return success();
225  }
226 
227  return rewriter.notifyMatchFailure(op.getLoc(),
228  "unhandled constant attribute");
229  }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertAddI
234 //===----------------------------------------------------------------------===//
235 
236 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
238 
239  LogicalResult
240  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
241  ConversionPatternRewriter &rewriter) const override {
242  Location loc = op->getLoc();
243  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
244  if (!newTy)
245  return rewriter.notifyMatchFailure(
246  loc, llvm::formatv("unsupported type: {0}", op.getType()));
247 
248  Type newElemTy = reduceInnermostDim(newTy);
249 
250  auto [lhsElem0, lhsElem1] =
251  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
252  auto [rhsElem0, rhsElem1] =
253  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
254 
255  auto lowSum =
256  rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
257  Value overflowVal =
258  rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
259 
260  Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
261  Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
262 
263  Value resultVec =
264  constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
265  rewriter.replaceOp(op, resultVec);
266  return success();
267  }
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // ConvertBitwiseBinary
272 //===----------------------------------------------------------------------===//
273 
274 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
275 template <typename BinaryOp>
276 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
279 
280  LogicalResult
281  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
282  ConversionPatternRewriter &rewriter) const override {
283  Location loc = op->getLoc();
284  auto newTy = this->getTypeConverter()->template convertType<VectorType>(
285  op.getType());
286  if (!newTy)
287  return rewriter.notifyMatchFailure(
288  loc, llvm::formatv("unsupported type: {0}", op.getType()));
289 
290  auto [lhsElem0, lhsElem1] =
291  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
292  auto [rhsElem0, rhsElem1] =
293  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
294 
295  Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
296  Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
297  Value resultVec =
298  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
299  rewriter.replaceOp(op, resultVec);
300  return success();
301  }
302 };
303 
304 //===----------------------------------------------------------------------===//
305 // ConvertCmpI
306 //===----------------------------------------------------------------------===//
307 
308 /// Returns the matching unsigned version of the given predicate `pred`, or the
309 /// same predicate if `pred` is not a signed.
310 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
311  using P = arith::CmpIPredicate;
312  switch (pred) {
313  case P::sge:
314  return P::uge;
315  case P::sgt:
316  return P::ugt;
317  case P::sle:
318  return P::ule;
319  case P::slt:
320  return P::ult;
321  default:
322  return pred;
323  }
324 }
325 
326 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
328 
329  LogicalResult
330  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
331  ConversionPatternRewriter &rewriter) const override {
332  Location loc = op->getLoc();
333  auto inputTy =
334  getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
335  if (!inputTy)
336  return rewriter.notifyMatchFailure(
337  loc, llvm::formatv("unsupported type: {0}", op.getType()));
338 
339  arith::CmpIPredicate highPred = adaptor.getPredicate();
340  arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
341 
342  auto [lhsElem0, lhsElem1] =
343  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
344  auto [rhsElem0, rhsElem1] =
345  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
346 
347  Value lowCmp =
348  rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
349  Value highCmp =
350  rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
351 
352  Value cmpResult{};
353  switch (highPred) {
354  case arith::CmpIPredicate::eq: {
355  cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
356  break;
357  }
358  case arith::CmpIPredicate::ne: {
359  cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
360  break;
361  }
362  default: {
363  // Handle inequality checks.
364  Value highEq = rewriter.create<arith::CmpIOp>(
365  loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
366  cmpResult =
367  rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
368  break;
369  }
370  }
371 
372  assert(cmpResult && "Unhandled case");
373  rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
374  return success();
375  }
376 };
377 
378 //===----------------------------------------------------------------------===//
379 // ConvertMulI
380 //===----------------------------------------------------------------------===//
381 
382 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
384 
385  LogicalResult
386  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  Location loc = op->getLoc();
389  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
390  if (!newTy)
391  return rewriter.notifyMatchFailure(
392  loc, llvm::formatv("unsupported type: {0}", op.getType()));
393 
394  auto [lhsElem0, lhsElem1] =
395  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
396  auto [rhsElem0, rhsElem1] =
397  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
398 
399  // The multiplication algorithm used is the standard (long) multiplication.
400  // Multiplying two i2N integers produces (at most) an i4N result, but
401  // because the calculation of top i2N is not necessary, we omit it.
402  auto mulLowLow =
403  rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
404  Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
405  Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
406 
407  Value resLow = mulLowLow.getLow();
408  Value resHi =
409  rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
410  resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
411 
412  Value resultVec =
413  constructResultVector(rewriter, loc, newTy, {resLow, resHi});
414  rewriter.replaceOp(op, resultVec);
415  return success();
416  }
417 };
418 
419 //===----------------------------------------------------------------------===//
420 // ConvertExtSI
421 //===----------------------------------------------------------------------===//
422 
423 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
425 
426  LogicalResult
427  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
428  ConversionPatternRewriter &rewriter) const override {
429  Location loc = op->getLoc();
430  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
431  if (!newTy)
432  return rewriter.notifyMatchFailure(
433  loc, llvm::formatv("unsupported type: {0}", op.getType()));
434 
435  Type newResultComponentTy = reduceInnermostDim(newTy);
436 
437  // Sign-extend the input value to determine the low half of the result.
438  // Then, check if the low half is negative, and sign-extend the comparison
439  // result to get the high half.
440  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
441  Value extended = rewriter.createOrFold<arith::ExtSIOp>(
442  loc, newResultComponentTy, newOperand);
443  Value operandZeroCst =
444  createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
445  Value signBit = rewriter.create<arith::CmpIOp>(
446  loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
447  Value signValue =
448  rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
449 
450  Value resultVec =
451  constructResultVector(rewriter, loc, newTy, {extended, signValue});
452  rewriter.replaceOp(op, resultVec);
453  return success();
454  }
455 };
456 
457 //===----------------------------------------------------------------------===//
458 // ConvertExtUI
459 //===----------------------------------------------------------------------===//
460 
461 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
463 
464  LogicalResult
465  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
466  ConversionPatternRewriter &rewriter) const override {
467  Location loc = op->getLoc();
468  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
469  if (!newTy)
470  return rewriter.notifyMatchFailure(
471  loc, llvm::formatv("unsupported type: {0}", op.getType()));
472 
473  Type newResultComponentTy = reduceInnermostDim(newTy);
474 
475  // Zero-extend the input value to determine the low half of the result.
476  // The high half is always zero.
477  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
478  Value extended = rewriter.createOrFold<arith::ExtUIOp>(
479  loc, newResultComponentTy, newOperand);
480  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
481  Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
482  rewriter.replaceOp(op, newRes);
483  return success();
484  }
485 };
486 
487 //===----------------------------------------------------------------------===//
488 // ConvertMaxMin
489 //===----------------------------------------------------------------------===//
490 
491 template <typename SourceOp, arith::CmpIPredicate CmpPred>
492 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 
495  LogicalResult
496  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
497  ConversionPatternRewriter &rewriter) const override {
498  Location loc = op->getLoc();
499 
500  Type oldTy = op.getType();
501  auto newTy = dyn_cast_or_null<VectorType>(
502  this->getTypeConverter()->convertType(oldTy));
503  if (!newTy)
504  return rewriter.notifyMatchFailure(
505  loc, llvm::formatv("unsupported type: {0}", op.getType()));
506 
507  // Rewrite Max*I/Min*I as compare and select over original operands. Let
508  // the CmpI and Select emulation patterns handle the final legalization.
509  Value cmp =
510  rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
511  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
512  op.getRhs());
513  return success();
514  }
515 };
516 
517 // Convert IndexCast ops
518 //===----------------------------------------------------------------------===//
519 
520 /// Returns true iff the type is `index` or `vector<...index>`.
521 static bool isIndexOrIndexVector(Type type) {
522  if (isa<IndexType>(type))
523  return true;
524 
525  if (auto vectorTy = dyn_cast<VectorType>(type))
526  if (isa<IndexType>(vectorTy.getElementType()))
527  return true;
528 
529  return false;
530 }
531 
532 template <typename CastOp>
533 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 
536  LogicalResult
537  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  Type resultType = op.getType();
540  if (!isIndexOrIndexVector(resultType))
541  return failure();
542 
543  Location loc = op.getLoc();
544  Type inType = op.getIn().getType();
545  auto newInTy =
546  this->getTypeConverter()->template convertType<VectorType>(inType);
547  if (!newInTy)
548  return rewriter.notifyMatchFailure(
549  loc, llvm::formatv("unsupported type: {0}", inType));
550 
551  // Discard the high half of the input truncating the original value.
552  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
553  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
554  rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
555  return success();
556  }
557 };
558 
559 template <typename CastOp, typename ExtensionOp>
560 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 
563  LogicalResult
564  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
565  ConversionPatternRewriter &rewriter) const override {
566  Type inType = op.getIn().getType();
567  if (!isIndexOrIndexVector(inType))
568  return failure();
569 
570  Location loc = op.getLoc();
571  auto *typeConverter =
572  this->template getTypeConverter<arith::WideIntEmulationConverter>();
573 
574  Type resultType = op.getType();
575  auto newTy = typeConverter->template convertType<VectorType>(resultType);
576  if (!newTy)
577  return rewriter.notifyMatchFailure(
578  loc, llvm::formatv("unsupported type: {0}", resultType));
579 
580  // Emit an index cast over the matching narrow type.
581  Type narrowTy =
582  rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
583  if (auto vecTy = dyn_cast<VectorType>(resultType))
584  narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
585 
586  // Sign or zero-extend the result. Let the matching conversion pattern
587  // legalize the extension op.
588  Value underlyingVal =
589  rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
590  rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
591  return success();
592  }
593 };
594 
595 //===----------------------------------------------------------------------===//
596 // ConvertSelect
597 //===----------------------------------------------------------------------===//
598 
599 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
601 
602  LogicalResult
603  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
604  ConversionPatternRewriter &rewriter) const override {
605  Location loc = op->getLoc();
606  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
607  if (!newTy)
608  return rewriter.notifyMatchFailure(
609  loc, llvm::formatv("unsupported type: {0}", op.getType()));
610 
611  auto [trueElem0, trueElem1] =
612  extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
613  auto [falseElem0, falseElem1] =
614  extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
615  Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
616 
617  Value resElem0 =
618  rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
619  Value resElem1 =
620  rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
621  Value resultVec =
622  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
623  rewriter.replaceOp(op, resultVec);
624  return success();
625  }
626 };
627 
628 //===----------------------------------------------------------------------===//
629 // ConvertShLI
630 //===----------------------------------------------------------------------===//
631 
632 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
634 
635  LogicalResult
636  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
637  ConversionPatternRewriter &rewriter) const override {
638  Location loc = op->getLoc();
639 
640  Type oldTy = op.getType();
641  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
642  if (!newTy)
643  return rewriter.notifyMatchFailure(
644  loc, llvm::formatv("unsupported type: {0}", op.getType()));
645 
646  Type newOperandTy = reduceInnermostDim(newTy);
647  // `oldBitWidth` == `2 * newBitWidth`
648  unsigned newBitWidth = newTy.getElementTypeBitWidth();
649 
650  auto [lhsElem0, lhsElem1] =
651  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
652  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
653 
654  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
655  // high halves of the results separately:
656  // 1. low := LHS.low shli RHS
657  //
658  // 2. high := a or b or c, where:
659  // a) Bits from LHS.high, shifted by the RHS.
660  // b) Bits from LHS.low, shifted right. These come into play when
661  // RHS < newBitWidth, e.g.:
662  // [0000][llll] shli 3 --> [0lll][l000]
663  // ^
664  // |
665  // [llll] shrui (4 - 3)
666  // c) Bits from LHS.low, shifted left. These matter when
667  // RHS > newBitWidth, e.g.:
668  // [0000][llll] shli 7 --> [l000][0000]
669  // ^
670  // |
671  // [llll] shli (7 - 4)
672  //
673  // Because shifts by values >= newBitWidth are undefined, we ignore the high
674  // half of RHS, and introduce 'bounds checks' to account for
675  // RHS.low > newBitWidth.
676  //
677  // TODO: Explore possible optimizations.
678  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
679  Value elemBitWidth =
680  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
681 
682  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
683  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
684 
685  Value shiftedElem0 =
686  rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
687  Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
688  zeroCst, shiftedElem0);
689 
690  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
691  loc, illegalElemShift, elemBitWidth, rhsElem0);
692  Value rightShiftAmount =
693  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
694  Value shiftedRight =
695  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
696  Value overshotShiftAmount =
697  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
698  Value shiftedLeft =
699  rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
700 
701  Value shiftedElem1 =
702  rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
703  Value resElem1High = rewriter.create<arith::SelectOp>(
704  loc, illegalElemShift, zeroCst, shiftedElem1);
705  Value resElem1Low = rewriter.create<arith::SelectOp>(
706  loc, illegalElemShift, shiftedLeft, shiftedRight);
707  Value resElem1 =
708  rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
709 
710  Value resultVec =
711  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
712  rewriter.replaceOp(op, resultVec);
713  return success();
714  }
715 };
716 
717 //===----------------------------------------------------------------------===//
718 // ConvertShRUI
719 //===----------------------------------------------------------------------===//
720 
721 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
723 
724  LogicalResult
725  matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
726  ConversionPatternRewriter &rewriter) const override {
727  Location loc = op->getLoc();
728 
729  Type oldTy = op.getType();
730  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
731  if (!newTy)
732  return rewriter.notifyMatchFailure(
733  loc, llvm::formatv("unsupported type: {0}", op.getType()));
734 
735  Type newOperandTy = reduceInnermostDim(newTy);
736  // `oldBitWidth` == `2 * newBitWidth`
737  unsigned newBitWidth = newTy.getElementTypeBitWidth();
738 
739  auto [lhsElem0, lhsElem1] =
740  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
741  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
742 
743  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
744  // high halves of the results separately:
745  // 1. low := a or b or c, where:
746  // a) Bits from LHS.low, shifted by the RHS.
747  // b) Bits from LHS.high, shifted left. These matter when
748  // RHS < newBitWidth, e.g.:
749  // [hhhh][0000] shrui 3 --> [000h][hhh0]
750  // ^
751  // |
752  // [hhhh] shli (4 - 1)
753  // c) Bits from LHS.high, shifted right. These come into play when
754  // RHS > newBitWidth, e.g.:
755  // [hhhh][0000] shrui 7 --> [0000][000h]
756  // ^
757  // |
758  // [hhhh] shrui (7 - 4)
759  //
760  // 2. high := LHS.high shrui RHS
761  //
762  // Because shifts by values >= newBitWidth are undefined, we ignore the high
763  // half of RHS, and introduce 'bounds checks' to account for
764  // RHS.low > newBitWidth.
765  //
766  // TODO: Explore possible optimizations.
767  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
768  Value elemBitWidth =
769  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
770 
771  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
772  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
773 
774  Value shiftedElem0 =
775  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
776  Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
777  zeroCst, shiftedElem0);
778  Value shiftedElem1 =
779  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
780  Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
781  zeroCst, shiftedElem1);
782 
783  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
784  loc, illegalElemShift, elemBitWidth, rhsElem0);
785  Value leftShiftAmount =
786  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
787  Value shiftedLeft =
788  rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
789  Value overshotShiftAmount =
790  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
791  Value shiftedRight =
792  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
793 
794  Value resElem0High = rewriter.create<arith::SelectOp>(
795  loc, illegalElemShift, shiftedRight, shiftedLeft);
796  Value resElem0 =
797  rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
798 
799  Value resultVec =
800  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
801  rewriter.replaceOp(op, resultVec);
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // ConvertShRSI
808 //===----------------------------------------------------------------------===//
809 
810 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
812 
813  LogicalResult
814  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
815  ConversionPatternRewriter &rewriter) const override {
816  Location loc = op->getLoc();
817 
818  Type oldTy = op.getType();
819  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
820  if (!newTy)
821  return rewriter.notifyMatchFailure(
822  loc, llvm::formatv("unsupported type: {0}", op.getType()));
823 
824  Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
825  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
826 
827  Type narrowTy = rhsElem0.getType();
828  int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
829 
830  // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
831  // Perform as many ops over the narrow integer type as possible and let the
832  // other emulation patterns convert the rest.
833  Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
834  Value signBit = rewriter.create<arith::CmpIOp>(
835  loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
836  signBit = dropTrailingX1Dim(rewriter, loc, signBit);
837 
838  // Create a bit pattern of either all ones or all zeros. Then shift it left
839  // to calculate the sign extension bits created by shifting the original
840  // sign bit right.
841  Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
842  Value maxShift =
843  createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
844  Value numNonSignExtBits =
845  rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
846  numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
847  numNonSignExtBits =
848  rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
849  Value signBits =
850  rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
851 
852  // Use original arguments to create the right shift.
853  Value shrui =
854  rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
855  Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
856 
857  // Handle shifting by zero. This is necessary when the `signBits` shift is
858  // invalid.
859  Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
860  rhsElem0, elemZero);
861  isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
862  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
863  shrsi);
864 
865  return success();
866  }
867 };
868 
869 //===----------------------------------------------------------------------===//
870 // ConvertSIToFP
871 //===----------------------------------------------------------------------===//
872 
873 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
875 
876  LogicalResult
877  matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
878  ConversionPatternRewriter &rewriter) const override {
879  Location loc = op.getLoc();
880 
881  Value in = op.getIn();
882  Type oldTy = in.getType();
883  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
884  if (!newTy)
885  return rewriter.notifyMatchFailure(
886  loc, llvm::formatv("unsupported type: {0}", oldTy));
887 
888  unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
889  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
890  Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
891  Value allOnesCst = createScalarOrSplatConstant(
892  rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
893 
894  // To avoid operating on very large unsigned numbers, perform the
895  // conversion on the absolute value. Then, decide whether to negate the
896  // result or not based on that sign bit. We assume two's complement and
897  // implement negation by flipping all bits and adding 1.
898  // Note that this relies on the the other conversion patterns to legalize
899  // created ops and narrow the bit widths.
900  Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
901  in, zeroCst);
902  Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
903  Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
904  Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
905 
906  Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
907  Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
908  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
909  absResult);
910  return success();
911  }
912 };
913 
914 //===----------------------------------------------------------------------===//
915 // ConvertUIToFP
916 //===----------------------------------------------------------------------===//
917 
918 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
920 
921  LogicalResult
922  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924  Location loc = op.getLoc();
925 
926  Type oldTy = op.getIn().getType();
927  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
928  if (!newTy)
929  return rewriter.notifyMatchFailure(
930  loc, llvm::formatv("unsupported type: {0}", oldTy));
931  unsigned newBitWidth = newTy.getElementTypeBitWidth();
932 
933  auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
934  Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
935  Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
936  Value zeroCst =
937  createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
938 
939  // The final result has the following form:
940  // if (hi == 0) return uitofp(low)
941  // else return uitofp(low) + uitofp(hi) * 2^BW
942  //
943  // where `BW` is the bitwidth of the narrowed integer type. We emit a
944  // select to make it easier to fold-away the `hi` part calculation when it
945  // is known to be zero.
946  //
947  // Note 1: The emulation is precise only for input values that have exact
948  // integer representation in the result floating point type, and may lead
949  // loss of precision otherwise.
950  //
951  // Note 2: We do not strictly need the `hi == 0`, case, but it makes
952  // constant folding easier.
953  Value hiEqZero = rewriter.create<arith::CmpIOp>(
954  loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
955 
956  Type resultTy = op.getType();
957  Type resultElemTy = getElementTypeOrSelf(resultTy);
958  Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
959  Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
960 
961  int64_t pow2Int = int64_t(1) << newBitWidth;
962  TypedAttr pow2Attr =
963  rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
964  if (auto vecTy = dyn_cast<VectorType>(resultTy))
965  pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
966 
967  Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
968 
969  Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
970  Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
971 
972  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
973  return success();
974  }
975 };
976 
977 //===----------------------------------------------------------------------===//
978 // ConvertTruncI
979 //===----------------------------------------------------------------------===//
980 
981 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
983 
984  LogicalResult
985  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
986  ConversionPatternRewriter &rewriter) const override {
987  Location loc = op.getLoc();
988  // Check if the result type is legal for this target. Currently, we do not
989  // support truncation to types wider than supported by the target.
990  if (!getTypeConverter()->isLegal(op.getType()))
991  return rewriter.notifyMatchFailure(
992  loc, llvm::formatv("unsupported truncation result type: {0}",
993  op.getType()));
994 
995  // Discard the high half of the input. Truncate the low half, if
996  // necessary.
997  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
998  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
999  Value truncated =
1000  rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1001  rewriter.replaceOp(op, truncated);
1002  return success();
1003  }
1004 };
1005 
1006 //===----------------------------------------------------------------------===//
1007 // ConvertVectorPrint
1008 //===----------------------------------------------------------------------===//
1009 
1010 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1012 
1013  LogicalResult
1014  matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1015  ConversionPatternRewriter &rewriter) const override {
1016  rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1017  return success();
1018  }
1019 };
1020 
1021 //===----------------------------------------------------------------------===//
1022 // Pass Definition
1023 //===----------------------------------------------------------------------===//
1024 
1025 struct EmulateWideIntPass final
1026  : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1027  using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1028 
1029  void runOnOperation() override {
1030  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1031  signalPassFailure();
1032  return;
1033  }
1034 
1035  Operation *op = getOperation();
1036  MLIRContext *ctx = op->getContext();
1037 
1038  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1039  ConversionTarget target(*ctx);
1040  target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1041  return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1042  });
1043  auto opLegalCallback = [&typeConverter](Operation *op) {
1044  return typeConverter.isLegal(op);
1045  };
1046  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1047  target
1048  .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
1049  opLegalCallback);
1050 
1051  RewritePatternSet patterns(ctx);
1052  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1053 
1054  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1055  signalPassFailure();
1056  }
1057 };
1058 } // end anonymous namespace
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Public Interface Definition
1062 //===----------------------------------------------------------------------===//
1063 
1065  unsigned widestIntSupportedByTarget)
1066  : maxIntWidth(widestIntSupportedByTarget) {
1067  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1068  "Only power-of-two integers with are supported");
1069  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1070 
1071  // Allow unknown types.
1072  addConversion([](Type ty) -> std::optional<Type> { return ty; });
1073 
1074  // Scalar case.
1075  addConversion([this](IntegerType ty) -> std::optional<Type> {
1076  unsigned width = ty.getWidth();
1077  if (width <= maxIntWidth)
1078  return ty;
1079 
1080  // i2N --> vector<2xiN>
1081  if (width == 2 * maxIntWidth)
1082  return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1083 
1084  return std::nullopt;
1085  });
1086 
1087  // Vector case.
1088  addConversion([this](VectorType ty) -> std::optional<Type> {
1089  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1090  if (!intTy)
1091  return ty;
1092 
1093  unsigned width = intTy.getWidth();
1094  if (width <= maxIntWidth)
1095  return ty;
1096 
1097  // vector<...xi2N> --> vector<...x2xiN>
1098  if (width == 2 * maxIntWidth) {
1099  auto newShape = to_vector(ty.getShape());
1100  newShape.push_back(2);
1101  return VectorType::get(newShape,
1102  IntegerType::get(ty.getContext(), maxIntWidth));
1103  }
1104 
1105  return std::nullopt;
1106  });
1107 
1108  // Function case.
1109  addConversion([this](FunctionType ty) -> std::optional<Type> {
1110  // Convert inputs and results, e.g.:
1111  // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1112  SmallVector<Type> inputs;
1113  if (failed(convertTypes(ty.getInputs(), inputs)))
1114  return std::nullopt;
1115 
1116  SmallVector<Type> results;
1117  if (failed(convertTypes(ty.getResults(), results)))
1118  return std::nullopt;
1119 
1120  return FunctionType::get(ty.getContext(), inputs, results);
1121  });
1122 }
1123 
1125  WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
1126  // Populate `func.*` conversion patterns.
1127  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1128  typeConverter);
1129  populateCallOpTypeConversionPattern(patterns, typeConverter);
1130  populateReturnOpTypeConversionPattern(patterns, typeConverter);
1131 
1132  // Populate `arith.*` conversion patterns.
1133  patterns.add<
1134  // Misc ops.
1135  ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1136  // Binary ops.
1137  ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1138  ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1139  ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1140  ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1141  ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1142  // Bitwise binary ops.
1143  ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1144  ConvertBitwiseBinary<arith::XOrIOp>,
1145  // Extension and truncation ops.
1146  ConvertExtSI, ConvertExtUI, ConvertTruncI,
1147  // Cast ops.
1148  ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1149  ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1150  ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1151  ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1152  ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
1153 }
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:273
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:99
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.