MLIR  22.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <optional>
14 #include <utility>
15 
17 
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
25 
26 using namespace mlir;
27 
28 //===----------------------------------------------------------------------===//
29 // Common utility functions
30 //===----------------------------------------------------------------------===//
31 
32 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
33 /// or splat vector bool constant.
34 static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
35  if (!attr)
36  return std::nullopt;
37 
38  if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
39  return boolAttr.getValue();
40  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
41  if (splatAttr.getElementType().isInteger(1))
42  return splatAttr.getSplatValue<bool>();
43  return std::nullopt;
44 }
45 
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
49  ArrayRef<unsigned> indices) {
50  // Check that given composite is a constant.
51  if (!composite)
52  return {};
53  // Return composite itself if we reach the end of the index chain.
54  if (indices.empty())
55  return composite;
56 
57  if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
58  assert(indices.size() == 1 && "must have exactly one index for a vector");
59  return vector.getValues<Attribute>()[indices[0]];
60  }
61 
62  if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
63  assert(!indices.empty() && "must have at least one index for an array");
64  return extractCompositeElement(array.getValue()[indices[0]],
65  indices.drop_front());
66  }
67 
68  return {};
69 }
70 
71 static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
72  bool div0 = b.isZero();
73  bool overflow = a.isMinSignedValue() && b.isAllOnes();
74 
75  return div0 || overflow;
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // TableGen'erated canonicalizers
80 //===----------------------------------------------------------------------===//
81 
82 namespace {
83 #include "SPIRVCanonicalization.inc"
84 } // namespace
85 
86 //===----------------------------------------------------------------------===//
87 // spirv.AccessChainOp
88 //===----------------------------------------------------------------------===//
89 
90 namespace {
91 
92 /// Combines chained `spirv::AccessChainOp` operations into one
93 /// `spirv::AccessChainOp` operation.
94 struct CombineChainedAccessChain final
95  : OpRewritePattern<spirv::AccessChainOp> {
97 
98  LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
99  PatternRewriter &rewriter) const override {
100  auto parentAccessChainOp =
101  accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
102 
103  if (!parentAccessChainOp) {
104  return failure();
105  }
106 
107  // Combine indices.
108  SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
109  llvm::append_range(indices, accessChainOp.getIndices());
110 
111  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
112  accessChainOp, parentAccessChainOp.getBasePtr(), indices);
113 
114  return success();
115  }
116 };
117 } // namespace
118 
119 void spirv::AccessChainOp::getCanonicalizationPatterns(
120  RewritePatternSet &results, MLIRContext *context) {
121  results.add<CombineChainedAccessChain>(context);
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // spirv.IAddCarry
126 //===----------------------------------------------------------------------===//
127 
128 // We are required to use CompositeConstructOp to create a constant struct as
129 // they are not yet implemented as constant, hence we can not do so in a fold.
130 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132 
133  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
134  PatternRewriter &rewriter) const override {
135  Location loc = op.getLoc();
136  Value lhs = op.getOperand1();
137  Value rhs = op.getOperand2();
138  Type constituentType = lhs.getType();
139 
140  // iaddcarry (x, 0) = <0, x>
141  if (matchPattern(rhs, m_Zero())) {
142  Value constituents[2] = {rhs, lhs};
143  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
144  constituents);
145  return success();
146  }
147 
148  // According to the SPIR-V spec:
149  //
150  // Result Type must be from OpTypeStruct. The struct must have two
151  // members...
152  //
153  // Member 0 of the result gets the low-order bits (full component width) of
154  // the addition.
155  //
156  // Member 1 of the result gets the high-order (carry) bit of the result of
157  // the addition. That is, it gets the value 1 if the addition overflowed
158  // the component width, and 0 otherwise.
159  Attribute lhsAttr;
160  Attribute rhsAttr;
161  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
162  !matchPattern(rhs, m_Constant(&rhsAttr)))
163  return failure();
164 
165  auto adds = constFoldBinaryOp<IntegerAttr>(
166  {lhsAttr, rhsAttr},
167  [](const APInt &a, const APInt &b) { return a + b; });
168  if (!adds)
169  return failure();
170 
171  auto carrys = constFoldBinaryOp<IntegerAttr>(
172  ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
173  APInt zero = APInt::getZero(a.getBitWidth());
174  return a.ult(b) ? (zero + 1) : zero;
175  });
176 
177  if (!carrys)
178  return failure();
179 
180  Value addsVal =
181  spirv::ConstantOp::create(rewriter, loc, constituentType, adds);
182 
183  Value carrysVal =
184  spirv::ConstantOp::create(rewriter, loc, constituentType, carrys);
185 
186  // Create empty struct
187  Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
188  // Fill in adds at id 0
189  Value intermediate =
190  spirv::CompositeInsertOp::create(rewriter, loc, addsVal, undef, 0);
191  // Fill in carrys at id 1
192  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
193  intermediate, 1);
194  return success();
195  }
196 };
197 
198 void spirv::IAddCarryOp::getCanonicalizationPatterns(
200  patterns.add<IAddCarryFold>(context);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // spirv.[S|U]MulExtended
205 //===----------------------------------------------------------------------===//
206 
207 // We are required to use CompositeConstructOp to create a constant struct as
208 // they are not yet implemented as constant, hence we can not do so in a fold.
209 template <typename MulOp, bool IsSigned>
210 struct MulExtendedFold final : OpRewritePattern<MulOp> {
212 
213  LogicalResult matchAndRewrite(MulOp op,
214  PatternRewriter &rewriter) const override {
215  Location loc = op.getLoc();
216  Value lhs = op.getOperand1();
217  Value rhs = op.getOperand2();
218  Type constituentType = lhs.getType();
219 
220  // [su]mulextended (x, 0) = <0, 0>
221  if (matchPattern(rhs, m_Zero())) {
222  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
223  Value constituents[2] = {zero, zero};
224  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
225  constituents);
226  return success();
227  }
228 
229  // According to the SPIR-V spec:
230  //
231  // Result Type must be from OpTypeStruct. The struct must have two
232  // members...
233  //
234  // Member 0 of the result gets the low-order bits of the multiplication.
235  //
236  // Member 1 of the result gets the high-order bits of the multiplication.
237  Attribute lhsAttr;
238  Attribute rhsAttr;
239  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
240  !matchPattern(rhs, m_Constant(&rhsAttr)))
241  return failure();
242 
243  auto lowBits = constFoldBinaryOp<IntegerAttr>(
244  {lhsAttr, rhsAttr},
245  [](const APInt &a, const APInt &b) { return a * b; });
246 
247  if (!lowBits)
248  return failure();
249 
250  auto highBits = constFoldBinaryOp<IntegerAttr>(
251  {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
252  if (IsSigned) {
253  return llvm::APIntOps::mulhs(a, b);
254  } else {
255  return llvm::APIntOps::mulhu(a, b);
256  }
257  });
258 
259  if (!highBits)
260  return failure();
261 
262  Value lowBitsVal =
263  spirv::ConstantOp::create(rewriter, loc, constituentType, lowBits);
264 
265  Value highBitsVal =
266  spirv::ConstantOp::create(rewriter, loc, constituentType, highBits);
267 
268  // Create empty struct
269  Value undef = spirv::UndefOp::create(rewriter, loc, op.getType());
270  // Fill in lowBits at id 0
271  Value intermediate =
272  spirv::CompositeInsertOp::create(rewriter, loc, lowBitsVal, undef, 0);
273  // Fill in highBits at id 1
274  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
275  intermediate, 1);
276  return success();
277  }
278 };
279 
281 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283  patterns.add<SMulExtendedOpFold>(context);
284 }
285 
286 struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
288 
289  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
290  PatternRewriter &rewriter) const override {
291  Location loc = op.getLoc();
292  Value lhs = op.getOperand1();
293  Value rhs = op.getOperand2();
294  Type constituentType = lhs.getType();
295 
296  // umulextended (x, 1) = <x, 0>
297  if (matchPattern(rhs, m_One())) {
298  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
299  Value constituents[2] = {lhs, zero};
300  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
301  constituents);
302  return success();
303  }
304 
305  return failure();
306  }
307 };
308 
310 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // spirv.UMod
317 //===----------------------------------------------------------------------===//
318 
319 // Input:
320 // %0 = spirv.UMod %arg0, %const32 : i32
321 // %1 = spirv.UMod %0, %const4 : i32
322 // Output:
323 // %0 = spirv.UMod %arg0, %const32 : i32
324 // %1 = spirv.UMod %arg0, %const4 : i32
325 
326 // The transformation is only applied if one divisor is a multiple of the other.
327 
328 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
330 
331  LogicalResult matchAndRewrite(spirv::UModOp umodOp,
332  PatternRewriter &rewriter) const override {
333  auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
334  if (!prevUMod)
335  return failure();
336 
337  TypedAttr prevValue;
338  TypedAttr currValue;
339  if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
340  !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
341  return failure();
342 
343  // Ensure that previous divisor is a multiple of the current divisor. If
344  // not, fail the transformation.
345  bool isApplicable = false;
346  if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
347  auto currInt = cast<IntegerAttr>(currValue);
348  isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
349  } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
350  auto currVec = cast<DenseElementsAttr>(currValue);
351  isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
352  currVec.getValues<APInt>()),
353  [](const auto &pair) {
354  auto &[prev, curr] = pair;
355  return prev.urem(curr) == 0;
356  });
357  }
358 
359  if (!isApplicable)
360  return failure();
361 
362  // The transformation is safe. Replace the existing UMod operation with a
363  // new UMod operation, using the original dividend and the current divisor.
364  rewriter.replaceOpWithNewOp<spirv::UModOp>(
365  umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
366 
367  return success();
368  }
369 };
370 
371 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
372  MLIRContext *context) {
373  patterns.insert<UModSimplification>(context);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // spirv.BitcastOp
378 //===----------------------------------------------------------------------===//
379 
380 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
381  Value curInput = getOperand();
382  if (getType() == curInput.getType())
383  return curInput;
384 
385  // Look through nested bitcasts.
386  if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
387  Value prevInput = prevCast.getOperand();
388  if (prevInput.getType() == getType())
389  return prevInput;
390 
391  getOperandMutable().assign(prevInput);
392  return getResult();
393  }
394 
395  // TODO(kuhar): Consider constant-folding the operand attribute.
396  return {};
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // spirv.CompositeExtractOp
401 //===----------------------------------------------------------------------===//
402 
403 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
404  Value compositeOp = getComposite();
405 
406  while (auto insertOp =
407  compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
408  if (getIndices() == insertOp.getIndices())
409  return insertOp.getObject();
410  compositeOp = insertOp.getComposite();
411  }
412 
413  if (auto constructOp =
414  compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
415  auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
416  if (getIndices().size() == 1 &&
417  constructOp.getConstituents().size() == type.getNumElements()) {
418  auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
419  if (i.getValue().getSExtValue() <
420  static_cast<int64_t>(constructOp.getConstituents().size()))
421  return constructOp.getConstituents()[i.getValue().getSExtValue()];
422  }
423  }
424 
425  auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
426  return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
427  });
428  return extractCompositeElement(adaptor.getComposite(), indexVector);
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // spirv.Constant
433 //===----------------------------------------------------------------------===//
434 
435 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
436  return getValue();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // spirv.IAdd
441 //===----------------------------------------------------------------------===//
442 
443 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
444  // x + 0 = x
445  if (matchPattern(getOperand2(), m_Zero()))
446  return getOperand1();
447 
448  // According to the SPIR-V spec:
449  //
450  // The resulting value will equal the low-order N bits of the correct result
451  // R, where N is the component width and R is computed with enough precision
452  // to avoid overflow and underflow.
453  return constFoldBinaryOp<IntegerAttr>(
454  adaptor.getOperands(),
455  [](APInt a, const APInt &b) { return std::move(a) + b; });
456 }
457 
458 //===----------------------------------------------------------------------===//
459 // spirv.IMul
460 //===----------------------------------------------------------------------===//
461 
462 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
463  // x * 0 == 0
464  if (matchPattern(getOperand2(), m_Zero()))
465  return getOperand2();
466  // x * 1 = x
467  if (matchPattern(getOperand2(), m_One()))
468  return getOperand1();
469 
470  // According to the SPIR-V spec:
471  //
472  // The resulting value will equal the low-order N bits of the correct result
473  // R, where N is the component width and R is computed with enough precision
474  // to avoid overflow and underflow.
475  return constFoldBinaryOp<IntegerAttr>(
476  adaptor.getOperands(),
477  [](const APInt &a, const APInt &b) { return a * b; });
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // spirv.ISub
482 //===----------------------------------------------------------------------===//
483 
484 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
485  // x - x = 0
486  if (getOperand1() == getOperand2())
487  return Builder(getContext()).getZeroAttr(getType());
488 
489  // According to the SPIR-V spec:
490  //
491  // The resulting value will equal the low-order N bits of the correct result
492  // R, where N is the component width and R is computed with enough precision
493  // to avoid overflow and underflow.
494  return constFoldBinaryOp<IntegerAttr>(
495  adaptor.getOperands(),
496  [](APInt a, const APInt &b) { return std::move(a) - b; });
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // spirv.SDiv
501 //===----------------------------------------------------------------------===//
502 
503 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
504  // sdiv (x, 1) = x
505  if (matchPattern(getOperand2(), m_One()))
506  return getOperand1();
507 
508  // According to the SPIR-V spec:
509  //
510  // Signed-integer division of Operand 1 divided by Operand 2.
511  // Results are computed per component. Behavior is undefined if Operand 2 is
512  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
513  // representable value for the operands' type, causing signed overflow.
514  //
515  // So don't fold during undefined behavior.
516  bool div0OrOverflow = false;
517  auto res = constFoldBinaryOp<IntegerAttr>(
518  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
519  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
520  div0OrOverflow = true;
521  return a;
522  }
523  return a.sdiv(b);
524  });
525  return div0OrOverflow ? Attribute() : res;
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // spirv.SMod
530 //===----------------------------------------------------------------------===//
531 
532 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
533  // smod (x, 1) = 0
534  if (matchPattern(getOperand2(), m_One()))
535  return Builder(getContext()).getZeroAttr(getType());
536 
537  // According to SPIR-V spec:
538  //
539  // Signed remainder operation for the remainder whose sign matches the sign
540  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
541  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
542  // value for the operands' type, causing signed overflow. Otherwise, the
543  // result is the remainder r of Operand 1 divided by Operand 2 where if
544  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
545  //
546  // So don't fold during undefined behavior
547  bool div0OrOverflow = false;
548  auto res = constFoldBinaryOp<IntegerAttr>(
549  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
550  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
551  div0OrOverflow = true;
552  return a;
553  }
554  APInt c = a.abs().urem(b.abs());
555  if (c.isZero())
556  return c;
557  if (b.isNegative()) {
558  APInt zero = APInt::getZero(c.getBitWidth());
559  return a.isNegative() ? (zero - c) : (b + c);
560  }
561  return a.isNegative() ? (b - c) : c;
562  });
563  return div0OrOverflow ? Attribute() : res;
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // spirv.SRem
568 //===----------------------------------------------------------------------===//
569 
570 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
571  // x % 1 = 0
572  if (matchPattern(getOperand2(), m_One()))
573  return Builder(getContext()).getZeroAttr(getType());
574 
575  // According to SPIR-V spec:
576  //
577  // Signed remainder operation for the remainder whose sign matches the sign
578  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
579  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
580  // value for the operands' type, causing signed overflow. Otherwise, the
581  // result is the remainder r of Operand 1 divided by Operand 2 where if
582  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
583 
584  // Don't fold if it would do undefined behavior.
585  bool div0OrOverflow = false;
586  auto res = constFoldBinaryOp<IntegerAttr>(
587  adaptor.getOperands(), [&](APInt a, const APInt &b) {
588  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
589  div0OrOverflow = true;
590  return a;
591  }
592  return a.srem(b);
593  });
594  return div0OrOverflow ? Attribute() : res;
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // spirv.UDiv
599 //===----------------------------------------------------------------------===//
600 
601 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
602  // udiv (x, 1) = x
603  if (matchPattern(getOperand2(), m_One()))
604  return getOperand1();
605 
606  // According to the SPIR-V spec:
607  //
608  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
609  // undefined if Operand 2 is 0.
610  //
611  // So don't fold during undefined behavior.
612  bool div0 = false;
613  auto res = constFoldBinaryOp<IntegerAttr>(
614  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
615  if (div0 || b.isZero()) {
616  div0 = true;
617  return a;
618  }
619  return a.udiv(b);
620  });
621  return div0 ? Attribute() : res;
622 }
623 
624 //===----------------------------------------------------------------------===//
625 // spirv.UMod
626 //===----------------------------------------------------------------------===//
627 
628 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
629  // umod (x, 1) = 0
630  if (matchPattern(getOperand2(), m_One()))
631  return Builder(getContext()).getZeroAttr(getType());
632 
633  // According to the SPIR-V spec:
634  //
635  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
636  // undefined if Operand 2 is 0.
637  //
638  // So don't fold during undefined behavior.
639  bool div0 = false;
640  auto res = constFoldBinaryOp<IntegerAttr>(
641  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
642  if (div0 || b.isZero()) {
643  div0 = true;
644  return a;
645  }
646  return a.urem(b);
647  });
648  return div0 ? Attribute() : res;
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // spirv.SNegate
653 //===----------------------------------------------------------------------===//
654 
655 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
656  // -(-x) = 0 - (0 - x) = x
657  auto op = getOperand();
658  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
659  return negateOp->getOperand(0);
660 
661  // According to the SPIR-V spec:
662  //
663  // Signed-integer subtract of Operand from zero.
664  return constFoldUnaryOp<IntegerAttr>(
665  adaptor.getOperands(), [](const APInt &a) {
666  APInt zero = APInt::getZero(a.getBitWidth());
667  return zero - a;
668  });
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // spirv.NotOp
673 //===----------------------------------------------------------------------===//
674 
675 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
676  // !(!x) = x
677  auto op = getOperand();
678  if (auto notOp = op.getDefiningOp<spirv::NotOp>())
679  return notOp->getOperand(0);
680 
681  // According to the SPIR-V spec:
682  //
683  // Complement the bits of Operand.
684  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
685  a.flipAllBits();
686  return a;
687  });
688 }
689 
690 //===----------------------------------------------------------------------===//
691 // spirv.LogicalAnd
692 //===----------------------------------------------------------------------===//
693 
694 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
695  if (std::optional<bool> rhs =
696  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
697  // x && true = x
698  if (*rhs)
699  return getOperand1();
700 
701  // x && false = false
702  if (!*rhs)
703  return adaptor.getOperand2();
704  }
705 
706  return Attribute();
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // spirv.LogicalEqualOp
711 //===----------------------------------------------------------------------===//
712 
714 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
715  // x == x -> true
716  if (getOperand1() == getOperand2()) {
717  auto trueAttr = BoolAttr::get(getContext(), true);
718  if (isa<IntegerType>(getType()))
719  return trueAttr;
720  if (auto vecTy = dyn_cast<VectorType>(getType()))
721  return SplatElementsAttr::get(vecTy, trueAttr);
722  }
723 
724  return constFoldBinaryOp<IntegerAttr>(
725  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
726  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
727  });
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // spirv.LogicalNotEqualOp
732 //===----------------------------------------------------------------------===//
733 
734 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
735  if (std::optional<bool> rhs =
736  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
737  // x != false -> x
738  if (!rhs.value())
739  return getOperand1();
740  }
741 
742  // x == x -> false
743  if (getOperand1() == getOperand2()) {
744  auto falseAttr = BoolAttr::get(getContext(), false);
745  if (isa<IntegerType>(getType()))
746  return falseAttr;
747  if (auto vecTy = dyn_cast<VectorType>(getType()))
748  return SplatElementsAttr::get(vecTy, falseAttr);
749  }
750 
751  return constFoldBinaryOp<IntegerAttr>(
752  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
753  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
754  });
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // spirv.LogicalNot
759 //===----------------------------------------------------------------------===//
760 
761 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
762  // !(!x) = x
763  auto op = getOperand();
764  if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
765  return notOp->getOperand(0);
766 
767  // According to the SPIR-V spec:
768  //
769  // Complement the bits of Operand.
770  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
771  [](const APInt &a) {
772  APInt zero = APInt::getZero(1);
773  return a == 1 ? zero : (zero + 1);
774  });
775 }
776 
777 void spirv::LogicalNotOp::getCanonicalizationPatterns(
778  RewritePatternSet &results, MLIRContext *context) {
779  results
780  .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
781  ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
782  context);
783 }
784 
785 //===----------------------------------------------------------------------===//
786 // spirv.LogicalOr
787 //===----------------------------------------------------------------------===//
788 
789 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
790  if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
791  if (*rhs) {
792  // x || true = true
793  return adaptor.getOperand2();
794  }
795 
796  if (!*rhs) {
797  // x || false = x
798  return getOperand1();
799  }
800  }
801 
802  return Attribute();
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // spirv.SelectOp
807 //===----------------------------------------------------------------------===//
808 
809 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
810  // spirv.Select _ x x -> x
811  Value trueVals = getTrueValue();
812  Value falseVals = getFalseValue();
813  if (trueVals == falseVals)
814  return trueVals;
815 
816  ArrayRef<Attribute> operands = adaptor.getOperands();
817 
818  // spirv.Select true x y -> x
819  // spirv.Select false x y -> y
820  if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
821  return *boolAttr ? trueVals : falseVals;
822 
823  // Check that all the operands are constant
824  if (!operands[0] || !operands[1] || !operands[2])
825  return Attribute();
826 
827  // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
828  // the scalar case. Hence, we are only required to consider the case of
829  // DenseElementsAttr in foldSelectOp.
830  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
831  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
832  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
833  if (!condAttrs || !trueAttrs || !falseAttrs)
834  return Attribute();
835 
836  auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
837  auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
838  falseAttrs.getValues<Attribute>());
839  for (auto [result, cond, falseRes] : iters) {
840  if (!cond.getValue())
841  result = falseRes;
842  }
843 
844  auto resultType = trueAttrs.getType();
845  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
846 }
847 
848 //===----------------------------------------------------------------------===//
849 // spirv.IEqualOp
850 //===----------------------------------------------------------------------===//
851 
852 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
853  // x == x -> true
854  if (getOperand1() == getOperand2()) {
855  auto trueAttr = BoolAttr::get(getContext(), true);
856  if (isa<IntegerType>(getType()))
857  return trueAttr;
858  if (auto vecTy = dyn_cast<VectorType>(getType()))
859  return SplatElementsAttr::get(vecTy, trueAttr);
860  }
861 
862  return constFoldBinaryOp<IntegerAttr>(
863  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
864  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
865  });
866 }
867 
868 //===----------------------------------------------------------------------===//
869 // spirv.INotEqualOp
870 //===----------------------------------------------------------------------===//
871 
872 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
873  // x == x -> false
874  if (getOperand1() == getOperand2()) {
875  auto falseAttr = BoolAttr::get(getContext(), false);
876  if (isa<IntegerType>(getType()))
877  return falseAttr;
878  if (auto vecTy = dyn_cast<VectorType>(getType()))
879  return SplatElementsAttr::get(vecTy, falseAttr);
880  }
881 
882  return constFoldBinaryOp<IntegerAttr>(
883  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
884  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
885  });
886 }
887 
888 //===----------------------------------------------------------------------===//
889 // spirv.SGreaterThan
890 //===----------------------------------------------------------------------===//
891 
893 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
894  // x == x -> false
895  if (getOperand1() == getOperand2()) {
896  auto falseAttr = BoolAttr::get(getContext(), false);
897  if (isa<IntegerType>(getType()))
898  return falseAttr;
899  if (auto vecTy = dyn_cast<VectorType>(getType()))
900  return SplatElementsAttr::get(vecTy, falseAttr);
901  }
902 
903  return constFoldBinaryOp<IntegerAttr>(
904  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
905  return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
906  });
907 }
908 
909 //===----------------------------------------------------------------------===//
910 // spirv.SGreaterThanEqual
911 //===----------------------------------------------------------------------===//
912 
913 OpFoldResult spirv::SGreaterThanEqualOp::fold(
914  spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
915  // x == x -> true
916  if (getOperand1() == getOperand2()) {
917  auto trueAttr = BoolAttr::get(getContext(), true);
918  if (isa<IntegerType>(getType()))
919  return trueAttr;
920  if (auto vecTy = dyn_cast<VectorType>(getType()))
921  return SplatElementsAttr::get(vecTy, trueAttr);
922  }
923 
924  return constFoldBinaryOp<IntegerAttr>(
925  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
926  return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
927  });
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // spirv.UGreaterThan
932 //===----------------------------------------------------------------------===//
933 
935 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
936  // x == x -> false
937  if (getOperand1() == getOperand2()) {
938  auto falseAttr = BoolAttr::get(getContext(), false);
939  if (isa<IntegerType>(getType()))
940  return falseAttr;
941  if (auto vecTy = dyn_cast<VectorType>(getType()))
942  return SplatElementsAttr::get(vecTy, falseAttr);
943  }
944 
945  return constFoldBinaryOp<IntegerAttr>(
946  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
947  return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
948  });
949 }
950 
951 //===----------------------------------------------------------------------===//
952 // spirv.UGreaterThanEqual
953 //===----------------------------------------------------------------------===//
954 
955 OpFoldResult spirv::UGreaterThanEqualOp::fold(
956  spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
957  // x == x -> true
958  if (getOperand1() == getOperand2()) {
959  auto trueAttr = BoolAttr::get(getContext(), true);
960  if (isa<IntegerType>(getType()))
961  return trueAttr;
962  if (auto vecTy = dyn_cast<VectorType>(getType()))
963  return SplatElementsAttr::get(vecTy, trueAttr);
964  }
965 
966  return constFoldBinaryOp<IntegerAttr>(
967  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
968  return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
969  });
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // spirv.SLessThan
974 //===----------------------------------------------------------------------===//
975 
976 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
977  // x == x -> false
978  if (getOperand1() == getOperand2()) {
979  auto falseAttr = BoolAttr::get(getContext(), false);
980  if (isa<IntegerType>(getType()))
981  return falseAttr;
982  if (auto vecTy = dyn_cast<VectorType>(getType()))
983  return SplatElementsAttr::get(vecTy, falseAttr);
984  }
985 
986  return constFoldBinaryOp<IntegerAttr>(
987  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
988  return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
989  });
990 }
991 
992 //===----------------------------------------------------------------------===//
993 // spirv.SLessThanEqual
994 //===----------------------------------------------------------------------===//
995 
997 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
998  // x == x -> true
999  if (getOperand1() == getOperand2()) {
1000  auto trueAttr = BoolAttr::get(getContext(), true);
1001  if (isa<IntegerType>(getType()))
1002  return trueAttr;
1003  if (auto vecTy = dyn_cast<VectorType>(getType()))
1004  return SplatElementsAttr::get(vecTy, trueAttr);
1005  }
1006 
1007  return constFoldBinaryOp<IntegerAttr>(
1008  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1009  return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1010  });
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // spirv.ULessThan
1015 //===----------------------------------------------------------------------===//
1016 
1017 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1018  // x == x -> false
1019  if (getOperand1() == getOperand2()) {
1020  auto falseAttr = BoolAttr::get(getContext(), false);
1021  if (isa<IntegerType>(getType()))
1022  return falseAttr;
1023  if (auto vecTy = dyn_cast<VectorType>(getType()))
1024  return SplatElementsAttr::get(vecTy, falseAttr);
1025  }
1026 
1027  return constFoldBinaryOp<IntegerAttr>(
1028  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1029  return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1030  });
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // spirv.ULessThanEqual
1035 //===----------------------------------------------------------------------===//
1036 
1038 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1039  // x == x -> true
1040  if (getOperand1() == getOperand2()) {
1041  auto trueAttr = BoolAttr::get(getContext(), true);
1042  if (isa<IntegerType>(getType()))
1043  return trueAttr;
1044  if (auto vecTy = dyn_cast<VectorType>(getType()))
1045  return SplatElementsAttr::get(vecTy, trueAttr);
1046  }
1047 
1048  return constFoldBinaryOp<IntegerAttr>(
1049  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1050  return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1051  });
1052 }
1053 
1054 //===----------------------------------------------------------------------===//
1055 // spirv.ShiftLeftLogical
1056 //===----------------------------------------------------------------------===//
1057 
1058 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1059  spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1060  // x << 0 -> x
1061  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1062  return getOperand1();
1063  }
1064 
1065  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1066 
1067  // Results are computed per component, and within each component, per bit...
1068  //
1069  // The result is undefined if Shift is greater than or equal to the bit width
1070  // of the components of Base.
1071  //
1072  // So we can use the APInt << method, but don't fold if undefined behaviour.
1073  bool shiftToLarge = false;
1074  auto res = constFoldBinaryOp<IntegerAttr>(
1075  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1076  if (shiftToLarge || b.uge(a.getBitWidth())) {
1077  shiftToLarge = true;
1078  return a;
1079  }
1080  return a << b;
1081  });
1082  return shiftToLarge ? Attribute() : res;
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // spirv.ShiftRightArithmetic
1087 //===----------------------------------------------------------------------===//
1088 
1089 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1090  spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1091  // x >> 0 -> x
1092  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1093  return getOperand1();
1094  }
1095 
1096  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1097 
1098  // Results are computed per component, and within each component, per bit...
1099  //
1100  // The result is undefined if Shift is greater than or equal to the bit width
1101  // of the components of Base.
1102  //
1103  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1104  bool shiftToLarge = false;
1105  auto res = constFoldBinaryOp<IntegerAttr>(
1106  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1107  if (shiftToLarge || b.uge(a.getBitWidth())) {
1108  shiftToLarge = true;
1109  return a;
1110  }
1111  return a.ashr(b);
1112  });
1113  return shiftToLarge ? Attribute() : res;
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // spirv.ShiftRightLogical
1118 //===----------------------------------------------------------------------===//
1119 
1120 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1121  spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1122  // x >> 0 -> x
1123  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1124  return getOperand1();
1125  }
1126 
1127  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1128 
1129  // Results are computed per component, and within each component, per bit...
1130  //
1131  // The result is undefined if Shift is greater than or equal to the bit width
1132  // of the components of Base.
1133  //
1134  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1135  bool shiftToLarge = false;
1136  auto res = constFoldBinaryOp<IntegerAttr>(
1137  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1138  if (shiftToLarge || b.uge(a.getBitWidth())) {
1139  shiftToLarge = true;
1140  return a;
1141  }
1142  return a.lshr(b);
1143  });
1144  return shiftToLarge ? Attribute() : res;
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // spirv.BitwiseAndOp
1149 //===----------------------------------------------------------------------===//
1150 
1152 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1153  // x & x -> x
1154  if (getOperand1() == getOperand2()) {
1155  return getOperand1();
1156  }
1157 
1158  APInt rhsMask;
1159  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1160  // x & 0 -> 0
1161  if (rhsMask.isZero())
1162  return getOperand2();
1163 
1164  // x & <all ones> -> x
1165  if (rhsMask.isAllOnes())
1166  return getOperand1();
1167 
1168  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1169  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1170  int valueBits =
1172  if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1173  return getOperand1();
1174  }
1175  }
1176 
1177  // According to the SPIR-V spec:
1178  //
1179  // Type is a scalar or vector of integer type.
1180  // Results are computed per component, and within each component, per bit.
1181  // So we can use the APInt & method.
1182  return constFoldBinaryOp<IntegerAttr>(
1183  adaptor.getOperands(),
1184  [](const APInt &a, const APInt &b) { return a & b; });
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // spirv.BitwiseOrOp
1189 //===----------------------------------------------------------------------===//
1190 
1191 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1192  // x | x -> x
1193  if (getOperand1() == getOperand2()) {
1194  return getOperand1();
1195  }
1196 
1197  APInt rhsMask;
1198  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1199  // x | 0 -> x
1200  if (rhsMask.isZero())
1201  return getOperand1();
1202 
1203  // x | <all ones> -> <all ones>
1204  if (rhsMask.isAllOnes())
1205  return getOperand2();
1206  }
1207 
1208  // According to the SPIR-V spec:
1209  //
1210  // Type is a scalar or vector of integer type.
1211  // Results are computed per component, and within each component, per bit.
1212  // So we can use the APInt | method.
1213  return constFoldBinaryOp<IntegerAttr>(
1214  adaptor.getOperands(),
1215  [](const APInt &a, const APInt &b) { return a | b; });
1216 }
1217 
1218 //===----------------------------------------------------------------------===//
1219 // spirv.BitwiseXorOp
1220 //===----------------------------------------------------------------------===//
1221 
1223 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1224  // x ^ 0 -> x
1225  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1226  return getOperand1();
1227  }
1228 
1229  // x ^ x -> 0
1230  if (getOperand1() == getOperand2())
1231  return Builder(getContext()).getZeroAttr(getType());
1232 
1233  // According to the SPIR-V spec:
1234  //
1235  // Type is a scalar or vector of integer type.
1236  // Results are computed per component, and within each component, per bit.
1237  // So we can use the APInt ^ method.
1238  return constFoldBinaryOp<IntegerAttr>(
1239  adaptor.getOperands(),
1240  [](const APInt &a, const APInt &b) { return a ^ b; });
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // spirv.mlir.selection
1245 //===----------------------------------------------------------------------===//
1246 
1247 namespace {
1248 // Blocks from the given `spirv.mlir.selection` operation must satisfy the
1249 // following layout:
1250 //
1251 // +-----------------------------------------------+
1252 // | header block |
1253 // | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1254 // +-----------------------------------------------+
1255 // / \
1256 // ...
1257 //
1258 //
1259 // +------------------------+ +------------------------+
1260 // | case #0 | | case #1 |
1261 // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1262 // | spirv.Branch ^merge | | spirv.Branch ^merge |
1263 // +------------------------+ +------------------------+
1264 //
1265 //
1266 // ...
1267 // \ /
1268 // v
1269 // +-------------+
1270 // | merge block |
1271 // +-------------+
1272 //
1273 struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1275 
1276  LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1277  PatternRewriter &rewriter) const override {
1278  Operation *op = selectionOp.getOperation();
1279  Region &body = op->getRegion(0);
1280  // Verifier allows an empty region for `spirv.mlir.selection`.
1281  if (body.empty()) {
1282  return failure();
1283  }
1284 
1285  // Check that region consists of 4 blocks:
1286  // header block, `true` block, `false` block and merge block.
1287  if (llvm::range_size(body) != 4) {
1288  return failure();
1289  }
1290 
1291  Block *headerBlock = selectionOp.getHeaderBlock();
1292  if (!onlyContainsBranchConditionalOp(headerBlock)) {
1293  return failure();
1294  }
1295 
1296  auto brConditionalOp =
1297  cast<spirv::BranchConditionalOp>(headerBlock->front());
1298 
1299  Block *trueBlock = brConditionalOp.getSuccessor(0);
1300  Block *falseBlock = brConditionalOp.getSuccessor(1);
1301  Block *mergeBlock = selectionOp.getMergeBlock();
1302 
1303  if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1304  return failure();
1305 
1306  Value trueValue = getSrcValue(trueBlock);
1307  Value falseValue = getSrcValue(falseBlock);
1308  Value ptrValue = getDstPtr(trueBlock);
1309  auto storeOpAttributes =
1310  cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1311 
1312  auto selectOp = spirv::SelectOp::create(
1313  rewriter, selectionOp.getLoc(), trueValue.getType(),
1314  brConditionalOp.getCondition(), trueValue, falseValue);
1315  spirv::StoreOp::create(rewriter, selectOp.getLoc(), ptrValue,
1316  selectOp.getResult(), storeOpAttributes);
1317 
1318  // `spirv.mlir.selection` is not needed anymore.
1319  rewriter.eraseOp(op);
1320  return success();
1321  }
1322 
1323 private:
1324  // Checks that given blocks follow the following rules:
1325  // 1. Each conditional block consists of two operations, the first operation
1326  // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1327  // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1328  // 3. A control flow goes into the given merge block from the given
1329  // conditional blocks.
1330  LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1331  Block *mergeBlock) const;
1332 
1333  bool onlyContainsBranchConditionalOp(Block *block) const {
1334  return llvm::hasSingleElement(*block) &&
1335  isa<spirv::BranchConditionalOp>(block->front());
1336  }
1337 
1338  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1339  return lhs->getDiscardableAttrDictionary() ==
1340  rhs->getDiscardableAttrDictionary() &&
1341  lhs.getProperties() == rhs.getProperties();
1342  }
1343 
1344  // Returns a source value for the given block.
1345  Value getSrcValue(Block *block) const {
1346  auto storeOp = cast<spirv::StoreOp>(block->front());
1347  return storeOp.getValue();
1348  }
1349 
1350  // Returns a destination value for the given block.
1351  Value getDstPtr(Block *block) const {
1352  auto storeOp = cast<spirv::StoreOp>(block->front());
1353  return storeOp.getPtr();
1354  }
1355 };
1356 
1357 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1358  Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1359  // Each block must consists of 2 operations.
1360  if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1361  return failure();
1362  }
1363 
1364  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1365  auto trueBrBranchOp =
1366  dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1367  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1368  auto falseBrBranchOp =
1369  dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1370 
1371  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1372  !falseBrBranchOp) {
1373  return failure();
1374  }
1375 
1376  // Checks that given type is valid for `spirv.SelectOp`.
1377  // According to SPIR-V spec:
1378  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1379  // Starting with version 1.4, Result Type can additionally be a composite type
1380  // other than a vector."
1381  bool isScalarOrVector =
1382  llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1383  .isScalarOrVector();
1384 
1385  // Check that each `spirv.Store` uses the same pointer, memory access
1386  // attributes and a valid type of the value.
1387  if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1388  !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1389  return failure();
1390  }
1391 
1392  if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1393  (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1394  return failure();
1395  }
1396 
1397  return success();
1398 }
1399 } // namespace
1400 
1401 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1402  MLIRContext *context) {
1403  results.add<ConvertSelectionOpToSelect>(context);
1404 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
Block * getSuccessor(unsigned i)
Definition: Block.cpp:259
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:831
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297