MLIR  21.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46  return boolAttr;
47  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49  return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55  Builder builder) {
56  // If the source number uses less active bits than the target bitwidth, then
57  // it should be safe to convert.
58  if (srcAttr.getValue().isIntN(dstType.getWidth()))
59  return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61  // XXX: Try again by interpreting the source number as a signed value.
62  // Although integers in the standard dialect are signless, they can represent
63  // a signed number. It's the operation decides how to interpret. This is
64  // dangerous, but it seems there is no good way of handling this if we still
65  // want to change the bitwidth. Emit a message at least.
66  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69  << dstAttr << "' for type '" << dstType << "'\n");
70  return dstAttr;
71  }
72 
73  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74  << "' illegal: cannot fit into target type '"
75  << dstType << "'\n");
76  return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82  Builder builder) {
83  // Only support converting to float for now.
84  if (!dstType.isF32())
85  return FloatAttr();
86 
87  // Try to convert the source floating-point number to single precision.
88  APFloat dstVal = srcAttr.getValue();
89  bool losesInfo = false;
90  APFloat::opStatus status =
91  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92  if (status != APFloat::opOK || losesInfo) {
93  LLVM_DEBUG(llvm::dbgs()
94  << srcAttr << " illegal: cannot fit into converted type '"
95  << dstType << "'\n");
96  return FloatAttr();
97  }
98 
99  return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 /// Returns true if the given `type` is a boolean scalar or vector type.
103 static bool isBoolScalarOrVector(Type type) {
104  assert(type && "Not a valid type");
105  if (type.isInteger(1))
106  return true;
107 
108  if (auto vecType = dyn_cast<VectorType>(type))
109  return vecType.getElementType().isInteger(1);
110 
111  return false;
112 }
113 
114 /// Creates a scalar/vector integer constant.
115 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116  OpBuilder &builder, Location loc) {
117  if (auto vectorType = dyn_cast<VectorType>(type)) {
118  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119  auto attr = SplatElementsAttr::get(vectorType, element);
120  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121  }
122 
123  if (auto intType = dyn_cast<IntegerType>(type))
124  return builder.create<spirv::ConstantOp>(
125  loc, type, builder.getIntegerAttr(type, value));
126 
127  return nullptr;
128 }
129 
130 /// Returns true if scalar/vector type `a` and `b` have the same number of
131 /// bitwidth.
132 static bool hasSameBitwidth(Type a, Type b) {
133  auto getNumBitwidth = [](Type type) {
134  unsigned bw = 0;
135  if (type.isIntOrFloat())
136  bw = type.getIntOrFloatBitWidth();
137  else if (auto vecType = dyn_cast<VectorType>(type))
138  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139  return bw;
140  };
141  unsigned aBW = getNumBitwidth(a);
142  unsigned bBW = getNumBitwidth(b);
143  return aBW != 0 && bBW != 0 && aBW == bBW;
144 }
145 
146 /// Returns a source type conversion failure for `srcType` and operation `op`.
147 static LogicalResult
149  Type srcType) {
150  return rewriter.notifyMatchFailure(
151  op->getLoc(),
152  llvm::formatv("failed to convert source type '{0}'", srcType));
153 }
154 
155 /// Returns a source type conversion failure for the result type of `op`.
156 static LogicalResult
158  assert(op->getNumResults() == 1);
159  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
160 }
161 
162 // TODO: Move to some common place?
163 static std::string getDecorationString(spirv::Decoration decor) {
164  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
165 }
166 
167 namespace {
168 
169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170 /// operations. Op can potentially support overflow flags.
171 template <typename Op, typename SPIRVOp>
172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
174 
175  LogicalResult
176  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  assert(adaptor.getOperands().size() <= 3);
179  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180  Type dstType = converter->convertType(op.getType());
181  if (!dstType) {
182  return rewriter.notifyMatchFailure(
183  op->getLoc(),
184  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185  }
186 
187  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188  !getElementTypeOrSelf(op.getType()).isIndex() &&
189  dstType != op.getType()) {
190  return op.emitError("bitwidth emulation is not implemented yet on "
191  "unsigned op pattern version");
192  }
193 
194  auto overflowFlags = arith::IntegerOverflowFlags::none;
195  if (auto overflowIface =
196  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197  if (converter->getTargetEnv().allows(
198  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199  overflowFlags = overflowIface.getOverflowAttr().getValue();
200  }
201 
202  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203  op, dstType, adaptor.getOperands());
204 
205  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207  rewriter.getUnitAttr());
208 
209  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211  rewriter.getUnitAttr());
212 
213  return success();
214  }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConstantOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Converts composite arith.constant operation to spirv.Constant.
222 struct ConstantCompositeOpPattern final
225 
226  LogicalResult
227  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  auto srcType = dyn_cast<ShapedType>(constOp.getType());
230  if (!srcType || srcType.getNumElements() == 1)
231  return failure();
232 
233  // arith.constant should only have vector or tensor types. This is a MLIR
234  // wide problem at the moment.
235  if (!isa<VectorType, RankedTensorType>(srcType))
236  return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237 
238  Type dstType = getTypeConverter()->convertType(srcType);
239  if (!dstType)
240  return failure();
241 
242  // Import the resource into the IR to make use of the special handling of
243  // element types later on.
244  mlir::DenseElementsAttr dstElementsAttr;
245  if (auto denseElementsAttr =
246  dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247  dstElementsAttr = denseElementsAttr;
248  } else if (auto resourceAttr =
249  dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250 
251  AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252  if (!blob)
253  return constOp->emitError("could not find resource blob");
254 
255  ArrayRef<char> ptr = blob->getData();
256 
257  // Check that the buffer meets the requirements to get converted to a
258  // DenseElementsAttr
259  bool detectedSplat = false;
260  if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261  return constOp->emitError("resource is not a valid buffer");
262 
263  dstElementsAttr =
264  DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265  } else {
266  return constOp->emitError("unsupported elements attribute");
267  }
268 
269  ShapedType dstAttrType = dstElementsAttr.getType();
270 
271  // If the composite type has more than one dimensions, perform
272  // linearization.
273  if (srcType.getRank() > 1) {
274  if (isa<RankedTensorType>(srcType)) {
275  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276  srcType.getElementType());
277  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278  } else {
279  // TODO: add support for large vectors.
280  return failure();
281  }
282  }
283 
284  Type srcElemType = srcType.getElementType();
285  Type dstElemType;
286  // Tensor types are converted to SPIR-V array types; vector types are
287  // converted to SPIR-V vector/array types.
288  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289  dstElemType = arrayType.getElementType();
290  else
291  dstElemType = cast<VectorType>(dstType).getElementType();
292 
293  // If the source and destination element types are different, perform
294  // attribute conversion.
295  if (srcElemType != dstElemType) {
296  SmallVector<Attribute, 8> elements;
297  if (isa<FloatType>(srcElemType)) {
298  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299  FloatAttr dstAttr =
300  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301  if (!dstAttr)
302  return failure();
303  elements.push_back(dstAttr);
304  }
305  } else if (srcElemType.isInteger(1)) {
306  return failure();
307  } else {
308  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309  IntegerAttr dstAttr = convertIntegerAttr(
310  srcAttr, cast<IntegerType>(dstElemType), rewriter);
311  if (!dstAttr)
312  return failure();
313  elements.push_back(dstAttr);
314  }
315  }
316 
317  // Unfortunately, we cannot use dialect-specific types for element
318  // attributes; element attributes only works with builtin types. So we
319  // need to prepare another converted builtin types for the destination
320  // elements attribute.
321  if (isa<RankedTensorType>(dstAttrType))
322  dstAttrType =
323  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324  else
325  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326 
327  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328  }
329 
330  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331  dstElementsAttr);
332  return success();
333  }
334 };
335 
336 /// Converts scalar arith.constant operation to spirv.Constant.
337 struct ConstantScalarOpPattern final
338  : public OpConversionPattern<arith::ConstantOp> {
340 
341  LogicalResult
342  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  Type srcType = constOp.getType();
345  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346  if (shapedType.getNumElements() != 1)
347  return failure();
348  srcType = shapedType.getElementType();
349  }
350  if (!srcType.isIntOrIndexOrFloat())
351  return failure();
352 
353  Attribute cstAttr = constOp.getValue();
354  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355  cstAttr = elementsAttr.getSplatValue<Attribute>();
356 
357  Type dstType = getTypeConverter()->convertType(srcType);
358  if (!dstType)
359  return failure();
360 
361  // Floating-point types.
362  if (isa<FloatType>(srcType)) {
363  auto srcAttr = cast<FloatAttr>(cstAttr);
364  auto dstAttr = srcAttr;
365 
366  // Floating-point types not supported in the target environment are all
367  // converted to float type.
368  if (srcType != dstType) {
369  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370  if (!dstAttr)
371  return failure();
372  }
373 
374  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375  return success();
376  }
377 
378  // Bool type.
379  if (srcType.isInteger(1)) {
380  // arith.constant can use 0/1 instead of true/false for i1 values. We need
381  // to handle that here.
382  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
383  if (!dstAttr)
384  return failure();
385  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386  return success();
387  }
388 
389  // IndexType or IntegerType. Index values are converted to 32-bit integer
390  // values when converting to SPIR-V.
391  auto srcAttr = cast<IntegerAttr>(cstAttr);
392  IntegerAttr dstAttr =
393  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394  if (!dstAttr)
395  return failure();
396  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397  return success();
398  }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // RemSIOp
403 //===----------------------------------------------------------------------===//
404 
405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406 /// the sign of `signOperand`.
407 ///
408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
411 /// if either operand can be negative. Emulate it via spirv.UMod.
412 template <typename SignedAbsOp>
413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414  Value signOperand, OpBuilder &builder) {
415  assert(lhs.getType() == rhs.getType());
416  assert(lhs == signOperand || rhs == signOperand);
417 
418  Type type = lhs.getType();
419 
420  // Calculate the remainder with spirv.UMod.
421  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424 
425  // Fix the sign.
426  Value isPositive;
427  if (lhs == signOperand)
428  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429  else
430  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433 }
434 
435 /// Converts arith.remsi to GLSL SPIR-V ops.
436 ///
437 /// This cannot be merged into the template unary/binary pattern due to Vulkan
438 /// restrictions over spirv.SRem and spirv.SMod.
439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
441 
442  LogicalResult
443  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447  adaptor.getOperands()[0], rewriter);
448  rewriter.replaceOp(op, result);
449 
450  return success();
451  }
452 };
453 
454 /// Converts arith.remsi to OpenCL SPIR-V ops.
455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
457 
458  LogicalResult
459  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460  ConversionPatternRewriter &rewriter) const override {
461  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463  adaptor.getOperands()[0], rewriter);
464  rewriter.replaceOp(op, result);
465 
466  return success();
467  }
468 };
469 
470 //===----------------------------------------------------------------------===//
471 // BitwiseOp
472 //===----------------------------------------------------------------------===//
473 
474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
475 /// other than the BinaryOpPatternPattern because if the operands are boolean
476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
481 
482  LogicalResult
483  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485  assert(adaptor.getOperands().size() == 2);
486  Type dstType = this->getTypeConverter()->convertType(op.getType());
487  if (!dstType)
488  return getTypeConversionFailure(rewriter, op);
489 
490  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492  op, dstType, adaptor.getOperands());
493  } else {
494  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495  op, dstType, adaptor.getOperands());
496  }
497  return success();
498  }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // XOrIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.xori to SPIR-V operations.
506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
508 
509  LogicalResult
510  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511  ConversionPatternRewriter &rewriter) const override {
512  assert(adaptor.getOperands().size() == 2);
513 
514  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515  return failure();
516 
517  Type dstType = getTypeConverter()->convertType(op.getType());
518  if (!dstType)
519  return getTypeConversionFailure(rewriter, op);
520 
521  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522  adaptor.getOperands());
523 
524  return success();
525  }
526 };
527 
528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529 /// vector of i1.
530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
532 
533  LogicalResult
534  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  assert(adaptor.getOperands().size() == 2);
537 
538  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539  return failure();
540 
541  Type dstType = getTypeConverter()->convertType(op.getType());
542  if (!dstType)
543  return getTypeConversionFailure(rewriter, op);
544 
545  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546  op, dstType, adaptor.getOperands());
547  return success();
548  }
549 };
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556 /// of i1.
557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
559 
560  LogicalResult
561  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  Type srcType = adaptor.getOperands().front().getType();
564  if (!isBoolScalarOrVector(srcType))
565  return failure();
566 
567  Type dstType = getTypeConverter()->convertType(op.getType());
568  if (!dstType)
569  return getTypeConversionFailure(rewriter, op);
570 
571  Location loc = op.getLoc();
572  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575  op, dstType, adaptor.getOperands().front(), one, zero);
576  return success();
577  }
578 };
579 
580 //===----------------------------------------------------------------------===//
581 // ExtSIOp
582 //===----------------------------------------------------------------------===//
583 
584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585 /// of i1.
586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
588 
589  LogicalResult
590  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  Value operand = adaptor.getIn();
593  if (!isBoolScalarOrVector(operand.getType()))
594  return failure();
595 
596  Location loc = op.getLoc();
597  Type dstType = getTypeConverter()->convertType(op.getType());
598  if (!dstType)
599  return getTypeConversionFailure(rewriter, op);
600 
601  Value allOnes;
602  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603  unsigned componentBitwidth = intTy.getWidth();
604  allOnes = rewriter.create<spirv::ConstantOp>(
605  loc, intTy,
606  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609  allOnes = rewriter.create<spirv::ConstantOp>(
610  loc, vectorTy,
611  SplatElementsAttr::get(vectorTy,
612  APInt::getAllOnes(componentBitwidth)));
613  } else {
614  return rewriter.notifyMatchFailure(
615  loc, llvm::formatv("unhandled type: {0}", dstType));
616  }
617 
618  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620  zero);
621  return success();
622  }
623 };
624 
625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626 /// vector of i1.
627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
629 
630  LogicalResult
631  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  Type srcType = adaptor.getIn().getType();
634  if (isBoolScalarOrVector(srcType))
635  return failure();
636 
637  Type dstType = getTypeConverter()->convertType(op.getType());
638  if (!dstType)
639  return getTypeConversionFailure(rewriter, op);
640 
641  if (dstType == srcType) {
642  // We can have the same source and destination type due to type emulation.
643  // Perform bit shifting to make sure we have the proper leading set bits.
644 
645  unsigned srcBW =
646  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647  unsigned dstBW =
649  assert(srcBW < dstBW);
650  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651  rewriter, op.getLoc());
652 
653  // First shift left to sequeeze out all leading bits beyond the original
654  // bitwidth. Here we need to use the original source and result type's
655  // bitwidth.
656  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658 
659  // Then we perform arithmetic right shift to make sure we have the right
660  // sign bits for negative values.
661  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662  op, dstType, shiftLOp, shiftSize);
663  } else {
664  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665  adaptor.getOperands());
666  }
667 
668  return success();
669  }
670 };
671 
672 //===----------------------------------------------------------------------===//
673 // ExtUIOp
674 //===----------------------------------------------------------------------===//
675 
676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677 /// of i1.
678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
680 
681  LogicalResult
682  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683  ConversionPatternRewriter &rewriter) const override {
684  Type srcType = adaptor.getOperands().front().getType();
685  if (!isBoolScalarOrVector(srcType))
686  return failure();
687 
688  Type dstType = getTypeConverter()->convertType(op.getType());
689  if (!dstType)
690  return getTypeConversionFailure(rewriter, op);
691 
692  Location loc = op.getLoc();
693  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696  op, dstType, adaptor.getOperands().front(), one, zero);
697  return success();
698  }
699 };
700 
701 /// Converts arith.extui for cases where the type of source is neither i1 nor
702 /// vector of i1.
703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
705 
706  LogicalResult
707  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708  ConversionPatternRewriter &rewriter) const override {
709  Type srcType = adaptor.getIn().getType();
710  if (isBoolScalarOrVector(srcType))
711  return failure();
712 
713  Type dstType = getTypeConverter()->convertType(op.getType());
714  if (!dstType)
715  return getTypeConversionFailure(rewriter, op);
716 
717  if (dstType == srcType) {
718  // We can have the same source and destination type due to type emulation.
719  // Perform bit masking to make sure we don't pollute downstream consumers
720  // with unwanted bits. Here we need to use the original source type's
721  // bitwidth.
722  unsigned bitwidth =
723  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
725  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
726  op.getLoc());
727  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728  adaptor.getIn(), mask);
729  } else {
730  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731  adaptor.getOperands());
732  }
733  return success();
734  }
735 };
736 
737 //===----------------------------------------------------------------------===//
738 // TruncIOp
739 //===----------------------------------------------------------------------===//
740 
741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742 /// of i1.
743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
745 
746  LogicalResult
747  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748  ConversionPatternRewriter &rewriter) const override {
749  Type dstType = getTypeConverter()->convertType(op.getType());
750  if (!dstType)
751  return getTypeConversionFailure(rewriter, op);
752 
753  if (!isBoolScalarOrVector(dstType))
754  return failure();
755 
756  Location loc = op.getLoc();
757  auto srcType = adaptor.getOperands().front().getType();
758  // Check if (x & 1) == 1.
759  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761  loc, srcType, adaptor.getOperands()[0], mask);
762  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763 
764  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767  return success();
768  }
769 };
770 
771 /// Converts arith.trunci for cases where the type of result is neither i1
772 /// nor vector of i1.
773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
775 
776  LogicalResult
777  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778  ConversionPatternRewriter &rewriter) const override {
779  Type srcType = adaptor.getIn().getType();
780  Type dstType = getTypeConverter()->convertType(op.getType());
781  if (!dstType)
782  return getTypeConversionFailure(rewriter, op);
783 
784  if (isBoolScalarOrVector(dstType))
785  return failure();
786 
787  if (dstType == srcType) {
788  // We can have the same source and destination type due to type emulation.
789  // Perform bit masking to make sure we don't pollute downstream consumers
790  // with unwanted bits. Here we need to use the original result type's
791  // bitwidth.
792  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
794  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
795  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796  adaptor.getIn(), mask);
797  } else {
798  // Given this is truncation, either SConvertOp or UConvertOp works.
799  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800  adaptor.getOperands());
801  }
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // TypeCastingOp
808 //===----------------------------------------------------------------------===//
809 
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812  switch (roundingMode) {
813  case arith::RoundingMode::downward:
814  return spirv::FPRoundingMode::RTN;
815  case arith::RoundingMode::to_nearest_even:
816  return spirv::FPRoundingMode::RTE;
817  case arith::RoundingMode::toward_zero:
818  return spirv::FPRoundingMode::RTZ;
819  case arith::RoundingMode::upward:
820  return spirv::FPRoundingMode::RTP;
821  case arith::RoundingMode::to_nearest_away:
822  // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823  // (as of SPIR-V 1.6)
824  return std::nullopt;
825  }
826  llvm_unreachable("Unhandled rounding mode");
827 }
828 
829 /// Converts type-casting standard operations to SPIR-V operations.
830 template <typename Op, typename SPIRVOp>
831 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
833 
834  LogicalResult
835  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836  ConversionPatternRewriter &rewriter) const override {
837  Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
838  Type dstType = this->getTypeConverter()->convertType(op.getType());
839  if (!dstType)
840  return getTypeConversionFailure(rewriter, op);
841 
842  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
843  return failure();
844 
845  if (dstType == srcType) {
846  // Due to type conversion, we are seeing the same source and target type.
847  // Then we can just erase this operation by forwarding its operand.
848  rewriter.replaceOp(op, adaptor.getOperands().front());
849  } else {
850  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
851  op, dstType, adaptor.getOperands());
852  if (auto roundingModeOp =
853  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854  if (arith::RoundingModeAttr roundingMode =
855  roundingModeOp.getRoundingModeAttr()) {
856  if (auto rm =
857  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
858  newOp->setAttr(
859  getDecorationString(spirv::Decoration::FPRoundingMode),
860  spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
861  } else {
862  return rewriter.notifyMatchFailure(
863  op->getLoc(),
864  llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
865  }
866  }
867  }
868  }
869  return success();
870  }
871 };
872 
873 //===----------------------------------------------------------------------===//
874 // CmpIOp
875 //===----------------------------------------------------------------------===//
876 
877 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
878 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
879 public:
881 
882  LogicalResult
883  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
884  ConversionPatternRewriter &rewriter) const override {
885  Type srcType = op.getLhs().getType();
886  if (!isBoolScalarOrVector(srcType))
887  return failure();
888  Type dstType = getTypeConverter()->convertType(srcType);
889  if (!dstType)
890  return getTypeConversionFailure(rewriter, op, srcType);
891 
892  switch (op.getPredicate()) {
893  case arith::CmpIPredicate::eq: {
894  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
895  adaptor.getRhs());
896  return success();
897  }
898  case arith::CmpIPredicate::ne: {
899  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
900  op, adaptor.getLhs(), adaptor.getRhs());
901  return success();
902  }
903  case arith::CmpIPredicate::uge:
904  case arith::CmpIPredicate::ugt:
905  case arith::CmpIPredicate::ule:
906  case arith::CmpIPredicate::ult: {
907  // There are no direct corresponding instructions in SPIR-V for such
908  // cases. Extend them to 32-bit and do comparision then.
909  Type type = rewriter.getI32Type();
910  if (auto vectorType = dyn_cast<VectorType>(dstType))
911  type = VectorType::get(vectorType.getShape(), type);
912  Value extLhs =
913  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
914  Value extRhs =
915  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
916 
917  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
918  extRhs);
919  return success();
920  }
921  default:
922  break;
923  }
924  return failure();
925  }
926 };
927 
928 /// Converts integer compare operation to SPIR-V ops.
929 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
930 public:
932 
933  LogicalResult
934  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
935  ConversionPatternRewriter &rewriter) const override {
936  Type srcType = op.getLhs().getType();
937  if (isBoolScalarOrVector(srcType))
938  return failure();
939  Type dstType = getTypeConverter()->convertType(srcType);
940  if (!dstType)
941  return getTypeConversionFailure(rewriter, op, srcType);
942 
943  switch (op.getPredicate()) {
944 #define DISPATCH(cmpPredicate, spirvOp) \
945  case cmpPredicate: \
946  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
947  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
948  !hasSameBitwidth(srcType, dstType)) { \
949  return op.emitError( \
950  "bitwidth emulation is not implemented yet on unsigned op"); \
951  } \
952  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
953  adaptor.getRhs()); \
954  return success();
955 
956  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
957  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
958  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
959  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
960  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
961  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
962  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
963  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
964  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
965  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
966 
967 #undef DISPATCH
968  }
969  return failure();
970  }
971 };
972 
973 //===----------------------------------------------------------------------===//
974 // CmpFOpPattern
975 //===----------------------------------------------------------------------===//
976 
977 /// Converts floating-point comparison operations to SPIR-V ops.
978 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
979 public:
981 
982  LogicalResult
983  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
984  ConversionPatternRewriter &rewriter) const override {
985  switch (op.getPredicate()) {
986 #define DISPATCH(cmpPredicate, spirvOp) \
987  case cmpPredicate: \
988  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
989  adaptor.getRhs()); \
990  return success();
991 
992  // Ordered.
993  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
994  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
995  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
996  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
997  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
998  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
999  // Unordered.
1000  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1001  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1002  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1003  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1004  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1005  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1006 
1007 #undef DISPATCH
1008 
1009  default:
1010  break;
1011  }
1012  return failure();
1013  }
1014 };
1015 
1016 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
1017 /// Kernel capability.
1018 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1019 public:
1021 
1022  LogicalResult
1023  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1024  ConversionPatternRewriter &rewriter) const override {
1025  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1026  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1027  adaptor.getRhs());
1028  return success();
1029  }
1030 
1031  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1032  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1033  adaptor.getRhs());
1034  return success();
1035  }
1036 
1037  return failure();
1038  }
1039 };
1040 
1041 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1042 /// require additional capability.
1043 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1044 public:
1046 
1047  LogicalResult
1048  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1049  ConversionPatternRewriter &rewriter) const override {
1050  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1051  op.getPredicate() != arith::CmpFPredicate::UNO)
1052  return failure();
1053 
1054  Location loc = op.getLoc();
1055 
1056  Value replace;
1057  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1058  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1059  // Ordered comparsion checks if neither operand is NaN.
1060  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1061  } else {
1062  // Unordered comparsion checks if either operand is NaN.
1063  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1064  }
1065  } else {
1066  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1067  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1068 
1069  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1070  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1071  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1072  }
1073 
1074  rewriter.replaceOp(op, replace);
1075  return success();
1076  }
1077 };
1078 
1079 //===----------------------------------------------------------------------===//
1080 // AddUIExtendedOp
1081 //===----------------------------------------------------------------------===//
1082 
1083 /// Converts arith.addui_extended to spirv.IAddCarry.
1084 class AddUIExtendedOpPattern final
1085  : public OpConversionPattern<arith::AddUIExtendedOp> {
1086 public:
1088  LogicalResult
1089  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1090  ConversionPatternRewriter &rewriter) const override {
1091  Type dstElemTy = adaptor.getLhs().getType();
1092  Location loc = op->getLoc();
1093  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1094  adaptor.getRhs());
1095 
1096  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1097  loc, result, llvm::ArrayRef(0));
1098  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1099  loc, result, llvm::ArrayRef(1));
1100 
1101  // Convert the carry value to boolean.
1102  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1103  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1104 
1105  rewriter.replaceOp(op, {sumResult, carryResult});
1106  return success();
1107  }
1108 };
1109 
1110 //===----------------------------------------------------------------------===//
1111 // MulIExtendedOp
1112 //===----------------------------------------------------------------------===//
1113 
1114 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1115 template <typename ArithMulOp, typename SPIRVMulOp>
1116 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1117 public:
1119  LogicalResult
1120  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1121  ConversionPatternRewriter &rewriter) const override {
1122  Location loc = op->getLoc();
1123  Value result =
1124  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1125 
1126  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1127  llvm::ArrayRef(0));
1128  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1129  llvm::ArrayRef(1));
1130 
1131  rewriter.replaceOp(op, {low, high});
1132  return success();
1133  }
1134 };
1135 
1136 //===----------------------------------------------------------------------===//
1137 // SelectOp
1138 //===----------------------------------------------------------------------===//
1139 
1140 /// Converts arith.select to spirv.Select.
1141 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1142 public:
1144  LogicalResult
1145  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1146  ConversionPatternRewriter &rewriter) const override {
1147  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1148  adaptor.getTrueValue(),
1149  adaptor.getFalseValue());
1150  return success();
1151  }
1152 };
1153 
1154 //===----------------------------------------------------------------------===//
1155 // MinimumFOp, MaximumFOp
1156 //===----------------------------------------------------------------------===//
1157 
1158 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1159 /// spirv.CL.fmax/fmin.
1160 template <typename Op, typename SPIRVOp>
1161 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1162 public:
1164  LogicalResult
1165  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1166  ConversionPatternRewriter &rewriter) const override {
1167  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1168  Type dstType = converter->convertType(op.getType());
1169  if (!dstType)
1170  return getTypeConversionFailure(rewriter, op);
1171 
1172  // arith.maximumf/minimumf:
1173  // "if one of the arguments is NaN, then the result is also NaN."
1174  // spirv.GL.FMax/FMin
1175  // "which operand is the result is undefined if one of the operands
1176  // is a NaN."
1177  // spirv.CL.fmax/fmin:
1178  // "If one argument is a NaN, Fmin returns the other argument."
1179 
1180  Location loc = op.getLoc();
1181  Value spirvOp =
1182  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1183 
1184  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1185  rewriter.replaceOp(op, spirvOp);
1186  return success();
1187  }
1188 
1189  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1190  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1191 
1192  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1193  adaptor.getLhs(), spirvOp);
1194  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1195  adaptor.getRhs(), select1);
1196 
1197  rewriter.replaceOp(op, select2);
1198  return success();
1199  }
1200 };
1201 
1202 //===----------------------------------------------------------------------===//
1203 // MinNumFOp, MaxNumFOp
1204 //===----------------------------------------------------------------------===//
1205 
1206 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1207 /// spirv.CL.fmax/fmin.
1208 template <typename Op, typename SPIRVOp>
1209 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1210  template <typename TargetOp>
1211  constexpr bool shouldInsertNanGuards() const {
1212  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1213  }
1214 
1215 public:
1217  LogicalResult
1218  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1219  ConversionPatternRewriter &rewriter) const override {
1220  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1221  Type dstType = converter->convertType(op.getType());
1222  if (!dstType)
1223  return getTypeConversionFailure(rewriter, op);
1224 
1225  // arith.maxnumf/minnumf:
1226  // "If one of the arguments is NaN, then the result is the other
1227  // argument."
1228  // spirv.GL.FMax/FMin
1229  // "which operand is the result is undefined if one of the operands
1230  // is a NaN."
1231  // spirv.CL.fmax/fmin:
1232  // "If one argument is a NaN, Fmin returns the other argument."
1233 
1234  Location loc = op.getLoc();
1235  Value spirvOp =
1236  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1237 
1238  if (!shouldInsertNanGuards<SPIRVOp>() ||
1239  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1240  rewriter.replaceOp(op, spirvOp);
1241  return success();
1242  }
1243 
1244  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1245  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1246 
1247  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1248  adaptor.getRhs(), spirvOp);
1249  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1250  adaptor.getLhs(), select1);
1251 
1252  rewriter.replaceOp(op, select2);
1253  return success();
1254  }
1255 };
1256 
1257 } // namespace
1258 
1259 //===----------------------------------------------------------------------===//
1260 // Pattern Population
1261 //===----------------------------------------------------------------------===//
1262 
1264  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1265  // clang-format off
1266  patterns.add<
1267  ConstantCompositeOpPattern,
1268  ConstantScalarOpPattern,
1269  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1270  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1271  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1275  RemSIOpGLPattern, RemSIOpCLPattern,
1276  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1277  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1278  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1279  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1288  ExtUIPattern, ExtUII1Pattern,
1289  ExtSIPattern, ExtSII1Pattern,
1290  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1291  TruncIPattern, TruncII1Pattern,
1292  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1293  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1294  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1295  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1296  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1297  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1298  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1299  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1300  CmpIOpBooleanPattern, CmpIOpPattern,
1301  CmpFOpNanNonePattern, CmpFOpPattern,
1302  AddUIExtendedOpPattern,
1303  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1304  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1305  SelectOpPattern,
1306 
1307  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1308  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1309  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1310  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1315 
1316  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1317  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1318  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1319  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1324  >(typeConverter, patterns.getContext());
1325  // clang-format on
1326 
1327  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1328  // capability is available.
1329  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1330  /*benefit=*/2);
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // Pass Definition
1335 //===----------------------------------------------------------------------===//
1336 
1337 namespace {
1338 struct ConvertArithToSPIRVPass
1339  : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1340  using Base::Base;
1341 
1342  void runOnOperation() override {
1343  Operation *op = getOperation();
1345  std::unique_ptr<SPIRVConversionTarget> target =
1346  SPIRVConversionTarget::get(targetAttr);
1347 
1349  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1350  SPIRVTypeConverter typeConverter(targetAttr, options);
1351 
1352  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1353  // in patterns for other dialects.
1354  target->addLegalOp<UnrealizedConversionCastOp>();
1355 
1356  // Fail hard when there are any remaining 'arith' ops.
1357  target->addIllegalDialect<arith::ArithDialect>();
1358 
1361 
1362  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1363  signalPassFailure();
1364  }
1365 };
1366 } // namespace
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI32Type()
Definition: Builders.cpp:63
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:56
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:242
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:826
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23