MLIR 22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Location.h"
16
17namespace mlir {
18namespace arith {
19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace arith
22} // namespace mlir
23
24using namespace mlir;
25
26/// Create an integer or index constant.
27static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return arith::ConstantOp::create(rewriter, loc,
32 DenseElementsAttr::get(shapedTy, attr));
33 }
34 return arith::ConstantOp::create(rewriter, loc, attr);
35}
36
37/// Create a float constant.
38static Value createFloatConst(Location loc, Type type, APFloat value,
39 PatternRewriter &rewriter) {
40 auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
41 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42 return arith::ConstantOp::create(rewriter, loc,
43 DenseElementsAttr::get(shapedTy, attr));
44 }
45
46 return arith::ConstantOp::create(rewriter, loc, attr);
47}
48
49/// Creates shapedType using shape from cloneFrom and base type from cloneTo
50static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
51 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52 return shapedTy.clone(cloneTo);
53 }
54 return cloneTo;
55}
56
57namespace {
58
59/// Expands CeilDivUIOp (n, m) into
60/// n == 0 ? 0 : ((n-1) / m) + 1
61struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
62 using Base::Base;
63 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
64 PatternRewriter &rewriter) const final {
65 Location loc = op.getLoc();
66 Value a = op.getLhs();
67 Value b = op.getRhs();
68 Value zero = createConst(loc, a.getType(), 0, rewriter);
69 Value compare =
70 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
71 Value one = createConst(loc, a.getType(), 1, rewriter);
72 Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73 Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74 Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
76 return success();
77 }
78};
79
80/// Expands CeilDivSIOp (a, b) into
81/// z = a / b
82/// if (z * b != a && (a < 0) == (b < 0)) {
83/// return z + 1;
84/// } else {
85/// return z;
86/// }
87struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
88 using Base::Base;
89 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
90 PatternRewriter &rewriter) const final {
91 Location loc = op.getLoc();
92 Type type = op.getType();
93 Value a = op.getLhs();
94 Value b = op.getRhs();
95
96 Value zero = createConst(loc, type, 0, rewriter);
97 Value one = createConst(loc, type, 1, rewriter);
98
99 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101 Value notEqualDivisor = arith::CmpIOp::create(
102 rewriter, loc, arith::CmpIPredicate::ne, a, product);
103
104 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
105 a, zero);
106 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
107 b, zero);
108
109 Value signEqual = arith::CmpIOp::create(
110 rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
111 Value cond =
112 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
113
114 Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
115
116 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
117 quotient);
118 return success();
119 }
120};
121
122/// Expands FloorDivSIOp (x, y) into
123/// z = x / y
124/// if (z * y != x && (x < 0) != (y < 0)) {
125/// return z - 1;
126/// } else {
127/// return z;
128/// }
129struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
130 using Base::Base;
131 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
132 PatternRewriter &rewriter) const final {
133 Location loc = op.getLoc();
134 Type type = op.getType();
135 Value a = op.getLhs();
136 Value b = op.getRhs();
137
138 Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139 Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140 Value notEqualDivisor = arith::CmpIOp::create(
141 rewriter, loc, arith::CmpIPredicate::ne, a, product);
142 Value zero = createConst(loc, type, 0, rewriter);
143
144 Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
145 a, zero);
146 Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
147 b, zero);
148
149 Value signOpposite = arith::CmpIOp::create(
150 rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
151 Value cond =
152 arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
153
154 Value minusOne = createConst(loc, type, -1, rewriter);
155 Value quotientMinusOne =
156 arith::AddIOp::create(rewriter, loc, quotient, minusOne);
157
158 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
159 quotient);
160 return success();
161 }
162};
163
164template <typename OpTy, arith::CmpIPredicate pred>
165struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
166public:
167 using OpRewritePattern<OpTy>::OpRewritePattern;
168
169 LogicalResult matchAndRewrite(OpTy op,
170 PatternRewriter &rewriter) const final {
171 Value lhs = op.getLhs();
172 Value rhs = op.getRhs();
173
174 Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
176 return success();
177 }
178};
179
180template <typename OpTy, arith::CmpFPredicate pred>
181struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
182public:
183 using OpRewritePattern<OpTy>::OpRewritePattern;
184
185 LogicalResult matchAndRewrite(OpTy op,
186 PatternRewriter &rewriter) const final {
187 Value lhs = op.getLhs();
188 Value rhs = op.getRhs();
189
190 Location loc = op.getLoc();
191 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
192 static_assert(pred == arith::CmpFPredicate::UGT ||
193 pred == arith::CmpFPredicate::ULT,
194 "pred must be either UGT or ULT");
195 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
197
198 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
199 Value isNaN = arith::CmpFOp::create(rewriter, loc,
200 arith::CmpFPredicate::UNO, rhs, rhs);
201 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
202 return success();
203 }
204};
205
206template <typename OpTy, arith::CmpFPredicate pred>
207struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
208public:
209 using OpRewritePattern<OpTy>::OpRewritePattern;
210
211 LogicalResult matchAndRewrite(OpTy op,
212 PatternRewriter &rewriter) const final {
213 Value lhs = op.getLhs();
214 Value rhs = op.getRhs();
215
216 Location loc = op.getLoc();
217 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
218 static_assert(pred == arith::CmpFPredicate::UGT ||
219 pred == arith::CmpFPredicate::ULT,
220 "pred must be either UGT or ULT");
221 Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222 Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
223
224 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
225 Value isNaN = arith::CmpFOp::create(rewriter, loc,
226 arith::CmpFPredicate::UNO, lhs, lhs);
227 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
228 return success();
229 }
230};
231
232struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
234 LogicalResult matchAndRewrite(arith::ExtFOp op,
235 PatternRewriter &rewriter) const final {
236 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
237 auto operand = op.getOperand();
238 Type operandTy = operand.getType();
239 Type resultTy = op.getType();
240 Type operandETy = getElementTypeOrSelf(operandTy);
241 Type resultETy = getElementTypeOrSelf(resultTy);
243 if (!operandETy.isBF16() || !resultETy.isF32()) {
244 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
246
247 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
248 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
249
250 Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
252
253 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
254 Value shl = arith::ShLIOp::create(b, exti, c16);
255 Value result = arith::BitcastOp::create(b, resultTy, shl);
257 rewriter.replaceOp(op, result);
258 return success();
260};
261
262struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
263 using Base::Base;
264 LogicalResult matchAndRewrite(arith::TruncFOp op,
265 PatternRewriter &rewriter) const final {
266 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
267 auto operand = op.getOperand();
268 Type operandTy = operand.getType();
269 Type resultTy = op.getType();
270 Type operandETy = getElementTypeOrSelf(operandTy);
271 Type resultETy = getElementTypeOrSelf(resultTy);
272
273 if (!operandETy.isF32() || !resultETy.isBF16()) {
274 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
275 }
276
277 if (op.getRoundingmodeAttr()) {
278 return rewriter.notifyMatchFailure(
279 op, "only applicable to default rounding mode.");
280 }
281
282 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
283 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
285 // Algorithm borrowed from this excellent code:
286 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
287 // There is a magic idea there, to let the addition of the rounding_bias to
288 // the mantissa simply overflow into the exponent bits. It's a bit of an
289 // aggressive, obfuscating optimization, but it is well-tested code, and it
290 // results in more concise and efficient IR.
291 // The case of NaN is handled separately (see isNaN and the final select).
292 // The case of infinities is NOT handled separately, which deserves an
293 // explanation. As the encoding of infinities has zero mantissa, the
294 // rounding-bias addition never carries into the exponent so that just gets
295 // truncated away, and as bfloat16 and float32 have the same number of
296 // exponent bits, that simple truncation is the desired outcome for
297 // infinities.
298 Value isNan =
299 arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
300 // Constant used to make the rounding bias.
301 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
302 // Constant used to generate a quiet NaN.
303 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
304 // Small constants used to address bits.
305 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
306 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
307 // Reinterpret the input f32 value as bits.
308 Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
309 // Read bit 16 as a value in {0,1}.
310 Value bit16 =
311 arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
312 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
313 // on bit 16, implementing the tie-breaking "to nearest even".
314 Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
315 // Add the rounding bias. Generally we want this to be added to the
316 // mantissa, but nothing prevents this to from carrying into the exponent
317 // bits, which would feel like a bug, but this is the magic trick here:
318 // when that happens, the mantissa gets reset to zero and the exponent
319 // gets incremented by the carry... which is actually exactly what we
320 // want.
321 Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
322 // Now that the rounding-bias has been added, truncating the low bits
323 // yields the correctly rounded result.
324 Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325 Value normalCaseResultI16 =
326 arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
327 // Select either the above-computed result, or a quiet NaN constant
328 // if the input was NaN.
329 Value select =
330 arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331 Value result = arith::BitcastOp::create(b, resultTy, select);
332 rewriter.replaceOp(op, result);
333 return success();
334 }
335};
336
337/// In this implementation of extf we take advantage of some key patterns we
338/// notice between the binary representation of an F4E2M1 value and its
339/// corresponding value in F32.
340///
341/// Note: x is sign bit
342/// | Binary | F4E2M1 | f32[23:32]
343/// | x000 | 0.0 | x000 0000 00
344/// | x001 | 0.5 | x011 1111 00
345/// | x010 | 1.0 | x011 1111 10
346/// | x011 | 1.5 | x011 1111 11
347/// | x100 | 2.0 | x010 0000 00
348/// | x101 | 3.0 | x010 0000 01
349/// | x110 | 4.0 | x010 0000 10
350/// | x111 | 6.0 | x010 0000 11
351///
352/// 1) There are only two versions of bits [25:31] in the f32 result
353/// F4E2M1 bits[2:3] decide whether:
354/// - F32 bits[25:31] = 0011 1111
355/// - F32 bits[25:31] = 0010 0000
356/// Exception is zero where
357/// - F32 bits[25:31] = 0000 0000
358///
359/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
360/// Exception is 0.5 where
361/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
362///
363/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
364///
365/// 4) F32 bits[1:22] = 0
366struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
367 using Base::Base;
368 LogicalResult matchAndRewrite(arith::ExtFOp op,
369 PatternRewriter &rewriter) const final {
370 Location loc = op.getLoc();
371 ImplicitLocOpBuilder b(loc, rewriter);
372 Value operand = op.getOperand();
373 Type operandTy = operand.getType();
374 Type resultTy = op.getType();
375 Type operandETy = getElementTypeOrSelf(operandTy);
376 Type resultETy = getElementTypeOrSelf(resultTy);
377
378 if (!isa<Float4E2M1FNType>(operandETy))
379 return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
380
381 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
382 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
383 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
384 Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
385
386 Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
387 Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
388 Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
389 Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
390 Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
391
392 Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
393
394 // Set last Exponent bit and Mantissa.
395 Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
396 Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
397 Value isHalf =
398 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
399 bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
400 bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
401 bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
402
403 // Set first 7 bits of Exponent.
404 Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
405 Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
406 Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
407 Value useLargerExp =
408 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
409 Value bits25To31 =
410 arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
411 Value zeroExp =
412 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
413 bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
414
415 // Set sign.
416 Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
417 Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
418 Value negative =
419 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
420 Value bit32 =
421 arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
422
423 // Add segments together.
424 Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
425 Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
426 Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
427 if (!isa<Float32Type>(resultETy))
428 result = arith::TruncFOp::create(b, resultTy, result);
429
430 rewriter.replaceOp(op, result);
431 return success();
432 }
433};
434
435struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
436 using Base::Base;
437 LogicalResult matchAndRewrite(arith::ExtFOp op,
438 PatternRewriter &rewriter) const final {
439 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440 Value operand = op.getOperand();
441 Type operandTy = operand.getType();
442 Type resultTy = op.getType();
443 Type operandETy = getElementTypeOrSelf(operandTy);
444 Type resultETy = getElementTypeOrSelf(resultTy);
445
446 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
447 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
448 }
449
450 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
451 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
452 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
453
454 Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
455 // create constants for NaNs
456 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
457 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
458 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
459
460 Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
461 Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
462
463 Value isNan =
464 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
465 // select for NaNs
466 f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
467 Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
468 if (resultETy.getIntOrFloatBitWidth() < 32) {
469 result = arith::TruncFOp::create(b, resultTy, result, nullptr,
470 op.getFastmathAttr());
471 } else if (resultETy.getIntOrFloatBitWidth() > 32) {
472 result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
473 }
474 rewriter.replaceOp(op, result);
475 return success();
476 }
477};
478
479/// Conversion from F32 to F4E2M1 according to the OCP Spec:
480/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
481///
482/// The spec requiers us to perform Round to Nearest, Ties to Even.
483///
484/// This means that after rounding, we should break ties by choosing the option
485/// which results in a mantissa of 0 in the least significant digit.
486///
487/// Table of representable values in F4E2M1:
488///
489/// Note: x is sign bit
490/// | Binary | F4E2M1 | F32[23:32]
491/// | x000 | 0.0 | x000 0000 00
492/// | x001 | 0.5 | x011 1111 00
493/// | x010 | 1.0 | x011 1111 10
494/// | x011 | 1.5 | x011 1111 11
495/// | x100 | 2.0 | x010 0000 00
496/// | x101 | 3.0 | x010 0000 01
497/// | x110 | 4.0 | x010 0000 10
498/// | x111 | 6.0 | x010 0000 11
499///
500/// Conversion procedure:
501/// Step 1: Clamp to representable bounds.
502/// Step 2: Convert exponent by adjusting bias.
503/// Step 3: Set mantissa to first bit.
504/// Step 4: Special consideration for subnormal and zero exponent.
505/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
506/// subnormal.
507struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
508 using Base::Base;
509 LogicalResult matchAndRewrite(arith::TruncFOp op,
510 PatternRewriter &rewriter) const final {
511 Location loc = op.getLoc();
512 ImplicitLocOpBuilder b(loc, rewriter);
513 Value operand = op.getOperand();
514 Type operandTy = operand.getType();
515 Type resultTy = op.getType();
516 Type operandETy = getElementTypeOrSelf(operandTy);
517 Type resultETy = getElementTypeOrSelf(resultTy);
518
519 Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
520 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
521 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
522 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
523
524 if (!isa<Float4E2M1FNType>(resultETy))
525 return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
526 if (!isa<Float32Type>(operandETy))
527 operand = arith::ExtFOp::create(b, f32Ty, operand);
528
529 Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
530 Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
531 Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
532 Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
533 Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
534 Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
535
536 // Step 0: Clamp to bounds.
537 Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
538 Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
539 Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
540 operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
541 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
542
543 // Step 1: Set sign bit.
544 Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
545 Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
546 Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
547 Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
548
549 // Step 2: Convert exponent by adjusting bias.
550 Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
551 Value cF4MantissaWidth = c0x1; // 1
552 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
553 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
554 Value biasAdjustedSignExp =
555 arith::SubIOp::create(b, f32SignExp, biasAdjustment);
556 Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
557 f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
558 f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
559
560 // Step 3: Set mantissa to first bit.
561 Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
562 Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
563 man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
564 Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
565 f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
566
567 // Step 4: Special consideration for conversion to 0.5.
568 Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
569 Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
570 Value isSubnormal =
571 arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
572 Value isNegOneExp =
573 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
574 Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
575 Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
576 man23Bits, zeroExpBits);
577 Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
578 Value isZeroExp =
579 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
580 Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
581 Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
582 Value subResult =
583 arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
584 subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
585 f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
586
587 // Step 5: Round up if necessary.
588 Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
589 Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
590 Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
591 Value shouldRound =
592 arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
593 shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
594 Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
595 f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
596
597 Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
598 rewriter.replaceOp(op, result);
599 return success();
600 }
601};
602
603/*
604TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
605Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
606they all map to NaN in F8E8M0 Type.
607*/
608struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
609 using Base::Base;
610 LogicalResult matchAndRewrite(arith::TruncFOp op,
611 PatternRewriter &rewriter) const final {
612 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
613 Value operand = op.getOperand();
614 Type operandTy = operand.getType();
615 Type operandETy = getElementTypeOrSelf(operandTy);
616 Type resultTy = op.getType();
617 Type resultETy = getElementTypeOrSelf(resultTy);
618 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
619 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
620 }
621
622 if (op.getRoundingmodeAttr()) {
623 return rewriter.notifyMatchFailure(
624 op, "only applicable to default rounding mode.");
625 }
626
627 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
628 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
629 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
630
631 if (operandETy.getIntOrFloatBitWidth() < 32) {
632 operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
633 } else if (operandETy.getIntOrFloatBitWidth() > 32) {
634 operand = arith::TruncFOp::create(
635 b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
636 }
637 Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
638 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
639 Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
640 Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
641 Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
642 rewriter.replaceOp(op, result);
643 return success();
644 }
645};
646
647struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
648 using Base::Base;
649 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
650 PatternRewriter &rewriter) const final {
651 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
652 Value inputOperand = op.getIn();
653 Value scaleOperand = op.getScale();
654 Type scaleTy = scaleOperand.getType();
655 Type scaleETy = getElementTypeOrSelf(scaleOperand);
656 // allow implicit exponent extraction from 16/32 bits floats
657 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
658 scaleETy = b.getF8E8M0Type();
659 scaleTy = cloneToShapedType(scaleTy, scaleETy);
660 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
661 op.getFastmathAttr());
662 }
663 // Catch scale types like f8E5M2.
664 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
665 return rewriter.notifyMatchFailure(
666 op, "scaling_extf is using scales of type which can not be converted "
667 "to f8E8M0FNU");
668 }
669 Type resultTy = op.getType();
670 // extf on scale will essentially create floating point number
671 // of type resulTy that is 2^scale and will also propagate NaNs
672 Value scaleExt =
673 arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
674 Value inputExt =
675 arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
676 Value result =
677 arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
678 rewriter.replaceOp(op, result);
679 return success();
680 }
681};
682
683/*
684Expands arith.ScalingTruncFOp(in, scale) into
685 scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
686 result = arith.truncf(in / (2^scale))
687 */
688struct ScalingTruncFOpConverter
689 : public OpRewritePattern<arith::ScalingTruncFOp> {
690 using Base::Base;
691 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
692 PatternRewriter &rewriter) const final {
693 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
694 Value inputOperand = op.getIn();
695 Value scaleOperand = op.getScale();
696 Type scaleTy = scaleOperand.getType();
697 Type scaleETy = getElementTypeOrSelf(scaleOperand);
698 // allow implicit exponent extraction from 16/32 bits floats
699 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
700 scaleETy = b.getF8E8M0Type();
701 scaleTy = cloneToShapedType(scaleTy, scaleETy);
702 scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
703 op.getFastmathAttr());
704 }
705 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
706 return rewriter.notifyMatchFailure(
707 op, "scaling_truncf is using scales type which can not be converted "
708 "to f8E8M0FNU");
709 }
710 Type resultTy = op.getType();
711 Type inputTy = inputOperand.getType();
712 // this will create a floating point number of type
713 // inputTy that is 2^scale and will also propagate NaNs
714 scaleOperand =
715 arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
716 Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
717 op.getFastmathAttr());
718 Value resultCast = arith::TruncFOp::create(
719 b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
720 rewriter.replaceOp(op, resultCast);
721 return success();
722 }
723};
724
725struct ArithExpandOpsPass
726 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
727 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
728
729 void runOnOperation() override {
730 RewritePatternSet patterns(&getContext());
731 ConversionTarget target(getContext());
732
733 arith::populateArithExpandOpsPatterns(patterns);
734
735 target.addLegalDialect<arith::ArithDialect>();
736 target.addLegalDialect<vector::VectorDialect>();
737
738 // clang-format off
739 target.addIllegalOp<
740 arith::CeilDivSIOp,
741 arith::CeilDivUIOp,
742 arith::FloorDivSIOp,
743 arith::MaxSIOp,
744 arith::MaxUIOp,
745 arith::MinSIOp,
746 arith::MinUIOp,
747 arith::MaximumFOp,
748 arith::MinimumFOp,
749 arith::MaxNumFOp,
750 arith::MinNumFOp,
751 arith::ScalingExtFOp,
752 arith::ScalingTruncFOp
753 >();
754
755 if (includeBf16)
756 arith::populateExpandBFloat16Patterns(patterns);
757 if (includeF8E8M0)
758 arith::populateExpandF8E8M0Patterns(patterns);
759 if (includeF4E2M1)
760 arith::populateExpandF4E2M1Patterns(patterns);
761
762 target.addDynamicallyLegalOp<arith::ExtFOp>(
763 [=](arith::ExtFOp op) {
764 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
765 Type outETy = getElementTypeOrSelf(op.getType());
766 bool legalTypes = true;
767 if (includeBf16)
768 legalTypes &= !(inETy.isBF16() && outETy.isF32());
769 if (includeF8E8M0)
770 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
771 if (includeF4E2M1)
772 legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
773 return legalTypes;
774 });
775
776 target.addDynamicallyLegalOp<arith::TruncFOp>(
777 [=](arith::TruncFOp op) {
778 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
779 Type outETy = getElementTypeOrSelf(op.getType());
780 bool legalTypes = true;
781 if (includeBf16)
782 legalTypes &= !(inETy.isF32() && outETy.isBF16());
783 if (includeF8E8M0)
784 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
785 if (includeF4E2M1)
786 legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
787 return legalTypes;
788 });
789
790 // clang-format on
791 if (failed(applyPartialConversion(getOperation(), target,
792 std::move(patterns))))
793 signalPassFailure();
794 }
795};
796
797} // namespace
798
802 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
803 patterns.getContext());
804}
805
807 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
808 patterns.getContext());
809}
810
812 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
813 patterns.getContext());
814}
815
817 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
818 patterns.getContext());
819}
820
823 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
824 patterns.getContext());
825}
826
830 // clang-format off
831 patterns.add<
832 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
833 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
834 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
835 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
836 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
837 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
838 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
839 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
840 >(patterns.getContext());
841 // clang-format on
842}
return success()
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition ExpandOps.cpp:27
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition ExpandOps.cpp:50
static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter)
Create a float constant.
Definition ExpandOps.cpp:38
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...