MLIR 23.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
17
21#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Common utility functions
30//===----------------------------------------------------------------------===//
31
32/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33/// or splat vector bool constant.
34static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35 if (!attr)
36 return std::nullopt;
37
38 if (auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<bool>();
43 return std::nullopt;
44}
45
46// Extracts an element from the given `composite` by following the given
47// `indices`. Returns a null Attribute if error happens.
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValues<Attribute>()[indices[0]];
60 }
61
62 if (auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69}
70
71static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
74
75 return div0 || overflow;
76}
77
78//===----------------------------------------------------------------------===//
79// TableGen'erated canonicalizers
80//===----------------------------------------------------------------------===//
81
82namespace {
83#include "SPIRVCanonicalization.inc"
84} // namespace
85
86//===----------------------------------------------------------------------===//
87// spirv.AccessChainOp
88//===----------------------------------------------------------------------===//
89
90namespace {
91
92/// Combines chained `spirv::AccessChainOp` operations into one
93/// `spirv::AccessChainOp` operation.
94struct CombineChainedAccessChain final
95 : OpRewritePattern<spirv::AccessChainOp> {
96 using Base::Base;
97
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter) const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102
103 if (!parentAccessChainOp) {
104 return failure();
105 }
106
107 // Combine indices.
108 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109 llvm::append_range(indices, accessChainOp.getIndices());
110
111 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
113
114 return success();
115 }
116};
117} // namespace
118
119void spirv::AccessChainOp::getCanonicalizationPatterns(
120 RewritePatternSet &results, MLIRContext *context) {
121 results.add<CombineChainedAccessChain>(context);
122}
123
124//===----------------------------------------------------------------------===//
125// spirv.IAddCarry / spirv.ISubBorrow
126//===----------------------------------------------------------------------===//
127
128template <typename Op>
131
132 static constexpr bool IsSub = std::is_same_v<Op, spirv::ISubBorrowOp>;
133
134 LogicalResult matchAndRewrite(Op op,
135 PatternRewriter &rewriter) const override {
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
138
139 // iaddcarry (x, 0) = <0, x>
140 // isubborrow (x, 0) = <x, 0>
141 if (matchPattern(rhs, m_Zero())) {
142 std::array<Value, 2> constituents =
143 IsSub ? std::array{lhs, rhs} : std::array{rhs, lhs};
144 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149 Attribute lhsAttr;
150 Attribute rhsAttr;
151 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
152 !matchPattern(rhs, m_Constant(&rhsAttr)))
153 return failure();
154
155 auto lowBits = constFoldBinaryOp<IntegerAttr>(
156 {lhsAttr, rhsAttr},
157 [](const APInt &a, const APInt &b) { return IsSub ? a - b : a + b; });
158 if (!lowBits)
159 return failure();
160
161 auto wrapBit = constFoldBinaryOp<IntegerAttr>(
162 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
163 bool wrapped = IsSub ? a.ult(b) : (a + b).ult(a);
164 return APInt(a.getBitWidth(), wrapped ? 1 : 0);
165 });
166 if (!wrapBit)
167 return failure();
168
169 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
170 op, op.getType(), rewriter.getArrayAttr({lowBits, wrapBit}));
171 return success();
172 }
173};
174
176void spirv::IAddCarryOp::getCanonicalizationPatterns(
177 RewritePatternSet &patterns, MLIRContext *context) {
178 patterns.add<IAddCarryFold>(context);
179}
180
182void spirv::ISubBorrowOp::getCanonicalizationPatterns(
183 RewritePatternSet &patterns, MLIRContext *context) {
184 patterns.add<ISubBorrowFold>(context);
185}
186
187//===----------------------------------------------------------------------===//
188// spirv.[S|U]MulExtended
189//===----------------------------------------------------------------------===//
190
191template <typename MulOp, bool IsSigned>
192struct MulExtendedFold final : OpRewritePattern<MulOp> {
194
195 LogicalResult matchAndRewrite(MulOp op,
196 PatternRewriter &rewriter) const override {
197 Location loc = op.getLoc();
198 Value lhs = op.getOperand1();
199 Value rhs = op.getOperand2();
200 Type constituentType = lhs.getType();
201
202 // [su]mulextended (x, 0) = <0, 0>
203 if (matchPattern(rhs, m_Zero())) {
204 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
205 Value constituents[2] = {zero, zero};
206 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
207 constituents);
208 return success();
209 }
210
211 // According to the SPIR-V spec:
212 //
213 // Result Type must be from OpTypeStruct. The struct must have two
214 // members...
215 //
216 // Member 0 of the result gets the low-order bits of the multiplication.
217 //
218 // Member 1 of the result gets the high-order bits of the multiplication.
219 Attribute lhsAttr;
220 Attribute rhsAttr;
221 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
222 !matchPattern(rhs, m_Constant(&rhsAttr)))
223 return failure();
224
225 auto lowBits = constFoldBinaryOp<IntegerAttr>(
226 {lhsAttr, rhsAttr},
227 [](const APInt &a, const APInt &b) { return a * b; });
228
229 if (!lowBits)
230 return failure();
231
232 auto highBits = constFoldBinaryOp<IntegerAttr>(
233 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
234 if (IsSigned) {
235 return llvm::APIntOps::mulhs(a, b);
236 }
237 return llvm::APIntOps::mulhu(a, b);
238 });
239
240 if (!highBits)
241 return failure();
242
243 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
244 op, op.getType(), rewriter.getArrayAttr({lowBits, highBits}));
245 return success();
246 }
247};
248
250void spirv::SMulExtendedOp::getCanonicalizationPatterns(
251 RewritePatternSet &patterns, MLIRContext *context) {
252 patterns.add<SMulExtendedOpFold>(context);
253}
254
255struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
256 using Base::Base;
257
258 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
259 PatternRewriter &rewriter) const override {
260 Location loc = op.getLoc();
261 Value lhs = op.getOperand1();
262 Value rhs = op.getOperand2();
263 Type constituentType = lhs.getType();
264
265 // umulextended (x, 1) = <x, 0>
266 if (matchPattern(rhs, m_One())) {
267 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
268 Value constituents[2] = {lhs, zero};
269 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
270 constituents);
271 return success();
272 }
273
274 return failure();
275 }
276};
277
279void spirv::UMulExtendedOp::getCanonicalizationPatterns(
280 RewritePatternSet &patterns, MLIRContext *context) {
281 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
282}
283
284//===----------------------------------------------------------------------===//
285// spirv.UMod
286//===----------------------------------------------------------------------===//
287
288// Input:
289// %0 = spirv.UMod %arg0, %const32 : i32
290// %1 = spirv.UMod %0, %const4 : i32
291// Output:
292// %0 = spirv.UMod %arg0, %const32 : i32
293// %1 = spirv.UMod %arg0, %const4 : i32
294
295// The transformation is only applied if one divisor is a multiple of the other.
296
297struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
298 using Base::Base;
299
300 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
301 PatternRewriter &rewriter) const override {
302 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
303 if (!prevUMod)
304 return failure();
305
306 TypedAttr prevValue;
307 TypedAttr currValue;
308 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
309 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
310 return failure();
311
312 // Ensure that previous divisor is a multiple of the current divisor. If
313 // not, fail the transformation.
314 bool isApplicable = false;
315 if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
316 auto currInt = cast<IntegerAttr>(currValue);
317 if (currInt.getValue().isZero())
318 return failure();
319 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
320 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
321 auto currVec = cast<DenseElementsAttr>(currValue);
322 if (llvm::any_of(currVec.getValues<APInt>(),
323 [](const APInt &curr) { return curr.isZero(); }))
324 return failure();
325 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
326 currVec.getValues<APInt>()),
327 [](const auto &pair) {
328 auto &[prev, curr] = pair;
329 return prev.urem(curr) == 0;
330 });
331 }
332
333 if (!isApplicable)
334 return failure();
335
336 // The transformation is safe. Replace the existing UMod operation with a
337 // new UMod operation, using the original dividend and the current divisor.
338 rewriter.replaceOpWithNewOp<spirv::UModOp>(
339 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
340
341 return success();
342 }
343};
344
345void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
346 MLIRContext *context) {
347 patterns.add<UModSimplification>(context);
348}
349
350//===----------------------------------------------------------------------===//
351// spirv.BitcastOp
352//===----------------------------------------------------------------------===//
353
354OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
355 Value curInput = getOperand();
356 if (getType() == curInput.getType())
357 return curInput;
358
359 // Look through nested bitcasts.
360 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
361 Value prevInput = prevCast.getOperand();
362 if (prevInput.getType() == getType())
363 return prevInput;
364
365 getOperandMutable().assign(prevInput);
366 return getResult();
367 }
368
369 // TODO(kuhar): Consider constant-folding the operand attribute.
370 return {};
371}
372
373//===----------------------------------------------------------------------===//
374// spirv.CompositeExtractOp
375//===----------------------------------------------------------------------===//
376
377OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
378 Value compositeOp = getComposite();
379
380 while (auto insertOp =
381 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
382 if (getIndices() == insertOp.getIndices())
383 return insertOp.getObject();
384 compositeOp = insertOp.getComposite();
385 }
386
387 if (auto constructOp =
388 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
389 auto type = cast<spirv::CompositeType>(constructOp.getType());
390 if (getIndices().size() == 1 &&
391 constructOp.getConstituents().size() == type.getNumElements()) {
392 auto i = cast<IntegerAttr>(*getIndices().begin());
393 if (i.getValue().getSExtValue() <
394 static_cast<int64_t>(constructOp.getConstituents().size()))
395 return constructOp.getConstituents()[i.getValue().getSExtValue()];
396 }
397 }
398
399 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
400 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
401 });
402 return extractCompositeElement(adaptor.getComposite(), indexVector);
403}
404
405//===----------------------------------------------------------------------===//
406// spirv.Constant
407//===----------------------------------------------------------------------===//
408
409OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
410 return getValue();
411}
412
413//===----------------------------------------------------------------------===//
414// spirv.IAdd
415//===----------------------------------------------------------------------===//
416
417OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
418 // x + 0 = x
419 if (matchPattern(getOperand2(), m_Zero()))
420 return getOperand1();
421
422 // According to the SPIR-V spec:
423 //
424 // The resulting value will equal the low-order N bits of the correct result
425 // R, where N is the component width and R is computed with enough precision
426 // to avoid overflow and underflow.
428 adaptor.getOperands(),
429 [](APInt a, const APInt &b) { return std::move(a) + b; });
430}
431
432//===----------------------------------------------------------------------===//
433// spirv.IMul
434//===----------------------------------------------------------------------===//
435
436OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
437 // x * 0 == 0
438 if (matchPattern(getOperand2(), m_Zero()))
439 return getOperand2();
440 // x * 1 = x
441 if (matchPattern(getOperand2(), m_One()))
442 return getOperand1();
443
444 // According to the SPIR-V spec:
445 //
446 // The resulting value will equal the low-order N bits of the correct result
447 // R, where N is the component width and R is computed with enough precision
448 // to avoid overflow and underflow.
450 adaptor.getOperands(),
451 [](const APInt &a, const APInt &b) { return a * b; });
452}
453
454//===----------------------------------------------------------------------===//
455// spirv.ISub
456//===----------------------------------------------------------------------===//
457
458OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
459 // x - x = 0
460 if (getOperand1() == getOperand2())
462
463 // According to the SPIR-V spec:
464 //
465 // The resulting value will equal the low-order N bits of the correct result
466 // R, where N is the component width and R is computed with enough precision
467 // to avoid overflow and underflow.
469 adaptor.getOperands(),
470 [](APInt a, const APInt &b) { return std::move(a) - b; });
471}
472
473//===----------------------------------------------------------------------===//
474// spirv.SDiv
475//===----------------------------------------------------------------------===//
476
477OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
478 // sdiv (x, 1) = x
479 if (matchPattern(getOperand2(), m_One()))
480 return getOperand1();
481
482 // According to the SPIR-V spec:
483 //
484 // Signed-integer division of Operand 1 divided by Operand 2.
485 // Results are computed per component. Behavior is undefined if Operand 2 is
486 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
487 // representable value for the operands' type, causing signed overflow.
488 //
489 // So don't fold during undefined behavior.
490 bool div0OrOverflow = false;
492 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
493 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
494 div0OrOverflow = true;
495 return a;
496 }
497 return a.sdiv(b);
498 });
499 return div0OrOverflow ? Attribute() : res;
500}
501
502//===----------------------------------------------------------------------===//
503// spirv.SMod
504//===----------------------------------------------------------------------===//
505
506OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
507 // smod (x, 1) = 0
508 if (matchPattern(getOperand2(), m_One()))
510
511 // According to SPIR-V spec:
512 //
513 // Signed remainder operation for the remainder whose sign matches the sign
514 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
515 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
516 // value for the operands' type, causing signed overflow. Otherwise, the
517 // result is the remainder r of Operand 1 divided by Operand 2 where if
518 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
519 //
520 // So don't fold during undefined behavior
521 bool div0OrOverflow = false;
523 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
524 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
525 div0OrOverflow = true;
526 return a;
527 }
528 APInt c = a.abs().urem(b.abs());
529 if (c.isZero())
530 return c;
531 if (b.isNegative()) {
532 APInt zero = APInt::getZero(c.getBitWidth());
533 return a.isNegative() ? (zero - c) : (b + c);
534 }
535 return a.isNegative() ? (b - c) : c;
536 });
537 return div0OrOverflow ? Attribute() : res;
538}
539
540//===----------------------------------------------------------------------===//
541// spirv.SRem
542//===----------------------------------------------------------------------===//
543
544OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
545 // x % 1 = 0
546 if (matchPattern(getOperand2(), m_One()))
548
549 // According to SPIR-V spec:
550 //
551 // Signed remainder operation for the remainder whose sign matches the sign
552 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
553 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
554 // value for the operands' type, causing signed overflow. Otherwise, the
555 // result is the remainder r of Operand 1 divided by Operand 2 where if
556 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
557
558 // Don't fold if it would do undefined behavior.
559 bool div0OrOverflow = false;
561 adaptor.getOperands(), [&](APInt a, const APInt &b) {
562 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
563 div0OrOverflow = true;
564 return a;
565 }
566 return a.srem(b);
567 });
568 return div0OrOverflow ? Attribute() : res;
569}
570
571//===----------------------------------------------------------------------===//
572// spirv.UDiv
573//===----------------------------------------------------------------------===//
574
575OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
576 // udiv (x, 1) = x
577 if (matchPattern(getOperand2(), m_One()))
578 return getOperand1();
579
580 // According to the SPIR-V spec:
581 //
582 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
583 // undefined if Operand 2 is 0.
584 //
585 // So don't fold during undefined behavior.
586 bool div0 = false;
588 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
589 if (div0 || b.isZero()) {
590 div0 = true;
591 return a;
592 }
593 return a.udiv(b);
594 });
595 return div0 ? Attribute() : res;
596}
597
598//===----------------------------------------------------------------------===//
599// spirv.UMod
600//===----------------------------------------------------------------------===//
601
602OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
603 // umod (x, 1) = 0
604 if (matchPattern(getOperand2(), m_One()))
606
607 // According to the SPIR-V spec:
608 //
609 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
610 // undefined if Operand 2 is 0.
611 //
612 // So don't fold during undefined behavior.
613 bool div0 = false;
615 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
616 if (div0 || b.isZero()) {
617 div0 = true;
618 return a;
619 }
620 return a.urem(b);
621 });
622 return div0 ? Attribute() : res;
623}
624
625//===----------------------------------------------------------------------===//
626// spirv.SNegate
627//===----------------------------------------------------------------------===//
628
629OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
630 // -(-x) = 0 - (0 - x) = x
631 auto op = getOperand();
632 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
633 return negateOp->getOperand(0);
634
635 // According to the SPIR-V spec:
636 //
637 // Signed-integer subtract of Operand from zero.
639 adaptor.getOperands(), [](const APInt &a) {
640 APInt zero = APInt::getZero(a.getBitWidth());
641 return zero - a;
642 });
643}
644
645//===----------------------------------------------------------------------===//
646// spirv.NotOp
647//===----------------------------------------------------------------------===//
648
649OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
650 // !(!x) = x
651 auto op = getOperand();
652 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
653 return notOp->getOperand(0);
654
655 // According to the SPIR-V spec:
656 //
657 // Complement the bits of Operand.
658 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
659 a.flipAllBits();
660 return a;
661 });
662}
663
664//===----------------------------------------------------------------------===//
665// spirv.LogicalAnd
666//===----------------------------------------------------------------------===//
667
668OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
669 if (std::optional<bool> rhs =
670 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
671 // x && true = x
672 if (*rhs)
673 return getOperand1();
674
675 // x && false = false
676 if (!*rhs)
677 return adaptor.getOperand2();
678 }
679
680 return Attribute();
681}
682
683//===----------------------------------------------------------------------===//
684// spirv.LogicalEqualOp
685//===----------------------------------------------------------------------===//
686
688spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
689 // x == x -> true
690 if (getOperand1() == getOperand2()) {
691 auto trueAttr = BoolAttr::get(getContext(), true);
692 if (isa<IntegerType>(getType()))
693 return trueAttr;
694 if (auto vecTy = dyn_cast<VectorType>(getType()))
695 return SplatElementsAttr::get(vecTy, trueAttr);
696 }
697
699 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
700 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
701 });
702}
703
704//===----------------------------------------------------------------------===//
705// spirv.LogicalNotEqualOp
706//===----------------------------------------------------------------------===//
707
708OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
709 if (std::optional<bool> rhs =
710 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
711 // x != false -> x
712 if (!rhs.value())
713 return getOperand1();
714 }
715
716 // x == x -> false
717 if (getOperand1() == getOperand2()) {
718 auto falseAttr = BoolAttr::get(getContext(), false);
719 if (isa<IntegerType>(getType()))
720 return falseAttr;
721 if (auto vecTy = dyn_cast<VectorType>(getType()))
722 return SplatElementsAttr::get(vecTy, falseAttr);
723 }
724
726 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
727 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
728 });
729}
730
731//===----------------------------------------------------------------------===//
732// spirv.LogicalNot
733//===----------------------------------------------------------------------===//
734
735OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
736 // !(!x) = x
737 auto op = getOperand();
738 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
739 return notOp->getOperand(0);
740
741 // According to the SPIR-V spec:
742 //
743 // Complement the bits of Operand.
744 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
745 [](const APInt &a) {
746 APInt zero = APInt::getZero(1);
747 return a == 1 ? zero : (zero + 1);
748 });
749}
750
751void spirv::LogicalNotOp::getCanonicalizationPatterns(
752 RewritePatternSet &results, MLIRContext *context) {
753 results
754 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
755 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
756 context);
757}
758
759//===----------------------------------------------------------------------===//
760// spirv.LogicalOr
761//===----------------------------------------------------------------------===//
762
763OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
764 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
765 if (*rhs) {
766 // x || true = true
767 return adaptor.getOperand2();
768 }
769
770 if (!*rhs) {
771 // x || false = x
772 return getOperand1();
773 }
774 }
775
776 return Attribute();
777}
778
779//===----------------------------------------------------------------------===//
780// spirv.SelectOp
781//===----------------------------------------------------------------------===//
782
783OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
784 // spirv.Select _ x x -> x
785 Value trueVals = getTrueValue();
786 Value falseVals = getFalseValue();
787 if (trueVals == falseVals)
788 return trueVals;
789
790 ArrayRef<Attribute> operands = adaptor.getOperands();
791
792 // spirv.Select true x y -> x
793 // spirv.Select false x y -> y
794 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
795 return *boolAttr ? trueVals : falseVals;
796
797 // Check that all the operands are constant
798 if (!operands[0] || !operands[1] || !operands[2])
799 return Attribute();
800
801 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
802 // the scalar case. Hence, we are only required to consider the case of
803 // DenseElementsAttr in foldSelectOp.
804 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
805 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
806 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
807 if (!condAttrs || !trueAttrs || !falseAttrs)
808 return Attribute();
809
810 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
811 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
812 falseAttrs.getValues<Attribute>());
813 for (auto [result, cond, falseRes] : iters) {
814 if (!cond.getValue())
815 result = falseRes;
816 }
817
818 auto resultType = trueAttrs.getType();
819 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
820}
821
822//===----------------------------------------------------------------------===//
823// spirv.IEqualOp
824//===----------------------------------------------------------------------===//
825
826OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
827 // x == x -> true
828 if (getOperand1() == getOperand2()) {
829 auto trueAttr = BoolAttr::get(getContext(), true);
830 if (isa<IntegerType>(getType()))
831 return trueAttr;
832 if (auto vecTy = dyn_cast<VectorType>(getType()))
833 return SplatElementsAttr::get(vecTy, trueAttr);
834 }
835
837 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
838 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
839 });
840}
841
842//===----------------------------------------------------------------------===//
843// spirv.INotEqualOp
844//===----------------------------------------------------------------------===//
845
846OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
847 // x == x -> false
848 if (getOperand1() == getOperand2()) {
849 auto falseAttr = BoolAttr::get(getContext(), false);
850 if (isa<IntegerType>(getType()))
851 return falseAttr;
852 if (auto vecTy = dyn_cast<VectorType>(getType()))
853 return SplatElementsAttr::get(vecTy, falseAttr);
854 }
855
857 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
858 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
859 });
860}
861
862//===----------------------------------------------------------------------===//
863// spirv.SGreaterThan
864//===----------------------------------------------------------------------===//
865
867spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
868 // x == x -> false
869 if (getOperand1() == getOperand2()) {
870 auto falseAttr = BoolAttr::get(getContext(), false);
871 if (isa<IntegerType>(getType()))
872 return falseAttr;
873 if (auto vecTy = dyn_cast<VectorType>(getType()))
874 return SplatElementsAttr::get(vecTy, falseAttr);
875 }
876
878 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
879 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
880 });
881}
882
883//===----------------------------------------------------------------------===//
884// spirv.SGreaterThanEqual
885//===----------------------------------------------------------------------===//
886
887OpFoldResult spirv::SGreaterThanEqualOp::fold(
888 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
889 // x == x -> true
890 if (getOperand1() == getOperand2()) {
891 auto trueAttr = BoolAttr::get(getContext(), true);
892 if (isa<IntegerType>(getType()))
893 return trueAttr;
894 if (auto vecTy = dyn_cast<VectorType>(getType()))
895 return SplatElementsAttr::get(vecTy, trueAttr);
896 }
897
899 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
900 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
901 });
902}
903
904//===----------------------------------------------------------------------===//
905// spirv.UGreaterThan
906//===----------------------------------------------------------------------===//
907
909spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
910 // x == x -> false
911 if (getOperand1() == getOperand2()) {
912 auto falseAttr = BoolAttr::get(getContext(), false);
913 if (isa<IntegerType>(getType()))
914 return falseAttr;
915 if (auto vecTy = dyn_cast<VectorType>(getType()))
916 return SplatElementsAttr::get(vecTy, falseAttr);
917 }
918
920 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
921 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
922 });
923}
924
925//===----------------------------------------------------------------------===//
926// spirv.UGreaterThanEqual
927//===----------------------------------------------------------------------===//
928
929OpFoldResult spirv::UGreaterThanEqualOp::fold(
930 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
931 // x == x -> true
932 if (getOperand1() == getOperand2()) {
933 auto trueAttr = BoolAttr::get(getContext(), true);
934 if (isa<IntegerType>(getType()))
935 return trueAttr;
936 if (auto vecTy = dyn_cast<VectorType>(getType()))
937 return SplatElementsAttr::get(vecTy, trueAttr);
938 }
939
941 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
942 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
943 });
944}
945
946//===----------------------------------------------------------------------===//
947// spirv.SLessThan
948//===----------------------------------------------------------------------===//
949
950OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
951 // x == x -> false
952 if (getOperand1() == getOperand2()) {
953 auto falseAttr = BoolAttr::get(getContext(), false);
954 if (isa<IntegerType>(getType()))
955 return falseAttr;
956 if (auto vecTy = dyn_cast<VectorType>(getType()))
957 return SplatElementsAttr::get(vecTy, falseAttr);
958 }
959
961 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
962 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
963 });
964}
965
966//===----------------------------------------------------------------------===//
967// spirv.SLessThanEqual
968//===----------------------------------------------------------------------===//
969
971spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
972 // x == x -> true
973 if (getOperand1() == getOperand2()) {
974 auto trueAttr = BoolAttr::get(getContext(), true);
975 if (isa<IntegerType>(getType()))
976 return trueAttr;
977 if (auto vecTy = dyn_cast<VectorType>(getType()))
978 return SplatElementsAttr::get(vecTy, trueAttr);
979 }
980
982 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
983 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
984 });
985}
986
987//===----------------------------------------------------------------------===//
988// spirv.ULessThan
989//===----------------------------------------------------------------------===//
990
991OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
992 // x == x -> false
993 if (getOperand1() == getOperand2()) {
994 auto falseAttr = BoolAttr::get(getContext(), false);
995 if (isa<IntegerType>(getType()))
996 return falseAttr;
997 if (auto vecTy = dyn_cast<VectorType>(getType()))
998 return SplatElementsAttr::get(vecTy, falseAttr);
999 }
1000
1002 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1003 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1004 });
1005}
1006
1007//===----------------------------------------------------------------------===//
1008// spirv.ULessThanEqual
1009//===----------------------------------------------------------------------===//
1010
1012spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1013 // x == x -> true
1014 if (getOperand1() == getOperand2()) {
1015 auto trueAttr = BoolAttr::get(getContext(), true);
1016 if (isa<IntegerType>(getType()))
1017 return trueAttr;
1018 if (auto vecTy = dyn_cast<VectorType>(getType()))
1019 return SplatElementsAttr::get(vecTy, trueAttr);
1020 }
1021
1023 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1024 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1025 });
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// spirv.ShiftLeftLogical
1030//===----------------------------------------------------------------------===//
1031
1032OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1033 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1034 // x << 0 -> x
1035 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1036 return getOperand1();
1037 }
1038
1039 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1040
1041 // Results are computed per component, and within each component, per bit...
1042 //
1043 // The result is undefined if Shift is greater than or equal to the bit width
1044 // of the components of Base.
1045 //
1046 // So we can use the APInt << method, but don't fold if undefined behaviour.
1047 bool shiftToLarge = false;
1049 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1050 if (shiftToLarge || b.uge(a.getBitWidth())) {
1051 shiftToLarge = true;
1052 return a;
1053 }
1054 return a << b;
1055 });
1056 return shiftToLarge ? Attribute() : res;
1057}
1058
1059//===----------------------------------------------------------------------===//
1060// spirv.ShiftRightArithmetic
1061//===----------------------------------------------------------------------===//
1062
1063OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1064 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1065 // x >> 0 -> x
1066 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1067 return getOperand1();
1068 }
1069
1070 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1071
1072 // Results are computed per component, and within each component, per bit...
1073 //
1074 // The result is undefined if Shift is greater than or equal to the bit width
1075 // of the components of Base.
1076 //
1077 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1078 bool shiftToLarge = false;
1080 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1081 if (shiftToLarge || b.uge(a.getBitWidth())) {
1082 shiftToLarge = true;
1083 return a;
1084 }
1085 return a.ashr(b);
1086 });
1087 return shiftToLarge ? Attribute() : res;
1088}
1089
1090//===----------------------------------------------------------------------===//
1091// spirv.ShiftRightLogical
1092//===----------------------------------------------------------------------===//
1093
1094OpFoldResult spirv::ShiftRightLogicalOp::fold(
1095 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1096 // x >> 0 -> x
1097 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1098 return getOperand1();
1099 }
1100
1101 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1102
1103 // Results are computed per component, and within each component, per bit...
1104 //
1105 // The result is undefined if Shift is greater than or equal to the bit width
1106 // of the components of Base.
1107 //
1108 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1109 bool shiftToLarge = false;
1111 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1112 if (shiftToLarge || b.uge(a.getBitWidth())) {
1113 shiftToLarge = true;
1114 return a;
1115 }
1116 return a.lshr(b);
1117 });
1118 return shiftToLarge ? Attribute() : res;
1119}
1120
1121//===----------------------------------------------------------------------===//
1122// spirv.BitwiseAndOp
1123//===----------------------------------------------------------------------===//
1124
1126spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1127 // x & x -> x
1128 if (getOperand1() == getOperand2()) {
1129 return getOperand1();
1130 }
1131
1132 APInt rhsMask;
1133 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1134 // x & 0 -> 0
1135 if (rhsMask.isZero())
1136 return getOperand2();
1137
1138 // x & <all ones> -> x
1139 if (rhsMask.isAllOnes())
1140 return getOperand1();
1141
1142 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1143 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1144 int valueBits =
1146 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1147 return getOperand1();
1148 }
1149 }
1150
1151 // According to the SPIR-V spec:
1152 //
1153 // Type is a scalar or vector of integer type.
1154 // Results are computed per component, and within each component, per bit.
1155 // So we can use the APInt & method.
1157 adaptor.getOperands(),
1158 [](const APInt &a, const APInt &b) { return a & b; });
1159}
1160
1161//===----------------------------------------------------------------------===//
1162// spirv.BitwiseOrOp
1163//===----------------------------------------------------------------------===//
1164
1165OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1166 // x | x -> x
1167 if (getOperand1() == getOperand2()) {
1168 return getOperand1();
1169 }
1170
1171 APInt rhsMask;
1172 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1173 // x | 0 -> x
1174 if (rhsMask.isZero())
1175 return getOperand1();
1176
1177 // x | <all ones> -> <all ones>
1178 if (rhsMask.isAllOnes())
1179 return getOperand2();
1180 }
1181
1182 // According to the SPIR-V spec:
1183 //
1184 // Type is a scalar or vector of integer type.
1185 // Results are computed per component, and within each component, per bit.
1186 // So we can use the APInt | method.
1188 adaptor.getOperands(),
1189 [](const APInt &a, const APInt &b) { return a | b; });
1190}
1191
1192//===----------------------------------------------------------------------===//
1193// spirv.BitwiseXorOp
1194//===----------------------------------------------------------------------===//
1195
1197spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1198 // x ^ 0 -> x
1199 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1200 return getOperand1();
1201 }
1202
1203 // x ^ x -> 0
1204 if (getOperand1() == getOperand2())
1206
1207 // According to the SPIR-V spec:
1208 //
1209 // Type is a scalar or vector of integer type.
1210 // Results are computed per component, and within each component, per bit.
1211 // So we can use the APInt ^ method.
1213 adaptor.getOperands(),
1214 [](const APInt &a, const APInt &b) { return a ^ b; });
1215}
1216
1217//===----------------------------------------------------------------------===//
1218// spirv.mlir.selection
1219//===----------------------------------------------------------------------===//
1220
1221namespace {
1222// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1223// following layout:
1224//
1225// +-----------------------------------------------+
1226// | header block |
1227// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1228// +-----------------------------------------------+
1229// / \
1230// ...
1231//
1232//
1233// +------------------------+ +------------------------+
1234// | case #0 | | case #1 |
1235// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1236// | spirv.Branch ^merge | | spirv.Branch ^merge |
1237// +------------------------+ +------------------------+
1238//
1239//
1240// ...
1241// \ /
1242// v
1243// +-------------+
1244// | merge block |
1245// +-------------+
1246//
1247struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1248 using Base::Base;
1249
1250 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1251 PatternRewriter &rewriter) const override {
1252 Operation *op = selectionOp.getOperation();
1253 Region &body = op->getRegion(0);
1254 // Verifier allows an empty region for `spirv.mlir.selection`.
1255 if (body.empty()) {
1256 return failure();
1257 }
1258
1259 // Check that region consists of 4 blocks:
1260 // header block, `true` block, `false` block and merge block.
1261 if (llvm::range_size(body) != 4) {
1262 return failure();
1263 }
1264
1265 Block *headerBlock = selectionOp.getHeaderBlock();
1266 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1267 return failure();
1268 }
1269
1270 auto brConditionalOp =
1271 cast<spirv::BranchConditionalOp>(headerBlock->front());
1272
1273 Block *trueBlock = brConditionalOp.getSuccessor(0);
1274 Block *falseBlock = brConditionalOp.getSuccessor(1);
1275 Block *mergeBlock = selectionOp.getMergeBlock();
1276
1277 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1278 return failure();
1279
1280 Value trueValue = getSrcValue(trueBlock);
1281 Value falseValue = getSrcValue(falseBlock);
1282 Value ptrValue = getDstPtr(trueBlock);
1283 auto storeOpAttributes =
1284 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1285
1286 auto selectOp = spirv::SelectOp::create(
1287 rewriter, selectionOp.getLoc(), trueValue.getType(),
1288 brConditionalOp.getCondition(), trueValue, falseValue);
1289 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1290 selectOp.getResult(), storeOpAttributes);
1291
1292 // `spirv.mlir.selection` is not needed anymore.
1293 rewriter.eraseOp(op);
1294 return success();
1295 }
1296
1297private:
1298 // Checks that given blocks follow the following rules:
1299 // 1. Each conditional block consists of two operations, the first operation
1300 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1301 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1302 // 3. A control flow goes into the given merge block from the given
1303 // conditional blocks.
1304 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1305 Block *mergeBlock) const;
1306
1307 bool onlyContainsBranchConditionalOp(Block *block) const {
1308 return llvm::hasSingleElement(*block) &&
1309 isa<spirv::BranchConditionalOp>(block->front());
1310 }
1311
1312 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1313 return lhs->getDiscardableAttrDictionary() ==
1314 rhs->getDiscardableAttrDictionary() &&
1315 lhs.getProperties() == rhs.getProperties();
1316 }
1317
1318 // Returns a source value for the given block.
1319 Value getSrcValue(Block *block) const {
1320 auto storeOp = cast<spirv::StoreOp>(block->front());
1321 return storeOp.getValue();
1322 }
1323
1324 // Returns a destination value for the given block.
1325 Value getDstPtr(Block *block) const {
1326 auto storeOp = cast<spirv::StoreOp>(block->front());
1327 return storeOp.getPtr();
1328 }
1329};
1330
1331LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1332 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1333 // Each block must consists of 2 operations.
1334 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1335 return failure();
1336 }
1337
1338 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1339 auto trueBrBranchOp =
1340 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1341 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1342 auto falseBrBranchOp =
1343 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1344
1345 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1346 !falseBrBranchOp) {
1347 return failure();
1348 }
1349
1350 // Checks that given type is valid for `spirv.SelectOp`.
1351 // According to SPIR-V spec:
1352 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1353 // Starting with version 1.4, Result Type can additionally be a composite type
1354 // other than a vector."
1355 bool isScalarOrVector =
1356 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1357 .isScalarOrVector();
1358
1359 // Check that each `spirv.Store` uses the same pointer, memory access
1360 // attributes and a valid type of the value.
1361 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1362 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1363 return failure();
1364 }
1365
1366 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1367 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1368 return failure();
1369 }
1370
1371 return success();
1372}
1373} // namespace
1374
1375void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1376 MLIRContext *context) {
1377 results.add<ConvertSelectionOpToSelect>(context);
1378}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
ArithmeticExtendedBinaryFold< spirv::ISubBorrowOp > ISubBorrowFold
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
ArithmeticExtendedBinaryFold< spirv::IAddCarryOp > IAddCarryFold
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})