MLIR 22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
20
21using namespace mlir;
22
23namespace mlir::math {
24#define GEN_PASS_DEF_MATHEXPANDOPSPASS
25#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26} // namespace mlir::math
27
28/// Create a float constant.
29static Value createFloatConst(Location loc, Type type, APFloat value,
30 OpBuilder &b) {
31 bool losesInfo = false;
32 auto eltType = getElementTypeOrSelf(type);
33 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
34 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35 APFloat::rmNearestTiesToEven, &losesInfo);
36 auto attr = b.getFloatAttr(eltType, value);
37 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
38 return arith::ConstantOp::create(b, loc,
39 DenseElementsAttr::get(shapedTy, attr));
40 }
42 return arith::ConstantOp::create(b, loc, attr);
45static Value createFloatConst(Location loc, Type type, double value,
46 OpBuilder &b) {
47 return createFloatConst(loc, type, APFloat(value), b);
49
50/// Create an integer constant.
51static Value createIntConst(Location loc, Type type, int64_t value,
52 OpBuilder &b) {
53 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
54 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
55 return arith::ConstantOp::create(b, loc,
56 DenseElementsAttr::get(shapedTy, attr));
57 }
58
59 return arith::ConstantOp::create(b, loc, attr);
60}
61
63 Type opType = operand.getType();
64 Type i64Ty = b.getI64Type();
65 if (auto shapedTy = dyn_cast<ShapedType>(opType))
66 i64Ty = shapedTy.clone(i64Ty);
67 Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
68 Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
69 // The truncation does not preserve the sign when the truncated
70 // value is -0. So here the sign is copied again.
71 return math::CopySignOp::create(b, fpFixedConvert, operand);
73
74// sinhf(float x) -> (exp(x) - exp(-x)) / 2
75static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
76 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
77 Value operand = op.getOperand();
78 Type opType = operand.getType();
79
80 Value exp = math::ExpOp::create(b, operand);
81 Value neg = arith::NegFOp::create(b, operand);
82 Value nexp = math::ExpOp::create(b, neg);
83 Value sub = arith::SubFOp::create(b, exp, nexp);
84 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
85 Value res = arith::MulFOp::create(b, sub, half);
86 rewriter.replaceOp(op, res);
87 return success();
89
90// coshf(float x) -> (exp(x) + exp(-x)) / 2
91static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
92 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
93 Value operand = op.getOperand();
94 Type opType = operand.getType();
95
96 Value exp = math::ExpOp::create(b, operand);
97 Value neg = arith::NegFOp::create(b, operand);
98 Value nexp = math::ExpOp::create(b, neg);
99 Value add = arith::AddFOp::create(b, exp, nexp);
100 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
101 Value res = arith::MulFOp::create(b, add, half);
102 rewriter.replaceOp(op, res);
103 return success();
104}
105
106/// Expands tanh op into
107/// 1-exp^{-2x} / 1+exp^{-2x}
108/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
109/// We compute a "signs" value which is -1 if input is negative and +1 if input
110/// is positive. Then multiply the input by this value, guaranteeing that the
111/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
112/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
113/// result by `sign(x)` to retain sign of the real result.
114static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
115 auto floatType = op.getOperand().getType();
116 Location loc = op.getLoc();
117 Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118 Value one = createFloatConst(loc, floatType, 1.0, rewriter);
119 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
120
121 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
129
130 // Normalize input to positive value: y = sign(x) * x
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
132
133 // Decompose on normalized input
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
139
140 // Multiply result by sign(x) to retain signs from negative inputs
141 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
142
143 return success();
144}
145
146// Converts math.tan to math.sin, math.cos, and arith.divf.
147static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
148 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
149 Value operand = op.getOperand();
150 Type type = operand.getType();
151 Value sin = math::SinOp::create(b, type, operand);
152 Value cos = math::CosOp::create(b, type, operand);
153 Value div = arith::DivFOp::create(b, type, sin, cos);
154 rewriter.replaceOp(op, div);
155 return success();
156}
157
158// asinh(float x) -> log(x + sqrt(x**2 + 1))
159static LogicalResult convertAsinhOp(math::AsinhOp op,
160 PatternRewriter &rewriter) {
161 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
162 Value operand = op.getOperand();
163 Type opType = operand.getType();
164
165 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
166 Value fma = math::FmaOp::create(b, operand, operand, one);
167 Value sqrt = math::SqrtOp::create(b, fma);
168 Value add = arith::AddFOp::create(b, operand, sqrt);
169 Value res = math::LogOp::create(b, add);
170 rewriter.replaceOp(op, res);
171 return success();
172}
173
174// acosh(float x) -> log(x + sqrt(x**2 - 1))
175static LogicalResult convertAcoshOp(math::AcoshOp op,
176 PatternRewriter &rewriter) {
177 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
178 Value operand = op.getOperand();
179 Type opType = operand.getType();
180
181 Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
182 Value fma = math::FmaOp::create(b, operand, operand, negOne);
183 Value sqrt = math::SqrtOp::create(b, fma);
184 Value add = arith::AddFOp::create(b, operand, sqrt);
185 Value res = math::LogOp::create(b, add);
186 rewriter.replaceOp(op, res);
187 return success();
188}
189
190// atanh(float x) -> log((1 + x) / (1 - x)) / 2
191static LogicalResult convertAtanhOp(math::AtanhOp op,
192 PatternRewriter &rewriter) {
193 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
194 Value operand = op.getOperand();
195 Type opType = operand.getType();
196
197 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
198 Value add = arith::AddFOp::create(b, operand, one);
199 Value neg = arith::NegFOp::create(b, operand);
200 Value sub = arith::AddFOp::create(b, neg, one);
201 Value div = arith::DivFOp::create(b, add, sub);
202 Value log = math::LogOp::create(b, div);
203 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
204 Value res = arith::MulFOp::create(b, log, half);
205 rewriter.replaceOp(op, res);
206 return success();
207}
208
209static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
210 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211 Value operandA = op.getOperand(0);
212 Value operandB = op.getOperand(1);
213 Value operandC = op.getOperand(2);
214 Type type = op.getType();
215 Value mult = arith::MulFOp::create(b, type, operandA, operandB);
216 Value add = arith::AddFOp::create(b, type, mult, operandC);
217 rewriter.replaceOp(op, add);
218 return success();
219}
220
221// Converts a ceilf() function to the following:
222// ceilf(float x) ->
223// y = (float)(int) x
224// if (x > y) then incr = 1 else incr = 0
225// y = y + incr <= replace this op with the ceilf op.
226static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
227 // Creating constants assumes the static shaped type.
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
230 return failure();
231
232 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
233 Value operand = op.getOperand();
234 Type opType = operand.getType();
235 Value fpFixedConvert = createTruncatedFPValue(operand, b);
236
237 // Creating constants for later use.
238 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
239 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
240
241 Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
242 fpFixedConvert);
243 Value incrValue =
244 arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
245
246 Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
247 rewriter.replaceOp(op, ret);
248 return success();
249}
250
251// Convert `math.fpowi` to a series of `arith.mulf` operations.
252// If the power is negative, we divide one by the result.
253// If both the base and power are zero, the result is 1.
254// In the case of non constant power, we convert the operation to `math.powf`.
255static LogicalResult convertFPowIOp(math::FPowIOp op,
256 PatternRewriter &rewriter) {
257 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
258 Value base = op.getOperand(0);
259 Value power = op.getOperand(1);
260 Type baseType = base.getType();
261
262 auto convertFPowItoPowf = [&]() -> LogicalResult {
263 Value castPowerToFp =
264 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
265 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
266 castPowerToFp);
267 rewriter.replaceOp(op, res);
268 return success();
269 };
270
271 Attribute cstAttr;
272 if (!matchPattern(power, m_Constant(&cstAttr)))
273 return convertFPowItoPowf();
274
275 APInt value;
276 if (!matchPattern(cstAttr, m_ConstantInt(&value)))
277 return convertFPowItoPowf();
278
279 int64_t powerInt = value.getSExtValue();
280 bool isNegative = powerInt < 0;
281 int64_t absPower = std::abs(powerInt);
282 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
283 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
284
285 while (absPower > 0) {
286 if (absPower & 1)
287 res = arith::MulFOp::create(b, baseType, base, res);
288 absPower >>= 1;
289 base = arith::MulFOp::create(b, baseType, base, base);
290 }
291
292 // Make sure not to introduce UB in case of negative power.
293 if (isNegative) {
294 auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
295 .getFloatSemantics();
296 Value zero =
297 createFloatConst(op->getLoc(), baseType,
298 APFloat::getZero(sem, /*Negative=*/false), rewriter);
299 Value negZero =
300 createFloatConst(op->getLoc(), baseType,
301 APFloat::getZero(sem, /*Negative=*/true), rewriter);
302 Value posInfinity =
303 createFloatConst(op->getLoc(), baseType,
304 APFloat::getInf(sem, /*Negative=*/false), rewriter);
305 Value negInfinity =
306 createFloatConst(op->getLoc(), baseType,
307 APFloat::getInf(sem, /*Negative=*/true), rewriter);
308 Value zeroEqCheck =
309 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
310 Value negZeroEqCheck =
311 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
312 res = arith::DivFOp::create(b, baseType, one, res);
313 res =
314 arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
315 res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
316 res);
317 }
318
319 rewriter.replaceOp(op, res);
320 return success();
321}
322
323// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
324// Some special cases where b is constant are handled separately:
325// when b == 0, or |b| == 0.5, 1.0, or 2.0.
326static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
327 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
328 Value operandA = op.getOperand(0);
329 Value operandB = op.getOperand(1);
330 auto typeA = operandA.getType();
331 auto typeB = operandB.getType();
332
333 auto &sem =
334 cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
335 APFloat valueB(sem);
336 auto mulf = [&](Value x, Value y) -> Value {
337 return arith::MulFOp::create(b, x, y);
338 };
339 if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
340 if (valueB.isZero()) {
341 // a^0 -> 1
342 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
343 rewriter.replaceOp(op, one);
344 return success();
345 }
346 if (valueB.isExactlyValue(1.0)) {
347 // a^1 -> a
348 rewriter.replaceOp(op, operandA);
349 return success();
350 }
351 if (valueB.isExactlyValue(-1.0)) {
352 // a^(-1) -> 1 / a
353 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
354 Value div = arith::DivFOp::create(b, one, operandA);
355 rewriter.replaceOp(op, div);
356 return success();
357 }
358 if (valueB.isExactlyValue(0.5)) {
359 // a^(1/2) -> sqrt(a)
360 Value sqrt = math::SqrtOp::create(b, operandA);
361 rewriter.replaceOp(op, sqrt);
362 return success();
363 }
364 if (valueB.isExactlyValue(-0.5)) {
365 // a^(-1/2) -> 1 / sqrt(a)
366 Value rsqrt = math::RsqrtOp::create(b, operandA);
367 rewriter.replaceOp(op, rsqrt);
368 return success();
369 }
370 if (valueB.isExactlyValue(2.0)) {
371 // a^2 -> a * a
372 rewriter.replaceOp(op, mulf(operandA, operandA));
373 return success();
374 }
375 if (valueB.isExactlyValue(-2.0)) {
376 // a^(-2) -> 1 / (a * a)
377 Value one =
378 createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
379 Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
380 rewriter.replaceOp(op, div);
381 return success();
382 }
383 if (valueB.isExactlyValue(3.0)) {
384 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
385 return success();
386 }
387 }
388
389 Value logA = math::LogOp::create(b, operandA);
390 Value mult = arith::MulFOp::create(b, operandB, logA);
391 Value expResult = math::ExpOp::create(b, mult);
392 rewriter.replaceOp(op, expResult);
393 return success();
394}
395
396// exp2f(float x) -> exp(x * ln(2))
397// Proof: Let's say 2^x = y
398// ln(2^x) = ln(y)
399// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
400static LogicalResult convertExp2fOp(math::Exp2Op op,
401 PatternRewriter &rewriter) {
402 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
403 Value operand = op.getOperand();
404 Type opType = operand.getType();
405 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
406 Value mult = arith::MulFOp::create(b, opType, operand, ln2);
407 Value exp = math::ExpOp::create(b, op->getLoc(), mult);
408 rewriter.replaceOp(op, exp);
409 return success();
410}
411
412static LogicalResult convertRoundOp(math::RoundOp op,
413 PatternRewriter &rewriter) {
414 Location loc = op.getLoc();
415 ImplicitLocOpBuilder b(loc, rewriter);
416 Value operand = op.getOperand();
417 Type opType = operand.getType();
418 Type opEType = getElementTypeOrSelf(opType);
419
420 if (!opEType.isF32()) {
421 return rewriter.notifyMatchFailure(op, "not a round of f32.");
422 }
423
424 Type i32Ty = b.getI32Type();
425 if (auto shapedTy = dyn_cast<ShapedType>(opType))
426 i32Ty = shapedTy.clone(i32Ty);
427
428 Value half = createFloatConst(loc, opType, 0.5, b);
429 Value c23 = createIntConst(loc, i32Ty, 23, b);
430 Value c127 = createIntConst(loc, i32Ty, 127, b);
431 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
432
433 Value incrValue = math::CopySignOp::create(b, half, operand);
434 Value add = arith::AddFOp::create(b, opType, operand, incrValue);
435 Value fpFixedConvert = createTruncatedFPValue(add, b);
436
437 // There are three cases where adding 0.5 to the value and truncating by
438 // converting to an i64 does not result in the correct behavior:
439 //
440 // 1. Special values: +-inf and +-nan
441 // Casting these special values to i64 has undefined behavior. To identify
442 // these values, we use the fact that these values are the only float
443 // values with the maximum possible biased exponent.
444 //
445 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
446 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
447 // errors that sometimes round the value up and sometimes round the value
448 // down. For example:
449 // 8388608.0 + 0.5 = 8388608.0
450 // 8388609.0 + 0.5 = 8388610.0
451 //
452 // 3. Very large values: |x| > INT_64_MAX
453 // Casting to i64 a value greater than the max i64 value will overflow the
454 // i64 leading to wrong outputs.
455 //
456 // All three cases satisfy the property `biasedExp >= 23`.
457 Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
458 Value operandExp = arith::AndIOp::create(
459 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
460 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
461 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
462 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
463
464 Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
465 fpFixedConvert);
466 rewriter.replaceOp(op, result);
467 return success();
468}
469
470// Converts math.ctlz to scf and arith operations. This is done
471// by performing a binary search on the bits.
472static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
473 PatternRewriter &rewriter) {
474 auto operand = op.getOperand();
475 auto operandTy = operand.getType();
476 auto eTy = getElementTypeOrSelf(operandTy);
477 Location loc = op.getLoc();
478
479 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
480 if (bitwidth > 64)
481 return failure();
482
483 uint64_t allbits = -1;
484 if (bitwidth < 64) {
485 allbits = allbits >> (64 - bitwidth);
486 }
487
488 Value x = operand;
489 Value count = createIntConst(loc, operandTy, 0, rewriter);
490 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
491 auto half = bw / 2;
492 auto bits = createIntConst(loc, operandTy, half, rewriter);
493 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
494
495 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
496 x, mask);
497 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
498 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
499
500 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
501 count = arith::SelectOp::create(rewriter, loc, pred, add, count);
502 }
503
504 Value zero = createIntConst(loc, operandTy, 0, rewriter);
505 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
506 operand, zero);
507
508 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
509 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
510 rewriter.replaceOp(op, sel);
511 return success();
512}
513
514// Convert `math.roundeven` into `math.round` + arith ops
515static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
516 PatternRewriter &rewriter) {
517 Location loc = op.getLoc();
518 ImplicitLocOpBuilder b(loc, rewriter);
519 auto operand = op.getOperand();
520 Type operandTy = operand.getType();
521 Type resultTy = op.getType();
522 Type operandETy = getElementTypeOrSelf(operandTy);
523 Type resultETy = getElementTypeOrSelf(resultTy);
524
525 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
526 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
527 }
528
529 Type fTy = operandTy;
530 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
531 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
532 iTy = shapedTy.clone(iTy);
533 }
534
535 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
536 // The width returned by getFPMantissaWidth includes the integer bit.
537 unsigned mantissaWidth =
538 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
539 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
540
541 // The names of the variables correspond to f32.
542 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
543 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
544 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
545 Value c1Float = createFloatConst(loc, fTy, 1.0, b);
546 Value c0 = createIntConst(loc, iTy, 0, b);
547 Value c1 = createIntConst(loc, iTy, 1, b);
548 Value cNeg1 = createIntConst(loc, iTy, -1, b);
549 Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
550 Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
551 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
552 Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
553 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
554 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
555
556 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
557 Value round = math::RoundOp::create(b, operand);
558 Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
559
560 // Get biased exponents for operand and round(operand)
561 Value operandExp = arith::AndIOp::create(
562 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
563 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
564 Value roundExp = arith::AndIOp::create(
565 b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
566 Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
567
568 auto safeShiftRight = [&](Value x, Value shift) -> Value {
569 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
570 Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
571 clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
572 return arith::ShRUIOp::create(b, x, clampedShift);
573 };
574
575 auto maskMantissa = [&](Value mantissa,
576 Value mantissaMaskRightShift) -> Value {
577 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
578 return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
579 };
580
581 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
582 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
583 // with `biasedExp > 23` (numbers where there is not enough precision to store
584 // decimals) are always even, and they satisfy the even condition trivially
585 // since the mantissa without all its bits is zero. The even condition
586 // is also true for +-0, since they have `biasedExp = -127` and the entire
587 // mantissa is zero. The case of +-1 has to be handled separately. Here
588 // we identify these values by noting that +-1 are the only whole numbers with
589 // `biasedExp == 0`.
590 //
591 // The special values +-inf and +-nan also satisfy the same property that
592 // whole non-unit even numbers satisfy. In particular, the special values have
593 // `biasedExp > 23`, so they get treated as large numbers with no room for
594 // decimals, which are always even.
595 Value roundBiasedExpEq0 =
596 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
597 Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
598 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
599 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
600 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
601 roundIsNotEvenOrSpecialVal =
602 arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
603
604 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
605 // integers if the bit at index `biasedExp` starting from the left in the
606 // mantissa is 1 and all the bits to the right are zero. Values with
607 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
608 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
609 // so these are handled separately. In particular, if `biasedExp == -1`, the
610 // value is halfway if the entire mantissa is zero.
611 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
612 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
613 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
614 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
615 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
616 Value operandIsHalfway =
617 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
618 expectedOperandMaskedMantissa);
619 // Ensure `biasedExp` is in the valid range for half values.
620 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
621 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
622 Value operandBiasedExpLt23 = arith::CmpIOp::create(
623 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
624 operandIsHalfway =
625 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
626 operandIsHalfway =
627 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
628
629 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
630 // case where `round` rounded in the opposite direction of `roundeven`.
631 Value sign = math::CopySignOp::create(b, c1Float, operand);
632 Value roundShifted = arith::SubFOp::create(b, round, sign);
633 // If the rounded value is even or a special value, we default to the behavior
634 // of `math.round`.
635 Value needsShift =
636 arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
637 Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
638 // The `x - sign` adjustment does not preserve the sign when we are adjusting
639 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
640 // rounded to -0.0.
641 result = math::CopySignOp::create(b, result, operand);
642 rewriter.replaceOp(op, result);
643 return success();
644}
645
646// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
647static LogicalResult convertRsqrtOp(math::RsqrtOp op,
648 PatternRewriter &rewriter) {
649
650 auto operand = op.getOperand();
651 auto operandTy = operand.getType();
652 // Operand type must be shatic shaped type to create const float.
653 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
654 if (shapedOperandType && !shapedOperandType.hasStaticShape())
655 return failure();
656
657 auto eTy = getElementTypeOrSelf(operandTy);
658 if (!isa<FloatType>(eTy))
659 return failure();
660
661 Location loc = op->getLoc();
662 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
663 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
664 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
665 return success();
666}
667
668// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
669static LogicalResult convertClampfOp(math::ClampFOp op,
670 PatternRewriter &rewriter) {
671 auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
672 op.getMin(), op.getFastmath());
673 rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
674 op.getFastmath());
675 return success();
676}
677
679 ArrayRef<StringRef> opMnemonics) {
680 auto filter = [&](StringRef name) {
681 // This should be a static assert and `consume_front` take a twine, but none
682 // is currently possible. TODO: augment `StringRef::consume_front` and make
683 // `getDialectNamespace` use `std::string_view`.
684 assert("math" == MathDialect::getDialectNamespace());
685 name.consume_front("math.");
686 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
687 };
688 if (filter(CountLeadingZerosOp::getOperationName()))
690 if (filter(SinhOp::getOperationName()))
692 if (filter(CoshOp::getOperationName()))
694 if (filter(TanOp::getOperationName()))
696 if (filter(TanhOp::getOperationName()))
698 if (filter(AsinhOp::getOperationName()))
700 if (filter(AcoshOp::getOperationName()))
702 if (filter(AtanhOp::getOperationName()))
704 if (filter(FmaOp::getOperationName()))
706 if (filter(CeilOp::getOperationName()))
708 if (filter(Exp2Op::getOperationName()))
710 if (filter(PowFOp::getOperationName()))
712 if (filter(FPowIOp::getOperationName()))
714 if (filter(RoundOp::getOperationName()))
716 if (filter(RoundEvenOp::getOperationName()))
718 if (filter(RsqrtOp::getOperationName()))
720 if (filter(ClampFOp::getOperationName()))
722}
723
724//===----------------------------------------------------------------------===//
725// MathExpandOpsPass pass
726//===----------------------------------------------------------------------===//
727namespace {
728struct MathExpandOpsPass final
729 : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
730 using MathExpandOpsPassBase::MathExpandOpsPassBase;
731
732 void runOnOperation() override {
734 SmallVector<StringRef> mnemonics =
735 llvm::to_vector_of<StringRef>(opMnemonics);
737 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
738 return signalPassFailure();
739 }
740};
741} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
Definition ExpandOps.cpp:62
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertClampfOp(math::ClampFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:91
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:75
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
Definition ExpandOps.cpp:29
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
Definition ExpandOps.cpp:51
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
#define add(a, b)
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition Matchers.h:520