22OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
24void ConstantOp::getAsmResultNames(
26 setNameFn(getResult(),
"cst");
30 if (
auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
31 auto complexTy = llvm::dyn_cast<ComplexType>(type);
32 if (!complexTy || arrAttr.size() != 2)
34 auto complexEltTy = complexTy.getElementType();
35 if (
auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
36 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
37 return im && fre.getType() == complexEltTy &&
38 im.getType() == complexEltTy;
40 if (
auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
41 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
42 return im && ire.getType() == complexEltTy &&
43 im.getType() == complexEltTy;
49LogicalResult ConstantOp::verify() {
51 if (arrayAttr.size() != 2) {
53 "requires 'value' to be a complex constant, represented as array of "
57 auto complexEltTy =
getType().getElementType();
58 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
59 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
61 "requires attribute's elements to be float or integer attributes");
62 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
63 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
64 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
66 <<
"requires attribute's element types (" << re.getType() <<
", "
68 <<
") to match the element type of the op's return type ("
69 << complexEltTy <<
")";
85LogicalResult BitcastOp::verify() {
86 auto operandType = getOperand().getType();
90 if (operandType == resultType)
93 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
94 return emitOpError(
"operand must be int/float/complex");
97 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
98 return emitOpError(
"result must be int/float/complex");
101 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
103 "requires that either input or output has a complex type");
106 if (isa<ComplexType>(resultType))
107 std::swap(operandType, resultType);
109 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
111 .getIntOrFloatBitWidth() *
113 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
115 if (operandBitwidth != resultBitwidth) {
116 return emitOpError(
"casting bitwidths do not match");
127 if (
auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
128 if (isa<ComplexType>(op.getType()) ||
129 isa<ComplexType>(defining.getOperand().getType())) {
132 defining.getOperand());
135 defining.getOperand());
140 if (
auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
142 defining.getOperand());
155 if (
auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
157 defining.getOperand());
176 if (
auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
177 if (
auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
178 if (reOp.getOperand() == imOp.getOperand()) {
179 return reOp.getOperand();
192 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
193 if (arrayAttr && arrayAttr.size() == 2)
195 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
196 return createOp.getOperand(1);
201template <
typename OpKind,
int ComponentIndex>
203 using OpRewritePattern<OpKind>::OpRewritePattern;
205 LogicalResult matchAndRewrite(OpKind op,
206 PatternRewriter &rewriter)
const override {
207 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215 Type elementType = createOp.getType().getElementType();
216 assert(isa<FloatType>(elementType));
219 op, elementType, createOp.getOperand(ComponentIndex));
227 results.
add<FoldComponentNeg<ImOp, 1>>(context);
236 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
237 if (arrayAttr && arrayAttr.size() == 2)
239 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
240 return createOp.getOperand(0);
246 results.
add<FoldComponentNeg<ReOp, 0>>(context);
255 if (
auto sub = getLhs().getDefiningOp<SubOp>())
256 if (getRhs() == sub.getRhs())
260 if (
auto sub = getRhs().getDefiningOp<SubOp>())
261 if (getLhs() == sub.getRhs())
265 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
266 auto arrayAttr = constantOp.getValue();
267 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
268 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
282 if (
auto add = getLhs().getDefiningOp<AddOp>())
283 if (getRhs() ==
add.getRhs())
287 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
288 auto arrayAttr = constantOp.getValue();
289 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
290 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
304 if (
auto negOp = getOperand().getDefiningOp<NegOp>())
305 return negOp.getOperand();
316 if (
auto expOp = getOperand().getDefiningOp<ExpOp>())
317 return expOp.getOperand();
328 if (
auto logOp = getOperand().getDefiningOp<LogOp>())
329 return logOp.getOperand();
340 if (
auto conjOp = getOperand().getDefiningOp<ConjOp>())
341 return conjOp.getOperand();
351 auto constant = getRhs().getDefiningOp<ConstantOp>();
355 ArrayAttr arrayAttr = constant.getValue();
356 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
357 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
363 if (real == APFloat(real.getSemantics(), 1))
374 auto rhs = adaptor.getRhs();
375 auto lhs = adaptor.getLhs();
382 if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
386 if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
389 APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
390 if (!rhsImag.isZero())
393 APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
394 APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
395 if (lhsReal.isNaN() || lhsImag.isNaN()) {
396 Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
397 return ArrayAttr::get(
getContext(), {nanValue, nanValue});
401 APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
402 if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
412#define GET_OP_CLASSES
413#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})