MLIR  18.0.0git
ExpandPatterns.cpp
Go to the documentation of this file.
1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
22 
23 using namespace mlir;
24 
25 /// Create a float constant.
26 static Value createFloatConst(Location loc, Type type, double value,
27  OpBuilder &b) {
28  auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
29  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
30  return b.create<arith::ConstantOp>(loc,
31  DenseElementsAttr::get(shapedTy, attr));
32  }
33 
34  return b.create<arith::ConstantOp>(loc, attr);
35 }
36 
37 /// Create a float constant.
38 static Value createIntConst(Location loc, Type type, int64_t value,
39  OpBuilder &b) {
40  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
41  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42  return b.create<arith::ConstantOp>(loc,
43  DenseElementsAttr::get(shapedTy, attr));
44  }
45 
46  return b.create<arith::ConstantOp>(loc, attr);
47 }
48 
50  Type opType = operand.getType();
51  Type i64Ty = b.getI64Type();
52  if (auto shapedTy = dyn_cast<ShapedType>(opType))
53  i64Ty = shapedTy.clone(i64Ty);
54  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
55  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
56  // The truncation does not preserve the sign when the truncated
57  // value is -0. So here the sign is copied again.
58  return b.create<math::CopySignOp>(fpFixedConvert, operand);
59 }
60 
61 /// Expands tanh op into
62 /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
63 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
64 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
65  auto floatType = op.getOperand().getType();
66  Location loc = op.getLoc();
67  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
68  Value two = createFloatConst(loc, floatType, 2.0, rewriter);
69  Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
70 
71  // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
72  Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
73  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
74  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
75  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
76  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
77 
78  // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
79  exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
80  dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
81  divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
82  Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
83 
84  // tanh(x) = x >= 0 ? positiveRes : negativeRes
85  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
86  Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
87  op.getOperand(), zero);
88  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
89  negativeRes);
90  return success();
91 }
92 
93 // Converts math.tan to math.sin, math.cos, and arith.divf.
94 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
95  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
96  Value operand = op.getOperand();
97  Type type = operand.getType();
98  Value sin = b.create<math::SinOp>(type, operand);
99  Value cos = b.create<math::CosOp>(type, operand);
100  Value div = b.create<arith::DivFOp>(type, sin, cos);
101  rewriter.replaceOp(op, div);
102  return success();
103 }
104 
105 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
106  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
107  Value operandA = op.getOperand(0);
108  Value operandB = op.getOperand(1);
109  Value operandC = op.getOperand(2);
110  Type type = op.getType();
111  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
112  Value add = b.create<arith::AddFOp>(type, mult, operandC);
113  rewriter.replaceOp(op, add);
114  return success();
115 }
116 
117 // Converts a floorf() function to the following:
118 // floorf(float x) ->
119 // y = (float)(int) x
120 // if (x < 0) then incr = -1 else incr = 0
121 // y = y + incr <= replace this op with the floorf op.
122 static LogicalResult convertFloorOp(math::FloorOp op,
123  PatternRewriter &rewriter) {
124  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
125  Value operand = op.getOperand();
126  Type opType = operand.getType();
127  Value fpFixedConvert = createTruncatedFPValue(operand, b);
128 
129  // Creating constants for later use.
130  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
131  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
132 
133  Value negCheck =
134  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
135  Value incrValue =
136  b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
137  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
138  rewriter.replaceOp(op, ret);
139  return success();
140 }
141 
142 // Converts a ceilf() function to the following:
143 // ceilf(float x) ->
144 // y = (float)(int) x
145 // if (x > y) then incr = 1 else incr = 0
146 // y = y + incr <= replace this op with the ceilf op.
147 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
148  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
149  Value operand = op.getOperand();
150  Type opType = operand.getType();
151  Value fpFixedConvert = createTruncatedFPValue(operand, b);
152 
153  // Creating constants for later use.
154  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
155  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
156 
157  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
158  fpFixedConvert);
159  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
160 
161  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
162  rewriter.replaceOp(op, ret);
163  return success();
164 }
165 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
166 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
167  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
168  Value operandA = op.getOperand(0);
169  Value operandB = op.getOperand(1);
170  Type opType = operandA.getType();
171  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
172  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
173  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
174  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
175  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
176 
177  Value logA = b.create<math::LogOp>(opType, opASquared);
178  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
179  Value expResult = b.create<math::ExpOp>(opType, mult);
180  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
181  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
182  Value negCheck =
183  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
184  Value oddPower =
185  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
186  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
187 
188  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
189  expResult);
190  rewriter.replaceOp(op, res);
191  return success();
192 }
193 
194 // exp2f(float x) -> exp(x * ln(2))
195 // Proof: Let's say 2^x = y
196 // ln(2^x) = ln(y)
197 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y
198 static LogicalResult convertExp2fOp(math::Exp2Op op,
199  PatternRewriter &rewriter) {
200  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
201  Value operand = op.getOperand();
202  Type opType = operand.getType();
203  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
204  Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
205  Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
206  rewriter.replaceOp(op, exp);
207  return success();
208 }
209 
210 static LogicalResult convertRoundOp(math::RoundOp op,
211  PatternRewriter &rewriter) {
212  Location loc = op.getLoc();
213  ImplicitLocOpBuilder b(loc, rewriter);
214  Value operand = op.getOperand();
215  Type opType = operand.getType();
216  Type opEType = getElementTypeOrSelf(opType);
217 
218  if (!opEType.isF32()) {
219  return rewriter.notifyMatchFailure(op, "not a round of f32.");
220  }
221 
222  Type i32Ty = b.getI32Type();
223  if (auto shapedTy = dyn_cast<ShapedType>(opType))
224  i32Ty = shapedTy.clone(i32Ty);
225 
226  Value half = createFloatConst(loc, opType, 0.5, b);
227  Value c23 = createIntConst(loc, i32Ty, 23, b);
228  Value c127 = createIntConst(loc, i32Ty, 127, b);
229  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
230 
231  Value incrValue = b.create<math::CopySignOp>(half, operand);
232  Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
233  Value fpFixedConvert = createTruncatedFPValue(add, b);
234 
235  // There are three cases where adding 0.5 to the value and truncating by
236  // converting to an i64 does not result in the correct behavior:
237  //
238  // 1. Special values: +-inf and +-nan
239  // Casting these special values to i64 has undefined behavior. To identify
240  // these values, we use the fact that these values are the only float
241  // values with the maximum possible biased exponent.
242  //
243  // 2. Large values: 2^23 <= |x| <= INT_64_MAX
244  // Adding 0.5 to a float larger than or equal to 2^23 results in precision
245  // errors that sometimes round the value up and sometimes round the value
246  // down. For example:
247  // 8388608.0 + 0.5 = 8388608.0
248  // 8388609.0 + 0.5 = 8388610.0
249  //
250  // 3. Very large values: |x| > INT_64_MAX
251  // Casting to i64 a value greater than the max i64 value will overflow the
252  // i64 leading to wrong outputs.
253  //
254  // All three cases satisfy the property `biasedExp >= 23`.
255  Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
256  Value operandExp = b.create<arith::AndIOp>(
257  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
258  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
259  Value isSpecialValOrLargeVal =
260  b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
261 
262  Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
263  fpFixedConvert);
264  rewriter.replaceOp(op, result);
265  return success();
266 }
267 
268 // Converts math.ctlz to scf and arith operations. This is done
269 // by performing a binary search on the bits.
270 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
271  PatternRewriter &rewriter) {
272  auto operand = op.getOperand();
273  auto operandTy = operand.getType();
274  auto eTy = getElementTypeOrSelf(operandTy);
275  Location loc = op.getLoc();
276 
277  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
278  if (bitwidth > 64)
279  return failure();
280 
281  uint64_t allbits = -1;
282  if (bitwidth < 64) {
283  allbits = allbits >> (64 - bitwidth);
284  }
285 
286  Value x = operand;
287  Value count = createIntConst(loc, operandTy, 0, rewriter);
288  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
289  auto half = bw / 2;
290  auto bits = createIntConst(loc, operandTy, half, rewriter);
291  auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
292 
293  Value pred =
294  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
295  Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
296  Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
297 
298  x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
299  count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
300  }
301 
302  Value zero = createIntConst(loc, operandTy, 0, rewriter);
303  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
304  operand, zero);
305 
306  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
307  Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
308  rewriter.replaceOp(op, sel);
309  return success();
310 }
311 
312 // Convert `math.roundeven` into `math.round` + arith ops
313 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
314  PatternRewriter &rewriter) {
315  Location loc = op.getLoc();
316  ImplicitLocOpBuilder b(loc, rewriter);
317  auto operand = op.getOperand();
318  Type operandTy = operand.getType();
319  Type resultTy = op.getType();
320  Type operandETy = getElementTypeOrSelf(operandTy);
321  Type resultETy = getElementTypeOrSelf(resultTy);
322 
323  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
324  return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
325  }
326 
327  Type fTy = operandTy;
328  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
329  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
330  iTy = shapedTy.clone(iTy);
331  }
332 
333  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
334  // The width returned by getFPMantissaWidth includes the integer bit.
335  unsigned mantissaWidth =
336  llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
337  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
338 
339  // The names of the variables correspond to f32.
340  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
341  // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
342  // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
343  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
344  Value c0 = createIntConst(loc, iTy, 0, b);
345  Value c1 = createIntConst(loc, iTy, 1, b);
346  Value cNeg1 = createIntConst(loc, iTy, -1, b);
347  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
348  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
349  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
350  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
351  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
352  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
353 
354  Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
355  Value round = b.create<math::RoundOp>(operand);
356  Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
357 
358  // Get biased exponents for operand and round(operand)
359  Value operandExp = b.create<arith::AndIOp>(
360  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
361  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
362  Value roundExp = b.create<arith::AndIOp>(
363  b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
364  Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
365 
366  auto safeShiftRight = [&](Value x, Value shift) -> Value {
367  // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
368  Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
369  clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
370  return b.create<arith::ShRUIOp>(x, clampedShift);
371  };
372 
373  auto maskMantissa = [&](Value mantissa,
374  Value mantissaMaskRightShift) -> Value {
375  Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
376  return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
377  };
378 
379  // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
380  // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
381  // with `biasedExp > 23` (numbers where there is not enough precision to store
382  // decimals) are always even, and they satisfy the even condition trivially
383  // since the mantissa without all its bits is zero. The even condition
384  // is also true for +-0, since they have `biasedExp = -127` and the entire
385  // mantissa is zero. The case of +-1 has to be handled separately. Here
386  // we identify these values by noting that +-1 are the only whole numbers with
387  // `biasedExp == 0`.
388  //
389  // The special values +-inf and +-nan also satisfy the same property that
390  // whole non-unit even numbers satisfy. In particular, the special values have
391  // `biasedExp > 23`, so they get treated as large numbers with no room for
392  // decimals, which are always even.
393  Value roundBiasedExpEq0 =
394  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
395  Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
396  Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
397  Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
398  arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
399  roundIsNotEvenOrSpecialVal =
400  b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
401 
402  // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
403  // integers if the bit at index `biasedExp` starting from the left in the
404  // mantissa is 1 and all the bits to the right are zero. Values with
405  // `biasedExp >= 23` don't have decimals, so they are never halfway. The
406  // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
407  // so these are handled separately. In particular, if `biasedExp == -1`, the
408  // value is halfway if the entire mantissa is zero.
409  Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
410  arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
411  Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
412  operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
413  Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
414  Value operandIsHalfway =
415  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
416  expectedOperandMaskedMantissa);
417  // Ensure `biasedExp` is in the valid range for half values.
418  Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
419  arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
420  Value operandBiasedExpLt23 =
421  b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
422  operandIsHalfway =
423  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
424  operandIsHalfway =
425  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
426 
427  // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
428  // case where `round` rounded in the opposite direction of `roundeven`.
429  Value sign = b.create<math::CopySignOp>(c1Float, operand);
430  Value roundShifted = b.create<arith::SubFOp>(round, sign);
431  // If the rounded value is even or a special value, we default to the behavior
432  // of `math.round`.
433  Value needsShift =
434  b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
435  Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
436  // The `x - sign` adjustment does not preserve the sign when we are adjusting
437  // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
438  // rounded to -0.0.
439  result = b.create<math::CopySignOp>(result, operand);
440  rewriter.replaceOp(op, result);
441  return success();
442 }
443 
445  patterns.add(convertCtlzOp);
446 }
447 
449  patterns.add(convertTanOp);
450 }
451 
453  patterns.add(convertTanhOp);
454 }
455 
457  patterns.add(convertFmaFOp);
458 }
459 
461  patterns.add(convertCeilOp);
462 }
463 
465  patterns.add(convertExp2fOp);
466 }
467 
469  patterns.add(convertPowfOp);
470 }
471 
473  patterns.add(convertRoundOp);
474 }
475 
477  patterns.add(convertFloorOp);
478 }
479 
481  patterns.add(convertRoundEvenOp);
482 }
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static Value createFloatConst(Location loc, Type type, double value, OpBuilder &b)
Create a float constant.
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create a float constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 2) exp^{2x}-1 / exp^{2x}+1 ,...
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI64Type()
Definition: Builders.cpp:85
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandFloorFPattern(RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26