22#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
39static FailureOr<FuncOp>
42 auto i32Type = IntegerType::get(symTable->getContext(), 32);
43 auto i64Type = IntegerType::get(symTable->getContext(), 64);
44 std::string funcName = (llvm::Twine(
"_mlir_apfloat_") + name).str();
46 {i32Type, i64Type, i64Type}, symbolTables);
50template <
typename OpTy>
65 FailureOr<FuncOp> fn =
74 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
77 auto floatTy = cast<FloatType>(resultType);
78 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
79 auto int64Type = rewriter.getI64Type();
80 Value lhsBits = arith::ExtUIOp::create(
81 rewriter, loc, int64Type,
82 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
83 Value rhsBits = arith::ExtUIOp::create(
84 rewriter, loc, int64Type,
85 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
88 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
89 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
90 auto resultOp = func::CallOp::create(rewriter, loc,
92 SymbolRefAttr::get(*fn), params);
95 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
96 resultOp->getResult(0));
97 return arith::BitcastOp::create(rewriter, loc, floatTy,
100 rewriter.replaceOp(op, repl);
108template <
typename OpTy>
120 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
121 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
122 FailureOr<FuncOp> fn =
124 {i32Type, i32Type, i64Type});
132 rewriter, loc, op.getOperand(),
Value(), op.getType(),
135 auto inFloatTy = cast<FloatType>(operand1.getType());
136 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
137 Value operandBits = arith::ExtUIOp::create(
138 rewriter, loc, i64Type,
139 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
142 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
143 auto outFloatTy = cast<FloatType>(resultType);
145 getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
146 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
147 auto resultOp = func::CallOp::create(rewriter, loc,
149 SymbolRefAttr::get(*fn), params);
152 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
153 Value truncatedBits = arith::TruncIOp::create(
154 rewriter, loc, outIntWType, resultOp->getResult(0));
155 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
158 rewriter.replaceOp(op, repl);
165template <
typename OpTy>
178 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
179 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
180 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
181 FailureOr<FuncOp> fn =
183 {i32Type, i32Type, i1Type, i64Type});
191 rewriter, loc, op.getOperand(),
Value(), op.getType(),
194 auto inFloatTy = cast<FloatType>(operand1.getType());
195 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
196 Value operandBits = arith::ExtUIOp::create(
197 rewriter, loc, i64Type,
198 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
201 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
202 auto outIntTy = cast<IntegerType>(resultType);
203 Value outWidthValue = arith::ConstantOp::create(
204 rewriter, loc, i32Type,
205 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
206 Value isUnsignedValue = arith::ConstantOp::create(
207 rewriter, loc, i1Type,
208 rewriter.getIntegerAttr(i1Type, isUnsigned));
209 SmallVector<Value> params = {inSemValue, outWidthValue,
210 isUnsignedValue, operandBits};
211 auto resultOp = func::CallOp::create(rewriter, loc,
213 SymbolRefAttr::get(*fn), params);
216 return arith::TruncIOp::create(rewriter, loc, outIntTy,
217 resultOp->getResult(0));
219 rewriter.replaceOp(op, repl);
227template <
typename OpTy>
240 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
241 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
242 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
244 rewriter,
symTable,
"_mlir_apfloat_convert_from_int",
245 {i32Type, i32Type, i1Type, i64Type});
253 rewriter, loc, op.getOperand(),
Value(), op.getType(),
256 auto inIntTy = cast<IntegerType>(operand1.getType());
257 Value operandBits = operand1;
258 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
261 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
264 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
269 auto outFloatTy = cast<FloatType>(resultType);
272 Value inWidthValue = arith::ConstantOp::create(
273 rewriter, loc, i32Type,
275 Value isUnsignedValue = arith::ConstantOp::create(
276 rewriter, loc, i1Type,
279 isUnsignedValue, operandBits};
280 auto resultOp = func::CallOp::create(rewriter, loc,
282 SymbolRefAttr::get(*fn), params);
285 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
286 Value truncatedBits = arith::TruncIOp::create(
287 rewriter, loc, outIntWType, resultOp->getResult(0));
288 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
291 rewriter.replaceOp(op, repl);
310 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
311 auto i8Type = IntegerType::get(
symTable->getContext(), 8);
312 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
313 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
314 FailureOr<FuncOp> fn =
316 {i32Type, i64Type, i64Type},
nullptr, i8Type);
324 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
327 auto floatTy = cast<FloatType>(lhs.getType());
328 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
329 Value lhsBits = arith::ExtUIOp::create(
330 rewriter, loc, i64Type,
331 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
332 Value rhsBits = arith::ExtUIOp::create(
333 rewriter, loc, i64Type,
334 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
337 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
338 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
339 Value comparisonResult =
340 func::CallOp::create(rewriter, loc,
TypeRange(i8Type),
341 SymbolRefAttr::get(*fn), params)
346 auto checkResult = [&](llvm::APFloat::cmpResult val) {
347 return arith::CmpIOp::create(
348 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
349 arith::ConstantOp::create(
350 rewriter, loc, i8Type,
358 Value first = checkResult(vals.front());
359 if (vals.size() == 1)
361 Value rest = checkResults(vals.drop_front());
362 return arith::OrIOp::create(rewriter, loc, first, rest)
368 switch (op.getPredicate()) {
369 case arith::CmpFPredicate::AlwaysFalse:
371 arith::ConstantOp::create(rewriter, loc, i1Type,
375 case arith::CmpFPredicate::OEQ:
376 result = checkResult(llvm::APFloat::cmpEqual);
378 case arith::CmpFPredicate::OGT:
379 result = checkResult(llvm::APFloat::cmpGreaterThan);
381 case arith::CmpFPredicate::OGE:
383 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
385 case arith::CmpFPredicate::OLT:
386 result = checkResult(llvm::APFloat::cmpLessThan);
388 case arith::CmpFPredicate::OLE:
390 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
392 case arith::CmpFPredicate::ONE:
395 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
397 case arith::CmpFPredicate::ORD:
399 result = checkResults({llvm::APFloat::cmpLessThan,
400 llvm::APFloat::cmpGreaterThan,
401 llvm::APFloat::cmpEqual});
403 case arith::CmpFPredicate::UEQ:
405 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
407 case arith::CmpFPredicate::UGT:
409 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
411 case arith::CmpFPredicate::UGE:
412 result = checkResults({llvm::APFloat::cmpUnordered,
413 llvm::APFloat::cmpGreaterThan,
414 llvm::APFloat::cmpEqual});
416 case arith::CmpFPredicate::ULT:
418 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
420 case arith::CmpFPredicate::ULE:
421 result = checkResults({llvm::APFloat::cmpUnordered,
422 llvm::APFloat::cmpLessThan,
423 llvm::APFloat::cmpEqual});
425 case arith::CmpFPredicate::UNE:
427 result = checkResults({llvm::APFloat::cmpLessThan,
428 llvm::APFloat::cmpGreaterThan,
429 llvm::APFloat::cmpUnordered});
431 case arith::CmpFPredicate::UNO:
432 result = checkResult(llvm::APFloat::cmpUnordered);
434 case arith::CmpFPredicate::AlwaysTrue:
436 arith::ConstantOp::create(rewriter, loc, i1Type,
443 rewriter.replaceOp(op, repl);
452template <
typename OpTy>
466 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
467 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
468 std::string funcName = (llvm::Twine(
"_mlir_apfloat_") +
APFloatName).str();
469 FailureOr<FuncOp> fn =
478 rewriter, loc, op.getOperand(),
Value(), op.getType(),
481 auto floatTy = cast<FloatType>(operand1.getType());
482 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
483 Value operandBits = arith::ExtUIOp::create(
484 rewriter, loc, i64Type,
485 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
488 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
489 SmallVector<Value> params = {semValue, operandBits};
491 func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
492 SymbolRefAttr::get(*fn), params)
496 Value truncatedBits =
497 arith::TruncIOp::create(rewriter, loc, intWType, resultBits);
498 return arith::BitcastOp::create(rewriter, loc, floatTy,
501 rewriter.replaceOp(op, repl);
510struct ArithToAPFloatConversionPass final
511 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
514 void runOnOperation()
override;
517void ArithToAPFloatConversionPass::runOnOperation() {
519 RewritePatternSet patterns(context);
520 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context,
"add",
522 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
523 context,
"subtract", getOperation());
524 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
525 context,
"multiply", getOperation());
526 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
527 context,
"divide", getOperation());
528 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
529 context,
"remainder", getOperation());
530 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
531 context,
"minnum", getOperation());
532 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
533 context,
"maxnum", getOperation());
534 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
535 context,
"minimum", getOperation());
536 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
537 context,
"maximum", getOperation());
538 patterns.add<FpToFpConversion<arith::ExtFOp>,
539 FpToFpConversion<arith::TruncFOp>, CmpFOpToAPFloatConversion>(
540 context, getOperation());
541 patterns.add<UnaryFloatOpToAPFloatConversion<arith::NegFOp>>(context,
"neg",
543 patterns.add<UnaryFloatOpToAPFloatConversion<arith::FlushDenormalsOp>>(
544 context,
"flush_denormals", getOperation());
545 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
547 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
549 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
551 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
554 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
555 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
564 return signalPassFailure();
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Include the generated interface declarations.
LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, llvm::function_ref< Value(Value, Value, Type)> fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
UnaryFloatOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})