MLIR  18.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/TypeUtilities.h"
16 
17 namespace mlir {
18 namespace arith {
19 #define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace arith
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return rewriter.create<arith::ConstantOp>(
32  loc, DenseElementsAttr::get(shapedTy, attr));
33  }
34 
35  return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
37 
38 namespace {
39 
40 /// Expands CeilDivUIOp (n, m) into
41 /// n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
44  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45  PatternRewriter &rewriter) const final {
46  Location loc = op.getLoc();
47  Value a = op.getLhs();
48  Value b = op.getRhs();
49  Value zero = createConst(loc, a.getType(), 0, rewriter);
50  Value compare =
51  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52  Value one = createConst(loc, a.getType(), 1, rewriter);
53  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57  return success();
58  }
59 };
60 
61 /// Expands CeilDivSIOp (n, m) into
62 /// 1) x = (m > 0) ? -1 : 1
63 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
66  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Type type = op.getType();
70  Value a = op.getLhs();
71  Value b = op.getRhs();
72  Value plusOne = createConst(loc, type, 1, rewriter);
73  Value zero = createConst(loc, type, 0, rewriter);
74  Value minusOne = createConst(loc, type, -1, rewriter);
75  // Compute x = (b>0) ? -1 : 1.
76  Value compare =
77  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78  Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79  // Compute positive res: 1 + ((x+a)/b).
80  Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81  Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82  Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83  // Compute negative res: - ((-a)/b).
84  Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85  Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86  Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87  // Result is (a*b>0) ? pos result : neg result.
88  // Note, we want to avoid using a*b because of possible overflow.
89  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90  // not particuliarly care if a*b<0 is true or false when b is zero
91  // as this will result in an illegal divide. So `a*b<0` can be reformulated
92  // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93  // We pick the first expression here.
94  Value aNeg =
95  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96  Value aPos =
97  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98  Value bNeg =
99  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100  Value bPos =
101  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104  Value compareRes =
105  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106  // Perform substitution and return success.
107  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108  negRes);
109  return success();
110  }
111 };
112 
113 /// Expands FloorDivSIOp (n, m) into
114 /// 1) x = (m<0) ? 1 : -1
115 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
116 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
118  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
119  PatternRewriter &rewriter) const final {
120  Location loc = op.getLoc();
121  Type type = op.getType();
122  Value a = op.getLhs();
123  Value b = op.getRhs();
124  Value plusOne = createConst(loc, type, 1, rewriter);
125  Value zero = createConst(loc, type, 0, rewriter);
126  Value minusOne = createConst(loc, type, -1, rewriter);
127  // Compute x = (b<0) ? 1 : -1.
128  Value compare =
129  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
130  Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
131  // Compute negative res: -1 - ((x-a)/b).
132  Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
133  Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
134  Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
135  // Compute positive res: a/b.
136  Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
137  // Result is (a*b<0) ? negative result : positive result.
138  // Note, we want to avoid using a*b because of possible overflow.
139  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
140  // not particuliarly care if a*b<0 is true or false when b is zero
141  // as this will result in an illegal divide. So `a*b<0` can be reformulated
142  // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
143  // We pick the first expression here.
144  Value aNeg =
145  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
146  Value aPos =
147  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
148  Value bNeg =
149  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
150  Value bPos =
151  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
152  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
153  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
154  Value compareRes =
155  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
156  // Perform substitution and return success.
157  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
158  posRes);
159  return success();
160  }
161 };
162 
163 template <typename OpTy, arith::CmpFPredicate pred>
164 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
165 public:
167 
168  LogicalResult matchAndRewrite(OpTy op,
169  PatternRewriter &rewriter) const final {
170  Value lhs = op.getLhs();
171  Value rhs = op.getRhs();
172 
173  Location loc = op.getLoc();
174  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
175  static_assert(pred == arith::CmpFPredicate::UGT ||
176  pred == arith::CmpFPredicate::ULT,
177  "pred must be either UGT or ULT");
178  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
179  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
180 
181  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
182  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
183  rhs, rhs);
184  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
185  return success();
186  }
187 };
188 
189 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
191  LogicalResult matchAndRewrite(arith::ExtFOp op,
192  PatternRewriter &rewriter) const final {
193  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
194  auto operand = op.getOperand();
195  Type operandTy = operand.getType();
196  Type resultTy = op.getType();
197  Type operandETy = getElementTypeOrSelf(operandTy);
198  Type resultETy = getElementTypeOrSelf(resultTy);
199 
200  if (!operandETy.isBF16() || !resultETy.isF32()) {
201  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
202  }
203 
204  Type i16Ty = b.getI16Type();
205  Type i32Ty = b.getI32Type();
206  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
207  i16Ty = shapedTy.clone(i16Ty);
208  i32Ty = shapedTy.clone(i32Ty);
209  }
210 
211  Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
212  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
213 
214  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
215  Value shl = b.create<arith::ShLIOp>(exti, c16);
216  Value result = b.create<arith::BitcastOp>(resultTy, shl);
217 
218  rewriter.replaceOp(op, result);
219  return success();
220  }
221 };
222 
223 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
225  LogicalResult matchAndRewrite(arith::TruncFOp op,
226  PatternRewriter &rewriter) const final {
227  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
228  auto operand = op.getOperand();
229  Type operandTy = operand.getType();
230  Type resultTy = op.getType();
231  Type operandETy = getElementTypeOrSelf(operandTy);
232  Type resultETy = getElementTypeOrSelf(resultTy);
233 
234  if (!operandETy.isF32() || !resultETy.isBF16()) {
235  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
236  }
237 
238  Type i1Ty = b.getI1Type();
239  Type i16Ty = b.getI16Type();
240  Type i32Ty = b.getI32Type();
241  Type f32Ty = b.getF32Type();
242  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
243  i1Ty = shapedTy.clone(i1Ty);
244  i16Ty = shapedTy.clone(i16Ty);
245  i32Ty = shapedTy.clone(i32Ty);
246  f32Ty = shapedTy.clone(f32Ty);
247  }
248 
249  Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
250 
251  Value c23 = createConst(op.getLoc(), i32Ty, 23, rewriter);
252  Value c31 = createConst(op.getLoc(), i32Ty, 31, rewriter);
253  Value c23Mask = createConst(op.getLoc(), i32Ty, (1 << 23) - 1, rewriter);
254  Value expMask =
255  createConst(op.getLoc(), i32Ty, ((1 << 8) - 1) << 23, rewriter);
256  Value expMax =
257  createConst(op.getLoc(), i32Ty, ((1 << 8) - 2) << 23, rewriter);
258 
259  // Grab the sign bit.
260  Value sign = b.create<arith::ShRUIOp>(bitcast, c31);
261 
262  // Our mantissa rounding value depends on the sign bit and the last
263  // truncated bit.
264  Value cManRound = createConst(op.getLoc(), i32Ty, (1 << 15), rewriter);
265  cManRound = b.create<arith::SubIOp>(cManRound, sign);
266 
267  // Grab out the mantissa and directly apply rounding.
268  Value man = b.create<arith::AndIOp>(bitcast, c23Mask);
269  Value manRound = b.create<arith::AddIOp>(man, cManRound);
270 
271  // Grab the overflow bit and shift right if we overflow.
272  Value roundBit = b.create<arith::ShRUIOp>(manRound, c23);
273  Value manNew = b.create<arith::ShRUIOp>(manRound, roundBit);
274 
275  // Grab the exponent and round using the mantissa's carry bit.
276  Value exp = b.create<arith::AndIOp>(bitcast, expMask);
277  Value expCarry = b.create<arith::AddIOp>(exp, manRound);
278  expCarry = b.create<arith::AndIOp>(expCarry, expMask);
279 
280  // If the exponent is saturated, we keep the max value.
281  Value expCmp =
282  b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, exp, expMax);
283  exp = b.create<arith::SelectOp>(expCmp, exp, expCarry);
284 
285  // If the exponent is max and we rolled over, keep the old mantissa.
286  Value roundBitBool = b.create<arith::TruncIOp>(i1Ty, roundBit);
287  Value keepOldMan = b.create<arith::AndIOp>(expCmp, roundBitBool);
288  man = b.create<arith::SelectOp>(keepOldMan, man, manNew);
289 
290  // Assemble the now rounded f32 value (as an i32).
291  Value rounded = b.create<arith::ShLIOp>(sign, c31);
292  rounded = b.create<arith::OrIOp>(rounded, exp);
293  rounded = b.create<arith::OrIOp>(rounded, man);
294 
295  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
296  Value shr = b.create<arith::ShRUIOp>(rounded, c16);
297  Value trunc = b.create<arith::TruncIOp>(i16Ty, shr);
298  Value result = b.create<arith::BitcastOp>(resultTy, trunc);
299 
300  rewriter.replaceOp(op, result);
301  return success();
302  }
303 };
304 
305 struct ArithExpandOpsPass
306  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
307  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
308 
309  void runOnOperation() override {
310  RewritePatternSet patterns(&getContext());
311  ConversionTarget target(getContext());
312 
314 
315  target.addLegalDialect<arith::ArithDialect>();
316  // clang-format off
317  target.addIllegalOp<
318  arith::CeilDivSIOp,
319  arith::CeilDivUIOp,
320  arith::FloorDivSIOp,
321  arith::MaximumFOp,
322  arith::MinimumFOp
323  >();
324 
325  if (includeBf16) {
327  target.addDynamicallyLegalOp<arith::ExtFOp>(
328  [](arith::ExtFOp op) {
330  Type outETy = getElementTypeOrSelf(op.getType());
331  return !(inETy.isBF16() && outETy.isF32());
332  });
333 
334  target.addDynamicallyLegalOp<arith::TruncFOp>(
335  [](arith::TruncFOp op) {
337  Type outETy = getElementTypeOrSelf(op.getType());
338  return !(inETy.isF32() && outETy.isBF16());
339  });
340  }
341 
342  // clang-format on
343  if (failed(applyPartialConversion(getOperation(), target,
344  std::move(patterns))))
345  signalPassFailure();
346  }
347 };
348 
349 } // namespace
350 
352  RewritePatternSet &patterns) {
353  patterns
354  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
355  patterns.getContext());
356 }
357 
359  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
360  patterns.getContext());
361 }
362 
365  // clang-format off
366  patterns.add<
367  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
368  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>
369  >(patterns.getContext());
370  // clang-format on
371 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
bool isBF16() const
Definition: Types.cpp:48
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void populateExpandBFloat16Patterns(RewritePatternSet &patterns)
Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
Definition: ExpandOps.cpp:358
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:351
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:363
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:65
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361