MLIR  21.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Location.h"
15 #include "mlir/IR/TypeUtilities.h"
17 #include "llvm/ADT/SmallVectorExtras.h"
18 #include <cstdint>
19 
20 namespace mlir {
21 namespace arith {
22 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
23 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
24 } // namespace arith
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 /// Create an integer or index constant.
30 static Value createConst(Location loc, Type type, int value,
31  PatternRewriter &rewriter) {
32  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
33  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
34  return rewriter.create<arith::ConstantOp>(
35  loc, DenseElementsAttr::get(shapedTy, attr));
36  }
37  return rewriter.create<arith::ConstantOp>(loc, attr);
38 }
39 
40 /// Create a float constant.
41 static Value createFloatConst(Location loc, Type type, APFloat value,
42  PatternRewriter &rewriter) {
43  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
44  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45  return rewriter.create<arith::ConstantOp>(
46  loc, DenseElementsAttr::get(shapedTy, attr));
47  }
48 
49  return rewriter.create<arith::ConstantOp>(loc, attr);
50 }
51 
52 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
53 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
54  if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
55  return shapedTy.clone(cloneTo);
56  }
57  return cloneTo;
58 }
59 
60 namespace {
61 
62 /// Expands CeilDivUIOp (n, m) into
63 /// n == 0 ? 0 : ((n-1) / m) + 1
64 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
66  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Value a = op.getLhs();
70  Value b = op.getRhs();
71  Value zero = createConst(loc, a.getType(), 0, rewriter);
72  Value compare =
73  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
74  Value one = createConst(loc, a.getType(), 1, rewriter);
75  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
76  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
77  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
78  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
79  return success();
80  }
81 };
82 
83 /// Expands CeilDivSIOp (a, b) into
84 /// z = a / b
85 /// if (z * b != a && (a < 0) == (b < 0)) {
86 /// return z + 1;
87 /// } else {
88 /// return z;
89 /// }
90 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
92  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
93  PatternRewriter &rewriter) const final {
94  Location loc = op.getLoc();
95  Type type = op.getType();
96  Value a = op.getLhs();
97  Value b = op.getRhs();
98 
99  Value zero = createConst(loc, type, 0, rewriter);
100  Value one = createConst(loc, type, 1, rewriter);
101 
102  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
103  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
104  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
105  loc, arith::CmpIPredicate::ne, a, product);
106 
107  Value aNeg =
108  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
109  Value bNeg =
110  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
111 
112  Value signEqual = rewriter.create<arith::CmpIOp>(
113  loc, arith::CmpIPredicate::eq, aNeg, bNeg);
114  Value cond =
115  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
116 
117  Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
118 
119  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
120  quotient);
121  return success();
122  }
123 };
124 
125 /// Expands FloorDivSIOp (x, y) into
126 /// z = x / y
127 /// if (z * y != x && (x < 0) != (y < 0)) {
128 /// return z - 1;
129 /// } else {
130 /// return z;
131 /// }
132 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
134  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
135  PatternRewriter &rewriter) const final {
136  Location loc = op.getLoc();
137  Type type = op.getType();
138  Value a = op.getLhs();
139  Value b = op.getRhs();
140 
141  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
142  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
143  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
144  loc, arith::CmpIPredicate::ne, a, product);
145  Value zero = createConst(loc, type, 0, rewriter);
146 
147  Value aNeg =
148  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
149  Value bNeg =
150  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
151 
152  Value signOpposite = rewriter.create<arith::CmpIOp>(
153  loc, arith::CmpIPredicate::ne, aNeg, bNeg);
154  Value cond =
155  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
156 
157  Value minusOne = createConst(loc, type, -1, rewriter);
158  Value quotientMinusOne =
159  rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
160 
161  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
162  quotient);
163  return success();
164  }
165 };
166 
167 template <typename OpTy, arith::CmpIPredicate pred>
168 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
169 public:
171 
172  LogicalResult matchAndRewrite(OpTy op,
173  PatternRewriter &rewriter) const final {
174  Value lhs = op.getLhs();
175  Value rhs = op.getRhs();
176 
177  Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
178  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
179  return success();
180  }
181 };
182 
183 template <typename OpTy, arith::CmpFPredicate pred>
184 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
185 public:
187 
188  LogicalResult matchAndRewrite(OpTy op,
189  PatternRewriter &rewriter) const final {
190  Value lhs = op.getLhs();
191  Value rhs = op.getRhs();
192 
193  Location loc = op.getLoc();
194  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
195  static_assert(pred == arith::CmpFPredicate::UGT ||
196  pred == arith::CmpFPredicate::ULT,
197  "pred must be either UGT or ULT");
198  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
199  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
200 
201  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
202  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
203  rhs, rhs);
204  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
205  return success();
206  }
207 };
208 
209 template <typename OpTy, arith::CmpFPredicate pred>
210 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
211 public:
213 
214  LogicalResult matchAndRewrite(OpTy op,
215  PatternRewriter &rewriter) const final {
216  Value lhs = op.getLhs();
217  Value rhs = op.getRhs();
218 
219  Location loc = op.getLoc();
220  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
221  static_assert(pred == arith::CmpFPredicate::UGT ||
222  pred == arith::CmpFPredicate::ULT,
223  "pred must be either UGT or ULT");
224  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
225  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
226 
227  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
228  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
229  lhs, lhs);
230  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
231  return success();
232  }
233 };
234 
235 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
237  LogicalResult matchAndRewrite(arith::ExtFOp op,
238  PatternRewriter &rewriter) const final {
239  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
240  auto operand = op.getOperand();
241  Type operandTy = operand.getType();
242  Type resultTy = op.getType();
243  Type operandETy = getElementTypeOrSelf(operandTy);
244  Type resultETy = getElementTypeOrSelf(resultTy);
245 
246  if (!operandETy.isBF16() || !resultETy.isF32()) {
247  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
248  }
249 
250  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
251  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
252 
253  Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
254  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
255 
256  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
257  Value shl = b.create<arith::ShLIOp>(exti, c16);
258  Value result = b.create<arith::BitcastOp>(resultTy, shl);
259 
260  rewriter.replaceOp(op, result);
261  return success();
262  }
263 };
264 
265 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
267  LogicalResult matchAndRewrite(arith::TruncFOp op,
268  PatternRewriter &rewriter) const final {
269  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
270  auto operand = op.getOperand();
271  Type operandTy = operand.getType();
272  Type resultTy = op.getType();
273  Type operandETy = getElementTypeOrSelf(operandTy);
274  Type resultETy = getElementTypeOrSelf(resultTy);
275 
276  if (!operandETy.isF32() || !resultETy.isBF16()) {
277  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
278  }
279 
280  if (op.getRoundingmodeAttr()) {
281  return rewriter.notifyMatchFailure(
282  op, "only applicable to default rounding mode.");
283  }
284 
285  Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
286  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
287 
288  // Algorithm borrowed from this excellent code:
289  // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
290  // There is a magic idea there, to let the addition of the rounding_bias to
291  // the mantissa simply overflow into the exponent bits. It's a bit of an
292  // aggressive, obfuscating optimization, but it is well-tested code, and it
293  // results in more concise and efficient IR.
294  // The case of NaN is handled separately (see isNaN and the final select).
295  // The case of infinities is NOT handled separately, which deserves an
296  // explanation. As the encoding of infinities has zero mantissa, the
297  // rounding-bias addition never carries into the exponent so that just gets
298  // truncated away, and as bfloat16 and float32 have the same number of
299  // exponent bits, that simple truncation is the desired outcome for
300  // infinities.
301  Value isNan =
302  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
303  // Constant used to make the rounding bias.
304  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
305  // Constant used to generate a quiet NaN.
306  Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
307  // Small constants used to address bits.
308  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
309  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
310  // Reinterpret the input f32 value as bits.
311  Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
312  // Read bit 16 as a value in {0,1}.
313  Value bit16 =
314  b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
315  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
316  // on bit 16, implementing the tie-breaking "to nearest even".
317  Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
318  // Add the rounding bias. Generally we want this to be added to the
319  // mantissa, but nothing prevents this to from carrying into the exponent
320  // bits, which would feel like a bug, but this is the magic trick here:
321  // when that happens, the mantissa gets reset to zero and the exponent
322  // gets incremented by the carry... which is actually exactly what we
323  // want.
324  Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
325  // Now that the rounding-bias has been added, truncating the low bits
326  // yields the correctly rounded result.
327  Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
328  Value normalCaseResultI16 =
329  b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
330  // Select either the above-computed result, or a quiet NaN constant
331  // if the input was NaN.
332  Value select =
333  b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
334  Value result = b.create<arith::BitcastOp>(resultTy, select);
335  rewriter.replaceOp(op, result);
336  return success();
337  }
338 };
339 
340 /// In this implementation of extf we take advantage of some key patterns we
341 /// notice between the binary representation of an F4E2M1 value and its
342 /// corresponding value in F32.
343 ///
344 /// Note: x is sign bit
345 /// | Binary | F4E2M1 | f32[23:32]
346 /// | x000 | 0.0 | x000 0000 00
347 /// | x001 | 0.5 | x011 1111 00
348 /// | x010 | 1.0 | x011 1111 10
349 /// | x011 | 1.5 | x011 1111 11
350 /// | x100 | 2.0 | x010 0000 00
351 /// | x101 | 3.0 | x010 0000 01
352 /// | x110 | 4.0 | x010 0000 10
353 /// | x111 | 6.0 | x010 0000 11
354 ///
355 /// 1) There are only two versions of bits [25:31] in the f32 result
356 /// F4E2M1 bits[2:3] decide whether:
357 /// - F32 bits[25:31] = 0011 1111
358 /// - F32 bits[25:31] = 0010 0000
359 /// Exception is zero where
360 /// - F32 bits[25:31] = 0000 0000
361 ///
362 /// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
363 /// Exception is 0.5 where
364 /// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
365 ///
366 /// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
367 ///
368 /// 4) F32 bits[1:22] = 0
369 struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
371  LogicalResult matchAndRewrite(arith::ExtFOp op,
372  PatternRewriter &rewriter) const final {
373  Location loc = op.getLoc();
374  ImplicitLocOpBuilder b(loc, rewriter);
375  Value operand = op.getOperand();
376  Type operandTy = operand.getType();
377  Type resultTy = op.getType();
378  Type operandETy = getElementTypeOrSelf(operandTy);
379  Type resultETy = getElementTypeOrSelf(resultTy);
380 
381  if (!isa<Float4E2M1FNType>(operandETy))
382  return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN");
383 
384  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
385  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
386  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
387  Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand);
388 
389  Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter);
390  Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
391  Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
392  Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
393 
394  // Set last Exponent bit and Mantissa.
395  Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
396  Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2);
397  Value isHalf =
398  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1);
399  bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24);
400  bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24);
401  bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014);
402 
403  // Set first 7 bits of Exponent.
404  Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter);
405  Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
406  Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
407  Value useLargerExp =
408  b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4);
409  Value bits25To31 =
410  b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits);
411  Value zeroExp =
412  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0);
413  bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31);
414 
415  // Set sign.
416  Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter);
417  Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter);
418  Value negative =
419  b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8);
420  Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits);
421 
422  // Add segments together.
423  Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31);
424  Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32);
425  Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32);
426  if (!isa<Float32Type>(resultETy))
427  result = b.create<arith::TruncFOp>(resultTy, result);
428 
429  rewriter.replaceOp(op, result);
430  return success();
431  }
432 };
433 
434 struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
436  LogicalResult matchAndRewrite(arith::ExtFOp op,
437  PatternRewriter &rewriter) const final {
438  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
439  Value operand = op.getOperand();
440  Type operandTy = operand.getType();
441  Type resultTy = op.getType();
442  Type operandETy = getElementTypeOrSelf(operandTy);
443  Type resultETy = getElementTypeOrSelf(resultTy);
444 
445  if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
446  return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
447  }
448 
449  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
450  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
451  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
452 
453  Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
454  // create constants for NaNs
455  Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
456  Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
457  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
458 
459  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
460  Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
461 
462  Value isNan =
463  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
464  // select for NaNs
465  f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
466  Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
467  if (resultETy.getIntOrFloatBitWidth() < 32) {
468  result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
469  op.getFastmathAttr());
470  } else if (resultETy.getIntOrFloatBitWidth() > 32) {
471  result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
472  }
473  rewriter.replaceOp(op, result);
474  return success();
475  }
476 };
477 
478 /// Conversion from F32 to F4E2M1 according to the OCP Spec:
479 /// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
480 ///
481 /// The spec requiers us to perform Round to Nearest, Ties to Even.
482 ///
483 /// This means that after rounding, we should break ties by choosing the option
484 /// which results in a mantissa of 0 in the least significant digit.
485 ///
486 /// Table of representable values in F4E2M1:
487 ///
488 /// Note: x is sign bit
489 /// | Binary | F4E2M1 | F32[23:32]
490 /// | x000 | 0.0 | x000 0000 00
491 /// | x001 | 0.5 | x011 1111 00
492 /// | x010 | 1.0 | x011 1111 10
493 /// | x011 | 1.5 | x011 1111 11
494 /// | x100 | 2.0 | x010 0000 00
495 /// | x101 | 3.0 | x010 0000 01
496 /// | x110 | 4.0 | x010 0000 10
497 /// | x111 | 6.0 | x010 0000 11
498 ///
499 /// Conversion procedure:
500 /// Step 1: Clamp to representable bounds.
501 /// Step 2: Convert exponent by adjusting bias.
502 /// Step 3: Set mantissa to first bit.
503 /// Step 4: Special consideration for subnormal and zero exponent.
504 /// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
505 /// subnormal.
506 struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
508  LogicalResult matchAndRewrite(arith::TruncFOp op,
509  PatternRewriter &rewriter) const final {
510  Location loc = op.getLoc();
511  ImplicitLocOpBuilder b(loc, rewriter);
512  Value operand = op.getOperand();
513  Type operandTy = operand.getType();
514  Type resultTy = op.getType();
515  Type operandETy = getElementTypeOrSelf(operandTy);
516  Type resultETy = getElementTypeOrSelf(resultTy);
517 
518  Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
519  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
520  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
521  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
522 
523  if (!isa<Float32Type>(operandETy))
524  operand = b.create<arith::ExtFOp>(f32Ty, operand);
525  if (!isa<Float4E2M1FNType>(resultETy))
526  return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
527 
528  Value c0x1 = createConst(loc, i4Ty, 1, rewriter);
529  Value c0x3 = createConst(loc, i4Ty, 3, rewriter);
530  Value c0x00000016 = createConst(loc, i32Ty, 22, rewriter);
531  Value c0x00 = createConst(loc, i8Ty, 0x00, rewriter);
532  Value c0xff = createConst(loc, i8Ty, 0xff, rewriter);
533  Value zeroExpBits = createConst(loc, i32Ty, 0, rewriter);
534 
535  // Step 0: Clamp to bounds.
536  Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter);
537  Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter);
538  Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand);
539  operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped);
540  Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
541 
542  // Step 1: Set sign bit.
543  Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23
544  Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth);
545  Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign);
546  Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3);
547 
548  // Step 2: Convert exponent by adjusting bias.
549  Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter);
550  Value cF4MantissaWidth = c0x1; // 1
551  Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23
552  Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
553  Value biasAdjustedSignExp =
554  b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
555  Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
556  f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth);
557  f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp);
558 
559  // Step 3: Set mantissa to first bit.
560  Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter);
561  Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
562  man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
563  Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
564  f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man);
565 
566  // Step 4: Special consideration for conversion to 0.5.
567  Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter);
568  Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
569  Value isSubnormal =
570  b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
571  Value isNegOneExp =
572  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
573  Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
574  Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
575  man23Bits, zeroExpBits);
576  Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
577  Value isZeroExp =
578  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
579  Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter);
580  Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter);
581  Value subResult =
582  b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
583  subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
584  f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
585 
586  // Step 5: Round up if necessary.
587  Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter);
588  Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000...
589  Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
590  Value shouldRound =
591  b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
592  shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
593  Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
594  f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
595 
596  Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
597  rewriter.replaceOp(op, result);
598  return success();
599  }
600 };
601 
602 /*
603 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
604 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
605 they all map to NaN in F8E8M0 Type.
606 */
607 struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
609  LogicalResult matchAndRewrite(arith::TruncFOp op,
610  PatternRewriter &rewriter) const final {
611  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
612  Value operand = op.getOperand();
613  Type operandTy = operand.getType();
614  Type operandETy = getElementTypeOrSelf(operandTy);
615  Type resultTy = op.getType();
616  Type resultETy = getElementTypeOrSelf(resultTy);
617  if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
618  return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
619  }
620 
621  if (op.getRoundingmodeAttr()) {
622  return rewriter.notifyMatchFailure(
623  op, "only applicable to default rounding mode.");
624  }
625 
626  Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
627  Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
628  Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
629 
630  if (operandETy.getIntOrFloatBitWidth() < 32) {
631  operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
632  } else if (operandETy.getIntOrFloatBitWidth() > 32) {
633  operand = b.create<arith::TruncFOp>(
634  f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
635  }
636  Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
637  Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
638  Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
639  Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
640  Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
641  rewriter.replaceOp(op, result);
642  return success();
643  }
644 };
645 
646 struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
648  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
649  PatternRewriter &rewriter) const final {
650  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
651  Value inputOperand = op.getIn();
652  Value scaleOperand = op.getScale();
653  Type scaleTy = scaleOperand.getType();
654  Type scaleETy = getElementTypeOrSelf(scaleOperand);
655  // allow implicit exponent extraction from 16/32 bits floats
656  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
657  scaleETy = b.getF8E8M0Type();
658  scaleTy = cloneToShapedType(scaleTy, scaleETy);
659  scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
660  op.getFastmathAttr());
661  }
662  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
663  return rewriter.notifyMatchFailure(
664  op, "scaling_extf is using scales of type which can not be converted "
665  "to f8E8M0FNU");
666  }
667  Type resultTy = op.getType();
668  // extf on scale will essentially create floating point number
669  // of type resulTy that is 2^scale and will also propagate NaNs
670  Value scaleExt =
671  b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
672  Value inputExt =
673  b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
674  Value result =
675  b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
676  rewriter.replaceOp(op, result);
677  return success();
678  }
679 };
680 
681 /*
682 Expands arith.ScalingTruncFOp(in, scale) into
683  scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
684  result = arith.truncf(in / (2^scale))
685  */
686 struct ScalingTruncFOpConverter
687  : public OpRewritePattern<arith::ScalingTruncFOp> {
689  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
690  PatternRewriter &rewriter) const final {
691  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
692  Value inputOperand = op.getIn();
693  Value scaleOperand = op.getScale();
694  Type scaleTy = scaleOperand.getType();
695  Type scaleETy = getElementTypeOrSelf(scaleOperand);
696  // allow implicit exponent extraction from 16/32 bits floats
697  if (scaleETy.getIntOrFloatBitWidth() >= 16) {
698  scaleETy = b.getF8E8M0Type();
699  scaleTy = cloneToShapedType(scaleTy, scaleETy);
700  scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
701  op.getFastmathAttr());
702  }
703  if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
704  return rewriter.notifyMatchFailure(
705  op, "scaling_truncf is using scales type which can not be converted "
706  "to f8E8M0FNU");
707  }
708  Type resultTy = op.getType();
709  Type inputTy = inputOperand.getType();
710  // this will create a floating point number of type
711  // inputTy that is 2^scale and will also propagate NaNs
712  scaleOperand =
713  b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
714  Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
715  op.getFastmathAttr());
716  Value resultCast = b.create<arith::TruncFOp>(
717  resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
718  rewriter.replaceOp(op, resultCast);
719  return success();
720  }
721 };
722 
723 struct ArithExpandOpsPass
724  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
725  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
726 
727  void runOnOperation() override {
729  ConversionTarget target(getContext());
730 
732 
733  target.addLegalDialect<arith::ArithDialect>();
734  target.addLegalDialect<vector::VectorDialect>();
735 
736  // clang-format off
737  target.addIllegalOp<
738  arith::CeilDivSIOp,
739  arith::CeilDivUIOp,
740  arith::FloorDivSIOp,
741  arith::MaxSIOp,
742  arith::MaxUIOp,
743  arith::MinSIOp,
744  arith::MinUIOp,
745  arith::MaximumFOp,
746  arith::MinimumFOp,
747  arith::MaxNumFOp,
748  arith::MinNumFOp,
749  arith::ScalingExtFOp,
750  arith::ScalingTruncFOp
751  >();
752 
753  if (includeBf16)
755  if (includeF8E8M0)
757  if (includeF4E2M1)
759 
760  target.addDynamicallyLegalOp<arith::ExtFOp>(
761  [=](arith::ExtFOp op) {
762  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
763  Type outETy = getElementTypeOrSelf(op.getType());
764  bool legalTypes = true;
765  if (includeBf16)
766  legalTypes &= !(inETy.isBF16() && outETy.isF32());
767  if (includeF8E8M0)
768  legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
769  if (includeF4E2M1)
770  legalTypes &= !llvm::isa<Float4E2M1FNType>(inETy);
771  return legalTypes;
772  });
773 
774  target.addDynamicallyLegalOp<arith::TruncFOp>(
775  [=](arith::TruncFOp op) {
776  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
777  Type outETy = getElementTypeOrSelf(op.getType());
778  bool legalTypes = true;
779  if (includeBf16)
780  legalTypes &= !(inETy.isF32() && outETy.isBF16());
781  if (includeF8E8M0)
782  legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
783  if (includeF4E2M1)
784  legalTypes &= !llvm::isa<Float4E2M1FNType>(outETy);
785  return legalTypes;
786  });
787 
788  // clang-format on
789  if (failed(applyPartialConversion(getOperation(), target,
790  std::move(patterns))))
791  signalPassFailure();
792  }
793 };
794 
795 } // namespace
796 
799  patterns
800  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
801  patterns.getContext());
802 }
803 
805  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
806  patterns.getContext());
807 }
808 
810  patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
811  patterns.getContext());
812 }
813 
815  patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
816  patterns.getContext());
817 }
818 
821  patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
822  patterns.getContext());
823 }
824 
828  // clang-format off
829  patterns.add<
830  MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
831  MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
832  MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
833  MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
834  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
835  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
836  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
837  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
838  >(patterns.getContext());
839  // clang-format on
840 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:30
static Type cloneToShapedType(Type cloneFrom, Type cloneTo)
Creates shapedType using shape from cloneFrom and base type from cloneTo.
Definition: ExpandOps.cpp:53
static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter)
Create a float constant.
Definition: ExpandOps.cpp:41
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:804
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns)
Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops.
Definition: ExpandOps.cpp:819
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:814
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:797
void populateExpandF4E2M1Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith f4e2m1 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:809
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:825
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319