MLIR 22.0.0git
EmulateWideInt.cpp
Go to the documentation of this file.
1//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/APInt.h"
22#include "llvm/Support/FormatVariadic.h"
23#include "llvm/Support/MathExtras.h"
24#include <cassert>
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHEMULATEWIDEINT
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29} // namespace mlir::arith
30
31using namespace mlir;
32
33//===----------------------------------------------------------------------===//
34// Common Helper Functions
35//===----------------------------------------------------------------------===//
36
37/// Returns N bottom and N top bits from `value`, where N = `newBitWidth`.
38/// Treats `value` as a 2*N bits-wide integer.
39/// The bottom bits are returned in the first pair element, while the top bits
40/// in the second one.
41static std::pair<APInt, APInt> getHalves(const APInt &value,
42 unsigned newBitWidth) {
43 APInt low = value.extractBits(newBitWidth, 0);
44 APInt high = value.extractBits(newBitWidth, newBitWidth);
45 return {std::move(low), std::move(high)};
46}
47
48/// Returns the type with the last (innermost) dimension reduced to x1.
49/// Scalarizes 1D vector inputs to match how we extract/insert vector values,
50/// e.g.:
51/// - vector<3x2xi16> --> vector<3x1xi16>
52/// - vector<2xi16> --> i16
53static Type reduceInnermostDim(VectorType type) {
54 if (type.getShape().size() == 1)
55 return type.getElementType();
56
57 auto newShape = to_vector(type.getShape());
58 newShape.back() = 1;
59 return VectorType::get(newShape, type.getElementType());
60}
61
62/// Extracts the `input` vector slice with elements at the last dimension offset
63/// by `lastOffset`. Returns a value of vector type with the last dimension
64/// reduced to x1 or fully scalarized, e.g.:
65/// - vector<3x2xi16> --> vector<3x1xi16>
66/// - vector<2xi16> --> i16
67static Value extractLastDimSlice(ConversionPatternRewriter &rewriter,
68 Location loc, Value input,
69 int64_t lastOffset) {
70 ArrayRef<int64_t> shape = cast<VectorType>(input.getType()).getShape();
71 assert(lastOffset < shape.back() && "Offset out of bounds");
72
73 // Scalarize the result in case of 1D vectors.
74 if (shape.size() == 1)
75 return vector::ExtractOp::create(rewriter, loc, input, lastOffset);
76
77 SmallVector<int64_t> offsets(shape.size(), 0);
78 offsets.back() = lastOffset;
79 auto sizes = llvm::to_vector(shape);
80 sizes.back() = 1;
81 SmallVector<int64_t> strides(shape.size(), 1);
82
83 return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets,
84 sizes, strides);
85}
86
87/// Extracts two vector slices from the `input` whose type is `vector<...x2T>`,
88/// with the first element at offset 0 and the second element at offset 1.
89static std::pair<Value, Value>
90extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc,
91 Value input) {
92 return {extractLastDimSlice(rewriter, loc, input, 0),
93 extractLastDimSlice(rewriter, loc, input, 1)};
94}
95
96// Performs a vector shape cast to drop the trailing x1 dimension. If the
97// `input` is a scalar, this is a noop.
98static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter,
99 Location loc, Value input) {
100 auto vecTy = dyn_cast<VectorType>(input.getType());
101 if (!vecTy)
102 return input;
103
104 // Shape cast to drop the last x1 dimension.
105 ArrayRef<int64_t> shape = vecTy.getShape();
106 assert(shape.size() >= 2 && "Expected vector with at list two dims");
107 assert(shape.back() == 1 && "Expected the last vector dim to be x1");
108
109 auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType());
110 return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input);
111}
112
113/// Performs a vector shape cast to append an x1 dimension. If the
114/// `input` is a scalar, this is a noop.
115static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc,
116 Value input) {
117 auto vecTy = dyn_cast<VectorType>(input.getType());
118 if (!vecTy)
119 return input;
120
121 // Add a trailing x1 dim.
122 auto newShape = llvm::to_vector(vecTy.getShape());
123 newShape.push_back(1);
124 auto newTy = VectorType::get(newShape, vecTy.getElementType());
125 return vector::ShapeCastOp::create(rewriter, loc, newTy, input);
126}
127
128/// Inserts the `source` vector slice into the `dest` vector at offset
129/// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is
130/// a 1D vector.
131static Value insertLastDimSlice(ConversionPatternRewriter &rewriter,
132 Location loc, Value source, Value dest,
133 int64_t lastOffset) {
134 ArrayRef<int64_t> shape = cast<VectorType>(dest.getType()).getShape();
135 assert(lastOffset < shape.back() && "Offset out of bounds");
137 // Handle scalar source.
138 if (isa<IntegerType>(source.getType()))
139 return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset);
141 SmallVector<int64_t> offsets(shape.size(), 0);
142 offsets.back() = lastOffset;
143 SmallVector<int64_t> strides(shape.size(), 1);
144 return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest,
145 offsets, strides);
146}
147
148/// Constructs a new vector of type `resultType` by creating a series of
149/// insertions of `resultComponents`, each at the next offset of the last vector
150/// dimension.
151/// When all `resultComponents` are scalars, the result type is `vector<NxT>`;
152/// when `resultComponents` are `vector<...x1xT>`s, the result type is
153/// `vector<...xNxT>`, where `N` is the number of `resultComponents`.
154static Value constructResultVector(ConversionPatternRewriter &rewriter,
155 Location loc, VectorType resultType,
156 ValueRange resultComponents) {
157 llvm::ArrayRef<int64_t> resultShape = resultType.getShape();
158 (void)resultShape;
159 assert(!resultShape.empty() && "Result expected to have dimensions");
160 assert(resultShape.back() == static_cast<int64_t>(resultComponents.size()) &&
161 "Wrong number of result components");
163 Value resultVec = createScalarOrSplatConstant(rewriter, loc, resultType, 0);
164 for (auto [i, component] : llvm::enumerate(resultComponents))
165 resultVec = insertLastDimSlice(rewriter, loc, component, resultVec, i);
166
167 return resultVec;
168}
169
170namespace {
171//===----------------------------------------------------------------------===//
172// ConvertConstant
173//===----------------------------------------------------------------------===//
174
175struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> {
176 using Base::Base;
177
178 LogicalResult
179 matchAndRewrite(arith::ConstantOp op, OpAdaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 Type oldType = op.getType();
182 auto newType = getTypeConverter()->convertType<VectorType>(oldType);
183 if (!newType)
184 return rewriter.notifyMatchFailure(
185 op, llvm::formatv("unsupported type: {0}", op.getType()));
186
187 unsigned newBitWidth = newType.getElementTypeBitWidth();
188 Attribute oldValue = op.getValueAttr();
189
190 if (auto intAttr = dyn_cast<IntegerAttr>(oldValue)) {
191 auto [low, high] = getHalves(intAttr.getValue(), newBitWidth);
192 auto newAttr = DenseElementsAttr::get(newType, {low, high});
193 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
194 return success();
195 }
196
197 if (auto splatAttr = dyn_cast<SplatElementsAttr>(oldValue)) {
198 auto [low, high] =
199 getHalves(splatAttr.getSplatValue<APInt>(), newBitWidth);
200 int64_t numSplatElems = splatAttr.getNumElements();
201 SmallVector<APInt> values;
202 values.reserve(numSplatElems * 2);
203 for (int64_t i = 0; i < numSplatElems; ++i) {
204 values.push_back(low);
205 values.push_back(high);
206 }
207
208 auto attr = DenseElementsAttr::get(newType, values);
209 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
210 return success();
211 }
212
213 if (auto elemsAttr = dyn_cast<DenseElementsAttr>(oldValue)) {
214 int64_t numElems = elemsAttr.getNumElements();
215 SmallVector<APInt> values;
216 values.reserve(numElems * 2);
217 for (const APInt &origVal : elemsAttr.getValues<APInt>()) {
218 auto [low, high] = getHalves(origVal, newBitWidth);
219 values.push_back(std::move(low));
220 values.push_back(std::move(high));
221 }
222
223 auto attr = DenseElementsAttr::get(newType, values);
224 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, attr);
225 return success();
226 }
227
228 return rewriter.notifyMatchFailure(op.getLoc(),
229 "unhandled constant attribute");
230 }
231};
232
233//===----------------------------------------------------------------------===//
234// ConvertAddI
235//===----------------------------------------------------------------------===//
236
237struct ConvertAddI final : OpConversionPattern<arith::AddIOp> {
238 using Base::Base;
239
240 LogicalResult
241 matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override {
243 Location loc = op->getLoc();
244 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
245 if (!newTy)
246 return rewriter.notifyMatchFailure(
247 loc, llvm::formatv("unsupported type: {0}", op.getType()));
248
249 Type newElemTy = reduceInnermostDim(newTy);
250
251 auto [lhsElem0, lhsElem1] =
252 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
253 auto [rhsElem0, rhsElem1] =
254 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
255
256 auto lowSum =
257 arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
258 Value overflowVal =
259 arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow());
260
261 Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1);
262 Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1);
263
264 Value resultVec =
265 constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high});
266 rewriter.replaceOp(op, resultVec);
267 return success();
268 }
269};
270
271//===----------------------------------------------------------------------===//
272// ConvertBitwiseBinary
273//===----------------------------------------------------------------------===//
274
275/// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`.
276template <typename BinaryOp>
277struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> {
278 using OpConversionPattern<BinaryOp>::OpConversionPattern;
279 using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor;
280
281 LogicalResult
282 matchAndRewrite(BinaryOp op, OpAdaptor adaptor,
283 ConversionPatternRewriter &rewriter) const override {
284 Location loc = op->getLoc();
285 auto newTy = this->getTypeConverter()->template convertType<VectorType>(
286 op.getType());
287 if (!newTy)
288 return rewriter.notifyMatchFailure(
289 loc, llvm::formatv("unsupported type: {0}", op.getType()));
290
291 auto [lhsElem0, lhsElem1] =
292 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
293 auto [rhsElem0, rhsElem1] =
294 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
295
296 Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0);
297 Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1);
298 Value resultVec =
299 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
300 rewriter.replaceOp(op, resultVec);
301 return success();
302 }
303};
304
305//===----------------------------------------------------------------------===//
306// ConvertCmpI
307//===----------------------------------------------------------------------===//
308
309/// Returns the matching unsigned version of the given predicate `pred`, or the
310/// same predicate if `pred` is not a signed.
311static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) {
312 using P = arith::CmpIPredicate;
313 switch (pred) {
314 case P::sge:
315 return P::uge;
316 case P::sgt:
317 return P::ugt;
318 case P::sle:
319 return P::ule;
320 case P::slt:
321 return P::ult;
322 default:
323 return pred;
324 }
325}
326
327struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> {
328 using Base::Base;
329
330 LogicalResult
331 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter) const override {
333 Location loc = op->getLoc();
334 auto inputTy =
335 getTypeConverter()->convertType<VectorType>(op.getLhs().getType());
336 if (!inputTy)
337 return rewriter.notifyMatchFailure(
338 loc, llvm::formatv("unsupported type: {0}", op.getType()));
339
340 arith::CmpIPredicate highPred = adaptor.getPredicate();
341 arith::CmpIPredicate lowPred = toUnsignedPredicate(highPred);
342
343 auto [lhsElem0, lhsElem1] =
344 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
345 auto [rhsElem0, rhsElem1] =
346 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
347
348 Value lowCmp =
349 arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0);
350 Value highCmp =
351 arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1);
352
353 Value cmpResult{};
354 switch (highPred) {
355 case arith::CmpIPredicate::eq: {
356 cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp);
357 break;
358 }
359 case arith::CmpIPredicate::ne: {
360 cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp);
361 break;
362 }
363 default: {
364 // Handle inequality checks.
365 Value highEq = arith::CmpIOp::create(
366 rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1);
367 cmpResult =
368 arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp);
369 break;
370 }
371 }
372
373 assert(cmpResult && "Unhandled case");
374 rewriter.replaceOp(op, dropTrailingX1Dim(rewriter, loc, cmpResult));
375 return success();
376 }
377};
378
379//===----------------------------------------------------------------------===//
380// ConvertMulI
381//===----------------------------------------------------------------------===//
382
383struct ConvertMulI final : OpConversionPattern<arith::MulIOp> {
384 using Base::Base;
385
386 LogicalResult
387 matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
388 ConversionPatternRewriter &rewriter) const override {
389 Location loc = op->getLoc();
390 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
391 if (!newTy)
392 return rewriter.notifyMatchFailure(
393 loc, llvm::formatv("unsupported type: {0}", op.getType()));
394
395 auto [lhsElem0, lhsElem1] =
396 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
397 auto [rhsElem0, rhsElem1] =
398 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
399
400 // The multiplication algorithm used is the standard (long) multiplication.
401 // Multiplying two i2N integers produces (at most) an i4N result, but
402 // because the calculation of top i2N is not necessary, we omit it.
403 auto mulLowLow =
404 arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0);
405 Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1);
406 Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0);
407
408 Value resLow = mulLowLow.getLow();
409 Value resHi =
410 arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi);
411 resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow);
412
413 Value resultVec =
414 constructResultVector(rewriter, loc, newTy, {resLow, resHi});
415 rewriter.replaceOp(op, resultVec);
416 return success();
417 }
418};
419
420//===----------------------------------------------------------------------===//
421// ConvertExtSI
422//===----------------------------------------------------------------------===//
423
424struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> {
425 using Base::Base;
426
427 LogicalResult
428 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
429 ConversionPatternRewriter &rewriter) const override {
430 Location loc = op->getLoc();
431 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
432 if (!newTy)
433 return rewriter.notifyMatchFailure(
434 loc, llvm::formatv("unsupported type: {0}", op.getType()));
435
436 Type newResultComponentTy = reduceInnermostDim(newTy);
437
438 // Sign-extend the input value to determine the low half of the result.
439 // Then, check if the low half is negative, and sign-extend the comparison
440 // result to get the high half.
441 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
442 Value extended = rewriter.createOrFold<arith::ExtSIOp>(
443 loc, newResultComponentTy, newOperand);
444 Value operandZeroCst =
445 createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0);
446 Value signBit = arith::CmpIOp::create(
447 rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst);
448 Value signValue =
449 arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit);
450
451 Value resultVec =
452 constructResultVector(rewriter, loc, newTy, {extended, signValue});
453 rewriter.replaceOp(op, resultVec);
454 return success();
455 }
456};
457
458//===----------------------------------------------------------------------===//
459// ConvertExtUI
460//===----------------------------------------------------------------------===//
461
462struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> {
463 using Base::Base;
464
465 LogicalResult
466 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter) const override {
468 Location loc = op->getLoc();
469 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
470 if (!newTy)
471 return rewriter.notifyMatchFailure(
472 loc, llvm::formatv("unsupported type: {0}", op.getType()));
473
474 Type newResultComponentTy = reduceInnermostDim(newTy);
475
476 // Zero-extend the input value to determine the low half of the result.
477 // The high half is always zero.
478 Value newOperand = appendX1Dim(rewriter, loc, adaptor.getIn());
479 Value extended = rewriter.createOrFold<arith::ExtUIOp>(
480 loc, newResultComponentTy, newOperand);
481 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newTy, 0);
482 Value newRes = insertLastDimSlice(rewriter, loc, extended, zeroCst, 0);
483 rewriter.replaceOp(op, newRes);
484 return success();
485 }
486};
487
488//===----------------------------------------------------------------------===//
489// ConvertMaxMin
490//===----------------------------------------------------------------------===//
491
492template <typename SourceOp, arith::CmpIPredicate CmpPred>
493struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
494 using OpConversionPattern<SourceOp>::OpConversionPattern;
495
496 LogicalResult
497 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
498 ConversionPatternRewriter &rewriter) const override {
499 Location loc = op->getLoc();
500
501 Type oldTy = op.getType();
502 auto newTy = dyn_cast_or_null<VectorType>(
503 this->getTypeConverter()->convertType(oldTy));
504 if (!newTy)
505 return rewriter.notifyMatchFailure(
506 loc, llvm::formatv("unsupported type: {0}", op.getType()));
507
508 // Rewrite Max*I/Min*I as compare and select over original operands. Let
509 // the CmpI and Select emulation patterns handle the final legalization.
510 Value cmp =
511 arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs());
512 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(),
513 op.getRhs());
514 return success();
515 }
516};
517
518// Convert IndexCast ops
519//===----------------------------------------------------------------------===//
520
521/// Returns true iff the type is `index` or `vector<...index>`.
522static bool isIndexOrIndexVector(Type type) {
523 if (isa<IndexType>(type))
524 return true;
525
526 if (auto vectorTy = dyn_cast<VectorType>(type))
527 if (isa<IndexType>(vectorTy.getElementType()))
528 return true;
529
530 return false;
531}
532
533template <typename CastOp>
534struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
535 using OpConversionPattern<CastOp>::OpConversionPattern;
536
537 LogicalResult
538 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter) const override {
540 Type resultType = op.getType();
541 if (!isIndexOrIndexVector(resultType))
542 return failure();
543
544 Location loc = op.getLoc();
545 Type inType = op.getIn().getType();
546 auto newInTy =
547 this->getTypeConverter()->template convertType<VectorType>(inType);
548 if (!newInTy)
549 return rewriter.notifyMatchFailure(
550 loc, llvm::formatv("unsupported type: {0}", inType));
551
552 // Discard the high half of the input truncating the original value.
553 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
554 extracted = dropTrailingX1Dim(rewriter, loc, extracted);
555 rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
556 return success();
557 }
558};
559
560template <typename CastOp, typename ExtensionOp>
561struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
562 using OpConversionPattern<CastOp>::OpConversionPattern;
563
564 LogicalResult
565 matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
566 ConversionPatternRewriter &rewriter) const override {
567 Type inType = op.getIn().getType();
568 if (!isIndexOrIndexVector(inType))
569 return failure();
570
571 Location loc = op.getLoc();
572 auto *typeConverter =
573 this->template getTypeConverter<arith::WideIntEmulationConverter>();
574
575 Type resultType = op.getType();
576 auto newTy = typeConverter->template convertType<VectorType>(resultType);
577 if (!newTy)
578 return rewriter.notifyMatchFailure(
579 loc, llvm::formatv("unsupported type: {0}", resultType));
580
581 // Emit an index cast over the matching narrow type.
582 Type narrowTy =
583 rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
584 if (auto vecTy = dyn_cast<VectorType>(resultType))
585 narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
586
587 // Sign or zero-extend the result. Let the matching conversion pattern
588 // legalize the extension op.
589 Value underlyingVal =
590 CastOp::create(rewriter, loc, narrowTy, adaptor.getIn());
591 rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
592 return success();
593 }
594};
595
596//===----------------------------------------------------------------------===//
597// ConvertSelect
598//===----------------------------------------------------------------------===//
599
600struct ConvertSelect final : OpConversionPattern<arith::SelectOp> {
601 using Base::Base;
602
603 LogicalResult
604 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
605 ConversionPatternRewriter &rewriter) const override {
606 Location loc = op->getLoc();
607 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
608 if (!newTy)
609 return rewriter.notifyMatchFailure(
610 loc, llvm::formatv("unsupported type: {0}", op.getType()));
611
612 auto [trueElem0, trueElem1] =
613 extractLastDimHalves(rewriter, loc, adaptor.getTrueValue());
614 auto [falseElem0, falseElem1] =
615 extractLastDimHalves(rewriter, loc, adaptor.getFalseValue());
616 Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition());
617
618 Value resElem0 =
619 arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0);
620 Value resElem1 =
621 arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1);
622 Value resultVec =
623 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
624 rewriter.replaceOp(op, resultVec);
625 return success();
626 }
627};
628
629//===----------------------------------------------------------------------===//
630// ConvertShLI
631//===----------------------------------------------------------------------===//
632
633struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> {
634 using Base::Base;
635
636 LogicalResult
637 matchAndRewrite(arith::ShLIOp op, OpAdaptor adaptor,
638 ConversionPatternRewriter &rewriter) const override {
639 Location loc = op->getLoc();
640
641 Type oldTy = op.getType();
642 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
643 if (!newTy)
644 return rewriter.notifyMatchFailure(
645 loc, llvm::formatv("unsupported type: {0}", op.getType()));
646
647 Type newOperandTy = reduceInnermostDim(newTy);
648 // `oldBitWidth` == `2 * newBitWidth`
649 unsigned newBitWidth = newTy.getElementTypeBitWidth();
650
651 auto [lhsElem0, lhsElem1] =
652 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
653 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
654
655 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
656 // high halves of the results separately:
657 // 1. low := LHS.low shli RHS
658 //
659 // 2. high := a or b or c, where:
660 // a) Bits from LHS.high, shifted by the RHS.
661 // b) Bits from LHS.low, shifted right. These come into play when
662 // RHS < newBitWidth, e.g.:
663 // [0000][llll] shli 3 --> [0lll][l000]
664 // ^
665 // |
666 // [llll] shrui (4 - 3)
667 // c) Bits from LHS.low, shifted left. These matter when
668 // RHS > newBitWidth, e.g.:
669 // [0000][llll] shli 7 --> [l000][0000]
670 // ^
671 // |
672 // [llll] shli (7 - 4)
673 //
674 // Because shifts by values >= newBitWidth are undefined, we ignore the high
675 // half of RHS, and introduce 'bounds checks' to account for
676 // RHS.low > newBitWidth.
677 //
678 // TODO: Explore possible optimizations.
679 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
680 Value elemBitWidth =
681 createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
682
683 Value illegalElemShift = arith::CmpIOp::create(
684 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
685
686 Value shiftedElem0 =
687 arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0);
688 Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
689 zeroCst, shiftedElem0);
690
691 Value cappedShiftAmount = arith::SelectOp::create(
692 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
693 Value rightShiftAmount =
694 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
695 Value shiftedRight =
696 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount);
697 Value overshotShiftAmount =
698 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
699 Value shiftedLeft =
700 arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount);
701
702 Value shiftedElem1 =
703 arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0);
704 Value resElem1High = arith::SelectOp::create(
705 rewriter, loc, illegalElemShift, zeroCst, shiftedElem1);
706 Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
707 shiftedLeft, shiftedRight);
708 Value resElem1 =
709 arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High);
710
711 Value resultVec =
712 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
713 rewriter.replaceOp(op, resultVec);
714 return success();
715 }
716};
717
718//===----------------------------------------------------------------------===//
719// ConvertShRUI
720//===----------------------------------------------------------------------===//
721
722struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> {
723 using Base::Base;
724
725 LogicalResult
726 matchAndRewrite(arith::ShRUIOp op, OpAdaptor adaptor,
727 ConversionPatternRewriter &rewriter) const override {
728 Location loc = op->getLoc();
729
730 Type oldTy = op.getType();
731 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
732 if (!newTy)
733 return rewriter.notifyMatchFailure(
734 loc, llvm::formatv("unsupported type: {0}", op.getType()));
735
736 Type newOperandTy = reduceInnermostDim(newTy);
737 // `oldBitWidth` == `2 * newBitWidth`
738 unsigned newBitWidth = newTy.getElementTypeBitWidth();
739
740 auto [lhsElem0, lhsElem1] =
741 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
742 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
743
744 // Assume that the shift amount is < 2 * newBitWidth. Calculate the low and
745 // high halves of the results separately:
746 // 1. low := a or b or c, where:
747 // a) Bits from LHS.low, shifted by the RHS.
748 // b) Bits from LHS.high, shifted left. These matter when
749 // RHS < newBitWidth, e.g.:
750 // [hhhh][0000] shrui 3 --> [000h][hhh0]
751 // ^
752 // |
753 // [hhhh] shli (4 - 1)
754 // c) Bits from LHS.high, shifted right. These come into play when
755 // RHS > newBitWidth, e.g.:
756 // [hhhh][0000] shrui 7 --> [0000][000h]
757 // ^
758 // |
759 // [hhhh] shrui (7 - 4)
760 //
761 // 2. high := LHS.high shrui RHS
762 //
763 // Because shifts by values >= newBitWidth are undefined, we ignore the high
764 // half of RHS, and introduce 'bounds checks' to account for
765 // RHS.low > newBitWidth.
766 //
767 // TODO: Explore possible optimizations.
768 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, newOperandTy, 0);
769 Value elemBitWidth =
770 createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth);
771
772 Value illegalElemShift = arith::CmpIOp::create(
773 rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth);
774
775 Value shiftedElem0 =
776 arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0);
777 Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift,
778 zeroCst, shiftedElem0);
779 Value shiftedElem1 =
780 arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0);
781 Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift,
782 zeroCst, shiftedElem1);
783
784 Value cappedShiftAmount = arith::SelectOp::create(
785 rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0);
786 Value leftShiftAmount =
787 arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount);
788 Value shiftedLeft =
789 arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount);
790 Value overshotShiftAmount =
791 arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth);
792 Value shiftedRight =
793 arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount);
794
795 Value resElem0High = arith::SelectOp::create(
796 rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft);
797 Value resElem0 =
798 arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High);
799
800 Value resultVec =
801 constructResultVector(rewriter, loc, newTy, {resElem0, resElem1});
802 rewriter.replaceOp(op, resultVec);
803 return success();
804 }
805};
806
807//===----------------------------------------------------------------------===//
808// ConvertShRSI
809//===----------------------------------------------------------------------===//
810
811struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
812 using Base::Base;
813
814 LogicalResult
815 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
816 ConversionPatternRewriter &rewriter) const override {
817 Location loc = op->getLoc();
818
819 Type oldTy = op.getType();
820 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
821 if (!newTy)
822 return rewriter.notifyMatchFailure(
823 loc, llvm::formatv("unsupported type: {0}", op.getType()));
824
825 Value lhsElem1 = extractLastDimSlice(rewriter, loc, adaptor.getLhs(), 1);
826 Value rhsElem0 = extractLastDimSlice(rewriter, loc, adaptor.getRhs(), 0);
827
828 Type narrowTy = rhsElem0.getType();
829 int64_t origBitwidth = newTy.getElementTypeBitWidth() * 2;
830
831 // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
832 // Perform as many ops over the narrow integer type as possible and let the
833 // other emulation patterns convert the rest.
834 Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
835 Value signBit = arith::CmpIOp::create(
836 rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
837 signBit = dropTrailingX1Dim(rewriter, loc, signBit);
838
839 // Create a bit pattern of either all ones or all zeros. Then shift it left
840 // to calculate the sign extension bits created by shifting the original
841 // sign bit right.
842 Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit);
843 Value maxShift =
844 createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth);
845 Value numNonSignExtBits =
846 arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0);
847 numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits);
848 numNonSignExtBits =
849 arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits);
850 Value signBits =
851 arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits);
852
853 // Use original arguments to create the right shift.
854 Value shrui =
855 arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs());
856 Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits);
857
858 // Handle shifting by zero. This is necessary when the `signBits` shift is
859 // invalid.
860 Value isNoop = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero);
862 isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
863 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
864 shrsi);
865
866 return success();
867 }
868};
869
870//===----------------------------------------------------------------------===//
871// ConvertSubI
872//===----------------------------------------------------------------------===//
873
874struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
875 using Base::Base;
876
877 LogicalResult
878 matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
879 ConversionPatternRewriter &rewriter) const override {
880 Location loc = op->getLoc();
881 auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
882 if (!newTy)
883 return rewriter.notifyMatchFailure(
884 loc, llvm::formatv("unsupported type: {}", op.getType()));
885
886 Type newElemTy = reduceInnermostDim(newTy);
887
888 auto [lhsElem0, lhsElem1] =
889 extractLastDimHalves(rewriter, loc, adaptor.getLhs());
890 auto [rhsElem0, rhsElem1] =
891 extractLastDimHalves(rewriter, loc, adaptor.getRhs());
892
893 // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
894 // CARRY is 1 or 0.
895 Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0);
896 // We have a carry if lhsElem0 < rhsElem0.
897 Value carry0 = arith::CmpIOp::create(
898 rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
899 Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0);
900
901 Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal);
902 Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1);
903
904 Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
905 rewriter.replaceOp(op, resultVec);
906 return success();
907 }
908};
909
910//===----------------------------------------------------------------------===//
911// ConvertSIToFP
912//===----------------------------------------------------------------------===//
913
914struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
915 using Base::Base;
916
917 LogicalResult
918 matchAndRewrite(arith::SIToFPOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920 Location loc = op.getLoc();
921
922 Value in = op.getIn();
923 Type oldTy = in.getType();
924 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
925 if (!newTy)
926 return rewriter.notifyMatchFailure(
927 loc, llvm::formatv("unsupported type: {0}", oldTy));
928
929 Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
930
931 // To avoid operating on very large unsigned numbers, perform the
932 // conversion on the absolute value. Then, decide whether to negate the
933 // result or not based on that sign bit. We implement negation by
934 // subtracting from zero. Note that this relies on the the other conversion
935 // patterns to legalize created ops and narrow the bit widths.
936 Value isNeg = arith::CmpIOp::create(rewriter, loc,
937 arith::CmpIPredicate::slt, in, zeroCst);
938 Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in);
939 Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in);
940
941 Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs);
942 Value negResult = arith::NegFOp::create(rewriter, loc, absResult);
943 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult,
944 absResult);
945 return success();
946 }
947};
948
949//===----------------------------------------------------------------------===//
950// ConvertUIToFP
951//===----------------------------------------------------------------------===//
952
953struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
954 using Base::Base;
955
956 LogicalResult
957 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter) const override {
959 Location loc = op.getLoc();
960
961 Type oldTy = op.getIn().getType();
962 auto newTy = getTypeConverter()->convertType<VectorType>(oldTy);
963 if (!newTy)
964 return rewriter.notifyMatchFailure(
965 loc, llvm::formatv("unsupported type: {0}", oldTy));
966 unsigned newBitWidth = newTy.getElementTypeBitWidth();
967
968 auto [low, hi] = extractLastDimHalves(rewriter, loc, adaptor.getIn());
969 Value lowInt = dropTrailingX1Dim(rewriter, loc, low);
970 Value hiInt = dropTrailingX1Dim(rewriter, loc, hi);
971 Value zeroCst =
972 createScalarOrSplatConstant(rewriter, loc, hiInt.getType(), 0);
973
974 // The final result has the following form:
975 // if (hi == 0) return uitofp(low)
976 // else return uitofp(low) + uitofp(hi) * 2^BW
977 //
978 // where `BW` is the bitwidth of the narrowed integer type. We emit a
979 // select to make it easier to fold-away the `hi` part calculation when it
980 // is known to be zero.
981 //
982 // Note 1: The emulation is precise only for input values that have exact
983 // integer representation in the result floating point type, and may lead
984 // loss of precision otherwise.
985 //
986 // Note 2: We do not strictly need the `hi == 0`, case, but it makes
987 // constant folding easier.
988 Value hiEqZero = arith::CmpIOp::create(
989 rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst);
990
991 Type resultTy = op.getType();
992 Type resultElemTy = getElementTypeOrSelf(resultTy);
993 Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt);
994 Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt);
995
996 int64_t pow2Int = int64_t(1) << newBitWidth;
997 TypedAttr pow2Attr =
998 rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
999 if (auto vecTy = dyn_cast<VectorType>(resultTy))
1000 pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
1001
1002 Value pow2Val =
1003 arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr);
1004
1005 Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val);
1006 Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal);
1007
1008 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result);
1009 return success();
1010 }
1011};
1012
1013//===----------------------------------------------------------------------===//
1014// ConvertFPToSI
1015//===----------------------------------------------------------------------===//
1016
1017struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
1018 using Base::Base;
1019
1020 LogicalResult
1021 matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter) const override {
1023 Location loc = op.getLoc();
1024 // Get the input float type.
1025 Value inFp = adaptor.getIn();
1026 Type fpTy = inFp.getType();
1027
1028 Type intTy = op.getType();
1029
1030 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1031 if (!newTy)
1032 return rewriter.notifyMatchFailure(
1033 loc, llvm::formatv("unsupported type: {}", intTy));
1034
1035 // Work on the absolute value and then convert the result to signed integer.
1036 // Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e. if the fp
1037 // is representable in signed i2N, emits the correct result. Else, the
1038 // result is UB.
1039
1040 TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy);
1041 Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr);
1042 Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0);
1043
1044 // Get the absolute value. One could have used math.absf here, but that
1045 // introduces an extra dependency.
1046 Value isNeg = arith::CmpFOp::create(
1047 rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst);
1048 Value negInFp = arith::NegFOp::create(rewriter, loc, inFp);
1049
1050 Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp);
1051
1052 // Defer the absolute value to fptoui.
1053 Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal);
1054
1055 // Negate the value if < 0 .
1056 Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res);
1057
1058 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
1059 return success();
1060 }
1061};
1062
1063//===----------------------------------------------------------------------===//
1064// ConvertFPToUI
1065//===----------------------------------------------------------------------===//
1066
1067struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
1068 using Base::Base;
1069
1070 LogicalResult
1071 matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
1072 ConversionPatternRewriter &rewriter) const override {
1073 Location loc = op.getLoc();
1074 // Get the input float type.
1075 Value inFp = adaptor.getIn();
1076 Type fpTy = inFp.getType();
1077
1078 Type intTy = op.getType();
1079 auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
1080 if (!newTy)
1081 return rewriter.notifyMatchFailure(
1082 loc, llvm::formatv("unsupported type: {}", intTy));
1083 unsigned newBitWidth = newTy.getElementTypeBitWidth();
1084
1085 Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
1086 if (auto vecType = dyn_cast<VectorType>(fpTy))
1087 newHalfType = VectorType::get(vecType.getShape(), newHalfType);
1088
1089 // The resulting integer has the upper part and the lower part. This would
1090 // be interpreted as 2^N * high + low, where N is the bitwidth. Therefore,
1091 // to calculate the higher part, we emit resHigh = fptoui(fp/2^N). For the
1092 // lower part, we emit fptoui(fp - resHigh * 2^N). The special cases of
1093 // overflows including +-inf, NaNs and negative numbers are UB.
1094
1095 const llvm::fltSemantics &fSemantics =
1096 cast<FloatType>(getElementTypeOrSelf(fpTy)).getFloatSemantics();
1097
1098 auto powBitwidth = llvm::APFloat(fSemantics);
1099 // If the integer does not fit the floating point number, we set the
1100 // powBitwidth to inf. This ensures that the upper part is set
1101 // correctly to 0. The opStatus inexact here only occurs when we have an
1102 // overflow, since the number is always a power of two.
1103 if (powBitwidth.convertFromAPInt(APInt(newBitWidth * 2, 1).shl(newBitWidth),
1104 false, llvm::RoundingMode::TowardZero) ==
1105 llvm::detail::opStatus::opInexact)
1106 powBitwidth = llvm::APFloat::getInf(fSemantics);
1107
1108 TypedAttr powBitwidthAttr =
1109 FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
1110 if (auto vecType = dyn_cast<VectorType>(fpTy))
1111 powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
1112 Value powBitwidthFloatCst =
1113 arith::ConstantOp::create(rewriter, loc, powBitwidthAttr);
1114
1115 Value fpDivPowBitwidth =
1116 arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1117 Value resHigh =
1118 arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth);
1119 // Calculate fp - resHigh * 2^N by getting the remainder of the division
1120 Value remainder =
1121 arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst);
1122 Value resLow =
1123 arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder);
1124
1125 Value high = appendX1Dim(rewriter, loc, resHigh);
1126 Value low = appendX1Dim(rewriter, loc, resLow);
1127
1128 Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
1129
1130 rewriter.replaceOp(op, resultVec);
1131 return success();
1132 }
1133};
1134
1135//===----------------------------------------------------------------------===//
1136// ConvertTruncI
1137//===----------------------------------------------------------------------===//
1138
1139struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> {
1140 using Base::Base;
1141
1142 LogicalResult
1143 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1144 ConversionPatternRewriter &rewriter) const override {
1145 Location loc = op.getLoc();
1146 // Check if the result type is legal for this target. Currently, we do not
1147 // support truncation to types wider than supported by the target.
1148 if (!getTypeConverter()->isLegal(op.getType()))
1149 return rewriter.notifyMatchFailure(
1150 loc, llvm::formatv("unsupported truncation result type: {0}",
1151 op.getType()));
1152
1153 // Discard the high half of the input. Truncate the low half, if
1154 // necessary.
1155 Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
1156 extracted = dropTrailingX1Dim(rewriter, loc, extracted);
1157 Value truncated =
1158 rewriter.createOrFold<arith::TruncIOp>(loc, op.getType(), extracted);
1159 rewriter.replaceOp(op, truncated);
1160 return success();
1161 }
1162};
1163
1164//===----------------------------------------------------------------------===//
1165// ConvertVectorPrint
1166//===----------------------------------------------------------------------===//
1167
1168struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> {
1169 using Base::Base;
1170
1171 LogicalResult
1172 matchAndRewrite(vector::PrintOp op, OpAdaptor adaptor,
1173 ConversionPatternRewriter &rewriter) const override {
1174 rewriter.replaceOpWithNewOp<vector::PrintOp>(op, adaptor.getSource());
1175 return success();
1176 }
1177};
1178
1179//===----------------------------------------------------------------------===//
1180// Pass Definition
1181//===----------------------------------------------------------------------===//
1182
1183struct EmulateWideIntPass final
1184 : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> {
1185 using ArithEmulateWideIntBase::ArithEmulateWideIntBase;
1186
1187 void runOnOperation() override {
1188 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
1189 signalPassFailure();
1190 return;
1191 }
1192
1193 Operation *op = getOperation();
1194 MLIRContext *ctx = op->getContext();
1195
1196 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
1197 ConversionTarget target(*ctx);
1198 target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
1199 return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
1200 });
1201 auto opLegalCallback = [&typeConverter](Operation *op) {
1202 return typeConverter.isLegal(op);
1203 };
1204 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
1205 target.addDynamicallyLegalOp<vector::PrintOp>(opLegalCallback);
1206 target.addDynamicallyLegalDialect<arith::ArithDialect>(opLegalCallback);
1207 target.addLegalDialect<vector::VectorDialect>();
1208
1209 RewritePatternSet patterns(ctx);
1210 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
1211
1212 // Populate `func.*` conversion patterns.
1213 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
1214 patterns, typeConverter);
1217
1218 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1219 signalPassFailure();
1220 }
1221};
1222} // end anonymous namespace
1223
1224//===----------------------------------------------------------------------===//
1225// Public Interface Definition
1226//===----------------------------------------------------------------------===//
1227
1229 unsigned widestIntSupportedByTarget)
1230 : maxIntWidth(widestIntSupportedByTarget) {
1231 assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
1232 "Only power-of-two integers with are supported");
1233 assert(widestIntSupportedByTarget >= 2 && "Integer type too narrow");
1234
1235 // Allow unknown types.
1236 addConversion([](Type ty) -> std::optional<Type> { return ty; });
1237
1238 // Scalar case.
1239 addConversion([this](IntegerType ty) -> std::optional<Type> {
1240 unsigned width = ty.getWidth();
1241 if (width <= maxIntWidth)
1242 return ty;
1243
1244 // i2N --> vector<2xiN>
1245 if (width == 2 * maxIntWidth)
1246 return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
1247
1248 return nullptr;
1249 });
1250
1251 // Vector case.
1252 addConversion([this](VectorType ty) -> std::optional<Type> {
1253 auto intTy = dyn_cast<IntegerType>(ty.getElementType());
1254 if (!intTy)
1255 return ty;
1256
1257 unsigned width = intTy.getWidth();
1258 if (width <= maxIntWidth)
1259 return ty;
1260
1261 // vector<...xi2N> --> vector<...x2xiN>
1262 if (width == 2 * maxIntWidth) {
1263 auto newShape = to_vector(ty.getShape());
1264 newShape.push_back(2);
1265 return VectorType::get(newShape,
1266 IntegerType::get(ty.getContext(), maxIntWidth));
1267 }
1268
1269 return nullptr;
1270 });
1271
1272 // Function case.
1273 addConversion([this](FunctionType ty) -> std::optional<Type> {
1274 // Convert inputs and results, e.g.:
1275 // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
1276 SmallVector<Type> inputs;
1277 if (failed(convertTypes(ty.getInputs(), inputs)))
1278 return nullptr;
1279
1280 SmallVector<Type> results;
1281 if (failed(convertTypes(ty.getResults(), results)))
1282 return nullptr;
1283
1284 return FunctionType::get(ty.getContext(), inputs, results);
1285 });
1286}
1287
1289 const WideIntEmulationConverter &typeConverter,
1291 // Populate `arith.*` conversion patterns.
1292 patterns.add<
1293 // Misc ops.
1294 ConvertConstant, ConvertCmpI, ConvertSelect, ConvertVectorPrint,
1295 // Binary ops.
1296 ConvertAddI, ConvertMulI, ConvertShLI, ConvertShRSI, ConvertShRUI,
1297 ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
1298 ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
1299 ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1300 ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
1301 // Bitwise binary ops.
1302 ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
1303 ConvertBitwiseBinary<arith::XOrIOp>,
1304 // Extension and truncation ops.
1305 ConvertExtSI, ConvertExtUI, ConvertTruncI,
1306 // Cast ops.
1307 ConvertIndexCastIntToIndex<arith::IndexCastOp>,
1308 ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
1309 ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
1310 ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
1311 ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
1312 typeConverter, patterns.getContext());
1313}
return success()
static std::pair< APInt, APInt > getHalves(const APInt &value, unsigned newBitWidth)
Returns N bottom and N top bits from value, where N = newBitWidth.
static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset)
Inserts the source vector slice into the dest vector at offset lastOffset in the last dimension.
static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
Performs a vector shape cast to append an x1 dimension.
static Type reduceInnermostDim(VectorType type)
Returns the type with the last (innermost) dimension reduced to x1.
static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents)
Constructs a new vector of type resultType by creating a series of insertions of resultComponents,...
static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input)
static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset)
Extracts the input vector slice with elements at the last dimension offset by lastOffset.
static std::pair< Value, Value > extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input)
Extracts two vector slices from the input whose type is vector<...x2T>, with the first element at off...
Attributes are known-constant values of operations.
Definition Attributes.h:25
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Converts integer types that are too wide for the target by splitting them in two halves and thus turn...
WideIntEmulationConverter(unsigned widestIntSupportedByTarget)
void populateArithWideIntEmulationPatterns(const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns)
Adds patterns to emulate wide Arith and Function ops over integer types into supported ones.
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition Utils.cpp:270
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns