MLIR 22.0.0git
ArithToAPFloat.cpp
Go to the documentation of this file.
1//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include "mlir/IR/Verifier.h"
19
20namespace mlir {
21#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::func;
27
28static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
29 StringRef name, FunctionType funcT, bool setPrivate,
30 SymbolTableCollection *symbolTables = nullptr) {
32 assert(!symTable->getRegion(0).empty() && "expected non-empty region");
33 b.setInsertionPointToStart(&symTable->getRegion(0).front());
34 FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
35 if (setPrivate)
36 funcOp.setPrivate();
37 if (symbolTables) {
38 SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
39 symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
40 }
41 return funcOp;
42}
43
44/// Helper function to look up or create the symbol for a runtime library
45/// function with the given parameter types. Returns an int64_t, unless a
46/// different result type is specified.
47static FailureOr<FuncOp>
48lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
49 StringRef name, TypeRange paramTypes,
50 SymbolTableCollection *symbolTables = nullptr,
51 Type resultType = {}) {
52 if (!resultType)
53 resultType = IntegerType::get(symTable->getContext(), 64);
54 std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
55 auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
56 FailureOr<FuncOp> func =
57 lookupFnDecl(symTable, funcName, funcT, symbolTables);
58 // Failed due to type mismatch.
59 if (failed(func))
60 return func;
61 // Successfully matched existing decl.
62 if (*func)
63 return *func;
64
65 return createFnDecl(b, symTable, funcName, funcT,
66 /*setPrivate=*/true, symbolTables);
67}
68
69/// Helper function to look up or create the symbol for a runtime library
70/// function for a binary arithmetic operation.
71///
72/// Parameter 1: APFloat semantics
73/// Parameter 2: Left-hand side operand
74/// Parameter 3: Right-hand side operand
75///
76/// This function will return a failure if the function is found but has an
77/// unexpected signature.
78///
79static FailureOr<FuncOp>
80lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
81 SymbolTableCollection *symbolTables = nullptr) {
82 auto i32Type = IntegerType::get(symTable->getContext(), 32);
83 auto i64Type = IntegerType::get(symTable->getContext(), 64);
84 return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type},
85 symbolTables);
86}
87
88static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) {
89 int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
90 return arith::ConstantOp::create(b, loc, b.getI32Type(),
91 b.getIntegerAttr(b.getI32Type(), sem));
92}
93
94/// Given two operands of vector type and vector result type (with the same
95/// shape), call the given function for each pair of scalar operands and
96/// package the result into a vector. If the given operands and result type are
97/// not vectors, call the function directly. The second operand is optional.
98template <typename Fn, typename... Values>
100 Value operand1, Value operand2, Type resultType,
101 Fn fn) {
102 auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
103 if (operand2) {
104 // Sanity check: Operand types must match.
105 assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
106 "expected same vector types");
107 }
108 if (!vecTy1) {
109 // Not a vector. Call the function directly.
110 return fn(operand1, operand2, resultType);
111 }
112
113 // Prepare scalar operands.
114 ResultRange sclars1 =
115 vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
116 SmallVector<Value> scalars2;
117 if (!operand2) {
118 // No second operand. Create a vector of empty values.
119 scalars2.assign(vecTy1.getNumElements(), Value());
120 } else {
121 llvm::append_range(
122 scalars2,
123 vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
124 }
125
126 // Call the function for each pair of scalar operands.
127 auto resultVecType = cast<VectorType>(resultType);
128 SmallVector<Value> results;
129 for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
130 Value result = fn(scalar1, scalar2, resultVecType.getElementType());
131 results.push_back(result);
132 }
133
134 // Package the results into a vector.
135 return vector::FromElementsOp::create(
136 rewriter, loc,
137 vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
138 results);
139}
140
141/// Check preconditions for the conversion:
142/// 1. All operands / results must be integers or floats (or vectors thereof).
143/// 2. The bitwidth of the operands / results must be <= 64.
144static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
145 for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
146 Type type = value.getType();
147 if (auto vecTy = dyn_cast<VectorType>(type)) {
148 type = vecTy.getElementType();
149 }
150 if (!type.isIntOrFloat()) {
151 return rewriter.notifyMatchFailure(
152 op, "only integers and floats (or vectors thereof) are supported");
153 }
154 if (type.getIntOrFloatBitWidth() > 64)
155 return rewriter.notifyMatchFailure(op,
156 "bitwidth > 64 bits is not supported");
157 }
158 return success();
159}
160
161/// Rewrite a binary arithmetic operation to an APFloat function call.
162template <typename OpTy>
165 const char *APFloatName,
166 SymbolOpInterface symTable,
167 PatternBenefit benefit = 1)
168 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
170
171 LogicalResult matchAndRewrite(OpTy op,
172 PatternRewriter &rewriter) const override {
173 if (failed(checkPreconditions(rewriter, op)))
174 return failure();
175
176 // Get APFloat function from runtime library.
177 FailureOr<FuncOp> fn =
179 if (failed(fn))
180 return fn;
181
182 // Scalarize and convert to APFloat runtime calls.
183 Location loc = op.getLoc();
184 rewriter.setInsertionPoint(op);
186 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
187 [&](Value lhs, Value rhs, Type resultType) {
188 // Cast operands to 64-bit integers.
189 auto floatTy = cast<FloatType>(resultType);
190 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
191 auto int64Type = rewriter.getI64Type();
192 Value lhsBits = arith::ExtUIOp::create(
193 rewriter, loc, int64Type,
194 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
195 Value rhsBits = arith::ExtUIOp::create(
196 rewriter, loc, int64Type,
197 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
198
199 // Call APFloat function.
200 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
201 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
202 auto resultOp = func::CallOp::create(rewriter, loc,
203 TypeRange(rewriter.getI64Type()),
204 SymbolRefAttr::get(*fn), params);
205
206 // Truncate result to the original width.
207 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
208 resultOp->getResult(0));
209 return arith::BitcastOp::create(rewriter, loc, floatTy,
210 truncatedBits);
211 });
212 rewriter.replaceOp(op, repl);
213 return success();
216 SymbolOpInterface symTable;
217 const char *APFloatName;
218};
219
220template <typename OpTy>
221struct FpToFpConversion final : OpRewritePattern<OpTy> {
222 FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
223 PatternBenefit benefit = 1)
224 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
226 LogicalResult matchAndRewrite(OpTy op,
227 PatternRewriter &rewriter) const override {
228 if (failed(checkPreconditions(rewriter, op)))
229 return failure();
230
231 // Get APFloat function from runtime library.
232 auto i32Type = IntegerType::get(symTable->getContext(), 32);
233 auto i64Type = IntegerType::get(symTable->getContext(), 64);
234 FailureOr<FuncOp> fn = lookupOrCreateApFloatFn(
235 rewriter, symTable, "convert", {i32Type, i32Type, i64Type});
236 if (failed(fn))
237 return fn;
238
239 // Scalarize and convert to APFloat runtime calls.
240 Location loc = op.getLoc();
241 rewriter.setInsertionPoint(op);
243 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
244 [&](Value operand1, Value operand2, Type resultType) {
245 // Cast operands to 64-bit integers.
246 auto inFloatTy = cast<FloatType>(operand1.getType());
247 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
248 Value operandBits = arith::ExtUIOp::create(
249 rewriter, loc, i64Type,
250 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
251
252 // Call APFloat function.
253 Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
254 auto outFloatTy = cast<FloatType>(resultType);
255 Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
256 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
257 auto resultOp = func::CallOp::create(rewriter, loc,
259 SymbolRefAttr::get(*fn), params);
260
261 // Truncate result to the original width.
262 auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
263 Value truncatedBits = arith::TruncIOp::create(
264 rewriter, loc, outIntWType, resultOp->getResult(0));
265 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
266 truncatedBits);
267 });
268 rewriter.replaceOp(op, repl);
269 return success();
270 }
271
272 SymbolOpInterface symTable;
273};
274
275template <typename OpTy>
277 FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
278 bool isUnsigned, PatternBenefit benefit = 1)
279 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
281
282 LogicalResult matchAndRewrite(OpTy op,
283 PatternRewriter &rewriter) const override {
284 if (failed(checkPreconditions(rewriter, op)))
285 return failure();
286
287 // Get APFloat function from runtime library.
288 auto i1Type = IntegerType::get(symTable->getContext(), 1);
289 auto i32Type = IntegerType::get(symTable->getContext(), 32);
290 auto i64Type = IntegerType::get(symTable->getContext(), 64);
291 FailureOr<FuncOp> fn =
292 lookupOrCreateApFloatFn(rewriter, symTable, "convert_to_int",
293 {i32Type, i32Type, i1Type, i64Type});
294 if (failed(fn))
295 return fn;
296
297 // Scalarize and convert to APFloat runtime calls.
298 Location loc = op.getLoc();
299 rewriter.setInsertionPoint(op);
301 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
302 [&](Value operand1, Value operand2, Type resultType) {
303 // Cast operands to 64-bit integers.
304 auto inFloatTy = cast<FloatType>(operand1.getType());
305 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
306 Value operandBits = arith::ExtUIOp::create(
307 rewriter, loc, i64Type,
308 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
309
310 // Call APFloat function.
311 Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy);
312 auto outIntTy = cast<IntegerType>(resultType);
313 Value outWidthValue = arith::ConstantOp::create(
314 rewriter, loc, i32Type,
315 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
316 Value isUnsignedValue = arith::ConstantOp::create(
317 rewriter, loc, i1Type,
318 rewriter.getIntegerAttr(i1Type, isUnsigned));
319 SmallVector<Value> params = {inSemValue, outWidthValue,
320 isUnsignedValue, operandBits};
321 auto resultOp = func::CallOp::create(rewriter, loc,
322 TypeRange(rewriter.getI64Type()),
323 SymbolRefAttr::get(*fn), params);
324
325 // Truncate result to the original width.
326 return arith::TruncIOp::create(rewriter, loc, outIntTy,
327 resultOp->getResult(0));
328 });
329 rewriter.replaceOp(op, repl);
330 return success();
331 }
332
333 SymbolOpInterface symTable;
335};
336
337template <typename OpTy>
339 IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
340 bool isUnsigned, PatternBenefit benefit = 1)
341 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
343
344 LogicalResult matchAndRewrite(OpTy op,
345 PatternRewriter &rewriter) const override {
346 if (failed(checkPreconditions(rewriter, op)))
347 return failure();
348
349 // Get APFloat function from runtime library.
350 auto i1Type = IntegerType::get(symTable->getContext(), 1);
351 auto i32Type = IntegerType::get(symTable->getContext(), 32);
352 auto i64Type = IntegerType::get(symTable->getContext(), 64);
353 FailureOr<FuncOp> fn =
354 lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
355 {i32Type, i32Type, i1Type, i64Type});
356 if (failed(fn))
357 return fn;
358
359 // Scalarize and convert to APFloat runtime calls.
360 Location loc = op.getLoc();
361 rewriter.setInsertionPoint(op);
363 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
364 [&](Value operand1, Value operand2, Type resultType) {
365 // Cast operands to 64-bit integers.
366 auto inIntTy = cast<IntegerType>(operand1.getType());
367 Value operandBits = operand1;
368 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
369 if (isUnsigned) {
370 operandBits =
371 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
372 } else {
373 operandBits =
374 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
375 }
376 }
377
378 // Call APFloat function.
379 auto outFloatTy = cast<FloatType>(resultType);
380 Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
381 Value inWidthValue = arith::ConstantOp::create(
382 rewriter, loc, i32Type,
383 rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
384 Value isUnsignedValue = arith::ConstantOp::create(
385 rewriter, loc, i1Type,
386 rewriter.getIntegerAttr(i1Type, isUnsigned));
387 SmallVector<Value> params = {outSemValue, inWidthValue,
388 isUnsignedValue, operandBits};
389 auto resultOp = func::CallOp::create(rewriter, loc,
390 TypeRange(rewriter.getI64Type()),
391 SymbolRefAttr::get(*fn), params);
392
393 // Truncate result to the original width.
394 auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
395 Value truncatedBits = arith::TruncIOp::create(
396 rewriter, loc, outIntWType, resultOp->getResult(0));
397 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
398 truncatedBits);
399 });
400 rewriter.replaceOp(op, repl);
401 return success();
402 }
403
404 SymbolOpInterface symTable;
406};
407
408struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
409 CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
410 PatternBenefit benefit = 1)
411 : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
412
413 LogicalResult matchAndRewrite(arith::CmpFOp op,
414 PatternRewriter &rewriter) const override {
415 if (failed(checkPreconditions(rewriter, op)))
416 return failure();
417
418 // Get APFloat function from runtime library.
419 auto i1Type = IntegerType::get(symTable->getContext(), 1);
420 auto i8Type = IntegerType::get(symTable->getContext(), 8);
421 auto i32Type = IntegerType::get(symTable->getContext(), 32);
422 auto i64Type = IntegerType::get(symTable->getContext(), 64);
423 FailureOr<FuncOp> fn =
424 lookupOrCreateApFloatFn(rewriter, symTable, "compare",
425 {i32Type, i64Type, i64Type}, nullptr, i8Type);
426 if (failed(fn))
427 return fn;
428
429 // Scalarize and convert to APFloat runtime calls.
430 Location loc = op.getLoc();
431 rewriter.setInsertionPoint(op);
433 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
434 [&](Value lhs, Value rhs, Type resultType) {
435 // Cast operands to 64-bit integers.
436 auto floatTy = cast<FloatType>(lhs.getType());
437 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
438 Value lhsBits = arith::ExtUIOp::create(
439 rewriter, loc, i64Type,
440 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
441 Value rhsBits = arith::ExtUIOp::create(
442 rewriter, loc, i64Type,
443 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
444
445 // Call APFloat function.
446 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
447 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
448 Value comparisonResult =
449 func::CallOp::create(rewriter, loc, TypeRange(i8Type),
450 SymbolRefAttr::get(*fn), params)
451 ->getResult(0);
452
453 // Generate an i1 SSA value that is "true" if the comparison result
454 // matches the given `val`.
455 auto checkResult = [&](llvm::APFloat::cmpResult val) {
456 return arith::CmpIOp::create(
457 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
458 arith::ConstantOp::create(
459 rewriter, loc, i8Type,
460 rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
461 .getResult());
462 };
463 // Generate an i1 SSA value that is "true" if the comparison result
464 // matches any of the given `vals`.
466 checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
467 Value first = checkResult(vals.front());
468 if (vals.size() == 1)
469 return first;
470 Value rest = checkResults(vals.drop_front());
471 return arith::OrIOp::create(rewriter, loc, first, rest)
472 .getResult();
473 };
474
475 // This switch-case statement was taken from arith::applyCmpPredicate.
477 switch (op.getPredicate()) {
478 case arith::CmpFPredicate::AlwaysFalse:
479 result =
480 arith::ConstantOp::create(rewriter, loc, i1Type,
481 rewriter.getIntegerAttr(i1Type, 0))
482 .getResult();
483 break;
484 case arith::CmpFPredicate::OEQ:
485 result = checkResult(llvm::APFloat::cmpEqual);
486 break;
487 case arith::CmpFPredicate::OGT:
488 result = checkResult(llvm::APFloat::cmpGreaterThan);
489 break;
490 case arith::CmpFPredicate::OGE:
491 result = checkResults(
492 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
493 break;
494 case arith::CmpFPredicate::OLT:
495 result = checkResult(llvm::APFloat::cmpLessThan);
496 break;
497 case arith::CmpFPredicate::OLE:
498 result = checkResults(
499 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
500 break;
501 case arith::CmpFPredicate::ONE:
502 // Not cmpUnordered and not cmpUnordered.
503 result = checkResults(
504 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
505 break;
506 case arith::CmpFPredicate::ORD:
507 // Not cmpUnordered.
508 result = checkResults({llvm::APFloat::cmpLessThan,
509 llvm::APFloat::cmpGreaterThan,
510 llvm::APFloat::cmpEqual});
511 break;
512 case arith::CmpFPredicate::UEQ:
513 result = checkResults(
514 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
515 break;
516 case arith::CmpFPredicate::UGT:
517 result = checkResults(
518 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
519 break;
520 case arith::CmpFPredicate::UGE:
521 result = checkResults({llvm::APFloat::cmpUnordered,
522 llvm::APFloat::cmpGreaterThan,
523 llvm::APFloat::cmpEqual});
524 break;
525 case arith::CmpFPredicate::ULT:
526 result = checkResults(
527 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
528 break;
529 case arith::CmpFPredicate::ULE:
530 result = checkResults({llvm::APFloat::cmpUnordered,
531 llvm::APFloat::cmpLessThan,
532 llvm::APFloat::cmpEqual});
533 break;
534 case arith::CmpFPredicate::UNE:
535 // Not cmpEqual.
536 result = checkResults({llvm::APFloat::cmpLessThan,
537 llvm::APFloat::cmpGreaterThan,
538 llvm::APFloat::cmpUnordered});
539 break;
540 case arith::CmpFPredicate::UNO:
541 result = checkResult(llvm::APFloat::cmpUnordered);
542 break;
543 case arith::CmpFPredicate::AlwaysTrue:
544 result =
545 arith::ConstantOp::create(rewriter, loc, i1Type,
546 rewriter.getIntegerAttr(i1Type, 1))
547 .getResult();
548 break;
549 }
550 return result;
551 });
552 rewriter.replaceOp(op, repl);
553 return success();
554 }
555
556 SymbolOpInterface symTable;
557};
558
559struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
560 NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
561 PatternBenefit benefit = 1)
562 : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
563
564 LogicalResult matchAndRewrite(arith::NegFOp op,
565 PatternRewriter &rewriter) const override {
566 if (failed(checkPreconditions(rewriter, op)))
567 return failure();
568
569 // Get APFloat function from runtime library.
570 auto i32Type = IntegerType::get(symTable->getContext(), 32);
571 auto i64Type = IntegerType::get(symTable->getContext(), 64);
572 FailureOr<FuncOp> fn =
573 lookupOrCreateApFloatFn(rewriter, symTable, "neg", {i32Type, i64Type});
574 if (failed(fn))
575 return fn;
576
577 // Scalarize and convert to APFloat runtime calls.
578 Location loc = op.getLoc();
579 rewriter.setInsertionPoint(op);
581 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
582 [&](Value operand1, Value operand2, Type resultType) {
583 // Cast operands to 64-bit integers.
584 auto floatTy = cast<FloatType>(operand1.getType());
585 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
586 Value operandBits = arith::ExtUIOp::create(
587 rewriter, loc, i64Type,
588 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
589
590 // Call APFloat function.
591 Value semValue = getSemanticsValue(rewriter, loc, floatTy);
592 SmallVector<Value> params = {semValue, operandBits};
593 Value negatedBits =
594 func::CallOp::create(rewriter, loc, TypeRange(i64Type),
595 SymbolRefAttr::get(*fn), params)
596 ->getResult(0);
597
598 // Truncate result to the original width.
599 Value truncatedBits =
600 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
601 return arith::BitcastOp::create(rewriter, loc, floatTy,
602 truncatedBits);
603 });
604 rewriter.replaceOp(op, repl);
605 return success();
606 }
607
608 SymbolOpInterface symTable;
609};
610
611namespace {
612struct ArithToAPFloatConversionPass final
613 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
614 using Base::Base;
615
616 void runOnOperation() override;
617};
618
619void ArithToAPFloatConversionPass::runOnOperation() {
620 MLIRContext *context = &getContext();
621 RewritePatternSet patterns(context);
622 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
623 getOperation());
624 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
625 context, "subtract", getOperation());
626 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
627 context, "multiply", getOperation());
628 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
629 context, "divide", getOperation());
630 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
631 context, "remainder", getOperation());
632 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
633 context, "minnum", getOperation());
634 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
635 context, "maxnum", getOperation());
636 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
637 context, "minimum", getOperation());
638 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
639 context, "maximum", getOperation());
641 .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
642 CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
643 context, getOperation());
644 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
645 /*isUnsigned=*/false);
646 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
647 /*isUnsigned=*/true);
648 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
649 /*isUnsigned=*/false);
650 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
651 /*isUnsigned=*/true);
652 LogicalResult result = success();
653 ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
654 if (diag.getSeverity() == DiagnosticSeverity::Error) {
655 result = failure();
656 }
657 // NB: if you don't return failure, no other diag handlers will fire (see
658 // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
659 return failure();
660 });
661 walkAndApplyPatterns(getOperation(), std::move(patterns));
662 if (failed(result))
663 return signalPassFailure();
664}
665} // namespace
return success()
static Value forEachScalarValue(RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, Fn fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static FailureOr< FuncOp > lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, FunctionType funcT, bool setPrivate, SymbolTableCollection *symbolTables=nullptr)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents a collection of SymbolTables.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< FuncOp > lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables=nullptr)
Look up a FuncOp with signature resultTypes(paramTypes) and name / name`.
Definition Utils.cpp:259
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})