MLIR  20.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/MathExtras.h"
23 #include <cassert>
24 
25 namespace mlir::arith {
26 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28 } // namespace mlir::arith
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Common Helper Functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
37 /// Treats `value` as a 2*N bits-wide integer.
38 /// The bottom bits are returned in the first pair element, while the top bits
39 /// in the second one.
40 static std::pair<APInt, APInt> getHalves(const APInt &value,
41  unsigned newBitWidth) {
42  APInt low = value.extractBits(newBitWidth, 0);
43  APInt high = value.extractBits(newBitWidth, newBitWidth);
44  return {std::move(low), std::move(high)};
45 }
46 
47 /// Returns the type with the last (innermost) dimension reduced to x1.
48 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
49 /// e.g.:
50 /// - vector<3x2xi16> --> vector<3x1xi16>
51 /// - vector<2xi16> --> i16
52 static Type reduceInnermostDim(VectorType type) {
53  if (type.getShape().size() == 1)
54  return type.getElementType();
55 
56  auto newShape = to_vector(type.getShape());
57  newShape.back() = 1;
58  return VectorType::get(newShape, type.getElementType());
59 }
60 
61 /// Extracts the `input` vector slice with elements at the last dimension offset
62 /// by `lastOffset`. Returns a value of vector type with the last dimension
63 /// reduced to x1 or fully scalarized, e.g.:
64 /// - vector<3x2xi16> --> vector<3x1xi16>
65 /// - vector<2xi16> --> i16
67  Location loc, Value input,
68  int64_t lastOffset) {
69  ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
70  assert(lastOffset < shape.back() && "Offset out of bounds");
71 
72  // Scalarize the result in case of 1D vectors.
73  if (shape.size() == 1)
74  return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
75 
76  SmallVector<int64_t> offsets(shape.size(), 0);
77  offsets.back() = lastOffset;
78  auto sizes = llvm::to_vector(shape);
79  sizes.back() = 1;
80  SmallVector<int64_t> strides(shape.size(), 1);
81 
82  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
83  sizes, strides);
84 }
85 
86 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
87 /// with the first element at offset 0 and the second element at offset 1.
88 static std::pair<Value, Value>
90  Value input) {
91  return {extractLastDimSlice(rewriter, loc, input, 0),
92  extractLastDimSlice(rewriter, loc, input, 1)};
93 }
94 
95 // Performs a vector shape cast to drop the trailing x1 dimension. If the
96 // `input` is a scalar, this is a noop.
98  Location loc, Value input) {
99  auto vecTy = dyn_cast<VectorType>(input.getType());
100  if (!vecTy)
101  return input;
102 
103  // Shape cast to drop the last x1 dimension.
104  ArrayRef<int64_t> shape = vecTy.getShape();
105  assert(shape.size() >= 2 && "Expected vector with at list two dims");
106  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
107 
108  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
109  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
110 }
111 
112 /// Performs a vector shape cast to append an x1 dimension. If the
113 /// `input` is a scalar, this is a noop.
115  Value input) {
116  auto vecTy = dyn_cast<VectorType>(input.getType());
117  if (!vecTy)
118  return input;
119 
120  // Add a trailing x1 dim.
121  auto newShape = llvm::to_vector(vecTy.getShape());
122  newShape.push_back(1);
123  auto newTy = VectorType::get(newShape, vecTy.getElementType());
124  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
125 }
126 
127 /// Inserts the `source` vector slice into the `dest` vector at offset
128 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
129 /// a 1D vector.
131  Location loc, Value source, Value dest,
132  int64_t lastOffset) {
133  ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
134  assert(lastOffset < shape.back() && "Offset out of bounds");
135 
136  // Handle scalar source.
137  if (isa<IntegerType>(source.getType()))
138  return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
139 
140  SmallVector<int64_t> offsets(shape.size(), 0);
141  offsets.back() = lastOffset;
142  SmallVector<int64_t> strides(shape.size(), 1);
143  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
144  offsets, strides);
145 }
146 
147 /// Constructs a new vector of type `resultType` by creating a series of
148 /// insertions of `resultComponents`, each at the next offset of the last vector
149 /// dimension.
150 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
151 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
152 /// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
154  Location loc, VectorType resultType,
155  ValueRange resultComponents) {
156  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
157  (void)resultShape;
158  assert(!resultShape.empty() && "Result expected to have dimensions");
159  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
160  "Wrong number of result components");
161 
162  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
163  for (auto [i, component] : llvm::enumerate(resultComponents))
164  resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
165 
166  return resultVec;
167 }
168 
169 namespace {
170 //===----------------------------------------------------------------------===//
171 // ConvertConstant
172 //===----------------------------------------------------------------------===//
173 
174 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
176 
177  LogicalResult
178  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  Type oldType = op.getType();
181  auto newType = getTypeConverter()->convertType<VectorType>(oldType);
182  if (!newType)
183  return rewriter.notifyMatchFailure(
184  op, llvm::formatv("unsupported type: {0}", op.getType()));
185 
186  unsigned newBitWidth = newType.getElementTypeBitWidth();
187  Attribute oldValue = op.getValueAttr();
188 
189  if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
190  auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
191  auto newAttr = DenseElementsAttr::get(newType, {low, high});
192  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
193  return success();
194  }
195 
196  if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
197  auto [low, high] =
198  getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
199  int64_t numSplatElems = splatAttr.getNumElements();
200  SmallVector<APInt> values;
201  values.reserve(numSplatElems * 2);
202  for (int64_t i = 0; i < numSplatElems; ++i) {
203  values.push_back(low);
204  values.push_back(high);
205  }
206 
207  auto attr = DenseElementsAttr::get(newType, values);
208  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
209  return success();
210  }
211 
212  if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
213  int64_t numElems = elemsAttr.getNumElements();
214  SmallVector<APInt> values;
215  values.reserve(numElems * 2);
216  for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
217  auto [low, high] = getHalves(origVal, newBitWidth);
218  values.push_back(std::move(low));
219  values.push_back(std::move(high));
220  }
221 
222  auto attr = DenseElementsAttr::get(newType, values);
223  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
224  return success();
225  }
226 
227  return rewriter.notifyMatchFailure(op.getLoc(),
228  "unhandled constant attribute");
229  }
230 };
231 
232 //===----------------------------------------------------------------------===//
233 // ConvertAddI
234 //===----------------------------------------------------------------------===//
235 
236 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
238 
239  LogicalResult
240  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
241  ConversionPatternRewriter &rewriter) const override {
242  Location loc = op->getLoc();
243  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
244  if (!newTy)
245  return rewriter.notifyMatchFailure(
246  loc, llvm::formatv("unsupported type: {0}", op.getType()));
247 
248  Type newElemTy = reduceInnermostDim(newTy);
249 
250  auto [lhsElem0, lhsElem1] =
251  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
252  auto [rhsElem0, rhsElem1] =
253  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
254 
255  auto lowSum =
256  rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
257  Value overflowVal =
258  rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
259 
260  Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
261  Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
262 
263  Value resultVec =
264  constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
265  rewriter.replaceOp(op, resultVec);
266  return success();
267  }
268 };
269 
270 //===----------------------------------------------------------------------===//
271 // ConvertBitwiseBinary
272 //===----------------------------------------------------------------------===//
273 
274 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
275 template <typename BinaryOp>
276 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
279 
280  LogicalResult
281  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
282  ConversionPatternRewriter &rewriter) const override {
283  Location loc = op->getLoc();
284  auto newTy = this->getTypeConverter()->template convertType<VectorType>(
285  op.getType());
286  if (!newTy)
287  return rewriter.notifyMatchFailure(
288  loc, llvm::formatv("unsupported type: {0}", op.getType()));
289 
290  auto [lhsElem0, lhsElem1] =
291  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
292  auto [rhsElem0, rhsElem1] =
293  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
294 
295  Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
296  Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
297  Value resultVec =
298  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
299  rewriter.replaceOp(op, resultVec);
300  return success();
301  }
302 };
303 
304 //===----------------------------------------------------------------------===//
305 // ConvertCmpI
306 //===----------------------------------------------------------------------===//
307 
308 /// Returns the matching unsigned version of the given predicate `pred`, or the
309 /// same predicate if `pred` is not a signed.
310 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
311  using P = arith::CmpIPredicate;
312  switch (pred) {
313  case P::sge:
314  return P::uge;
315  case P::sgt:
316  return P::ugt;
317  case P::sle:
318  return P::ule;
319  case P::slt:
320  return P::ult;
321  default:
322  return pred;
323  }
324 }
325 
326 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
328 
329  LogicalResult
330  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
331  ConversionPatternRewriter &rewriter) const override {
332  Location loc = op->getLoc();
333  auto inputTy =
334  getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
335  if (!inputTy)
336  return rewriter.notifyMatchFailure(
337  loc, llvm::formatv("unsupported type: {0}", op.getType()));
338 
339  arith::CmpIPredicate highPred = adaptor.getPredicate();
340  arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
341 
342  auto [lhsElem0, lhsElem1] =
343  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
344  auto [rhsElem0, rhsElem1] =
345  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
346 
347  Value lowCmp =
348  rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
349  Value highCmp =
350  rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
351 
352  Value cmpResult{};
353  switch (highPred) {
354  case arith::CmpIPredicate::eq: {
355  cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
356  break;
357  }
358  case arith::CmpIPredicate::ne: {
359  cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
360  break;
361  }
362  default: {
363  // Handle inequality checks.
364  Value highEq = rewriter.create<arith::CmpIOp>(
365  loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
366  cmpResult =
367  rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
368  break;
369  }
370  }
371 
372  assert(cmpResult && "Unhandled case");
373  rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
374  return success();
375  }
376 };
377 
378 //===----------------------------------------------------------------------===//
379 // ConvertMulI
380 //===----------------------------------------------------------------------===//
381 
382 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
384 
385  LogicalResult
386  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override {
388  Location loc = op->getLoc();
389  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
390  if (!newTy)
391  return rewriter.notifyMatchFailure(
392  loc, llvm::formatv("unsupported type: {0}", op.getType()));
393 
394  auto [lhsElem0, lhsElem1] =
395  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
396  auto [rhsElem0, rhsElem1] =
397  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
398 
399  // The multiplication algorithm used is the standard (long) multiplication.
400  // Multiplying two i2N integers produces (at most) an i4N result, but
401  // because the calculation of top i2N is not necessary, we omit it.
402  auto mulLowLow =
403  rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
404  Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
405  Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
406 
407  Value resLow = mulLowLow.getLow();
408  Value resHi =
409  rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
410  resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
411 
412  Value resultVec =
413  constructResultVector(rewriter, loc, newTy, {resLow, resHi});
414  rewriter.replaceOp(op, resultVec);
415  return success();
416  }
417 };
418 
419 //===----------------------------------------------------------------------===//
420 // ConvertExtSI
421 //===----------------------------------------------------------------------===//
422 
423 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
425 
426  LogicalResult
427  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
428  ConversionPatternRewriter &rewriter) const override {
429  Location loc = op->getLoc();
430  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
431  if (!newTy)
432  return rewriter.notifyMatchFailure(
433  loc, llvm::formatv("unsupported type: {0}", op.getType()));
434 
435  Type newResultComponentTy = reduceInnermostDim(newTy);
436 
437  // Sign-extend the input value to determine the low half of the result.
438  // Then, check if the low half is negative, and sign-extend the comparison
439  // result to get the high half.
440  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
441  Value extended = rewriter.createOrFold<arith::ExtSIOp>(
442  loc, newResultComponentTy, newOperand);
443  Value operandZeroCst =
444  createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
445  Value signBit = rewriter.create<arith::CmpIOp>(
446  loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
447  Value signValue =
448  rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
449 
450  Value resultVec =
451  constructResultVector(rewriter, loc, newTy, {extended, signValue});
452  rewriter.replaceOp(op, resultVec);
453  return success();
454  }
455 };
456 
457 //===----------------------------------------------------------------------===//
458 // ConvertExtUI
459 //===----------------------------------------------------------------------===//
460 
461 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
463 
464  LogicalResult
465  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
466  ConversionPatternRewriter &rewriter) const override {
467  Location loc = op->getLoc();
468  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
469  if (!newTy)
470  return rewriter.notifyMatchFailure(
471  loc, llvm::formatv("unsupported type: {0}", op.getType()));
472 
473  Type newResultComponentTy = reduceInnermostDim(newTy);
474 
475  // Zero-extend the input value to determine the low half of the result.
476  // The high half is always zero.
477  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
478  Value extended = rewriter.createOrFold<arith::ExtUIOp>(
479  loc, newResultComponentTy, newOperand);
480  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
481  Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
482  rewriter.replaceOp(op, newRes);
483  return success();
484  }
485 };
486 
487 //===----------------------------------------------------------------------===//
488 // ConvertMaxMin
489 //===----------------------------------------------------------------------===//
490 
491 template <typename SourceOp, arith::CmpIPredicate CmpPred>
492 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 
495  LogicalResult
496  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
497  ConversionPatternRewriter &rewriter) const override {
498  Location loc = op->getLoc();
499 
500  Type oldTy = op.getType();
501  auto newTy = dyn_cast_or_null<VectorType>(
502  this->getTypeConverter()->convertType(oldTy));
503  if (!newTy)
504  return rewriter.notifyMatchFailure(
505  loc, llvm::formatv("unsupported type: {0}", op.getType()));
506 
507  // Rewrite Max*I/Min*I as compare and select over original operands. Let
508  // the CmpI and Select emulation patterns handle the final legalization.
509  Value cmp =
510  rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
511  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
512  op.getRhs());
513  return success();
514  }
515 };
516 
517 // Convert IndexCast ops
518 //===----------------------------------------------------------------------===//
519 
520 /// Returns true iff the type is `index` or `vector<...index>`.
521 static bool isIndexOrIndexVector(Type type) {
522  if (isa<IndexType>(type))
523  return true;
524 
525  if (auto vectorTy = dyn_cast<VectorType>(type))
526  if (isa<IndexType>(vectorTy.getElementType()))
527  return true;
528 
529  return false;
530 }
531 
532 template <typename CastOp>
533 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 
536  LogicalResult
537  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  Type resultType = op.getType();
540  if (!isIndexOrIndexVector(resultType))
541  return failure();
542 
543  Location loc = op.getLoc();
544  Type inType = op.getIn().getType();
545  auto newInTy =
546  this->getTypeConverter()->template convertType<VectorType>(inType);
547  if (!newInTy)
548  return rewriter.notifyMatchFailure(
549  loc, llvm::formatv("unsupported type: {0}", inType));
550 
551  // Discard the high half of the input truncating the original value.
552  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
553  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
554  rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
555  return success();
556  }
557 };
558 
559 template <typename CastOp, typename ExtensionOp>
560 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 
563  LogicalResult
564  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
565  ConversionPatternRewriter &rewriter) const override {
566  Type inType = op.getIn().getType();
567  if (!isIndexOrIndexVector(inType))
568  return failure();
569 
570  Location loc = op.getLoc();
571  auto *typeConverter =
572  this->template getTypeConverter<arith::WideIntEmulationConverter>();
573 
574  Type resultType = op.getType();
575  auto newTy = typeConverter->template convertType<VectorType>(resultType);
576  if (!newTy)
577  return rewriter.notifyMatchFailure(
578  loc, llvm::formatv("unsupported type: {0}", resultType));
579 
580  // Emit an index cast over the matching narrow type.
581  Type narrowTy =
582  rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
583  if (auto vecTy = dyn_cast<VectorType>(resultType))
584  narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
585 
586  // Sign or zero-extend the result. Let the matching conversion pattern
587  // legalize the extension op.
588  Value underlyingVal =
589  rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
590  rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
591  return success();
592  }
593 };
594 
595 //===----------------------------------------------------------------------===//
596 // ConvertSelect
597 //===----------------------------------------------------------------------===//
598 
599 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
601 
602  LogicalResult
603  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
604  ConversionPatternRewriter &rewriter) const override {
605  Location loc = op->getLoc();
606  auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
607  if (!newTy)
608  return rewriter.notifyMatchFailure(
609  loc, llvm::formatv("unsupported type: {0}", op.getType()));
610 
611  auto [trueElem0, trueElem1] =
612  extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
613  auto [falseElem0, falseElem1] =
614  extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
615  Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
616 
617  Value resElem0 =
618  rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
619  Value resElem1 =
620  rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
621  Value resultVec =
622  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
623  rewriter.replaceOp(op, resultVec);
624  return success();
625  }
626 };
627 
628 //===----------------------------------------------------------------------===//
629 // ConvertShLI
630 //===----------------------------------------------------------------------===//
631 
632 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
634 
635  LogicalResult
636  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
637  ConversionPatternRewriter &rewriter) const override {
638  Location loc = op->getLoc();
639 
640  Type oldTy = op.getType();
641  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
642  if (!newTy)
643  return rewriter.notifyMatchFailure(
644  loc, llvm::formatv("unsupported type: {0}", op.getType()));
645 
646  Type newOperandTy = reduceInnermostDim(newTy);
647  // `oldBitWidth` == `2 * newBitWidth`
648  unsigned newBitWidth = newTy.getElementTypeBitWidth();
649 
650  auto [lhsElem0, lhsElem1] =
651  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
652  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
653 
654  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
655  // high halves of the results separately:
656  // 1. low := LHS.low shli RHS
657  //
658  // 2. high := a or b or c, where:
659  // a) Bits from LHS.high, shifted by the RHS.
660  // b) Bits from LHS.low, shifted right. These come into play when
661  // RHS < newBitWidth, e.g.:
662  // [0000][llll] shli 3 --> [0lll][l000]
663  // ^
664  // |
665  // [llll] shrui (4 - 3)
666  // c) Bits from LHS.low, shifted left. These matter when
667  // RHS > newBitWidth, e.g.:
668  // [0000][llll] shli 7 --> [l000][0000]
669  // ^
670  // |
671  // [llll] shli (7 - 4)
672  //
673  // Because shifts by values >= newBitWidth are undefined, we ignore the high
674  // half of RHS, and introduce 'bounds checks' to account for
675  // RHS.low > newBitWidth.
676  //
677  // TODO: Explore possible optimizations.
678  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
679  Value elemBitWidth =
680  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
681 
682  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
683  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
684 
685  Value shiftedElem0 =
686  rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
687  Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
688  zeroCst, shiftedElem0);
689 
690  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
691  loc, illegalElemShift, elemBitWidth, rhsElem0);
692  Value rightShiftAmount =
693  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
694  Value shiftedRight =
695  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
696  Value overshotShiftAmount =
697  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
698  Value shiftedLeft =
699  rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
700 
701  Value shiftedElem1 =
702  rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
703  Value resElem1High = rewriter.create<arith::SelectOp>(
704  loc, illegalElemShift, zeroCst, shiftedElem1);
705  Value resElem1Low = rewriter.create<arith::SelectOp>(
706  loc, illegalElemShift, shiftedLeft, shiftedRight);
707  Value resElem1 =
708  rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
709 
710  Value resultVec =
711  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
712  rewriter.replaceOp(op, resultVec);
713  return success();
714  }
715 };
716 
717 //===----------------------------------------------------------------------===//
718 // ConvertShRUI
719 //===----------------------------------------------------------------------===//
720 
721 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
723 
724  LogicalResult
725  matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
726  ConversionPatternRewriter &rewriter) const override {
727  Location loc = op->getLoc();
728 
729  Type oldTy = op.getType();
730  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
731  if (!newTy)
732  return rewriter.notifyMatchFailure(
733  loc, llvm::formatv("unsupported type: {0}", op.getType()));
734 
735  Type newOperandTy = reduceInnermostDim(newTy);
736  // `oldBitWidth` == `2 * newBitWidth`
737  unsigned newBitWidth = newTy.getElementTypeBitWidth();
738 
739  auto [lhsElem0, lhsElem1] =
740  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
741  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
742 
743  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
744  // high halves of the results separately:
745  // 1. low := a or b or c, where:
746  // a) Bits from LHS.low, shifted by the RHS.
747  // b) Bits from LHS.high, shifted left. These matter when
748  // RHS < newBitWidth, e.g.:
749  // [hhhh][0000] shrui 3 --> [000h][hhh0]
750  // ^
751  // |
752  // [hhhh] shli (4 - 1)
753  // c) Bits from LHS.high, shifted right. These come into play when
754  // RHS > newBitWidth, e.g.:
755  // [hhhh][0000] shrui 7 --> [0000][000h]
756  // ^
757  // |
758  // [hhhh] shrui (7 - 4)
759  //
760  // 2. high := LHS.high shrui RHS
761  //
762  // Because shifts by values >= newBitWidth are undefined, we ignore the high
763  // half of RHS, and introduce 'bounds checks' to account for
764  // RHS.low > newBitWidth.
765  //
766  // TODO: Explore possible optimizations.
767  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
768  Value elemBitWidth =
769  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
770 
771  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
772  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
773 
774  Value shiftedElem0 =
775  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
776  Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
777  zeroCst, shiftedElem0);
778  Value shiftedElem1 =
779  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
780  Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
781  zeroCst, shiftedElem1);
782 
783  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
784  loc, illegalElemShift, elemBitWidth, rhsElem0);
785  Value leftShiftAmount =
786  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
787  Value shiftedLeft =
788  rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
789  Value overshotShiftAmount =
790  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
791  Value shiftedRight =
792  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
793 
794  Value resElem0High = rewriter.create<arith::SelectOp>(
795  loc, illegalElemShift, shiftedRight, shiftedLeft);
796  Value resElem0 =
797  rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
798 
799  Value resultVec =
800  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
801  rewriter.replaceOp(op, resultVec);
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // ConvertShRSI
808 //===----------------------------------------------------------------------===//
809 
810 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
812 
813  LogicalResult
814  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
815  ConversionPatternRewriter &rewriter) const override {
816  Location loc = op->getLoc();
817 
818  Type oldTy = op.getType();
819  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
820  if (!newTy)
821  return rewriter.notifyMatchFailure(
822  loc, llvm::formatv("unsupported type: {0}", op.getType()));
823 
824  Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
825  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
826 
827  Type narrowTy = rhsElem0.getType();
828  int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
829 
830  // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
831  // Perform as many ops over the narrow integer type as possible and let the
832  // other emulation patterns convert the rest.
833  Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
834  Value signBit = rewriter.create<arith::CmpIOp>(
835  loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
836  signBit = dropTrailingX1Dim(rewriter, loc, signBit);
837 
838  // Create a bit pattern of either all ones or all zeros. Then shift it left
839  // to calculate the sign extension bits created by shifting the original
840  // sign bit right.
841  Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
842  Value maxShift =
843  createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
844  Value numNonSignExtBits =
845  rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
846  numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
847  numNonSignExtBits =
848  rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
849  Value signBits =
850  rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
851 
852  // Use original arguments to create the right shift.
853  Value shrui =
854  rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
855  Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
856 
857  // Handle shifting by zero. This is necessary when the `signBits` shift is
858  // invalid.
859  Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
860  rhsElem0, elemZero);
861  isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
862  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
863  shrsi);
864 
865  return success();
866  }
867 };
868 
869 //===----------------------------------------------------------------------===//
870 // ConvertSIToFP
871 //===----------------------------------------------------------------------===//
872 
873 struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
875 
876  LogicalResult
877  matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
878  ConversionPatternRewriter &rewriter) const override {
879  Location loc = op.getLoc();
880 
881  Value in = op.getIn();
882  Type oldTy = in.getType();
883  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
884  if (!newTy)
885  return rewriter.notifyMatchFailure(
886  loc, llvm::formatv("unsupported type: {0}", oldTy));
887 
888  unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
889  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
890  Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
891  Value allOnesCst = createScalarOrSplatConstant(
892  rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
893 
894  // To avoid operating on very large unsigned numbers, perform the
895  // conversion on the absolute value. Then, decide whether to negate the
896  // result or not based on that sign bit. We assume two's complement and
897  // implement negation by flipping all bits and adding 1.
898  // Note that this relies on the the other conversion patterns to legalize
899  // created ops and narrow the bit widths.
900  Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
901  in, zeroCst);
902  Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
903  Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
904  Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
905 
906  Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
907  Value negResult = rewriter.create<arith::NegFOp>(loc, absResult);
908  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
909  absResult);
910  return success();
911  }
912 };
913 
914 //===----------------------------------------------------------------------===//
915 // ConvertUIToFP
916 //===----------------------------------------------------------------------===//
917 
918 struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
920 
921  LogicalResult
922  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924  Location loc = op.getLoc();
925 
926  Type oldTy = op.getIn().getType();
927  auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
928  if (!newTy)
929  return rewriter.notifyMatchFailure(
930  loc, llvm::formatv("unsupported type: {0}", oldTy));
931  unsigned newBitWidth = newTy.getElementTypeBitWidth();
932 
933  auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
934  Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
935  Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
936  Value zeroCst =
937  createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
938 
939  // The final result has the following form:
940  // if (hi == 0) return uitofp(low)
941  // else return uitofp(low) + uitofp(hi) * 2^BW
942  //
943  // where `BW` is the bitwidth of the narrowed integer type. We emit a
944  // select to make it easier to fold-away the `hi` part calculation when it
945  // is known to be zero.
946  //
947  // Note 1: The emulation is precise only for input values that have exact
948  // integer representation in the result floating point type, and may lead
949  // loss of precision otherwise.
950  //
951  // Note 2: We do not strictly need the `hi == 0`, case, but it makes
952  // constant folding easier.
953  Value hiEqZero = rewriter.create<arith::CmpIOp>(
954  loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
955 
956  Type resultTy = op.getType();
957  Type resultElemTy = getElementTypeOrSelf(resultTy);
958  Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt);
959  Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);
960 
961  int64_t pow2Int = int64_t(1) << newBitWidth;
962  TypedAttr pow2Attr =
963  rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
964  if (auto vecTy = dyn_cast<VectorType>(resultTy))
965  pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
966 
967  Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr);
968 
969  Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val);
970  Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal);
971 
972  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
973  return success();
974  }
975 };
976 
977 //===----------------------------------------------------------------------===//
978 // ConvertTruncI
979 //===----------------------------------------------------------------------===//
980 
981 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
983 
984  LogicalResult
985  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
986  ConversionPatternRewriter &rewriter) const override {
987  Location loc = op.getLoc();
988  // Check if the result type is legal for this target. Currently, we do not
989  // support truncation to types wider than supported by the target.
990  if (!getTypeConverter()->isLegal(op.getType()))
991  return rewriter.notifyMatchFailure(
992  loc, llvm::formatv("unsupported truncation result type: {0}",
993  op.getType()));
994 
995  // Discard the high half of the input. Truncate the low half, if
996  // necessary.
997  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
998  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
999  Value truncated =
1000  rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1001  rewriter.replaceOp(op, truncated);
1002  return success();
1003  }
1004 };
1005 
1006 //===----------------------------------------------------------------------===//
1007 // ConvertVectorPrint
1008 //===----------------------------------------------------------------------===//
1009 
1010 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1012 
1013  LogicalResult
1014  matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1015  ConversionPatternRewriter &rewriter) const override {
1016  rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1017  return success();
1018  }
1019 };
1020 
1021 //===----------------------------------------------------------------------===//
1022 // Pass Definition
1023 //===----------------------------------------------------------------------===//
1024 
1025 struct EmulateWideIntPass final
1026  : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1027  using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1028 
1029  void runOnOperation() override {
1030  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1031  signalPassFailure();
1032  return;
1033  }
1034 
1035  Operation *op = getOperation();
1036  MLIRContext *ctx = op->getContext();
1037 
1038  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1039  ConversionTarget target(*ctx);
1040  target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1041  return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1042  });
1043  auto opLegalCallback = [&typeConverter](Operation *op) {
1044  return typeConverter.isLegal(op);
1045  };
1046  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1047  target
1048  .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
1049  opLegalCallback);
1050 
1051  RewritePatternSet patterns(ctx);
1052  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1053 
1054  if (failed(applyPartialConversion(op, target, std::move(patterns))))
1055  signalPassFailure();
1056  }
1057 };
1058 } // end anonymous namespace
1059 
1060 //===----------------------------------------------------------------------===//
1061 // Public Interface Definition
1062 //===----------------------------------------------------------------------===//
1063 
1065  unsigned widestIntSupportedByTarget)
1066  : maxIntWidth(widestIntSupportedByTarget) {
1067  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1068  "Only power-of-two integers with are supported");
1069  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1070 
1071  // Allow unknown types.
1072  addConversion([](Type ty) -> std::optional<Type> { return ty; });
1073 
1074  // Scalar case.
1075  addConversion([this](IntegerType ty) -> std::optional<Type> {
1076  unsigned width = ty.getWidth();
1077  if (width <= maxIntWidth)
1078  return ty;
1079 
1080  // i2N --> vector<2xiN>
1081  if (width == 2 * maxIntWidth)
1082  return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1083 
1084  return nullptr;
1085  });
1086 
1087  // Vector case.
1088  addConversion([this](VectorType ty) -> std::optional<Type> {
1089  auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1090  if (!intTy)
1091  return ty;
1092 
1093  unsigned width = intTy.getWidth();
1094  if (width <= maxIntWidth)
1095  return ty;
1096 
1097  // vector<...xi2N> --> vector<...x2xiN>
1098  if (width == 2 * maxIntWidth) {
1099  auto newShape = to_vector(ty.getShape());
1100  newShape.push_back(2);
1101  return VectorType::get(newShape,
1102  IntegerType::get(ty.getContext(), maxIntWidth));
1103  }
1104 
1105  return nullptr;
1106  });
1107 
1108  // Function case.
1109  addConversion([this](FunctionType ty) -> std::optional<Type> {
1110  // Convert inputs and results, e.g.:
1111  // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1112  SmallVector<Type> inputs;
1113  if (failed(convertTypes(ty.getInputs(), inputs)))
1114  return nullptr;
1115 
1116  SmallVector<Type> results;
1117  if (failed(convertTypes(ty.getResults(), results)))
1118  return nullptr;
1119 
1120  return FunctionType::get(ty.getContext(), inputs, results);
1121  });
1122 }
1123 
1125  const WideIntEmulationConverter &typeConverter,
1126  RewritePatternSet &patterns) {
1127  // Populate `func.*` conversion patterns.
1128  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1129  typeConverter);
1130  populateCallOpTypeConversionPattern(patterns, typeConverter);
1131  populateReturnOpTypeConversionPattern(patterns, typeConverter);
1132 
1133  // Populate `arith.*` conversion patterns.
1134  patterns.add<
1135  // Misc ops.
1136  ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1137  // Binary ops.
1138  ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1139  ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1140  ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1141  ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1142  ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1143  // Bitwise binary ops.
1144  ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1145  ConvertBitwiseBinary<arith::XOrIOp>,
1146  // Extension and truncation ops.
1147  ConvertExtSI, ConvertExtUI, ConvertTruncI,
1148  // Cast ops.
1149  ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1150  ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1151  ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1152  ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1153  ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
1154 }
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.