26void ConstantOp::getAsmResultNames(
28 setNameFn(getResult(),
"cst");
32 if (
auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
33 auto complexTy = llvm::dyn_cast<ComplexType>(type);
34 if (!complexTy || arrAttr.size() != 2)
36 auto complexEltTy = complexTy.getElementType();
37 if (
auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
38 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
39 return im && fre.getType() == complexEltTy &&
40 im.getType() == complexEltTy;
42 if (
auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
43 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
44 return im && ire.getType() == complexEltTy &&
45 im.getType() == complexEltTy;
51LogicalResult ConstantOp::verify() {
53 if (arrayAttr.size() != 2) {
55 "requires 'value' to be a complex constant, represented as array of "
59 auto complexEltTy =
getType().getElementType();
60 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
61 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
63 "requires attribute's elements to be float or integer attributes");
64 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
65 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
66 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
68 <<
"requires attribute's element types (" << re.getType() <<
", "
70 <<
") to match the element type of the op's return type ("
71 << complexEltTy <<
")";
87LogicalResult BitcastOp::verify() {
88 auto operandType = getOperand().getType();
92 if (operandType == resultType)
95 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
96 return emitOpError(
"operand must be int/float/complex");
99 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
100 return emitOpError(
"result must be int/float/complex");
103 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105 "requires that either input or output has a complex type");
108 if (isa<ComplexType>(resultType))
109 std::swap(operandType, resultType);
111 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113 .getIntOrFloatBitWidth() *
115 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117 if (operandBitwidth != resultBitwidth) {
118 return emitOpError(
"casting bitwidths do not match");
129 if (
auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
130 if (isa<ComplexType>(op.getType()) ||
131 isa<ComplexType>(defining.getOperand().getType())) {
134 defining.getOperand());
137 defining.getOperand());
142 if (
auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144 defining.getOperand());
157 if (
auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159 defining.getOperand());
178 if (
auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
179 if (
auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
180 if (reOp.getOperand() == imOp.getOperand()) {
181 return reOp.getOperand();
194 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
195 if (arrayAttr && arrayAttr.size() == 2)
197 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
198 return createOp.getOperand(1);
203template <
typename OpKind,
int ComponentIndex>
205 using OpRewritePattern<OpKind>::OpRewritePattern;
207 LogicalResult matchAndRewrite(OpKind op,
208 PatternRewriter &rewriter)
const override {
209 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
213 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
217 Type elementType = createOp.getType().getElementType();
218 assert(isa<FloatType>(elementType));
221 op, elementType, createOp.getOperand(ComponentIndex));
229 results.
add<FoldComponentNeg<ImOp, 1>>(context);
238 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
239 if (arrayAttr && arrayAttr.size() == 2)
241 if (
auto createOp = getOperand().getDefiningOp<CreateOp>())
242 return createOp.getOperand(0);
248 results.
add<FoldComponentNeg<ReOp, 0>>(context);
257 if (
auto sub = getLhs().getDefiningOp<SubOp>())
258 if (getRhs() == sub.getRhs())
262 if (
auto sub = getRhs().getDefiningOp<SubOp>())
263 if (getLhs() == sub.getRhs())
267 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
268 auto arrayAttr = constantOp.getValue();
269 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
270 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
284 if (
auto add = getLhs().getDefiningOp<AddOp>())
285 if (getRhs() ==
add.getRhs())
289 if (
auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
290 auto arrayAttr = constantOp.getValue();
291 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
292 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
306 if (
auto negOp = getOperand().getDefiningOp<NegOp>())
307 return negOp.getOperand();
318 if (
auto expOp = getOperand().getDefiningOp<ExpOp>())
319 return expOp.getOperand();
330 if (
auto logOp = getOperand().getDefiningOp<LogOp>())
331 return logOp.getOperand();
342 if (
auto conjOp = getOperand().getDefiningOp<ConjOp>())
343 return conjOp.getOperand();
353 auto constant = getRhs().getDefiningOp<ConstantOp>();
357 ArrayAttr arrayAttr = constant.getValue();
358 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
359 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
365 if (real == APFloat(real.getSemantics(), 1))
376 auto rhs = adaptor.getRhs();
381 if (!arrayAttr || arrayAttr.size() != 2)
384 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
385 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
391 if (real == APFloat(real.getSemantics(), 1))
401#define GET_OP_CLASSES
402#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::function_ref< Fn > function_ref
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})