MLIR 22.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
17
21#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Common utility functions
30//===----------------------------------------------------------------------===//
31
32/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33/// or splat vector bool constant.
34static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35 if (!attr)
36 return std::nullopt;
37
38 if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<bool>();
43 return std::nullopt;
44}
45
46// Extracts an element from the given `composite` by following the given
47// `indices`. Returns a null Attribute if error happens.
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValues<Attribute>()[indices[0]];
60 }
61
62 if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69}
70
71static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
74
75 return div0 || overflow;
76}
77
78//===----------------------------------------------------------------------===//
79// TableGen'erated canonicalizers
80//===----------------------------------------------------------------------===//
81
82namespace {
83#include "SPIRVCanonicalization.inc"
84} // namespace
85
86//===----------------------------------------------------------------------===//
87// spirv.AccessChainOp
88//===----------------------------------------------------------------------===//
89
90namespace {
91
92/// Combines chained `spirv::AccessChainOp` operations into one
93/// `spirv::AccessChainOp` operation.
94struct CombineChainedAccessChain final
95 : OpRewritePattern<spirv::AccessChainOp> {
96 using Base::Base;
97
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter) const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102
103 if (!parentAccessChainOp) {
104 return failure();
105 }
106
107 // Combine indices.
108 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109 llvm::append_range(indices, accessChainOp.getIndices());
110
111 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
113
114 return success();
115 }
116};
117} // namespace
118
119void spirv::AccessChainOp::getCanonicalizationPatterns(
120 RewritePatternSet &results, MLIRContext *context) {
121 results.add<CombineChainedAccessChain>(context);
122}
123
124//===----------------------------------------------------------------------===//
125// spirv.IAddCarry
126//===----------------------------------------------------------------------===//
127
128// We are required to use CompositeConstructOp to create a constant struct as
129// they are not yet implemented as constant, hence we can not do so in a fold.
130struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
131 using Base::Base;
132
133 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
134 PatternRewriter &rewriter) const override {
135 Location loc = op.getLoc();
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
138 Type constituentType = lhs.getType();
139
140 // iaddcarry (x, 0) = <0, x>
141 if (matchPattern(rhs, m_Zero())) {
142 Value constituents[2] = {rhs, lhs};
143 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
144 constituents);
145 return success();
146 }
147
148 // According to the SPIR-V spec:
149 //
150 // Result Type must be from OpTypeStruct. The struct must have two
151 // members...
152 //
153 // Member 0 of the result gets the low-order bits (full component width) of
154 // the addition.
155 //
156 // Member 1 of the result gets the high-order (carry) bit of the result of
157 // the addition. That is, it gets the value 1 if the addition overflowed
158 // the component width, and 0 otherwise.
159 Attribute lhsAttr;
160 Attribute rhsAttr;
161 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
162 !matchPattern(rhs, m_Constant(&rhsAttr)))
163 return failure();
164
166 {lhsAttr, rhsAttr},
167 [](const APInt &a, const APInt &b) { return a + b; });
168 if (!adds)
169 return failure();
170
171 auto carrys = constFoldBinaryOp<IntegerAttr>(
172 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
173 APInt zero = APInt::getZero(a.getBitWidth());
174 return a.ult(b) ? (zero + 1) : zero;
175 });
176
177 if (!carrys)
178 return failure();
179
180 Value addsVal =
181 spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
182
183 Value carrysVal =
184 spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
185
186 // Create empty struct
187 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
188 // Fill in adds at id 0
189 Value intermediate =
190 spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
191 // Fill in carrys at id 1
192 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
193 intermediate, 1);
194 return success();
195 }
196};
197
198void spirv::IAddCarryOp::getCanonicalizationPatterns(
200 patterns.add<IAddCarryFold>(context);
201}
202
203//===----------------------------------------------------------------------===//
204// spirv.[S|U]MulExtended
205//===----------------------------------------------------------------------===//
206
207// We are required to use CompositeConstructOp to create a constant struct as
208// they are not yet implemented as constant, hence we can not do so in a fold.
209template <typename MulOp, bool IsSigned>
210struct MulExtendedFold final : OpRewritePattern<MulOp> {
212
213 LogicalResult matchAndRewrite(MulOp op,
214 PatternRewriter &rewriter) const override {
215 Location loc = op.getLoc();
216 Value lhs = op.getOperand1();
217 Value rhs = op.getOperand2();
218 Type constituentType = lhs.getType();
219
220 // [su]mulextended (x, 0) = <0, 0>
221 if (matchPattern(rhs, m_Zero())) {
222 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
223 Value constituents[2] = {zero, zero};
224 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
225 constituents);
226 return success();
227 }
228
229 // According to the SPIR-V spec:
230 //
231 // Result Type must be from OpTypeStruct. The struct must have two
232 // members...
233 //
234 // Member 0 of the result gets the low-order bits of the multiplication.
235 //
236 // Member 1 of the result gets the high-order bits of the multiplication.
237 Attribute lhsAttr;
238 Attribute rhsAttr;
239 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
240 !matchPattern(rhs, m_Constant(&rhsAttr)))
241 return failure();
242
243 auto lowBits = constFoldBinaryOp<IntegerAttr>(
244 {lhsAttr, rhsAttr},
245 [](const APInt &a, const APInt &b) { return a * b; });
246
247 if (!lowBits)
248 return failure();
249
250 auto highBits = constFoldBinaryOp<IntegerAttr>(
251 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
252 if (IsSigned) {
253 return llvm::APIntOps::mulhs(a, b);
254 } else {
255 return llvm::APIntOps::mulhu(a, b);
256 }
257 });
258
259 if (!highBits)
260 return failure();
261
262 Value lowBitsVal =
263 spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
264
265 Value highBitsVal =
266 spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
267
268 // Create empty struct
269 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
270 // Fill in lowBits at id 0
271 Value intermediate =
272 spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
273 // Fill in highBits at id 1
274 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
275 intermediate, 1);
276 return success();
277 }
278};
279
281void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283 patterns.add<SMulExtendedOpFold>(context);
284}
285
286struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
287 using Base::Base;
288
289 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
290 PatternRewriter &rewriter) const override {
291 Location loc = op.getLoc();
292 Value lhs = op.getOperand1();
293 Value rhs = op.getOperand2();
294 Type constituentType = lhs.getType();
295
296 // umulextended (x, 1) = <x, 0>
297 if (matchPattern(rhs, m_One())) {
298 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
299 Value constituents[2] = {lhs, zero};
300 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
301 constituents);
302 return success();
303 }
304
305 return failure();
306 }
307};
308
310void spirv::UMulExtendedOp::getCanonicalizationPatterns(
313}
314
315//===----------------------------------------------------------------------===//
316// spirv.UMod
317//===----------------------------------------------------------------------===//
318
319// Input:
320// %0 = spirv.UMod %arg0, %const32 : i32
321// %1 = spirv.UMod %0, %const4 : i32
322// Output:
323// %0 = spirv.UMod %arg0, %const32 : i32
324// %1 = spirv.UMod %arg0, %const4 : i32
325
326// The transformation is only applied if one divisor is a multiple of the other.
327
328struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
329 using Base::Base;
330
331 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
332 PatternRewriter &rewriter) const override {
333 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
334 if (!prevUMod)
335 return failure();
336
337 TypedAttr prevValue;
338 TypedAttr currValue;
339 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
340 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
341 return failure();
342
343 // Ensure that previous divisor is a multiple of the current divisor. If
344 // not, fail the transformation.
345 bool isApplicable = false;
346 if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
347 auto currInt = cast<IntegerAttr>(currValue);
348 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
349 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
350 auto currVec = cast<DenseElementsAttr>(currValue);
351 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
352 currVec.getValues<APInt>()),
353 [](const auto &pair) {
354 auto &[prev, curr] = pair;
355 return prev.urem(curr) == 0;
356 });
357 }
358
359 if (!isApplicable)
360 return failure();
361
362 // The transformation is safe. Replace the existing UMod operation with a
363 // new UMod operation, using the original dividend and the current divisor.
364 rewriter.replaceOpWithNewOp<spirv::UModOp>(
365 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
366
367 return success();
368 }
369};
370
371void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
372 MLIRContext *context) {
373 patterns.insert<UModSimplification>(context);
374}
375
376//===----------------------------------------------------------------------===//
377// spirv.BitcastOp
378//===----------------------------------------------------------------------===//
379
380OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
381 Value curInput = getOperand();
382 if (getType() == curInput.getType())
383 return curInput;
384
385 // Look through nested bitcasts.
386 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
387 Value prevInput = prevCast.getOperand();
388 if (prevInput.getType() == getType())
389 return prevInput;
390
391 getOperandMutable().assign(prevInput);
392 return getResult();
393 }
394
395 // TODO(kuhar): Consider constant-folding the operand attribute.
396 return {};
397}
398
399//===----------------------------------------------------------------------===//
400// spirv.CompositeExtractOp
401//===----------------------------------------------------------------------===//
402
403OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
404 Value compositeOp = getComposite();
405
406 while (auto insertOp =
407 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
408 if (getIndices() == insertOp.getIndices())
409 return insertOp.getObject();
410 compositeOp = insertOp.getComposite();
411 }
412
413 if (auto constructOp =
414 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
415 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
416 if (getIndices().size() == 1 &&
417 constructOp.getConstituents().size() == type.getNumElements()) {
418 auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
419 if (i.getValue().getSExtValue() <
420 static_cast<int64_t>(constructOp.getConstituents().size()))
421 return constructOp.getConstituents()[i.getValue().getSExtValue()];
422 }
423 }
424
425 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
426 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
427 });
428 return extractCompositeElement(adaptor.getComposite(), indexVector);
429}
430
431//===----------------------------------------------------------------------===//
432// spirv.Constant
433//===----------------------------------------------------------------------===//
434
435OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
436 return getValue();
437}
438
439//===----------------------------------------------------------------------===//
440// spirv.IAdd
441//===----------------------------------------------------------------------===//
442
443OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
444 // x + 0 = x
445 if (matchPattern(getOperand2(), m_Zero()))
446 return getOperand1();
447
448 // According to the SPIR-V spec:
449 //
450 // The resulting value will equal the low-order N bits of the correct result
451 // R, where N is the component width and R is computed with enough precision
452 // to avoid overflow and underflow.
454 adaptor.getOperands(),
455 [](APInt a, const APInt &b) { return std::move(a) + b; });
456}
457
458//===----------------------------------------------------------------------===//
459// spirv.IMul
460//===----------------------------------------------------------------------===//
461
462OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
463 // x * 0 == 0
464 if (matchPattern(getOperand2(), m_Zero()))
465 return getOperand2();
466 // x * 1 = x
467 if (matchPattern(getOperand2(), m_One()))
468 return getOperand1();
469
470 // According to the SPIR-V spec:
471 //
472 // The resulting value will equal the low-order N bits of the correct result
473 // R, where N is the component width and R is computed with enough precision
474 // to avoid overflow and underflow.
476 adaptor.getOperands(),
477 [](const APInt &a, const APInt &b) { return a * b; });
478}
479
480//===----------------------------------------------------------------------===//
481// spirv.ISub
482//===----------------------------------------------------------------------===//
483
484OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
485 // x - x = 0
486 if (getOperand1() == getOperand2())
488
489 // According to the SPIR-V spec:
490 //
491 // The resulting value will equal the low-order N bits of the correct result
492 // R, where N is the component width and R is computed with enough precision
493 // to avoid overflow and underflow.
495 adaptor.getOperands(),
496 [](APInt a, const APInt &b) { return std::move(a) - b; });
497}
498
499//===----------------------------------------------------------------------===//
500// spirv.SDiv
501//===----------------------------------------------------------------------===//
502
503OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
504 // sdiv (x, 1) = x
505 if (matchPattern(getOperand2(), m_One()))
506 return getOperand1();
507
508 // According to the SPIR-V spec:
509 //
510 // Signed-integer division of Operand 1 divided by Operand 2.
511 // Results are computed per component. Behavior is undefined if Operand 2 is
512 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
513 // representable value for the operands' type, causing signed overflow.
514 //
515 // So don't fold during undefined behavior.
516 bool div0OrOverflow = false;
518 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
519 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520 div0OrOverflow = true;
521 return a;
522 }
523 return a.sdiv(b);
524 });
525 return div0OrOverflow ? Attribute() : res;
526}
527
528//===----------------------------------------------------------------------===//
529// spirv.SMod
530//===----------------------------------------------------------------------===//
531
532OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
533 // smod (x, 1) = 0
534 if (matchPattern(getOperand2(), m_One()))
536
537 // According to SPIR-V spec:
538 //
539 // Signed remainder operation for the remainder whose sign matches the sign
540 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
541 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
542 // value for the operands' type, causing signed overflow. Otherwise, the
543 // result is the remainder r of Operand 1 divided by Operand 2 where if
544 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
545 //
546 // So don't fold during undefined behavior
547 bool div0OrOverflow = false;
549 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
550 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
551 div0OrOverflow = true;
552 return a;
553 }
554 APInt c = a.abs().urem(b.abs());
555 if (c.isZero())
556 return c;
557 if (b.isNegative()) {
558 APInt zero = APInt::getZero(c.getBitWidth());
559 return a.isNegative() ? (zero - c) : (b + c);
560 }
561 return a.isNegative() ? (b - c) : c;
562 });
563 return div0OrOverflow ? Attribute() : res;
564}
565
566//===----------------------------------------------------------------------===//
567// spirv.SRem
568//===----------------------------------------------------------------------===//
569
570OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
571 // x % 1 = 0
572 if (matchPattern(getOperand2(), m_One()))
574
575 // According to SPIR-V spec:
576 //
577 // Signed remainder operation for the remainder whose sign matches the sign
578 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
579 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
580 // value for the operands' type, causing signed overflow. Otherwise, the
581 // result is the remainder r of Operand 1 divided by Operand 2 where if
582 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
583
584 // Don't fold if it would do undefined behavior.
585 bool div0OrOverflow = false;
587 adaptor.getOperands(), [&](APInt a, const APInt &b) {
588 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
589 div0OrOverflow = true;
590 return a;
591 }
592 return a.srem(b);
593 });
594 return div0OrOverflow ? Attribute() : res;
595}
596
597//===----------------------------------------------------------------------===//
598// spirv.UDiv
599//===----------------------------------------------------------------------===//
600
601OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
602 // udiv (x, 1) = x
603 if (matchPattern(getOperand2(), m_One()))
604 return getOperand1();
605
606 // According to the SPIR-V spec:
607 //
608 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
609 // undefined if Operand 2 is 0.
610 //
611 // So don't fold during undefined behavior.
612 bool div0 = false;
614 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
615 if (div0 || b.isZero()) {
616 div0 = true;
617 return a;
618 }
619 return a.udiv(b);
620 });
621 return div0 ? Attribute() : res;
622}
623
624//===----------------------------------------------------------------------===//
625// spirv.UMod
626//===----------------------------------------------------------------------===//
627
628OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
629 // umod (x, 1) = 0
630 if (matchPattern(getOperand2(), m_One()))
632
633 // According to the SPIR-V spec:
634 //
635 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
636 // undefined if Operand 2 is 0.
637 //
638 // So don't fold during undefined behavior.
639 bool div0 = false;
641 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
642 if (div0 || b.isZero()) {
643 div0 = true;
644 return a;
645 }
646 return a.urem(b);
647 });
648 return div0 ? Attribute() : res;
649}
650
651//===----------------------------------------------------------------------===//
652// spirv.SNegate
653//===----------------------------------------------------------------------===//
654
655OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
656 // -(-x) = 0 - (0 - x) = x
657 auto op = getOperand();
658 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
659 return negateOp->getOperand(0);
660
661 // According to the SPIR-V spec:
662 //
663 // Signed-integer subtract of Operand from zero.
665 adaptor.getOperands(), [](const APInt &a) {
666 APInt zero = APInt::getZero(a.getBitWidth());
667 return zero - a;
668 });
669}
670
671//===----------------------------------------------------------------------===//
672// spirv.NotOp
673//===----------------------------------------------------------------------===//
674
675OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
676 // !(!x) = x
677 auto op = getOperand();
678 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
679 return notOp->getOperand(0);
680
681 // According to the SPIR-V spec:
682 //
683 // Complement the bits of Operand.
684 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
685 a.flipAllBits();
686 return a;
687 });
688}
689
690//===----------------------------------------------------------------------===//
691// spirv.LogicalAnd
692//===----------------------------------------------------------------------===//
693
694OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
695 if (std::optional<bool> rhs =
696 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
697 // x && true = x
698 if (*rhs)
699 return getOperand1();
700
701 // x && false = false
702 if (!*rhs)
703 return adaptor.getOperand2();
704 }
705
706 return Attribute();
707}
708
709//===----------------------------------------------------------------------===//
710// spirv.LogicalEqualOp
711//===----------------------------------------------------------------------===//
712
714spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
715 // x == x -> true
716 if (getOperand1() == getOperand2()) {
717 auto trueAttr = BoolAttr::get(getContext(), true);
718 if (isa<IntegerType>(getType()))
719 return trueAttr;
720 if (auto vecTy = dyn_cast<VectorType>(getType()))
721 return SplatElementsAttr::get(vecTy, trueAttr);
722 }
723
725 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
726 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
727 });
728}
729
730//===----------------------------------------------------------------------===//
731// spirv.LogicalNotEqualOp
732//===----------------------------------------------------------------------===//
733
734OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
735 if (std::optional<bool> rhs =
736 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
737 // x != false -> x
738 if (!rhs.value())
739 return getOperand1();
740 }
741
742 // x == x -> false
743 if (getOperand1() == getOperand2()) {
744 auto falseAttr = BoolAttr::get(getContext(), false);
745 if (isa<IntegerType>(getType()))
746 return falseAttr;
747 if (auto vecTy = dyn_cast<VectorType>(getType()))
748 return SplatElementsAttr::get(vecTy, falseAttr);
749 }
750
752 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
753 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
754 });
755}
756
757//===----------------------------------------------------------------------===//
758// spirv.LogicalNot
759//===----------------------------------------------------------------------===//
760
761OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
762 // !(!x) = x
763 auto op = getOperand();
764 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
765 return notOp->getOperand(0);
766
767 // According to the SPIR-V spec:
768 //
769 // Complement the bits of Operand.
770 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
771 [](const APInt &a) {
772 APInt zero = APInt::getZero(1);
773 return a == 1 ? zero : (zero + 1);
774 });
775}
776
777void spirv::LogicalNotOp::getCanonicalizationPatterns(
778 RewritePatternSet &results, MLIRContext *context) {
779 results
780 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
781 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
782 context);
783}
784
785//===----------------------------------------------------------------------===//
786// spirv.LogicalOr
787//===----------------------------------------------------------------------===//
788
789OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
790 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
791 if (*rhs) {
792 // x || true = true
793 return adaptor.getOperand2();
794 }
795
796 if (!*rhs) {
797 // x || false = x
798 return getOperand1();
799 }
800 }
801
802 return Attribute();
803}
804
805//===----------------------------------------------------------------------===//
806// spirv.SelectOp
807//===----------------------------------------------------------------------===//
808
809OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
810 // spirv.Select _ x x -> x
811 Value trueVals = getTrueValue();
812 Value falseVals = getFalseValue();
813 if (trueVals == falseVals)
814 return trueVals;
815
816 ArrayRef<Attribute> operands = adaptor.getOperands();
817
818 // spirv.Select true x y -> x
819 // spirv.Select false x y -> y
820 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
821 return *boolAttr ? trueVals : falseVals;
822
823 // Check that all the operands are constant
824 if (!operands[0] || !operands[1] || !operands[2])
825 return Attribute();
826
827 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
828 // the scalar case. Hence, we are only required to consider the case of
829 // DenseElementsAttr in foldSelectOp.
830 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
831 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
832 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
833 if (!condAttrs || !trueAttrs || !falseAttrs)
834 return Attribute();
835
836 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
837 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
838 falseAttrs.getValues<Attribute>());
839 for (auto [result, cond, falseRes] : iters) {
840 if (!cond.getValue())
841 result = falseRes;
842 }
843
844 auto resultType = trueAttrs.getType();
845 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
846}
847
848//===----------------------------------------------------------------------===//
849// spirv.IEqualOp
850//===----------------------------------------------------------------------===//
851
852OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
853 // x == x -> true
854 if (getOperand1() == getOperand2()) {
855 auto trueAttr = BoolAttr::get(getContext(), true);
856 if (isa<IntegerType>(getType()))
857 return trueAttr;
858 if (auto vecTy = dyn_cast<VectorType>(getType()))
859 return SplatElementsAttr::get(vecTy, trueAttr);
860 }
861
863 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
864 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
865 });
866}
867
868//===----------------------------------------------------------------------===//
869// spirv.INotEqualOp
870//===----------------------------------------------------------------------===//
871
872OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
873 // x == x -> false
874 if (getOperand1() == getOperand2()) {
875 auto falseAttr = BoolAttr::get(getContext(), false);
876 if (isa<IntegerType>(getType()))
877 return falseAttr;
878 if (auto vecTy = dyn_cast<VectorType>(getType()))
879 return SplatElementsAttr::get(vecTy, falseAttr);
880 }
881
883 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
884 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
885 });
886}
887
888//===----------------------------------------------------------------------===//
889// spirv.SGreaterThan
890//===----------------------------------------------------------------------===//
891
893spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
894 // x == x -> false
895 if (getOperand1() == getOperand2()) {
896 auto falseAttr = BoolAttr::get(getContext(), false);
897 if (isa<IntegerType>(getType()))
898 return falseAttr;
899 if (auto vecTy = dyn_cast<VectorType>(getType()))
900 return SplatElementsAttr::get(vecTy, falseAttr);
901 }
902
904 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
905 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
906 });
907}
908
909//===----------------------------------------------------------------------===//
910// spirv.SGreaterThanEqual
911//===----------------------------------------------------------------------===//
912
913OpFoldResult spirv::SGreaterThanEqualOp::fold(
914 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
915 // x == x -> true
916 if (getOperand1() == getOperand2()) {
917 auto trueAttr = BoolAttr::get(getContext(), true);
918 if (isa<IntegerType>(getType()))
919 return trueAttr;
920 if (auto vecTy = dyn_cast<VectorType>(getType()))
921 return SplatElementsAttr::get(vecTy, trueAttr);
922 }
923
925 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
926 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
927 });
928}
929
930//===----------------------------------------------------------------------===//
931// spirv.UGreaterThan
932//===----------------------------------------------------------------------===//
933
935spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
936 // x == x -> false
937 if (getOperand1() == getOperand2()) {
938 auto falseAttr = BoolAttr::get(getContext(), false);
939 if (isa<IntegerType>(getType()))
940 return falseAttr;
941 if (auto vecTy = dyn_cast<VectorType>(getType()))
942 return SplatElementsAttr::get(vecTy, falseAttr);
943 }
944
946 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
947 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
948 });
949}
950
951//===----------------------------------------------------------------------===//
952// spirv.UGreaterThanEqual
953//===----------------------------------------------------------------------===//
954
955OpFoldResult spirv::UGreaterThanEqualOp::fold(
956 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
957 // x == x -> true
958 if (getOperand1() == getOperand2()) {
959 auto trueAttr = BoolAttr::get(getContext(), true);
960 if (isa<IntegerType>(getType()))
961 return trueAttr;
962 if (auto vecTy = dyn_cast<VectorType>(getType()))
963 return SplatElementsAttr::get(vecTy, trueAttr);
964 }
965
967 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
968 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
969 });
970}
971
972//===----------------------------------------------------------------------===//
973// spirv.SLessThan
974//===----------------------------------------------------------------------===//
975
976OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
977 // x == x -> false
978 if (getOperand1() == getOperand2()) {
979 auto falseAttr = BoolAttr::get(getContext(), false);
980 if (isa<IntegerType>(getType()))
981 return falseAttr;
982 if (auto vecTy = dyn_cast<VectorType>(getType()))
983 return SplatElementsAttr::get(vecTy, falseAttr);
984 }
985
987 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
988 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
989 });
990}
991
992//===----------------------------------------------------------------------===//
993// spirv.SLessThanEqual
994//===----------------------------------------------------------------------===//
995
997spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
998 // x == x -> true
999 if (getOperand1() == getOperand2()) {
1000 auto trueAttr = BoolAttr::get(getContext(), true);
1001 if (isa<IntegerType>(getType()))
1002 return trueAttr;
1003 if (auto vecTy = dyn_cast<VectorType>(getType()))
1004 return SplatElementsAttr::get(vecTy, trueAttr);
1005 }
1006
1008 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1009 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1010 });
1011}
1012
1013//===----------------------------------------------------------------------===//
1014// spirv.ULessThan
1015//===----------------------------------------------------------------------===//
1016
1017OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1018 // x == x -> false
1019 if (getOperand1() == getOperand2()) {
1020 auto falseAttr = BoolAttr::get(getContext(), false);
1021 if (isa<IntegerType>(getType()))
1022 return falseAttr;
1023 if (auto vecTy = dyn_cast<VectorType>(getType()))
1024 return SplatElementsAttr::get(vecTy, falseAttr);
1025 }
1026
1028 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1029 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1030 });
1031}
1032
1033//===----------------------------------------------------------------------===//
1034// spirv.ULessThanEqual
1035//===----------------------------------------------------------------------===//
1036
1038spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1039 // x == x -> true
1040 if (getOperand1() == getOperand2()) {
1041 auto trueAttr = BoolAttr::get(getContext(), true);
1042 if (isa<IntegerType>(getType()))
1043 return trueAttr;
1044 if (auto vecTy = dyn_cast<VectorType>(getType()))
1045 return SplatElementsAttr::get(vecTy, trueAttr);
1046 }
1047
1049 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1050 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1051 });
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// spirv.ShiftLeftLogical
1056//===----------------------------------------------------------------------===//
1057
1058OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1059 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1060 // x << 0 -> x
1061 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1062 return getOperand1();
1063 }
1064
1065 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1066
1067 // Results are computed per component, and within each component, per bit...
1068 //
1069 // The result is undefined if Shift is greater than or equal to the bit width
1070 // of the components of Base.
1071 //
1072 // So we can use the APInt << method, but don't fold if undefined behaviour.
1073 bool shiftToLarge = false;
1075 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1076 if (shiftToLarge || b.uge(a.getBitWidth())) {
1077 shiftToLarge = true;
1078 return a;
1079 }
1080 return a << b;
1081 });
1082 return shiftToLarge ? Attribute() : res;
1083}
1084
1085//===----------------------------------------------------------------------===//
1086// spirv.ShiftRightArithmetic
1087//===----------------------------------------------------------------------===//
1088
1089OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1090 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1091 // x >> 0 -> x
1092 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1093 return getOperand1();
1094 }
1095
1096 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1097
1098 // Results are computed per component, and within each component, per bit...
1099 //
1100 // The result is undefined if Shift is greater than or equal to the bit width
1101 // of the components of Base.
1102 //
1103 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1104 bool shiftToLarge = false;
1106 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1107 if (shiftToLarge || b.uge(a.getBitWidth())) {
1108 shiftToLarge = true;
1109 return a;
1110 }
1111 return a.ashr(b);
1112 });
1113 return shiftToLarge ? Attribute() : res;
1114}
1115
1116//===----------------------------------------------------------------------===//
1117// spirv.ShiftRightLogical
1118//===----------------------------------------------------------------------===//
1119
1120OpFoldResult spirv::ShiftRightLogicalOp::fold(
1121 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1122 // x >> 0 -> x
1123 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1124 return getOperand1();
1125 }
1126
1127 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1128
1129 // Results are computed per component, and within each component, per bit...
1130 //
1131 // The result is undefined if Shift is greater than or equal to the bit width
1132 // of the components of Base.
1133 //
1134 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1135 bool shiftToLarge = false;
1137 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1138 if (shiftToLarge || b.uge(a.getBitWidth())) {
1139 shiftToLarge = true;
1140 return a;
1141 }
1142 return a.lshr(b);
1143 });
1144 return shiftToLarge ? Attribute() : res;
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// spirv.BitwiseAndOp
1149//===----------------------------------------------------------------------===//
1150
1152spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1153 // x & x -> x
1154 if (getOperand1() == getOperand2()) {
1155 return getOperand1();
1156 }
1157
1158 APInt rhsMask;
1159 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1160 // x & 0 -> 0
1161 if (rhsMask.isZero())
1162 return getOperand2();
1163
1164 // x & <all ones> -> x
1165 if (rhsMask.isAllOnes())
1166 return getOperand1();
1167
1168 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1169 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1170 int valueBits =
1172 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1173 return getOperand1();
1174 }
1175 }
1176
1177 // According to the SPIR-V spec:
1178 //
1179 // Type is a scalar or vector of integer type.
1180 // Results are computed per component, and within each component, per bit.
1181 // So we can use the APInt & method.
1183 adaptor.getOperands(),
1184 [](const APInt &a, const APInt &b) { return a & b; });
1185}
1186
1187//===----------------------------------------------------------------------===//
1188// spirv.BitwiseOrOp
1189//===----------------------------------------------------------------------===//
1190
1191OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1192 // x | x -> x
1193 if (getOperand1() == getOperand2()) {
1194 return getOperand1();
1195 }
1196
1197 APInt rhsMask;
1198 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1199 // x | 0 -> x
1200 if (rhsMask.isZero())
1201 return getOperand1();
1202
1203 // x | <all ones> -> <all ones>
1204 if (rhsMask.isAllOnes())
1205 return getOperand2();
1206 }
1207
1208 // According to the SPIR-V spec:
1209 //
1210 // Type is a scalar or vector of integer type.
1211 // Results are computed per component, and within each component, per bit.
1212 // So we can use the APInt | method.
1214 adaptor.getOperands(),
1215 [](const APInt &a, const APInt &b) { return a | b; });
1216}
1217
1218//===----------------------------------------------------------------------===//
1219// spirv.BitwiseXorOp
1220//===----------------------------------------------------------------------===//
1221
1223spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1224 // x ^ 0 -> x
1225 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1226 return getOperand1();
1227 }
1228
1229 // x ^ x -> 0
1230 if (getOperand1() == getOperand2())
1232
1233 // According to the SPIR-V spec:
1234 //
1235 // Type is a scalar or vector of integer type.
1236 // Results are computed per component, and within each component, per bit.
1237 // So we can use the APInt ^ method.
1239 adaptor.getOperands(),
1240 [](const APInt &a, const APInt &b) { return a ^ b; });
1241}
1242
1243//===----------------------------------------------------------------------===//
1244// spirv.mlir.selection
1245//===----------------------------------------------------------------------===//
1246
1247namespace {
1248// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1249// following layout:
1250//
1251// +-----------------------------------------------+
1252// | header block |
1253// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1254// +-----------------------------------------------+
1255// / \
1256// ...
1257//
1258//
1259// +------------------------+ +------------------------+
1260// | case #0 | | case #1 |
1261// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1262// | spirv.Branch ^merge | | spirv.Branch ^merge |
1263// +------------------------+ +------------------------+
1264//
1265//
1266// ...
1267// \ /
1268// v
1269// +-------------+
1270// | merge block |
1271// +-------------+
1272//
1273struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1274 using Base::Base;
1275
1276 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1277 PatternRewriter &rewriter) const override {
1278 Operation *op = selectionOp.getOperation();
1279 Region &body = op->getRegion(0);
1280 // Verifier allows an empty region for `spirv.mlir.selection`.
1281 if (body.empty()) {
1282 return failure();
1283 }
1284
1285 // Check that region consists of 4 blocks:
1286 // header block, `true` block, `false` block and merge block.
1287 if (llvm::range_size(body) != 4) {
1288 return failure();
1289 }
1290
1291 Block *headerBlock = selectionOp.getHeaderBlock();
1292 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1293 return failure();
1294 }
1295
1296 auto brConditionalOp =
1297 cast<spirv::BranchConditionalOp>(headerBlock->front());
1298
1299 Block *trueBlock = brConditionalOp.getSuccessor(0);
1300 Block *falseBlock = brConditionalOp.getSuccessor(1);
1301 Block *mergeBlock = selectionOp.getMergeBlock();
1302
1303 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1304 return failure();
1305
1306 Value trueValue = getSrcValue(trueBlock);
1307 Value falseValue = getSrcValue(falseBlock);
1308 Value ptrValue = getDstPtr(trueBlock);
1309 auto storeOpAttributes =
1310 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1311
1312 auto selectOp = spirv::SelectOp::create(
1313 rewriter, selectionOp.getLoc(), trueValue.getType(),
1314 brConditionalOp.getCondition(), trueValue, falseValue);
1315 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1316 selectOp.getResult(), storeOpAttributes);
1317
1318 // `spirv.mlir.selection` is not needed anymore.
1319 rewriter.eraseOp(op);
1320 return success();
1321 }
1322
1323private:
1324 // Checks that given blocks follow the following rules:
1325 // 1. Each conditional block consists of two operations, the first operation
1326 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1327 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1328 // 3. A control flow goes into the given merge block from the given
1329 // conditional blocks.
1330 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1331 Block *mergeBlock) const;
1332
1333 bool onlyContainsBranchConditionalOp(Block *block) const {
1334 return llvm::hasSingleElement(*block) &&
1335 isa<spirv::BranchConditionalOp>(block->front());
1336 }
1337
1338 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1339 return lhs->getDiscardableAttrDictionary() ==
1340 rhs->getDiscardableAttrDictionary() &&
1341 lhs.getProperties() == rhs.getProperties();
1342 }
1343
1344 // Returns a source value for the given block.
1345 Value getSrcValue(Block *block) const {
1346 auto storeOp = cast<spirv::StoreOp>(block->front());
1347 return storeOp.getValue();
1348 }
1349
1350 // Returns a destination value for the given block.
1351 Value getDstPtr(Block *block) const {
1352 auto storeOp = cast<spirv::StoreOp>(block->front());
1353 return storeOp.getPtr();
1354 }
1355};
1356
1357LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1358 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1359 // Each block must consists of 2 operations.
1360 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1361 return failure();
1362 }
1363
1364 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1365 auto trueBrBranchOp =
1366 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1367 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1368 auto falseBrBranchOp =
1369 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1370
1371 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1372 !falseBrBranchOp) {
1373 return failure();
1374 }
1375
1376 // Checks that given type is valid for `spirv.SelectOp`.
1377 // According to SPIR-V spec:
1378 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1379 // Starting with version 1.4, Result Type can additionally be a composite type
1380 // other than a vector."
1381 bool isScalarOrVector =
1382 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1383 .isScalarOrVector();
1384
1385 // Check that each `spirv.Store` uses the same pointer, memory access
1386 // attributes and a valid type of the value.
1387 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1388 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1389 return failure();
1390 }
1391
1392 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1393 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1394 return failure();
1395 }
1396
1397 return success();
1398}
1399} // namespace
1400
1401void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1402 MLIRContext *context) {
1403 results.add<ConvertSelectionOpToSelect>(context);
1404}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
iterator begin()
Definition Block.h:143
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})