MLIR  19.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/Support/MathExtras.h"
24 #include <cassert>
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 } // namespace mlir::arith
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Common Helper Functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
38 /// Treats `value` as a 2*N bits-wide integer.
39 /// The bottom bits are returned in the first pair element, while the top bits
40 /// in the second one.
41 static std::pair<APInt, APInt> getHalves(const APInt &value,
42  unsigned newBitWidth) {
43  APInt low = value.extractBits(newBitWidth, 0);
44  APInt high = value.extractBits(newBitWidth, newBitWidth);
45  return {std::move(low), std::move(high)};
46 }
47 
48 /// Returns the type with the last (innermost) dimension reduced to x1.
49 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
50 /// e.g.:
51 /// - vector<3x2xi16> --> vector<3x1xi16>
52 /// - vector<2xi16> --> i16
53 static Type reduceInnermostDim(VectorType type) {
54  if (type.getShape().size() == 1)
55  return type.getElementType();
56 
57  auto newShape = to_vector(type.getShape());
58  newShape.back() = 1;
59  return VectorType::get(newShape, type.getElementType());
60 }
61 
62 /// Extracts the `input` vector slice with elements at the last dimension offset
63 /// by `lastOffset`. Returns a value of vector type with the last dimension
64 /// reduced to x1 or fully scalarized, e.g.:
65 /// - vector<3x2xi16> --> vector<3x1xi16>
66 /// - vector<2xi16> --> i16
68  Location loc, Value input,
69  int64_t lastOffset) {
70  ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
71  assert(lastOffset < shape.back() && "Offset out of bounds");
72 
73  // Scalarize the result in case of 1D vectors.
74  if (shape.size() == 1)
75  return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
76 
77  SmallVector<int64_t> offsets(shape.size(), 0);
78  offsets.back() = lastOffset;
79  auto sizes = llvm::to_vector(shape);
80  sizes.back() = 1;
81  SmallVector<int64_t> strides(shape.size(), 1);
82 
83  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
84  sizes, strides);
85 }
86 
87 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
88 /// with the first element at offset 0 and the second element at offset 1.
89 static std::pair<Value, Value>
91  Value input) {
92  return {extractLastDimSlice(rewriter, loc, input, 0),
93  extractLastDimSlice(rewriter, loc, input, 1)};
94 }
95 
96 // Performs a vector shape cast to drop the trailing x1 dimension. If the
97 // `input` is a scalar, this is a noop.
99  Location loc, Value input) {
100  auto vecTy = dyn_cast<VectorType>(input.getType());
101  if (!vecTy)
102  return input;
103 
104  // Shape cast to drop the last x1 dimension.
105  ArrayRef<int64_t> shape = vecTy.getShape();
106  assert(shape.size() >= 2 && "Expected vector with at list two dims");
107  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
108 
109  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
110  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
111 }
112 
113 /// Performs a vector shape cast to append an x1 dimension. If the
114 /// `input` is a scalar, this is a noop.
116  Value input) {
117  auto vecTy = dyn_cast<VectorType>(input.getType());
118  if (!vecTy)
119  return input;
120 
121  // Add a trailing x1 dim.
122  auto newShape = llvm::to_vector(vecTy.getShape());
123  newShape.push_back(1);
124  auto newTy = VectorType::get(newShape, vecTy.getElementType());
125  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
126 }
127 
128 /// Inserts the `source` vector slice into the `dest` vector at offset
129 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
130 /// a 1D vector.
132  Location loc, Value source, Value dest,
133  int64_t lastOffset) {
134  ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
135  assert(lastOffset < shape.back() && "Offset out of bounds");
136 
137  // Handle scalar source.
138  if (isa<IntegerType>(source.getType()))
139  return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
140 
141  SmallVector<int64_t> offsets(shape.size(), 0);
142  offsets.back() = lastOffset;
143  SmallVector<int64_t> strides(shape.size(), 1);
144  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
145  offsets, strides);
146 }
147 
148 /// Constructs a new vector of type `resultType` by creating a series of
149 /// insertions of `resultComponents`, each at the next offset of the last vector
150 /// dimension.
151 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
152 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
153 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
155  Location loc, VectorType resultType,
156  ValueRange resultComponents) {
157  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
158  (void)resultShape;
159  assert(!resultShape.empty() && "Result expected to have dimensions");
160  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
161  "Wrong number of result components");
162 
163  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
164  for (auto [i, component] : llvm::enumerate(resultComponents))
165  resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
166 
167  return resultVec;
168 }
169 
170 namespace {
171 //===----------------------------------------------------------------------===//
172 // ConvertConstant
173 //===----------------------------------------------------------------------===//
174 
175 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
177 
179  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  Type oldType = op.getType();
182  auto newType = getTypeConverter()->convertType<VectorType>(oldType);
183  if (!newType)
184  return rewriter.notifyMatchFailure(
185  op, llvm::formatv("unsupported type: {0}", op.getType()));
186 
187  unsigned newBitWidth = newType.getElementTypeBitWidth();
188  Attribute oldValue = op.getValueAttr();
189 
190  if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191  auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
192  auto newAttr = DenseElementsAttr::get(newType, {low, high});
193  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
194  return success();
195  }
196 
197  if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
198  auto [low, high] =
199  getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200  int64_t numSplatElems = splatAttr.getNumElements();
201  SmallVector<APInt> values;
202  values.reserve(numSplatElems * 2);
203  for (int64_t i = 0; i < numSplatElems; ++i) {
204  values.push_back(low);
205  values.push_back(high);
206  }
207 
208  auto attr = DenseElementsAttr::get(newType, values);
209  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
210  return success();
211  }
212 
213  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214  int64_t numElems = elemsAttr.getNumElements();
215  SmallVector<APInt> values;
216  values.reserve(numElems * 2);
217  for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
218  auto [low, high] = getHalves(origVal, newBitWidth);
219  values.push_back(std::move(low));
220  values.push_back(std::move(high));
221  }
222 
223  auto attr = DenseElementsAttr::get(newType, values);
224  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
225  return success();
226  }
227 
228  return rewriter.notifyMatchFailure(op.getLoc(),
229  "unhandled constant attribute");
230  }
231 };
232 
233 //===----------------------------------------------------------------------===//
234 // ConvertAddI
235 //===----------------------------------------------------------------------===//
236 
237 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
239 
241  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
242  ConversionPatternRewriter &rewriter) const override {
243  Location loc = op->getLoc();
244  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
245  if (!newTy)
246  return rewriter.notifyMatchFailure(
247  loc, llvm::formatv("unsupported type: {0}", op.getType()));
248 
249  Type newElemTy = reduceInnermostDim(newTy);
250 
251  auto [lhsElem0, lhsElem1] =
252  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
253  auto [rhsElem0, rhsElem1] =
254  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
255 
256  auto lowSum =
257  rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
258  Value overflowVal =
259  rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
260 
261  Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
262  Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
263 
264  Value resultVec =
265  constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
266  rewriter.replaceOp(op, resultVec);
267  return success();
268  }
269 };
270 
271 //===----------------------------------------------------------------------===//
272 // ConvertBitwiseBinary
273 //===----------------------------------------------------------------------===//
274 
275 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
276 template <typename BinaryOp>
277 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
279  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
280 
282  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
283  ConversionPatternRewriter &rewriter) const override {
284  Location loc = op->getLoc();
285  auto newTy = this->getTypeConverter()->template convertType<VectorType>(
286  op.getType());
287  if (!newTy)
288  return rewriter.notifyMatchFailure(
289  loc, llvm::formatv("unsupported type: {0}", op.getType()));
290 
291  auto [lhsElem0, lhsElem1] =
292  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
293  auto [rhsElem0, rhsElem1] =
294  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
295 
296  Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
297  Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
298  Value resultVec =
299  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
300  rewriter.replaceOp(op, resultVec);
301  return success();
302  }
303 };
304 
305 //===----------------------------------------------------------------------===//
306 // ConvertCmpI
307 //===----------------------------------------------------------------------===//
308 
309 /// Returns the matching unsigned version of the given predicate `pred`, or the
310 /// same predicate if `pred` is not a signed.
311 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312  using P = arith::CmpIPredicate;
313  switch (pred) {
314  case P::sge:
315  return P::uge;
316  case P::sgt:
317  return P::ugt;
318  case P::sle:
319  return P::ule;
320  case P::slt:
321  return P::ult;
322  default:
323  return pred;
324  }
325 }
326 
327 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
329 
331  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
332  ConversionPatternRewriter &rewriter) const override {
333  Location loc = op->getLoc();
334  auto inputTy =
335  getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
336  if (!inputTy)
337  return rewriter.notifyMatchFailure(
338  loc, llvm::formatv("unsupported type: {0}", op.getType()));
339 
340  arith::CmpIPredicate highPred = adaptor.getPredicate();
341  arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
342 
343  auto [lhsElem0, lhsElem1] =
344  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
345  auto [rhsElem0, rhsElem1] =
346  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
347 
348  Value lowCmp =
349  rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
350  Value highCmp =
351  rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
352 
353  Value cmpResult{};
354  switch (highPred) {
355  case arith::CmpIPredicate::eq: {
356  cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
357  break;
358  }
359  case arith::CmpIPredicate::ne: {
360  cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
361  break;
362  }
363  default: {
364  // Handle inequality checks.
365  Value highEq = rewriter.create<arith::CmpIOp>(
366  loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
367  cmpResult =
368  rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
369  break;
370  }
371  }
372 
373  assert(cmpResult && "Unhandled case");
374  rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
375  return success();
376  }
377 };
378 
379 //===----------------------------------------------------------------------===//
380 // ConvertMulI
381 //===----------------------------------------------------------------------===//
382 
383 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
385 
387  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
388  ConversionPatternRewriter &rewriter) const override {
389  Location loc = op->getLoc();
390  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
391  if (!newTy)
392  return rewriter.notifyMatchFailure(
393  loc, llvm::formatv("unsupported type: {0}", op.getType()));
394 
395  auto [lhsElem0, lhsElem1] =
396  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
397  auto [rhsElem0, rhsElem1] =
398  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
399 
400  // The multiplication algorithm used is the standard (long) multiplication.
401  // Multiplying two i2N integers produces (at most) an i4N result, but
402  // because the calculation of top i2N is not necessary, we omit it.
403  auto mulLowLow =
404  rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
405  Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
406  Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
407 
408  Value resLow = mulLowLow.getLow();
409  Value resHi =
410  rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
411  resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
412 
413  Value resultVec =
414  constructResultVector(rewriter, loc, newTy, {resLow, resHi});
415  rewriter.replaceOp(op, resultVec);
416  return success();
417  }
418 };
419 
420 //===----------------------------------------------------------------------===//
421 // ConvertExtSI
422 //===----------------------------------------------------------------------===//
423 
424 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
426 
428  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
429  ConversionPatternRewriter &rewriter) const override {
430  Location loc = op->getLoc();
431  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
432  if (!newTy)
433  return rewriter.notifyMatchFailure(
434  loc, llvm::formatv("unsupported type: {0}", op.getType()));
435 
436  Type newResultComponentTy = reduceInnermostDim(newTy);
437 
438  // Sign-extend the input value to determine the low half of the result.
439  // Then, check if the low half is negative, and sign-extend the comparison
440  // result to get the high half.
441  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
442  Value extended = rewriter.createOrFold<arith::ExtSIOp>(
443  loc, newResultComponentTy, newOperand);
444  Value operandZeroCst =
445  createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
446  Value signBit = rewriter.create<arith::CmpIOp>(
447  loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
448  Value signValue =
449  rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
450 
451  Value resultVec =
452  constructResultVector(rewriter, loc, newTy, {extended, signValue});
453  rewriter.replaceOp(op, resultVec);
454  return success();
455  }
456 };
457 
458 //===----------------------------------------------------------------------===//
459 // ConvertExtUI
460 //===----------------------------------------------------------------------===//
461 
462 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
464 
466  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
467  ConversionPatternRewriter &rewriter) const override {
468  Location loc = op->getLoc();
469  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
470  if (!newTy)
471  return rewriter.notifyMatchFailure(
472  loc, llvm::formatv("unsupported type: {0}", op.getType()));
473 
474  Type newResultComponentTy = reduceInnermostDim(newTy);
475 
476  // Zero-extend the input value to determine the low half of the result.
477  // The high half is always zero.
478  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
479  Value extended = rewriter.createOrFold<arith::ExtUIOp>(
480  loc, newResultComponentTy, newOperand);
481  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
482  Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
483  rewriter.replaceOp(op, newRes);
484  return success();
485  }
486 };
487 
488 //===----------------------------------------------------------------------===//
489 // ConvertMaxMin
490 //===----------------------------------------------------------------------===//
491 
492 template <typename SourceOp, arith::CmpIPredicate CmpPred>
493 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
495 
497  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
498  ConversionPatternRewriter &rewriter) const override {
499  Location loc = op->getLoc();
500 
501  Type oldTy = op.getType();
502  auto newTy = dyn_cast_or_null<VectorType>(
503  this->getTypeConverter()->convertType(oldTy));
504  if (!newTy)
505  return rewriter.notifyMatchFailure(
506  loc, llvm::formatv("unsupported type: {0}", op.getType()));
507 
508  // Rewrite Max*I/Min*I as compare and select over original operands. Let
509  // the CmpI and Select emulation patterns handle the final legalization.
510  Value cmp =
511  rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
512  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
513  op.getRhs());
514  return success();
515  }
516 };
517 
518 // Convert IndexCast ops
519 //===----------------------------------------------------------------------===//
520 
521 /// Returns true iff the type is `index` or `vector<...index>`.
522 static bool isIndexOrIndexVector(Type type) {
523  if (isa<IndexType>(type))
524  return true;
525 
526  if (auto vectorTy = dyn_cast<VectorType>(type))
527  if (isa<IndexType>(vectorTy.getElementType()))
528  return true;
529 
530  return false;
531 }
532 
533 template <typename CastOp>
534 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
536 
538  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
539  ConversionPatternRewriter &rewriter) const override {
540  Type resultType = op.getType();
541  if (!isIndexOrIndexVector(resultType))
542  return failure();
543 
544  Location loc = op.getLoc();
545  Type inType = op.getIn().getType();
546  auto newInTy =
547  this->getTypeConverter()->template convertType<VectorType>(inType);
548  if (!newInTy)
549  return rewriter.notifyMatchFailure(
550  loc, llvm::formatv("unsupported type: {0}", inType));
551 
552  // Discard the high half of the input truncating the original value.
553  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
554  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
555  rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
556  return success();
557  }
558 };
559 
560 template <typename CastOp, typename ExtensionOp>
561 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
563 
565  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
566  ConversionPatternRewriter &rewriter) const override {
567  Type inType = op.getIn().getType();
568  if (!isIndexOrIndexVector(inType))
569  return failure();
570 
571  Location loc = op.getLoc();
572  auto *typeConverter =
573  this->template getTypeConverter<arith::WideIntEmulationConverter>();
574 
575  Type resultType = op.getType();
576  auto newTy = typeConverter->template convertType<VectorType>(resultType);
577  if (!newTy)
578  return rewriter.notifyMatchFailure(
579  loc, llvm::formatv("unsupported type: {0}", resultType));
580 
581  // Emit an index cast over the matching narrow type.
582  Type narrowTy =
583  rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584  if (auto vecTy = dyn_cast<VectorType>(resultType))
585  narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
586 
587  // Sign or zero-extend the result. Let the matching conversion pattern
588  // legalize the extension op.
589  Value underlyingVal =
590  rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
591  rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
592  return success();
593  }
594 };
595 
596 //===----------------------------------------------------------------------===//
597 // ConvertSelect
598 //===----------------------------------------------------------------------===//
599 
600 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
602 
604  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
605  ConversionPatternRewriter &rewriter) const override {
606  Location loc = op->getLoc();
607  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
608  if (!newTy)
609  return rewriter.notifyMatchFailure(
610  loc, llvm::formatv("unsupported type: {0}", op.getType()));
611 
612  auto [trueElem0, trueElem1] =
613  extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
614  auto [falseElem0, falseElem1] =
615  extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
616  Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
617 
618  Value resElem0 =
619  rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
620  Value resElem1 =
621  rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
622  Value resultVec =
623  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
624  rewriter.replaceOp(op, resultVec);
625  return success();
626  }
627 };
628 
629 //===----------------------------------------------------------------------===//
630 // ConvertShLI
631 //===----------------------------------------------------------------------===//
632 
633 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
635 
637  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
638  ConversionPatternRewriter &rewriter) const override {
639  Location loc = op->getLoc();
640 
641  Type oldTy = op.getType();
642  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
643  if (!newTy)
644  return rewriter.notifyMatchFailure(
645  loc, llvm::formatv("unsupported type: {0}", op.getType()));
646 
647  Type newOperandTy = reduceInnermostDim(newTy);
648  // `oldBitWidth` == `2 * newBitWidth`
649  unsigned newBitWidth = newTy.getElementTypeBitWidth();
650 
651  auto [lhsElem0, lhsElem1] =
652  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
653  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
654 
655  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
656  // high halves of the results separately:
657  // 1. low := LHS.low shli RHS
658  //
659  // 2. high := a or b or c, where:
660  // a) Bits from LHS.high, shifted by the RHS.
661  // b) Bits from LHS.low, shifted right. These come into play when
662  // RHS < newBitWidth, e.g.:
663  // [0000][llll] shli 3 --> [0lll][l000]
664  // ^
665  // |
666  // [llll] shrui (4 - 3)
667  // c) Bits from LHS.low, shifted left. These matter when
668  // RHS > newBitWidth, e.g.:
669  // [0000][llll] shli 7 --> [l000][0000]
670  // ^
671  // |
672  // [llll] shli (7 - 4)
673  //
674  // Because shifts by values >= newBitWidth are undefined, we ignore the high
675  // half of RHS, and introduce 'bounds checks' to account for
676  // RHS.low > newBitWidth.
677  //
678  // TODO: Explore possible optimizations.
679  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
680  Value elemBitWidth =
681  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
682 
683  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
684  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
685 
686  Value shiftedElem0 =
687  rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
688  Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
689  zeroCst, shiftedElem0);
690 
691  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
692  loc, illegalElemShift, elemBitWidth, rhsElem0);
693  Value rightShiftAmount =
694  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
695  Value shiftedRight =
696  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
697  Value overshotShiftAmount =
698  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
699  Value shiftedLeft =
700  rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
701 
702  Value shiftedElem1 =
703  rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
704  Value resElem1High = rewriter.create<arith::SelectOp>(
705  loc, illegalElemShift, zeroCst, shiftedElem1);
706  Value resElem1Low = rewriter.create<arith::SelectOp>(
707  loc, illegalElemShift, shiftedLeft, shiftedRight);
708  Value resElem1 =
709  rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
710 
711  Value resultVec =
712  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
713  rewriter.replaceOp(op, resultVec);
714  return success();
715  }
716 };
717 
718 //===----------------------------------------------------------------------===//
719 // ConvertShRUI
720 //===----------------------------------------------------------------------===//
721 
722 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
724 
726  matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
727  ConversionPatternRewriter &rewriter) const override {
728  Location loc = op->getLoc();
729 
730  Type oldTy = op.getType();
731  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
732  if (!newTy)
733  return rewriter.notifyMatchFailure(
734  loc, llvm::formatv("unsupported type: {0}", op.getType()));
735 
736  Type newOperandTy = reduceInnermostDim(newTy);
737  // `oldBitWidth` == `2 * newBitWidth`
738  unsigned newBitWidth = newTy.getElementTypeBitWidth();
739 
740  auto [lhsElem0, lhsElem1] =
741  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
742  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
743 
744  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
745  // high halves of the results separately:
746  // 1. low := a or b or c, where:
747  // a) Bits from LHS.low, shifted by the RHS.
748  // b) Bits from LHS.high, shifted left. These matter when
749  // RHS < newBitWidth, e.g.:
750  // [hhhh][0000] shrui 3 --> [000h][hhh0]
751  // ^
752  // |
753  // [hhhh] shli (4 - 1)
754  // c) Bits from LHS.high, shifted right. These come into play when
755  // RHS > newBitWidth, e.g.:
756  // [hhhh][0000] shrui 7 --> [0000][000h]
757  // ^
758  // |
759  // [hhhh] shrui (7 - 4)
760  //
761  // 2. high := LHS.high shrui RHS
762  //
763  // Because shifts by values >= newBitWidth are undefined, we ignore the high
764  // half of RHS, and introduce 'bounds checks' to account for
765  // RHS.low > newBitWidth.
766  //
767  // TODO: Explore possible optimizations.
768  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
769  Value elemBitWidth =
770  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
771 
772  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
773  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
774 
775  Value shiftedElem0 =
776  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
777  Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
778  zeroCst, shiftedElem0);
779  Value shiftedElem1 =
780  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
781  Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
782  zeroCst, shiftedElem1);
783 
784  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
785  loc, illegalElemShift, elemBitWidth, rhsElem0);
786  Value leftShiftAmount =
787  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
788  Value shiftedLeft =
789  rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
790  Value overshotShiftAmount =
791  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
792  Value shiftedRight =
793  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
794 
795  Value resElem0High = rewriter.create<arith::SelectOp>(
796  loc, illegalElemShift, shiftedRight, shiftedLeft);
797  Value resElem0 =
798  rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
799 
800  Value resultVec =
801  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
802  rewriter.replaceOp(op, resultVec);
803  return success();
804  }
805 };
806 
807 //===----------------------------------------------------------------------===//
808 // ConvertShRSI
809 //===----------------------------------------------------------------------===//
810 
811 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
813 
815  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
816  ConversionPatternRewriter &rewriter) const override {
817  Location loc = op->getLoc();
818 
819  Type oldTy = op.getType();
820  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
821  if (!newTy)
822  return rewriter.notifyMatchFailure(
823  loc, llvm::formatv("unsupported type: {0}", op.getType()));
824 
825  Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
826  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
827 
828  Type narrowTy = rhsElem0.getType();
829  int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
830 
831  // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
832  // Perform as many ops over the narrow integer type as possible and let the
833  // other emulation patterns convert the rest.
834  Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
835  Value signBit = rewriter.create<arith::CmpIOp>(
836  loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
837  signBit = dropTrailingX1Dim(rewriter, loc, signBit);
838 
839  // Create a bit pattern of either all ones or all zeros. Then shift it left
840  // to calculate the sign extension bits created by shifting the original
841  // sign bit right.
842  Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
843  Value maxShift =
844  createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
845  Value numNonSignExtBits =
846  rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
847  numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
848  numNonSignExtBits =
849  rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
850  Value signBits =
851  rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
852 
853  // Use original arguments to create the right shift.
854  Value shrui =
855  rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
856  Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
857 
858  // Handle shifting by zero. This is necessary when the `signBits` shift is
859  // invalid.
860  Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
861  rhsElem0, elemZero);
862  isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
863  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
864  shrsi);
865 
866  return success();
867  }
868 };
869 
870 //===----------------------------------------------------------------------===//
871 // ConvertSIToFP
872 //===----------------------------------------------------------------------===//
873 
874 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
876 
878  matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
879  ConversionPatternRewriter &rewriter) const override {
880  Location loc = op.getLoc();
881 
882  Value in = op.getIn();
883  Type oldTy = in.getType();
884  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
885  if (!newTy)
886  return rewriter.notifyMatchFailure(
887  loc, llvm::formatv("unsupported type: {0}", oldTy));
888 
889  unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
890  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
891  Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
892  Value allOnesCst = createScalarOrSplatConstant(
893  rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
894 
895  // To avoid operating on very large unsigned numbers, perform the
896  // conversion on the absolute value. Then, decide whether to negate the
897  // result or not based on that sign bit. We assume two's complement and
898  // implement negation by flipping all bits and adding 1.
899  // Note that this relies on the the other conversion patterns to legalize
900  // created ops and narrow the bit widths.
901  Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
902  in, zeroCst);
903  Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
904  Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
905  Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
906 
907  Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
908  Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
909  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
910  absResult);
911  return success();
912  }
913 };
914 
915 //===----------------------------------------------------------------------===//
916 // ConvertUIToFP
917 //===----------------------------------------------------------------------===//
918 
919 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
921 
923  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
924  ConversionPatternRewriter &rewriter) const override {
925  Location loc = op.getLoc();
926 
927  Type oldTy = op.getIn().getType();
928  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
929  if (!newTy)
930  return rewriter.notifyMatchFailure(
931  loc, llvm::formatv("unsupported type: {0}", oldTy));
932  unsigned newBitWidth = newTy.getElementTypeBitWidth();
933 
934  auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
935  Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
936  Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
937  Value zeroCst =
938  createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
939 
940  // The final result has the following form:
941  // if (hi == 0) return uitofp(low)
942  // else return uitofp(low) + uitofp(hi) * 2^BW
943  //
944  // where `BW` is the bitwidth of the narrowed integer type. We emit a
945  // select to make it easier to fold-away the `hi` part calculation when it
946  // is known to be zero.
947  //
948  // Note 1: The emulation is precise only for input values that have exact
949  // integer representation in the result floating point type, and may lead
950  // loss of precision otherwise.
951  //
952  // Note 2: We do not strictly need the `hi == 0`, case, but it makes
953  // constant folding easier.
954  Value hiEqZero = rewriter.create<arith::CmpIOp>(
955  loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
956 
957  Type resultTy = op.getType();
958  Type resultElemTy = getElementTypeOrSelf(resultTy);
959  Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
960  Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
961 
962  int64_t pow2Int = int64_t(1) << newBitWidth;
963  TypedAttr pow2Attr =
964  rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
965  if (auto vecTy = dyn_cast<VectorType>(resultTy))
966  pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
967 
968  Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
969 
970  Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
971  Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
972 
973  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
974  return success();
975  }
976 };
977 
978 //===----------------------------------------------------------------------===//
979 // ConvertTruncI
980 //===----------------------------------------------------------------------===//
981 
982 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
984 
986  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
987  ConversionPatternRewriter &rewriter) const override {
988  Location loc = op.getLoc();
989  // Check if the result type is legal for this target. Currently, we do not
990  // support truncation to types wider than supported by the target.
991  if (!getTypeConverter()->isLegal(op.getType()))
992  return rewriter.notifyMatchFailure(
993  loc, llvm::formatv("unsupported truncation result type: {0}",
994  op.getType()));
995 
996  // Discard the high half of the input. Truncate the low half, if
997  // necessary.
998  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
999  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
1000  Value truncated =
1001  rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1002  rewriter.replaceOp(op, truncated);
1003  return success();
1004  }
1005 };
1006 
1007 //===----------------------------------------------------------------------===//
1008 // ConvertVectorPrint
1009 //===----------------------------------------------------------------------===//
1010 
1011 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1013 
1015  matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1016  ConversionPatternRewriter &rewriter) const override {
1017  rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1018  return success();
1019  }
1020 };
1021 
1022 //===----------------------------------------------------------------------===//
1023 // Pass Definition
1024 //===----------------------------------------------------------------------===//
1025 
1026 struct EmulateWideIntPass final
1027  : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1028  using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1029 
1030  void runOnOperation() override {
1031  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1032  signalPassFailure();
1033  return;
1034  }
1035 
1036  Operation *op = getOperation();
1037  MLIRContext *ctx = op->getContext();
1038 
1039  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1040  ConversionTarget target(*ctx);
1041  target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1042  return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1043  });
1044  auto opLegalCallback = [&typeConverter](Operation *op) {
1045  return typeConverter.isLegal(op);
1046  };
1047  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1048  target
1049  .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
1050  opLegalCallback);
1051 
1052  RewritePatternSet patterns(ctx);
1053  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1054 
1055  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1056  signalPassFailure();
1057  }
1058 };
1059 } // end anonymous namespace
1060 
1061 //===----------------------------------------------------------------------===//
1062 // Public Interface Definition
1063 //===----------------------------------------------------------------------===//
1064 
1066  unsigned widestIntSupportedByTarget)
1067  : maxIntWidth(widestIntSupportedByTarget) {
1068  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1069  "Only power-of-two integers with are supported");
1070  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1071 
1072  // Allow unknown types.
1073  addConversion([](Type ty) -> std::optional<Type> { return ty; });
1074 
1075  // Scalar case.
1076  addConversion([this](IntegerType ty) -> std::optional<Type> {
1077  unsigned width = ty.getWidth();
1078  if (width <= maxIntWidth)
1079  return ty;
1080 
1081  // i2N --> vector<2xiN>
1082  if (width == 2 * maxIntWidth)
1083  return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1084 
1085  return std::nullopt;
1086  });
1087 
1088  // Vector case.
1089  addConversion([this](VectorType ty) -> std::optional<Type> {
1090  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1091  if (!intTy)
1092  return ty;
1093 
1094  unsigned width = intTy.getWidth();
1095  if (width <= maxIntWidth)
1096  return ty;
1097 
1098  // vector<...xi2N> --> vector<...x2xiN>
1099  if (width == 2 * maxIntWidth) {
1100  auto newShape = to_vector(ty.getShape());
1101  newShape.push_back(2);
1102  return VectorType::get(newShape,
1103  IntegerType::get(ty.getContext(), maxIntWidth));
1104  }
1105 
1106  return std::nullopt;
1107  });
1108 
1109  // Function case.
1110  addConversion([this](FunctionType ty) -> std::optional<Type> {
1111  // Convert inputs and results, e.g.:
1112  // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1113  SmallVector<Type> inputs;
1114  if (failed(convertTypes(ty.getInputs(), inputs)))
1115  return std::nullopt;
1116 
1117  SmallVector<Type> results;
1118  if (failed(convertTypes(ty.getResults(), results)))
1119  return std::nullopt;
1120 
1121  return FunctionType::get(ty.getContext(), inputs, results);
1122  });
1123 }
1124 
1126  WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
1127  // Populate `func.*` conversion patterns.
1128  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1129  typeConverter);
1130  populateCallOpTypeConversionPattern(patterns, typeConverter);
1131  populateReturnOpTypeConversionPattern(patterns, typeConverter);
1132 
1133  // Populate `arith.*` conversion patterns.
1134  patterns.add<
1135  // Misc ops.
1136  ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1137  // Binary ops.
1138  ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1139  ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1140  ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1141  ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1142  ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1143  // Bitwise binary ops.
1144  ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1145  ConvertBitwiseBinary<arith::XOrIOp>,
1146  // Extension and truncation ops.
1147  ConvertExtSI, ConvertExtUI, ConvertTruncI,
1148  // Cast ops.
1149  ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1150  ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1151  ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1152  ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1153  ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
1154 }
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:201
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26