MLIR  20.0.0git
ComplexOps.cpp
Go to the documentation of this file.
1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/PatternMatch.h"
15 
16 using namespace mlir;
17 using namespace mlir::complex;
18 
19 //===----------------------------------------------------------------------===//
20 // ConstantOp
21 //===----------------------------------------------------------------------===//
22 
23 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
24  return getValue();
25 }
26 
27 void ConstantOp::getAsmResultNames(
28  function_ref<void(Value, StringRef)> setNameFn) {
29  setNameFn(getResult(), "cst");
30 }
31 
32 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
33  if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
34  auto complexTy = llvm::dyn_cast<ComplexType>(type);
35  if (!complexTy || arrAttr.size() != 2)
36  return false;
37  auto complexEltTy = complexTy.getElementType();
38  if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
39  auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
40  return im && fre.getType() == complexEltTy &&
41  im.getType() == complexEltTy;
42  }
43  if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
44  auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
45  return im && ire.getType() == complexEltTy &&
46  im.getType() == complexEltTy;
47  }
48  }
49  return false;
50 }
51 
52 LogicalResult ConstantOp::verify() {
53  ArrayAttr arrayAttr = getValue();
54  if (arrayAttr.size() != 2) {
55  return emitOpError(
56  "requires 'value' to be a complex constant, represented as array of "
57  "two values");
58  }
59 
60  auto complexEltTy = getType().getElementType();
61  if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
62  !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
63  return emitOpError(
64  "requires attribute's elements to be float or integer attributes");
65  auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
66  auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
67  if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
68  return emitOpError()
69  << "requires attribute's element types (" << re.getType() << ", "
70  << im.getType()
71  << ") to match the element type of the op's return type ("
72  << complexEltTy << ")";
73  }
74  return success();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // BitcastOp
79 //===----------------------------------------------------------------------===//
80 
81 OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
82  if (getOperand().getType() == getType())
83  return getOperand();
84 
85  return {};
86 }
87 
88 LogicalResult BitcastOp::verify() {
89  auto operandType = getOperand().getType();
90  auto resultType = getType();
91 
92  // We allow this to be legal as it can be folded away.
93  if (operandType == resultType)
94  return success();
95 
96  if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
97  return emitOpError("operand must be int/float/complex");
98  }
99 
100  if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
101  return emitOpError("result must be int/float/complex");
102  }
103 
104  if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
105  return emitOpError(
106  "requires that either input or output has a complex type");
107  }
108 
109  if (isa<ComplexType>(resultType))
110  std::swap(operandType, resultType);
111 
112  int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
113  .getElementType()
114  .getIntOrFloatBitWidth() *
115  2;
116  int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
117 
118  if (operandBitwidth != resultBitwidth) {
119  return emitOpError("casting bitwidths do not match");
120  }
121 
122  return success();
123 }
124 
125 struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
127 
128  LogicalResult matchAndRewrite(BitcastOp op,
129  PatternRewriter &rewriter) const override {
130  if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
131  if (isa<ComplexType>(op.getType()) ||
132  isa<ComplexType>(defining.getOperand().getType())) {
133  // complex.bitcast requires that input or output is complex.
134  rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
135  defining.getOperand());
136  } else {
137  rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
138  defining.getOperand());
139  }
140  return success();
141  }
142 
143  if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
144  rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
145  defining.getOperand());
146  return success();
147  }
148 
149  return failure();
150  }
151 };
152 
153 struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
155 
156  LogicalResult matchAndRewrite(arith::BitcastOp op,
157  PatternRewriter &rewriter) const override {
158  if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
159  rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
160  defining.getOperand());
161  return success();
162  }
163 
164  return failure();
165  }
166 };
167 
168 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
169  MLIRContext *context) {
170  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // CreateOp
175 //===----------------------------------------------------------------------===//
176 
177 OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
178  // Fold complex.create(complex.re(op), complex.im(op)).
179  if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
180  if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
181  if (reOp.getOperand() == imOp.getOperand()) {
182  return reOp.getOperand();
183  }
184  }
185  }
186  return {};
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // ImOp
191 //===----------------------------------------------------------------------===//
192 
193 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
194  ArrayAttr arrayAttr =
195  llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
196  if (arrayAttr && arrayAttr.size() == 2)
197  return arrayAttr[1];
198  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
199  return createOp.getOperand(1);
200  return {};
201 }
202 
203 namespace {
204 template <typename OpKind, int ComponentIndex>
205 struct FoldComponentNeg final : OpRewritePattern<OpKind> {
207 
208  LogicalResult matchAndRewrite(OpKind op,
209  PatternRewriter &rewriter) const override {
210  auto negOp = op.getOperand().template getDefiningOp<NegOp>();
211  if (!negOp)
212  return failure();
213 
214  auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
215  if (!createOp)
216  return failure();
217 
218  Type elementType = createOp.getType().getElementType();
219  assert(isa<FloatType>(elementType));
220 
221  rewriter.replaceOpWithNewOp<arith::NegFOp>(
222  op, elementType, createOp.getOperand(ComponentIndex));
223  return success();
224  }
225 };
226 } // namespace
227 
228 void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
229  MLIRContext *context) {
230  results.add<FoldComponentNeg<ImOp, 1>>(context);
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // ReOp
235 //===----------------------------------------------------------------------===//
236 
237 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
238  ArrayAttr arrayAttr =
239  llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
240  if (arrayAttr && arrayAttr.size() == 2)
241  return arrayAttr[0];
242  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
243  return createOp.getOperand(0);
244  return {};
245 }
246 
247 void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
248  MLIRContext *context) {
249  results.add<FoldComponentNeg<ReOp, 0>>(context);
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // AddOp
254 //===----------------------------------------------------------------------===//
255 
256 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
257  // complex.add(complex.sub(a, b), b) -> a
258  if (auto sub = getLhs().getDefiningOp<SubOp>())
259  if (getRhs() == sub.getRhs())
260  return sub.getLhs();
261 
262  // complex.add(b, complex.sub(a, b)) -> a
263  if (auto sub = getRhs().getDefiningOp<SubOp>())
264  if (getLhs() == sub.getRhs())
265  return sub.getLhs();
266 
267  // complex.add(a, complex.constant<0.0, 0.0>) -> a
268  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
269  auto arrayAttr = constantOp.getValue();
270  if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
271  llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
272  return getLhs();
273  }
274  }
275 
276  return {};
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // SubOp
281 //===----------------------------------------------------------------------===//
282 
283 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
284  // complex.sub(complex.add(a, b), b) -> a
285  if (auto add = getLhs().getDefiningOp<AddOp>())
286  if (getRhs() == add.getRhs())
287  return add.getLhs();
288 
289  // complex.sub(a, complex.constant<0.0, 0.0>) -> a
290  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
291  auto arrayAttr = constantOp.getValue();
292  if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
293  llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
294  return getLhs();
295  }
296  }
297 
298  return {};
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // NegOp
303 //===----------------------------------------------------------------------===//
304 
305 OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
306  // complex.neg(complex.neg(a)) -> a
307  if (auto negOp = getOperand().getDefiningOp<NegOp>())
308  return negOp.getOperand();
309 
310  return {};
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // LogOp
315 //===----------------------------------------------------------------------===//
316 
317 OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
318  // complex.log(complex.exp(a)) -> a
319  if (auto expOp = getOperand().getDefiningOp<ExpOp>())
320  return expOp.getOperand();
321 
322  return {};
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // ExpOp
327 //===----------------------------------------------------------------------===//
328 
329 OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
330  // complex.exp(complex.log(a)) -> a
331  if (auto logOp = getOperand().getDefiningOp<LogOp>())
332  return logOp.getOperand();
333 
334  return {};
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // ConjOp
339 //===----------------------------------------------------------------------===//
340 
341 OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
342  // complex.conj(complex.conj(a)) -> a
343  if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
344  return conjOp.getOperand();
345 
346  return {};
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // MulOp
351 //===----------------------------------------------------------------------===//
352 
353 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
354  auto constant = getRhs().getDefiningOp<ConstantOp>();
355  if (!constant)
356  return {};
357 
358  ArrayAttr arrayAttr = constant.getValue();
359  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
360  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
361 
362  if (!imag.isZero())
363  return {};
364 
365  // complex.mul(a, complex.constant<1.0, 0.0>) -> a
366  if (real == APFloat(real.getSemantics(), 1))
367  return getLhs();
368 
369  return {};
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // DivOp
374 //===----------------------------------------------------------------------===//
375 
376 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
377  auto rhs = adaptor.getRhs();
378  if (!rhs)
379  return {};
380 
381  ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
382  if (!arrayAttr || arrayAttr.size() != 2)
383  return {};
384 
385  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
386  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
387 
388  if (!imag.isZero())
389  return {};
390 
391  // complex.div(a, complex.constant<1.0, 0.0>) -> a
392  if (real == APFloat(real.getSemantics(), 1))
393  return getLhs();
394 
395  return {};
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // TableGen'd op method definitions
400 //===----------------------------------------------------------------------===//
401 
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
Definition: ComplexOps.cpp:156
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
Definition: ComplexOps.cpp:128
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358