MLIR  18.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/MathExtras.h"
23 #include <cassert>
24 
25 namespace mlir::arith {
26 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28 } // namespace mlir::arith
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Common Helper Functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
37 /// Treats `value` as a 2*N bits-wide integer.
38 /// The bottom bits are returned in the first pair element, while the top bits
39 /// in the second one.
40 static std::pair<APInt, APInt> getHalves(const APInt &value,
41  unsigned newBitWidth) {
42  APInt low = value.extractBits(newBitWidth, 0);
43  APInt high = value.extractBits(newBitWidth, newBitWidth);
44  return {std::move(low), std::move(high)};
45 }
46 
47 /// Returns the type with the last (innermost) dimension reduced to x1.
48 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
49 /// e.g.:
50 /// - vector<3x2xi16> --> vector<3x1xi16>
51 /// - vector<2xi16> --> i16
52 static Type reduceInnermostDim(VectorType type) {
53  if (type.getShape().size() == 1)
54  return type.getElementType();
55 
56  auto newShape = to_vector(type.getShape());
57  newShape.back() = 1;
58  return VectorType::get(newShape, type.getElementType());
59 }
60 
61 /// Returns a constant of integer of vector type filled with (repeated) `value`.
63  Location loc, Type type,
64  const APInt &value) {
65  TypedAttr attr;
66  if (dyn_cast<IntegerType>(type)) {
67  attr = rewriter.getIntegerAttr(type, value);
68  } else {
69  auto vecTy = cast<VectorType>(type);
70  attr = SplatElementsAttr::get(vecTy, value);
71  }
72 
73  return rewriter.create<arith::ConstantOp>(loc, attr);
74 }
75 
76 /// Returns a constant of integer of vector type filled with (repeated) `value`.
78  Location loc, Type type,
79  int64_t value) {
80  unsigned elementBitWidth = 0;
81  if (auto intTy = dyn_cast<IntegerType>(type))
82  elementBitWidth = intTy.getWidth();
83  else
84  elementBitWidth = cast<VectorType>(type).getElementTypeBitWidth();
85 
86  return createScalarOrSplatConstant(rewriter, loc, type,
87  APInt(elementBitWidth, value));
88 }
89 
90 /// Extracts the `input` vector slice with elements at the last dimension offset
91 /// by `lastOffset`. Returns a value of vector type with the last dimension
92 /// reduced to x1 or fully scalarized, e.g.:
93 /// - vector<3x2xi16> --> vector<3x1xi16>
94 /// - vector<2xi16> --> i16
96  Location loc, Value input,
97  int64_t lastOffset) {
98  ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
99  assert(lastOffset < shape.back() && "Offset out of bounds");
100 
101  // Scalarize the result in case of 1D vectors.
102  if (shape.size() == 1)
103  return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
104 
105  SmallVector<int64_t> offsets(shape.size(), 0);
106  offsets.back() = lastOffset;
107  auto sizes = llvm::to_vector(shape);
108  sizes.back() = 1;
109  SmallVector<int64_t> strides(shape.size(), 1);
110 
111  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
112  sizes, strides);
113 }
114 
115 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
116 /// with the first element at offset 0 and the second element at offset 1.
117 static std::pair<Value, Value>
119  Value input) {
120  return {extractLastDimSlice(rewriter, loc, input, 0),
121  extractLastDimSlice(rewriter, loc, input, 1)};
122 }
123 
124 // Performs a vector shape cast to drop the trailing x1 dimension. If the
125 // `input` is a scalar, this is a noop.
127  Location loc, Value input) {
128  auto vecTy = dyn_cast<VectorType>(input.getType());
129  if (!vecTy)
130  return input;
131 
132  // Shape cast to drop the last x1 dimension.
133  ArrayRef<int64_t> shape = vecTy.getShape();
134  assert(shape.size() >= 2 && "Expected vector with at list two dims");
135  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
136 
137  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
138  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
139 }
140 
141 /// Performs a vector shape cast to append an x1 dimension. If the
142 /// `input` is a scalar, this is a noop.
144  Value input) {
145  auto vecTy = dyn_cast<VectorType>(input.getType());
146  if (!vecTy)
147  return input;
148 
149  // Add a trailing x1 dim.
150  auto newShape = llvm::to_vector(vecTy.getShape());
151  newShape.push_back(1);
152  auto newTy = VectorType::get(newShape, vecTy.getElementType());
153  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
154 }
155 
156 /// Inserts the `source` vector slice into the `dest` vector at offset
157 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
158 /// a 1D vector.
160  Location loc, Value source, Value dest,
161  int64_t lastOffset) {
162  ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
163  assert(lastOffset < shape.back() && "Offset out of bounds");
164 
165  // Handle scalar source.
166  if (isa<IntegerType>(source.getType()))
167  return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
168 
169  SmallVector<int64_t> offsets(shape.size(), 0);
170  offsets.back() = lastOffset;
171  SmallVector<int64_t> strides(shape.size(), 1);
172  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
173  offsets, strides);
174 }
175 
176 /// Constructs a new vector of type `resultType` by creating a series of
177 /// insertions of `resultComponents`, each at the next offset of the last vector
178 /// dimension.
179 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
180 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
181 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
183  Location loc, VectorType resultType,
184  ValueRange resultComponents) {
185  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
186  (void)resultShape;
187  assert(!resultShape.empty() && "Result expected to have dimensions");
188  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
189  "Wrong number of result components");
190 
191  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
192  for (auto [i, component] : llvm::enumerate(resultComponents))
193  resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
194 
195  return resultVec;
196 }
197 
198 namespace {
199 //===----------------------------------------------------------------------===//
200 // ConvertConstant
201 //===----------------------------------------------------------------------===//
202 
203 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
205 
207  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
208  ConversionPatternRewriter &rewriter) const override {
209  Type oldType = op.getType();
210  auto newType = getTypeConverter()->convertType<VectorType>(oldType);
211  if (!newType)
212  return rewriter.notifyMatchFailure(
213  op, llvm::formatv("unsupported type: {0}", op.getType()));
214 
215  unsigned newBitWidth = newType.getElementTypeBitWidth();
216  Attribute oldValue = op.getValueAttr();
217 
218  if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
219  auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
220  auto newAttr = DenseElementsAttr::get(newType, {low, high});
221  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
222  return success();
223  }
224 
225  if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
226  auto [low, high] =
227  getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
228  int64_t numSplatElems = splatAttr.getNumElements();
229  SmallVector<APInt> values;
230  values.reserve(numSplatElems * 2);
231  for (int64_t i = 0; i < numSplatElems; ++i) {
232  values.push_back(low);
233  values.push_back(high);
234  }
235 
236  auto attr = DenseElementsAttr::get(newType, values);
237  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
238  return success();
239  }
240 
241  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
242  int64_t numElems = elemsAttr.getNumElements();
243  SmallVector<APInt> values;
244  values.reserve(numElems * 2);
245  for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
246  auto [low, high] = getHalves(origVal, newBitWidth);
247  values.push_back(std::move(low));
248  values.push_back(std::move(high));
249  }
250 
251  auto attr = DenseElementsAttr::get(newType, values);
252  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
253  return success();
254  }
255 
256  return rewriter.notifyMatchFailure(op.getLoc(),
257  "unhandled constant attribute");
258  }
259 };
260 
261 //===----------------------------------------------------------------------===//
262 // ConvertAddI
263 //===----------------------------------------------------------------------===//
264 
265 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
267 
269  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
270  ConversionPatternRewriter &rewriter) const override {
271  Location loc = op->getLoc();
272  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
273  if (!newTy)
274  return rewriter.notifyMatchFailure(
275  loc, llvm::formatv("unsupported type: {0}", op.getType()));
276 
277  Type newElemTy = reduceInnermostDim(newTy);
278 
279  auto [lhsElem0, lhsElem1] =
280  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
281  auto [rhsElem0, rhsElem1] =
282  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
283 
284  auto lowSum =
285  rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
286  Value overflowVal =
287  rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
288 
289  Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
290  Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
291 
292  Value resultVec =
293  constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
294  rewriter.replaceOp(op, resultVec);
295  return success();
296  }
297 };
298 
299 //===----------------------------------------------------------------------===//
300 // ConvertBitwiseBinary
301 //===----------------------------------------------------------------------===//
302 
303 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
304 template <typename BinaryOp>
305 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
307  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
308 
310  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
311  ConversionPatternRewriter &rewriter) const override {
312  Location loc = op->getLoc();
313  auto newTy = this->getTypeConverter()->template convertType<VectorType>(
314  op.getType());
315  if (!newTy)
316  return rewriter.notifyMatchFailure(
317  loc, llvm::formatv("unsupported type: {0}", op.getType()));
318 
319  auto [lhsElem0, lhsElem1] =
320  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
321  auto [rhsElem0, rhsElem1] =
322  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
323 
324  Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
325  Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
326  Value resultVec =
327  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
328  rewriter.replaceOp(op, resultVec);
329  return success();
330  }
331 };
332 
333 //===----------------------------------------------------------------------===//
334 // ConvertCmpI
335 //===----------------------------------------------------------------------===//
336 
337 /// Returns the matching unsigned version of the given predicate `pred`, or the
338 /// same predicate if `pred` is not a signed.
339 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
340  using P = arith::CmpIPredicate;
341  switch (pred) {
342  case P::sge:
343  return P::uge;
344  case P::sgt:
345  return P::ugt;
346  case P::sle:
347  return P::ule;
348  case P::slt:
349  return P::ult;
350  default:
351  return pred;
352  }
353 }
354 
355 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
357 
359  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
360  ConversionPatternRewriter &rewriter) const override {
361  Location loc = op->getLoc();
362  auto inputTy =
363  getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
364  if (!inputTy)
365  return rewriter.notifyMatchFailure(
366  loc, llvm::formatv("unsupported type: {0}", op.getType()));
367 
368  arith::CmpIPredicate highPred = adaptor.getPredicate();
369  arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
370 
371  auto [lhsElem0, lhsElem1] =
372  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
373  auto [rhsElem0, rhsElem1] =
374  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
375 
376  Value lowCmp =
377  rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
378  Value highCmp =
379  rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
380 
381  Value cmpResult{};
382  switch (highPred) {
383  case arith::CmpIPredicate::eq: {
384  cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
385  break;
386  }
387  case arith::CmpIPredicate::ne: {
388  cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
389  break;
390  }
391  default: {
392  // Handle inequality checks.
393  Value highEq = rewriter.create<arith::CmpIOp>(
394  loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
395  cmpResult =
396  rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
397  break;
398  }
399  }
400 
401  assert(cmpResult && "Unhandled case");
402  rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
403  return success();
404  }
405 };
406 
407 //===----------------------------------------------------------------------===//
408 // ConvertMulI
409 //===----------------------------------------------------------------------===//
410 
411 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
413 
415  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
416  ConversionPatternRewriter &rewriter) const override {
417  Location loc = op->getLoc();
418  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
419  if (!newTy)
420  return rewriter.notifyMatchFailure(
421  loc, llvm::formatv("unsupported type: {0}", op.getType()));
422 
423  auto [lhsElem0, lhsElem1] =
424  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
425  auto [rhsElem0, rhsElem1] =
426  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
427 
428  // The multiplication algorithm used is the standard (long) multiplication.
429  // Multiplying two i2N integers produces (at most) an i4N result, but
430  // because the calculation of top i2N is not necessary, we omit it.
431  auto mulLowLow =
432  rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
433  Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
434  Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
435 
436  Value resLow = mulLowLow.getLow();
437  Value resHi =
438  rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
439  resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
440 
441  Value resultVec =
442  constructResultVector(rewriter, loc, newTy, {resLow, resHi});
443  rewriter.replaceOp(op, resultVec);
444  return success();
445  }
446 };
447 
448 //===----------------------------------------------------------------------===//
449 // ConvertExtSI
450 //===----------------------------------------------------------------------===//
451 
452 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
454 
456  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
457  ConversionPatternRewriter &rewriter) const override {
458  Location loc = op->getLoc();
459  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
460  if (!newTy)
461  return rewriter.notifyMatchFailure(
462  loc, llvm::formatv("unsupported type: {0}", op.getType()));
463 
464  Type newResultComponentTy = reduceInnermostDim(newTy);
465 
466  // Sign-extend the input value to determine the low half of the result.
467  // Then, check if the low half is negative, and sign-extend the comparison
468  // result to get the high half.
469  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
470  Value extended = rewriter.createOrFold<arith::ExtSIOp>(
471  loc, newResultComponentTy, newOperand);
472  Value operandZeroCst =
473  createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
474  Value signBit = rewriter.create<arith::CmpIOp>(
475  loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
476  Value signValue =
477  rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
478 
479  Value resultVec =
480  constructResultVector(rewriter, loc, newTy, {extended, signValue});
481  rewriter.replaceOp(op, resultVec);
482  return success();
483  }
484 };
485 
486 //===----------------------------------------------------------------------===//
487 // ConvertExtUI
488 //===----------------------------------------------------------------------===//
489 
490 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
492 
494  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
495  ConversionPatternRewriter &rewriter) const override {
496  Location loc = op->getLoc();
497  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
498  if (!newTy)
499  return rewriter.notifyMatchFailure(
500  loc, llvm::formatv("unsupported type: {0}", op.getType()));
501 
502  Type newResultComponentTy = reduceInnermostDim(newTy);
503 
504  // Zero-extend the input value to determine the low half of the result.
505  // The high half is always zero.
506  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
507  Value extended = rewriter.createOrFold<arith::ExtUIOp>(
508  loc, newResultComponentTy, newOperand);
509  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
510  Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
511  rewriter.replaceOp(op, newRes);
512  return success();
513  }
514 };
515 
516 //===----------------------------------------------------------------------===//
517 // ConvertMaxMin
518 //===----------------------------------------------------------------------===//
519 
520 template <typename SourceOp, arith::CmpIPredicate CmpPred>
521 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
523 
525  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
526  ConversionPatternRewriter &rewriter) const override {
527  Location loc = op->getLoc();
528 
529  Type oldTy = op.getType();
530  auto newTy = dyn_cast_or_null<VectorType>(
531  this->getTypeConverter()->convertType(oldTy));
532  if (!newTy)
533  return rewriter.notifyMatchFailure(
534  loc, llvm::formatv("unsupported type: {0}", op.getType()));
535 
536  // Rewrite Max*I/Min*I as compare and select over original operands. Let
537  // the CmpI and Select emulation patterns handle the final legalization.
538  Value cmp =
539  rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
540  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
541  op.getRhs());
542  return success();
543  }
544 };
545 
546 // Convert IndexCast ops
547 //===----------------------------------------------------------------------===//
548 
549 /// Returns true iff the type is `index` or `vector<...index>`.
550 static bool isIndexOrIndexVector(Type type) {
551  if (isa<IndexType>(type))
552  return true;
553 
554  if (auto vectorTy = dyn_cast<VectorType>(type))
555  if (isa<IndexType>(vectorTy.getElementType()))
556  return true;
557 
558  return false;
559 }
560 
561 template <typename CastOp>
562 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
564 
566  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override {
568  Type resultType = op.getType();
569  if (!isIndexOrIndexVector(resultType))
570  return failure();
571 
572  Location loc = op.getLoc();
573  Type inType = op.getIn().getType();
574  auto newInTy =
575  this->getTypeConverter()->template convertType<VectorType>(inType);
576  if (!newInTy)
577  return rewriter.notifyMatchFailure(
578  loc, llvm::formatv("unsupported type: {0}", inType));
579 
580  // Discard the high half of the input truncating the original value.
581  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
582  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
583  rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
584  return success();
585  }
586 };
587 
588 template <typename CastOp, typename ExtensionOp>
589 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
591 
593  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
594  ConversionPatternRewriter &rewriter) const override {
595  Type inType = op.getIn().getType();
596  if (!isIndexOrIndexVector(inType))
597  return failure();
598 
599  Location loc = op.getLoc();
600  auto *typeConverter =
601  this->template getTypeConverter<arith::WideIntEmulationConverter>();
602 
603  Type resultType = op.getType();
604  auto newTy = typeConverter->template convertType<VectorType>(resultType);
605  if (!newTy)
606  return rewriter.notifyMatchFailure(
607  loc, llvm::formatv("unsupported type: {0}", resultType));
608 
609  // Emit an index cast over the matching narrow type.
610  Type narrowTy =
611  rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
612  if (auto vecTy = dyn_cast<VectorType>(resultType))
613  narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
614 
615  // Sign or zero-extend the result. Let the matching conversion pattern
616  // legalize the extension op.
617  Value underlyingVal =
618  rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
619  rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
620  return success();
621  }
622 };
623 
624 //===----------------------------------------------------------------------===//
625 // ConvertSelect
626 //===----------------------------------------------------------------------===//
627 
628 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
630 
632  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
633  ConversionPatternRewriter &rewriter) const override {
634  Location loc = op->getLoc();
635  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
636  if (!newTy)
637  return rewriter.notifyMatchFailure(
638  loc, llvm::formatv("unsupported type: {0}", op.getType()));
639 
640  auto [trueElem0, trueElem1] =
641  extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
642  auto [falseElem0, falseElem1] =
643  extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
644  Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
645 
646  Value resElem0 =
647  rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
648  Value resElem1 =
649  rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
650  Value resultVec =
651  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
652  rewriter.replaceOp(op, resultVec);
653  return success();
654  }
655 };
656 
657 //===----------------------------------------------------------------------===//
658 // ConvertShLI
659 //===----------------------------------------------------------------------===//
660 
661 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
663 
665  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
666  ConversionPatternRewriter &rewriter) const override {
667  Location loc = op->getLoc();
668 
669  Type oldTy = op.getType();
670  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
671  if (!newTy)
672  return rewriter.notifyMatchFailure(
673  loc, llvm::formatv("unsupported type: {0}", op.getType()));
674 
675  Type newOperandTy = reduceInnermostDim(newTy);
676  // `oldBitWidth` == `2 * newBitWidth`
677  unsigned newBitWidth = newTy.getElementTypeBitWidth();
678 
679  auto [lhsElem0, lhsElem1] =
680  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
681  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
682 
683  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
684  // high halves of the results separately:
685  // 1. low := LHS.low shli RHS
686  //
687  // 2. high := a or b or c, where:
688  // a) Bits from LHS.high, shifted by the RHS.
689  // b) Bits from LHS.low, shifted right. These come into play when
690  // RHS < newBitWidth, e.g.:
691  // [0000][llll] shli 3 --> [0lll][l000]
692  // ^
693  // |
694  // [llll] shrui (4 - 3)
695  // c) Bits from LHS.low, shifted left. These matter when
696  // RHS > newBitWidth, e.g.:
697  // [0000][llll] shli 7 --> [l000][0000]
698  // ^
699  // |
700  // [llll] shli (7 - 4)
701  //
702  // Because shifts by values >= newBitWidth are undefined, we ignore the high
703  // half of RHS, and introduce 'bounds checks' to account for
704  // RHS.low > newBitWidth.
705  //
706  // TODO: Explore possible optimizations.
707  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
708  Value elemBitWidth =
709  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
710 
711  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
712  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
713 
714  Value shiftedElem0 =
715  rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
716  Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
717  zeroCst, shiftedElem0);
718 
719  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
720  loc, illegalElemShift, elemBitWidth, rhsElem0);
721  Value rightShiftAmount =
722  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
723  Value shiftedRight =
724  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
725  Value overshotShiftAmount =
726  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
727  Value shiftedLeft =
728  rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
729 
730  Value shiftedElem1 =
731  rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
732  Value resElem1High = rewriter.create<arith::SelectOp>(
733  loc, illegalElemShift, zeroCst, shiftedElem1);
734  Value resElem1Low = rewriter.create<arith::SelectOp>(
735  loc, illegalElemShift, shiftedLeft, shiftedRight);
736  Value resElem1 =
737  rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
738 
739  Value resultVec =
740  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
741  rewriter.replaceOp(op, resultVec);
742  return success();
743  }
744 };
745 
746 //===----------------------------------------------------------------------===//
747 // ConvertShRUI
748 //===----------------------------------------------------------------------===//
749 
750 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
752 
754  matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
755  ConversionPatternRewriter &rewriter) const override {
756  Location loc = op->getLoc();
757 
758  Type oldTy = op.getType();
759  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
760  if (!newTy)
761  return rewriter.notifyMatchFailure(
762  loc, llvm::formatv("unsupported type: {0}", op.getType()));
763 
764  Type newOperandTy = reduceInnermostDim(newTy);
765  // `oldBitWidth` == `2 * newBitWidth`
766  unsigned newBitWidth = newTy.getElementTypeBitWidth();
767 
768  auto [lhsElem0, lhsElem1] =
769  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
770  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
771 
772  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
773  // high halves of the results separately:
774  // 1. low := a or b or c, where:
775  // a) Bits from LHS.low, shifted by the RHS.
776  // b) Bits from LHS.high, shifted left. These matter when
777  // RHS < newBitWidth, e.g.:
778  // [hhhh][0000] shrui 3 --> [000h][hhh0]
779  // ^
780  // |
781  // [hhhh] shli (4 - 1)
782  // c) Bits from LHS.high, shifted right. These come into play when
783  // RHS > newBitWidth, e.g.:
784  // [hhhh][0000] shrui 7 --> [0000][000h]
785  // ^
786  // |
787  // [hhhh] shrui (7 - 4)
788  //
789  // 2. high := LHS.high shrui RHS
790  //
791  // Because shifts by values >= newBitWidth are undefined, we ignore the high
792  // half of RHS, and introduce 'bounds checks' to account for
793  // RHS.low > newBitWidth.
794  //
795  // TODO: Explore possible optimizations.
796  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
797  Value elemBitWidth =
798  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
799 
800  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
801  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
802 
803  Value shiftedElem0 =
804  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
805  Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
806  zeroCst, shiftedElem0);
807  Value shiftedElem1 =
808  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
809  Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
810  zeroCst, shiftedElem1);
811 
812  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
813  loc, illegalElemShift, elemBitWidth, rhsElem0);
814  Value leftShiftAmount =
815  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
816  Value shiftedLeft =
817  rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
818  Value overshotShiftAmount =
819  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
820  Value shiftedRight =
821  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
822 
823  Value resElem0High = rewriter.create<arith::SelectOp>(
824  loc, illegalElemShift, shiftedRight, shiftedLeft);
825  Value resElem0 =
826  rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
827 
828  Value resultVec =
829  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
830  rewriter.replaceOp(op, resultVec);
831  return success();
832  }
833 };
834 
835 //===----------------------------------------------------------------------===//
836 // ConvertShRSI
837 //===----------------------------------------------------------------------===//
838 
839 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
841 
843  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
844  ConversionPatternRewriter &rewriter) const override {
845  Location loc = op->getLoc();
846 
847  Type oldTy = op.getType();
848  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
849  if (!newTy)
850  return rewriter.notifyMatchFailure(
851  loc, llvm::formatv("unsupported type: {0}", op.getType()));
852 
853  Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
854  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
855 
856  Type narrowTy = rhsElem0.getType();
857  int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
858 
859  // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
860  // Perform as many ops over the narrow integer type as possible and let the
861  // other emulation patterns convert the rest.
862  Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
863  Value signBit = rewriter.create<arith::CmpIOp>(
864  loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
865  signBit = dropTrailingX1Dim(rewriter, loc, signBit);
866 
867  // Create a bit pattern of either all ones or all zeros. Then shift it left
868  // to calculate the sign extension bits created by shifting the original
869  // sign bit right.
870  Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
871  Value maxShift =
872  createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
873  Value numNonSignExtBits =
874  rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
875  numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
876  numNonSignExtBits =
877  rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
878  Value signBits =
879  rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
880 
881  // Use original arguments to create the right shift.
882  Value shrui =
883  rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
884  Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
885 
886  // Handle shifting by zero. This is necessary when the `signBits` shift is
887  // invalid.
888  Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
889  rhsElem0, elemZero);
890  isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
891  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
892  shrsi);
893 
894  return success();
895  }
896 };
897 
898 //===----------------------------------------------------------------------===//
899 // ConvertSIToFP
900 //===----------------------------------------------------------------------===//
901 
902 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
904 
906  matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
907  ConversionPatternRewriter &rewriter) const override {
908  Location loc = op.getLoc();
909 
910  Value in = op.getIn();
911  Type oldTy = in.getType();
912  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
913  if (!newTy)
914  return rewriter.notifyMatchFailure(
915  loc, llvm::formatv("unsupported type: {0}", oldTy));
916 
917  unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
918  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
919  Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
920  Value allOnesCst = createScalarOrSplatConstant(
921  rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
922 
923  // To avoid operating on very large unsigned numbers, perform the
924  // conversion on the absolute value. Then, decide whether to negate the
925  // result or not based on that sign bit. We assume two's complement and
926  // implement negation by flipping all bits and adding 1.
927  // Note that this relies on the the other conversion patterns to legalize
928  // created ops and narrow the bit widths.
929  Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
930  in, zeroCst);
931  Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
932  Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
933  Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
934 
935  Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
936  Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
937  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
938  absResult);
939  return success();
940  }
941 };
942 
943 //===----------------------------------------------------------------------===//
944 // ConvertUIToFP
945 //===----------------------------------------------------------------------===//
946 
947 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
949 
951  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
952  ConversionPatternRewriter &rewriter) const override {
953  Location loc = op.getLoc();
954 
955  Type oldTy = op.getIn().getType();
956  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
957  if (!newTy)
958  return rewriter.notifyMatchFailure(
959  loc, llvm::formatv("unsupported type: {0}", oldTy));
960  unsigned newBitWidth = newTy.getElementTypeBitWidth();
961 
962  auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
963  Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
964  Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
965  Value zeroCst =
966  createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
967 
968  // The final result has the following form:
969  // if (hi == 0) return uitofp(low)
970  // else return uitofp(low) + uitofp(hi) * 2^BW
971  //
972  // where `BW` is the bitwidth of the narrowed integer type. We emit a
973  // select to make it easier to fold-away the `hi` part calculation when it
974  // is known to be zero.
975  //
976  // Note 1: The emulation is precise only for input values that have exact
977  // integer representation in the result floating point type, and may lead
978  // loss of precision otherwise.
979  //
980  // Note 2: We do not strictly need the `hi == 0`, case, but it makes
981  // constant folding easier.
982  Value hiEqZero = rewriter.create<arith::CmpIOp>(
983  loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
984 
985  Type resultTy = op.getType();
986  Type resultElemTy = getElementTypeOrSelf(resultTy);
987  Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
988  Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
989 
990  int64_t pow2Int = int64_t(1) << newBitWidth;
991  TypedAttr pow2Attr =
992  rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
993  if (auto vecTy = dyn_cast<VectorType>(resultTy))
994  pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
995 
996  Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
997 
998  Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
999  Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
1000 
1001  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
1002  return success();
1003  }
1004 };
1005 
1006 //===----------------------------------------------------------------------===//
1007 // ConvertTruncI
1008 //===----------------------------------------------------------------------===//
1009 
1010 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
1012 
1014  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1015  ConversionPatternRewriter &rewriter) const override {
1016  Location loc = op.getLoc();
1017  // Check if the result type is legal for this target. Currently, we do not
1018  // support truncation to types wider than supported by the target.
1019  if (!getTypeConverter()->isLegal(op.getType()))
1020  return rewriter.notifyMatchFailure(
1021  loc, llvm::formatv("unsupported truncation result type: {0}",
1022  op.getType()));
1023 
1024  // Discard the high half of the input. Truncate the low half, if
1025  // necessary.
1026  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
1027  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
1028  Value truncated =
1029  rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1030  rewriter.replaceOp(op, truncated);
1031  return success();
1032  }
1033 };
1034 
1035 //===----------------------------------------------------------------------===//
1036 // ConvertVectorPrint
1037 //===----------------------------------------------------------------------===//
1038 
1039 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1041 
1043  matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1044  ConversionPatternRewriter &rewriter) const override {
1045  rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1046  return success();
1047  }
1048 };
1049 
1050 //===----------------------------------------------------------------------===//
1051 // Pass Definition
1052 //===----------------------------------------------------------------------===//
1053 
1054 struct EmulateWideIntPass final
1055  : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1056  using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1057 
1058  void runOnOperation() override {
1059  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1060  signalPassFailure();
1061  return;
1062  }
1063 
1064  Operation *op = getOperation();
1065  MLIRContext *ctx = op->getContext();
1066 
1067  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1068  ConversionTarget target(*ctx);
1069  target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1070  return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1071  });
1072  auto opLegalCallback = [&typeConverter](Operation *op) {
1073  return typeConverter.isLegal(op);
1074  };
1075  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1076  target
1077  .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
1078  opLegalCallback);
1079 
1080  RewritePatternSet patterns(ctx);
1081  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1082 
1083  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1084  signalPassFailure();
1085  }
1086 };
1087 } // end anonymous namespace
1088 
1089 //===----------------------------------------------------------------------===//
1090 // Public Interface Definition
1091 //===----------------------------------------------------------------------===//
1092 
1094  unsigned widestIntSupportedByTarget)
1095  : maxIntWidth(widestIntSupportedByTarget) {
1096  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1097  "Only power-of-two integers with are supported");
1098  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1099 
1100  // Allow unknown types.
1101  addConversion([](Type ty) -> std::optional<Type> { return ty; });
1102 
1103  // Scalar case.
1104  addConversion([this](IntegerType ty) -> std::optional<Type> {
1105  unsigned width = ty.getWidth();
1106  if (width <= maxIntWidth)
1107  return ty;
1108 
1109  // i2N --> vector<2xiN>
1110  if (width == 2 * maxIntWidth)
1111  return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1112 
1113  return std::nullopt;
1114  });
1115 
1116  // Vector case.
1117  addConversion([this](VectorType ty) -> std::optional<Type> {
1118  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1119  if (!intTy)
1120  return ty;
1121 
1122  unsigned width = intTy.getWidth();
1123  if (width <= maxIntWidth)
1124  return ty;
1125 
1126  // vector<...xi2N> --> vector<...x2xiN>
1127  if (width == 2 * maxIntWidth) {
1128  auto newShape = to_vector(ty.getShape());
1129  newShape.push_back(2);
1130  return VectorType::get(newShape,
1131  IntegerType::get(ty.getContext(), maxIntWidth));
1132  }
1133 
1134  return std::nullopt;
1135  });
1136 
1137  // Function case.
1138  addConversion([this](FunctionType ty) -> std::optional<Type> {
1139  // Convert inputs and results, e.g.:
1140  // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1141  SmallVector<Type> inputs;
1142  if (failed(convertTypes(ty.getInputs(), inputs)))
1143  return std::nullopt;
1144 
1145  SmallVector<Type> results;
1146  if (failed(convertTypes(ty.getResults(), results)))
1147  return std::nullopt;
1148 
1149  return FunctionType::get(ty.getContext(), inputs, results);
1150  });
1151 }
1152 
1154  WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
1155  // Populate `func.*` conversion patterns.
1156  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1157  typeConverter);
1158  populateCallOpTypeConversionPattern(patterns, typeConverter);
1159  populateReturnOpTypeConversionPattern(patterns, typeConverter);
1160 
1161  // Populate `arith.*` conversion patterns.
1162  patterns.add<
1163  // Misc ops.
1164  ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1165  // Binary ops.
1166  ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1167  ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1168  ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1169  ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1170  ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1171  // Bitwise binary ops.
1172  ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1173  ConvertBitwiseBinary<arith::XOrIOp>,
1174  // Extension and truncation ops.
1175  ConvertExtSI, ConvertExtUI, ConvertTruncI,
1176  // Cast ops.
1177  ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1178  ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1179  ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1180  ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1181  ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
1182 }
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, const APInt &value)
Returns a constant of integer of vector type filled with (repeated) value.
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26