MLIR 23.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Location.h"
16
17namespace mlir {
18namespace arith {
19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace arith
22} // namespace mlir
23
24using namespace mlir;
25
26/// Create an integer or index constant.
27static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
32 DenseElementsAttr::get(shapedTy, attr));
33 }
34 return arith::ConstantOp::create(rewriter, loc, attr);
35}
36
37/// Create an integer constant from an APInt.
38static Value createAPIntConst(Location loc, Type type, const APInt &value,
39 PatternRewriter &rewriter) {
40 auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value);
41 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
43 DenseElementsAttr::get(shapedTy, attr));
44 }
45 return arith::ConstantOp::create(rewriter, loc, attr);
46}
47
48/// Create a float constant.
49static Value createFloatConst(Location loc, Type type, const APFloat &value,
50 PatternRewriter &rewriter) {
51 auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
52 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
53 return arith::ConstantOp::create(rewriter, loc,
54 DenseElementsAttr::get(shapedTy, attr));
55 }
56
57 return arith::ConstantOp::create(rewriter, loc, attr);
58}
59
60/// Creates shapedType using shape from cloneFrom and base type from cloneTo
61static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
62 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
63 return shapedTy.clone(cloneTo);
64 }
65 return cloneTo;
66}
67
68namespace {
69
70/// Expands CeilDivUIOp (n, m) into
71/// n == 0 ? 0 : ((n-1) / m) + 1
72struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
73 using Base::Base;
74 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
75 PatternRewriter &rewriter) const final {
76 Location loc = op.getLoc();
77 Value a = op.getLhs();
78 Value b = op.getRhs();
79 Value zero = createConst(loc, a.getType(), 0, rewriter);
80 Value compare =
81 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
82 Value one = createConst(loc, a.getType(), 1, rewriter);
83 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
84 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
85 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
86 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
87 return success();
88 }
89};
90
91/// Expands CeilDivSIOp (a, b) into
92/// z = a / b
93/// if (z * b != a && (a < 0) == (b < 0)) {
94/// return z + 1;
95/// } else {
96/// return z;
97/// }
98struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
99 using Base::Base;
100 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
101 PatternRewriter &rewriter) const final {
102 Location loc = op.getLoc();
103 Type type = op.getType();
104 Value a = op.getLhs();
105 Value b = op.getRhs();
106
107 Value zero = createConst(loc, type, 0, rewriter);
108 Value one = createConst(loc, type, 1, rewriter);
109
110 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
111 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
112 Value notEqualDivisor = arith::CmpIOp::create(
113 rewriter, loc, arith::CmpIPredicate::ne, a, product);
114
115 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
116 a, zero);
117 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
118 b, zero);
119
120 Value signEqual = arith::CmpIOp::create(
121 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
122 Value cond =
123 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
124
125 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
126
127 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
128 quotient);
129 return success();
130 }
131};
132
133/// Expands FloorDivSIOp (x, y) into
134/// z = x / y
135/// if (z * y != x && (x < 0) != (y < 0)) {
136/// return z - 1;
137/// } else {
138/// return z;
139/// }
140struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
141 using Base::Base;
142 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
143 PatternRewriter &rewriter) const final {
144 Location loc = op.getLoc();
145 Type type = op.getType();
146 Value a = op.getLhs();
147 Value b = op.getRhs();
148
149 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
150 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
151 Value notEqualDivisor = arith::CmpIOp::create(
152 rewriter, loc, arith::CmpIPredicate::ne, a, product);
153 Value zero = createConst(loc, type, 0, rewriter);
154
155 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
156 a, zero);
157 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
158 b, zero);
159
160 Value signOpposite = arith::CmpIOp::create(
161 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
162 Value cond =
163 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
164
165 Value minusOne = createConst(loc, type, -1, rewriter);
166 Value quotientMinusOne =
167 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
168
169 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
170 quotient);
171 return success();
172 }
173};
174
175template <typename OpTy, arith::CmpIPredicate pred>
176struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
177public:
178 using OpRewritePattern<OpTy>::OpRewritePattern;
179
180 LogicalResult matchAndRewrite(OpTy op,
181 PatternRewriter &rewriter) const final {
182 Value lhs = op.getLhs();
183 Value rhs = op.getRhs();
184
185 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
186 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
187 return success();
188 }
189};
190
191template <typename OpTy, arith::CmpFPredicate pred>
192struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
193public:
194 using OpRewritePattern<OpTy>::OpRewritePattern;
195
196 LogicalResult matchAndRewrite(OpTy op,
197 PatternRewriter &rewriter) const final {
198 Value lhs = op.getLhs();
199 Value rhs = op.getRhs();
200
201 Location loc = op.getLoc();
202 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
203 static_assert(pred == arith::CmpFPredicate::UGT ||
204 pred == arith::CmpFPredicate::ULT,
205 "pred must be either UGT or ULT");
206 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
207 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
208
209 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
210 Value isNaN = arith::CmpFOp::create(rewriter, loc,
211 arith::CmpFPredicate::UNO, rhs, rhs);
212 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
213 return success();
214 }
215};
216
217template <typename OpTy, arith::CmpFPredicate pred>
218struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
219public:
220 using OpRewritePattern<OpTy>::OpRewritePattern;
221
222 LogicalResult matchAndRewrite(OpTy op,
223 PatternRewriter &rewriter) const final {
224 Value lhs = op.getLhs();
225 Value rhs = op.getRhs();
226
227 Location loc = op.getLoc();
228 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
229 static_assert(pred == arith::CmpFPredicate::UGT ||
230 pred == arith::CmpFPredicate::ULT,
231 "pred must be either UGT or ULT");
232 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
233 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
234
235 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
236 Value isNaN = arith::CmpFOp::create(rewriter, loc,
237 arith::CmpFPredicate::UNO, lhs, lhs);
238 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
239 return success();
240 }
241};
242
243struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
244 using Base::Base;
245 LogicalResult matchAndRewrite(arith::ExtFOp op,
246 PatternRewriter &rewriter) const final {
247 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
248 auto operand = op.getOperand();
249 Type operandTy = operand.getType();
250 Type resultTy = op.getType();
251 Type operandETy = getElementTypeOrSelf(operandTy);
252 Type resultETy = getElementTypeOrSelf(resultTy);
253
254 if (!operandETy.isBF16() || !resultETy.isF32()) {
255 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
256 }
257
258 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
259 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
260
261 Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
262 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
263
264 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
265 Value shl = arith::ShLIOp::create(b, exti, c16);
266 Value result = arith::BitcastOp::create(b, resultTy, shl);
267
268 rewriter.replaceOp(op, result);
269 return success();
270 }
271};
272
273struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
274 using Base::Base;
275 LogicalResult matchAndRewrite(arith::TruncFOp op,
276 PatternRewriter &rewriter) const final {
277 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
278 auto operand = op.getOperand();
279 Type operandTy = operand.getType();
280 Type resultTy = op.getType();
281 Type operandETy = getElementTypeOrSelf(operandTy);
282 Type resultETy = getElementTypeOrSelf(resultTy);
283
284 if (!operandETy.isF32() || !resultETy.isBF16()) {
285 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
286 }
287
288 if (op.getRoundingmodeAttr()) {
289 return rewriter.notifyMatchFailure(
290 op, "only applicable to default rounding mode.");
291 }
292
293 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
294 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
295
296 // Algorithm borrowed from this excellent code:
297 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
298 // There is a magic idea there, to let the addition of the rounding_bias to
299 // the mantissa simply overflow into the exponent bits. It's a bit of an
300 // aggressive, obfuscating optimization, but it is well-tested code, and it
301 // results in more concise and efficient IR.
302 // The case of NaN is handled separately (see isNaN and the final select).
303 // The case of infinities is NOT handled separately, which deserves an
304 // explanation. As the encoding of infinities has zero mantissa, the
305 // rounding-bias addition never carries into the exponent so that just gets
306 // truncated away, and as bfloat16 and float32 have the same number of
307 // exponent bits, that simple truncation is the desired outcome for
308 // infinities.
309 Value isNan =
310 arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
311 // Constant used to make the rounding bias.
312 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
313 // Constant used to generate a quiet NaN.
314 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
315 // Small constants used to address bits.
316 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
317 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
318 // Reinterpret the input f32 value as bits.
319 Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
320 // Read bit 16 as a value in {0,1}.
321 Value bit16 =
322 arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
323 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
324 // on bit 16, implementing the tie-breaking "to nearest even".
325 Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
326 // Add the rounding bias. Generally we want this to be added to the
327 // mantissa, but nothing prevents this to from carrying into the exponent
328 // bits, which would feel like a bug, but this is the magic trick here:
329 // when that happens, the mantissa gets reset to zero and the exponent
330 // gets incremented by the carry... which is actually exactly what we
331 // want.
332 Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
333 // Now that the rounding-bias has been added, truncating the low bits
334 // yields the correctly rounded result.
335 Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
336 Value normalCaseResultI16 =
337 arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
338 // Select either the above-computed result, or a quiet NaN constant
339 // if the input was NaN.
340 Value select =
341 arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
342 Value result = arith::BitcastOp::create(b, resultTy, select);
343 rewriter.replaceOp(op, result);
344 return success();
345 }
346};
347
348/// In this implementation of extf we take advantage of some key patterns we
349/// notice between the binary representation of an F4E2M1 value and its
350/// corresponding value in F32.
351///
352/// Note: x is sign bit
353/// | Binary | F4E2M1 | f32[23:32]
354/// | x000 | 0.0 | x000 0000 00
355/// | x001 | 0.5 | x011 1111 00
356/// | x010 | 1.0 | x011 1111 10
357/// | x011 | 1.5 | x011 1111 11
358/// | x100 | 2.0 | x010 0000 00
359/// | x101 | 3.0 | x010 0000 01
360/// | x110 | 4.0 | x010 0000 10
361/// | x111 | 6.0 | x010 0000 11
362///
363/// 1) There are only two versions of bits [25:31] in the f32 result
364/// F4E2M1 bits[2:3] decide whether:
365/// - F32 bits[25:31] = 0011 1111
366/// - F32 bits[25:31] = 0010 0000
367/// Exception is zero where
368/// - F32 bits[25:31] = 0000 0000
369///
370/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
371/// Exception is 0.5 where
372/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
373///
374/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
375///
376/// 4) F32 bits[1:22] = 0
377struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
378 using Base::Base;
379 LogicalResult matchAndRewrite(arith::ExtFOp op,
380 PatternRewriter &rewriter) const final {
381 Location loc = op.getLoc();
382 ImplicitLocOpBuilder b(loc, rewriter);
383 Value operand = op.getOperand();
384 Type operandTy = operand.getType();
385 Type resultTy = op.getType();
386 Type operandETy = getElementTypeOrSelf(operandTy);
387 Type resultETy = getElementTypeOrSelf(resultTy);
388
389 if (!isa<Float4E2M1FNType>(operandETy))
390 return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
391
392 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
393 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
394 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
395 Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
396
397 Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
398 Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
399 Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
400 Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
401 Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
402
403 Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
404
405 // Set last Exponent bit and Mantissa.
406 Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
407 Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
408 Value isHalf =
409 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
410 bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
411 bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
412 bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
413
414 // Set first 7 bits of Exponent.
415 Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
416 Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
417 Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
418 Value useLargerExp =
419 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
420 Value bits25To31 =
421 arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
422 Value zeroExp =
423 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
424 bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
425
426 // Set sign.
427 Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
428 Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
429 Value negative =
430 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
431 Value bit32 =
432 arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
433
434 // Add segments together.
435 Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
436 Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
437 Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
438 if (!isa<Float32Type>(resultETy))
439 result = arith::TruncFOp::create(b, resultTy, result);
440
441 rewriter.replaceOp(op, result);
442 return success();
443 }
444};
445
446struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
447 using Base::Base;
448 LogicalResult matchAndRewrite(arith::ExtFOp op,
449 PatternRewriter &rewriter) const final {
450 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
451 Value operand = op.getOperand();
452 Type operandTy = operand.getType();
453 Type resultTy = op.getType();
454 Type operandETy = getElementTypeOrSelf(operandTy);
455 Type resultETy = getElementTypeOrSelf(resultTy);
456
457 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
458 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
459 }
460
461 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
462 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
463 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
464
465 Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
466 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
467 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
468 Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
469
470 // If FastMathFlag allows no NaN checks, skip it
471 auto fastMath = op.getFastmathAttr();
472 bool NoNaN = fastMath
473 ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
474 arith::FastMathFlags::nnan
475 : false;
476 if (!NoNaN) {
477 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
478 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
479 Value isNan =
480 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
481 // select for NaNs
482 f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
483 }
484
485 Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
486 if (resultETy.getIntOrFloatBitWidth() < 32) {
487 result = arith::TruncFOp::create(b, resultTy, result, nullptr,
488 op.getFastmathAttr());
489 } else if (resultETy.getIntOrFloatBitWidth() > 32) {
490 result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
491 }
492 rewriter.replaceOp(op, result);
493 return success();
494 }
495};
496
497/// Conversion from F32 to F4E2M1 according to the OCP Spec:
498/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
499///
500/// The spec requiers us to perform Round to Nearest, Ties to Even.
501///
502/// This means that after rounding, we should break ties by choosing the option
503/// which results in a mantissa of 0 in the least significant digit.
504///
505/// Table of representable values in F4E2M1:
506///
507/// Note: x is sign bit
508/// | Binary | F4E2M1 | F32[23:32]
509/// | x000 | 0.0 | x000 0000 00
510/// | x001 | 0.5 | x011 1111 00
511/// | x010 | 1.0 | x011 1111 10
512/// | x011 | 1.5 | x011 1111 11
513/// | x100 | 2.0 | x010 0000 00
514/// | x101 | 3.0 | x010 0000 01
515/// | x110 | 4.0 | x010 0000 10
516/// | x111 | 6.0 | x010 0000 11
517///
518/// Conversion procedure:
519/// Step 1: Clamp to representable bounds.
520/// Step 2: Convert exponent by adjusting bias.
521/// Step 3: Set mantissa to first bit.
522/// Step 4: Special consideration for subnormal and zero exponent.
523/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
524/// subnormal.
525struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
526 using Base::Base;
527 LogicalResult matchAndRewrite(arith::TruncFOp op,
528 PatternRewriter &rewriter) const final {
529 Location loc = op.getLoc();
530 ImplicitLocOpBuilder b(loc, rewriter);
531 Value operand = op.getOperand();
532 Type operandTy = operand.getType();
533 Type resultTy = op.getType();
534 Type operandETy = getElementTypeOrSelf(operandTy);
535 Type resultETy = getElementTypeOrSelf(resultTy);
536
537 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
538 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
539 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
540 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
541
542 if (!isa<Float4E2M1FNType>(resultETy))
543 return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
544 if (!isa<Float32Type>(operandETy))
545 operand = arith::ExtFOp::create(b, f32Ty, operand);
546
547 Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
548 Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
549 Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
550 Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
551 Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
552 Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
553
554 // Step 0: Clamp to bounds.
555 Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
556 Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
557 Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
558 operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
559 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
560
561 // Step 1: Set sign bit.
562 Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
563 Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
564 Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
565 Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
566
567 // Step 2: Convert exponent by adjusting bias.
568 Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
569 Value cF4MantissaWidth = c0x1; // 1
570 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
571 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
572 Value biasAdjustedSignExp =
573 arith::SubIOp::create(b, f32SignExp, biasAdjustment);
574 Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
575 f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
576 f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
577
578 // Step 3: Set mantissa to first bit.
579 Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
580 Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
581 man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
582 Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
583 f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
584
585 // Step 4: Special consideration for conversion to 0.5.
586 Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
587 Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
588 Value isSubnormal =
589 arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
590 Value isNegOneExp =
591 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
592 Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
593 Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
594 man23Bits, zeroExpBits);
595 Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
596 Value isZeroExp =
597 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
598 Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
599 Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
600 Value subResult =
601 arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
602 subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
603 f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
604
605 // Step 5: Round up if necessary.
606 Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
607 Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
608 Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
609 Value shouldRound =
610 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
611 shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
612 Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
613 f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
614
615 Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
616 rewriter.replaceOp(op, result);
617 return success();
618 }
619};
620
621/*
622TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
623Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
624they all map to NaN in F8E8M0 Type.
625*/
626struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
627 using Base::Base;
628 LogicalResult matchAndRewrite(arith::TruncFOp op,
629 PatternRewriter &rewriter) const final {
630 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
631 Value operand = op.getOperand();
632 Type operandTy = operand.getType();
633 Type operandETy = getElementTypeOrSelf(operandTy);
634 Type resultTy = op.getType();
635 Type resultETy = getElementTypeOrSelf(resultTy);
636 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
637 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
638 }
639
640 if (op.getRoundingmodeAttr()) {
641 return rewriter.notifyMatchFailure(
642 op, "only applicable to default rounding mode.");
643 }
644
645 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
646 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
647 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
648
649 if (operandETy.getIntOrFloatBitWidth() < 32) {
650 operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
651 } else if (operandETy.getIntOrFloatBitWidth() > 32) {
652 operand = arith::TruncFOp::create(
653 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
654 }
655 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
656 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
657 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
658 Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
659 Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
660 rewriter.replaceOp(op, result);
661 return success();
662 }
663};
664
665struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
666 using Base::Base;
667 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
668 PatternRewriter &rewriter) const final {
669 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
670 Value inputOperand = op.getIn();
671 Value scaleOperand = op.getScale();
672 Type scaleTy = scaleOperand.getType();
673 Type scaleETy = getElementTypeOrSelf(scaleOperand);
674 // allow implicit exponent extraction from 16/32 bits floats
675 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
676 scaleETy = b.getF8E8M0Type();
677 scaleTy = cloneToShapedType(scaleTy, scaleETy);
678 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
679 op.getFastmathAttr());
680 }
681 // Catch scale types like f8E5M2.
682 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
683 return rewriter.notifyMatchFailure(
684 op, "scaling_extf is using scales of type which can not be converted "
685 "to f8E8M0FNU");
686 }
687 Type resultTy = op.getType();
688 // extf on scale will essentially create floating point number
689 // of type resulTy that is 2^scale and will also propagate NaNs
690 Value scaleExt =
691 arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
692 Value inputExt =
693 arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
694 Value result =
695 arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
696 rewriter.replaceOp(op, result);
697 return success();
698 }
699};
700
701/*
702Expands arith.ScalingTruncFOp(in, scale) into
703 scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
704 result = arith.truncf(in / (2^scale))
705 */
706struct ScalingTruncFOpConverter
707 : public OpRewritePattern<arith::ScalingTruncFOp> {
708 using Base::Base;
709 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
710 PatternRewriter &rewriter) const final {
711 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
712 Value inputOperand = op.getIn();
713 Value scaleOperand = op.getScale();
714 Type scaleTy = scaleOperand.getType();
715 Type scaleETy = getElementTypeOrSelf(scaleOperand);
716 // allow implicit exponent extraction from 16/32 bits floats
717 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
718 scaleETy = b.getF8E8M0Type();
719 scaleTy = cloneToShapedType(scaleTy, scaleETy);
720 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
721 op.getFastmathAttr());
722 }
723 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
724 return rewriter.notifyMatchFailure(
725 op, "scaling_truncf is using scales type which can not be converted "
726 "to f8E8M0FNU");
727 }
728 Type resultTy = op.getType();
729 Type inputTy = inputOperand.getType();
730 // this will create a floating point number of type
731 // inputTy that is 2^scale and will also propagate NaNs
732 scaleOperand =
733 arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
734 Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
735 op.getFastmathAttr());
736 Value resultCast = arith::TruncFOp::create(
737 b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
738 rewriter.replaceOp(op, resultCast);
739 return success();
740 }
741};
742
743/// Expands `arith.flush_denormals` into integer arithmetic.
744///
745/// For an IEEE-like floating-point value with a sign|exponent|mantissa bit
746/// layout, a value is denormal iff its biased exponent field is zero and its
747/// stored mantissa is non-zero. When the exponent field is zero, the value is
748/// either pos/neg 0 (mantissa = 0) or a denormal (mantissa != 0); in both
749/// cases, clearing the mantissa bits produces the desired sign-preserved zero
750/// (a no-op for pos/neg 0, a flush for denormals). When the exponent field is
751/// non-zero, the value passes through unchanged.
752///
753/// Pseudocode:
754/// bits = bitcast(x, iN)
755/// expIsZero = (bits & expMask) == 0
756/// cleared = bits & ~manMask
757/// resultBits = select(expIsZero, cleared, bits)
758/// result = bitcast(resultBits, floatTy)
759struct FlushDenormalsOpConverter
760 : public OpRewritePattern<arith::FlushDenormalsOp> {
761 using Base::Base;
762 LogicalResult matchAndRewrite(arith::FlushDenormalsOp op,
763 PatternRewriter &rewriter) const final {
764 Location loc = op.getLoc();
765 ImplicitLocOpBuilder b(loc, rewriter);
766 Value operand = op.getOperand();
767 Type operandTy = operand.getType();
768 auto floatTy = dyn_cast<FloatType>(getElementTypeOrSelf(operandTy));
769 if (!floatTy)
770 return rewriter.notifyMatchFailure(op, "operand is not a float type");
771
772 const llvm::fltSemantics &sem = floatTy.getFloatSemantics();
773 // Restrict to IEEE-like encodings, where the sign bit is the MSB and
774 // denormals are exactly "biased exponent == 0 and non-zero mantissa".
775 if (!llvm::APFloatBase::isIEEELikeFP(sem))
776 return rewriter.notifyMatchFailure(
777 op, "only IEEE-like floating-point types are supported");
778
779 unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem);
780 unsigned precision = llvm::APFloatBase::semanticsPrecision(sem);
781 // Stored mantissa bits = precision - 1 (implicit leading bit not stored).
782 // Exponent field bits = totalBits - 1 (sign) - storedMantissa.
783 if (precision < 1 || precision > totalBits)
784 return rewriter.notifyMatchFailure(op, "unexpected float semantics");
785 unsigned mantissaBits = precision - 1;
786 unsigned expBits = totalBits - 1 - mantissaBits;
787 if (expBits == 0 || mantissaBits == 0)
788 return rewriter.notifyMatchFailure(
789 op, "degenerate float encoding has no exponent or mantissa");
790
791 Type intTy =
792 cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits));
793 Value bits = arith::BitcastOp::create(b, intTy, operand);
794 APInt expMaskVal =
795 APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits);
796 APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits);
797 APInt zeroVal = APInt::getZero(totalBits);
798 Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter);
799 Value clearMantissaMask =
800 createAPIntConst(loc, intTy, clearMantissaMaskVal, rewriter);
801 Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter);
802
803 // expField == 0
804 Value expField = arith::AndIOp::create(b, bits, expMask);
805 Value expIsZero =
806 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero);
807
808 // Clear mantissa bits: when exp == 0, this produces pos/neg 0.0.
809 Value cleared = arith::AndIOp::create(b, bits, clearMantissaMask);
810 Value resultBits = arith::SelectOp::create(b, expIsZero, cleared, bits);
811 Value result = arith::BitcastOp::create(b, operandTy, resultBits);
812
813 rewriter.replaceOp(op, result);
814 return success();
815 }
816};
817
818struct ArithExpandOpsPass
819 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
820 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
821
822 void runOnOperation() override {
823 RewritePatternSet patterns(&getContext());
824 ConversionTarget target(getContext());
825
826 arith::populateArithExpandOpsPatterns(patterns);
827
828 target.addLegalDialect<arith::ArithDialect>();
829 target.addLegalDialect<vector::VectorDialect>();
830
831 // clang-format off
832 target.addIllegalOp<
833 arith::CeilDivSIOp,
834 arith::CeilDivUIOp,
835 arith::FloorDivSIOp,
836 arith::MaxSIOp,
837 arith::MaxUIOp,
838 arith::MinSIOp,
839 arith::MinUIOp,
840 arith::MaximumFOp,
841 arith::MinimumFOp,
842 arith::MaxNumFOp,
843 arith::MinNumFOp,
844 arith::ScalingExtFOp,
845 arith::ScalingTruncFOp
846 >();
847
848 if (includeBf16)
849 arith::populateExpandBFloat16Patterns(patterns);
850 if (includeF8E8M0)
851 arith::populateExpandF8E8M0Patterns(patterns);
852 if (includeF4E2M1)
853 arith::populateExpandF4E2M1Patterns(patterns);
854 if (includeFlushDenormals) {
855 arith::populateExpandFlushDenormalsPatterns(patterns);
856 // Only IEEE-like floating-point types are expanded by the pattern;
857 // leave `arith.flush_denormals` on other types alone.
858 target.addDynamicallyLegalOp<arith::FlushDenormalsOp>(
859 [](arith::FlushDenormalsOp op) {
860 auto floatTy =
861 dyn_cast<FloatType>(getElementTypeOrSelf(op.getType()));
862 if (!floatTy)
863 return true;
864 return !llvm::APFloatBase::isIEEELikeFP(
865 floatTy.getFloatSemantics());
866 });
867 }
868
869 target.addDynamicallyLegalOp<arith::ExtFOp>(
870 [=](arith::ExtFOp op) {
871 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
872 Type outETy = getElementTypeOrSelf(op.getType());
873 bool legalTypes = true;
874 if (includeBf16)
875 legalTypes &= !(inETy.isBF16() && outETy.isF32());
876 if (includeF8E8M0)
877 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
878 if (includeF4E2M1)
879 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
880 return legalTypes;
881 });
882
883 target.addDynamicallyLegalOp<arith::TruncFOp>(
884 [=](arith::TruncFOp op) {
885 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
886 Type outETy = getElementTypeOrSelf(op.getType());
887 bool legalTypes = true;
888 if (includeBf16)
889 legalTypes &= !(inETy.isF32() && outETy.isBF16());
890 if (includeF8E8M0)
891 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
892 if (includeF4E2M1)
893 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
894 return legalTypes;
895 });
896
897 // clang-format on
898 if (failed(applyPartialConversion(getOperation(), target,
899 std::move(patterns))))
900 signalPassFailure();
901 }
902};
903
904} // namespace
905
907 RewritePatternSet &patterns) {
908 patterns
909 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
910 patterns.getContext());
911}
912
914 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
915 patterns.getContext());
916}
917
919 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
920 patterns.getContext());
921}
922
924 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
925 patterns.getContext());
926}
927
929 RewritePatternSet &patterns) {
930 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
931 patterns.getContext());
932}
933
935 RewritePatternSet &patterns) {
936 patterns.add<FlushDenormalsOpConverter>(patterns.getContext());
937}
938
942 // clang-format off
943 patterns.add<
944 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
945 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
946 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
947 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
948 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
949 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
950 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
951 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
952 >(patterns.getContext());
953 // clang-format on
954}
return success()
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition ExpandOps.cpp:27
static Value createAPIntConst(Location loc, Type type, const APInt &value, PatternRewriter &rewriter)
Create an integer constant from an APInt.
Definition ExpandOps.cpp:38
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition ExpandOps.cpp:61
static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter)
Create a float constant.
Definition ExpandOps.cpp:49
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns)
Add patterns to expand arith.flush_denormals into integer arithmetic (bitcast + bit masks + compare +...
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...