MLIR  19.0.0git
SPIRVCanonicalization.cpp
Go to the documentation of this file.
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <optional>
14 #include <utility>
15 
17 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Common utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34 /// or splat vector bool constant.
35 static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36  if (!attr)
37  return std::nullopt;
38 
39  if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40  return boolAttr.getValue();
41  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42  if (splatAttr.getElementType().isInteger(1))
43  return splatAttr.getSplatValue<bool>();
44  return std::nullopt;
45 }
46 
47 // Extracts an element from the given `composite` by following the given
48 // `indices`. Returns a null Attribute if error happens.
50  ArrayRef<unsigned> indices) {
51  // Check that given composite is a constant.
52  if (!composite)
53  return {};
54  // Return composite itself if we reach the end of the index chain.
55  if (indices.empty())
56  return composite;
57 
58  if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59  assert(indices.size() == 1 && "must have exactly one index for a vector");
60  return vector.getValues<Attribute>()[indices[0]];
61  }
62 
63  if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64  assert(!indices.empty() && "must have at least one index for an array");
65  return extractCompositeElement(array.getValue()[indices[0]],
66  indices.drop_front());
67  }
68 
69  return {};
70 }
71 
72 static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73  bool div0 = b.isZero();
74  bool overflow = a.isMinSignedValue() && b.isAllOnes();
75 
76  return div0 || overflow;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // TableGen'erated canonicalizers
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 #include "SPIRVCanonicalization.inc"
85 } // namespace
86 
87 //===----------------------------------------------------------------------===//
88 // spirv.AccessChainOp
89 //===----------------------------------------------------------------------===//
90 
91 namespace {
92 
93 /// Combines chained `spirv::AccessChainOp` operations into one
94 /// `spirv::AccessChainOp` operation.
95 struct CombineChainedAccessChain final
96  : OpRewritePattern<spirv::AccessChainOp> {
98 
99  LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100  PatternRewriter &rewriter) const override {
101  auto parentAccessChainOp =
102  accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 
104  if (!parentAccessChainOp) {
105  return failure();
106  }
107 
108  // Combine indices.
109  SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110  llvm::append_range(indices, accessChainOp.getIndices());
111 
112  rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113  accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114 
115  return success();
116  }
117 };
118 } // namespace
119 
120 void spirv::AccessChainOp::getCanonicalizationPatterns(
121  RewritePatternSet &results, MLIRContext *context) {
122  results.add<CombineChainedAccessChain>(context);
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // spirv.IAddCarry
127 //===----------------------------------------------------------------------===//
128 
129 // We are required to use CompositeConstructOp to create a constant struct as
130 // they are not yet implemented as constant, hence we can not do so in a fold.
131 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
133 
134  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135  PatternRewriter &rewriter) const override {
136  Location loc = op.getLoc();
137  Value lhs = op.getOperand1();
138  Value rhs = op.getOperand2();
139  Type constituentType = lhs.getType();
140 
141  // iaddcarry (x, 0) = <0, x>
142  if (matchPattern(rhs, m_Zero())) {
143  Value constituents[2] = {rhs, lhs};
144  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145  constituents);
146  return success();
147  }
148 
149  // According to the SPIR-V spec:
150  //
151  // Result Type must be from OpTypeStruct. The struct must have two
152  // members...
153  //
154  // Member 0 of the result gets the low-order bits (full component width) of
155  // the addition.
156  //
157  // Member 1 of the result gets the high-order (carry) bit of the result of
158  // the addition. That is, it gets the value 1 if the addition overflowed
159  // the component width, and 0 otherwise.
160  Attribute lhsAttr;
161  Attribute rhsAttr;
162  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
163  !matchPattern(rhs, m_Constant(&rhsAttr)))
164  return failure();
165 
166  auto adds = constFoldBinaryOp<IntegerAttr>(
167  {lhsAttr, rhsAttr},
168  [](const APInt &a, const APInt &b) { return a + b; });
169  if (!adds)
170  return failure();
171 
172  auto carrys = constFoldBinaryOp<IntegerAttr>(
173  ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174  APInt zero = APInt::getZero(a.getBitWidth());
175  return a.ult(b) ? (zero + 1) : zero;
176  });
177 
178  if (!carrys)
179  return failure();
180 
181  Value addsVal =
182  rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183 
184  Value carrysVal =
185  rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186 
187  // Create empty struct
188  Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189  // Fill in adds at id 0
190  Value intermediate =
191  rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192  // Fill in carrys at id 1
193  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194  intermediate, 1);
195  return success();
196  }
197 };
198 
199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
200  RewritePatternSet &patterns, MLIRContext *context) {
201  patterns.add<IAddCarryFold>(context);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // spirv.[S|U]MulExtended
206 //===----------------------------------------------------------------------===//
207 
208 // We are required to use CompositeConstructOp to create a constant struct as
209 // they are not yet implemented as constant, hence we can not do so in a fold.
210 template <typename MulOp, bool IsSigned>
211 struct MulExtendedFold final : OpRewritePattern<MulOp> {
213 
215  PatternRewriter &rewriter) const override {
216  Location loc = op.getLoc();
217  Value lhs = op.getOperand1();
218  Value rhs = op.getOperand2();
219  Type constituentType = lhs.getType();
220 
221  // [su]mulextended (x, 0) = <0, 0>
222  if (matchPattern(rhs, m_Zero())) {
223  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224  Value constituents[2] = {zero, zero};
225  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226  constituents);
227  return success();
228  }
229 
230  // According to the SPIR-V spec:
231  //
232  // Result Type must be from OpTypeStruct. The struct must have two
233  // members...
234  //
235  // Member 0 of the result gets the low-order bits of the multiplication.
236  //
237  // Member 1 of the result gets the high-order bits of the multiplication.
238  Attribute lhsAttr;
239  Attribute rhsAttr;
240  if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
241  !matchPattern(rhs, m_Constant(&rhsAttr)))
242  return failure();
243 
244  auto lowBits = constFoldBinaryOp<IntegerAttr>(
245  {lhsAttr, rhsAttr},
246  [](const APInt &a, const APInt &b) { return a * b; });
247 
248  if (!lowBits)
249  return failure();
250 
251  auto highBits = constFoldBinaryOp<IntegerAttr>(
252  {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253  if (IsSigned) {
254  return llvm::APIntOps::mulhs(a, b);
255  } else {
256  return llvm::APIntOps::mulhu(a, b);
257  }
258  });
259 
260  if (!highBits)
261  return failure();
262 
263  Value lowBitsVal =
264  rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265 
266  Value highBitsVal =
267  rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268 
269  // Create empty struct
270  Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271  // Fill in lowBits at id 0
272  Value intermediate =
273  rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274  // Fill in highBits at id 1
275  rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276  intermediate, 1);
277  return success();
278  }
279 };
280 
282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283  RewritePatternSet &patterns, MLIRContext *context) {
284  patterns.add<SMulExtendedOpFold>(context);
285 }
286 
287 struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
289 
290  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291  PatternRewriter &rewriter) const override {
292  Location loc = op.getLoc();
293  Value lhs = op.getOperand1();
294  Value rhs = op.getOperand2();
295  Type constituentType = lhs.getType();
296 
297  // umulextended (x, 1) = <x, 0>
298  if (matchPattern(rhs, m_One())) {
299  Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300  Value constituents[2] = {lhs, zero};
301  rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302  constituents);
303  return success();
304  }
305 
306  return failure();
307  }
308 };
309 
311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312  RewritePatternSet &patterns, MLIRContext *context) {
313  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // spirv.UMod
318 //===----------------------------------------------------------------------===//
319 
320 // Input:
321 // %0 = spirv.UMod %arg0, %const32 : i32
322 // %1 = spirv.UMod %0, %const4 : i32
323 // Output:
324 // %0 = spirv.UMod %arg0, %const32 : i32
325 // %1 = spirv.UMod %arg0, %const4 : i32
326 
327 // The transformation is only applied if one divisor is a multiple of the other.
328 
329 // TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
332 
333  LogicalResult matchAndRewrite(spirv::UModOp umodOp,
334  PatternRewriter &rewriter) const override {
335  auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
336  if (!prevUMod)
337  return failure();
338 
339  IntegerAttr prevValue;
340  IntegerAttr currValue;
341  if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342  !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343  return failure();
344 
345  APInt prevConstValue = prevValue.getValue();
346  APInt currConstValue = currValue.getValue();
347 
348  // Ensure that one divisor is a multiple of the other. If not, fail the
349  // transformation.
350  if (prevConstValue.urem(currConstValue) != 0 &&
351  currConstValue.urem(prevConstValue) != 0)
352  return failure();
353 
354  // The transformation is safe. Replace the existing UMod operation with a
355  // new UMod operation, using the original dividend and the current divisor.
356  rewriter.replaceOpWithNewOp<spirv::UModOp>(
357  umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
358 
359  return success();
360  }
361 };
362 
363 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
364  MLIRContext *context) {
365  patterns.insert<UModSimplification>(context);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // spirv.BitcastOp
370 //===----------------------------------------------------------------------===//
371 
372 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
373  Value curInput = getOperand();
374  if (getType() == curInput.getType())
375  return curInput;
376 
377  // Look through nested bitcasts.
378  if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
379  Value prevInput = prevCast.getOperand();
380  if (prevInput.getType() == getType())
381  return prevInput;
382 
383  getOperandMutable().assign(prevInput);
384  return getResult();
385  }
386 
387  // TODO(kuhar): Consider constant-folding the operand attribute.
388  return {};
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // spirv.CompositeExtractOp
393 //===----------------------------------------------------------------------===//
394 
395 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
396  Value compositeOp = getComposite();
397 
398  while (auto insertOp =
399  compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
400  if (getIndices() == insertOp.getIndices())
401  return insertOp.getObject();
402  compositeOp = insertOp.getComposite();
403  }
404 
405  if (auto constructOp =
406  compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
407  auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
408  if (getIndices().size() == 1 &&
409  constructOp.getConstituents().size() == type.getNumElements()) {
410  auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
411  if (i.getValue().getSExtValue() <
412  static_cast<int64_t>(constructOp.getConstituents().size()))
413  return constructOp.getConstituents()[i.getValue().getSExtValue()];
414  }
415  }
416 
417  auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
418  return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
419  });
420  return extractCompositeElement(adaptor.getComposite(), indexVector);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // spirv.Constant
425 //===----------------------------------------------------------------------===//
426 
427 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
428  return getValue();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // spirv.IAdd
433 //===----------------------------------------------------------------------===//
434 
435 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
436  // x + 0 = x
437  if (matchPattern(getOperand2(), m_Zero()))
438  return getOperand1();
439 
440  // According to the SPIR-V spec:
441  //
442  // The resulting value will equal the low-order N bits of the correct result
443  // R, where N is the component width and R is computed with enough precision
444  // to avoid overflow and underflow.
445  return constFoldBinaryOp<IntegerAttr>(
446  adaptor.getOperands(),
447  [](APInt a, const APInt &b) { return std::move(a) + b; });
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // spirv.IMul
452 //===----------------------------------------------------------------------===//
453 
454 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
455  // x * 0 == 0
456  if (matchPattern(getOperand2(), m_Zero()))
457  return getOperand2();
458  // x * 1 = x
459  if (matchPattern(getOperand2(), m_One()))
460  return getOperand1();
461 
462  // According to the SPIR-V spec:
463  //
464  // The resulting value will equal the low-order N bits of the correct result
465  // R, where N is the component width and R is computed with enough precision
466  // to avoid overflow and underflow.
467  return constFoldBinaryOp<IntegerAttr>(
468  adaptor.getOperands(),
469  [](const APInt &a, const APInt &b) { return a * b; });
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // spirv.ISub
474 //===----------------------------------------------------------------------===//
475 
476 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
477  // x - x = 0
478  if (getOperand1() == getOperand2())
479  return Builder(getContext()).getIntegerAttr(getType(), 0);
480 
481  // According to the SPIR-V spec:
482  //
483  // The resulting value will equal the low-order N bits of the correct result
484  // R, where N is the component width and R is computed with enough precision
485  // to avoid overflow and underflow.
486  return constFoldBinaryOp<IntegerAttr>(
487  adaptor.getOperands(),
488  [](APInt a, const APInt &b) { return std::move(a) - b; });
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // spirv.SDiv
493 //===----------------------------------------------------------------------===//
494 
495 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
496  // sdiv (x, 1) = x
497  if (matchPattern(getOperand2(), m_One()))
498  return getOperand1();
499 
500  // According to the SPIR-V spec:
501  //
502  // Signed-integer division of Operand 1 divided by Operand 2.
503  // Results are computed per component. Behavior is undefined if Operand 2 is
504  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
505  // representable value for the operands' type, causing signed overflow.
506  //
507  // So don't fold during undefined behavior.
508  bool div0OrOverflow = false;
509  auto res = constFoldBinaryOp<IntegerAttr>(
510  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
511  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
512  div0OrOverflow = true;
513  return a;
514  }
515  return a.sdiv(b);
516  });
517  return div0OrOverflow ? Attribute() : res;
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // spirv.SMod
522 //===----------------------------------------------------------------------===//
523 
524 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
525  // smod (x, 1) = 0
526  if (matchPattern(getOperand2(), m_One()))
527  return Builder(getContext()).getZeroAttr(getType());
528 
529  // According to SPIR-V spec:
530  //
531  // Signed remainder operation for the remainder whose sign matches the sign
532  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
533  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
534  // value for the operands' type, causing signed overflow. Otherwise, the
535  // result is the remainder r of Operand 1 divided by Operand 2 where if
536  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
537  //
538  // So don't fold during undefined behavior
539  bool div0OrOverflow = false;
540  auto res = constFoldBinaryOp<IntegerAttr>(
541  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
542  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
543  div0OrOverflow = true;
544  return a;
545  }
546  APInt c = a.abs().urem(b.abs());
547  if (c.isZero())
548  return c;
549  if (b.isNegative()) {
550  APInt zero = APInt::getZero(c.getBitWidth());
551  return a.isNegative() ? (zero - c) : (b + c);
552  }
553  return a.isNegative() ? (b - c) : c;
554  });
555  return div0OrOverflow ? Attribute() : res;
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // spirv.SRem
560 //===----------------------------------------------------------------------===//
561 
562 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
563  // x % 1 = 0
564  if (matchPattern(getOperand2(), m_One()))
565  return Builder(getContext()).getZeroAttr(getType());
566 
567  // According to SPIR-V spec:
568  //
569  // Signed remainder operation for the remainder whose sign matches the sign
570  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
571  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
572  // value for the operands' type, causing signed overflow. Otherwise, the
573  // result is the remainder r of Operand 1 divided by Operand 2 where if
574  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
575 
576  // Don't fold if it would do undefined behavior.
577  bool div0OrOverflow = false;
578  auto res = constFoldBinaryOp<IntegerAttr>(
579  adaptor.getOperands(), [&](APInt a, const APInt &b) {
580  if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
581  div0OrOverflow = true;
582  return a;
583  }
584  return a.srem(b);
585  });
586  return div0OrOverflow ? Attribute() : res;
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // spirv.UDiv
591 //===----------------------------------------------------------------------===//
592 
593 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
594  // udiv (x, 1) = x
595  if (matchPattern(getOperand2(), m_One()))
596  return getOperand1();
597 
598  // According to the SPIR-V spec:
599  //
600  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
601  // undefined if Operand 2 is 0.
602  //
603  // So don't fold during undefined behavior.
604  bool div0 = false;
605  auto res = constFoldBinaryOp<IntegerAttr>(
606  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
607  if (div0 || b.isZero()) {
608  div0 = true;
609  return a;
610  }
611  return a.udiv(b);
612  });
613  return div0 ? Attribute() : res;
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // spirv.UMod
618 //===----------------------------------------------------------------------===//
619 
620 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
621  // umod (x, 1) = 0
622  if (matchPattern(getOperand2(), m_One()))
623  return Builder(getContext()).getZeroAttr(getType());
624 
625  // According to the SPIR-V spec:
626  //
627  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
628  // undefined if Operand 2 is 0.
629  //
630  // So don't fold during undefined behavior.
631  bool div0 = false;
632  auto res = constFoldBinaryOp<IntegerAttr>(
633  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
634  if (div0 || b.isZero()) {
635  div0 = true;
636  return a;
637  }
638  return a.urem(b);
639  });
640  return div0 ? Attribute() : res;
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // spirv.SNegate
645 //===----------------------------------------------------------------------===//
646 
647 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
648  // -(-x) = 0 - (0 - x) = x
649  auto op = getOperand();
650  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
651  return negateOp->getOperand(0);
652 
653  // According to the SPIR-V spec:
654  //
655  // Signed-integer subtract of Operand from zero.
656  return constFoldUnaryOp<IntegerAttr>(
657  adaptor.getOperands(), [](const APInt &a) {
658  APInt zero = APInt::getZero(a.getBitWidth());
659  return zero - a;
660  });
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // spirv.NotOp
665 //===----------------------------------------------------------------------===//
666 
667 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
668  // !(!x) = x
669  auto op = getOperand();
670  if (auto notOp = op.getDefiningOp<spirv::NotOp>())
671  return notOp->getOperand(0);
672 
673  // According to the SPIR-V spec:
674  //
675  // Complement the bits of Operand.
676  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
677  a.flipAllBits();
678  return a;
679  });
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // spirv.LogicalAnd
684 //===----------------------------------------------------------------------===//
685 
686 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
687  if (std::optional<bool> rhs =
688  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
689  // x && true = x
690  if (*rhs)
691  return getOperand1();
692 
693  // x && false = false
694  if (!*rhs)
695  return adaptor.getOperand2();
696  }
697 
698  return Attribute();
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // spirv.LogicalEqualOp
703 //===----------------------------------------------------------------------===//
704 
706 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
707  // x == x -> true
708  if (getOperand1() == getOperand2()) {
709  auto trueAttr = BoolAttr::get(getContext(), true);
710  if (isa<IntegerType>(getType()))
711  return trueAttr;
712  if (auto vecTy = dyn_cast<VectorType>(getType()))
713  return SplatElementsAttr::get(vecTy, trueAttr);
714  }
715 
716  return constFoldBinaryOp<IntegerAttr>(
717  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
718  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
719  });
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // spirv.LogicalNotEqualOp
724 //===----------------------------------------------------------------------===//
725 
726 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
727  if (std::optional<bool> rhs =
728  getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
729  // x != false -> x
730  if (!rhs.value())
731  return getOperand1();
732  }
733 
734  // x == x -> false
735  if (getOperand1() == getOperand2()) {
736  auto falseAttr = BoolAttr::get(getContext(), false);
737  if (isa<IntegerType>(getType()))
738  return falseAttr;
739  if (auto vecTy = dyn_cast<VectorType>(getType()))
740  return SplatElementsAttr::get(vecTy, falseAttr);
741  }
742 
743  return constFoldBinaryOp<IntegerAttr>(
744  adaptor.getOperands(), [](const APInt &a, const APInt &b) {
745  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
746  });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // spirv.LogicalNot
751 //===----------------------------------------------------------------------===//
752 
753 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
754  // !(!x) = x
755  auto op = getOperand();
756  if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
757  return notOp->getOperand(0);
758 
759  // According to the SPIR-V spec:
760  //
761  // Complement the bits of Operand.
762  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
763  [](const APInt &a) {
764  APInt zero = APInt::getZero(1);
765  return a == 1 ? zero : (zero + 1);
766  });
767 }
768 
769 void spirv::LogicalNotOp::getCanonicalizationPatterns(
770  RewritePatternSet &results, MLIRContext *context) {
771  results
772  .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
773  ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
774  context);
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // spirv.LogicalOr
779 //===----------------------------------------------------------------------===//
780 
781 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
782  if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
783  if (*rhs) {
784  // x || true = true
785  return adaptor.getOperand2();
786  }
787 
788  if (!*rhs) {
789  // x || false = x
790  return getOperand1();
791  }
792  }
793 
794  return Attribute();
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // spirv.SelectOp
799 //===----------------------------------------------------------------------===//
800 
801 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
802  // spirv.Select _ x x -> x
803  Value trueVals = getTrueValue();
804  Value falseVals = getFalseValue();
805  if (trueVals == falseVals)
806  return trueVals;
807 
808  ArrayRef<Attribute> operands = adaptor.getOperands();
809 
810  // spirv.Select true x y -> x
811  // spirv.Select false x y -> y
812  if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
813  return *boolAttr ? trueVals : falseVals;
814 
815  // Check that all the operands are constant
816  if (!operands[0] || !operands[1] || !operands[2])
817  return Attribute();
818 
819  // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
820  // the scalar case. Hence, we are only required to consider the case of
821  // DenseElementsAttr in foldSelectOp.
822  auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
823  auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
824  auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
825  if (!condAttrs || !trueAttrs || !falseAttrs)
826  return Attribute();
827 
828  auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
829  auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
830  falseAttrs.getValues<Attribute>());
831  for (auto [result, cond, falseRes] : iters) {
832  if (!cond.getValue())
833  result = falseRes;
834  }
835 
836  auto resultType = trueAttrs.getType();
837  return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // spirv.IEqualOp
842 //===----------------------------------------------------------------------===//
843 
844 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
845  // x == x -> true
846  if (getOperand1() == getOperand2()) {
847  auto trueAttr = BoolAttr::get(getContext(), true);
848  if (isa<IntegerType>(getType()))
849  return trueAttr;
850  if (auto vecTy = dyn_cast<VectorType>(getType()))
851  return SplatElementsAttr::get(vecTy, trueAttr);
852  }
853 
854  return constFoldBinaryOp<IntegerAttr>(
855  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
856  return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
857  });
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // spirv.INotEqualOp
862 //===----------------------------------------------------------------------===//
863 
864 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
865  // x == x -> false
866  if (getOperand1() == getOperand2()) {
867  auto falseAttr = BoolAttr::get(getContext(), false);
868  if (isa<IntegerType>(getType()))
869  return falseAttr;
870  if (auto vecTy = dyn_cast<VectorType>(getType()))
871  return SplatElementsAttr::get(vecTy, falseAttr);
872  }
873 
874  return constFoldBinaryOp<IntegerAttr>(
875  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
876  return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
877  });
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // spirv.SGreaterThan
882 //===----------------------------------------------------------------------===//
883 
885 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
886  // x == x -> false
887  if (getOperand1() == getOperand2()) {
888  auto falseAttr = BoolAttr::get(getContext(), false);
889  if (isa<IntegerType>(getType()))
890  return falseAttr;
891  if (auto vecTy = dyn_cast<VectorType>(getType()))
892  return SplatElementsAttr::get(vecTy, falseAttr);
893  }
894 
895  return constFoldBinaryOp<IntegerAttr>(
896  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
897  return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
898  });
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // spirv.SGreaterThanEqual
903 //===----------------------------------------------------------------------===//
904 
905 OpFoldResult spirv::SGreaterThanEqualOp::fold(
906  spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
907  // x == x -> true
908  if (getOperand1() == getOperand2()) {
909  auto trueAttr = BoolAttr::get(getContext(), true);
910  if (isa<IntegerType>(getType()))
911  return trueAttr;
912  if (auto vecTy = dyn_cast<VectorType>(getType()))
913  return SplatElementsAttr::get(vecTy, trueAttr);
914  }
915 
916  return constFoldBinaryOp<IntegerAttr>(
917  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
918  return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
919  });
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // spirv.UGreaterThan
924 //===----------------------------------------------------------------------===//
925 
927 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
928  // x == x -> false
929  if (getOperand1() == getOperand2()) {
930  auto falseAttr = BoolAttr::get(getContext(), false);
931  if (isa<IntegerType>(getType()))
932  return falseAttr;
933  if (auto vecTy = dyn_cast<VectorType>(getType()))
934  return SplatElementsAttr::get(vecTy, falseAttr);
935  }
936 
937  return constFoldBinaryOp<IntegerAttr>(
938  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
939  return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
940  });
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // spirv.UGreaterThanEqual
945 //===----------------------------------------------------------------------===//
946 
947 OpFoldResult spirv::UGreaterThanEqualOp::fold(
948  spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
949  // x == x -> true
950  if (getOperand1() == getOperand2()) {
951  auto trueAttr = BoolAttr::get(getContext(), true);
952  if (isa<IntegerType>(getType()))
953  return trueAttr;
954  if (auto vecTy = dyn_cast<VectorType>(getType()))
955  return SplatElementsAttr::get(vecTy, trueAttr);
956  }
957 
958  return constFoldBinaryOp<IntegerAttr>(
959  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
960  return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
961  });
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // spirv.SLessThan
966 //===----------------------------------------------------------------------===//
967 
968 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
969  // x == x -> false
970  if (getOperand1() == getOperand2()) {
971  auto falseAttr = BoolAttr::get(getContext(), false);
972  if (isa<IntegerType>(getType()))
973  return falseAttr;
974  if (auto vecTy = dyn_cast<VectorType>(getType()))
975  return SplatElementsAttr::get(vecTy, falseAttr);
976  }
977 
978  return constFoldBinaryOp<IntegerAttr>(
979  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
980  return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
981  });
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // spirv.SLessThanEqual
986 //===----------------------------------------------------------------------===//
987 
989 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
990  // x == x -> true
991  if (getOperand1() == getOperand2()) {
992  auto trueAttr = BoolAttr::get(getContext(), true);
993  if (isa<IntegerType>(getType()))
994  return trueAttr;
995  if (auto vecTy = dyn_cast<VectorType>(getType()))
996  return SplatElementsAttr::get(vecTy, trueAttr);
997  }
998 
999  return constFoldBinaryOp<IntegerAttr>(
1000  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1001  return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1002  });
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // spirv.ULessThan
1007 //===----------------------------------------------------------------------===//
1008 
1009 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1010  // x == x -> false
1011  if (getOperand1() == getOperand2()) {
1012  auto falseAttr = BoolAttr::get(getContext(), false);
1013  if (isa<IntegerType>(getType()))
1014  return falseAttr;
1015  if (auto vecTy = dyn_cast<VectorType>(getType()))
1016  return SplatElementsAttr::get(vecTy, falseAttr);
1017  }
1018 
1019  return constFoldBinaryOp<IntegerAttr>(
1020  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1021  return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1022  });
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // spirv.ULessThanEqual
1027 //===----------------------------------------------------------------------===//
1028 
1030 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1031  // x == x -> true
1032  if (getOperand1() == getOperand2()) {
1033  auto trueAttr = BoolAttr::get(getContext(), true);
1034  if (isa<IntegerType>(getType()))
1035  return trueAttr;
1036  if (auto vecTy = dyn_cast<VectorType>(getType()))
1037  return SplatElementsAttr::get(vecTy, trueAttr);
1038  }
1039 
1040  return constFoldBinaryOp<IntegerAttr>(
1041  adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1042  return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1043  });
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // spirv.ShiftLeftLogical
1048 //===----------------------------------------------------------------------===//
1049 
1050 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1051  spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1052  // x << 0 -> x
1053  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1054  return getOperand1();
1055  }
1056 
1057  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1058 
1059  // Results are computed per component, and within each component, per bit...
1060  //
1061  // The result is undefined if Shift is greater than or equal to the bit width
1062  // of the components of Base.
1063  //
1064  // So we can use the APInt << method, but don't fold if undefined behaviour.
1065  bool shiftToLarge = false;
1066  auto res = constFoldBinaryOp<IntegerAttr>(
1067  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1068  if (shiftToLarge || b.uge(a.getBitWidth())) {
1069  shiftToLarge = true;
1070  return a;
1071  }
1072  return a << b;
1073  });
1074  return shiftToLarge ? Attribute() : res;
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // spirv.ShiftRightArithmetic
1079 //===----------------------------------------------------------------------===//
1080 
1081 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1082  spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1083  // x >> 0 -> x
1084  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1085  return getOperand1();
1086  }
1087 
1088  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1089 
1090  // Results are computed per component, and within each component, per bit...
1091  //
1092  // The result is undefined if Shift is greater than or equal to the bit width
1093  // of the components of Base.
1094  //
1095  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1096  bool shiftToLarge = false;
1097  auto res = constFoldBinaryOp<IntegerAttr>(
1098  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1099  if (shiftToLarge || b.uge(a.getBitWidth())) {
1100  shiftToLarge = true;
1101  return a;
1102  }
1103  return a.ashr(b);
1104  });
1105  return shiftToLarge ? Attribute() : res;
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // spirv.ShiftRightLogical
1110 //===----------------------------------------------------------------------===//
1111 
1112 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1113  spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1114  // x >> 0 -> x
1115  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1116  return getOperand1();
1117  }
1118 
1119  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1120 
1121  // Results are computed per component, and within each component, per bit...
1122  //
1123  // The result is undefined if Shift is greater than or equal to the bit width
1124  // of the components of Base.
1125  //
1126  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1127  bool shiftToLarge = false;
1128  auto res = constFoldBinaryOp<IntegerAttr>(
1129  adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1130  if (shiftToLarge || b.uge(a.getBitWidth())) {
1131  shiftToLarge = true;
1132  return a;
1133  }
1134  return a.lshr(b);
1135  });
1136  return shiftToLarge ? Attribute() : res;
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // spirv.BitwiseAndOp
1141 //===----------------------------------------------------------------------===//
1142 
1144 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1145  // x & x -> x
1146  if (getOperand1() == getOperand2()) {
1147  return getOperand1();
1148  }
1149 
1150  APInt rhsMask;
1151  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1152  // x & 0 -> 0
1153  if (rhsMask.isZero())
1154  return getOperand2();
1155 
1156  // x & <all ones> -> x
1157  if (rhsMask.isAllOnes())
1158  return getOperand1();
1159 
1160  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1161  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1162  int valueBits =
1164  if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1165  return getOperand1();
1166  }
1167  }
1168 
1169  // According to the SPIR-V spec:
1170  //
1171  // Type is a scalar or vector of integer type.
1172  // Results are computed per component, and within each component, per bit.
1173  // So we can use the APInt & method.
1174  return constFoldBinaryOp<IntegerAttr>(
1175  adaptor.getOperands(),
1176  [](const APInt &a, const APInt &b) { return a & b; });
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // spirv.BitwiseOrOp
1181 //===----------------------------------------------------------------------===//
1182 
1183 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1184  // x | x -> x
1185  if (getOperand1() == getOperand2()) {
1186  return getOperand1();
1187  }
1188 
1189  APInt rhsMask;
1190  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1191  // x | 0 -> x
1192  if (rhsMask.isZero())
1193  return getOperand1();
1194 
1195  // x | <all ones> -> <all ones>
1196  if (rhsMask.isAllOnes())
1197  return getOperand2();
1198  }
1199 
1200  // According to the SPIR-V spec:
1201  //
1202  // Type is a scalar or vector of integer type.
1203  // Results are computed per component, and within each component, per bit.
1204  // So we can use the APInt | method.
1205  return constFoldBinaryOp<IntegerAttr>(
1206  adaptor.getOperands(),
1207  [](const APInt &a, const APInt &b) { return a | b; });
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // spirv.BitwiseXorOp
1212 //===----------------------------------------------------------------------===//
1213 
1215 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1216  // x ^ 0 -> x
1217  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1218  return getOperand1();
1219  }
1220 
1221  // x ^ x -> 0
1222  if (getOperand1() == getOperand2())
1223  return Builder(getContext()).getZeroAttr(getType());
1224 
1225  // According to the SPIR-V spec:
1226  //
1227  // Type is a scalar or vector of integer type.
1228  // Results are computed per component, and within each component, per bit.
1229  // So we can use the APInt ^ method.
1230  return constFoldBinaryOp<IntegerAttr>(
1231  adaptor.getOperands(),
1232  [](const APInt &a, const APInt &b) { return a ^ b; });
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // spirv.mlir.selection
1237 //===----------------------------------------------------------------------===//
1238 
1239 namespace {
1240 // Blocks from the given `spirv.mlir.selection` operation must satisfy the
1241 // following layout:
1242 //
1243 // +-----------------------------------------------+
1244 // | header block |
1245 // | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1246 // +-----------------------------------------------+
1247 // / \
1248 // ...
1249 //
1250 //
1251 // +------------------------+ +------------------------+
1252 // | case #0 | | case #1 |
1253 // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1254 // | spirv.Branch ^merge | | spirv.Branch ^merge |
1255 // +------------------------+ +------------------------+
1256 //
1257 //
1258 // ...
1259 // \ /
1260 // v
1261 // +-------------+
1262 // | merge block |
1263 // +-------------+
1264 //
1265 struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1267 
1268  LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1269  PatternRewriter &rewriter) const override {
1270  Operation *op = selectionOp.getOperation();
1271  Region &body = op->getRegion(0);
1272  // Verifier allows an empty region for `spirv.mlir.selection`.
1273  if (body.empty()) {
1274  return failure();
1275  }
1276 
1277  // Check that region consists of 4 blocks:
1278  // header block, `true` block, `false` block and merge block.
1279  if (llvm::range_size(body) != 4) {
1280  return failure();
1281  }
1282 
1283  Block *headerBlock = selectionOp.getHeaderBlock();
1284  if (!onlyContainsBranchConditionalOp(headerBlock)) {
1285  return failure();
1286  }
1287 
1288  auto brConditionalOp =
1289  cast<spirv::BranchConditionalOp>(headerBlock->front());
1290 
1291  Block *trueBlock = brConditionalOp.getSuccessor(0);
1292  Block *falseBlock = brConditionalOp.getSuccessor(1);
1293  Block *mergeBlock = selectionOp.getMergeBlock();
1294 
1295  if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1296  return failure();
1297 
1298  Value trueValue = getSrcValue(trueBlock);
1299  Value falseValue = getSrcValue(falseBlock);
1300  Value ptrValue = getDstPtr(trueBlock);
1301  auto storeOpAttributes =
1302  cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1303 
1304  auto selectOp = rewriter.create<spirv::SelectOp>(
1305  selectionOp.getLoc(), trueValue.getType(),
1306  brConditionalOp.getCondition(), trueValue, falseValue);
1307  rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1308  selectOp.getResult(), storeOpAttributes);
1309 
1310  // `spirv.mlir.selection` is not needed anymore.
1311  rewriter.eraseOp(op);
1312  return success();
1313  }
1314 
1315 private:
1316  // Checks that given blocks follow the following rules:
1317  // 1. Each conditional block consists of two operations, the first operation
1318  // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1319  // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1320  // 3. A control flow goes into the given merge block from the given
1321  // conditional blocks.
1322  LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1323  Block *mergeBlock) const;
1324 
1325  bool onlyContainsBranchConditionalOp(Block *block) const {
1326  return llvm::hasSingleElement(*block) &&
1327  isa<spirv::BranchConditionalOp>(block->front());
1328  }
1329 
1330  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1331  return lhs->getDiscardableAttrDictionary() ==
1332  rhs->getDiscardableAttrDictionary() &&
1333  lhs.getProperties() == rhs.getProperties();
1334  }
1335 
1336  // Returns a source value for the given block.
1337  Value getSrcValue(Block *block) const {
1338  auto storeOp = cast<spirv::StoreOp>(block->front());
1339  return storeOp.getValue();
1340  }
1341 
1342  // Returns a destination value for the given block.
1343  Value getDstPtr(Block *block) const {
1344  auto storeOp = cast<spirv::StoreOp>(block->front());
1345  return storeOp.getPtr();
1346  }
1347 };
1348 
1349 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1350  Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1351  // Each block must consists of 2 operations.
1352  if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1353  return failure();
1354  }
1355 
1356  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1357  auto trueBrBranchOp =
1358  dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1359  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1360  auto falseBrBranchOp =
1361  dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1362 
1363  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1364  !falseBrBranchOp) {
1365  return failure();
1366  }
1367 
1368  // Checks that given type is valid for `spirv.SelectOp`.
1369  // According to SPIR-V spec:
1370  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1371  // Starting with version 1.4, Result Type can additionally be a composite type
1372  // other than a vector."
1373  bool isScalarOrVector =
1374  llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1375  .isScalarOrVector();
1376 
1377  // Check that each `spirv.Store` uses the same pointer, memory access
1378  // attributes and a valid type of the value.
1379  if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1380  !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1381  return failure();
1382  }
1383 
1384  if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1385  (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1386  return failure();
1387  }
1388 
1389  return success();
1390 }
1391 } // namespace
1392 
1393 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1394  MLIRContext *context) {
1395  results.add<ConvertSelectionOpToSelect>(context);
1396 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static uint64_t zext(uint32_t arg)
static MLIRContext * getContext(OpFoldResult val)
static Attribute extractCompositeElement(Attribute composite, ArrayRef< unsigned > indices)
MulExtendedFold< spirv::UMulExtendedOp, false > UMulExtendedOpFold
static std::optional< bool > getScalarOrSplatBoolAttr(Attribute attr)
Returns the boolean value under the hood if the given boolAttr is a scalar or splat vector bool const...
static bool isDivZeroOrOverflow(const APInt &a, const APInt &b)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation & front()
Definition: Block.h:150
iterator begin()
Definition: Block.h:140
Block * getSuccessor(unsigned i)
Definition: Block.cpp:258
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:389
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(MulOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(spirv::UMulExtendedOp op, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Attempt to match against code rooted at the specified operation, which is the same operation code as ...
Definition: PatternMatch.h:329