MLIR 23.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the folders and canonicalization patterns for SPIR-V ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include <optional>
14#include <utility>
15
17
21#include "mlir/IR/Matchers.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallVectorExtras.h"
25
26using namespace mlir;
27
28//===----------------------------------------------------------------------===//
29// Common utility functions
30//===----------------------------------------------------------------------===//
31
32/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33/// or splat vector bool constant.
34static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35 if (!attr)
36 return std::nullopt;
37
38 if (auto boolAttr = dyn_cast<BoolAttr>(attr))
39 return boolAttr.getValue();
40 if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
41 if (splatAttr.getElementType().isInteger(1))
42 return splatAttr.getSplatValue<bool>();
43 return std::nullopt;
44}
45
46// Extracts an element from the given `composite` by following the given
47// `indices`. Returns a null Attribute if error happens.
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = dyn_cast<ElementsAttr>(composite)) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValues<Attribute>()[indices[0]];
60 }
61
62 if (auto array = dyn_cast<ArrayAttr>(composite)) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69}
70
71static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72 bool div0 = b.isZero();
73 bool overflow = a.isMinSignedValue() && b.isAllOnes();
74
75 return div0 || overflow;
76}
77
78//===----------------------------------------------------------------------===//
79// TableGen'erated canonicalizers
80//===----------------------------------------------------------------------===//
81
82namespace {
83#include "SPIRVCanonicalization.inc"
84} // namespace
85
86//===----------------------------------------------------------------------===//
87// spirv.AccessChainOp
88//===----------------------------------------------------------------------===//
89
90namespace {
91
92/// Combines chained `spirv::AccessChainOp` operations into one
93/// `spirv::AccessChainOp` operation.
94struct CombineChainedAccessChain final
95 : OpRewritePattern<spirv::AccessChainOp> {
96 using Base::Base;
97
98 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99 PatternRewriter &rewriter) const override {
100 auto parentAccessChainOp =
101 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102
103 if (!parentAccessChainOp) {
104 return failure();
105 }
106
107 // Combine indices.
108 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109 llvm::append_range(indices, accessChainOp.getIndices());
110
111 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
113
114 return success();
115 }
116};
117} // namespace
118
119void spirv::AccessChainOp::getCanonicalizationPatterns(
120 RewritePatternSet &results, MLIRContext *context) {
121 results.add<CombineChainedAccessChain>(context);
122}
123
124//===----------------------------------------------------------------------===//
125// spirv.IAddCarry
126//===----------------------------------------------------------------------===//
127
128// We are required to use CompositeConstructOp to create a constant struct as
129// they are not yet implemented as constant, hence we can not do so in a fold.
130struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
131 using Base::Base;
132
133 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
134 PatternRewriter &rewriter) const override {
135 Location loc = op.getLoc();
136 Value lhs = op.getOperand1();
137 Value rhs = op.getOperand2();
138 Type constituentType = lhs.getType();
139
140 // iaddcarry (x, 0) = <0, x>
141 if (matchPattern(rhs, m_Zero())) {
142 Value constituents[2] = {rhs, lhs};
143 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
144 constituents);
145 return success();
146 }
147
148 // According to the SPIR-V spec:
149 //
150 // Result Type must be from OpTypeStruct. The struct must have two
151 // members...
152 //
153 // Member 0 of the result gets the low-order bits (full component width) of
154 // the addition.
155 //
156 // Member 1 of the result gets the high-order (carry) bit of the result of
157 // the addition. That is, it gets the value 1 if the addition overflowed
158 // the component width, and 0 otherwise.
159 Attribute lhsAttr;
160 Attribute rhsAttr;
161 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
162 !matchPattern(rhs, m_Constant(&rhsAttr)))
163 return failure();
164
166 {lhsAttr, rhsAttr},
167 [](const APInt &a, const APInt &b) { return a + b; });
168 if (!adds)
169 return failure();
170
171 auto carrys = constFoldBinaryOp<IntegerAttr>(
172 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
173 APInt zero = APInt::getZero(a.getBitWidth());
174 return a.ult(b) ? (zero + 1) : zero;
175 });
176
177 if (!carrys)
178 return failure();
179
180 Value addsVal =
181 spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
182
183 Value carrysVal =
184 spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
185
186 // Create empty struct
187 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
188 // Fill in adds at id 0
189 Value intermediate =
190 spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
191 // Fill in carrys at id 1
192 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
193 intermediate, 1);
194 return success();
195 }
196};
197
198void spirv::IAddCarryOp::getCanonicalizationPatterns(
200 patterns.add<IAddCarryFold>(context);
201}
202
203//===----------------------------------------------------------------------===//
204// spirv.[S|U]MulExtended
205//===----------------------------------------------------------------------===//
206
207// We are required to use CompositeConstructOp to create a constant struct as
208// they are not yet implemented as constant, hence we can not do so in a fold.
209template <typename MulOp, bool IsSigned>
210struct MulExtendedFold final : OpRewritePattern<MulOp> {
212
213 LogicalResult matchAndRewrite(MulOp op,
214 PatternRewriter &rewriter) const override {
215 Location loc = op.getLoc();
216 Value lhs = op.getOperand1();
217 Value rhs = op.getOperand2();
218 Type constituentType = lhs.getType();
219
220 // [su]mulextended (x, 0) = <0, 0>
221 if (matchPattern(rhs, m_Zero())) {
222 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
223 Value constituents[2] = {zero, zero};
224 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
225 constituents);
226 return success();
227 }
228
229 // According to the SPIR-V spec:
230 //
231 // Result Type must be from OpTypeStruct. The struct must have two
232 // members...
233 //
234 // Member 0 of the result gets the low-order bits of the multiplication.
235 //
236 // Member 1 of the result gets the high-order bits of the multiplication.
237 Attribute lhsAttr;
238 Attribute rhsAttr;
239 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
240 !matchPattern(rhs, m_Constant(&rhsAttr)))
241 return failure();
242
243 auto lowBits = constFoldBinaryOp<IntegerAttr>(
244 {lhsAttr, rhsAttr},
245 [](const APInt &a, const APInt &b) { return a * b; });
246
247 if (!lowBits)
248 return failure();
249
250 auto highBits = constFoldBinaryOp<IntegerAttr>(
251 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
252 if (IsSigned) {
253 return llvm::APIntOps::mulhs(a, b);
254 }
255 return llvm::APIntOps::mulhu(a, b);
256 });
257
258 if (!highBits)
259 return failure();
260
261 Value lowBitsVal =
262 spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
263
264 Value highBitsVal =
265 spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
266
267 // Create empty struct
268 Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
269 // Fill in lowBits at id 0
270 Value intermediate =
271 spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
272 // Fill in highBits at id 1
273 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
274 intermediate, 1);
275 return success();
276 }
277};
278
280void spirv::SMulExtendedOp::getCanonicalizationPatterns(
282 patterns.add<SMulExtendedOpFold>(context);
283}
284
285struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
286 using Base::Base;
287
288 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
289 PatternRewriter &rewriter) const override {
290 Location loc = op.getLoc();
291 Value lhs = op.getOperand1();
292 Value rhs = op.getOperand2();
293 Type constituentType = lhs.getType();
294
295 // umulextended (x, 1) = <x, 0>
296 if (matchPattern(rhs, m_One())) {
297 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
298 Value constituents[2] = {lhs, zero};
299 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
300 constituents);
301 return success();
302 }
303
304 return failure();
305 }
306};
307
309void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312}
313
314//===----------------------------------------------------------------------===//
315// spirv.UMod
316//===----------------------------------------------------------------------===//
317
318// Input:
319// %0 = spirv.UMod %arg0, %const32 : i32
320// %1 = spirv.UMod %0, %const4 : i32
321// Output:
322// %0 = spirv.UMod %arg0, %const32 : i32
323// %1 = spirv.UMod %arg0, %const4 : i32
324
325// The transformation is only applied if one divisor is a multiple of the other.
326
327struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
328 using Base::Base;
329
330 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
331 PatternRewriter &rewriter) const override {
332 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
333 if (!prevUMod)
334 return failure();
335
336 TypedAttr prevValue;
337 TypedAttr currValue;
338 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
339 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
340 return failure();
341
342 // Ensure that previous divisor is a multiple of the current divisor. If
343 // not, fail the transformation.
344 bool isApplicable = false;
345 if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
346 auto currInt = cast<IntegerAttr>(currValue);
347 isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
348 } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
349 auto currVec = cast<DenseElementsAttr>(currValue);
350 isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
351 currVec.getValues<APInt>()),
352 [](const auto &pair) {
353 auto &[prev, curr] = pair;
354 return prev.urem(curr) == 0;
355 });
356 }
357
358 if (!isApplicable)
359 return failure();
360
361 // The transformation is safe. Replace the existing UMod operation with a
362 // new UMod operation, using the original dividend and the current divisor.
363 rewriter.replaceOpWithNewOp<spirv::UModOp>(
364 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
365
366 return success();
367 }
368};
369
370void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
371 MLIRContext *context) {
372 patterns.add<UModSimplification>(context);
373}
374
375//===----------------------------------------------------------------------===//
376// spirv.BitcastOp
377//===----------------------------------------------------------------------===//
378
379OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
380 Value curInput = getOperand();
381 if (getType() == curInput.getType())
382 return curInput;
383
384 // Look through nested bitcasts.
385 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
386 Value prevInput = prevCast.getOperand();
387 if (prevInput.getType() == getType())
388 return prevInput;
389
390 getOperandMutable().assign(prevInput);
391 return getResult();
392 }
393
394 // TODO(kuhar): Consider constant-folding the operand attribute.
395 return {};
396}
397
398//===----------------------------------------------------------------------===//
399// spirv.CompositeExtractOp
400//===----------------------------------------------------------------------===//
401
402OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
403 Value compositeOp = getComposite();
404
405 while (auto insertOp =
406 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
407 if (getIndices() == insertOp.getIndices())
408 return insertOp.getObject();
409 compositeOp = insertOp.getComposite();
410 }
411
412 if (auto constructOp =
413 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
414 auto type = cast<spirv::CompositeType>(constructOp.getType());
415 if (getIndices().size() == 1 &&
416 constructOp.getConstituents().size() == type.getNumElements()) {
417 auto i = cast<IntegerAttr>(*getIndices().begin());
418 if (i.getValue().getSExtValue() <
419 static_cast<int64_t>(constructOp.getConstituents().size()))
420 return constructOp.getConstituents()[i.getValue().getSExtValue()];
421 }
422 }
423
424 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
425 return static_cast<unsigned>(cast<IntegerAttr>(attr).getInt());
426 });
427 return extractCompositeElement(adaptor.getComposite(), indexVector);
428}
429
430//===----------------------------------------------------------------------===//
431// spirv.Constant
432//===----------------------------------------------------------------------===//
433
434OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
435 return getValue();
436}
437
438//===----------------------------------------------------------------------===//
439// spirv.IAdd
440//===----------------------------------------------------------------------===//
441
442OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
443 // x + 0 = x
444 if (matchPattern(getOperand2(), m_Zero()))
445 return getOperand1();
446
447 // According to the SPIR-V spec:
448 //
449 // The resulting value will equal the low-order N bits of the correct result
450 // R, where N is the component width and R is computed with enough precision
451 // to avoid overflow and underflow.
453 adaptor.getOperands(),
454 [](APInt a, const APInt &b) { return std::move(a) + b; });
455}
456
457//===----------------------------------------------------------------------===//
458// spirv.IMul
459//===----------------------------------------------------------------------===//
460
461OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
462 // x * 0 == 0
463 if (matchPattern(getOperand2(), m_Zero()))
464 return getOperand2();
465 // x * 1 = x
466 if (matchPattern(getOperand2(), m_One()))
467 return getOperand1();
468
469 // According to the SPIR-V spec:
470 //
471 // The resulting value will equal the low-order N bits of the correct result
472 // R, where N is the component width and R is computed with enough precision
473 // to avoid overflow and underflow.
475 adaptor.getOperands(),
476 [](const APInt &a, const APInt &b) { return a * b; });
477}
478
479//===----------------------------------------------------------------------===//
480// spirv.ISub
481//===----------------------------------------------------------------------===//
482
483OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
484 // x - x = 0
485 if (getOperand1() == getOperand2())
487
488 // According to the SPIR-V spec:
489 //
490 // The resulting value will equal the low-order N bits of the correct result
491 // R, where N is the component width and R is computed with enough precision
492 // to avoid overflow and underflow.
494 adaptor.getOperands(),
495 [](APInt a, const APInt &b) { return std::move(a) - b; });
496}
497
498//===----------------------------------------------------------------------===//
499// spirv.SDiv
500//===----------------------------------------------------------------------===//
501
502OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
503 // sdiv (x, 1) = x
504 if (matchPattern(getOperand2(), m_One()))
505 return getOperand1();
506
507 // According to the SPIR-V spec:
508 //
509 // Signed-integer division of Operand 1 divided by Operand 2.
510 // Results are computed per component. Behavior is undefined if Operand 2 is
511 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
512 // representable value for the operands' type, causing signed overflow.
513 //
514 // So don't fold during undefined behavior.
515 bool div0OrOverflow = false;
517 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
518 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
519 div0OrOverflow = true;
520 return a;
521 }
522 return a.sdiv(b);
523 });
524 return div0OrOverflow ? Attribute() : res;
525}
526
527//===----------------------------------------------------------------------===//
528// spirv.SMod
529//===----------------------------------------------------------------------===//
530
531OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
532 // smod (x, 1) = 0
533 if (matchPattern(getOperand2(), m_One()))
535
536 // According to SPIR-V spec:
537 //
538 // Signed remainder operation for the remainder whose sign matches the sign
539 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
540 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
541 // value for the operands' type, causing signed overflow. Otherwise, the
542 // result is the remainder r of Operand 1 divided by Operand 2 where if
543 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
544 //
545 // So don't fold during undefined behavior
546 bool div0OrOverflow = false;
548 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
549 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
550 div0OrOverflow = true;
551 return a;
552 }
553 APInt c = a.abs().urem(b.abs());
554 if (c.isZero())
555 return c;
556 if (b.isNegative()) {
557 APInt zero = APInt::getZero(c.getBitWidth());
558 return a.isNegative() ? (zero - c) : (b + c);
559 }
560 return a.isNegative() ? (b - c) : c;
561 });
562 return div0OrOverflow ? Attribute() : res;
563}
564
565//===----------------------------------------------------------------------===//
566// spirv.SRem
567//===----------------------------------------------------------------------===//
568
569OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
570 // x % 1 = 0
571 if (matchPattern(getOperand2(), m_One()))
573
574 // According to SPIR-V spec:
575 //
576 // Signed remainder operation for the remainder whose sign matches the sign
577 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
578 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
579 // value for the operands' type, causing signed overflow. Otherwise, the
580 // result is the remainder r of Operand 1 divided by Operand 2 where if
581 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
582
583 // Don't fold if it would do undefined behavior.
584 bool div0OrOverflow = false;
586 adaptor.getOperands(), [&](APInt a, const APInt &b) {
587 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
588 div0OrOverflow = true;
589 return a;
590 }
591 return a.srem(b);
592 });
593 return div0OrOverflow ? Attribute() : res;
594}
595
596//===----------------------------------------------------------------------===//
597// spirv.UDiv
598//===----------------------------------------------------------------------===//
599
600OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
601 // udiv (x, 1) = x
602 if (matchPattern(getOperand2(), m_One()))
603 return getOperand1();
604
605 // According to the SPIR-V spec:
606 //
607 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
608 // undefined if Operand 2 is 0.
609 //
610 // So don't fold during undefined behavior.
611 bool div0 = false;
613 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
614 if (div0 || b.isZero()) {
615 div0 = true;
616 return a;
617 }
618 return a.udiv(b);
619 });
620 return div0 ? Attribute() : res;
621}
622
623//===----------------------------------------------------------------------===//
624// spirv.UMod
625//===----------------------------------------------------------------------===//
626
627OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
628 // umod (x, 1) = 0
629 if (matchPattern(getOperand2(), m_One()))
631
632 // According to the SPIR-V spec:
633 //
634 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
635 // undefined if Operand 2 is 0.
636 //
637 // So don't fold during undefined behavior.
638 bool div0 = false;
640 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
641 if (div0 || b.isZero()) {
642 div0 = true;
643 return a;
644 }
645 return a.urem(b);
646 });
647 return div0 ? Attribute() : res;
648}
649
650//===----------------------------------------------------------------------===//
651// spirv.SNegate
652//===----------------------------------------------------------------------===//
653
654OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
655 // -(-x) = 0 - (0 - x) = x
656 auto op = getOperand();
657 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
658 return negateOp->getOperand(0);
659
660 // According to the SPIR-V spec:
661 //
662 // Signed-integer subtract of Operand from zero.
664 adaptor.getOperands(), [](const APInt &a) {
665 APInt zero = APInt::getZero(a.getBitWidth());
666 return zero - a;
667 });
668}
669
670//===----------------------------------------------------------------------===//
671// spirv.NotOp
672//===----------------------------------------------------------------------===//
673
674OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
675 // !(!x) = x
676 auto op = getOperand();
677 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
678 return notOp->getOperand(0);
679
680 // According to the SPIR-V spec:
681 //
682 // Complement the bits of Operand.
683 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
684 a.flipAllBits();
685 return a;
686 });
687}
688
689//===----------------------------------------------------------------------===//
690// spirv.LogicalAnd
691//===----------------------------------------------------------------------===//
692
693OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
694 if (std::optional<bool> rhs =
695 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
696 // x && true = x
697 if (*rhs)
698 return getOperand1();
699
700 // x && false = false
701 if (!*rhs)
702 return adaptor.getOperand2();
703 }
704
705 return Attribute();
706}
707
708//===----------------------------------------------------------------------===//
709// spirv.LogicalEqualOp
710//===----------------------------------------------------------------------===//
711
713spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
714 // x == x -> true
715 if (getOperand1() == getOperand2()) {
716 auto trueAttr = BoolAttr::get(getContext(), true);
717 if (isa<IntegerType>(getType()))
718 return trueAttr;
719 if (auto vecTy = dyn_cast<VectorType>(getType()))
720 return SplatElementsAttr::get(vecTy, trueAttr);
721 }
722
724 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
725 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
726 });
727}
728
729//===----------------------------------------------------------------------===//
730// spirv.LogicalNotEqualOp
731//===----------------------------------------------------------------------===//
732
733OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
734 if (std::optional<bool> rhs =
735 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
736 // x != false -> x
737 if (!rhs.value())
738 return getOperand1();
739 }
740
741 // x == x -> false
742 if (getOperand1() == getOperand2()) {
743 auto falseAttr = BoolAttr::get(getContext(), false);
744 if (isa<IntegerType>(getType()))
745 return falseAttr;
746 if (auto vecTy = dyn_cast<VectorType>(getType()))
747 return SplatElementsAttr::get(vecTy, falseAttr);
748 }
749
751 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
752 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
753 });
754}
755
756//===----------------------------------------------------------------------===//
757// spirv.LogicalNot
758//===----------------------------------------------------------------------===//
759
760OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
761 // !(!x) = x
762 auto op = getOperand();
763 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
764 return notOp->getOperand(0);
765
766 // According to the SPIR-V spec:
767 //
768 // Complement the bits of Operand.
769 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
770 [](const APInt &a) {
771 APInt zero = APInt::getZero(1);
772 return a == 1 ? zero : (zero + 1);
773 });
774}
775
776void spirv::LogicalNotOp::getCanonicalizationPatterns(
777 RewritePatternSet &results, MLIRContext *context) {
778 results
779 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
780 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
781 context);
782}
783
784//===----------------------------------------------------------------------===//
785// spirv.LogicalOr
786//===----------------------------------------------------------------------===//
787
788OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
789 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
790 if (*rhs) {
791 // x || true = true
792 return adaptor.getOperand2();
793 }
794
795 if (!*rhs) {
796 // x || false = x
797 return getOperand1();
798 }
799 }
800
801 return Attribute();
802}
803
804//===----------------------------------------------------------------------===//
805// spirv.SelectOp
806//===----------------------------------------------------------------------===//
807
808OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
809 // spirv.Select _ x x -> x
810 Value trueVals = getTrueValue();
811 Value falseVals = getFalseValue();
812 if (trueVals == falseVals)
813 return trueVals;
814
815 ArrayRef<Attribute> operands = adaptor.getOperands();
816
817 // spirv.Select true x y -> x
818 // spirv.Select false x y -> y
819 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
820 return *boolAttr ? trueVals : falseVals;
821
822 // Check that all the operands are constant
823 if (!operands[0] || !operands[1] || !operands[2])
824 return Attribute();
825
826 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
827 // the scalar case. Hence, we are only required to consider the case of
828 // DenseElementsAttr in foldSelectOp.
829 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
830 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
831 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
832 if (!condAttrs || !trueAttrs || !falseAttrs)
833 return Attribute();
834
835 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
836 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
837 falseAttrs.getValues<Attribute>());
838 for (auto [result, cond, falseRes] : iters) {
839 if (!cond.getValue())
840 result = falseRes;
841 }
842
843 auto resultType = trueAttrs.getType();
844 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
845}
846
847//===----------------------------------------------------------------------===//
848// spirv.IEqualOp
849//===----------------------------------------------------------------------===//
850
851OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
852 // x == x -> true
853 if (getOperand1() == getOperand2()) {
854 auto trueAttr = BoolAttr::get(getContext(), true);
855 if (isa<IntegerType>(getType()))
856 return trueAttr;
857 if (auto vecTy = dyn_cast<VectorType>(getType()))
858 return SplatElementsAttr::get(vecTy, trueAttr);
859 }
860
862 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
863 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
864 });
865}
866
867//===----------------------------------------------------------------------===//
868// spirv.INotEqualOp
869//===----------------------------------------------------------------------===//
870
871OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
872 // x == x -> false
873 if (getOperand1() == getOperand2()) {
874 auto falseAttr = BoolAttr::get(getContext(), false);
875 if (isa<IntegerType>(getType()))
876 return falseAttr;
877 if (auto vecTy = dyn_cast<VectorType>(getType()))
878 return SplatElementsAttr::get(vecTy, falseAttr);
879 }
880
882 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
883 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
884 });
885}
886
887//===----------------------------------------------------------------------===//
888// spirv.SGreaterThan
889//===----------------------------------------------------------------------===//
890
892spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
893 // x == x -> false
894 if (getOperand1() == getOperand2()) {
895 auto falseAttr = BoolAttr::get(getContext(), false);
896 if (isa<IntegerType>(getType()))
897 return falseAttr;
898 if (auto vecTy = dyn_cast<VectorType>(getType()))
899 return SplatElementsAttr::get(vecTy, falseAttr);
900 }
901
903 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
904 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
905 });
906}
907
908//===----------------------------------------------------------------------===//
909// spirv.SGreaterThanEqual
910//===----------------------------------------------------------------------===//
911
912OpFoldResult spirv::SGreaterThanEqualOp::fold(
913 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
914 // x == x -> true
915 if (getOperand1() == getOperand2()) {
916 auto trueAttr = BoolAttr::get(getContext(), true);
917 if (isa<IntegerType>(getType()))
918 return trueAttr;
919 if (auto vecTy = dyn_cast<VectorType>(getType()))
920 return SplatElementsAttr::get(vecTy, trueAttr);
921 }
922
924 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
925 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
926 });
927}
928
929//===----------------------------------------------------------------------===//
930// spirv.UGreaterThan
931//===----------------------------------------------------------------------===//
932
934spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
935 // x == x -> false
936 if (getOperand1() == getOperand2()) {
937 auto falseAttr = BoolAttr::get(getContext(), false);
938 if (isa<IntegerType>(getType()))
939 return falseAttr;
940 if (auto vecTy = dyn_cast<VectorType>(getType()))
941 return SplatElementsAttr::get(vecTy, falseAttr);
942 }
943
945 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
946 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
947 });
948}
949
950//===----------------------------------------------------------------------===//
951// spirv.UGreaterThanEqual
952//===----------------------------------------------------------------------===//
953
954OpFoldResult spirv::UGreaterThanEqualOp::fold(
955 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
956 // x == x -> true
957 if (getOperand1() == getOperand2()) {
958 auto trueAttr = BoolAttr::get(getContext(), true);
959 if (isa<IntegerType>(getType()))
960 return trueAttr;
961 if (auto vecTy = dyn_cast<VectorType>(getType()))
962 return SplatElementsAttr::get(vecTy, trueAttr);
963 }
964
966 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
967 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
968 });
969}
970
971//===----------------------------------------------------------------------===//
972// spirv.SLessThan
973//===----------------------------------------------------------------------===//
974
975OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
976 // x == x -> false
977 if (getOperand1() == getOperand2()) {
978 auto falseAttr = BoolAttr::get(getContext(), false);
979 if (isa<IntegerType>(getType()))
980 return falseAttr;
981 if (auto vecTy = dyn_cast<VectorType>(getType()))
982 return SplatElementsAttr::get(vecTy, falseAttr);
983 }
984
986 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
987 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
988 });
989}
990
991//===----------------------------------------------------------------------===//
992// spirv.SLessThanEqual
993//===----------------------------------------------------------------------===//
994
996spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
997 // x == x -> true
998 if (getOperand1() == getOperand2()) {
999 auto trueAttr = BoolAttr::get(getContext(), true);
1000 if (isa<IntegerType>(getType()))
1001 return trueAttr;
1002 if (auto vecTy = dyn_cast<VectorType>(getType()))
1003 return SplatElementsAttr::get(vecTy, trueAttr);
1004 }
1005
1007 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1008 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1009 });
1010}
1011
1012//===----------------------------------------------------------------------===//
1013// spirv.ULessThan
1014//===----------------------------------------------------------------------===//
1015
1016OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1017 // x == x -> false
1018 if (getOperand1() == getOperand2()) {
1019 auto falseAttr = BoolAttr::get(getContext(), false);
1020 if (isa<IntegerType>(getType()))
1021 return falseAttr;
1022 if (auto vecTy = dyn_cast<VectorType>(getType()))
1023 return SplatElementsAttr::get(vecTy, falseAttr);
1024 }
1025
1027 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1028 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1029 });
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// spirv.ULessThanEqual
1034//===----------------------------------------------------------------------===//
1035
1037spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1038 // x == x -> true
1039 if (getOperand1() == getOperand2()) {
1040 auto trueAttr = BoolAttr::get(getContext(), true);
1041 if (isa<IntegerType>(getType()))
1042 return trueAttr;
1043 if (auto vecTy = dyn_cast<VectorType>(getType()))
1044 return SplatElementsAttr::get(vecTy, trueAttr);
1045 }
1046
1048 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1049 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1050 });
1051}
1052
1053//===----------------------------------------------------------------------===//
1054// spirv.ShiftLeftLogical
1055//===----------------------------------------------------------------------===//
1056
1057OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1058 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1059 // x << 0 -> x
1060 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1061 return getOperand1();
1062 }
1063
1064 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1065
1066 // Results are computed per component, and within each component, per bit...
1067 //
1068 // The result is undefined if Shift is greater than or equal to the bit width
1069 // of the components of Base.
1070 //
1071 // So we can use the APInt << method, but don't fold if undefined behaviour.
1072 bool shiftToLarge = false;
1074 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1075 if (shiftToLarge || b.uge(a.getBitWidth())) {
1076 shiftToLarge = true;
1077 return a;
1078 }
1079 return a << b;
1080 });
1081 return shiftToLarge ? Attribute() : res;
1082}
1083
1084//===----------------------------------------------------------------------===//
1085// spirv.ShiftRightArithmetic
1086//===----------------------------------------------------------------------===//
1087
1088OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1089 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1090 // x >> 0 -> x
1091 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1092 return getOperand1();
1093 }
1094
1095 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1096
1097 // Results are computed per component, and within each component, per bit...
1098 //
1099 // The result is undefined if Shift is greater than or equal to the bit width
1100 // of the components of Base.
1101 //
1102 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1103 bool shiftToLarge = false;
1105 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1106 if (shiftToLarge || b.uge(a.getBitWidth())) {
1107 shiftToLarge = true;
1108 return a;
1109 }
1110 return a.ashr(b);
1111 });
1112 return shiftToLarge ? Attribute() : res;
1113}
1114
1115//===----------------------------------------------------------------------===//
1116// spirv.ShiftRightLogical
1117//===----------------------------------------------------------------------===//
1118
1119OpFoldResult spirv::ShiftRightLogicalOp::fold(
1120 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1121 // x >> 0 -> x
1122 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1123 return getOperand1();
1124 }
1125
1126 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1127
1128 // Results are computed per component, and within each component, per bit...
1129 //
1130 // The result is undefined if Shift is greater than or equal to the bit width
1131 // of the components of Base.
1132 //
1133 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1134 bool shiftToLarge = false;
1136 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1137 if (shiftToLarge || b.uge(a.getBitWidth())) {
1138 shiftToLarge = true;
1139 return a;
1140 }
1141 return a.lshr(b);
1142 });
1143 return shiftToLarge ? Attribute() : res;
1144}
1145
1146//===----------------------------------------------------------------------===//
1147// spirv.BitwiseAndOp
1148//===----------------------------------------------------------------------===//
1149
1151spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1152 // x & x -> x
1153 if (getOperand1() == getOperand2()) {
1154 return getOperand1();
1155 }
1156
1157 APInt rhsMask;
1158 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1159 // x & 0 -> 0
1160 if (rhsMask.isZero())
1161 return getOperand2();
1162
1163 // x & <all ones> -> x
1164 if (rhsMask.isAllOnes())
1165 return getOperand1();
1166
1167 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1168 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1169 int valueBits =
1171 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1172 return getOperand1();
1173 }
1174 }
1175
1176 // According to the SPIR-V spec:
1177 //
1178 // Type is a scalar or vector of integer type.
1179 // Results are computed per component, and within each component, per bit.
1180 // So we can use the APInt & method.
1182 adaptor.getOperands(),
1183 [](const APInt &a, const APInt &b) { return a & b; });
1184}
1185
1186//===----------------------------------------------------------------------===//
1187// spirv.BitwiseOrOp
1188//===----------------------------------------------------------------------===//
1189
1190OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1191 // x | x -> x
1192 if (getOperand1() == getOperand2()) {
1193 return getOperand1();
1194 }
1195
1196 APInt rhsMask;
1197 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1198 // x | 0 -> x
1199 if (rhsMask.isZero())
1200 return getOperand1();
1201
1202 // x | <all ones> -> <all ones>
1203 if (rhsMask.isAllOnes())
1204 return getOperand2();
1205 }
1206
1207 // According to the SPIR-V spec:
1208 //
1209 // Type is a scalar or vector of integer type.
1210 // Results are computed per component, and within each component, per bit.
1211 // So we can use the APInt | method.
1213 adaptor.getOperands(),
1214 [](const APInt &a, const APInt &b) { return a | b; });
1215}
1216
1217//===----------------------------------------------------------------------===//
1218// spirv.BitwiseXorOp
1219//===----------------------------------------------------------------------===//
1220
1222spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1223 // x ^ 0 -> x
1224 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1225 return getOperand1();
1226 }
1227
1228 // x ^ x -> 0
1229 if (getOperand1() == getOperand2())
1231
1232 // According to the SPIR-V spec:
1233 //
1234 // Type is a scalar or vector of integer type.
1235 // Results are computed per component, and within each component, per bit.
1236 // So we can use the APInt ^ method.
1238 adaptor.getOperands(),
1239 [](const APInt &a, const APInt &b) { return a ^ b; });
1240}
1241
1242//===----------------------------------------------------------------------===//
1243// spirv.mlir.selection
1244//===----------------------------------------------------------------------===//
1245
1246namespace {
1247// Blocks from the given `spirv.mlir.selection` operation must satisfy the
1248// following layout:
1249//
1250// +-----------------------------------------------+
1251// | header block |
1252// | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1253// +-----------------------------------------------+
1254// / \
1255// ...
1256//
1257//
1258// +------------------------+ +------------------------+
1259// | case #0 | | case #1 |
1260// | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1261// | spirv.Branch ^merge | | spirv.Branch ^merge |
1262// +------------------------+ +------------------------+
1263//
1264//
1265// ...
1266// \ /
1267// v
1268// +-------------+
1269// | merge block |
1270// +-------------+
1271//
1272struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1273 using Base::Base;
1274
1275 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1276 PatternRewriter &rewriter) const override {
1277 Operation *op = selectionOp.getOperation();
1278 Region &body = op->getRegion(0);
1279 // Verifier allows an empty region for `spirv.mlir.selection`.
1280 if (body.empty()) {
1281 return failure();
1282 }
1283
1284 // Check that region consists of 4 blocks:
1285 // header block, `true` block, `false` block and merge block.
1286 if (llvm::range_size(body) != 4) {
1287 return failure();
1288 }
1289
1290 Block *headerBlock = selectionOp.getHeaderBlock();
1291 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1292 return failure();
1293 }
1294
1295 auto brConditionalOp =
1296 cast<spirv::BranchConditionalOp>(headerBlock->front());
1297
1298 Block *trueBlock = brConditionalOp.getSuccessor(0);
1299 Block *falseBlock = brConditionalOp.getSuccessor(1);
1300 Block *mergeBlock = selectionOp.getMergeBlock();
1301
1302 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1303 return failure();
1304
1305 Value trueValue = getSrcValue(trueBlock);
1306 Value falseValue = getSrcValue(falseBlock);
1307 Value ptrValue = getDstPtr(trueBlock);
1308 auto storeOpAttributes =
1309 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1310
1311 auto selectOp = spirv::SelectOp::create(
1312 rewriter, selectionOp.getLoc(), trueValue.getType(),
1313 brConditionalOp.getCondition(), trueValue, falseValue);
1314 spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1315 selectOp.getResult(), storeOpAttributes);
1316
1317 // `spirv.mlir.selection` is not needed anymore.
1318 rewriter.eraseOp(op);
1319 return success();
1320 }
1321
1322private:
1323 // Checks that given blocks follow the following rules:
1324 // 1. Each conditional block consists of two operations, the first operation
1325 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1326 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1327 // 3. A control flow goes into the given merge block from the given
1328 // conditional blocks.
1329 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1330 Block *mergeBlock) const;
1331
1332 bool onlyContainsBranchConditionalOp(Block *block) const {
1333 return llvm::hasSingleElement(*block) &&
1334 isa<spirv::BranchConditionalOp>(block->front());
1335 }
1336
1337 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1338 return lhs->getDiscardableAttrDictionary() ==
1339 rhs->getDiscardableAttrDictionary() &&
1340 lhs.getProperties() == rhs.getProperties();
1341 }
1342
1343 // Returns a source value for the given block.
1344 Value getSrcValue(Block *block) const {
1345 auto storeOp = cast<spirv::StoreOp>(block->front());
1346 return storeOp.getValue();
1347 }
1348
1349 // Returns a destination value for the given block.
1350 Value getDstPtr(Block *block) const {
1351 auto storeOp = cast<spirv::StoreOp>(block->front());
1352 return storeOp.getPtr();
1353 }
1354};
1355
1356LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1357 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1358 // Each block must consists of 2 operations.
1359 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1360 return failure();
1361 }
1362
1363 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1364 auto trueBrBranchOp =
1365 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1366 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1367 auto falseBrBranchOp =
1368 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1369
1370 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1371 !falseBrBranchOp) {
1372 return failure();
1373 }
1374
1375 // Checks that given type is valid for `spirv.SelectOp`.
1376 // According to SPIR-V spec:
1377 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1378 // Starting with version 1.4, Result Type can additionally be a composite type
1379 // other than a vector."
1380 bool isScalarOrVector =
1381 cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1382 .isScalarOrVector();
1383
1384 // Check that each `spirv.Store` uses the same pointer, memory access
1385 // attributes and a valid type of the value.
1386 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1387 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1388 return failure();
1389 }
1390
1391 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1392 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1393 return failure();
1394 }
1395
1396 return success();
1397}
1398} // namespace
1399
1400void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1401 MLIRContext *context) {
1402 results.add<ConvertSelectionOpToSelect>(context);
1403}
return success()
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
if(!isCopyOut)
b getContext())
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
MulExtendedFold< spirv::SMulExtendedOp, true > SMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
iterator begin()
Definition Block.h:153
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition Matchers.h:478
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
Attribute constFoldUnaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})