MLIR 22.0.0git
ComplexOps.cpp
Go to the documentation of this file.
1//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
14
15using namespace mlir;
16using namespace mlir::complex;
17
18//===----------------------------------------------------------------------===//
19// ConstantOp
20//===----------------------------------------------------------------------===//
21
22OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
23 return getValue();
24}
25
26void ConstantOp::getAsmResultNames(
27 function_ref<void(Value, StringRef)> setNameFn) {
28 setNameFn(getResult(), "cst");
29}
30
31bool ConstantOp::isBuildableWith(Attribute value, Type type) {
32 if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
33 auto complexTy = llvm::dyn_cast<ComplexType>(type);
34 if (!complexTy || arrAttr.size() != 2)
35 return false;
36 auto complexEltTy = complexTy.getElementType();
37 if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
38 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
39 return im && fre.getType() == complexEltTy &&
40 im.getType() == complexEltTy;
41 }
42 if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
43 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
44 return im && ire.getType() == complexEltTy &&
45 im.getType() == complexEltTy;
46 }
47 }
48 return false;
49}
50
51LogicalResult ConstantOp::verify() {
52 ArrayAttr arrayAttr = getValue();
53 if (arrayAttr.size() != 2) {
54 return emitOpError(
55 "requires 'value' to be a complex constant, represented as array of "
56 "two values");
57 }
58
59 auto complexEltTy = getType().getElementType();
60 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
61 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
62 return emitOpError(
63 "requires attribute's elements to be float or integer attributes");
64 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
65 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
66 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
67 return emitOpError()
68 << "requires attribute's element types (" << re.getType() << ", "
69 << im.getType()
70 << ") to match the element type of the op's return type ("
71 << complexEltTy << ")";
72 }
73 return success();
74}
75
76//===----------------------------------------------------------------------===//
77// BitcastOp
78//===----------------------------------------------------------------------===//
79
80OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
81 if (getOperand().getType() == getType())
82 return getOperand();
83
84 return {};
85}
86
87LogicalResult BitcastOp::verify() {
88 auto operandType = getOperand().getType();
89 auto resultType = getType();
90
91 // We allow this to be legal as it can be folded away.
92 if (operandType == resultType)
93 return success();
94
95 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
96 return emitOpError("operand must be int/float/complex");
97 }
98
99 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
100 return emitOpError("result must be int/float/complex");
101 }
102
103 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
104 return emitOpError(
105 "requires that either input or output has a complex type");
106 }
107
108 if (isa<ComplexType>(resultType))
109 std::swap(operandType, resultType);
110
111 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
112 .getElementType()
113 .getIntOrFloatBitWidth() *
114 2;
115 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
116
117 if (operandBitwidth != resultBitwidth) {
118 return emitOpError("casting bitwidths do not match");
119 }
120
121 return success();
122}
123
124struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
125 using OpRewritePattern<BitcastOp>::OpRewritePattern;
126
127 LogicalResult matchAndRewrite(BitcastOp op,
128 PatternRewriter &rewriter) const override {
129 if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
130 if (isa<ComplexType>(op.getType()) ||
131 isa<ComplexType>(defining.getOperand().getType())) {
132 // complex.bitcast requires that input or output is complex.
133 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
134 defining.getOperand());
135 } else {
136 rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
137 defining.getOperand());
138 }
139 return success();
140 }
141
142 if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
143 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
144 defining.getOperand());
145 return success();
146 }
147
148 return failure();
149 }
150};
151
152struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
153 using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
154
155 LogicalResult matchAndRewrite(arith::BitcastOp op,
156 PatternRewriter &rewriter) const override {
157 if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
158 rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
159 defining.getOperand());
160 return success();
161 }
162
163 return failure();
164 }
165};
166
167void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
168 MLIRContext *context) {
169 results.add<MergeComplexBitcast, MergeArithBitcast>(context);
170}
171
172//===----------------------------------------------------------------------===//
173// CreateOp
174//===----------------------------------------------------------------------===//
175
176OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
177 // Fold complex.create(complex.re(op), complex.im(op)).
178 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
179 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
180 if (reOp.getOperand() == imOp.getOperand()) {
181 return reOp.getOperand();
182 }
183 }
184 }
185 return {};
186}
187
188//===----------------------------------------------------------------------===//
189// ImOp
190//===----------------------------------------------------------------------===//
191
192OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
193 ArrayAttr arrayAttr =
194 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
195 if (arrayAttr && arrayAttr.size() == 2)
196 return arrayAttr[1];
197 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
198 return createOp.getOperand(1);
199 return {};
200}
201
202namespace {
203template <typename OpKind, int ComponentIndex>
204struct FoldComponentNeg final : OpRewritePattern<OpKind> {
205 using OpRewritePattern<OpKind>::OpRewritePattern;
206
207 LogicalResult matchAndRewrite(OpKind op,
208 PatternRewriter &rewriter) const override {
209 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
210 if (!negOp)
211 return failure();
212
213 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
214 if (!createOp)
215 return failure();
216
217 Type elementType = createOp.getType().getElementType();
218 assert(isa<FloatType>(elementType));
219
220 rewriter.replaceOpWithNewOp<arith::NegFOp>(
221 op, elementType, createOp.getOperand(ComponentIndex));
222 return success();
223 }
224};
225} // namespace
226
227void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
228 MLIRContext *context) {
229 results.add<FoldComponentNeg<ImOp, 1>>(context);
230}
231
232//===----------------------------------------------------------------------===//
233// ReOp
234//===----------------------------------------------------------------------===//
235
236OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
237 ArrayAttr arrayAttr =
238 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
239 if (arrayAttr && arrayAttr.size() == 2)
240 return arrayAttr[0];
241 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
242 return createOp.getOperand(0);
243 return {};
244}
245
246void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
247 MLIRContext *context) {
248 results.add<FoldComponentNeg<ReOp, 0>>(context);
249}
250
251//===----------------------------------------------------------------------===//
252// AddOp
253//===----------------------------------------------------------------------===//
254
255OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
256 // complex.add(complex.sub(a, b), b) -> a
257 if (auto sub = getLhs().getDefiningOp<SubOp>())
258 if (getRhs() == sub.getRhs())
259 return sub.getLhs();
260
261 // complex.add(b, complex.sub(a, b)) -> a
262 if (auto sub = getRhs().getDefiningOp<SubOp>())
263 if (getLhs() == sub.getRhs())
264 return sub.getLhs();
265
266 // complex.add(a, complex.constant<0.0, 0.0>) -> a
267 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
268 auto arrayAttr = constantOp.getValue();
269 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
270 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
271 return getLhs();
272 }
273 }
274
275 return {};
276}
277
278//===----------------------------------------------------------------------===//
279// SubOp
280//===----------------------------------------------------------------------===//
281
282OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
283 // complex.sub(complex.add(a, b), b) -> a
284 if (auto add = getLhs().getDefiningOp<AddOp>())
285 if (getRhs() == add.getRhs())
286 return add.getLhs();
287
288 // complex.sub(a, complex.constant<0.0, 0.0>) -> a
289 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
290 auto arrayAttr = constantOp.getValue();
291 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
292 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
293 return getLhs();
294 }
295 }
296
297 return {};
298}
299
300//===----------------------------------------------------------------------===//
301// NegOp
302//===----------------------------------------------------------------------===//
303
304OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
305 // complex.neg(complex.neg(a)) -> a
306 if (auto negOp = getOperand().getDefiningOp<NegOp>())
307 return negOp.getOperand();
308
309 return {};
310}
311
312//===----------------------------------------------------------------------===//
313// LogOp
314//===----------------------------------------------------------------------===//
315
316OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
317 // complex.log(complex.exp(a)) -> a
318 if (auto expOp = getOperand().getDefiningOp<ExpOp>())
319 return expOp.getOperand();
320
321 return {};
322}
323
324//===----------------------------------------------------------------------===//
325// ExpOp
326//===----------------------------------------------------------------------===//
327
328OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
329 // complex.exp(complex.log(a)) -> a
330 if (auto logOp = getOperand().getDefiningOp<LogOp>())
331 return logOp.getOperand();
332
333 return {};
334}
335
336//===----------------------------------------------------------------------===//
337// ConjOp
338//===----------------------------------------------------------------------===//
339
340OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
341 // complex.conj(complex.conj(a)) -> a
342 if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
343 return conjOp.getOperand();
344
345 return {};
346}
347
348//===----------------------------------------------------------------------===//
349// MulOp
350//===----------------------------------------------------------------------===//
351
352OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
353 auto constant = getRhs().getDefiningOp<ConstantOp>();
354 if (!constant)
355 return {};
356
357 ArrayAttr arrayAttr = constant.getValue();
358 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
359 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
360
361 if (!imag.isZero())
362 return {};
363
364 // complex.mul(a, complex.constant<1.0, 0.0>) -> a
365 if (real == APFloat(real.getSemantics(), 1))
366 return getLhs();
367
368 return {};
369}
370
371//===----------------------------------------------------------------------===//
372// DivOp
373//===----------------------------------------------------------------------===//
374
375OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
376 auto rhs = adaptor.getRhs();
377 if (!rhs)
378 return {};
379
380 ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
381 if (!arrayAttr || arrayAttr.size() != 2)
382 return {};
383
384 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
385 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
386
387 if (!imag.isZero())
388 return {};
389
390 // complex.div(a, complex.constant<1.0, 0.0>) -> a
391 if (real == APFloat(real.getSemantics(), 1))
392 return getLhs();
393
394 return {};
395}
396
397//===----------------------------------------------------------------------===//
398// TableGen'd op method definitions
399//===----------------------------------------------------------------------===//
400
401#define GET_OP_CLASSES
402#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})