MLIR 23.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
17
21#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Common utility functions
30//===----------------------------------------------------------------------===//
31
32/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33/// or splat vector bool constant.
34static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35 if (!attr)
36 return std::nullopt;
37
38 if (auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<bool>();
43 return std::nullopt;
44}
45
46// Extracts an element from the given `composite` by following the given
47// `indices`. Returns a null Attribute if error happens.
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValues<Attribute>()[indices[0]];
60 }
61
62 if (auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69}
70
71static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
74
75 return div0 || overflow;
76}
77
78//===----------------------------------------------------------------------===//
79// TableGen'erated canonicalizers
80//===----------------------------------------------------------------------===//
81
82namespace {
83#include "SPIRVCanonicalization.inc"
84} // namespace
85
86//===----------------------------------------------------------------------===//
87// spirv.AccessChainOp
88//===----------------------------------------------------------------------===//
89
90namespace {
91
92/// Combines chained `spirv::AccessChainOp` operations into one
93/// `spirv::AccessChainOp` operation.
94struct CombineChainedAccessChain final
95 : OpRewritePattern<spirv::AccessChainOp> {
96 using Base::Base;
97
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter) const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102
103 if (!parentAccessChainOp) {
104 return failure();
105 }
106
107 // Combine indices.
108 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109 llvm::append_range(indices, accessChainOp.getIndices());
110
111 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
113
114 return success();
115 }
116};
117} // namespace
118
119void spirv::AccessChainOp::getCanonicalizationPatterns(
120 RewritePatternSet &results, MLIRContext *context) {
121 results.add<CombineChainedAccessChain>(context);
122}
123
124//===----------------------------------------------------------------------===//
125// spirv.IAddCarry / spirv.ISubBorrow
126//===----------------------------------------------------------------------===//
127
128template <typename Op>
131
132 static constexpr bool IsSub = std::is_same_v<Op, spirv::ISubBorrowOp>;
133
134 LogicalResult matchAndRewrite(Op op,
135 PatternRewriter &rewriter) const override {
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
138
139 // iaddcarry (x, 0) = <0, x>
140 // isubborrow (x, 0) = <x, 0>
141 if (matchPattern(rhs, m_Zero())) {
142 std::array<Value, 2> constituents =
143 IsSub ? std::array{lhs, rhs} : std::array{rhs, lhs};
144 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149 Attribute lhsAttr;
150 Attribute rhsAttr;
151 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
152 !matchPattern(rhs, m_Constant(&rhsAttr)))
153 return failure();
154
155 auto lowBits = constFoldBinaryOp<IntegerAttr>(
156 {lhsAttr, rhsAttr},
157 [](const APInt &a, const APInt &b) { return IsSub ? a - b : a + b; });
158 if (!lowBits)
159 return failure();
160
161 auto wrapBit = constFoldBinaryOp<IntegerAttr>(
162 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
163 bool wrapped = IsSub ? a.ult(b) : (a + b).ult(a);
164 return APInt(a.getBitWidth(), wrapped ? 1 : 0);
165 });
166 if (!wrapBit)
167 return failure();
168
169 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
170 op, op.getType(), rewriter.getArrayAttr({lowBits, wrapBit}));
171 return success();
172 }
173};
174
176void spirv::IAddCarryOp::getCanonicalizationPatterns(
177 RewritePatternSet &patterns, MLIRContext *context) {
178 patterns.add<IAddCarryFold>(context);
179}
180
182void spirv::ISubBorrowOp::getCanonicalizationPatterns(
183 RewritePatternSet &patterns, MLIRContext *context) {
184 patterns.add<ISubBorrowFold>(context);
185}
186
187//===----------------------------------------------------------------------===//
188// spirv.[S|U]MulExtended
189//===----------------------------------------------------------------------===//
190
191template <typename MulOp, bool IsSigned>
192struct MulExtendedFold final : OpRewritePattern<MulOp> {
194
195 LogicalResult matchAndRewrite(MulOp op,
196 PatternRewriter &rewriter) const override {
197 Location loc = op.getLoc();
198 Value lhs = op.getOperand1();
199 Value rhs = op.getOperand2();
200 Type constituentType = lhs.getType();
201
202 // [su]mulextended (x, 0) = <0, 0>
203 if (matchPattern(rhs, m_Zero())) {
204 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
205 Value constituents[2] = {zero, zero};
206 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
207 constituents);
208 return success();
209 }
210
211 // According to the SPIR-V spec:
212 //
213 // Result Type must be from OpTypeStruct. The struct must have two
214 // members...
215 //
216 // Member 0 of the result gets the low-order bits of the multiplication.
217 //
218 // Member 1 of the result gets the high-order bits of the multiplication.
219 Attribute lhsAttr;
220 Attribute rhsAttr;
221 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
222 !matchPattern(rhs, m_Constant(&rhsAttr)))
223 return failure();
224
225 auto lowBits = constFoldBinaryOp<IntegerAttr>(
226 {lhsAttr, rhsAttr},
227 [](const APInt &a, const APInt &b) { return a * b; });
228
229 if (!lowBits)
230 return failure();
231
232 auto highBits = constFoldBinaryOp<IntegerAttr>(
233 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
234 if (IsSigned) {
235 return llvm::APIntOps::mulhs(a, b);
236 }
237 return llvm::APIntOps::mulhu(a, b);
238 });
239
240 if (!highBits)
241 return failure();
242
243 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
244 op, op.getType(), rewriter.getArrayAttr({lowBits, highBits}));
245 return success();
246 }
247};
248
250void spirv::SMulExtendedOp::getCanonicalizationPatterns(
251 RewritePatternSet &patterns, MLIRContext *context) {
252 patterns.add<SMulExtendedOpFold>(context);
253}
254
255struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
256 using Base::Base;
257
258 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
259 PatternRewriter &rewriter) const override {
260 Location loc = op.getLoc();
261 Value lhs = op.getOperand1();
262 Value rhs = op.getOperand2();
263 Type constituentType = lhs.getType();
264
265 // umulextended (x, 1) = <x, 0>
266 if (matchPattern(rhs, m_One())) {
267 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
268 Value constituents[2] = {lhs, zero};
269 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
270 constituents);
271 return success();
272 }
273
274 return failure();
275 }
276};
277
279void spirv::UMulExtendedOp::getCanonicalizationPatterns(
280 RewritePatternSet &patterns, MLIRContext *context) {
281 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
282}
283
284//===----------------------------------------------------------------------===//
285// spirv.UMod
286//===----------------------------------------------------------------------===//
287
288// Input:
289// %0 = spirv.UMod %arg0, %const32 : i32
290// %1 = spirv.UMod %0, %const4 : i32
291// Output:
292// %0 = spirv.UMod %arg0, %const32 : i32
293// %1 = spirv.UMod %arg0, %const4 : i32
294
295// The transformation is only applied if one divisor is a multiple of the other.
296
297struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
298 using Base::Base;
299
300 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
301 PatternRewriter &rewriter) const override {
302 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
303 if (!prevUMod)
304 return failure();
305
306 TypedAttr prevValue;
307 TypedAttr currValue;
308 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
309 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
310 return failure();
311
312 // Ensure that previous divisor is a multiple of the current divisor. If
313 // not, fail the transformation.
314 bool isApplicable = false;
315 if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
316 auto currInt = cast<IntegerAttr>(currValue);
317 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
318 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
319 auto currVec = cast<DenseElementsAttr>(currValue);
320 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
321 currVec.getValues<APInt>()),
322 [](const auto &pair) {
323 auto &[prev, curr] = pair;
324 return prev.urem(curr) == 0;
325 });
326 }
327
328 if (!isApplicable)
329 return failure();
330
331 // The transformation is safe. Replace the existing UMod operation with a
332 // new UMod operation, using the original dividend and the current divisor.
333 rewriter.replaceOpWithNewOp<spirv::UModOp>(
334 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
335
336 return success();
337 }
338};
339
340void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
341 MLIRContext *context) {
342 patterns.add<UModSimplification>(context);
343}
344
345//===----------------------------------------------------------------------===//
346// spirv.BitcastOp
347//===----------------------------------------------------------------------===//
348
349OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
350 Value curInput = getOperand();
351 if (getType() == curInput.getType())
352 return curInput;
353
354 // Look through nested bitcasts.
355 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
356 Value prevInput = prevCast.getOperand();
357 if (prevInput.getType() == getType())
358 return prevInput;
359
360 getOperandMutable().assign(prevInput);
361 return getResult();
362 }
363
364 // TODO(kuhar): Consider constant-folding the operand attribute.
365 return {};
366}
367
368//===----------------------------------------------------------------------===//
369// spirv.CompositeExtractOp
370//===----------------------------------------------------------------------===//
371
372OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
373 Value compositeOp = getComposite();
374
375 while (auto insertOp =
376 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
377 if (getIndices() == insertOp.getIndices())
378 return insertOp.getObject();
379 compositeOp = insertOp.getComposite();
380 }
381
382 if (auto constructOp =
383 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
384 auto type = cast<spirv::CompositeType>(constructOp.getType());
385 if (getIndices().size() == 1 &&
386 constructOp.getConstituents().size() == type.getNumElements()) {
387 auto i = cast<IntegerAttr>(*getIndices().begin());
388 if (i.getValue().getSExtValue() <
389 static_cast<int64_t>(constructOp.getConstituents().size()))
390 return constructOp.getConstituents()[i.getValue().getSExtValue()];
391 }
392 }
393
394 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
395 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
396 });
397 return extractCompositeElement(adaptor.getComposite(), indexVector);
398}
399
400//===----------------------------------------------------------------------===//
401// spirv.Constant
402//===----------------------------------------------------------------------===//
403
404OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
405 return getValue();
406}
407
408//===----------------------------------------------------------------------===//
409// spirv.IAdd
410//===----------------------------------------------------------------------===//
411
412OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
413 // x + 0 = x
414 if (matchPattern(getOperand2(), m_Zero()))
415 return getOperand1();
416
417 // According to the SPIR-V spec:
418 //
419 // The resulting value will equal the low-order N bits of the correct result
420 // R, where N is the component width and R is computed with enough precision
421 // to avoid overflow and underflow.
423 adaptor.getOperands(),
424 [](APInt a, const APInt &b) { return std::move(a) + b; });
425}
426
427//===----------------------------------------------------------------------===//
428// spirv.IMul
429//===----------------------------------------------------------------------===//
430
431OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
432 // x * 0 == 0
433 if (matchPattern(getOperand2(), m_Zero()))
434 return getOperand2();
435 // x * 1 = x
436 if (matchPattern(getOperand2(), m_One()))
437 return getOperand1();
438
439 // According to the SPIR-V spec:
440 //
441 // The resulting value will equal the low-order N bits of the correct result
442 // R, where N is the component width and R is computed with enough precision
443 // to avoid overflow and underflow.
445 adaptor.getOperands(),
446 [](const APInt &a, const APInt &b) { return a * b; });
447}
448
449//===----------------------------------------------------------------------===//
450// spirv.ISub
451//===----------------------------------------------------------------------===//
452
453OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
454 // x - x = 0
455 if (getOperand1() == getOperand2())
457
458 // According to the SPIR-V spec:
459 //
460 // The resulting value will equal the low-order N bits of the correct result
461 // R, where N is the component width and R is computed with enough precision
462 // to avoid overflow and underflow.
464 adaptor.getOperands(),
465 [](APInt a, const APInt &b) { return std::move(a) - b; });
466}
467
468//===----------------------------------------------------------------------===//
469// spirv.SDiv
470//===----------------------------------------------------------------------===//
471
472OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
473 // sdiv (x, 1) = x
474 if (matchPattern(getOperand2(), m_One()))
475 return getOperand1();
476
477 // According to the SPIR-V spec:
478 //
479 // Signed-integer division of Operand 1 divided by Operand 2.
480 // Results are computed per component. Behavior is undefined if Operand 2 is
481 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
482 // representable value for the operands' type, causing signed overflow.
483 //
484 // So don't fold during undefined behavior.
485 bool div0OrOverflow = false;
487 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
488 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
489 div0OrOverflow = true;
490 return a;
491 }
492 return a.sdiv(b);
493 });
494 return div0OrOverflow ? Attribute() : res;
495}
496
497//===----------------------------------------------------------------------===//
498// spirv.SMod
499//===----------------------------------------------------------------------===//
500
501OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
502 // smod (x, 1) = 0
503 if (matchPattern(getOperand2(), m_One()))
505
506 // According to SPIR-V spec:
507 //
508 // Signed remainder operation for the remainder whose sign matches the sign
509 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
510 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
511 // value for the operands' type, causing signed overflow. Otherwise, the
512 // result is the remainder r of Operand 1 divided by Operand 2 where if
513 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
514 //
515 // So don't fold during undefined behavior
516 bool div0OrOverflow = false;
518 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
521 return a;
522 }
523 APInt c = a.abs().urem(b.abs());
524 if (c.isZero())
525 return c;
526 if (b.isNegative()) {
527 APInt zero = APInt::getZero(c.getBitWidth());
528 return a.isNegative() ? (zero - c) : (b + c);
529 }
530 return a.isNegative() ? (b - c) : c;
531 });
532 return div0OrOverflow ? Attribute() : res;
533}
534
535//===----------------------------------------------------------------------===//
536// spirv.SRem
537//===----------------------------------------------------------------------===//
538
539OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
540 // x % 1 = 0
541 if (matchPattern(getOperand2(), m_One()))
543
544 // According to SPIR-V spec:
545 //
546 // Signed remainder operation for the remainder whose sign matches the sign
547 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
548 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
549 // value for the operands' type, causing signed overflow. Otherwise, the
550 // result is the remainder r of Operand 1 divided by Operand 2 where if
551 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
552
553 // Don't fold if it would do undefined behavior.
554 bool div0OrOverflow = false;
556 adaptor.getOperands(), [&](APInt a, const APInt &b) {
557 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
558 div0OrOverflow = true;
559 return a;
560 }
561 return a.srem(b);
562 });
563 return div0OrOverflow ? Attribute() : res;
564}
565
566//===----------------------------------------------------------------------===//
567// spirv.UDiv
568//===----------------------------------------------------------------------===//
569
570OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
571 // udiv (x, 1) = x
572 if (matchPattern(getOperand2(), m_One()))
573 return getOperand1();
574
575 // According to the SPIR-V spec:
576 //
577 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
578 // undefined if Operand 2 is 0.
579 //
580 // So don't fold during undefined behavior.
581 bool div0 = false;
583 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
584 if (div0 || b.isZero()) {
585 div0 = true;
586 return a;
587 }
588 return a.udiv(b);
589 });
590 return div0 ? Attribute() : res;
591}
592
593//===----------------------------------------------------------------------===//
594// spirv.UMod
595//===----------------------------------------------------------------------===//
596
597OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
598 // umod (x, 1) = 0
599 if (matchPattern(getOperand2(), m_One()))
601
602 // According to the SPIR-V spec:
603 //
604 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
605 // undefined if Operand 2 is 0.
606 //
607 // So don't fold during undefined behavior.
608 bool div0 = false;
610 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
611 if (div0 || b.isZero()) {
612 div0 = true;
613 return a;
614 }
615 return a.urem(b);
616 });
617 return div0 ? Attribute() : res;
618}
619
620//===----------------------------------------------------------------------===//
621// spirv.SNegate
622//===----------------------------------------------------------------------===//
623
624OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
625 // -(-x) = 0 - (0 - x) = x
626 auto op = getOperand();
627 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
628 return negateOp->getOperand(0);
629
630 // According to the SPIR-V spec:
631 //
632 // Signed-integer subtract of Operand from zero.
634 adaptor.getOperands(), [](const APInt &a) {
635 APInt zero = APInt::getZero(a.getBitWidth());
636 return zero - a;
637 });
638}
639
640//===----------------------------------------------------------------------===//
641// spirv.NotOp
642//===----------------------------------------------------------------------===//
643
644OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
645 // !(!x) = x
646 auto op = getOperand();
647 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
648 return notOp->getOperand(0);
649
650 // According to the SPIR-V spec:
651 //
652 // Complement the bits of Operand.
653 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
654 a.flipAllBits();
655 return a;
656 });
657}
658
659//===----------------------------------------------------------------------===//
660// spirv.LogicalAnd
661//===----------------------------------------------------------------------===//
662
663OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
664 if (std::optional<bool> rhs =
665 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
666 // x && true = x
667 if (*rhs)
668 return getOperand1();
669
670 // x && false = false
671 if (!*rhs)
672 return adaptor.getOperand2();
673 }
674
675 return Attribute();
676}
677
678//===----------------------------------------------------------------------===//
679// spirv.LogicalEqualOp
680//===----------------------------------------------------------------------===//
681
683spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
684 // x == x -> true
685 if (getOperand1() == getOperand2()) {
686 auto trueAttr = BoolAttr::get(getContext(), true);
687 if (isa<IntegerType>(getType()))
688 return trueAttr;
689 if (auto vecTy = dyn_cast<VectorType>(getType()))
690 return SplatElementsAttr::get(vecTy, trueAttr);
691 }
692
694 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
695 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
696 });
697}
698
699//===----------------------------------------------------------------------===//
700// spirv.LogicalNotEqualOp
701//===----------------------------------------------------------------------===//
702
703OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
704 if (std::optional<bool> rhs =
705 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
706 // x != false -> x
707 if (!rhs.value())
708 return getOperand1();
709 }
710
711 // x == x -> false
712 if (getOperand1() == getOperand2()) {
713 auto falseAttr = BoolAttr::get(getContext(), false);
714 if (isa<IntegerType>(getType()))
715 return falseAttr;
716 if (auto vecTy = dyn_cast<VectorType>(getType()))
717 return SplatElementsAttr::get(vecTy, falseAttr);
718 }
719
721 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
722 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
723 });
724}
725
726//===----------------------------------------------------------------------===//
727// spirv.LogicalNot
728//===----------------------------------------------------------------------===//
729
730OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
731 // !(!x) = x
732 auto op = getOperand();
733 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
734 return notOp->getOperand(0);
735
736 // According to the SPIR-V spec:
737 //
738 // Complement the bits of Operand.
739 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
740 [](const APInt &a) {
741 APInt zero = APInt::getZero(1);
742 return a == 1 ? zero : (zero + 1);
743 });
744}
745
746void spirv::LogicalNotOp::getCanonicalizationPatterns(
747 RewritePatternSet &results, MLIRContext *context) {
748 results
749 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
750 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
751 context);
752}
753
754//===----------------------------------------------------------------------===//
755// spirv.LogicalOr
756//===----------------------------------------------------------------------===//
757
758OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
759 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
760 if (*rhs) {
761 // x || true = true
762 return adaptor.getOperand2();
763 }
764
765 if (!*rhs) {
766 // x || false = x
767 return getOperand1();
768 }
769 }
770
771 return Attribute();
772}
773
774//===----------------------------------------------------------------------===//
775// spirv.SelectOp
776//===----------------------------------------------------------------------===//
777
778OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
779 // spirv.Select _ x x -> x
780 Value trueVals = getTrueValue();
781 Value falseVals = getFalseValue();
782 if (trueVals == falseVals)
783 return trueVals;
784
785 ArrayRef<Attribute> operands = adaptor.getOperands();
786
787 // spirv.Select true x y -> x
788 // spirv.Select false x y -> y
789 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
790 return *boolAttr ? trueVals : falseVals;
791
792 // Check that all the operands are constant
793 if (!operands[0] || !operands[1] || !operands[2])
794 return Attribute();
795
796 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
797 // the scalar case. Hence, we are only required to consider the case of
798 // DenseElementsAttr in foldSelectOp.
799 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
800 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
801 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
802 if (!condAttrs || !trueAttrs || !falseAttrs)
803 return Attribute();
804
805 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
806 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
807 falseAttrs.getValues<Attribute>());
808 for (auto [result, cond, falseRes] : iters) {
809 if (!cond.getValue())
810 result = falseRes;
811 }
812
813 auto resultType = trueAttrs.getType();
814 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
815}
816
817//===----------------------------------------------------------------------===//
818// spirv.IEqualOp
819//===----------------------------------------------------------------------===//
820
821OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
822 // x == x -> true
823 if (getOperand1() == getOperand2()) {
824 auto trueAttr = BoolAttr::get(getContext(), true);
825 if (isa<IntegerType>(getType()))
826 return trueAttr;
827 if (auto vecTy = dyn_cast<VectorType>(getType()))
828 return SplatElementsAttr::get(vecTy, trueAttr);
829 }
830
832 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
833 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
834 });
835}
836
837//===----------------------------------------------------------------------===//
838// spirv.INotEqualOp
839//===----------------------------------------------------------------------===//
840
841OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
842 // x == x -> false
843 if (getOperand1() == getOperand2()) {
844 auto falseAttr = BoolAttr::get(getContext(), false);
845 if (isa<IntegerType>(getType()))
846 return falseAttr;
847 if (auto vecTy = dyn_cast<VectorType>(getType()))
848 return SplatElementsAttr::get(vecTy, falseAttr);
849 }
850
852 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
853 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
854 });
855}
856
857//===----------------------------------------------------------------------===//
858// spirv.SGreaterThan
859//===----------------------------------------------------------------------===//
860
862spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
863 // x == x -> false
864 if (getOperand1() == getOperand2()) {
865 auto falseAttr = BoolAttr::get(getContext(), false);
866 if (isa<IntegerType>(getType()))
867 return falseAttr;
868 if (auto vecTy = dyn_cast<VectorType>(getType()))
869 return SplatElementsAttr::get(vecTy, falseAttr);
870 }
871
873 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
874 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
875 });
876}
877
878//===----------------------------------------------------------------------===//
879// spirv.SGreaterThanEqual
880//===----------------------------------------------------------------------===//
881
882OpFoldResult spirv::SGreaterThanEqualOp::fold(
883 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
884 // x == x -> true
885 if (getOperand1() == getOperand2()) {
886 auto trueAttr = BoolAttr::get(getContext(), true);
887 if (isa<IntegerType>(getType()))
888 return trueAttr;
889 if (auto vecTy = dyn_cast<VectorType>(getType()))
890 return SplatElementsAttr::get(vecTy, trueAttr);
891 }
892
894 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
895 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
896 });
897}
898
899//===----------------------------------------------------------------------===//
900// spirv.UGreaterThan
901//===----------------------------------------------------------------------===//
902
904spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
905 // x == x -> false
906 if (getOperand1() == getOperand2()) {
907 auto falseAttr = BoolAttr::get(getContext(), false);
908 if (isa<IntegerType>(getType()))
909 return falseAttr;
910 if (auto vecTy = dyn_cast<VectorType>(getType()))
911 return SplatElementsAttr::get(vecTy, falseAttr);
912 }
913
915 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
916 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
917 });
918}
919
920//===----------------------------------------------------------------------===//
921// spirv.UGreaterThanEqual
922//===----------------------------------------------------------------------===//
923
924OpFoldResult spirv::UGreaterThanEqualOp::fold(
925 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
926 // x == x -> true
927 if (getOperand1() == getOperand2()) {
928 auto trueAttr = BoolAttr::get(getContext(), true);
929 if (isa<IntegerType>(getType()))
930 return trueAttr;
931 if (auto vecTy = dyn_cast<VectorType>(getType()))
932 return SplatElementsAttr::get(vecTy, trueAttr);
933 }
934
936 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
937 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
938 });
939}
940
941//===----------------------------------------------------------------------===//
942// spirv.SLessThan
943//===----------------------------------------------------------------------===//
944
945OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
946 // x == x -> false
947 if (getOperand1() == getOperand2()) {
948 auto falseAttr = BoolAttr::get(getContext(), false);
949 if (isa<IntegerType>(getType()))
950 return falseAttr;
951 if (auto vecTy = dyn_cast<VectorType>(getType()))
952 return SplatElementsAttr::get(vecTy, falseAttr);
953 }
954
956 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
957 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
958 });
959}
960
961//===----------------------------------------------------------------------===//
962// spirv.SLessThanEqual
963//===----------------------------------------------------------------------===//
964
966spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
967 // x == x -> true
968 if (getOperand1() == getOperand2()) {
969 auto trueAttr = BoolAttr::get(getContext(), true);
970 if (isa<IntegerType>(getType()))
971 return trueAttr;
972 if (auto vecTy = dyn_cast<VectorType>(getType()))
973 return SplatElementsAttr::get(vecTy, trueAttr);
974 }
975
977 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
978 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
979 });
980}
981
982//===----------------------------------------------------------------------===//
983// spirv.ULessThan
984//===----------------------------------------------------------------------===//
985
986OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
987 // x == x -> false
988 if (getOperand1() == getOperand2()) {
989 auto falseAttr = BoolAttr::get(getContext(), false);
990 if (isa<IntegerType>(getType()))
991 return falseAttr;
992 if (auto vecTy = dyn_cast<VectorType>(getType()))
993 return SplatElementsAttr::get(vecTy, falseAttr);
994 }
995
997 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
998 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
999 });
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// spirv.ULessThanEqual
1004//===----------------------------------------------------------------------===//
1005
1007spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1008 // x == x -> true
1009 if (getOperand1() == getOperand2()) {
1010 auto trueAttr = BoolAttr::get(getContext(), true);
1011 if (isa<IntegerType>(getType()))
1012 return trueAttr;
1013 if (auto vecTy = dyn_cast<VectorType>(getType()))
1014 return SplatElementsAttr::get(vecTy, trueAttr);
1015 }
1016
1018 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1019 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1020 });
1021}
1022
1023//===----------------------------------------------------------------------===//
1024// spirv.ShiftLeftLogical
1025//===----------------------------------------------------------------------===//
1026
1027OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1028 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1029 // x << 0 -> x
1030 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1031 return getOperand1();
1032 }
1033
1034 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1035
1036 // Results are computed per component, and within each component, per bit...
1037 //
1038 // The result is undefined if Shift is greater than or equal to the bit width
1039 // of the components of Base.
1040 //
1041 // So we can use the APInt << method, but don't fold if undefined behaviour.
1042 bool shiftToLarge = false;
1044 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1045 if (shiftToLarge || b.uge(a.getBitWidth())) {
1046 shiftToLarge = true;
1047 return a;
1048 }
1049 return a << b;
1050 });
1051 return shiftToLarge ? Attribute() : res;
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// spirv.ShiftRightArithmetic
1056//===----------------------------------------------------------------------===//
1057
1058OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1059 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1060 // x >> 0 -> x
1061 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1062 return getOperand1();
1063 }
1064
1065 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1066
1067 // Results are computed per component, and within each component, per bit...
1068 //
1069 // The result is undefined if Shift is greater than or equal to the bit width
1070 // of the components of Base.
1071 //
1072 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1073 bool shiftToLarge = false;
1075 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1076 if (shiftToLarge || b.uge(a.getBitWidth())) {
1077 shiftToLarge = true;
1078 return a;
1079 }
1080 return a.ashr(b);
1081 });
1082 return shiftToLarge ? Attribute() : res;
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// spirv.ShiftRightLogical
1087//===----------------------------------------------------------------------===//
1088
1089OpFoldResult spirv::ShiftRightLogicalOp::fold(
1090 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1091 // x >> 0 -> x
1092 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1093 return getOperand1();
1094 }
1095
1096 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1097
1098 // Results are computed per component, and within each component, per bit...
1099 //
1100 // The result is undefined if Shift is greater than or equal to the bit width
1101 // of the components of Base.
1102 //
1103 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1104 bool shiftToLarge = false;
1106 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1107 if (shiftToLarge || b.uge(a.getBitWidth())) {
1108 shiftToLarge = true;
1109 return a;
1110 }
1111 return a.lshr(b);
1112 });
1113 return shiftToLarge ? Attribute() : res;
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// spirv.BitwiseAndOp
1118//===----------------------------------------------------------------------===//
1119
1121spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1122 // x & x -> x
1123 if (getOperand1() == getOperand2()) {
1124 return getOperand1();
1125 }
1126
1127 APInt rhsMask;
1128 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1129 // x & 0 -> 0
1130 if (rhsMask.isZero())
1131 return getOperand2();
1132
1133 // x & <all ones> -> x
1134 if (rhsMask.isAllOnes())
1135 return getOperand1();
1136
1137 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1138 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1139 int valueBits =
1141 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1142 return getOperand1();
1143 }
1144 }
1145
1146 // According to the SPIR-V spec:
1147 //
1148 // Type is a scalar or vector of integer type.
1149 // Results are computed per component, and within each component, per bit.
1150 // So we can use the APInt & method.
1152 adaptor.getOperands(),
1153 [](const APInt &a, const APInt &b) { return a & b; });
1154}
1155
1156//===----------------------------------------------------------------------===//
1157// spirv.BitwiseOrOp
1158//===----------------------------------------------------------------------===//
1159
1160OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1161 // x | x -> x
1162 if (getOperand1() == getOperand2()) {
1163 return getOperand1();
1164 }
1165
1166 APInt rhsMask;
1167 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1168 // x | 0 -> x
1169 if (rhsMask.isZero())
1170 return getOperand1();
1171
1172 // x | <all ones> -> <all ones>
1173 if (rhsMask.isAllOnes())
1174 return getOperand2();
1175 }
1176
1177 // According to the SPIR-V spec:
1178 //
1179 // Type is a scalar or vector of integer type.
1180 // Results are computed per component, and within each component, per bit.
1181 // So we can use the APInt | method.
1183 adaptor.getOperands(),
1184 [](const APInt &a, const APInt &b) { return a | b; });
1185}
1186
1187//===----------------------------------------------------------------------===//
1188// spirv.BitwiseXorOp
1189//===----------------------------------------------------------------------===//
1190
1192spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1193 // x ^ 0 -> x
1194 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1195 return getOperand1();
1196 }
1197
1198 // x ^ x -> 0
1199 if (getOperand1() == getOperand2())
1201
1202 // According to the SPIR-V spec:
1203 //
1204 // Type is a scalar or vector of integer type.
1205 // Results are computed per component, and within each component, per bit.
1206 // So we can use the APInt ^ method.
1208 adaptor.getOperands(),
1209 [](const APInt &a, const APInt &b) { return a ^ b; });
1210}
1211
1212//===----------------------------------------------------------------------===//
1213// spirv.mlir.selection
1214//===----------------------------------------------------------------------===//
1215
1216namespace {
1217// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1218// following layout:
1219//
1220// +-----------------------------------------------+
1221// | header block |
1222// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1223// +-----------------------------------------------+
1224// / \
1225// ...
1226//
1227//
1228// +------------------------+ +------------------------+
1229// | case #0 | | case #1 |
1230// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1231// | spirv.Branch ^merge | | spirv.Branch ^merge |
1232// +------------------------+ +------------------------+
1233//
1234//
1235// ...
1236// \ /
1237// v
1238// +-------------+
1239// | merge block |
1240// +-------------+
1241//
1242struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1243 using Base::Base;
1244
1245 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1246 PatternRewriter &rewriter) const override {
1247 Operation *op = selectionOp.getOperation();
1248 Region &body = op->getRegion(0);
1249 // Verifier allows an empty region for `spirv.mlir.selection`.
1250 if (body.empty()) {
1251 return failure();
1252 }
1253
1254 // Check that region consists of 4 blocks:
1255 // header block, `true` block, `false` block and merge block.
1256 if (llvm::range_size(body) != 4) {
1257 return failure();
1258 }
1259
1260 Block *headerBlock = selectionOp.getHeaderBlock();
1261 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1262 return failure();
1263 }
1264
1265 auto brConditionalOp =
1266 cast<spirv::BranchConditionalOp>(headerBlock->front());
1267
1268 Block *trueBlock = brConditionalOp.getSuccessor(0);
1269 Block *falseBlock = brConditionalOp.getSuccessor(1);
1270 Block *mergeBlock = selectionOp.getMergeBlock();
1271
1272 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1273 return failure();
1274
1275 Value trueValue = getSrcValue(trueBlock);
1276 Value falseValue = getSrcValue(falseBlock);
1277 Value ptrValue = getDstPtr(trueBlock);
1278 auto storeOpAttributes =
1279 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1280
1281 auto selectOp = spirv::SelectOp::create(
1282 rewriter, selectionOp.getLoc(), trueValue.getType(),
1283 brConditionalOp.getCondition(), trueValue, falseValue);
1284 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1285 selectOp.getResult(), storeOpAttributes);
1286
1287 // `spirv.mlir.selection` is not needed anymore.
1288 rewriter.eraseOp(op);
1289 return success();
1290 }
1291
1292private:
1293 // Checks that given blocks follow the following rules:
1294 // 1. Each conditional block consists of two operations, the first operation
1295 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1296 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1297 // 3. A control flow goes into the given merge block from the given
1298 // conditional blocks.
1299 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1300 Block *mergeBlock) const;
1301
1302 bool onlyContainsBranchConditionalOp(Block *block) const {
1303 return llvm::hasSingleElement(*block) &&
1304 isa<spirv::BranchConditionalOp>(block->front());
1305 }
1306
1307 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1308 return lhs->getDiscardableAttrDictionary() ==
1309 rhs->getDiscardableAttrDictionary() &&
1310 lhs.getProperties() == rhs.getProperties();
1311 }
1312
1313 // Returns a source value for the given block.
1314 Value getSrcValue(Block *block) const {
1315 auto storeOp = cast<spirv::StoreOp>(block->front());
1316 return storeOp.getValue();
1317 }
1318
1319 // Returns a destination value for the given block.
1320 Value getDstPtr(Block *block) const {
1321 auto storeOp = cast<spirv::StoreOp>(block->front());
1322 return storeOp.getPtr();
1323 }
1324};
1325
1326LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1327 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1328 // Each block must consists of 2 operations.
1329 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1330 return failure();
1331 }
1332
1333 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1334 auto trueBrBranchOp =
1335 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1336 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1337 auto falseBrBranchOp =
1338 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1339
1340 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1341 !falseBrBranchOp) {
1342 return failure();
1343 }
1344
1345 // Checks that given type is valid for `spirv.SelectOp`.
1346 // According to SPIR-V spec:
1347 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1348 // Starting with version 1.4, Result Type can additionally be a composite type
1349 // other than a vector."
1350 bool isScalarOrVector =
1351 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1352 .isScalarOrVector();
1353
1354 // Check that each `spirv.Store` uses the same pointer, memory access
1355 // attributes and a valid type of the value.
1356 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1357 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1358 return failure();
1359 }
1360
1361 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1362 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1363 return failure();
1364 }
1365
1366 return success();
1367}
1368} // namespace
1369
1370void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1371 MLIRContext *context) {
1372 results.add<ConvertSelectionOpToSelect>(context);
1373}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
ArithmeticExtendedBinaryFold< spirv::ISubBorrowOp > ISubBorrowFold
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
ArithmeticExtendedBinaryFold< spirv::IAddCarryOp > IAddCarryFold
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})