MLIR  21.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <optional>
14 #include <utility>
15 
17 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Common utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34 /// or splat vector bool constant.
35 static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36  if (!attr)
37  return std::nullopt;
38 
39  if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40  return boolAttr.getValue();
41  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42  if (splatAttr.getElementType().isInteger(1))
43  return splatAttr.getSplatValue<bool>();
44  return std::nullopt;
45 }
46 
47 // Extracts an element from the given `composite` by following the given
48 // `indices`. Returns a null Attribute if error happens.
50  ArrayRef<unsigned> indices) {
51  // Check that given composite is a constant.
52  if (!composite)
53  return {};
54  // Return composite itself if we reach the end of the index chain.
55  if (indices.empty())
56  return composite;
57 
58  if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59  assert(indices.size() == 1 && "must have exactly one index for a vector");
60  return vector.getValues<Attribute>()[indices[0]];
61  }
62 
63  if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64  assert(!indices.empty() && "must have at least one index for an array");
65  return extractCompositeElement(array.getValue()[indices[0]],
66  indices.drop_front());
67  }
68 
69  return {};
70 }
71 
72 static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73  bool div0 = b.isZero();
74  bool overflow = a.isMinSignedValue() && b.isAllOnes();
75 
76  return div0 || overflow;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // TableGen'erated canonicalizers
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 #include "SPIRVCanonicalization.inc"
85 } // namespace
86 
87 //===----------------------------------------------------------------------===//
88 // spirv.AccessChainOp
89 //===----------------------------------------------------------------------===//
90 
91 namespace {
92 
93 /// Combines chained `spirv::AccessChainOp` operations into one
94 /// `spirv::AccessChainOp` operation.
95 struct CombineChainedAccessChain final
96  : OpRewritePattern<spirv::AccessChainOp> {
98 
99  LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100  PatternRewriter &rewriter) const override {
101  auto parentAccessChainOp =
102  accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 
104  if (!parentAccessChainOp) {
105  return failure();
106  }
107 
108  // Combine indices.
109  SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110  llvm::append_range(indices, accessChainOp.getIndices());
111 
112  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113  accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114 
115  return success();
116  }
117 };
118 } // namespace
119 
120 void spirv::AccessChainOp::getCanonicalizationPatterns(
121  RewritePatternSet &results, MLIRContext *context) {
122  results.add<CombineChainedAccessChain>(context);
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // spirv.IAddCarry
127 //===----------------------------------------------------------------------===//
128 
129 // We are required to use CompositeConstructOp to create a constant struct as
130 // they are not yet implemented as constant, hence we can not do so in a fold.
131 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
133 
134  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135  PatternRewriter &rewriter) const override {
136  Location loc = op.getLoc();
137  Value lhs = op.getOperand1();
138  Value rhs = op.getOperand2();
139  Type constituentType = lhs.getType();
140 
141  // iaddcarry (x, 0) = <0, x>
142  if (matchPattern(rhs, m_Zero())) {
143  Value constituents[2] = {rhs, lhs};
144  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145  constituents);
146  return success();
147  }
148 
149  // According to the SPIR-V spec:
150  //
151  // Result Type must be from OpTypeStruct. The struct must have two
152  // members...
153  //
154  // Member 0 of the result gets the low-order bits (full component width) of
155  // the addition.
156  //
157  // Member 1 of the result gets the high-order (carry) bit of the result of
158  // the addition. That is, it gets the value 1 if the addition overflowed
159  // the component width, and 0 otherwise.
160  Attribute lhsAttr;
161  Attribute rhsAttr;
162  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
163  !matchPattern(rhs, m_Constant(&rhsAttr)))
164  return failure();
165 
166  auto adds = constFoldBinaryOp<IntegerAttr>(
167  {lhsAttr, rhsAttr},
168  [](const APInt &a, const APInt &b) { return a + b; });
169  if (!adds)
170  return failure();
171 
172  auto carrys = constFoldBinaryOp<IntegerAttr>(
173  ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174  APInt zero = APInt::getZero(a.getBitWidth());
175  return a.ult(b) ? (zero + 1) : zero;
176  });
177 
178  if (!carrys)
179  return failure();
180 
181  Value addsVal =
182  rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183 
184  Value carrysVal =
185  rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186 
187  // Create empty struct
188  Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189  // Fill in adds at id 0
190  Value intermediate =
191  rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192  // Fill in carrys at id 1
193  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194  intermediate, 1);
195  return success();
196  }
197 };
198 
199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
201  patterns.add<IAddCarryFold>(context);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // spirv.[S|U]MulExtended
206 //===----------------------------------------------------------------------===//
207 
208 // We are required to use CompositeConstructOp to create a constant struct as
209 // they are not yet implemented as constant, hence we can not do so in a fold.
210 template <typename MulOp, bool IsSigned>
211 struct MulExtendedFold final : OpRewritePattern<MulOp> {
213 
214  LogicalResult matchAndRewrite(MulOp op,
215  PatternRewriter &rewriter) const override {
216  Location loc = op.getLoc();
217  Value lhs = op.getOperand1();
218  Value rhs = op.getOperand2();
219  Type constituentType = lhs.getType();
220 
221  // [su]mulextended (x, 0) = <0, 0>
222  if (matchPattern(rhs, m_Zero())) {
223  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224  Value constituents[2] = {zero, zero};
225  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226  constituents);
227  return success();
228  }
229 
230  // According to the SPIR-V spec:
231  //
232  // Result Type must be from OpTypeStruct. The struct must have two
233  // members...
234  //
235  // Member 0 of the result gets the low-order bits of the multiplication.
236  //
237  // Member 1 of the result gets the high-order bits of the multiplication.
238  Attribute lhsAttr;
239  Attribute rhsAttr;
240  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
241  !matchPattern(rhs, m_Constant(&rhsAttr)))
242  return failure();
243 
244  auto lowBits = constFoldBinaryOp<IntegerAttr>(
245  {lhsAttr, rhsAttr},
246  [](const APInt &a, const APInt &b) { return a * b; });
247 
248  if (!lowBits)
249  return failure();
250 
251  auto highBits = constFoldBinaryOp<IntegerAttr>(
252  {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253  if (IsSigned) {
254  return llvm::APIntOps::mulhs(a, b);
255  } else {
256  return llvm::APIntOps::mulhu(a, b);
257  }
258  });
259 
260  if (!highBits)
261  return failure();
262 
263  Value lowBitsVal =
264  rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265 
266  Value highBitsVal =
267  rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268 
269  // Create empty struct
270  Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271  // Fill in lowBits at id 0
272  Value intermediate =
273  rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274  // Fill in highBits at id 1
275  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276  intermediate, 1);
277  return success();
278  }
279 };
280 
282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
284  patterns.add<SMulExtendedOpFold>(context);
285 }
286 
287 struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
289 
290  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291  PatternRewriter &rewriter) const override {
292  Location loc = op.getLoc();
293  Value lhs = op.getOperand1();
294  Value rhs = op.getOperand2();
295  Type constituentType = lhs.getType();
296 
297  // umulextended (x, 1) = <x, 0>
298  if (matchPattern(rhs, m_One())) {
299  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300  Value constituents[2] = {lhs, zero};
301  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302  constituents);
303  return success();
304  }
305 
306  return failure();
307  }
308 };
309 
311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // spirv.UMod
318 //===----------------------------------------------------------------------===//
319 
320 // Input:
321 // %0 = spirv.UMod %arg0, %const32 : i32
322 // %1 = spirv.UMod %0, %const4 : i32
323 // Output:
324 // %0 = spirv.UMod %arg0, %const32 : i32
325 // %1 = spirv.UMod %arg0, %const4 : i32
326 
327 // The transformation is only applied if one divisor is a multiple of the other.
328 
329 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331 
332  LogicalResult matchAndRewrite(spirv::UModOp umodOp,
333  PatternRewriter &rewriter) const override {
334  auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
335  if (!prevUMod)
336  return failure();
337 
338  TypedAttr prevValue;
339  TypedAttr currValue;
340  if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
341  !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
342  return failure();
343 
344  // Ensure that previous divisor is a multiple of the current divisor. If
345  // not, fail the transformation.
346  bool isApplicable = false;
347  if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
348  auto currInt = cast<IntegerAttr>(currValue);
349  isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350  } else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
351  auto currVec = cast<DenseElementsAttr>(currValue);
352  isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
353  currVec.getValues<APInt>()),
354  [](const auto &pair) {
355  auto &[prev, curr] = pair;
356  return prev.urem(curr) == 0;
357  });
358  }
359 
360  if (!isApplicable)
361  return failure();
362 
363  // The transformation is safe. Replace the existing UMod operation with a
364  // new UMod operation, using the original dividend and the current divisor.
365  rewriter.replaceOpWithNewOp<spirv::UModOp>(
366  umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
367 
368  return success();
369  }
370 };
371 
372 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
373  MLIRContext *context) {
374  patterns.insert<UModSimplification>(context);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // spirv.BitcastOp
379 //===----------------------------------------------------------------------===//
380 
381 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
382  Value curInput = getOperand();
383  if (getType() == curInput.getType())
384  return curInput;
385 
386  // Look through nested bitcasts.
387  if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
388  Value prevInput = prevCast.getOperand();
389  if (prevInput.getType() == getType())
390  return prevInput;
391 
392  getOperandMutable().assign(prevInput);
393  return getResult();
394  }
395 
396  // TODO(kuhar): Consider constant-folding the operand attribute.
397  return {};
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // spirv.CompositeExtractOp
402 //===----------------------------------------------------------------------===//
403 
404 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
405  Value compositeOp = getComposite();
406 
407  while (auto insertOp =
408  compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
409  if (getIndices() == insertOp.getIndices())
410  return insertOp.getObject();
411  compositeOp = insertOp.getComposite();
412  }
413 
414  if (auto constructOp =
415  compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
416  auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
417  if (getIndices().size() == 1 &&
418  constructOp.getConstituents().size() == type.getNumElements()) {
419  auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
420  if (i.getValue().getSExtValue() <
421  static_cast<int64_t>(constructOp.getConstituents().size()))
422  return constructOp.getConstituents()[i.getValue().getSExtValue()];
423  }
424  }
425 
426  auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
427  return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
428  });
429  return extractCompositeElement(adaptor.getComposite(), indexVector);
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // spirv.Constant
434 //===----------------------------------------------------------------------===//
435 
436 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
437  return getValue();
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // spirv.IAdd
442 //===----------------------------------------------------------------------===//
443 
444 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
445  // x + 0 = x
446  if (matchPattern(getOperand2(), m_Zero()))
447  return getOperand1();
448 
449  // According to the SPIR-V spec:
450  //
451  // The resulting value will equal the low-order N bits of the correct result
452  // R, where N is the component width and R is computed with enough precision
453  // to avoid overflow and underflow.
454  return constFoldBinaryOp<IntegerAttr>(
455  adaptor.getOperands(),
456  [](APInt a, const APInt &b) { return std::move(a) + b; });
457 }
458 
459 //===----------------------------------------------------------------------===//
460 // spirv.IMul
461 //===----------------------------------------------------------------------===//
462 
463 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
464  // x * 0 == 0
465  if (matchPattern(getOperand2(), m_Zero()))
466  return getOperand2();
467  // x * 1 = x
468  if (matchPattern(getOperand2(), m_One()))
469  return getOperand1();
470 
471  // According to the SPIR-V spec:
472  //
473  // The resulting value will equal the low-order N bits of the correct result
474  // R, where N is the component width and R is computed with enough precision
475  // to avoid overflow and underflow.
476  return constFoldBinaryOp<IntegerAttr>(
477  adaptor.getOperands(),
478  [](const APInt &a, const APInt &b) { return a * b; });
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // spirv.ISub
483 //===----------------------------------------------------------------------===//
484 
485 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
486  // x - x = 0
487  if (getOperand1() == getOperand2())
488  return Builder(getContext()).getZeroAttr(getType());
489 
490  // According to the SPIR-V spec:
491  //
492  // The resulting value will equal the low-order N bits of the correct result
493  // R, where N is the component width and R is computed with enough precision
494  // to avoid overflow and underflow.
495  return constFoldBinaryOp<IntegerAttr>(
496  adaptor.getOperands(),
497  [](APInt a, const APInt &b) { return std::move(a) - b; });
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // spirv.SDiv
502 //===----------------------------------------------------------------------===//
503 
504 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
505  // sdiv (x, 1) = x
506  if (matchPattern(getOperand2(), m_One()))
507  return getOperand1();
508 
509  // According to the SPIR-V spec:
510  //
511  // Signed-integer division of Operand 1 divided by Operand 2.
512  // Results are computed per component. Behavior is undefined if Operand 2 is
513  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
514  // representable value for the operands' type, causing signed overflow.
515  //
516  // So don't fold during undefined behavior.
517  bool div0OrOverflow = false;
518  auto res = constFoldBinaryOp<IntegerAttr>(
519  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
520  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
521  div0OrOverflow = true;
522  return a;
523  }
524  return a.sdiv(b);
525  });
526  return div0OrOverflow ? Attribute() : res;
527 }
528 
529 //===----------------------------------------------------------------------===//
530 // spirv.SMod
531 //===----------------------------------------------------------------------===//
532 
533 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
534  // smod (x, 1) = 0
535  if (matchPattern(getOperand2(), m_One()))
536  return Builder(getContext()).getZeroAttr(getType());
537 
538  // According to SPIR-V spec:
539  //
540  // Signed remainder operation for the remainder whose sign matches the sign
541  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
542  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
543  // value for the operands' type, causing signed overflow. Otherwise, the
544  // result is the remainder r of Operand 1 divided by Operand 2 where if
545  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
546  //
547  // So don't fold during undefined behavior
548  bool div0OrOverflow = false;
549  auto res = constFoldBinaryOp<IntegerAttr>(
550  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
551  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
552  div0OrOverflow = true;
553  return a;
554  }
555  APInt c = a.abs().urem(b.abs());
556  if (c.isZero())
557  return c;
558  if (b.isNegative()) {
559  APInt zero = APInt::getZero(c.getBitWidth());
560  return a.isNegative() ? (zero - c) : (b + c);
561  }
562  return a.isNegative() ? (b - c) : c;
563  });
564  return div0OrOverflow ? Attribute() : res;
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // spirv.SRem
569 //===----------------------------------------------------------------------===//
570 
571 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
572  // x % 1 = 0
573  if (matchPattern(getOperand2(), m_One()))
574  return Builder(getContext()).getZeroAttr(getType());
575 
576  // According to SPIR-V spec:
577  //
578  // Signed remainder operation for the remainder whose sign matches the sign
579  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
580  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
581  // value for the operands' type, causing signed overflow. Otherwise, the
582  // result is the remainder r of Operand 1 divided by Operand 2 where if
583  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
584 
585  // Don't fold if it would do undefined behavior.
586  bool div0OrOverflow = false;
587  auto res = constFoldBinaryOp<IntegerAttr>(
588  adaptor.getOperands(), [&](APInt a, const APInt &b) {
589  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
590  div0OrOverflow = true;
591  return a;
592  }
593  return a.srem(b);
594  });
595  return div0OrOverflow ? Attribute() : res;
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // spirv.UDiv
600 //===----------------------------------------------------------------------===//
601 
602 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
603  // udiv (x, 1) = x
604  if (matchPattern(getOperand2(), m_One()))
605  return getOperand1();
606 
607  // According to the SPIR-V spec:
608  //
609  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
610  // undefined if Operand 2 is 0.
611  //
612  // So don't fold during undefined behavior.
613  bool div0 = false;
614  auto res = constFoldBinaryOp<IntegerAttr>(
615  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
616  if (div0 || b.isZero()) {
617  div0 = true;
618  return a;
619  }
620  return a.udiv(b);
621  });
622  return div0 ? Attribute() : res;
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // spirv.UMod
627 //===----------------------------------------------------------------------===//
628 
629 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
630  // umod (x, 1) = 0
631  if (matchPattern(getOperand2(), m_One()))
632  return Builder(getContext()).getZeroAttr(getType());
633 
634  // According to the SPIR-V spec:
635  //
636  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
637  // undefined if Operand 2 is 0.
638  //
639  // So don't fold during undefined behavior.
640  bool div0 = false;
641  auto res = constFoldBinaryOp<IntegerAttr>(
642  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
643  if (div0 || b.isZero()) {
644  div0 = true;
645  return a;
646  }
647  return a.urem(b);
648  });
649  return div0 ? Attribute() : res;
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // spirv.SNegate
654 //===----------------------------------------------------------------------===//
655 
656 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
657  // -(-x) = 0 - (0 - x) = x
658  auto op = getOperand();
659  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
660  return negateOp->getOperand(0);
661 
662  // According to the SPIR-V spec:
663  //
664  // Signed-integer subtract of Operand from zero.
665  return constFoldUnaryOp<IntegerAttr>(
666  adaptor.getOperands(), [](const APInt &a) {
667  APInt zero = APInt::getZero(a.getBitWidth());
668  return zero - a;
669  });
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // spirv.NotOp
674 //===----------------------------------------------------------------------===//
675 
676 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
677  // !(!x) = x
678  auto op = getOperand();
679  if (auto notOp = op.getDefiningOp<spirv::NotOp>())
680  return notOp->getOperand(0);
681 
682  // According to the SPIR-V spec:
683  //
684  // Complement the bits of Operand.
685  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
686  a.flipAllBits();
687  return a;
688  });
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // spirv.LogicalAnd
693 //===----------------------------------------------------------------------===//
694 
695 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
696  if (std::optional<bool> rhs =
697  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
698  // x && true = x
699  if (*rhs)
700  return getOperand1();
701 
702  // x && false = false
703  if (!*rhs)
704  return adaptor.getOperand2();
705  }
706 
707  return Attribute();
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // spirv.LogicalEqualOp
712 //===----------------------------------------------------------------------===//
713 
715 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
716  // x == x -> true
717  if (getOperand1() == getOperand2()) {
718  auto trueAttr = BoolAttr::get(getContext(), true);
719  if (isa<IntegerType>(getType()))
720  return trueAttr;
721  if (auto vecTy = dyn_cast<VectorType>(getType()))
722  return SplatElementsAttr::get(vecTy, trueAttr);
723  }
724 
725  return constFoldBinaryOp<IntegerAttr>(
726  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
727  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
728  });
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // spirv.LogicalNotEqualOp
733 //===----------------------------------------------------------------------===//
734 
735 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
736  if (std::optional<bool> rhs =
737  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
738  // x != false -> x
739  if (!rhs.value())
740  return getOperand1();
741  }
742 
743  // x == x -> false
744  if (getOperand1() == getOperand2()) {
745  auto falseAttr = BoolAttr::get(getContext(), false);
746  if (isa<IntegerType>(getType()))
747  return falseAttr;
748  if (auto vecTy = dyn_cast<VectorType>(getType()))
749  return SplatElementsAttr::get(vecTy, falseAttr);
750  }
751 
752  return constFoldBinaryOp<IntegerAttr>(
753  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
754  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
755  });
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // spirv.LogicalNot
760 //===----------------------------------------------------------------------===//
761 
762 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
763  // !(!x) = x
764  auto op = getOperand();
765  if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
766  return notOp->getOperand(0);
767 
768  // According to the SPIR-V spec:
769  //
770  // Complement the bits of Operand.
771  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
772  [](const APInt &a) {
773  APInt zero = APInt::getZero(1);
774  return a == 1 ? zero : (zero + 1);
775  });
776 }
777 
778 void spirv::LogicalNotOp::getCanonicalizationPatterns(
779  RewritePatternSet &results, MLIRContext *context) {
780  results
781  .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
782  ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
783  context);
784 }
785 
786 //===----------------------------------------------------------------------===//
787 // spirv.LogicalOr
788 //===----------------------------------------------------------------------===//
789 
790 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
791  if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
792  if (*rhs) {
793  // x || true = true
794  return adaptor.getOperand2();
795  }
796 
797  if (!*rhs) {
798  // x || false = x
799  return getOperand1();
800  }
801  }
802 
803  return Attribute();
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // spirv.SelectOp
808 //===----------------------------------------------------------------------===//
809 
810 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
811  // spirv.Select _ x x -> x
812  Value trueVals = getTrueValue();
813  Value falseVals = getFalseValue();
814  if (trueVals == falseVals)
815  return trueVals;
816 
817  ArrayRef<Attribute> operands = adaptor.getOperands();
818 
819  // spirv.Select true x y -> x
820  // spirv.Select false x y -> y
821  if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
822  return *boolAttr ? trueVals : falseVals;
823 
824  // Check that all the operands are constant
825  if (!operands[0] || !operands[1] || !operands[2])
826  return Attribute();
827 
828  // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
829  // the scalar case. Hence, we are only required to consider the case of
830  // DenseElementsAttr in foldSelectOp.
831  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
832  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
833  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
834  if (!condAttrs || !trueAttrs || !falseAttrs)
835  return Attribute();
836 
837  auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
838  auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
839  falseAttrs.getValues<Attribute>());
840  for (auto [result, cond, falseRes] : iters) {
841  if (!cond.getValue())
842  result = falseRes;
843  }
844 
845  auto resultType = trueAttrs.getType();
846  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // spirv.IEqualOp
851 //===----------------------------------------------------------------------===//
852 
853 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
854  // x == x -> true
855  if (getOperand1() == getOperand2()) {
856  auto trueAttr = BoolAttr::get(getContext(), true);
857  if (isa<IntegerType>(getType()))
858  return trueAttr;
859  if (auto vecTy = dyn_cast<VectorType>(getType()))
860  return SplatElementsAttr::get(vecTy, trueAttr);
861  }
862 
863  return constFoldBinaryOp<IntegerAttr>(
864  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
865  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
866  });
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // spirv.INotEqualOp
871 //===----------------------------------------------------------------------===//
872 
873 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
874  // x == x -> false
875  if (getOperand1() == getOperand2()) {
876  auto falseAttr = BoolAttr::get(getContext(), false);
877  if (isa<IntegerType>(getType()))
878  return falseAttr;
879  if (auto vecTy = dyn_cast<VectorType>(getType()))
880  return SplatElementsAttr::get(vecTy, falseAttr);
881  }
882 
883  return constFoldBinaryOp<IntegerAttr>(
884  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
885  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
886  });
887 }
888 
889 //===----------------------------------------------------------------------===//
890 // spirv.SGreaterThan
891 //===----------------------------------------------------------------------===//
892 
894 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
895  // x == x -> false
896  if (getOperand1() == getOperand2()) {
897  auto falseAttr = BoolAttr::get(getContext(), false);
898  if (isa<IntegerType>(getType()))
899  return falseAttr;
900  if (auto vecTy = dyn_cast<VectorType>(getType()))
901  return SplatElementsAttr::get(vecTy, falseAttr);
902  }
903 
904  return constFoldBinaryOp<IntegerAttr>(
905  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
906  return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
907  });
908 }
909 
910 //===----------------------------------------------------------------------===//
911 // spirv.SGreaterThanEqual
912 //===----------------------------------------------------------------------===//
913 
914 OpFoldResult spirv::SGreaterThanEqualOp::fold(
915  spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
916  // x == x -> true
917  if (getOperand1() == getOperand2()) {
918  auto trueAttr = BoolAttr::get(getContext(), true);
919  if (isa<IntegerType>(getType()))
920  return trueAttr;
921  if (auto vecTy = dyn_cast<VectorType>(getType()))
922  return SplatElementsAttr::get(vecTy, trueAttr);
923  }
924 
925  return constFoldBinaryOp<IntegerAttr>(
926  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
927  return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
928  });
929 }
930 
931 //===----------------------------------------------------------------------===//
932 // spirv.UGreaterThan
933 //===----------------------------------------------------------------------===//
934 
936 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
937  // x == x -> false
938  if (getOperand1() == getOperand2()) {
939  auto falseAttr = BoolAttr::get(getContext(), false);
940  if (isa<IntegerType>(getType()))
941  return falseAttr;
942  if (auto vecTy = dyn_cast<VectorType>(getType()))
943  return SplatElementsAttr::get(vecTy, falseAttr);
944  }
945 
946  return constFoldBinaryOp<IntegerAttr>(
947  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
948  return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
949  });
950 }
951 
952 //===----------------------------------------------------------------------===//
953 // spirv.UGreaterThanEqual
954 //===----------------------------------------------------------------------===//
955 
956 OpFoldResult spirv::UGreaterThanEqualOp::fold(
957  spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
958  // x == x -> true
959  if (getOperand1() == getOperand2()) {
960  auto trueAttr = BoolAttr::get(getContext(), true);
961  if (isa<IntegerType>(getType()))
962  return trueAttr;
963  if (auto vecTy = dyn_cast<VectorType>(getType()))
964  return SplatElementsAttr::get(vecTy, trueAttr);
965  }
966 
967  return constFoldBinaryOp<IntegerAttr>(
968  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
969  return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
970  });
971 }
972 
973 //===----------------------------------------------------------------------===//
974 // spirv.SLessThan
975 //===----------------------------------------------------------------------===//
976 
977 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
978  // x == x -> false
979  if (getOperand1() == getOperand2()) {
980  auto falseAttr = BoolAttr::get(getContext(), false);
981  if (isa<IntegerType>(getType()))
982  return falseAttr;
983  if (auto vecTy = dyn_cast<VectorType>(getType()))
984  return SplatElementsAttr::get(vecTy, falseAttr);
985  }
986 
987  return constFoldBinaryOp<IntegerAttr>(
988  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
989  return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
990  });
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // spirv.SLessThanEqual
995 //===----------------------------------------------------------------------===//
996 
998 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
999  // x == x -> true
1000  if (getOperand1() == getOperand2()) {
1001  auto trueAttr = BoolAttr::get(getContext(), true);
1002  if (isa<IntegerType>(getType()))
1003  return trueAttr;
1004  if (auto vecTy = dyn_cast<VectorType>(getType()))
1005  return SplatElementsAttr::get(vecTy, trueAttr);
1006  }
1007 
1008  return constFoldBinaryOp<IntegerAttr>(
1009  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1010  return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1011  });
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // spirv.ULessThan
1016 //===----------------------------------------------------------------------===//
1017 
1018 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1019  // x == x -> false
1020  if (getOperand1() == getOperand2()) {
1021  auto falseAttr = BoolAttr::get(getContext(), false);
1022  if (isa<IntegerType>(getType()))
1023  return falseAttr;
1024  if (auto vecTy = dyn_cast<VectorType>(getType()))
1025  return SplatElementsAttr::get(vecTy, falseAttr);
1026  }
1027 
1028  return constFoldBinaryOp<IntegerAttr>(
1029  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1030  return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1031  });
1032 }
1033 
1034 //===----------------------------------------------------------------------===//
1035 // spirv.ULessThanEqual
1036 //===----------------------------------------------------------------------===//
1037 
1039 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1040  // x == x -> true
1041  if (getOperand1() == getOperand2()) {
1042  auto trueAttr = BoolAttr::get(getContext(), true);
1043  if (isa<IntegerType>(getType()))
1044  return trueAttr;
1045  if (auto vecTy = dyn_cast<VectorType>(getType()))
1046  return SplatElementsAttr::get(vecTy, trueAttr);
1047  }
1048 
1049  return constFoldBinaryOp<IntegerAttr>(
1050  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1051  return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1052  });
1053 }
1054 
1055 //===----------------------------------------------------------------------===//
1056 // spirv.ShiftLeftLogical
1057 //===----------------------------------------------------------------------===//
1058 
1059 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1060  spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1061  // x << 0 -> x
1062  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1063  return getOperand1();
1064  }
1065 
1066  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1067 
1068  // Results are computed per component, and within each component, per bit...
1069  //
1070  // The result is undefined if Shift is greater than or equal to the bit width
1071  // of the components of Base.
1072  //
1073  // So we can use the APInt << method, but don't fold if undefined behaviour.
1074  bool shiftToLarge = false;
1075  auto res = constFoldBinaryOp<IntegerAttr>(
1076  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1077  if (shiftToLarge || b.uge(a.getBitWidth())) {
1078  shiftToLarge = true;
1079  return a;
1080  }
1081  return a << b;
1082  });
1083  return shiftToLarge ? Attribute() : res;
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // spirv.ShiftRightArithmetic
1088 //===----------------------------------------------------------------------===//
1089 
1090 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1091  spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1092  // x >> 0 -> x
1093  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1094  return getOperand1();
1095  }
1096 
1097  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1098 
1099  // Results are computed per component, and within each component, per bit...
1100  //
1101  // The result is undefined if Shift is greater than or equal to the bit width
1102  // of the components of Base.
1103  //
1104  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1105  bool shiftToLarge = false;
1106  auto res = constFoldBinaryOp<IntegerAttr>(
1107  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1108  if (shiftToLarge || b.uge(a.getBitWidth())) {
1109  shiftToLarge = true;
1110  return a;
1111  }
1112  return a.ashr(b);
1113  });
1114  return shiftToLarge ? Attribute() : res;
1115 }
1116 
1117 //===----------------------------------------------------------------------===//
1118 // spirv.ShiftRightLogical
1119 //===----------------------------------------------------------------------===//
1120 
1121 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1122  spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1123  // x >> 0 -> x
1124  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1125  return getOperand1();
1126  }
1127 
1128  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1129 
1130  // Results are computed per component, and within each component, per bit...
1131  //
1132  // The result is undefined if Shift is greater than or equal to the bit width
1133  // of the components of Base.
1134  //
1135  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1136  bool shiftToLarge = false;
1137  auto res = constFoldBinaryOp<IntegerAttr>(
1138  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1139  if (shiftToLarge || b.uge(a.getBitWidth())) {
1140  shiftToLarge = true;
1141  return a;
1142  }
1143  return a.lshr(b);
1144  });
1145  return shiftToLarge ? Attribute() : res;
1146 }
1147 
1148 //===----------------------------------------------------------------------===//
1149 // spirv.BitwiseAndOp
1150 //===----------------------------------------------------------------------===//
1151 
1153 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1154  // x & x -> x
1155  if (getOperand1() == getOperand2()) {
1156  return getOperand1();
1157  }
1158 
1159  APInt rhsMask;
1160  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1161  // x & 0 -> 0
1162  if (rhsMask.isZero())
1163  return getOperand2();
1164 
1165  // x & <all ones> -> x
1166  if (rhsMask.isAllOnes())
1167  return getOperand1();
1168 
1169  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1170  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1171  int valueBits =
1173  if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1174  return getOperand1();
1175  }
1176  }
1177 
1178  // According to the SPIR-V spec:
1179  //
1180  // Type is a scalar or vector of integer type.
1181  // Results are computed per component, and within each component, per bit.
1182  // So we can use the APInt & method.
1183  return constFoldBinaryOp<IntegerAttr>(
1184  adaptor.getOperands(),
1185  [](const APInt &a, const APInt &b) { return a & b; });
1186 }
1187 
1188 //===----------------------------------------------------------------------===//
1189 // spirv.BitwiseOrOp
1190 //===----------------------------------------------------------------------===//
1191 
1192 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1193  // x | x -> x
1194  if (getOperand1() == getOperand2()) {
1195  return getOperand1();
1196  }
1197 
1198  APInt rhsMask;
1199  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1200  // x | 0 -> x
1201  if (rhsMask.isZero())
1202  return getOperand1();
1203 
1204  // x | <all ones> -> <all ones>
1205  if (rhsMask.isAllOnes())
1206  return getOperand2();
1207  }
1208 
1209  // According to the SPIR-V spec:
1210  //
1211  // Type is a scalar or vector of integer type.
1212  // Results are computed per component, and within each component, per bit.
1213  // So we can use the APInt | method.
1214  return constFoldBinaryOp<IntegerAttr>(
1215  adaptor.getOperands(),
1216  [](const APInt &a, const APInt &b) { return a | b; });
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // spirv.BitwiseXorOp
1221 //===----------------------------------------------------------------------===//
1222 
1224 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1225  // x ^ 0 -> x
1226  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1227  return getOperand1();
1228  }
1229 
1230  // x ^ x -> 0
1231  if (getOperand1() == getOperand2())
1232  return Builder(getContext()).getZeroAttr(getType());
1233 
1234  // According to the SPIR-V spec:
1235  //
1236  // Type is a scalar or vector of integer type.
1237  // Results are computed per component, and within each component, per bit.
1238  // So we can use the APInt ^ method.
1239  return constFoldBinaryOp<IntegerAttr>(
1240  adaptor.getOperands(),
1241  [](const APInt &a, const APInt &b) { return a ^ b; });
1242 }
1243 
1244 //===----------------------------------------------------------------------===//
1245 // spirv.mlir.selection
1246 //===----------------------------------------------------------------------===//
1247 
1248 namespace {
1249 // Blocks from the given `spirv.mlir.selection` operation must satisfy the
1250 // following layout:
1251 //
1252 // +-----------------------------------------------+
1253 // | header block |
1254 // | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1255 // +-----------------------------------------------+
1256 // / \
1257 // ...
1258 //
1259 //
1260 // +------------------------+ +------------------------+
1261 // | case #0 | | case #1 |
1262 // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1263 // | spirv.Branch ^merge | | spirv.Branch ^merge |
1264 // +------------------------+ +------------------------+
1265 //
1266 //
1267 // ...
1268 // \ /
1269 // v
1270 // +-------------+
1271 // | merge block |
1272 // +-------------+
1273 //
1274 struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1276 
1277  LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1278  PatternRewriter &rewriter) const override {
1279  Operation *op = selectionOp.getOperation();
1280  Region &body = op->getRegion(0);
1281  // Verifier allows an empty region for `spirv.mlir.selection`.
1282  if (body.empty()) {
1283  return failure();
1284  }
1285 
1286  // Check that region consists of 4 blocks:
1287  // header block, `true` block, `false` block and merge block.
1288  if (llvm::range_size(body) != 4) {
1289  return failure();
1290  }
1291 
1292  Block *headerBlock = selectionOp.getHeaderBlock();
1293  if (!onlyContainsBranchConditionalOp(headerBlock)) {
1294  return failure();
1295  }
1296 
1297  auto brConditionalOp =
1298  cast<spirv::BranchConditionalOp>(headerBlock->front());
1299 
1300  Block *trueBlock = brConditionalOp.getSuccessor(0);
1301  Block *falseBlock = brConditionalOp.getSuccessor(1);
1302  Block *mergeBlock = selectionOp.getMergeBlock();
1303 
1304  if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1305  return failure();
1306 
1307  Value trueValue = getSrcValue(trueBlock);
1308  Value falseValue = getSrcValue(falseBlock);
1309  Value ptrValue = getDstPtr(trueBlock);
1310  auto storeOpAttributes =
1311  cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1312 
1313  auto selectOp = rewriter.create<spirv::SelectOp>(
1314  selectionOp.getLoc(), trueValue.getType(),
1315  brConditionalOp.getCondition(), trueValue, falseValue);
1316  rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1317  selectOp.getResult(), storeOpAttributes);
1318 
1319  // `spirv.mlir.selection` is not needed anymore.
1320  rewriter.eraseOp(op);
1321  return success();
1322  }
1323 
1324 private:
1325  // Checks that given blocks follow the following rules:
1326  // 1. Each conditional block consists of two operations, the first operation
1327  // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1328  // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1329  // 3. A control flow goes into the given merge block from the given
1330  // conditional blocks.
1331  LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1332  Block *mergeBlock) const;
1333 
1334  bool onlyContainsBranchConditionalOp(Block *block) const {
1335  return llvm::hasSingleElement(*block) &&
1336  isa<spirv::BranchConditionalOp>(block->front());
1337  }
1338 
1339  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1340  return lhs->getDiscardableAttrDictionary() ==
1341  rhs->getDiscardableAttrDictionary() &&
1342  lhs.getProperties() == rhs.getProperties();
1343  }
1344 
1345  // Returns a source value for the given block.
1346  Value getSrcValue(Block *block) const {
1347  auto storeOp = cast<spirv::StoreOp>(block->front());
1348  return storeOp.getValue();
1349  }
1350 
1351  // Returns a destination value for the given block.
1352  Value getDstPtr(Block *block) const {
1353  auto storeOp = cast<spirv::StoreOp>(block->front());
1354  return storeOp.getPtr();
1355  }
1356 };
1357 
1358 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1359  Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1360  // Each block must consists of 2 operations.
1361  if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1362  return failure();
1363  }
1364 
1365  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1366  auto trueBrBranchOp =
1367  dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1368  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1369  auto falseBrBranchOp =
1370  dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1371 
1372  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1373  !falseBrBranchOp) {
1374  return failure();
1375  }
1376 
1377  // Checks that given type is valid for `spirv.SelectOp`.
1378  // According to SPIR-V spec:
1379  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1380  // Starting with version 1.4, Result Type can additionally be a composite type
1381  // other than a vector."
1382  bool isScalarOrVector =
1383  llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1384  .isScalarOrVector();
1385 
1386  // Check that each `spirv.Store` uses the same pointer, memory access
1387  // attributes and a valid type of the value.
1388  if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1389  !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1390  return failure();
1391  }
1392 
1393  if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1394  (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1395  return failure();
1396  }
1397 
1398  return success();
1399 }
1400 } // namespace
1401 
1402 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1403  MLIRContext *context) {
1404  results.add<ConvertSelectionOpToSelect>(context);
1405 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation & front()
Definition: Block.h:153
iterator begin()
Definition: Block.h:143
Block * getSuccessor(unsigned i)
Definition: Block.cpp:261
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:478
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297