MLIR  17.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1 //===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <cassert>
21 
22 namespace mlir::arith {
23 #define GEN_PASS_DEF_ARITHEMULATEWIDEINT
24 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
25 } // namespace mlir::arith
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Common Helper Functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
34 /// Treats `value` as a 2*N bits-wide integer.
35 /// The bottom bits are returned in the first pair element, while the top bits
36 /// in the second one.
37 static std::pair<APInt, APInt> getHalves(const APInt &value,
38  unsigned newBitWidth) {
39  APInt low = value.extractBits(newBitWidth, 0);
40  APInt high = value.extractBits(newBitWidth, newBitWidth);
41  return {std::move(low), std::move(high)};
42 }
43 
44 /// Returns the type with the last (innermost) dimention reduced to x1.
45 /// Scalarizes 1D vector inputs to match how we extract/insert vector values,
46 /// e.g.:
47 /// - vector<3x2xi16> --> vector<3x1xi16>
48 /// - vector<2xi16> --> i16
49 static Type reduceInnermostDim(VectorType type) {
50  if (type.getShape().size() == 1)
51  return type.getElementType();
52 
53  auto newShape = to_vector(type.getShape());
54  newShape.back() = 1;
55  return VectorType::get(newShape, type.getElementType());
56 }
57 
58 /// Returns a constant of integer of vector type filled with (repeated) `value`.
60  Location loc, Type type,
61  const APInt &value) {
62  Attribute attr;
63  if (auto intTy = type.dyn_cast<IntegerType>()) {
64  attr = rewriter.getIntegerAttr(type, value);
65  } else {
66  auto vecTy = type.cast<VectorType>();
67  attr = SplatElementsAttr::get(vecTy, value);
68  }
69 
70  return rewriter.create<arith::ConstantOp>(loc, attr);
71 }
72 
73 /// Returns a constant of integer of vector type filled with (repeated) `value`.
75  Location loc, Type type,
76  int64_t value) {
77  unsigned elementBitWidth = 0;
78  if (auto intTy = type.dyn_cast<IntegerType>())
79  elementBitWidth = intTy.getWidth();
80  else
81  elementBitWidth = type.cast<VectorType>().getElementTypeBitWidth();
82 
83  return createScalarOrSplatConstant(rewriter, loc, type,
84  APInt(elementBitWidth, value));
85 }
86 
87 /// Extracts the `input` vector slice with elements at the last dimension offset
88 /// by `lastOffset`. Returns a value of vector type with the last dimension
89 /// reduced to x1 or fully scalarized, e.g.:
90 /// - vector<3x2xi16> --> vector<3x1xi16>
91 /// - vector<2xi16> --> i16
93  Location loc, Value input,
94  int64_t lastOffset) {
95  ArrayRef<int64_t> shape = input.getType().cast<VectorType>().getShape();
96  assert(lastOffset < shape.back() && "Offset out of bounds");
97 
98  // Scalarize the result in case of 1D vectors.
99  if (shape.size() == 1)
100  return rewriter.create<vector::ExtractOp>(loc, input, lastOffset);
101 
102  SmallVector<int64_t> offsets(shape.size(), 0);
103  offsets.back() = lastOffset;
104  auto sizes = llvm::to_vector(shape);
105  sizes.back() = 1;
106  SmallVector<int64_t> strides(shape.size(), 1);
107 
108  return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets,
109  sizes, strides);
110 }
111 
112 /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
113 /// with the first element at offset 0 and the second element at offset 1.
114 static std::pair<Value, Value>
116  Value input) {
117  return {extractLastDimSlice(rewriter, loc, input, 0),
118  extractLastDimSlice(rewriter, loc, input, 1)};
119 }
120 
121 // Performs a vector shape cast to drop the trailing x1 dimension. If the
122 // `input` is a scalar, this is a noop.
124  Location loc, Value input) {
125  auto vecTy = input.getType().dyn_cast<VectorType>();
126  if (!vecTy)
127  return input;
128 
129  // Shape cast to drop the last x1 dimention.
130  ArrayRef<int64_t> shape = vecTy.getShape();
131  assert(shape.size() >= 2 && "Expected vector with at list two dims");
132  assert(shape.back() == 1 && "Expected the last vector dim to be x1");
133 
134  auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
135  return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input);
136 }
137 
138 /// Performs a vector shape cast to append an x1 dimension. If the
139 /// `input` is a scalar, this is a noop.
141  Value input) {
142  auto vecTy = input.getType().dyn_cast<VectorType>();
143  if (!vecTy)
144  return input;
145 
146  // Add a trailing x1 dim.
147  auto newShape = llvm::to_vector(vecTy.getShape());
148  newShape.push_back(1);
149  auto newTy = VectorType::get(newShape, vecTy.getElementType());
150  return rewriter.create<vector::ShapeCastOp>(loc, newTy, input);
151 }
152 
153 /// Inserts the `source` vector slice into the `dest` vector at offset
154 /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
155 /// a 1D vector.
157  Location loc, Value source, Value dest,
158  int64_t lastOffset) {
159  ArrayRef<int64_t> shape = dest.getType().cast<VectorType>().getShape();
160  assert(lastOffset < shape.back() && "Offset out of bounds");
161 
162  // Handle scalar source.
163  if (source.getType().isa<IntegerType>())
164  return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset);
165 
166  SmallVector<int64_t> offsets(shape.size(), 0);
167  offsets.back() = lastOffset;
168  SmallVector<int64_t> strides(shape.size(), 1);
169  return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest,
170  offsets, strides);
171 }
172 
173 /// Constructs a new vector of type `resultType` by creating a series of
174 /// insertions of `resultComponents`, each at the next offset of the last vector
175 /// dimension.
176 /// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
177 /// when `resultComponents` are `vector<...x1xT>`s, the result type is
178 /// `vector<...xNxT>`, where `N` is the number of `resultComponenets`.
180  Location loc, VectorType resultType,
181  ValueRange resultComponents) {
182  llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
183  (void)resultShape;
184  assert(!resultShape.empty() && "Result expected to have dimentions");
185  assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
186  "Wrong number of result components");
187 
188  Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
189  for (auto [i, component] : llvm::enumerate(resultComponents))
190  resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
191 
192  return resultVec;
193 }
194 
195 namespace {
196 //===----------------------------------------------------------------------===//
197 // ConvertConstant
198 //===----------------------------------------------------------------------===//
199 
200 struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
202 
204  matchAndRewrite(arith::ConstantOp op, OpAdaptor,
205  ConversionPatternRewriter &rewriter) const override {
206  Type oldType = op.getType();
207  auto newType = getTypeConverter()->convertType(oldType).cast<VectorType>();
208  unsigned newBitWidth = newType.getElementTypeBitWidth();
209  Attribute oldValue = op.getValueAttr();
210 
211  if (auto intAttr = oldValue.dyn_cast<IntegerAttr>()) {
212  auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
213  auto newAttr = DenseElementsAttr::get(newType, {low, high});
214  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
215  return success();
216  }
217 
218  if (auto splatAttr = oldValue.dyn_cast<SplatElementsAttr>()) {
219  auto [low, high] =
220  getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
221  int64_t numSplatElems = splatAttr.getNumElements();
222  SmallVector<APInt> values;
223  values.reserve(numSplatElems * 2);
224  for (int64_t i = 0; i < numSplatElems; ++i) {
225  values.push_back(low);
226  values.push_back(high);
227  }
228 
229  auto attr = DenseElementsAttr::get(newType, values);
230  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
231  return success();
232  }
233 
234  if (auto elemsAttr = oldValue.dyn_cast<DenseElementsAttr>()) {
235  int64_t numElems = elemsAttr.getNumElements();
236  SmallVector<APInt> values;
237  values.reserve(numElems * 2);
238  for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
239  auto [low, high] = getHalves(origVal, newBitWidth);
240  values.push_back(std::move(low));
241  values.push_back(std::move(high));
242  }
243 
244  auto attr = DenseElementsAttr::get(newType, values);
245  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
246  return success();
247  }
248 
249  return rewriter.notifyMatchFailure(op.getLoc(),
250  "unhandled constant attribute");
251  }
252 };
253 
254 //===----------------------------------------------------------------------===//
255 // ConvertAddI
256 //===----------------------------------------------------------------------===//
257 
258 struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
260 
262  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const override {
264  Location loc = op->getLoc();
265  auto newTy = getTypeConverter()
266  ->convertType(op.getType())
267  .dyn_cast_or_null<VectorType>();
268  if (!newTy)
269  return rewriter.notifyMatchFailure(
270  loc, llvm::formatv("unsupported type: {0}", op.getType()));
271 
272  Type newElemTy = reduceInnermostDim(newTy);
273 
274  auto [lhsElem0, lhsElem1] =
275  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
276  auto [rhsElem0, rhsElem1] =
277  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
278 
279  auto lowSum =
280  rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0);
281  Value overflowVal =
282  rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow());
283 
284  Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1);
285  Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1);
286 
287  Value resultVec =
288  constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
289  rewriter.replaceOp(op, resultVec);
290  return success();
291  }
292 };
293 
294 //===----------------------------------------------------------------------===//
295 // ConvertBitwiseBinary
296 //===----------------------------------------------------------------------===//
297 
298 /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
299 template <typename BinaryOp>
300 struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
302  using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
303 
305  matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
306  ConversionPatternRewriter &rewriter) const override {
307  Location loc = op->getLoc();
308  auto newTy = this->getTypeConverter()
309  ->convertType(op.getType())
310  .template dyn_cast_or_null<VectorType>();
311  if (!newTy)
312  return rewriter.notifyMatchFailure(
313  loc, llvm::formatv("unsupported type: {0}", op.getType()));
314 
315  auto [lhsElem0, lhsElem1] =
316  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
317  auto [rhsElem0, rhsElem1] =
318  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
319 
320  Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0);
321  Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1);
322  Value resultVec =
323  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
324  rewriter.replaceOp(op, resultVec);
325  return success();
326  }
327 };
328 
329 //===----------------------------------------------------------------------===//
330 // ConvertCmpI
331 //===----------------------------------------------------------------------===//
332 
333 /// Returns the matching unsigned version of the given predicate `pred`, or the
334 /// same predicate if `pred` is not a signed.
335 static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
336  using P = arith::CmpIPredicate;
337  switch (pred) {
338  case P::sge:
339  return P::uge;
340  case P::sgt:
341  return P::ugt;
342  case P::sle:
343  return P::ule;
344  case P::slt:
345  return P::ult;
346  default:
347  return pred;
348  }
349 }
350 
351 struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
353 
355  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
356  ConversionPatternRewriter &rewriter) const override {
357  Location loc = op->getLoc();
358  auto inputTy = getTypeConverter()
359  ->convertType(op.getLhs().getType())
360  .dyn_cast_or_null<VectorType>();
361  if (!inputTy)
362  return rewriter.notifyMatchFailure(
363  loc, llvm::formatv("unsupported type: {0}", op.getType()));
364 
365  arith::CmpIPredicate highPred = adaptor.getPredicate();
366  arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
367 
368  auto [lhsElem0, lhsElem1] =
369  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
370  auto [rhsElem0, rhsElem1] =
371  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
372 
373  Value lowCmp =
374  rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0);
375  Value highCmp =
376  rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1);
377 
378  Value cmpResult{};
379  switch (highPred) {
380  case arith::CmpIPredicate::eq: {
381  cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp);
382  break;
383  }
384  case arith::CmpIPredicate::ne: {
385  cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp);
386  break;
387  }
388  default: {
389  // Handle inequality checks.
390  Value highEq = rewriter.create<arith::CmpIOp>(
391  loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
392  cmpResult =
393  rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp);
394  break;
395  }
396  }
397 
398  assert(cmpResult && "Unhandled case");
399  rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
400  return success();
401  }
402 };
403 
404 //===----------------------------------------------------------------------===//
405 // ConvertMulI
406 //===----------------------------------------------------------------------===//
407 
408 struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
410 
412  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
413  ConversionPatternRewriter &rewriter) const override {
414  Location loc = op->getLoc();
415  auto newTy = getTypeConverter()
416  ->convertType(op.getType())
417  .dyn_cast_or_null<VectorType>();
418  if (!newTy)
419  return rewriter.notifyMatchFailure(
420  loc, llvm::formatv("unsupported type: {0}", op.getType()));
421 
422  auto [lhsElem0, lhsElem1] =
423  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
424  auto [rhsElem0, rhsElem1] =
425  extractLastDimHalves(rewriter, loc, adaptor.getRhs());
426 
427  // The multiplication algorithm used is the standard (long) multiplication.
428  // Multiplying two i2N integers produces (at most) an i4N result, but
429  // because the calculation of top i2N is not necessary, we omit it.
430  auto mulLowLow =
431  rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0);
432  Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1);
433  Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0);
434 
435  Value resLow = mulLowLow.getLow();
436  Value resHi =
437  rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi);
438  resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow);
439 
440  Value resultVec =
441  constructResultVector(rewriter, loc, newTy, {resLow, resHi});
442  rewriter.replaceOp(op, resultVec);
443  return success();
444  }
445 };
446 
447 //===----------------------------------------------------------------------===//
448 // ConvertExtSI
449 //===----------------------------------------------------------------------===//
450 
451 struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
453 
455  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
456  ConversionPatternRewriter &rewriter) const override {
457  Location loc = op->getLoc();
458  auto newTy = getTypeConverter()
459  ->convertType(op.getType())
460  .dyn_cast_or_null<VectorType>();
461  if (!newTy)
462  return rewriter.notifyMatchFailure(
463  loc, llvm::formatv("unsupported type: {0}", op.getType()));
464 
465  Type newResultComponentTy = reduceInnermostDim(newTy);
466 
467  // Sign-extend the input value to determine the low half of the result.
468  // Then, check if the low half is negative, and sign-extend the comparison
469  // result to get the high half.
470  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
471  Value extended = rewriter.createOrFold<arith::ExtSIOp>(
472  loc, newResultComponentTy, newOperand);
473  Value operandZeroCst =
474  createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
475  Value signBit = rewriter.create<arith::CmpIOp>(
476  loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
477  Value signValue =
478  rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit);
479 
480  Value resultVec =
481  constructResultVector(rewriter, loc, newTy, {extended, signValue});
482  rewriter.replaceOp(op, resultVec);
483  return success();
484  }
485 };
486 
487 //===----------------------------------------------------------------------===//
488 // ConvertExtUI
489 //===----------------------------------------------------------------------===//
490 
491 struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
493 
495  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
496  ConversionPatternRewriter &rewriter) const override {
497  Location loc = op->getLoc();
498  auto newTy = getTypeConverter()
499  ->convertType(op.getType())
500  .dyn_cast_or_null<VectorType>();
501  if (!newTy)
502  return rewriter.notifyMatchFailure(
503  loc, llvm::formatv("unsupported type: {0}", op.getType()));
504 
505  Type newResultComponentTy = reduceInnermostDim(newTy);
506 
507  // Zero-extend the input value to determine the low half of the result.
508  // The high half is always zero.
509  Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
510  Value extended = rewriter.createOrFold<arith::ExtUIOp>(
511  loc, newResultComponentTy, newOperand);
512  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
513  Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
514  rewriter.replaceOp(op, newRes);
515  return success();
516  }
517 };
518 
519 //===----------------------------------------------------------------------===//
520 // ConvertMaxMin
521 //===----------------------------------------------------------------------===//
522 
523 template <typename SourceOp, arith::CmpIPredicate CmpPred>
524 struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
526 
528  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
529  ConversionPatternRewriter &rewriter) const override {
530  Location loc = op->getLoc();
531 
532  Type oldTy = op.getType();
533  auto newTy = this->getTypeConverter()
534  ->convertType(oldTy)
535  .template dyn_cast_or_null<VectorType>();
536  if (!newTy)
537  return rewriter.notifyMatchFailure(
538  loc, llvm::formatv("unsupported type: {0}", op.getType()));
539 
540  // Rewrite Max*I/Min*I as compare and select over original operands. Let
541  // the CmpI and Select emulation patterns handle the final legalization.
542  Value cmp =
543  rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs());
544  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
545  op.getRhs());
546  return success();
547  }
548 };
549 
550 // Convert IndexCast ops
551 //===----------------------------------------------------------------------===//
552 
553 /// Returns true iff the type is `index` or `vector<...index>`.
554 static bool isIndexOrIndexVector(Type type) {
555  if (type.isa<IndexType>())
556  return true;
557 
558  if (auto vectorTy = type.dyn_cast<VectorType>())
559  if (vectorTy.getElementType().isa<IndexType>())
560  return true;
561 
562  return false;
563 }
564 
565 template <typename CastOp>
566 struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
568 
570  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
571  ConversionPatternRewriter &rewriter) const override {
572  Type resultType = op.getType();
573  if (!isIndexOrIndexVector(resultType))
574  return failure();
575 
576  Location loc = op.getLoc();
577  Type inType = op.getIn().getType();
578  auto newInTy = this->getTypeConverter()
579  ->convertType(inType)
580  .template dyn_cast_or_null<VectorType>();
581  if (!newInTy)
582  return rewriter.notifyMatchFailure(
583  loc, llvm::formatv("unsupported type: {0}", inType));
584 
585  // Discard the high half of the input truncating the original value.
586  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
587  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
588  rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
589  return success();
590  }
591 };
592 
593 template <typename CastOp, typename ExtensionOp>
594 struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
596 
598  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
599  ConversionPatternRewriter &rewriter) const override {
600  Type inType = op.getIn().getType();
601  if (!isIndexOrIndexVector(inType))
602  return failure();
603 
604  Location loc = op.getLoc();
605  auto *typeConverter =
606  this->template getTypeConverter<arith::WideIntEmulationConverter>();
607 
608  Type resultType = op.getType();
609  auto newTy = typeConverter->convertType(resultType)
610  .template dyn_cast_or_null<VectorType>();
611  if (!newTy)
612  return rewriter.notifyMatchFailure(
613  loc, llvm::formatv("unsupported type: {0}", resultType));
614 
615  // Emit an index cast over the matching narrow type.
616  Type narrowTy =
617  rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
618  if (auto vecTy = resultType.dyn_cast<VectorType>())
619  narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
620 
621  // Sign or zero-extend the result. Let the matching conversion pattern
622  // legalize the extension op.
623  Value underlyingVal =
624  rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
625  rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
626  return success();
627  }
628 };
629 
630 //===----------------------------------------------------------------------===//
631 // ConvertSelect
632 //===----------------------------------------------------------------------===//
633 
634 struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
636 
638  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
639  ConversionPatternRewriter &rewriter) const override {
640  Location loc = op->getLoc();
641  auto newTy = getTypeConverter()
642  ->convertType(op.getType())
643  .dyn_cast_or_null<VectorType>();
644  if (!newTy)
645  return rewriter.notifyMatchFailure(
646  loc, llvm::formatv("unsupported type: {0}", op.getType()));
647 
648  auto [trueElem0, trueElem1] =
649  extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
650  auto [falseElem0, falseElem1] =
651  extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
652  Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
653 
654  Value resElem0 =
655  rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0);
656  Value resElem1 =
657  rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1);
658  Value resultVec =
659  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
660  rewriter.replaceOp(op, resultVec);
661  return success();
662  }
663 };
664 
665 //===----------------------------------------------------------------------===//
666 // ConvertShLI
667 //===----------------------------------------------------------------------===//
668 
669 struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
671 
673  matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
674  ConversionPatternRewriter &rewriter) const override {
675  Location loc = op->getLoc();
676 
677  Type oldTy = op.getType();
678  auto newTy =
679  getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
680  if (!newTy)
681  return rewriter.notifyMatchFailure(
682  loc, llvm::formatv("unsupported type: {0}", op.getType()));
683 
684  Type newOperandTy = reduceInnermostDim(newTy);
685  // `oldBitWidth` == `2 * newBitWidth`
686  unsigned newBitWidth = newTy.getElementTypeBitWidth();
687 
688  auto [lhsElem0, lhsElem1] =
689  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
690  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
691 
692  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
693  // high halves of the results separately:
694  // 1. low := LHS.low shli RHS
695  //
696  // 2. high := a or b or c, where:
697  // a) Bits from LHS.high, shifted by the RHS.
698  // b) Bits from LHS.low, shifted right. These come into play when
699  // RHS < newBitWidth, e.g.:
700  // [0000][llll] shli 3 --> [0lll][l000]
701  // ^
702  // |
703  // [llll] shrui (4 - 3)
704  // c) Bits from LHS.low, shifted left. These matter when
705  // RHS > newBitWidth, e.g.:
706  // [0000][llll] shli 7 --> [l000][0000]
707  // ^
708  // |
709  // [llll] shli (7 - 4)
710  //
711  // Because shifts by values >= newBitWidth are undefined, we ignore the high
712  // half of RHS, and introduce 'bounds checks' to account for
713  // RHS.low > newBitWidth.
714  //
715  // TODO: Explore possible optimizations.
716  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
717  Value elemBitWidth =
718  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
719 
720  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
721  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
722 
723  Value shiftedElem0 =
724  rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0);
725  Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
726  zeroCst, shiftedElem0);
727 
728  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
729  loc, illegalElemShift, elemBitWidth, rhsElem0);
730  Value rightShiftAmount =
731  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
732  Value shiftedRight =
733  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount);
734  Value overshotShiftAmount =
735  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
736  Value shiftedLeft =
737  rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount);
738 
739  Value shiftedElem1 =
740  rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0);
741  Value resElem1High = rewriter.create<arith::SelectOp>(
742  loc, illegalElemShift, zeroCst, shiftedElem1);
743  Value resElem1Low = rewriter.create<arith::SelectOp>(
744  loc, illegalElemShift, shiftedLeft, shiftedRight);
745  Value resElem1 =
746  rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High);
747 
748  Value resultVec =
749  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
750  rewriter.replaceOp(op, resultVec);
751  return success();
752  }
753 };
754 
755 //===----------------------------------------------------------------------===//
756 // ConvertShRUI
757 //===----------------------------------------------------------------------===//
758 
759 struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
761 
763  matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
764  ConversionPatternRewriter &rewriter) const override {
765  Location loc = op->getLoc();
766 
767  Type oldTy = op.getType();
768  auto newTy =
769  getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
770  if (!newTy)
771  return rewriter.notifyMatchFailure(
772  loc, llvm::formatv("unsupported type: {0}", op.getType()));
773 
774  Type newOperandTy = reduceInnermostDim(newTy);
775  // `oldBitWidth` == `2 * newBitWidth`
776  unsigned newBitWidth = newTy.getElementTypeBitWidth();
777 
778  auto [lhsElem0, lhsElem1] =
779  extractLastDimHalves(rewriter, loc, adaptor.getLhs());
780  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
781 
782  // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
783  // high halves of the results separately:
784  // 1. low := a or b or c, where:
785  // a) Bits from LHS.low, shifted by the RHS.
786  // b) Bits from LHS.high, shifted left. These matter when
787  // RHS < newBitWidth, e.g.:
788  // [hhhh][0000] shrui 3 --> [000h][hhh0]
789  // ^
790  // |
791  // [hhhh] shli (4 - 1)
792  // c) Bits from LHS.high, shifted right. These come into play when
793  // RHS > newBitWidth, e.g.:
794  // [hhhh][0000] shrui 7 --> [0000][000h]
795  // ^
796  // |
797  // [hhhh] shrui (7 - 4)
798  //
799  // 2. high := LHS.high shrui RHS
800  //
801  // Because shifts by values >= newBitWidth are undefined, we ignore the high
802  // half of RHS, and introduce 'bounds checks' to account for
803  // RHS.low > newBitWidth.
804  //
805  // TODO: Explore possible optimizations.
806  Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
807  Value elemBitWidth =
808  createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
809 
810  Value illegalElemShift = rewriter.create<arith::CmpIOp>(
811  loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
812 
813  Value shiftedElem0 =
814  rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0);
815  Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
816  zeroCst, shiftedElem0);
817  Value shiftedElem1 =
818  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0);
819  Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift,
820  zeroCst, shiftedElem1);
821 
822  Value cappedShiftAmount = rewriter.create<arith::SelectOp>(
823  loc, illegalElemShift, elemBitWidth, rhsElem0);
824  Value leftShiftAmount =
825  rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount);
826  Value shiftedLeft =
827  rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount);
828  Value overshotShiftAmount =
829  rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth);
830  Value shiftedRight =
831  rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount);
832 
833  Value resElem0High = rewriter.create<arith::SelectOp>(
834  loc, illegalElemShift, shiftedRight, shiftedLeft);
835  Value resElem0 =
836  rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High);
837 
838  Value resultVec =
839  constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
840  rewriter.replaceOp(op, resultVec);
841  return success();
842  }
843 };
844 
845 //===----------------------------------------------------------------------===//
846 // ConvertShRSI
847 //===----------------------------------------------------------------------===//
848 
849 struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
851 
853  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const override {
855  Location loc = op->getLoc();
856 
857  Type oldTy = op.getType();
858  auto newTy =
859  getTypeConverter()->convertType(oldTy).dyn_cast_or_null<VectorType>();
860  if (!newTy)
861  return rewriter.notifyMatchFailure(
862  loc, llvm::formatv("unsupported type: {0}", op.getType()));
863 
864  Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
865  Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
866 
867  Type narrowTy = rhsElem0.getType();
868  int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
869 
870  // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
871  // Perform as many ops over the narrow integer type as possible and let the
872  // other emulation patterns convert the rest.
873  Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
874  Value signBit = rewriter.create<arith::CmpIOp>(
875  loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
876  signBit = dropTrailingX1Dim(rewriter, loc, signBit);
877 
878  // Create a bit pattern of either all ones or all zeros. Then shift it left
879  // to calculate the sign extension bits created by shifting the original
880  // sign bit right.
881  Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit);
882  Value maxShift =
883  createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
884  Value numNonSignExtBits =
885  rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0);
886  numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
887  numNonSignExtBits =
888  rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits);
889  Value signBits =
890  rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
891 
892  // Use original arguments to create the right shift.
893  Value shrui =
894  rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
895  Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
896 
897  // Handle shifting by zero. This is necessary when the `signBits` shift is
898  // invalid.
899  Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
900  rhsElem0, elemZero);
901  isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
902  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
903  shrsi);
904 
905  return success();
906  }
907 };
908 
909 //===----------------------------------------------------------------------===//
910 // ConvertTruncI
911 //===----------------------------------------------------------------------===//
912 
913 struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
915 
917  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
918  ConversionPatternRewriter &rewriter) const override {
919  Location loc = op.getLoc();
920  // Check if the result type is legal for this target. Currently, we do not
921  // support truncation to types wider than supported by the target.
922  if (!getTypeConverter()->isLegal(op.getType()))
923  return rewriter.notifyMatchFailure(
924  loc, llvm::formatv("unsupported truncation result type: {0}",
925  op.getType()));
926 
927  // Discard the high half of the input. Truncate the low half, if
928  // necessary.
929  Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
930  extracted = dropTrailingX1Dim(rewriter, loc, extracted);
931  Value truncated =
932  rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
933  rewriter.replaceOp(op, truncated);
934  return success();
935  }
936 };
937 
938 //===----------------------------------------------------------------------===//
939 // ConvertVectorPrint
940 //===----------------------------------------------------------------------===//
941 
942 struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
944 
946  matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
947  ConversionPatternRewriter &rewriter) const override {
948  rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
949  return success();
950  }
951 };
952 
953 //===----------------------------------------------------------------------===//
954 // Pass Definition
955 //===----------------------------------------------------------------------===//
956 
957 struct EmulateWideIntPass final
958  : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
959  using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
960 
961  void runOnOperation() override {
962  if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
963  signalPassFailure();
964  return;
965  }
966 
967  Operation *op = getOperation();
968  MLIRContext *ctx = op->getContext();
969 
970  arith::WideIntEmulationConverter typeConverter(widestIntSupported);
971  ConversionTarget target(*ctx);
972  target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
973  return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
974  });
975  auto opLegalCallback = [&typeConverter](Operation *op) {
976  return typeConverter.isLegal(op);
977  };
978  target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
979  target
980  .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
981  opLegalCallback);
982 
983  RewritePatternSet patterns(ctx);
984  arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
985 
986  if (failed(applyPartialConversion(op, target, std::move(patterns))))
987  signalPassFailure();
988  }
989 };
990 } // end anonymous namespace
991 
992 //===----------------------------------------------------------------------===//
993 // Public Interface Definition
994 //===----------------------------------------------------------------------===//
995 
997  unsigned widestIntSupportedByTarget)
998  : maxIntWidth(widestIntSupportedByTarget) {
999  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1000  "Only power-of-two integers with are supported");
1001  assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1002 
1003  // Allow unknown types.
1004  addConversion([](Type ty) -> std::optional<Type> { return ty; });
1005 
1006  // Scalar case.
1007  addConversion([this](IntegerType ty) -> std::optional<Type> {
1008  unsigned width = ty.getWidth();
1009  if (width <= maxIntWidth)
1010  return ty;
1011 
1012  // i2N --> vector<2xiN>
1013  if (width == 2 * maxIntWidth)
1014  return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1015 
1016  return std::nullopt;
1017  });
1018 
1019  // Vector case.
1020  addConversion([this](VectorType ty) -> std::optional<Type> {
1021  auto intTy = ty.getElementType().dyn_cast<IntegerType>();
1022  if (!intTy)
1023  return ty;
1024 
1025  unsigned width = intTy.getWidth();
1026  if (width <= maxIntWidth)
1027  return ty;
1028 
1029  // vector<...xi2N> --> vector<...x2xiN>
1030  if (width == 2 * maxIntWidth) {
1031  auto newShape = to_vector(ty.getShape());
1032  newShape.push_back(2);
1033  return VectorType::get(newShape,
1034  IntegerType::get(ty.getContext(), maxIntWidth));
1035  }
1036 
1037  return std::nullopt;
1038  });
1039 
1040  // Function case.
1041  addConversion([this](FunctionType ty) -> std::optional<Type> {
1042  // Convert inputs and results, e.g.:
1043  // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1044  SmallVector<Type> inputs;
1045  if (failed(convertTypes(ty.getInputs(), inputs)))
1046  return std::nullopt;
1047 
1048  SmallVector<Type> results;
1049  if (failed(convertTypes(ty.getResults(), results)))
1050  return std::nullopt;
1051 
1052  return FunctionType::get(ty.getContext(), inputs, results);
1053  });
1054 }
1055 
1057  WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
1058  // Populate `func.*` conversion patterns.
1059  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
1060  typeConverter);
1061  populateCallOpTypeConversionPattern(patterns, typeConverter);
1062  populateReturnOpTypeConversionPattern(patterns, typeConverter);
1063 
1064  // Populate `arith.*` conversion patterns.
1065  patterns.add<
1066  // Misc ops.
1067  ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1068  // Binary ops.
1069  ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1070  ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1071  ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1072  ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1073  ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1074  // Bitwise binary ops.
1075  ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1076  ConvertBitwiseBinary<arith::XOrIOp>,
1077  // Extension and truncation ops.
1078  ConvertExtSI, ConvertExtUI, ConvertTruncI,
1079  // Cast ops.
1080  ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1081  ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1082  ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1083  ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
1084  typeConverter, patterns.getContext());
1085 }
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter, Location loc, Type type, const APInt &value)
Returns a constant of integer of vector type filled with (repeated) value.
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimention reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:174
U dyn_cast() const
Definition: Attributes.h:169
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
This class describes a specific conversion target.
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:473
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:191
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results)
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
U dyn_cast_or_null() const
Definition: Types.h:316
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:350
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26