MLIR 23.0.0git
MathOps.cpp
Go to the documentation of this file.
1//===- MathOps.cpp - MLIR operations for math implementation --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Builders.h"
14#include <optional>
15
16using namespace mlir;
17using namespace mlir::math;
18
19//===----------------------------------------------------------------------===//
20// Common helpers
21//===----------------------------------------------------------------------===//
22
23/// Return the type of the same shape (scalar, vector or tensor) containing i1.
24static Type getI1SameShape(Type type) {
25 auto i1Type = IntegerType::get(type.getContext(), 1);
26 if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
27 return shapedType.cloneWith(std::nullopt, i1Type);
28 if (llvm::isa<UnrankedTensorType>(type))
29 return UnrankedTensorType::get(i1Type);
30 return i1Type;
31}
32
33//===----------------------------------------------------------------------===//
34// TableGen'd op method definitions
35//===----------------------------------------------------------------------===//
36
37#define GET_OP_CLASSES
38#include "mlir/Dialect/Math/IR/MathOps.cpp.inc"
39
40//===----------------------------------------------------------------------===//
41// AbsFOp folder
42//===----------------------------------------------------------------------===//
43
44OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) {
45 return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
46 [](const APFloat &a) { return abs(a); });
47}
48
49//===----------------------------------------------------------------------===//
50// AbsIOp folder
51//===----------------------------------------------------------------------===//
52
53OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) {
54 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
55 [](const APInt &a) { return a.abs(); });
56}
57
58//===----------------------------------------------------------------------===//
59// AcosOp folder
60//===----------------------------------------------------------------------===//
61
62OpFoldResult math::AcosOp::fold(FoldAdaptor adaptor) {
64 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
65 switch (a.getSizeInBits(a.getSemantics())) {
66 case 64:
67 return APFloat(acos(a.convertToDouble()));
68 case 32:
69 return APFloat(acosf(a.convertToFloat()));
70 default:
71 return {};
72 }
73 });
74}
75
76//===----------------------------------------------------------------------===//
77// AcoshOp folder
78//===----------------------------------------------------------------------===//
79
80OpFoldResult math::AcoshOp::fold(FoldAdaptor adaptor) {
82 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
83 switch (a.getSizeInBits(a.getSemantics())) {
84 case 64:
85 return APFloat(acosh(a.convertToDouble()));
86 case 32:
87 return APFloat(acoshf(a.convertToFloat()));
88 default:
89 return {};
90 }
91 });
92}
93
94//===----------------------------------------------------------------------===//
95// AsinOp folder
96//===----------------------------------------------------------------------===//
97
98OpFoldResult math::AsinOp::fold(FoldAdaptor adaptor) {
100 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
101 switch (a.getSizeInBits(a.getSemantics())) {
102 case 64:
103 return APFloat(asin(a.convertToDouble()));
104 case 32:
105 return APFloat(asinf(a.convertToFloat()));
106 default:
107 return {};
108 }
109 });
110}
111
112//===----------------------------------------------------------------------===//
113// AsinhOp folder
114//===----------------------------------------------------------------------===//
115
116OpFoldResult math::AsinhOp::fold(FoldAdaptor adaptor) {
118 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
119 switch (a.getSizeInBits(a.getSemantics())) {
120 case 64:
121 return APFloat(asinh(a.convertToDouble()));
122 case 32:
123 return APFloat(asinhf(a.convertToFloat()));
124 default:
125 return {};
126 }
127 });
128}
129
130//===----------------------------------------------------------------------===//
131// AtanOp folder
132//===----------------------------------------------------------------------===//
133
134OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) {
136 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
137 switch (a.getSizeInBits(a.getSemantics())) {
138 case 64:
139 return APFloat(atan(a.convertToDouble()));
140 case 32:
141 return APFloat(atanf(a.convertToFloat()));
142 default:
143 return {};
144 }
145 });
146}
147
148//===----------------------------------------------------------------------===//
149// AtanhOp folder
150//===----------------------------------------------------------------------===//
151
152OpFoldResult math::AtanhOp::fold(FoldAdaptor adaptor) {
154 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
155 switch (a.getSizeInBits(a.getSemantics())) {
156 case 64:
157 return APFloat(atanh(a.convertToDouble()));
158 case 32:
159 return APFloat(atanhf(a.convertToFloat()));
160 default:
161 return {};
162 }
163 });
164}
165
166//===----------------------------------------------------------------------===//
167// Atan2Op folder
168//===----------------------------------------------------------------------===//
169
170OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
172 adaptor.getOperands(),
173 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
174 if (a.isZero() && b.isZero())
175 return llvm::APFloat::getNaN(a.getSemantics());
176
177 if (a.getSizeInBits(a.getSemantics()) == 64 &&
178 b.getSizeInBits(b.getSemantics()) == 64)
179 return APFloat(atan2(a.convertToDouble(), b.convertToDouble()));
180
181 if (a.getSizeInBits(a.getSemantics()) == 32 &&
182 b.getSizeInBits(b.getSemantics()) == 32)
183 return APFloat(atan2f(a.convertToFloat(), b.convertToFloat()));
184
185 return {};
186 });
187}
188
189//===----------------------------------------------------------------------===//
190// CeilOp folder
191//===----------------------------------------------------------------------===//
192
193OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) {
195 adaptor.getOperands(), [](const APFloat &a) {
196 APFloat result(a);
197 result.roundToIntegral(llvm::RoundingMode::TowardPositive);
198 return result;
199 });
200}
201
202//===----------------------------------------------------------------------===//
203// CopySignOp folder
204//===----------------------------------------------------------------------===//
205
206OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) {
207 return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
208 [](const APFloat &a, const APFloat &b) {
209 APFloat result(a);
210 result.copySign(b);
211 return result;
212 });
213}
214
215//===----------------------------------------------------------------------===//
216// CosOp folder
217//===----------------------------------------------------------------------===//
218
219OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) {
221 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
222 switch (a.getSizeInBits(a.getSemantics())) {
223 case 64:
224 return APFloat(cos(a.convertToDouble()));
225 case 32:
226 return APFloat(cosf(a.convertToFloat()));
227 default:
228 return {};
229 }
230 });
231}
232
233//===----------------------------------------------------------------------===//
234// CoshOp folder
235//===----------------------------------------------------------------------===//
236
237OpFoldResult math::CoshOp::fold(FoldAdaptor adaptor) {
239 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
240 switch (a.getSizeInBits(a.getSemantics())) {
241 case 64:
242 return APFloat(cosh(a.convertToDouble()));
243 case 32:
244 return APFloat(coshf(a.convertToFloat()));
245 default:
246 return {};
247 }
248 });
249}
250
251//===----------------------------------------------------------------------===//
252// SinOp folder
253//===----------------------------------------------------------------------===//
254
255OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) {
257 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
258 switch (a.getSizeInBits(a.getSemantics())) {
259 case 64:
260 return APFloat(sin(a.convertToDouble()));
261 case 32:
262 return APFloat(sinf(a.convertToFloat()));
263 default:
264 return {};
265 }
266 });
267}
268
269//===----------------------------------------------------------------------===//
270// SinhOp folder
271//===----------------------------------------------------------------------===//
272
273OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
275 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
276 switch (a.getSizeInBits(a.getSemantics())) {
277 case 64:
278 return APFloat(sinh(a.convertToDouble()));
279 case 32:
280 return APFloat(sinhf(a.convertToFloat()));
281 default:
282 return {};
283 }
284 });
285}
286
287//===----------------------------------------------------------------------===//
288// SinCosOp getShapeForUnroll
289//===----------------------------------------------------------------------===//
290
291std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
292 if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
293 return llvm::to_vector<4>(vt.getShape());
294 return std::nullopt;
295}
296
297//===----------------------------------------------------------------------===//
298// CountLeadingZerosOp folder
299//===----------------------------------------------------------------------===//
300
301OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) {
303 adaptor.getOperands(),
304 [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); });
305}
306
307//===----------------------------------------------------------------------===//
308// CountTrailingZerosOp folder
309//===----------------------------------------------------------------------===//
310
311OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) {
313 adaptor.getOperands(),
314 [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); });
315}
316
317//===----------------------------------------------------------------------===//
318// CtPopOp folder
319//===----------------------------------------------------------------------===//
320
321OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) {
323 adaptor.getOperands(),
324 [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); });
325}
326
327//===----------------------------------------------------------------------===//
328// ErfOp folder
329//===----------------------------------------------------------------------===//
330
331OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
333 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
334 switch (a.getSizeInBits(a.getSemantics())) {
335 case 64:
336 return APFloat(erf(a.convertToDouble()));
337 case 32:
338 return APFloat(erff(a.convertToFloat()));
339 default:
340 return {};
341 }
342 });
343}
344
345//===----------------------------------------------------------------------===//
346// ErfcOp folder
347//===----------------------------------------------------------------------===//
348
349OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
351 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
352 switch (APFloat::SemanticsToEnum(a.getSemantics())) {
353 case APFloat::Semantics::S_IEEEdouble:
354 return APFloat(erfc(a.convertToDouble()));
355 case APFloat::Semantics::S_IEEEsingle:
356 return APFloat(erfcf(a.convertToFloat()));
357 default:
358 return {};
359 }
360 });
361}
362
363//===----------------------------------------------------------------------===//
364// IPowIOp folder
365//===----------------------------------------------------------------------===//
366
367OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) {
369 adaptor.getOperands(),
370 [](const APInt &base, const APInt &power) -> std::optional<APInt> {
371 unsigned width = base.getBitWidth();
372 auto zeroValue = APInt::getZero(width);
373 // i1 folding is ambiguous with signed semantics, don't fold.
374 if (width == 1)
375 return {};
376 APInt oneValue{width, 1ULL, /*isSigned=*/true};
377 APInt minusOneValue{width, -1ULL, /*isSigned=*/true};
378
379 if (power.isZero())
380 return oneValue;
381
382 if (power.isNegative()) {
383 // Leave 0 raised to negative power not folded.
384 if (base.isZero())
385 return {};
386 if (base.isOne())
387 return oneValue;
388 // If abs(base) > 1, then the result is zero.
389 if (base.ne(minusOneValue))
390 return zeroValue;
391 // base == -1:
392 // -1: power is odd
393 // 1: power is even
394 if (power[0] == 1)
395 return minusOneValue;
396
397 return oneValue;
398 }
399
400 // power is positive.
401 APInt result = oneValue;
402 APInt curBase = base;
403 APInt curPower = power;
404 while (true) {
405 if (curPower[0] == 1)
406 result *= curBase;
407 curPower.lshrInPlace(1);
408 if (curPower.isZero())
409 return result;
410 curBase *= curBase;
411 }
412 });
413
414 return Attribute();
415}
416
417//===----------------------------------------------------------------------===//
418// LogOp folder
419//===----------------------------------------------------------------------===//
420
421OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) {
423 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
424 if (a.isNegative())
425 return {};
426
427 if (a.getSizeInBits(a.getSemantics()) == 64)
428 return APFloat(log(a.convertToDouble()));
429
430 if (a.getSizeInBits(a.getSemantics()) == 32)
431 return APFloat(logf(a.convertToFloat()));
432
433 return {};
434 });
435}
436
437//===----------------------------------------------------------------------===//
438// Log2Op folder
439//===----------------------------------------------------------------------===//
440
441OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) {
443 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
444 if (a.isNegative())
445 return {};
446
447 if (a.getSizeInBits(a.getSemantics()) == 64)
448 return APFloat(log2(a.convertToDouble()));
449
450 if (a.getSizeInBits(a.getSemantics()) == 32)
451 return APFloat(log2f(a.convertToFloat()));
452
453 return {};
454 });
455}
456
457//===----------------------------------------------------------------------===//
458// Log10Op folder
459//===----------------------------------------------------------------------===//
460
461OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) {
463 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
464 if (a.isNegative())
465 return {};
466
467 switch (a.getSizeInBits(a.getSemantics())) {
468 case 64:
469 return APFloat(log10(a.convertToDouble()));
470 case 32:
471 return APFloat(log10f(a.convertToFloat()));
472 default:
473 return {};
474 }
475 });
476}
477
478//===----------------------------------------------------------------------===//
479// Log1pOp folder
480//===----------------------------------------------------------------------===//
481
482OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) {
484 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
485 switch (a.getSizeInBits(a.getSemantics())) {
486 case 64:
487 if ((a + APFloat(1.0)).isNegative())
488 return {};
489 return APFloat(log1p(a.convertToDouble()));
490 case 32:
491 if ((a + APFloat(1.0f)).isNegative())
492 return {};
493 return APFloat(log1pf(a.convertToFloat()));
494 default:
495 return {};
496 }
497 });
498}
499
500//===----------------------------------------------------------------------===//
501// PowFOp folder
502//===----------------------------------------------------------------------===//
503
504OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) {
506 adaptor.getOperands(),
507 [](const APFloat &a, const APFloat &b) -> std::optional<APFloat> {
508 if (a.getSizeInBits(a.getSemantics()) == 64 &&
509 b.getSizeInBits(b.getSemantics()) == 64)
510 return APFloat(pow(a.convertToDouble(), b.convertToDouble()));
511
512 if (a.getSizeInBits(a.getSemantics()) == 32 &&
513 b.getSizeInBits(b.getSemantics()) == 32)
514 return APFloat(powf(a.convertToFloat(), b.convertToFloat()));
515
516 return {};
517 });
518}
519
520//===----------------------------------------------------------------------===//
521// SqrtOp folder
522//===----------------------------------------------------------------------===//
523
524OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) {
526 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
527 if (a.isNegative())
528 return {};
529
530 switch (a.getSizeInBits(a.getSemantics())) {
531 case 64:
532 return APFloat(sqrt(a.convertToDouble()));
533 case 32:
534 return APFloat(sqrtf(a.convertToFloat()));
535 default:
536 return {};
537 }
538 });
539}
540
541//===----------------------------------------------------------------------===//
542// ExpOp folder
543//===----------------------------------------------------------------------===//
544
545OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) {
547 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
548 switch (a.getSizeInBits(a.getSemantics())) {
549 case 64:
550 return APFloat(exp(a.convertToDouble()));
551 case 32:
552 return APFloat(expf(a.convertToFloat()));
553 default:
554 return {};
555 }
556 });
557}
558
559//===----------------------------------------------------------------------===//
560// Exp2Op folder
561//===----------------------------------------------------------------------===//
562
563OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) {
565 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
566 switch (a.getSizeInBits(a.getSemantics())) {
567 case 64:
568 return APFloat(exp2(a.convertToDouble()));
569 case 32:
570 return APFloat(exp2f(a.convertToFloat()));
571 default:
572 return {};
573 }
574 });
575}
576
577//===----------------------------------------------------------------------===//
578// ExpM1Op folder
579//===----------------------------------------------------------------------===//
580
581OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
583 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
584 switch (a.getSizeInBits(a.getSemantics())) {
585 case 64:
586 return APFloat(expm1(a.convertToDouble()));
587 case 32:
588 return APFloat(expm1f(a.convertToFloat()));
589 default:
590 return {};
591 }
592 });
593}
594
595//===----------------------------------------------------------------------===//
596// IsFiniteOp folder
597//===----------------------------------------------------------------------===//
598
599OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
600 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
601 return BoolAttr::get(val.getContext(), val.getValue().isFinite());
602 }
603 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
605 cast<ShapedType>(getType()),
606 APInt(1, splat.getSplatValue<APFloat>().isFinite()));
607 }
608 return {};
609}
610
611//===----------------------------------------------------------------------===//
612// IsInfOp folder
613//===----------------------------------------------------------------------===//
614
615OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
616 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
617 return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
618 }
619 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
621 cast<ShapedType>(getType()),
622 APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
623 }
624 return {};
625}
626
627//===----------------------------------------------------------------------===//
628// IsNaNOp folder
629//===----------------------------------------------------------------------===//
630
631OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
632 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
633 return BoolAttr::get(val.getContext(), val.getValue().isNaN());
634 }
635 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
637 cast<ShapedType>(getType()),
638 APInt(1, splat.getSplatValue<APFloat>().isNaN()));
639 }
640 return {};
641}
642
643//===----------------------------------------------------------------------===//
644// IsNormalOp folder
645//===----------------------------------------------------------------------===//
646
647OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
648 if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
649 return BoolAttr::get(val.getContext(), val.getValue().isNormal());
650 }
651 if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
653 cast<ShapedType>(getType()),
654 APInt(1, splat.getSplatValue<APFloat>().isNormal()));
655 }
656 return {};
657}
658
659//===----------------------------------------------------------------------===//
660// TanOp folder
661//===----------------------------------------------------------------------===//
662
663OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) {
665 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
666 switch (a.getSizeInBits(a.getSemantics())) {
667 case 64:
668 return APFloat(tan(a.convertToDouble()));
669 case 32:
670 return APFloat(tanf(a.convertToFloat()));
671 default:
672 return {};
673 }
674 });
675}
676
677//===----------------------------------------------------------------------===//
678// TanhOp folder
679//===----------------------------------------------------------------------===//
680
681OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) {
683 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
684 switch (a.getSizeInBits(a.getSemantics())) {
685 case 64:
686 return APFloat(tanh(a.convertToDouble()));
687 case 32:
688 return APFloat(tanhf(a.convertToFloat()));
689 default:
690 return {};
691 }
692 });
693}
694
695//===----------------------------------------------------------------------===//
696// RoundEvenOp folder
697//===----------------------------------------------------------------------===//
698
699OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) {
701 adaptor.getOperands(), [](const APFloat &a) {
702 APFloat result(a);
703 result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven);
704 return result;
705 });
706}
707
708//===----------------------------------------------------------------------===//
709// FloorOp folder
710//===----------------------------------------------------------------------===//
711
712OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) {
714 adaptor.getOperands(), [](const APFloat &a) {
715 APFloat result(a);
716 result.roundToIntegral(llvm::RoundingMode::TowardNegative);
717 return result;
718 });
719}
720
721//===----------------------------------------------------------------------===//
722// RoundOp folder
723//===----------------------------------------------------------------------===//
724
725OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) {
727 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
728 switch (a.getSizeInBits(a.getSemantics())) {
729 case 64:
730 return APFloat(round(a.convertToDouble()));
731 case 32:
732 return APFloat(roundf(a.convertToFloat()));
733 default:
734 return {};
735 }
736 });
737}
738
739//===----------------------------------------------------------------------===//
740// TruncOp folder
741//===----------------------------------------------------------------------===//
742
743OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
745 adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
746 switch (a.getSizeInBits(a.getSemantics())) {
747 case 64:
748 return APFloat(trunc(a.convertToDouble()));
749 case 32:
750 return APFloat(truncf(a.convertToFloat()));
751 default:
752 return {};
753 }
754 });
755}
756
757/// Materialize an integer or floating point constant.
758Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
759 Attribute value, Type type,
760 Location loc) {
761 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
762 return ub::PoisonOp::create(builder, loc, type, poison);
763
764 return arith::ConstantOp::materialize(builder, value, type, loc);
765}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
static Type getI1SameShape(Type type)
Return the type of the same shape (scalar, vector or tensor) containing i1.
Definition MathOps.cpp:24
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Attribute constFoldBinaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Attribute constFoldUnaryOpConditional(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Performs constant folding calculate with element-wise behavior on the one attributes in operands and ...
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)