MLIR  20.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/TypeUtilities.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return rewriter.create<arith::ConstantOp>(
32  loc, DenseElementsAttr::get(shapedTy, attr));
33  }
34 
35  return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
37 
38 namespace {
39 
40 /// Expands CeilDivUIOp (n, m) into
41 /// n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
44  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45  PatternRewriter &rewriter) const final {
46  Location loc = op.getLoc();
47  Value a = op.getLhs();
48  Value b = op.getRhs();
49  Value zero = createConst(loc, a.getType(), 0, rewriter);
50  Value compare =
51  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52  Value one = createConst(loc, a.getType(), 1, rewriter);
53  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57  return success();
58  }
59 };
60 
61 /// Expands CeilDivSIOp (n, m) into
62 /// 1) x = (m > 0) ? -1 : 1
63 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
66  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Type type = op.getType();
70  Value a = op.getLhs();
71  Value b = op.getRhs();
72  Value plusOne = createConst(loc, type, 1, rewriter);
73  Value zero = createConst(loc, type, 0, rewriter);
74  Value minusOne = createConst(loc, type, -1, rewriter);
75  // Compute x = (b>0) ? -1 : 1.
76  Value compare =
77  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78  Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79  // Compute positive res: 1 + ((x+a)/b).
80  Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81  Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82  Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83  // Compute negative res: - ((-a)/b).
84  Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85  Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86  Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87  // Result is (a*b>0) ? pos result : neg result.
88  // Note, we want to avoid using a*b because of possible overflow.
89  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90  // not particuliarly care if a*b<0 is true or false when b is zero
91  // as this will result in an illegal divide. So `a*b<0` can be reformulated
92  // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93  // We pick the first expression here.
94  Value aNeg =
95  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96  Value aPos =
97  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98  Value bNeg =
99  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100  Value bPos =
101  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104  Value compareRes =
105  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106  // Perform substitution and return success.
107  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108  negRes);
109  return success();
110  }
111 };
112 
113 /// Expands FloorDivSIOp (x, y) into
114 /// z = x / y
115 /// if (z * y != x && (x < 0) != (y < 0)) {
116 /// return z - 1;
117 /// } else {
118 /// return z;
119 /// }
120 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
122  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123  PatternRewriter &rewriter) const final {
124  Location loc = op.getLoc();
125  Type type = op.getType();
126  Value a = op.getLhs();
127  Value b = op.getRhs();
128 
129  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132  loc, arith::CmpIPredicate::ne, a, product);
133  Value zero = createConst(loc, type, 0, rewriter);
134 
135  Value aNeg =
136  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137  Value bNeg =
138  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
139 
140  Value signOpposite = rewriter.create<arith::CmpIOp>(
141  loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142  Value cond =
143  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144 
145  Value minusOne = createConst(loc, type, -1, rewriter);
146  Value quotientMinusOne =
147  rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
148 
149  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150  quotient);
151  return success();
152  }
153 };
154 
155 template <typename OpTy, arith::CmpIPredicate pred>
156 struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
157 public:
159 
160  LogicalResult matchAndRewrite(OpTy op,
161  PatternRewriter &rewriter) const final {
162  Value lhs = op.getLhs();
163  Value rhs = op.getRhs();
164 
165  Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
166  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
167  return success();
168  }
169 };
170 
171 template <typename OpTy, arith::CmpFPredicate pred>
172 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
173 public:
175 
176  LogicalResult matchAndRewrite(OpTy op,
177  PatternRewriter &rewriter) const final {
178  Value lhs = op.getLhs();
179  Value rhs = op.getRhs();
180 
181  Location loc = op.getLoc();
182  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
183  static_assert(pred == arith::CmpFPredicate::UGT ||
184  pred == arith::CmpFPredicate::ULT,
185  "pred must be either UGT or ULT");
186  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
187  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
188 
189  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
190  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
191  rhs, rhs);
192  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
193  return success();
194  }
195 };
196 
197 template <typename OpTy, arith::CmpFPredicate pred>
198 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
199 public:
201 
202  LogicalResult matchAndRewrite(OpTy op,
203  PatternRewriter &rewriter) const final {
204  Value lhs = op.getLhs();
205  Value rhs = op.getRhs();
206 
207  Location loc = op.getLoc();
208  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
209  static_assert(pred == arith::CmpFPredicate::UGT ||
210  pred == arith::CmpFPredicate::ULT,
211  "pred must be either UGT or ULT");
212  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
213  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
214 
215  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
216  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
217  lhs, lhs);
218  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
219  return success();
220  }
221 };
222 
223 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225  LogicalResult matchAndRewrite(arith::ExtFOp op,
226  PatternRewriter &rewriter) const final {
227  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
228  auto operand = op.getOperand();
229  Type operandTy = operand.getType();
230  Type resultTy = op.getType();
231  Type operandETy = getElementTypeOrSelf(operandTy);
232  Type resultETy = getElementTypeOrSelf(resultTy);
233 
234  if (!operandETy.isBF16() || !resultETy.isF32()) {
235  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
236  }
237 
238  Type i16Ty = b.getI16Type();
239  Type i32Ty = b.getI32Type();
240  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
241  i16Ty = shapedTy.clone(i16Ty);
242  i32Ty = shapedTy.clone(i32Ty);
243  }
244 
245  Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
246  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
247 
248  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
249  Value shl = b.create<arith::ShLIOp>(exti, c16);
250  Value result = b.create<arith::BitcastOp>(resultTy, shl);
251 
252  rewriter.replaceOp(op, result);
253  return success();
254  }
255 };
256 
257 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
259  LogicalResult matchAndRewrite(arith::TruncFOp op,
260  PatternRewriter &rewriter) const final {
261  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
262  auto operand = op.getOperand();
263  Type operandTy = operand.getType();
264  Type resultTy = op.getType();
265  Type operandETy = getElementTypeOrSelf(operandTy);
266  Type resultETy = getElementTypeOrSelf(resultTy);
267 
268  if (!operandETy.isF32() || !resultETy.isBF16()) {
269  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
270  }
271 
272  if (op.getRoundingmodeAttr()) {
273  return rewriter.notifyMatchFailure(
274  op, "only applicable to default rounding mode.");
275  }
276 
277  Type i16Ty = b.getI16Type();
278  Type i32Ty = b.getI32Type();
279  Type f32Ty = b.getF32Type();
280  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
281  i16Ty = shapedTy.clone(i16Ty);
282  i32Ty = shapedTy.clone(i32Ty);
283  f32Ty = shapedTy.clone(f32Ty);
284  }
285 
286  // Algorithm borrowed from this excellent code:
287  // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
288  // There is a magic idea there, to let the addition of the rounding_bias to
289  // the mantissa simply overflow into the exponent bits. It's a bit of an
290  // aggressive, obfuscating optimization, but it is well-tested code, and it
291  // results in more concise and efficient IR.
292  // The case of NaN is handled separately (see isNaN and the final select).
293  // The case of infinities is NOT handled separately, which deserves an
294  // explanation. As the encoding of infinities has zero mantissa, the
295  // rounding-bias addition never carries into the exponent so that just gets
296  // truncated away, and as bfloat16 and float32 have the same number of
297  // exponent bits, that simple truncation is the desired outcome for
298  // infinities.
299  Value isNan =
300  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
301  // Constant used to make the rounding bias.
302  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
303  // Constant used to generate a quiet NaN.
304  Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
305  // Small constants used to address bits.
306  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
307  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
308  // Reinterpret the input f32 value as bits.
309  Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
310  // Read bit 16 as a value in {0,1}.
311  Value bit16 =
312  b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
313  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
314  // on bit 16, implementing the tie-breaking "to nearest even".
315  Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
316  // Add the rounding bias. Generally we want this to be added to the
317  // mantissa, but nothing prevents this to from carrying into the exponent
318  // bits, which would feel like a bug, but this is the magic trick here:
319  // when that happens, the mantissa gets reset to zero and the exponent
320  // gets incremented by the carry... which is actually exactly what we
321  // want.
322  Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
323  // Now that the rounding-bias has been added, truncating the low bits
324  // yields the correctly rounded result.
325  Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
326  Value normalCaseResult_i16 =
327  b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
328  // Select either the above-computed result, or a quiet NaN constant
329  // if the input was NaN.
330  Value select =
331  b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
332  Value result = b.create<arith::BitcastOp>(resultTy, select);
333  rewriter.replaceOp(op, result);
334  return success();
335  }
336 };
337 
338 struct ArithExpandOpsPass
339  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
340  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
341 
342  void runOnOperation() override {
343  RewritePatternSet patterns(&getContext());
344  ConversionTarget target(getContext());
345 
347 
348  target.addLegalDialect<arith::ArithDialect>();
349  // clang-format off
350  target.addIllegalOp<
351  arith::CeilDivSIOp,
352  arith::CeilDivUIOp,
353  arith::FloorDivSIOp,
354  arith::MaxSIOp,
355  arith::MaxUIOp,
356  arith::MinSIOp,
357  arith::MinUIOp,
358  arith::MaximumFOp,
359  arith::MinimumFOp,
360  arith::MaxNumFOp,
361  arith::MinNumFOp
362  >();
363 
364  if (includeBf16) {
366  target.addDynamicallyLegalOp<arith::ExtFOp>(
367  [](arith::ExtFOp op) {
368  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
369  Type outETy = getElementTypeOrSelf(op.getType());
370  return !(inETy.isBF16() && outETy.isF32());
371  });
372 
373  target.addDynamicallyLegalOp<arith::TruncFOp>(
374  [](arith::TruncFOp op) {
375  Type inETy = getElementTypeOrSelf(op.getOperand().getType());
376  Type outETy = getElementTypeOrSelf(op.getType());
377  return !(inETy.isF32() && outETy.isBF16());
378  });
379  }
380 
381  // clang-format on
382  if (failed(applyPartialConversion(getOperation(), target,
383  std::move(patterns))))
384  signalPassFailure();
385  }
386 };
387 
388 } // namespace
389 
391  RewritePatternSet &patterns) {
392  patterns
393  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
394  patterns.getContext());
395 }
396 
398  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
399  patterns.getContext());
400 }
401 
404  // clang-format off
405  patterns.add<
406  MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
407  MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
408  MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
409  MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
410  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
411  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
412  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
413  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
414  >(patterns.getContext());
415  // clang-format on
416 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isBF16() const
Definition: Types.cpp:56
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:397
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:390
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:402
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:68
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362