MLIR  22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/Location.h"
14 #include "mlir/IR/TypeUtilities.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return arith::ConstantOp::create(rewriter, loc,
32  DenseElementsAttr::get(shapedTy, attr));
33  }
34  return arith::ConstantOp::create(rewriter, loc, attr);
35 }
36 
37 /// Create a float constant.
38 static Value createFloatConst(Location loc, Type type, APFloat value,
39  PatternRewriter &rewriter) {
40  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
41  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42  return arith::ConstantOp::create(rewriter, loc,
43  DenseElementsAttr::get(shapedTy, attr));
44  }
45 
46  return arith::ConstantOp::create(rewriter, loc, attr);
47 }
48 
49 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
50 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
51  if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
52  return shapedTy.clone(cloneTo);
53  }
54  return cloneTo;
55 }
56 
57 namespace {
58 
59 /// Expands CeilDivUIOp (n, m) into
60 /// n == 0 ? 0 : ((n-1) / m) + 1
61 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
62  using Base::Base;
63  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
64  PatternRewriter &rewriter) const final {
65  Location loc = op.getLoc();
66  Value a = op.getLhs();
67  Value b = op.getRhs();
68  Value zero = createConst(loc, a.getType(), 0, rewriter);
69  Value compare =
70  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero);
71  Value one = createConst(loc, a.getType(), 1, rewriter);
72  Value minusOne = arith::SubIOp::create(rewriter, loc, a, one);
73  Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b);
74  Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
75  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
76  return success();
77  }
78 };
79 
80 /// Expands CeilDivSIOp (a, b) into
81 /// z = a / b
82 /// if (z * b != a && (a < 0) == (b < 0)) {
83 /// return z + 1;
84 /// } else {
85 /// return z;
86 /// }
87 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
88  using Base::Base;
89  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
90  PatternRewriter &rewriter) const final {
91  Location loc = op.getLoc();
92  Type type = op.getType();
93  Value a = op.getLhs();
94  Value b = op.getRhs();
95 
96  Value zero = createConst(loc, type, 0, rewriter);
97  Value one = createConst(loc, type, 1, rewriter);
98 
99  Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
100  Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
101  Value notEqualDivisor = arith::CmpIOp::create(
102  rewriter, loc, arith::CmpIPredicate::ne, a, product);
103 
104  Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
105  a, zero);
106  Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
107  b, zero);
108 
109  Value signEqual = arith::CmpIOp::create(
110  rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg);
111  Value cond =
112  arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual);
113 
114  Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one);
115 
116  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
117  quotient);
118  return success();
119  }
120 };
121 
122 /// Expands FloorDivSIOp (x, y) into
123 /// z = x / y
124 /// if (z * y != x && (x < 0) != (y < 0)) {
125 /// return z - 1;
126 /// } else {
127 /// return z;
128 /// }
129 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
130  using Base::Base;
131  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
132  PatternRewriter &rewriter) const final {
133  Location loc = op.getLoc();
134  Type type = op.getType();
135  Value a = op.getLhs();
136  Value b = op.getRhs();
137 
138  Value quotient = arith::DivSIOp::create(rewriter, loc, a, b);
139  Value product = arith::MulIOp::create(rewriter, loc, quotient, b);
140  Value notEqualDivisor = arith::CmpIOp::create(
141  rewriter, loc, arith::CmpIPredicate::ne, a, product);
142  Value zero = createConst(loc, type, 0, rewriter);
143 
144  Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
145  a, zero);
146  Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
147  b, zero);
148 
149  Value signOpposite = arith::CmpIOp::create(
150  rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg);
151  Value cond =
152  arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite);
153 
154  Value minusOne = createConst(loc, type, -1, rewriter);
155  Value quotientMinusOne =
156  arith::AddIOp::create(rewriter, loc, quotient, minusOne);
157 
158  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
159  quotient);
160  return success();
161  }
162 };
163 
164 template <typename OpTy, arith::CmpIPredicate pred>
165 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
166 public:
168 
169  LogicalResult matchAndRewrite(OpTy op,
170  PatternRewriter &rewriter) const final {
171  Value lhs = op.getLhs();
172  Value rhs = op.getRhs();
173 
174  Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs);
175  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
176  return success();
177  }
178 };
179 
180 template <typename OpTy, arith::CmpFPredicate pred>
181 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
182 public:
184 
185  LogicalResult matchAndRewrite(OpTy op,
186  PatternRewriter &rewriter) const final {
187  Value lhs = op.getLhs();
188  Value rhs = op.getRhs();
189 
190  Location loc = op.getLoc();
191  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
192  static_assert(pred == arith::CmpFPredicate::UGT ||
193  pred == arith::CmpFPredicate::ULT,
194  "pred must be either UGT or ULT");
195  Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
196  Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
197 
198  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
199  Value isNaN = arith::CmpFOp::create(rewriter, loc,
200  arith::CmpFPredicate::UNO, rhs, rhs);
201  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
202  return success();
203  }
204 };
205 
206 template <typename OpTy, arith::CmpFPredicate pred>
207 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
208 public:
210 
211  LogicalResult matchAndRewrite(OpTy op,
212  PatternRewriter &rewriter) const final {
213  Value lhs = op.getLhs();
214  Value rhs = op.getRhs();
215 
216  Location loc = op.getLoc();
217  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
218  static_assert(pred == arith::CmpFPredicate::UGT ||
219  pred == arith::CmpFPredicate::ULT,
220  "pred must be either UGT or ULT");
221  Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs);
222  Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs);
223 
224  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
225  Value isNaN = arith::CmpFOp::create(rewriter, loc,
226  arith::CmpFPredicate::UNO, lhs, lhs);
227  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
228  return success();
229  }
230 };
231 
232 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
233  using Base::Base;
234  LogicalResult matchAndRewrite(arith::ExtFOp op,
235  PatternRewriter &rewriter) const final {
236  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
237  auto operand = op.getOperand();
238  Type operandTy = operand.getType();
239  Type resultTy = op.getType();
240  Type operandETy = getElementTypeOrSelf(operandTy);
241  Type resultETy = getElementTypeOrSelf(resultTy);
242 
243  if (!operandETy.isBF16() || !resultETy.isF32()) {
244  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
245  }
246 
247  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
248  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
249 
250  Value bitcast = arith::BitcastOp::create(b, i16Ty, operand);
251  Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
252 
253  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
254  Value shl = arith::ShLIOp::create(b, exti, c16);
255  Value result = arith::BitcastOp::create(b, resultTy, shl);
256 
257  rewriter.replaceOp(op, result);
258  return success();
259  }
260 };
261 
262 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
263  using Base::Base;
264  LogicalResult matchAndRewrite(arith::TruncFOp op,
265  PatternRewriter &rewriter) const final {
266  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
267  auto operand = op.getOperand();
268  Type operandTy = operand.getType();
269  Type resultTy = op.getType();
270  Type operandETy = getElementTypeOrSelf(operandTy);
271  Type resultETy = getElementTypeOrSelf(resultTy);
272 
273  if (!operandETy.isF32() || !resultETy.isBF16()) {
274  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
275  }
276 
277  if (op.getRoundingmodeAttr()) {
278  return rewriter.notifyMatchFailure(
279  op, "only applicable to default rounding mode.");
280  }
281 
282  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
283  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
284 
285  // Algorithm borrowed from this excellent code:
286  // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
287  // There is a magic idea there, to let the addition of the rounding_bias to
288  // the mantissa simply overflow into the exponent bits. It's a bit of an
289  // aggressive, obfuscating optimization, but it is well-tested code, and it
290  // results in more concise and efficient IR.
291  // The case of NaN is handled separately (see isNaN and the final select).
292  // The case of infinities is NOT handled separately, which deserves an
293  // explanation. As the encoding of infinities has zero mantissa, the
294  // rounding-bias addition never carries into the exponent so that just gets
295  // truncated away, and as bfloat16 and float32 have the same number of
296  // exponent bits, that simple truncation is the desired outcome for
297  // infinities.
298  Value isNan =
299  arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand);
300  // Constant used to make the rounding bias.
301  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
302  // Constant used to generate a quiet NaN.
303  Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
304  // Small constants used to address bits.
305  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
306  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
307  // Reinterpret the input f32 value as bits.
308  Value bitcast = arith::BitcastOp::create(b, i32Ty, operand);
309  // Read bit 16 as a value in {0,1}.
310  Value bit16 =
311  arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1);
312  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
313  // on bit 16, implementing the tie-breaking "to nearest even".
314  Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF);
315  // Add the rounding bias. Generally we want this to be added to the
316  // mantissa, but nothing prevents this to from carrying into the exponent
317  // bits, which would feel like a bug, but this is the magic trick here:
318  // when that happens, the mantissa gets reset to zero and the exponent
319  // gets incremented by the carry... which is actually exactly what we
320  // want.
321  Value biased = arith::AddIOp::create(b, bitcast, roundingBias);
322  // Now that the rounding-bias has been added, truncating the low bits
323  // yields the correctly rounded result.
324  Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16);
325  Value normalCaseResultI16 =
326  arith::TruncIOp::create(b, i16Ty, biasedAndShifted);
327  // Select either the above-computed result, or a quiet NaN constant
328  // if the input was NaN.
329  Value select =
330  arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16);
331  Value result = arith::BitcastOp::create(b, resultTy, select);
332  rewriter.replaceOp(op, result);
333  return success();
334  }
335 };
336 
337 /// In this implementation of extf we take advantage of some key patterns we
338 /// notice between the binary representation of an F4E2M1 value and its
339 /// corresponding value in F32.
340 ///
341 /// Note: x is sign bit
342 /// | Binary | F4E2M1 | f32[23:32]
343 /// | x000 | 0.0 | x000 0000 00
344 /// | x001 | 0.5 | x011 1111 00
345 /// | x010 | 1.0 | x011 1111 10
346 /// | x011 | 1.5 | x011 1111 11
347 /// | x100 | 2.0 | x010 0000 00
348 /// | x101 | 3.0 | x010 0000 01
349 /// | x110 | 4.0 | x010 0000 10
350 /// | x111 | 6.0 | x010 0000 11
351 ///
352 /// 1) There are only two versions of bits [25:31] in the f32 result
353 /// F4E2M1 bits[2:3] decide whether:
354 /// - F32 bits[25:31] = 0011 1111
355 /// - F32 bits[25:31] = 0010 0000
356 /// Exception is zero where
357 /// - F32 bits[25:31] = 0000 0000
358 ///
359 /// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
360 /// Exception is 0.5 where
361 /// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
362 ///
363 /// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
364 ///
365 /// 4) F32 bits[1:22] = 0
366 struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
367  using Base::Base;
368  LogicalResult matchAndRewrite(arith::ExtFOp op,
369  PatternRewriter &rewriter) const final {
370  Location loc = op.getLoc();
371  ImplicitLocOpBuilder b(loc, rewriter);
372  Value operand = op.getOperand();
373  Type operandTy = operand.getType();
374  Type resultTy = op.getType();
375  Type operandETy = getElementTypeOrSelf(operandTy);
376  Type resultETy = getElementTypeOrSelf(resultTy);
377 
378  if (!isa<Float4E2M1FNType>(operandETy))
379  return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
380 
381  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
382  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
383  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
384  Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand);
385 
386  Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
387  Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
388  Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
389  Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
390  Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
391 
392  Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
393 
394  // Set last Exponent bit and Mantissa.
395  Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
396  Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
397  Value isHalf =
398  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
399  bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
400  bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
401  bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
402 
403  // Set first 7 bits of Exponent.
404  Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
405  Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
406  Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
407  Value useLargerExp =
408  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
409  Value bits25To31 =
410  arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
411  Value zeroExp =
412  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
413  bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
414 
415  // Set sign.
416  Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
417  Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
418  Value negative =
419  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8);
420  Value bit32 =
421  arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits);
422 
423  // Add segments together.
424  Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31);
425  Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32);
426  Value result = arith::BitcastOp::create(b, f32Ty, bits1To32);
427  if (!isa<Float32Type>(resultETy))
428  result = arith::TruncFOp::create(b, resultTy, result);
429 
430  rewriter.replaceOp(op, result);
431  return success();
432  }
433 };
434 
435 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
436  using Base::Base;
437  LogicalResult matchAndRewrite(arith::ExtFOp op,
438  PatternRewriter &rewriter) const final {
439  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440  Value operand = op.getOperand();
441  Type operandTy = operand.getType();
442  Type resultTy = op.getType();
443  Type operandETy = getElementTypeOrSelf(operandTy);
444  Type resultETy = getElementTypeOrSelf(resultTy);
445 
446  if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
447  return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
448  }
449 
450  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
451  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
452  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
453 
454  Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
455  // create constants for NaNs
456  Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
457  Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
458  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
459 
460  Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
461  Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
462 
463  Value isNan =
464  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
465  // select for NaNs
466  f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
467  Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
468  if (resultETy.getIntOrFloatBitWidth() < 32) {
469  result = arith::TruncFOp::create(b, resultTy, result, nullptr,
470  op.getFastmathAttr());
471  } else if (resultETy.getIntOrFloatBitWidth() > 32) {
472  result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr());
473  }
474  rewriter.replaceOp(op, result);
475  return success();
476  }
477 };
478 
479 /// Conversion from F32 to F4E2M1 according to the OCP Spec:
480 /// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
481 ///
482 /// The spec requiers us to perform Round to Nearest, Ties to Even.
483 ///
484 /// This means that after rounding, we should break ties by choosing the option
485 /// which results in a mantissa of 0 in the least significant digit.
486 ///
487 /// Table of representable values in F4E2M1:
488 ///
489 /// Note: x is sign bit
490 /// | Binary | F4E2M1 | F32[23:32]
491 /// | x000 | 0.0 | x000 0000 00
492 /// | x001 | 0.5 | x011 1111 00
493 /// | x010 | 1.0 | x011 1111 10
494 /// | x011 | 1.5 | x011 1111 11
495 /// | x100 | 2.0 | x010 0000 00
496 /// | x101 | 3.0 | x010 0000 01
497 /// | x110 | 4.0 | x010 0000 10
498 /// | x111 | 6.0 | x010 0000 11
499 ///
500 /// Conversion procedure:
501 /// Step 1: Clamp to representable bounds.
502 /// Step 2: Convert exponent by adjusting bias.
503 /// Step 3: Set mantissa to first bit.
504 /// Step 4: Special consideration for subnormal and zero exponent.
505 /// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
506 /// subnormal.
507 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
508  using Base::Base;
509  LogicalResult matchAndRewrite(arith::TruncFOp op,
510  PatternRewriter &rewriter) const final {
511  Location loc = op.getLoc();
512  ImplicitLocOpBuilder b(loc, rewriter);
513  Value operand = op.getOperand();
514  Type operandTy = operand.getType();
515  Type resultTy = op.getType();
516  Type operandETy = getElementTypeOrSelf(operandTy);
517  Type resultETy = getElementTypeOrSelf(resultTy);
518 
519  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
520  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
521  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
522  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
523 
524  if (!isa<Float4E2M1FNType>(resultETy))
525  return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
526  if (!isa<Float32Type>(operandETy))
527  operand = arith::ExtFOp::create(b, f32Ty, operand);
528 
529  Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
530  Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
531  Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
532  Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
533  Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
534  Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
535 
536  // Step 0: Clamp to bounds.
537  Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
538  Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
539  Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand);
540  operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped);
541  Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped);
542 
543  // Step 1: Set sign bit.
544  Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
545  Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth);
546  Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign);
547  Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3);
548 
549  // Step 2: Convert exponent by adjusting bias.
550  Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
551  Value cF4MantissaWidth = c0x1; // 1
552  Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
553  Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
554  Value biasAdjustedSignExp =
555  arith::SubIOp::create(b, f32SignExp, biasAdjustment);
556  Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp);
557  f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth);
558  f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp);
559 
560  // Step 3: Set mantissa to first bit.
561  Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
562  Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask);
563  man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016);
564  Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit);
565  f4Bits = arith::AddIOp::create(b, f4Bits, f4Man);
566 
567  // Step 4: Special consideration for conversion to 0.5.
568  Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
569  Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp);
570  Value isSubnormal =
571  arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00);
572  Value isNegOneExp =
573  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff);
574  Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask);
575  Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt,
576  man23Bits, zeroExpBits);
577  Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan);
578  Value isZeroExp =
579  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00);
580  Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
581  Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
582  Value subResult =
583  arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits);
584  subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult);
585  f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult);
586 
587  // Step 5: Round up if necessary.
588  Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
589  Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
590  Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask);
591  Value shouldRound =
592  arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound);
593  shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal);
594  Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1);
595  f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits);
596 
597  Value result = arith::BitcastOp::create(b, resultTy, f4Bits);
598  rewriter.replaceOp(op, result);
599  return success();
600  }
601 };
602 
603 /*
604 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
605 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
606 they all map to NaN in F8E8M0 Type.
607 */
608 struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
609  using Base::Base;
610  LogicalResult matchAndRewrite(arith::TruncFOp op,
611  PatternRewriter &rewriter) const final {
612  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
613  Value operand = op.getOperand();
614  Type operandTy = operand.getType();
615  Type operandETy = getElementTypeOrSelf(operandTy);
616  Type resultTy = op.getType();
617  Type resultETy = getElementTypeOrSelf(resultTy);
618  if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
619  return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
620  }
621 
622  if (op.getRoundingmodeAttr()) {
623  return rewriter.notifyMatchFailure(
624  op, "only applicable to default rounding mode.");
625  }
626 
627  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
628  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
629  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
630 
631  if (operandETy.getIntOrFloatBitWidth() < 32) {
632  operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr());
633  } else if (operandETy.getIntOrFloatBitWidth() > 32) {
634  operand = arith::TruncFOp::create(
635  b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
636  }
637  Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand);
638  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
639  Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth);
640  Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp);
641  Value result = arith::BitcastOp::create(b, resultTy, exp8Bits);
642  rewriter.replaceOp(op, result);
643  return success();
644  }
645 };
646 
647 struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
648  using Base::Base;
649  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
650  PatternRewriter &rewriter) const final {
651  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
652  Value inputOperand = op.getIn();
653  Value scaleOperand = op.getScale();
654  Type scaleTy = scaleOperand.getType();
655  Type scaleETy = getElementTypeOrSelf(scaleOperand);
656  // allow implicit exponent extraction from 16/32 bits floats
657  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
658  scaleETy = b.getF8E8M0Type();
659  scaleTy = cloneToShapedType(scaleTy, scaleETy);
660  scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
661  op.getFastmathAttr());
662  }
663  // Catch scale types like f8E5M2.
664  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
665  return rewriter.notifyMatchFailure(
666  op, "scaling_extf is using scales of type which can not be converted "
667  "to f8E8M0FNU");
668  }
669  Type resultTy = op.getType();
670  // extf on scale will essentially create floating point number
671  // of type resulTy that is 2^scale and will also propagate NaNs
672  Value scaleExt =
673  arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr());
674  Value inputExt =
675  arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr());
676  Value result =
677  arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr());
678  rewriter.replaceOp(op, result);
679  return success();
680  }
681 };
682 
683 /*
684 Expands arith.ScalingTruncFOp(in, scale) into
685  scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
686  result = arith.truncf(in / (2^scale))
687  */
688 struct ScalingTruncFOpConverter
689  : public OpRewritePattern<arith::ScalingTruncFOp> {
690  using Base::Base;
691  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
692  PatternRewriter &rewriter) const final {
693  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
694  Value inputOperand = op.getIn();
695  Value scaleOperand = op.getScale();
696  Type scaleTy = scaleOperand.getType();
697  Type scaleETy = getElementTypeOrSelf(scaleOperand);
698  // allow implicit exponent extraction from 16/32 bits floats
699  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
700  scaleETy = b.getF8E8M0Type();
701  scaleTy = cloneToShapedType(scaleTy, scaleETy);
702  scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr,
703  op.getFastmathAttr());
704  }
705  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
706  return rewriter.notifyMatchFailure(
707  op, "scaling_truncf is using scales type which can not be converted "
708  "to f8E8M0FNU");
709  }
710  Type resultTy = op.getType();
711  Type inputTy = inputOperand.getType();
712  // this will create a floating point number of type
713  // inputTy that is 2^scale and will also propagate NaNs
714  scaleOperand =
715  arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr());
716  Value result = arith::DivFOp::create(b, inputOperand, scaleOperand,
717  op.getFastmathAttr());
718  Value resultCast = arith::TruncFOp::create(
719  b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
720  rewriter.replaceOp(op, resultCast);
721  return success();
722  }
723 };
724 
725 struct ArithExpandOpsPass
726  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
727  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
728 
729  void runOnOperation() override {
731  ConversionTarget target(getContext());
732 
734 
735  target.addLegalDialect<arith::ArithDialect>();
736  target.addLegalDialect<vector::VectorDialect>();
737 
738  // clang-format off
739  target.addIllegalOp<
740  arith::CeilDivSIOp,
741  arith::CeilDivUIOp,
742  arith::FloorDivSIOp,
743  arith::MaxSIOp,
744  arith::MaxUIOp,
745  arith::MinSIOp,
746  arith::MinUIOp,
747  arith::MaximumFOp,
748  arith::MinimumFOp,
749  arith::MaxNumFOp,
750  arith::MinNumFOp,
751  arith::ScalingExtFOp,
752  arith::ScalingTruncFOp
753  >();
754 
755  if (includeBf16)
757  if (includeF8E8M0)
759  if (includeF4E2M1)
761 
762  target.addDynamicallyLegalOp<arith::ExtFOp>(
763  [=](arith::ExtFOp op) {
764  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
765  Type outETy = getElementTypeOrSelf(op.getType());
766  bool legalTypes = true;
767  if (includeBf16)
768  legalTypes &= !(inETy.isBF16() && outETy.isF32());
769  if (includeF8E8M0)
770  legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
771  if (includeF4E2M1)
772  legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
773  return legalTypes;
774  });
775 
776  target.addDynamicallyLegalOp<arith::TruncFOp>(
777  [=](arith::TruncFOp op) {
778  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
779  Type outETy = getElementTypeOrSelf(op.getType());
780  bool legalTypes = true;
781  if (includeBf16)
782  legalTypes &= !(inETy.isF32() && outETy.isBF16());
783  if (includeF8E8M0)
784  legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
785  if (includeF4E2M1)
786  legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
787  return legalTypes;
788  });
789 
790  // clang-format on
791  if (failed(applyPartialConversion(getOperation(), target,
792  std::move(patterns))))
793  signalPassFailure();
794  }
795 };
796 
797 } // namespace
798 
801  patterns
802  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
803  patterns.getContext());
804 }
805 
807  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
808  patterns.getContext());
809 }
810 
812  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
813  patterns.getContext());
814 }
815 
817  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
818  patterns.getContext());
819 }
820 
823  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
824  patterns.getContext());
825 }
826 
830  // clang-format off
831  patterns.add<
832  MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
833  MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
834  MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
835  MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
836  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
837  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
838  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
839  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
840  >(patterns.getContext());
841  // clang-format on
842 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition: ExpandOps.cpp:50
static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter)
Create a float constant.
Definition: ExpandOps.cpp:38
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:629
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:806
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
Definition: ExpandOps.cpp:821
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:816
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:799
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:811
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:827
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314