MLIR 23.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Location.h"
16
17namespace mlir {
18namespace arith {
19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace arith
22} // namespace mlir
23
24using namespace mlir;
25
26/// Create an integer or index constant.
27static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
32 DenseElementsAttr::get(shapedTy, attr));
33 }
34 return arith::ConstantOp::create(rewriter, loc, attr);
35}
36
37/// Create a float constant.
38static Value createFloatConst(Location loc, Type type, const APFloat &value,
39 PatternRewriter &rewriter) {
40 auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
41 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
43 DenseElementsAttr::get(shapedTy, attr));
44 }
45
46 return arith::ConstantOp::create(rewriter, loc, attr);
47}
48
49/// Creates shapedType using shape from cloneFrom and base type from cloneTo
50static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
51 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52 return shapedTy.clone(cloneTo);
53 }
54 return cloneTo;
55}
56
57namespace {
58
59/// Expands CeilDivUIOp (n, m) into
60/// n == 0 ? 0 : ((n-1) / m) + 1
61struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
62 using Base::Base;
63 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
64 PatternRewriter &rewriter) const final {
65 Location loc = op.getLoc();
66 Value a = op.getLhs();
67 Value b = op.getRhs();
68 Value zero = createConst(loc, a.getType(), 0, rewriter);
69 Value compare =
70 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
71 Value one = createConst(loc, a.getType(), 1, rewriter);
72 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
76 return success();
77 }
78};
79
80/// Expands CeilDivSIOp (a, b) into
81/// z = a / b
82/// if (z * b != a && (a < 0) == (b < 0)) {
83/// return z + 1;
84/// } else {
85/// return z;
86/// }
87struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
88 using Base::Base;
89 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
90 PatternRewriter &rewriter) const final {
91 Location loc = op.getLoc();
92 Type type = op.getType();
93 Value a = op.getLhs();
94 Value b = op.getRhs();
95
96 Value zero = createConst(loc, type, 0, rewriter);
97 Value one = createConst(loc, type, 1, rewriter);
98
99 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101 Value notEqualDivisor = arith::CmpIOp::create(
102 rewriter, loc, arith::CmpIPredicate::ne, a, product);
103
104 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
105 a, zero);
106 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
107 b, zero);
108
109 Value signEqual = arith::CmpIOp::create(
110 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
111 Value cond =
112 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
113
114 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
115
116 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
117 quotient);
118 return success();
119 }
120};
121
122/// Expands FloorDivSIOp (x, y) into
123/// z = x / y
124/// if (z * y != x && (x < 0) != (y < 0)) {
125/// return z - 1;
126/// } else {
127/// return z;
128/// }
129struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
130 using Base::Base;
131 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
132 PatternRewriter &rewriter) const final {
133 Location loc = op.getLoc();
134 Type type = op.getType();
135 Value a = op.getLhs();
136 Value b = op.getRhs();
137
138 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140 Value notEqualDivisor = arith::CmpIOp::create(
141 rewriter, loc, arith::CmpIPredicate::ne, a, product);
142 Value zero = createConst(loc, type, 0, rewriter);
143
144 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
145 a, zero);
146 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
147 b, zero);
148
149 Value signOpposite = arith::CmpIOp::create(
150 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
151 Value cond =
152 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
153
154 Value minusOne = createConst(loc, type, -1, rewriter);
155 Value quotientMinusOne =
156 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
157
158 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
159 quotient);
160 return success();
161 }
162};
163
164template <typename OpTy, arith::CmpIPredicate pred>
165struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
166public:
167 using OpRewritePattern<OpTy>::OpRewritePattern;
168
169 LogicalResult matchAndRewrite(OpTy op,
170 PatternRewriter &rewriter) const final {
171 Value lhs = op.getLhs();
172 Value rhs = op.getRhs();
173
174 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
176 return success();
177 }
178};
179
180template <typename OpTy, arith::CmpFPredicate pred>
181struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
182public:
183 using OpRewritePattern<OpTy>::OpRewritePattern;
184
185 LogicalResult matchAndRewrite(OpTy op,
186 PatternRewriter &rewriter) const final {
187 Value lhs = op.getLhs();
188 Value rhs = op.getRhs();
189
190 Location loc = op.getLoc();
191 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
192 static_assert(pred == arith::CmpFPredicate::UGT ||
193 pred == arith::CmpFPredicate::ULT,
194 "pred must be either UGT or ULT");
195 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
197
198 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
199 Value isNaN = arith::CmpFOp::create(rewriter, loc,
200 arith::CmpFPredicate::UNO, rhs, rhs);
201 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
202 return success();
203 }
204};
205
206template <typename OpTy, arith::CmpFPredicate pred>
207struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
208public:
209 using OpRewritePattern<OpTy>::OpRewritePattern;
210
211 LogicalResult matchAndRewrite(OpTy op,
212 PatternRewriter &rewriter) const final {
213 Value lhs = op.getLhs();
214 Value rhs = op.getRhs();
215
216 Location loc = op.getLoc();
217 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
218 static_assert(pred == arith::CmpFPredicate::UGT ||
219 pred == arith::CmpFPredicate::ULT,
220 "pred must be either UGT or ULT");
221 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
223
224 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
225 Value isNaN = arith::CmpFOp::create(rewriter, loc,
226 arith::CmpFPredicate::UNO, lhs, lhs);
227 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
228 return success();
229 }
230};
231
232struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
234 LogicalResult matchAndRewrite(arith::ExtFOp op,
235 PatternRewriter &rewriter) const final {
236 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
237 auto operand = op.getOperand();
238 Type operandTy = operand.getType();
239 Type resultTy = op.getType();
240 Type operandETy = getElementTypeOrSelf(operandTy);
241 Type resultETy = getElementTypeOrSelf(resultTy);
243 if (!operandETy.isBF16() || !resultETy.isF32()) {
244 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
246
247 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
248 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
249
250 Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
252
253 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
254 Value shl = arith::ShLIOp::create(b, exti, c16);
255 Value result = arith::BitcastOp::create(b, resultTy, shl);
257 rewriter.replaceOp(op, result);
258 return success();
260};
261
262struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
263 using Base::Base;
264 LogicalResult matchAndRewrite(arith::TruncFOp op,
265 PatternRewriter &rewriter) const final {
266 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
267 auto operand = op.getOperand();
268 Type operandTy = operand.getType();
269 Type resultTy = op.getType();
270 Type operandETy = getElementTypeOrSelf(operandTy);
271 Type resultETy = getElementTypeOrSelf(resultTy);
272
273 if (!operandETy.isF32() || !resultETy.isBF16()) {
274 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
275 }
276
277 if (op.getRoundingmodeAttr()) {
278 return rewriter.notifyMatchFailure(
279 op, "only applicable to default rounding mode.");
280 }
281
282 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
283 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
285 // Algorithm borrowed from this excellent code:
286 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
287 // There is a magic idea there, to let the addition of the rounding_bias to
288 // the mantissa simply overflow into the exponent bits. It's a bit of an
289 // aggressive, obfuscating optimization, but it is well-tested code, and it
290 // results in more concise and efficient IR.
291 // The case of NaN is handled separately (see isNaN and the final select).
292 // The case of infinities is NOT handled separately, which deserves an
293 // explanation. As the encoding of infinities has zero mantissa, the
294 // rounding-bias addition never carries into the exponent so that just gets
295 // truncated away, and as bfloat16 and float32 have the same number of
296 // exponent bits, that simple truncation is the desired outcome for
297 // infinities.
298 Value isNan =
299 arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
300 // Constant used to make the rounding bias.
301 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
302 // Constant used to generate a quiet NaN.
303 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
304 // Small constants used to address bits.
305 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
306 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
307 // Reinterpret the input f32 value as bits.
308 Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
309 // Read bit 16 as a value in {0,1}.
310 Value bit16 =
311 arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
312 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
313 // on bit 16, implementing the tie-breaking "to nearest even".
314 Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
315 // Add the rounding bias. Generally we want this to be added to the
316 // mantissa, but nothing prevents this to from carrying into the exponent
317 // bits, which would feel like a bug, but this is the magic trick here:
318 // when that happens, the mantissa gets reset to zero and the exponent
319 // gets incremented by the carry... which is actually exactly what we
320 // want.
321 Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
322 // Now that the rounding-bias has been added, truncating the low bits
323 // yields the correctly rounded result.
324 Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325 Value normalCaseResultI16 =
326 arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
327 // Select either the above-computed result, or a quiet NaN constant
328 // if the input was NaN.
329 Value select =
330 arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331 Value result = arith::BitcastOp::create(b, resultTy, select);
332 rewriter.replaceOp(op, result);
333 return success();
334 }
335};
336
337/// In this implementation of extf we take advantage of some key patterns we
338/// notice between the binary representation of an F4E2M1 value and its
339/// corresponding value in F32.
340///
341/// Note: x is sign bit
342/// | Binary | F4E2M1 | f32[23:32]
343/// | x000 | 0.0 | x000 0000 00
344/// | x001 | 0.5 | x011 1111 00
345/// | x010 | 1.0 | x011 1111 10
346/// | x011 | 1.5 | x011 1111 11
347/// | x100 | 2.0 | x010 0000 00
348/// | x101 | 3.0 | x010 0000 01
349/// | x110 | 4.0 | x010 0000 10
350/// | x111 | 6.0 | x010 0000 11
351///
352/// 1) There are only two versions of bits [25:31] in the f32 result
353/// F4E2M1 bits[2:3] decide whether:
354/// - F32 bits[25:31] = 0011 1111
355/// - F32 bits[25:31] = 0010 0000
356/// Exception is zero where
357/// - F32 bits[25:31] = 0000 0000
358///
359/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
360/// Exception is 0.5 where
361/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
362///
363/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
364///
365/// 4) F32 bits[1:22] = 0
366struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
367 using Base::Base;
368 LogicalResult matchAndRewrite(arith::ExtFOp op,
369 PatternRewriter &rewriter) const final {
370 Location loc = op.getLoc();
371 ImplicitLocOpBuilder b(loc, rewriter);
372 Value operand = op.getOperand();
373 Type operandTy = operand.getType();
374 Type resultTy = op.getType();
375 Type operandETy = getElementTypeOrSelf(operandTy);
376 Type resultETy = getElementTypeOrSelf(resultTy);
377
378 if (!isa<Float4E2M1FNType>(operandETy))
379 return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
380
381 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
382 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
383 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
384 Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
385
386 Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
387 Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
388 Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
389 Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
390 Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
391
392 Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
393
394 // Set last Exponent bit and Mantissa.
395 Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
396 Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
397 Value isHalf =
398 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
399 bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
400 bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
401 bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
402
403 // Set first 7 bits of Exponent.
404 Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
405 Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
406 Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
407 Value useLargerExp =
408 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
409 Value bits25To31 =
410 arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
411 Value zeroExp =
412 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
413 bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
414
415 // Set sign.
416 Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
417 Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
418 Value negative =
419 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
420 Value bit32 =
421 arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
422
423 // Add segments together.
424 Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
425 Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
426 Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
427 if (!isa<Float32Type>(resultETy))
428 result = arith::TruncFOp::create(b, resultTy, result);
429
430 rewriter.replaceOp(op, result);
431 return success();
432 }
433};
434
435struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
436 using Base::Base;
437 LogicalResult matchAndRewrite(arith::ExtFOp op,
438 PatternRewriter &rewriter) const final {
439 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440 Value operand = op.getOperand();
441 Type operandTy = operand.getType();
442 Type resultTy = op.getType();
443 Type operandETy = getElementTypeOrSelf(operandTy);
444 Type resultETy = getElementTypeOrSelf(resultTy);
445
446 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
447 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
448 }
449
450 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
451 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
452 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
453
454 Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
455 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
456 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
457 Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
458
459 // If FastMathFlag allows no NaN checks, skip it
460 auto fastMath = op.getFastmathAttr();
461 bool NoNaN = fastMath
462 ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
463 arith::FastMathFlags::nnan
464 : false;
465 if (!NoNaN) {
466 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
467 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
468 Value isNan =
469 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
470 // select for NaNs
471 f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
472 }
473
474 Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
475 if (resultETy.getIntOrFloatBitWidth() < 32) {
476 result = arith::TruncFOp::create(b, resultTy, result, nullptr,
477 op.getFastmathAttr());
478 } else if (resultETy.getIntOrFloatBitWidth() > 32) {
479 result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
480 }
481 rewriter.replaceOp(op, result);
482 return success();
483 }
484};
485
486/// Conversion from F32 to F4E2M1 according to the OCP Spec:
487/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
488///
489/// The spec requiers us to perform Round to Nearest, Ties to Even.
490///
491/// This means that after rounding, we should break ties by choosing the option
492/// which results in a mantissa of 0 in the least significant digit.
493///
494/// Table of representable values in F4E2M1:
495///
496/// Note: x is sign bit
497/// | Binary | F4E2M1 | F32[23:32]
498/// | x000 | 0.0 | x000 0000 00
499/// | x001 | 0.5 | x011 1111 00
500/// | x010 | 1.0 | x011 1111 10
501/// | x011 | 1.5 | x011 1111 11
502/// | x100 | 2.0 | x010 0000 00
503/// | x101 | 3.0 | x010 0000 01
504/// | x110 | 4.0 | x010 0000 10
505/// | x111 | 6.0 | x010 0000 11
506///
507/// Conversion procedure:
508/// Step 1: Clamp to representable bounds.
509/// Step 2: Convert exponent by adjusting bias.
510/// Step 3: Set mantissa to first bit.
511/// Step 4: Special consideration for subnormal and zero exponent.
512/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
513/// subnormal.
514struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
515 using Base::Base;
516 LogicalResult matchAndRewrite(arith::TruncFOp op,
517 PatternRewriter &rewriter) const final {
518 Location loc = op.getLoc();
519 ImplicitLocOpBuilder b(loc, rewriter);
520 Value operand = op.getOperand();
521 Type operandTy = operand.getType();
522 Type resultTy = op.getType();
523 Type operandETy = getElementTypeOrSelf(operandTy);
524 Type resultETy = getElementTypeOrSelf(resultTy);
525
526 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
527 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
528 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
529 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
530
531 if (!isa<Float4E2M1FNType>(resultETy))
532 return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
533 if (!isa<Float32Type>(operandETy))
534 operand = arith::ExtFOp::create(b, f32Ty, operand);
535
536 Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
537 Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
538 Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
539 Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
540 Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
541 Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
542
543 // Step 0: Clamp to bounds.
544 Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
545 Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
546 Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
547 operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
548 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
549
550 // Step 1: Set sign bit.
551 Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
552 Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
553 Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
554 Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
555
556 // Step 2: Convert exponent by adjusting bias.
557 Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
558 Value cF4MantissaWidth = c0x1; // 1
559 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
560 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
561 Value biasAdjustedSignExp =
562 arith::SubIOp::create(b, f32SignExp, biasAdjustment);
563 Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
564 f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
565 f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
566
567 // Step 3: Set mantissa to first bit.
568 Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
569 Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
570 man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
571 Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
572 f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
573
574 // Step 4: Special consideration for conversion to 0.5.
575 Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
576 Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
577 Value isSubnormal =
578 arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
579 Value isNegOneExp =
580 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
581 Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
582 Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
583 man23Bits, zeroExpBits);
584 Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
585 Value isZeroExp =
586 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
587 Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
588 Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
589 Value subResult =
590 arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
591 subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
592 f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
593
594 // Step 5: Round up if necessary.
595 Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
596 Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
597 Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
598 Value shouldRound =
599 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
600 shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
601 Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
602 f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
603
604 Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
605 rewriter.replaceOp(op, result);
606 return success();
607 }
608};
609
610/*
611TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
612Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
613they all map to NaN in F8E8M0 Type.
614*/
615struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
616 using Base::Base;
617 LogicalResult matchAndRewrite(arith::TruncFOp op,
618 PatternRewriter &rewriter) const final {
619 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
620 Value operand = op.getOperand();
621 Type operandTy = operand.getType();
622 Type operandETy = getElementTypeOrSelf(operandTy);
623 Type resultTy = op.getType();
624 Type resultETy = getElementTypeOrSelf(resultTy);
625 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
626 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
627 }
628
629 if (op.getRoundingmodeAttr()) {
630 return rewriter.notifyMatchFailure(
631 op, "only applicable to default rounding mode.");
632 }
633
634 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
635 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
636 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
637
638 if (operandETy.getIntOrFloatBitWidth() < 32) {
639 operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
640 } else if (operandETy.getIntOrFloatBitWidth() > 32) {
641 operand = arith::TruncFOp::create(
642 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
643 }
644 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
645 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
646 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
647 Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
648 Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
649 rewriter.replaceOp(op, result);
650 return success();
651 }
652};
653
654struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
655 using Base::Base;
656 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
657 PatternRewriter &rewriter) const final {
658 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
659 Value inputOperand = op.getIn();
660 Value scaleOperand = op.getScale();
661 Type scaleTy = scaleOperand.getType();
662 Type scaleETy = getElementTypeOrSelf(scaleOperand);
663 // allow implicit exponent extraction from 16/32 bits floats
664 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
665 scaleETy = b.getF8E8M0Type();
666 scaleTy = cloneToShapedType(scaleTy, scaleETy);
667 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
668 op.getFastmathAttr());
669 }
670 // Catch scale types like f8E5M2.
671 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
672 return rewriter.notifyMatchFailure(
673 op, "scaling_extf is using scales of type which can not be converted "
674 "to f8E8M0FNU");
675 }
676 Type resultTy = op.getType();
677 // extf on scale will essentially create floating point number
678 // of type resulTy that is 2^scale and will also propagate NaNs
679 Value scaleExt =
680 arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
681 Value inputExt =
682 arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
683 Value result =
684 arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
685 rewriter.replaceOp(op, result);
686 return success();
687 }
688};
689
690/*
691Expands arith.ScalingTruncFOp(in, scale) into
692 scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
693 result = arith.truncf(in / (2^scale))
694 */
695struct ScalingTruncFOpConverter
696 : public OpRewritePattern<arith::ScalingTruncFOp> {
697 using Base::Base;
698 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
699 PatternRewriter &rewriter) const final {
700 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
701 Value inputOperand = op.getIn();
702 Value scaleOperand = op.getScale();
703 Type scaleTy = scaleOperand.getType();
704 Type scaleETy = getElementTypeOrSelf(scaleOperand);
705 // allow implicit exponent extraction from 16/32 bits floats
706 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
707 scaleETy = b.getF8E8M0Type();
708 scaleTy = cloneToShapedType(scaleTy, scaleETy);
709 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
710 op.getFastmathAttr());
711 }
712 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
713 return rewriter.notifyMatchFailure(
714 op, "scaling_truncf is using scales type which can not be converted "
715 "to f8E8M0FNU");
716 }
717 Type resultTy = op.getType();
718 Type inputTy = inputOperand.getType();
719 // this will create a floating point number of type
720 // inputTy that is 2^scale and will also propagate NaNs
721 scaleOperand =
722 arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
723 Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
724 op.getFastmathAttr());
725 Value resultCast = arith::TruncFOp::create(
726 b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
727 rewriter.replaceOp(op, resultCast);
728 return success();
729 }
730};
731
732struct ArithExpandOpsPass
733 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
734 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
735
736 void runOnOperation() override {
737 RewritePatternSet patterns(&getContext());
738 ConversionTarget target(getContext());
739
740 arith::populateArithExpandOpsPatterns(patterns);
741
742 target.addLegalDialect<arith::ArithDialect>();
743 target.addLegalDialect<vector::VectorDialect>();
744
745 // clang-format off
746 target.addIllegalOp<
747 arith::CeilDivSIOp,
748 arith::CeilDivUIOp,
749 arith::FloorDivSIOp,
750 arith::MaxSIOp,
751 arith::MaxUIOp,
752 arith::MinSIOp,
753 arith::MinUIOp,
754 arith::MaximumFOp,
755 arith::MinimumFOp,
756 arith::MaxNumFOp,
757 arith::MinNumFOp,
758 arith::ScalingExtFOp,
759 arith::ScalingTruncFOp
760 >();
761
762 if (includeBf16)
763 arith::populateExpandBFloat16Patterns(patterns);
764 if (includeF8E8M0)
765 arith::populateExpandF8E8M0Patterns(patterns);
766 if (includeF4E2M1)
767 arith::populateExpandF4E2M1Patterns(patterns);
768
769 target.addDynamicallyLegalOp<arith::ExtFOp>(
770 [=](arith::ExtFOp op) {
771 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
772 Type outETy = getElementTypeOrSelf(op.getType());
773 bool legalTypes = true;
774 if (includeBf16)
775 legalTypes &= !(inETy.isBF16() && outETy.isF32());
776 if (includeF8E8M0)
777 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
778 if (includeF4E2M1)
779 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
780 return legalTypes;
781 });
782
783 target.addDynamicallyLegalOp<arith::TruncFOp>(
784 [=](arith::TruncFOp op) {
785 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
786 Type outETy = getElementTypeOrSelf(op.getType());
787 bool legalTypes = true;
788 if (includeBf16)
789 legalTypes &= !(inETy.isF32() && outETy.isBF16());
790 if (includeF8E8M0)
791 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
792 if (includeF4E2M1)
793 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
794 return legalTypes;
795 });
796
797 // clang-format on
798 if (failed(applyPartialConversion(getOperation(), target,
799 std::move(patterns))))
800 signalPassFailure();
801 }
802};
803
804} // namespace
805
809 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
810 patterns.getContext());
811}
812
814 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
815 patterns.getContext());
816}
817
819 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
820 patterns.getContext());
821}
822
824 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
825 patterns.getContext());
826}
827
830 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
831 patterns.getContext());
832}
833
837 // clang-format off
838 patterns.add<
839 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
840 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
841 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
842 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
843 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
844 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
845 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
846 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
847 >(patterns.getContext());
848 // clang-format on
849}
return success()
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition ExpandOps.cpp:27
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition ExpandOps.cpp:50
static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter)
Create a float constant.
Definition ExpandOps.cpp:38
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...