MLIR  22.0.0git
ComplexOps.cpp
Go to the documentation of this file.
1 //===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/IR/PatternMatch.h"
14 
15 using namespace mlir;
16 using namespace mlir::complex;
17 
18 //===----------------------------------------------------------------------===//
19 // ConstantOp
20 //===----------------------------------------------------------------------===//
21 
22 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
23  return getValue();
24 }
25 
26 void ConstantOp::getAsmResultNames(
27  function_ref<void(Value, StringRef)> setNameFn) {
28  setNameFn(getResult(), "cst");
29 }
30 
31 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
32  if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
33  auto complexTy = llvm::dyn_cast<ComplexType>(type);
34  if (!complexTy || arrAttr.size() != 2)
35  return false;
36  auto complexEltTy = complexTy.getElementType();
37  if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
38  auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
39  return im && fre.getType() == complexEltTy &&
40  im.getType() == complexEltTy;
41  }
42  if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
43  auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
44  return im && ire.getType() == complexEltTy &&
45  im.getType() == complexEltTy;
46  }
47  }
48  return false;
49 }
50 
51 LogicalResult ConstantOp::verify() {
52  ArrayAttr arrayAttr = getValue();
53  if (arrayAttr.size() != 2) {
54  return emitOpError(
55  "requires 'value' to be a complex constant, represented as array of "
56  "two values");
57  }
58 
59  auto complexEltTy = getType().getElementType();
60  if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
61  !isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
62  return emitOpError(
63  "requires attribute's elements to be float or integer attributes");
64  auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
65  auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
66  if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
67  return emitOpError()
68  << "requires attribute's element types (" << re.getType() << ", "
69  << im.getType()
70  << ") to match the element type of the op's return type ("
71  << complexEltTy << ")";
72  }
73  return success();
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // BitcastOp
78 //===----------------------------------------------------------------------===//
79 
80 OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
81  if (getOperand().getType() == getType())
82  return getOperand();
83 
84  return {};
85 }
86 
87 LogicalResult BitcastOp::verify() {
88  auto operandType = getOperand().getType();
89  auto resultType = getType();
90 
91  // We allow this to be legal as it can be folded away.
92  if (operandType == resultType)
93  return success();
94 
95  if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
96  return emitOpError("operand must be int/float/complex");
97  }
98 
99  if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
100  return emitOpError("result must be int/float/complex");
101  }
102 
103  if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
104  return emitOpError(
105  "requires that either input or output has a complex type");
106  }
107 
108  if (isa<ComplexType>(resultType))
109  std::swap(operandType, resultType);
110 
111  int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
112  .getElementType()
113  .getIntOrFloatBitWidth() *
114  2;
115  int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
116 
117  if (operandBitwidth != resultBitwidth) {
118  return emitOpError("casting bitwidths do not match");
119  }
120 
121  return success();
122 }
123 
124 struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
126 
127  LogicalResult matchAndRewrite(BitcastOp op,
128  PatternRewriter &rewriter) const override {
129  if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
130  if (isa<ComplexType>(op.getType()) ||
131  isa<ComplexType>(defining.getOperand().getType())) {
132  // complex.bitcast requires that input or output is complex.
133  rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
134  defining.getOperand());
135  } else {
136  rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
137  defining.getOperand());
138  }
139  return success();
140  }
141 
142  if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
143  rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
144  defining.getOperand());
145  return success();
146  }
147 
148  return failure();
149  }
150 };
151 
152 struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
154 
155  LogicalResult matchAndRewrite(arith::BitcastOp op,
156  PatternRewriter &rewriter) const override {
157  if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
158  rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
159  defining.getOperand());
160  return success();
161  }
162 
163  return failure();
164  }
165 };
166 
167 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
168  MLIRContext *context) {
169  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // CreateOp
174 //===----------------------------------------------------------------------===//
175 
176 OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
177  // Fold complex.create(complex.re(op), complex.im(op)).
178  if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
179  if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
180  if (reOp.getOperand() == imOp.getOperand()) {
181  return reOp.getOperand();
182  }
183  }
184  }
185  return {};
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // ImOp
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
193  ArrayAttr arrayAttr =
194  llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
195  if (arrayAttr && arrayAttr.size() == 2)
196  return arrayAttr[1];
197  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
198  return createOp.getOperand(1);
199  return {};
200 }
201 
202 namespace {
203 template <typename OpKind, int ComponentIndex>
204 struct FoldComponentNeg final : OpRewritePattern<OpKind> {
206 
207  LogicalResult matchAndRewrite(OpKind op,
208  PatternRewriter &rewriter) const override {
209  auto negOp = op.getOperand().template getDefiningOp<NegOp>();
210  if (!negOp)
211  return failure();
212 
213  auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
214  if (!createOp)
215  return failure();
216 
217  Type elementType = createOp.getType().getElementType();
218  assert(isa<FloatType>(elementType));
219 
220  rewriter.replaceOpWithNewOp<arith::NegFOp>(
221  op, elementType, createOp.getOperand(ComponentIndex));
222  return success();
223  }
224 };
225 } // namespace
226 
227 void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
228  MLIRContext *context) {
229  results.add<FoldComponentNeg<ImOp, 1>>(context);
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // ReOp
234 //===----------------------------------------------------------------------===//
235 
236 OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
237  ArrayAttr arrayAttr =
238  llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
239  if (arrayAttr && arrayAttr.size() == 2)
240  return arrayAttr[0];
241  if (auto createOp = getOperand().getDefiningOp<CreateOp>())
242  return createOp.getOperand(0);
243  return {};
244 }
245 
246 void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
247  MLIRContext *context) {
248  results.add<FoldComponentNeg<ReOp, 0>>(context);
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // AddOp
253 //===----------------------------------------------------------------------===//
254 
255 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
256  // complex.add(complex.sub(a, b), b) -> a
257  if (auto sub = getLhs().getDefiningOp<SubOp>())
258  if (getRhs() == sub.getRhs())
259  return sub.getLhs();
260 
261  // complex.add(b, complex.sub(a, b)) -> a
262  if (auto sub = getRhs().getDefiningOp<SubOp>())
263  if (getLhs() == sub.getRhs())
264  return sub.getLhs();
265 
266  // complex.add(a, complex.constant<0.0, 0.0>) -> a
267  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
268  auto arrayAttr = constantOp.getValue();
269  if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
270  llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
271  return getLhs();
272  }
273  }
274 
275  return {};
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // SubOp
280 //===----------------------------------------------------------------------===//
281 
282 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
283  // complex.sub(complex.add(a, b), b) -> a
284  if (auto add = getLhs().getDefiningOp<AddOp>())
285  if (getRhs() == add.getRhs())
286  return add.getLhs();
287 
288  // complex.sub(a, complex.constant<0.0, 0.0>) -> a
289  if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
290  auto arrayAttr = constantOp.getValue();
291  if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
292  llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
293  return getLhs();
294  }
295  }
296 
297  return {};
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // NegOp
302 //===----------------------------------------------------------------------===//
303 
304 OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
305  // complex.neg(complex.neg(a)) -> a
306  if (auto negOp = getOperand().getDefiningOp<NegOp>())
307  return negOp.getOperand();
308 
309  return {};
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // LogOp
314 //===----------------------------------------------------------------------===//
315 
316 OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
317  // complex.log(complex.exp(a)) -> a
318  if (auto expOp = getOperand().getDefiningOp<ExpOp>())
319  return expOp.getOperand();
320 
321  return {};
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // ExpOp
326 //===----------------------------------------------------------------------===//
327 
328 OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
329  // complex.exp(complex.log(a)) -> a
330  if (auto logOp = getOperand().getDefiningOp<LogOp>())
331  return logOp.getOperand();
332 
333  return {};
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // ConjOp
338 //===----------------------------------------------------------------------===//
339 
340 OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
341  // complex.conj(complex.conj(a)) -> a
342  if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
343  return conjOp.getOperand();
344 
345  return {};
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // MulOp
350 //===----------------------------------------------------------------------===//
351 
352 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
353  auto constant = getRhs().getDefiningOp<ConstantOp>();
354  if (!constant)
355  return {};
356 
357  ArrayAttr arrayAttr = constant.getValue();
358  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
359  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
360 
361  if (!imag.isZero())
362  return {};
363 
364  // complex.mul(a, complex.constant<1.0, 0.0>) -> a
365  if (real == APFloat(real.getSemantics(), 1))
366  return getLhs();
367 
368  return {};
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // DivOp
373 //===----------------------------------------------------------------------===//
374 
375 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
376  auto rhs = adaptor.getRhs();
377  if (!rhs)
378  return {};
379 
380  ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
381  if (!arrayAttr || arrayAttr.size() != 2)
382  return {};
383 
384  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
385  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
386 
387  if (!imag.isZero())
388  return {};
389 
390  // complex.div(a, complex.constant<1.0, 0.0>) -> a
391  if (real == APFloat(real.getSemantics(), 1))
392  return getLhs();
393 
394  return {};
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // TableGen'd op method definitions
399 //===----------------------------------------------------------------------===//
400 
401 #define GET_OP_CLASSES
402 #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult matchAndRewrite(arith::BitcastOp op, PatternRewriter &rewriter) const override
Definition: ComplexOps.cpp:155
LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override
Definition: ComplexOps.cpp:127
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314