MLIR 23.0.0git
ComplexOps.cpp
Go to the documentation of this file.
1//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
14
15using namespace mlir;
16using namespace mlir::complex;
17
18//===----------------------------------------------------------------------===//
19// ConstantOp
20//===----------------------------------------------------------------------===//
21
22OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
23
24void ConstantOp::getAsmResultNames(
25 function_ref<void(Value, StringRef)> setNameFn) {
26 setNameFn(getResult(), "cst");
27}
28
29bool ConstantOp::isBuildableWith(Attribute value, Type type) {
30 if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
31 auto complexTy = llvm::dyn_cast<ComplexType>(type);
32 if (!complexTy || arrAttr.size() != 2)
33 return false;
34 auto complexEltTy = complexTy.getElementType();
35 if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
36 auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
37 return im && fre.getType() == complexEltTy &&
38 im.getType() == complexEltTy;
39 }
40 if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
41 auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
42 return im && ire.getType() == complexEltTy &&
43 im.getType() == complexEltTy;
44 }
45 }
46 return false;
47}
48
49LogicalResult ConstantOp::verify() {
50 ArrayAttr arrayAttr = getValue();
51 if (arrayAttr.size() != 2) {
52 return emitOpError(
53 "requires 'value' to be a complex constant, represented as array of "
54 "two values");
55 }
56
57 auto complexEltTy = getType().getElementType();
58 if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
59 !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
60 return emitOpError(
61 "requires attribute's elements to be float or integer attributes");
62 auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
63 auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
64 if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
65 return emitOpError()
66 << "requires attribute's element types (" << re.getType() << ", "
67 << im.getType()
68 << ") to match the element type of the op's return type ("
69 << complexEltTy << ")";
70 }
71 return success();
72}
73
74//===----------------------------------------------------------------------===//
75// BitcastOp
76//===----------------------------------------------------------------------===//
77
78OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
79 if (getOperand().getType() == getType())
80 return getOperand();
81
82 return {};
83}
84
85LogicalResult BitcastOp::verify() {
86 auto operandType = getOperand().getType();
87 auto resultType = getType();
88
89 // We allow this to be legal as it can be folded away.
90 if (operandType == resultType)
91 return success();
92
93 if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
94 return emitOpError("operand must be int/float/complex");
95 }
96
97 if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
98 return emitOpError("result must be int/float/complex");
99 }
100
101 if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
102 return emitOpError(
103 "requires that either input or output has a complex type");
104 }
105
106 if (isa<ComplexType>(resultType))
107 std::swap(operandType, resultType);
108
109 int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
110 .getElementType()
111 .getIntOrFloatBitWidth() *
112 2;
113 int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
114
115 if (operandBitwidth != resultBitwidth) {
116 return emitOpError("casting bitwidths do not match");
117 }
118
119 return success();
120}
121
122struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
123 using OpRewritePattern<BitcastOp>::OpRewritePattern;
124
125 LogicalResult matchAndRewrite(BitcastOp op,
126 PatternRewriter &rewriter) const override {
127 if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
128 if (isa<ComplexType>(op.getType()) ||
129 isa<ComplexType>(defining.getOperand().getType())) {
130 // complex.bitcast requires that input or output is complex.
131 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
132 defining.getOperand());
133 } else {
134 rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
135 defining.getOperand());
136 }
137 return success();
138 }
139
140 if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
141 rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
142 defining.getOperand());
143 return success();
144 }
145
146 return failure();
147 }
148};
149
150struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
151 using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
152
153 LogicalResult matchAndRewrite(arith::BitcastOp op,
154 PatternRewriter &rewriter) const override {
155 if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
156 rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
157 defining.getOperand());
158 return success();
159 }
160
161 return failure();
162 }
163};
164
165void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
166 MLIRContext *context) {
167 results.add<MergeComplexBitcast, MergeArithBitcast>(context);
168}
169
170//===----------------------------------------------------------------------===//
171// CreateOp
172//===----------------------------------------------------------------------===//
173
174OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
175 // Fold complex.create(complex.re(op), complex.im(op)).
176 if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
177 if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
178 if (reOp.getOperand() == imOp.getOperand()) {
179 return reOp.getOperand();
180 }
181 }
182 }
183 return {};
184}
185
186//===----------------------------------------------------------------------===//
187// ImOp
188//===----------------------------------------------------------------------===//
189
190OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
191 ArrayAttr arrayAttr =
192 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
193 if (arrayAttr && arrayAttr.size() == 2)
194 return arrayAttr[1];
195 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
196 return createOp.getOperand(1);
197 return {};
198}
199
200namespace {
201template <typename OpKind, int ComponentIndex>
202struct FoldComponentNeg final : OpRewritePattern<OpKind> {
203 using OpRewritePattern<OpKind>::OpRewritePattern;
204
205 LogicalResult matchAndRewrite(OpKind op,
206 PatternRewriter &rewriter) const override {
207 auto negOp = op.getOperand().template getDefiningOp<NegOp>();
208 if (!negOp)
209 return failure();
210
211 auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
212 if (!createOp)
213 return failure();
214
215 Type elementType = createOp.getType().getElementType();
216 assert(isa<FloatType>(elementType));
217
218 rewriter.replaceOpWithNewOp<arith::NegFOp>(
219 op, elementType, createOp.getOperand(ComponentIndex));
220 return success();
221 }
222};
223} // namespace
224
225void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
226 MLIRContext *context) {
227 results.add<FoldComponentNeg<ImOp, 1>>(context);
228}
229
230//===----------------------------------------------------------------------===//
231// ReOp
232//===----------------------------------------------------------------------===//
233
234OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
235 ArrayAttr arrayAttr =
236 llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
237 if (arrayAttr && arrayAttr.size() == 2)
238 return arrayAttr[0];
239 if (auto createOp = getOperand().getDefiningOp<CreateOp>())
240 return createOp.getOperand(0);
241 return {};
242}
243
244void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
245 MLIRContext *context) {
246 results.add<FoldComponentNeg<ReOp, 0>>(context);
247}
248
249//===----------------------------------------------------------------------===//
250// AddOp
251//===----------------------------------------------------------------------===//
252
253OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
254 // complex.add(complex.sub(a, b), b) -> a
255 if (auto sub = getLhs().getDefiningOp<SubOp>())
256 if (getRhs() == sub.getRhs())
257 return sub.getLhs();
258
259 // complex.add(b, complex.sub(a, b)) -> a
260 if (auto sub = getRhs().getDefiningOp<SubOp>())
261 if (getLhs() == sub.getRhs())
262 return sub.getLhs();
263
264 // complex.add(a, complex.constant<0.0, 0.0>) -> a
265 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
266 auto arrayAttr = constantOp.getValue();
267 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
268 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
269 return getLhs();
270 }
271 }
272
273 return {};
274}
275
276//===----------------------------------------------------------------------===//
277// SubOp
278//===----------------------------------------------------------------------===//
279
280OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
281 // complex.sub(complex.add(a, b), b) -> a
282 if (auto add = getLhs().getDefiningOp<AddOp>())
283 if (getRhs() == add.getRhs())
284 return add.getLhs();
285
286 // complex.sub(a, complex.constant<0.0, 0.0>) -> a
287 if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
288 auto arrayAttr = constantOp.getValue();
289 if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
290 llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
291 return getLhs();
292 }
293 }
294
295 return {};
296}
297
298//===----------------------------------------------------------------------===//
299// NegOp
300//===----------------------------------------------------------------------===//
301
302OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
303 // complex.neg(complex.neg(a)) -> a
304 if (auto negOp = getOperand().getDefiningOp<NegOp>())
305 return negOp.getOperand();
306
307 return {};
308}
309
310//===----------------------------------------------------------------------===//
311// LogOp
312//===----------------------------------------------------------------------===//
313
314OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
315 // complex.log(complex.exp(a)) -> a
316 if (auto expOp = getOperand().getDefiningOp<ExpOp>())
317 return expOp.getOperand();
318
319 return {};
320}
321
322//===----------------------------------------------------------------------===//
323// ExpOp
324//===----------------------------------------------------------------------===//
325
326OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
327 // complex.exp(complex.log(a)) -> a
328 if (auto logOp = getOperand().getDefiningOp<LogOp>())
329 return logOp.getOperand();
330
331 return {};
332}
333
334//===----------------------------------------------------------------------===//
335// ConjOp
336//===----------------------------------------------------------------------===//
337
338OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
339 // complex.conj(complex.conj(a)) -> a
340 if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
341 return conjOp.getOperand();
342
343 return {};
344}
345
346//===----------------------------------------------------------------------===//
347// MulOp
348//===----------------------------------------------------------------------===//
349
350OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
351 auto constant = getRhs().getDefiningOp<ConstantOp>();
352 if (!constant)
353 return {};
354
355 ArrayAttr arrayAttr = constant.getValue();
356 APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
357 APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
358
359 if (!imag.isZero())
360 return {};
361
362 // complex.mul(a, complex.constant<1.0, 0.0>) -> a
363 if (real == APFloat(real.getSemantics(), 1))
364 return getLhs();
365
366 return {};
367}
368
369//===----------------------------------------------------------------------===//
370// DivOp
371//===----------------------------------------------------------------------===//
372
373OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
374 auto rhs = adaptor.getRhs();
375 auto lhs = adaptor.getLhs();
376
377 // We can't fold without knowing that LHS isn't NaN
378 if (!rhs || !lhs)
379 return {};
380
381 ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
382 if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
383 return {};
384
385 ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
386 if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
387 return {};
388
389 APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
390 if (!rhsImag.isZero())
391 return {};
392
393 APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
394 APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
395 if (lhsReal.isNaN() || lhsImag.isNaN()) {
396 Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
397 return ArrayAttr::get(getContext(), {nanValue, nanValue});
398 }
399
400 // complex.div(a, complex.constant<1.0, 0.0>) -> a
401 APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
402 if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
403 return getLhs();
404
405 return {};
406}
407
408//===----------------------------------------------------------------------===//
409// TableGen'd op method definitions
410//===----------------------------------------------------------------------===//
411
412#define GET_OP_CLASSES
413#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
ArrayAttr()
b getContext())
#define add(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})