17 using namespace mlir::complex;
27 void ConstantOp::getAsmResultNames(
29 setNameFn(getResult(),
"cst");
32 bool ConstantOp::isBuildableWith(
Attribute value,
Type type) {
33 if (
auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34 auto complexTy = llvm::dyn_cast<ComplexType>(type);
35 if (!complexTy || arrAttr.size() != 2)
37 auto complexEltTy = complexTy.getElementType();
38 if (
auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40 return im && fre.getType() == complexEltTy &&
41 im.getType() == complexEltTy;
43 if (
auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45 return im && ire.getType() == complexEltTy &&
46 im.getType() == complexEltTy;
53 ArrayAttr arrayAttr = getValue();
54 if (arrayAttr.size() != 2) {
56 "requires 'value' to be a complex constant, represented as array of "
60 auto complexEltTy =
getType().getElementType();
61 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
62 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
64 "requires attribute's elements to be float or integer attributes");
65 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
66 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
69 <<
"requires attribute's element types (" << re.getType() <<
", "
71 <<
") to match the element type of the op's return type ("
72 << complexEltTy <<
")";
89 auto operandType = getOperand().getType();
93 if (operandType == resultType)
96 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97 return emitOpError(
"operand must be int/float/complex");
100 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101 return emitOpError(
"result must be int/float/complex");
104 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
106 "requires that either input or output has a complex type");
109 if (isa<ComplexType>(resultType))
110 std::swap(operandType, resultType);
112 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
114 .getIntOrFloatBitWidth() *
116 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
118 if (operandBitwidth != resultBitwidth) {
119 return emitOpError(
"casting bitwidths do not match");
130 if (
auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131 if (isa<ComplexType>(op.getType()) ||
132 isa<ComplexType>(defining.getOperand().getType())) {
135 defining.getOperand());
138 defining.getOperand());
143 if (
auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
145 defining.getOperand());
158 if (
auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
160 defining.getOperand());
179 if (
auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180 if (
auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181 if (reOp.getOperand() == imOp.getOperand()) {
182 return reOp.getOperand();
194 ArrayAttr arrayAttr =
195 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196 if (arrayAttr && arrayAttr.size() == 2)
198 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
199 return createOp.getOperand(1);
204 template <
typename OpKind,
int ComponentIndex>
208 LogicalResult matchAndRewrite(OpKind op,
210 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
214 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
218 Type elementType = createOp.getType().getElementType();
219 assert(isa<FloatType>(elementType));
222 op, elementType, createOp.getOperand(ComponentIndex));
230 results.
add<FoldComponentNeg<ImOp, 1>>(context);
238 ArrayAttr arrayAttr =
239 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240 if (arrayAttr && arrayAttr.size() == 2)
242 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
243 return createOp.getOperand(0);
249 results.
add<FoldComponentNeg<ReOp, 0>>(context);
258 if (
auto sub = getLhs().getDefiningOp<SubOp>())
259 if (getRhs() == sub.getRhs())
263 if (
auto sub = getRhs().getDefiningOp<SubOp>())
264 if (getLhs() == sub.getRhs())
268 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269 auto arrayAttr = constantOp.getValue();
270 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
285 if (
auto add = getLhs().getDefiningOp<AddOp>())
286 if (getRhs() == add.getRhs())
290 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291 auto arrayAttr = constantOp.getValue();
292 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
307 if (
auto negOp = getOperand().getDefiningOp<NegOp>())
308 return negOp.getOperand();
319 if (
auto expOp = getOperand().getDefiningOp<ExpOp>())
320 return expOp.getOperand();
331 if (
auto logOp = getOperand().getDefiningOp<LogOp>())
332 return logOp.getOperand();
343 if (
auto conjOp = getOperand().getDefiningOp<ConjOp>())
344 return conjOp.getOperand();
354 auto constant = getRhs().getDefiningOp<ConstantOp>();
358 ArrayAttr arrayAttr = constant.getValue();
359 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
360 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
366 if (real == APFloat(real.getSemantics(), 1))
377 auto rhs = adaptor.getRhs();
381 ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
382 if (!arrayAttr || arrayAttr.size() != 2)
385 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
386 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
392 if (real == APFloat(real.getSemantics(), 1))
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...