MLIR  22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/TypeUtilities.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return arith::ConstantOp::create(rewriter, loc,
32  DenseElementsAttr::get(shapedTy, attr));
33  }
34  return arith::ConstantOp::create(rewriter, loc, attr);
35 }
36 
37 /// Create a float constant.
38 static Value createFloatConst(Location loc, Type type, APFloat value,
39  PatternRewriter &rewriter) {
40  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
41  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42  return arith::ConstantOp::create(rewriter, loc,
43  DenseElementsAttr::get(shapedTy, attr));
44  }
45 
46  return arith::ConstantOp::create(rewriter, loc, attr);
47 }
48 
49 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
50 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
51  if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52  return shapedTy.clone(cloneTo);
53  }
54  return cloneTo;
55 }
56 
57 namespace {
58 
59 /// Expands CeilDivUIOp (n, m) into
60 /// n == 0 ? 0 : ((n-1) / m) + 1
61 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
63  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
64  PatternRewriter &rewriter) const final {
65  Location loc = op.getLoc();
66  Value a = op.getLhs();
67  Value b = op.getRhs();
68  Value zero = createConst(loc, a.getType(), 0, rewriter);
69  Value compare =
70  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
71  Value one = createConst(loc, a.getType(), 1, rewriter);
72  Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73  Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74  Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
76  return success();
77  }
78 };
79 
80 /// Expands CeilDivSIOp (a, b) into
81 /// z = a / b
82 /// if (z * b != a && (a < 0) == (b < 0)) {
83 /// return z + 1;
84 /// } else {
85 /// return z;
86 /// }
87 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
89  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
90  PatternRewriter &rewriter) const final {
91  Location loc = op.getLoc();
92  Type type = op.getType();
93  Value a = op.getLhs();
94  Value b = op.getRhs();
95 
96  Value zero = createConst(loc, type, 0, rewriter);
97  Value one = createConst(loc, type, 1, rewriter);
98 
99  Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100  Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101  Value notEqualDivisor = arith::CmpIOp::create(
102  rewriter, loc, arith::CmpIPredicate::ne, a, product);
103 
104  Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
105  a, zero);
106  Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
107  b, zero);
108 
109  Value signEqual = arith::CmpIOp::create(
110  rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
111  Value cond =
112  arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
113 
114  Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
115 
116  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
117  quotient);
118  return success();
119  }
120 };
121 
122 /// Expands FloorDivSIOp (x, y) into
123 /// z = x / y
124 /// if (z * y != x && (x < 0) != (y < 0)) {
125 /// return z - 1;
126 /// } else {
127 /// return z;
128 /// }
129 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
131  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
132  PatternRewriter &rewriter) const final {
133  Location loc = op.getLoc();
134  Type type = op.getType();
135  Value a = op.getLhs();
136  Value b = op.getRhs();
137 
138  Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139  Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140  Value notEqualDivisor = arith::CmpIOp::create(
141  rewriter, loc, arith::CmpIPredicate::ne, a, product);
142  Value zero = createConst(loc, type, 0, rewriter);
143 
144  Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
145  a, zero);
146  Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
147  b, zero);
148 
149  Value signOpposite = arith::CmpIOp::create(
150  rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
151  Value cond =
152  arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
153 
154  Value minusOne = createConst(loc, type, -1, rewriter);
155  Value quotientMinusOne =
156  arith::AddIOp::create(rewriter, loc, quotient, minusOne);
157 
158  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
159  quotient);
160  return success();
161  }
162 };
163 
164 template <typename OpTy, arith::CmpIPredicate pred>
165 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
166 public:
168 
169  LogicalResult matchAndRewrite(OpTy op,
170  PatternRewriter &rewriter) const final {
171  Value lhs = op.getLhs();
172  Value rhs = op.getRhs();
173 
174  Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
176  return success();
177  }
178 };
179 
180 template <typename OpTy, arith::CmpFPredicate pred>
181 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
182 public:
184 
185  LogicalResult matchAndRewrite(OpTy op,
186  PatternRewriter &rewriter) const final {
187  Value lhs = op.getLhs();
188  Value rhs = op.getRhs();
189 
190  Location loc = op.getLoc();
191  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
192  static_assert(pred == arith::CmpFPredicate::UGT ||
193  pred == arith::CmpFPredicate::ULT,
194  "pred must be either UGT or ULT");
195  Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196  Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
197 
198  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
199  Value isNaN = arith::CmpFOp::create(rewriter, loc,
200  arith::CmpFPredicate::UNO, rhs, rhs);
201  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
202  return success();
203  }
204 };
205 
206 template <typename OpTy, arith::CmpFPredicate pred>
207 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
208 public:
210 
211  LogicalResult matchAndRewrite(OpTy op,
212  PatternRewriter &rewriter) const final {
213  Value lhs = op.getLhs();
214  Value rhs = op.getRhs();
215 
216  Location loc = op.getLoc();
217  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
218  static_assert(pred == arith::CmpFPredicate::UGT ||
219  pred == arith::CmpFPredicate::ULT,
220  "pred must be either UGT or ULT");
221  Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222  Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
223 
224  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
225  Value isNaN = arith::CmpFOp::create(rewriter, loc,
226  arith::CmpFPredicate::UNO, lhs, lhs);
227  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
228  return success();
229  }
230 };
231 
232 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
234  LogicalResult matchAndRewrite(arith::ExtFOp op,
235  PatternRewriter &rewriter) const final {
236  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
237  auto operand = op.getOperand();
238  Type operandTy = operand.getType();
239  Type resultTy = op.getType();
240  Type operandETy = getElementTypeOrSelf(operandTy);
241  Type resultETy = getElementTypeOrSelf(resultTy);
242 
243  if (!operandETy.isBF16() || !resultETy.isF32()) {
244  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
245  }
246 
247  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
248  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
249 
250  Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251  Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
252 
253  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
254  Value shl = arith::ShLIOp::create(b, exti, c16);
255  Value result = arith::BitcastOp::create(b, resultTy, shl);
256 
257  rewriter.replaceOp(op, result);
258  return success();
259  }
260 };
261 
262 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264  LogicalResult matchAndRewrite(arith::TruncFOp op,
265  PatternRewriter &rewriter) const final {
266  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
267  auto operand = op.getOperand();
268  Type operandTy = operand.getType();
269  Type resultTy = op.getType();
270  Type operandETy = getElementTypeOrSelf(operandTy);
271  Type resultETy = getElementTypeOrSelf(resultTy);
272 
273  if (!operandETy.isF32() || !resultETy.isBF16()) {
274  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
275  }
276 
277  if (op.getRoundingmodeAttr()) {
278  return rewriter.notifyMatchFailure(
279  op, "only applicable to default rounding mode.");
280  }
281 
282  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
283  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
284 
285  // Algorithm borrowed from this excellent code:
286  // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
287  // There is a magic idea there, to let the addition of the rounding_bias to
288  // the mantissa simply overflow into the exponent bits. It's a bit of an
289  // aggressive, obfuscating optimization, but it is well-tested code, and it
290  // results in more concise and efficient IR.
291  // The case of NaN is handled separately (see isNaN and the final select).
292  // The case of infinities is NOT handled separately, which deserves an
293  // explanation. As the encoding of infinities has zero mantissa, the
294  // rounding-bias addition never carries into the exponent so that just gets
295  // truncated away, and as bfloat16 and float32 have the same number of
296  // exponent bits, that simple truncation is the desired outcome for
297  // infinities.
298  Value isNan =
299  arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
300  // Constant used to make the rounding bias.
301  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
302  // Constant used to generate a quiet NaN.
303  Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
304  // Small constants used to address bits.
305  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
306  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
307  // Reinterpret the input f32 value as bits.
308  Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
309  // Read bit 16 as a value in {0,1}.
310  Value bit16 =
311  arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
312  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
313  // on bit 16, implementing the tie-breaking "to nearest even".
314  Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
315  // Add the rounding bias. Generally we want this to be added to the
316  // mantissa, but nothing prevents this to from carrying into the exponent
317  // bits, which would feel like a bug, but this is the magic trick here:
318  // when that happens, the mantissa gets reset to zero and the exponent
319  // gets incremented by the carry... which is actually exactly what we
320  // want.
321  Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
322  // Now that the rounding-bias has been added, truncating the low bits
323  // yields the correctly rounded result.
324  Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325  Value normalCaseResultI16 =
326  arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
327  // Select either the above-computed result, or a quiet NaN constant
328  // if the input was NaN.
329  Value select =
330  arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331  Value result = arith::BitcastOp::create(b, resultTy, select);
332  rewriter.replaceOp(op, result);
333  return success();
334  }
335 };
336 
337 /// In this implementation of extf we take advantage of some key patterns we
338 /// notice between the binary representation of an F4E2M1 value and its
339 /// corresponding value in F32.
340 ///
341 /// Note: x is sign bit
342 /// | Binary | F4E2M1 | f32[23:32]
343 /// | x000 | 0.0 | x000 0000 00
344 /// | x001 | 0.5 | x011 1111 00
345 /// | x010 | 1.0 | x011 1111 10
346 /// | x011 | 1.5 | x011 1111 11
347 /// | x100 | 2.0 | x010 0000 00
348 /// | x101 | 3.0 | x010 0000 01
349 /// | x110 | 4.0 | x010 0000 10
350 /// | x111 | 6.0 | x010 0000 11
351 ///
352 /// 1) There are only two versions of bits [25:31] in the f32 result
353 /// F4E2M1 bits[2:3] decide whether:
354 /// - F32 bits[25:31] = 0011 1111
355 /// - F32 bits[25:31] = 0010 0000
356 /// Exception is zero where
357 /// - F32 bits[25:31] = 0000 0000
358 ///
359 /// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
360 /// Exception is 0.5 where
361 /// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
362 ///
363 /// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
364 ///
365 /// 4) F32 bits[1:22] = 0
366 struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368  LogicalResult matchAndRewrite(arith::ExtFOp op,
369  PatternRewriter &rewriter) const final {
370  Location loc = op.getLoc();
371  ImplicitLocOpBuilder b(loc, rewriter);
372  Value operand = op.getOperand();
373  Type operandTy = operand.getType();
374  Type resultTy = op.getType();
375  Type operandETy = getElementTypeOrSelf(operandTy);
376  Type resultETy = getElementTypeOrSelf(resultTy);
377 
378  if (!isa<Float4E2M1FNType>(operandETy))
379  return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
380 
381  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
382  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
383  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
384  Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
385 
386  Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
387  Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
388  Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
389  Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
390 
391  // Set last Exponent bit and Mantissa.
392  Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
393  Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
394  Value isHalf =
395  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
396  bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
397  bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
398  bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
399 
400  // Set first 7 bits of Exponent.
401  Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
402  Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
403  Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
404  Value useLargerExp =
405  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
406  Value bits25To31 =
407  arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
408  Value zeroExp =
409  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
410  bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
411 
412  // Set sign.
413  Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
414  Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
415  Value negative =
416  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
417  Value bit32 =
418  arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
419 
420  // Add segments together.
421  Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
422  Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
423  Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
424  if (!isa<Float32Type>(resultETy))
425  result = arith::TruncFOp::create(b, resultTy, result);
426 
427  rewriter.replaceOp(op, result);
428  return success();
429  }
430 };
431 
432 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
434  LogicalResult matchAndRewrite(arith::ExtFOp op,
435  PatternRewriter &rewriter) const final {
436  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
437  Value operand = op.getOperand();
438  Type operandTy = operand.getType();
439  Type resultTy = op.getType();
440  Type operandETy = getElementTypeOrSelf(operandTy);
441  Type resultETy = getElementTypeOrSelf(resultTy);
442 
443  if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
444  return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
445  }
446 
447  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
448  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
449  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
450 
451  Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
452  // create constants for NaNs
453  Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
454  Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
455  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
456 
457  Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
458  Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
459 
460  Value isNan =
461  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
462  // select for NaNs
463  f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
464  Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
465  if (resultETy.getIntOrFloatBitWidth() < 32) {
466  result = arith::TruncFOp::create(b, resultTy, result, nullptr,
467  op.getFastmathAttr());
468  } else if (resultETy.getIntOrFloatBitWidth() > 32) {
469  result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
470  }
471  rewriter.replaceOp(op, result);
472  return success();
473  }
474 };
475 
476 /// Conversion from F32 to F4E2M1 according to the OCP Spec:
477 /// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
478 ///
479 /// The spec requiers us to perform Round to Nearest, Ties to Even.
480 ///
481 /// This means that after rounding, we should break ties by choosing the option
482 /// which results in a mantissa of 0 in the least significant digit.
483 ///
484 /// Table of representable values in F4E2M1:
485 ///
486 /// Note: x is sign bit
487 /// | Binary | F4E2M1 | F32[23:32]
488 /// | x000 | 0.0 | x000 0000 00
489 /// | x001 | 0.5 | x011 1111 00
490 /// | x010 | 1.0 | x011 1111 10
491 /// | x011 | 1.5 | x011 1111 11
492 /// | x100 | 2.0 | x010 0000 00
493 /// | x101 | 3.0 | x010 0000 01
494 /// | x110 | 4.0 | x010 0000 10
495 /// | x111 | 6.0 | x010 0000 11
496 ///
497 /// Conversion procedure:
498 /// Step 1: Clamp to representable bounds.
499 /// Step 2: Convert exponent by adjusting bias.
500 /// Step 3: Set mantissa to first bit.
501 /// Step 4: Special consideration for subnormal and zero exponent.
502 /// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
503 /// subnormal.
504 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
506  LogicalResult matchAndRewrite(arith::TruncFOp op,
507  PatternRewriter &rewriter) const final {
508  Location loc = op.getLoc();
509  ImplicitLocOpBuilder b(loc, rewriter);
510  Value operand = op.getOperand();
511  Type operandTy = operand.getType();
512  Type resultTy = op.getType();
513  Type operandETy = getElementTypeOrSelf(operandTy);
514  Type resultETy = getElementTypeOrSelf(resultTy);
515 
516  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
517  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
518  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
519  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
520 
521  if (!isa<Float4E2M1FNType>(resultETy))
522  return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
523  if (!isa<Float32Type>(operandETy))
524  operand = arith::ExtFOp::create(b, f32Ty, operand);
525 
526  Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
527  Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
528  Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
529  Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
530  Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
531  Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
532 
533  // Step 0: Clamp to bounds.
534  Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
535  Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
536  Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
537  operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
538  Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
539 
540  // Step 1: Set sign bit.
541  Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
542  Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
543  Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
544  Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
545 
546  // Step 2: Convert exponent by adjusting bias.
547  Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
548  Value cF4MantissaWidth = c0x1; // 1
549  Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
550  Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
551  Value biasAdjustedSignExp =
552  arith::SubIOp::create(b, f32SignExp, biasAdjustment);
553  Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
554  f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
555  f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
556 
557  // Step 3: Set mantissa to first bit.
558  Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
559  Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
560  man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
561  Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
562  f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
563 
564  // Step 4: Special consideration for conversion to 0.5.
565  Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
566  Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
567  Value isSubnormal =
568  arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
569  Value isNegOneExp =
570  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
571  Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
572  Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
573  man23Bits, zeroExpBits);
574  Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
575  Value isZeroExp =
576  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
577  Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
578  Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
579  Value subResult =
580  arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
581  subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
582  f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
583 
584  // Step 5: Round up if necessary.
585  Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
586  Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
587  Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
588  Value shouldRound =
589  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
590  shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
591  Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
592  f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
593 
594  Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
595  rewriter.replaceOp(op, result);
596  return success();
597  }
598 };
599 
600 /*
601 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
602 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
603 they all map to NaN in F8E8M0 Type.
604 */
605 struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
607  LogicalResult matchAndRewrite(arith::TruncFOp op,
608  PatternRewriter &rewriter) const final {
609  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
610  Value operand = op.getOperand();
611  Type operandTy = operand.getType();
612  Type operandETy = getElementTypeOrSelf(operandTy);
613  Type resultTy = op.getType();
614  Type resultETy = getElementTypeOrSelf(resultTy);
615  if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
616  return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
617  }
618 
619  if (op.getRoundingmodeAttr()) {
620  return rewriter.notifyMatchFailure(
621  op, "only applicable to default rounding mode.");
622  }
623 
624  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
625  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
626  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
627 
628  if (operandETy.getIntOrFloatBitWidth() < 32) {
629  operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
630  } else if (operandETy.getIntOrFloatBitWidth() > 32) {
631  operand = arith::TruncFOp::create(
632  b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
633  }
634  Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
635  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
636  Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
637  Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
638  Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
639  rewriter.replaceOp(op, result);
640  return success();
641  }
642 };
643 
644 struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
646  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
647  PatternRewriter &rewriter) const final {
648  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
649  Value inputOperand = op.getIn();
650  Value scaleOperand = op.getScale();
651  Type scaleTy = scaleOperand.getType();
652  Type scaleETy = getElementTypeOrSelf(scaleOperand);
653  // allow implicit exponent extraction from 16/32 bits floats
654  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
655  scaleETy = b.getF8E8M0Type();
656  scaleTy = cloneToShapedType(scaleTy, scaleETy);
657  scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
658  op.getFastmathAttr());
659  }
660  // Catch scale types like f8E5M2.
661  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
662  return rewriter.notifyMatchFailure(
663  op, "scaling_extf is using scales of type which can not be converted "
664  "to f8E8M0FNU");
665  }
666  Type resultTy = op.getType();
667  // extf on scale will essentially create floating point number
668  // of type resulTy that is 2^scale and will also propagate NaNs
669  Value scaleExt =
670  arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
671  Value inputExt =
672  arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
673  Value result =
674  arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
675  rewriter.replaceOp(op, result);
676  return success();
677  }
678 };
679 
680 /*
681 Expands arith.ScalingTruncFOp(in, scale) into
682  scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
683  result = arith.truncf(in / (2^scale))
684  */
685 struct ScalingTruncFOpConverter
686  : public OpRewritePattern<arith::ScalingTruncFOp> {
688  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
689  PatternRewriter &rewriter) const final {
690  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
691  Value inputOperand = op.getIn();
692  Value scaleOperand = op.getScale();
693  Type scaleTy = scaleOperand.getType();
694  Type scaleETy = getElementTypeOrSelf(scaleOperand);
695  // allow implicit exponent extraction from 16/32 bits floats
696  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
697  scaleETy = b.getF8E8M0Type();
698  scaleTy = cloneToShapedType(scaleTy, scaleETy);
699  scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
700  op.getFastmathAttr());
701  }
702  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
703  return rewriter.notifyMatchFailure(
704  op, "scaling_truncf is using scales type which can not be converted "
705  "to f8E8M0FNU");
706  }
707  Type resultTy = op.getType();
708  Type inputTy = inputOperand.getType();
709  // this will create a floating point number of type
710  // inputTy that is 2^scale and will also propagate NaNs
711  scaleOperand =
712  arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
713  Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
714  op.getFastmathAttr());
715  Value resultCast = arith::TruncFOp::create(
716  b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
717  rewriter.replaceOp(op, resultCast);
718  return success();
719  }
720 };
721 
722 struct ArithExpandOpsPass
723  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
724  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
725 
726  void runOnOperation() override {
728  ConversionTarget target(getContext());
729 
731 
732  target.addLegalDialect<arith::ArithDialect>();
733  target.addLegalDialect<vector::VectorDialect>();
734 
735  // clang-format off
736  target.addIllegalOp<
737  arith::CeilDivSIOp,
738  arith::CeilDivUIOp,
739  arith::FloorDivSIOp,
740  arith::MaxSIOp,
741  arith::MaxUIOp,
742  arith::MinSIOp,
743  arith::MinUIOp,
744  arith::MaximumFOp,
745  arith::MinimumFOp,
746  arith::MaxNumFOp,
747  arith::MinNumFOp,
748  arith::ScalingExtFOp,
749  arith::ScalingTruncFOp
750  >();
751 
752  if (includeBf16)
754  if (includeF8E8M0)
756  if (includeF4E2M1)
758 
759  target.addDynamicallyLegalOp<arith::ExtFOp>(
760  [=](arith::ExtFOp op) {
761  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
762  Type outETy = getElementTypeOrSelf(op.getType());
763  bool legalTypes = true;
764  if (includeBf16)
765  legalTypes &= !(inETy.isBF16() && outETy.isF32());
766  if (includeF8E8M0)
767  legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
768  if (includeF4E2M1)
769  legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
770  return legalTypes;
771  });
772 
773  target.addDynamicallyLegalOp<arith::TruncFOp>(
774  [=](arith::TruncFOp op) {
775  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
776  Type outETy = getElementTypeOrSelf(op.getType());
777  bool legalTypes = true;
778  if (includeBf16)
779  legalTypes &= !(inETy.isF32() && outETy.isBF16());
780  if (includeF8E8M0)
781  legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
782  if (includeF4E2M1)
783  legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
784  return legalTypes;
785  });
786 
787  // clang-format on
788  if (failed(applyPartialConversion(getOperation(), target,
789  std::move(patterns))))
790  signalPassFailure();
791  }
792 };
793 
794 } // namespace
795 
798  patterns
799  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
800  patterns.getContext());
801 }
802 
804  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
805  patterns.getContext());
806 }
807 
809  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
810  patterns.getContext());
811 }
812 
814  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
815  patterns.getContext());
816 }
817 
820  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
821  patterns.getContext());
822 }
823 
827  // clang-format off
828  patterns.add<
829  MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
830  MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
831  MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
832  MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
833  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
834  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
835  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
836  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
837  >(patterns.getContext());
838  // clang-format on
839 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition: ExpandOps.cpp:50
static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter)
Create a float constant.
Definition: ExpandOps.cpp:38
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:803
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
Definition: ExpandOps.cpp:818
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:813
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:796
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:808
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:824
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319