MLIR  21.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/TypeUtilities.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return rewriter.create<arith::ConstantOp>(
32  loc, DenseElementsAttr::get(shapedTy, attr));
33  }
34 
35  return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
37 
38 namespace {
39 
40 /// Expands CeilDivUIOp (n, m) into
41 /// n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
44  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45  PatternRewriter &rewriter) const final {
46  Location loc = op.getLoc();
47  Value a = op.getLhs();
48  Value b = op.getRhs();
49  Value zero = createConst(loc, a.getType(), 0, rewriter);
50  Value compare =
51  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52  Value one = createConst(loc, a.getType(), 1, rewriter);
53  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57  return success();
58  }
59 };
60 
61 /// Expands CeilDivSIOp (a, b) into
62 /// z = a / b
63 /// if (z * b != a && (a < 0) == (b < 0)) {
64 /// return z + 1;
65 /// } else {
66 /// return z;
67 /// }
68 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
70  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
71  PatternRewriter &rewriter) const final {
72  Location loc = op.getLoc();
73  Type type = op.getType();
74  Value a = op.getLhs();
75  Value b = op.getRhs();
76 
77  Value zero = createConst(loc, type, 0, rewriter);
78  Value one = createConst(loc, type, 1, rewriter);
79 
80  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
81  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
82  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
83  loc, arith::CmpIPredicate::ne, a, product);
84 
85  Value aNeg =
86  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
87  Value bNeg =
88  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
89 
90  Value signEqual = rewriter.create<arith::CmpIOp>(
91  loc, arith::CmpIPredicate::eq, aNeg, bNeg);
92  Value cond =
93  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
94 
95  Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
96 
97  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
98  quotient);
99  return success();
100  }
101 };
102 
103 /// Expands FloorDivSIOp (x, y) into
104 /// z = x / y
105 /// if (z * y != x && (x < 0) != (y < 0)) {
106 /// return z - 1;
107 /// } else {
108 /// return z;
109 /// }
110 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
112  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
113  PatternRewriter &rewriter) const final {
114  Location loc = op.getLoc();
115  Type type = op.getType();
116  Value a = op.getLhs();
117  Value b = op.getRhs();
118 
119  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
120  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
121  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
122  loc, arith::CmpIPredicate::ne, a, product);
123  Value zero = createConst(loc, type, 0, rewriter);
124 
125  Value aNeg =
126  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
127  Value bNeg =
128  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
129 
130  Value signOpposite = rewriter.create<arith::CmpIOp>(
131  loc, arith::CmpIPredicate::ne, aNeg, bNeg);
132  Value cond =
133  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
134 
135  Value minusOne = createConst(loc, type, -1, rewriter);
136  Value quotientMinusOne =
137  rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
138 
139  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
140  quotient);
141  return success();
142  }
143 };
144 
145 template <typename OpTy, arith::CmpIPredicate pred>
146 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
147 public:
149 
150  LogicalResult matchAndRewrite(OpTy op,
151  PatternRewriter &rewriter) const final {
152  Value lhs = op.getLhs();
153  Value rhs = op.getRhs();
154 
155  Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
156  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
157  return success();
158  }
159 };
160 
161 template <typename OpTy, arith::CmpFPredicate pred>
162 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
163 public:
165 
166  LogicalResult matchAndRewrite(OpTy op,
167  PatternRewriter &rewriter) const final {
168  Value lhs = op.getLhs();
169  Value rhs = op.getRhs();
170 
171  Location loc = op.getLoc();
172  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
173  static_assert(pred == arith::CmpFPredicate::UGT ||
174  pred == arith::CmpFPredicate::ULT,
175  "pred must be either UGT or ULT");
176  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
177  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
178 
179  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
180  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
181  rhs, rhs);
182  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
183  return success();
184  }
185 };
186 
187 template <typename OpTy, arith::CmpFPredicate pred>
188 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
189 public:
191 
192  LogicalResult matchAndRewrite(OpTy op,
193  PatternRewriter &rewriter) const final {
194  Value lhs = op.getLhs();
195  Value rhs = op.getRhs();
196 
197  Location loc = op.getLoc();
198  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
199  static_assert(pred == arith::CmpFPredicate::UGT ||
200  pred == arith::CmpFPredicate::ULT,
201  "pred must be either UGT or ULT");
202  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
203  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
204 
205  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
206  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
207  lhs, lhs);
208  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
209  return success();
210  }
211 };
212 
213 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
215  LogicalResult matchAndRewrite(arith::ExtFOp op,
216  PatternRewriter &rewriter) const final {
217  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
218  auto operand = op.getOperand();
219  Type operandTy = operand.getType();
220  Type resultTy = op.getType();
221  Type operandETy = getElementTypeOrSelf(operandTy);
222  Type resultETy = getElementTypeOrSelf(resultTy);
223 
224  if (!operandETy.isBF16() || !resultETy.isF32()) {
225  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
226  }
227 
228  Type i16Ty = b.getI16Type();
229  Type i32Ty = b.getI32Type();
230  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231  i16Ty = shapedTy.clone(i16Ty);
232  i32Ty = shapedTy.clone(i32Ty);
233  }
234 
235  Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
236  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
237 
238  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
239  Value shl = b.create<arith::ShLIOp>(exti, c16);
240  Value result = b.create<arith::BitcastOp>(resultTy, shl);
241 
242  rewriter.replaceOp(op, result);
243  return success();
244  }
245 };
246 
247 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
249  LogicalResult matchAndRewrite(arith::TruncFOp op,
250  PatternRewriter &rewriter) const final {
251  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
252  auto operand = op.getOperand();
253  Type operandTy = operand.getType();
254  Type resultTy = op.getType();
255  Type operandETy = getElementTypeOrSelf(operandTy);
256  Type resultETy = getElementTypeOrSelf(resultTy);
257 
258  if (!operandETy.isF32() || !resultETy.isBF16()) {
259  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
260  }
261 
262  if (op.getRoundingmodeAttr()) {
263  return rewriter.notifyMatchFailure(
264  op, "only applicable to default rounding mode.");
265  }
266 
267  Type i16Ty = b.getI16Type();
268  Type i32Ty = b.getI32Type();
269  Type f32Ty = b.getF32Type();
270  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271  i16Ty = shapedTy.clone(i16Ty);
272  i32Ty = shapedTy.clone(i32Ty);
273  f32Ty = shapedTy.clone(f32Ty);
274  }
275 
276  // Algorithm borrowed from this excellent code:
277  // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
278  // There is a magic idea there, to let the addition of the rounding_bias to
279  // the mantissa simply overflow into the exponent bits. It's a bit of an
280  // aggressive, obfuscating optimization, but it is well-tested code, and it
281  // results in more concise and efficient IR.
282  // The case of NaN is handled separately (see isNaN and the final select).
283  // The case of infinities is NOT handled separately, which deserves an
284  // explanation. As the encoding of infinities has zero mantissa, the
285  // rounding-bias addition never carries into the exponent so that just gets
286  // truncated away, and as bfloat16 and float32 have the same number of
287  // exponent bits, that simple truncation is the desired outcome for
288  // infinities.
289  Value isNan =
290  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
291  // Constant used to make the rounding bias.
292  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
293  // Constant used to generate a quiet NaN.
294  Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
295  // Small constants used to address bits.
296  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
297  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
298  // Reinterpret the input f32 value as bits.
299  Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
300  // Read bit 16 as a value in {0,1}.
301  Value bit16 =
302  b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
303  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
304  // on bit 16, implementing the tie-breaking "to nearest even".
305  Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
306  // Add the rounding bias. Generally we want this to be added to the
307  // mantissa, but nothing prevents this to from carrying into the exponent
308  // bits, which would feel like a bug, but this is the magic trick here:
309  // when that happens, the mantissa gets reset to zero and the exponent
310  // gets incremented by the carry... which is actually exactly what we
311  // want.
312  Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
313  // Now that the rounding-bias has been added, truncating the low bits
314  // yields the correctly rounded result.
315  Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
316  Value normalCaseResult_i16 =
317  b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
318  // Select either the above-computed result, or a quiet NaN constant
319  // if the input was NaN.
320  Value select =
321  b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
322  Value result = b.create<arith::BitcastOp>(resultTy, select);
323  rewriter.replaceOp(op, result);
324  return success();
325  }
326 };
327 
328 struct ArithExpandOpsPass
329  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
331 
332  void runOnOperation() override {
334  ConversionTarget target(getContext());
335 
337 
338  target.addLegalDialect<arith::ArithDialect>();
339  // clang-format off
340  target.addIllegalOp<
341  arith::CeilDivSIOp,
342  arith::CeilDivUIOp,
343  arith::FloorDivSIOp,
344  arith::MaxSIOp,
345  arith::MaxUIOp,
346  arith::MinSIOp,
347  arith::MinUIOp,
348  arith::MaximumFOp,
349  arith::MinimumFOp,
350  arith::MaxNumFOp,
351  arith::MinNumFOp
352  >();
353 
354  if (includeBf16) {
356  target.addDynamicallyLegalOp<arith::ExtFOp>(
357  [](arith::ExtFOp op) {
358  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
359  Type outETy = getElementTypeOrSelf(op.getType());
360  return !(inETy.isBF16() && outETy.isF32());
361  });
362 
363  target.addDynamicallyLegalOp<arith::TruncFOp>(
364  [](arith::TruncFOp op) {
365  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
366  Type outETy = getElementTypeOrSelf(op.getType());
367  return !(inETy.isF32() && outETy.isBF16());
368  });
369  }
370 
371  // clang-format on
372  if (failed(applyPartialConversion(getOperation(), target,
373  std::move(patterns))))
374  signalPassFailure();
375  }
376 };
377 
378 } // namespace
379 
382  patterns
383  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
384  patterns.getContext());
385 }
386 
388  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
389  patterns.getContext());
390 }
391 
394  // clang-format off
395  patterns.add<
396  MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
397  MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
398  MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
399  MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
400  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
401  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
402  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
403  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
404  >(patterns.getContext());
405  // clang-format on
406 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isBF16() const
Definition: Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:387
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:380
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:392
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319