MLIR  21.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46  return boolAttr;
47  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49  return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55  Builder builder) {
56  // If the source number uses less active bits than the target bitwidth, then
57  // it should be safe to convert.
58  if (srcAttr.getValue().isIntN(dstType.getWidth()))
59  return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61  // XXX: Try again by interpreting the source number as a signed value.
62  // Although integers in the standard dialect are signless, they can represent
63  // a signed number. It's the operation decides how to interpret. This is
64  // dangerous, but it seems there is no good way of handling this if we still
65  // want to change the bitwidth. Emit a message at least.
66  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69  << dstAttr << "' for type '" << dstType << "'\n");
70  return dstAttr;
71  }
72 
73  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74  << "' illegal: cannot fit into target type '"
75  << dstType << "'\n");
76  return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82  Builder builder) {
83  // Only support converting to float for now.
84  if (!dstType.isF32())
85  return FloatAttr();
86 
87  // Try to convert the source floating-point number to single precision.
88  APFloat dstVal = srcAttr.getValue();
89  bool losesInfo = false;
90  APFloat::opStatus status =
91  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92  if (status != APFloat::opOK || losesInfo) {
93  LLVM_DEBUG(llvm::dbgs()
94  << srcAttr << " illegal: cannot fit into converted type '"
95  << dstType << "'\n");
96  return FloatAttr();
97  }
98 
99  return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 /// Returns true if the given `type` is a boolean scalar or vector type.
103 static bool isBoolScalarOrVector(Type type) {
104  assert(type && "Not a valid type");
105  if (type.isInteger(1))
106  return true;
107 
108  if (auto vecType = dyn_cast<VectorType>(type))
109  return vecType.getElementType().isInteger(1);
110 
111  return false;
112 }
113 
114 /// Creates a scalar/vector integer constant.
115 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116  OpBuilder &builder, Location loc) {
117  if (auto vectorType = dyn_cast<VectorType>(type)) {
118  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119  auto attr = SplatElementsAttr::get(vectorType, element);
120  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121  }
122 
123  if (auto intType = dyn_cast<IntegerType>(type))
124  return builder.create<spirv::ConstantOp>(
125  loc, type, builder.getIntegerAttr(type, value));
126 
127  return nullptr;
128 }
129 
130 /// Returns true if scalar/vector type `a` and `b` have the same number of
131 /// bitwidth.
132 static bool hasSameBitwidth(Type a, Type b) {
133  auto getNumBitwidth = [](Type type) {
134  unsigned bw = 0;
135  if (type.isIntOrFloat())
136  bw = type.getIntOrFloatBitWidth();
137  else if (auto vecType = dyn_cast<VectorType>(type))
138  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139  return bw;
140  };
141  unsigned aBW = getNumBitwidth(a);
142  unsigned bBW = getNumBitwidth(b);
143  return aBW != 0 && bBW != 0 && aBW == bBW;
144 }
145 
146 /// Returns a source type conversion failure for `srcType` and operation `op`.
147 static LogicalResult
149  Type srcType) {
150  return rewriter.notifyMatchFailure(
151  op->getLoc(),
152  llvm::formatv("failed to convert source type '{0}'", srcType));
153 }
154 
155 /// Returns a source type conversion failure for the result type of `op`.
156 static LogicalResult
158  assert(op->getNumResults() == 1);
159  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
160 }
161 
162 // TODO: Move to some common place?
163 static std::string getDecorationString(spirv::Decoration decor) {
164  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
165 }
166 
167 namespace {
168 
169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170 /// operations. Op can potentially support overflow flags.
171 template <typename Op, typename SPIRVOp>
172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
174 
175  LogicalResult
176  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  assert(adaptor.getOperands().size() <= 3);
179  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180  Type dstType = converter->convertType(op.getType());
181  if (!dstType) {
182  return rewriter.notifyMatchFailure(
183  op->getLoc(),
184  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185  }
186 
187  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188  !getElementTypeOrSelf(op.getType()).isIndex() &&
189  dstType != op.getType()) {
190  return op.emitError("bitwidth emulation is not implemented yet on "
191  "unsigned op pattern version");
192  }
193 
194  auto overflowFlags = arith::IntegerOverflowFlags::none;
195  if (auto overflowIface =
196  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197  if (converter->getTargetEnv().allows(
198  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199  overflowFlags = overflowIface.getOverflowAttr().getValue();
200  }
201 
202  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203  op, dstType, adaptor.getOperands());
204 
205  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207  rewriter.getUnitAttr());
208 
209  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211  rewriter.getUnitAttr());
212 
213  return success();
214  }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConstantOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Converts composite arith.constant operation to spirv.Constant.
222 struct ConstantCompositeOpPattern final
225 
226  LogicalResult
227  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  auto srcType = dyn_cast<ShapedType>(constOp.getType());
230  if (!srcType || srcType.getNumElements() == 1)
231  return failure();
232 
233  // arith.constant should only have vector or tensor types. This is a MLIR
234  // wide problem at the moment.
235  if (!isa<VectorType, RankedTensorType>(srcType))
236  return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237 
238  Type dstType = getTypeConverter()->convertType(srcType);
239  if (!dstType)
240  return failure();
241 
242  // Import the resource into the IR to make use of the special handling of
243  // element types later on.
244  mlir::DenseElementsAttr dstElementsAttr;
245  if (auto denseElementsAttr =
246  dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247  dstElementsAttr = denseElementsAttr;
248  } else if (auto resourceAttr =
249  dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250 
251  AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252  if (!blob)
253  return constOp->emitError("could not find resource blob");
254 
255  ArrayRef<char> ptr = blob->getData();
256 
257  // Check that the buffer meets the requirements to get converted to a
258  // DenseElementsAttr
259  bool detectedSplat = false;
260  if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261  return constOp->emitError("resource is not a valid buffer");
262 
263  dstElementsAttr =
264  DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265  } else {
266  return constOp->emitError("unsupported elements attribute");
267  }
268 
269  ShapedType dstAttrType = dstElementsAttr.getType();
270 
271  // If the composite type has more than one dimensions, perform
272  // linearization.
273  if (srcType.getRank() > 1) {
274  if (isa<RankedTensorType>(srcType)) {
275  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276  srcType.getElementType());
277  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278  } else {
279  // TODO: add support for large vectors.
280  return failure();
281  }
282  }
283 
284  Type srcElemType = srcType.getElementType();
285  Type dstElemType;
286  // Tensor types are converted to SPIR-V array types; vector types are
287  // converted to SPIR-V vector/array types.
288  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289  dstElemType = arrayType.getElementType();
290  else
291  dstElemType = cast<VectorType>(dstType).getElementType();
292 
293  // If the source and destination element types are different, perform
294  // attribute conversion.
295  if (srcElemType != dstElemType) {
296  SmallVector<Attribute, 8> elements;
297  if (isa<FloatType>(srcElemType)) {
298  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299  FloatAttr dstAttr =
300  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301  if (!dstAttr)
302  return failure();
303  elements.push_back(dstAttr);
304  }
305  } else if (srcElemType.isInteger(1)) {
306  return failure();
307  } else {
308  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309  IntegerAttr dstAttr = convertIntegerAttr(
310  srcAttr, cast<IntegerType>(dstElemType), rewriter);
311  if (!dstAttr)
312  return failure();
313  elements.push_back(dstAttr);
314  }
315  }
316 
317  // Unfortunately, we cannot use dialect-specific types for element
318  // attributes; element attributes only works with builtin types. So we
319  // need to prepare another converted builtin types for the destination
320  // elements attribute.
321  if (isa<RankedTensorType>(dstAttrType))
322  dstAttrType =
323  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324  else
325  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326 
327  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328  }
329 
330  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331  dstElementsAttr);
332  return success();
333  }
334 };
335 
336 /// Converts scalar arith.constant operation to spirv.Constant.
337 struct ConstantScalarOpPattern final
338  : public OpConversionPattern<arith::ConstantOp> {
340 
341  LogicalResult
342  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  Type srcType = constOp.getType();
345  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346  if (shapedType.getNumElements() != 1)
347  return failure();
348  srcType = shapedType.getElementType();
349  }
350  if (!srcType.isIntOrIndexOrFloat())
351  return failure();
352 
353  Attribute cstAttr = constOp.getValue();
354  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355  cstAttr = elementsAttr.getSplatValue<Attribute>();
356 
357  Type dstType = getTypeConverter()->convertType(srcType);
358  if (!dstType)
359  return failure();
360 
361  // Floating-point types.
362  if (isa<FloatType>(srcType)) {
363  auto srcAttr = cast<FloatAttr>(cstAttr);
364  auto dstAttr = srcAttr;
365 
366  // Floating-point types not supported in the target environment are all
367  // converted to float type.
368  if (srcType != dstType) {
369  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370  if (!dstAttr)
371  return failure();
372  }
373 
374  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375  return success();
376  }
377 
378  // Bool type.
379  if (srcType.isInteger(1)) {
380  // arith.constant can use 0/1 instead of true/false for i1 values. We need
381  // to handle that here.
382  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
383  if (!dstAttr)
384  return failure();
385  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386  return success();
387  }
388 
389  // IndexType or IntegerType. Index values are converted to 32-bit integer
390  // values when converting to SPIR-V.
391  auto srcAttr = cast<IntegerAttr>(cstAttr);
392  IntegerAttr dstAttr =
393  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394  if (!dstAttr)
395  return failure();
396  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397  return success();
398  }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // RemSIOp
403 //===----------------------------------------------------------------------===//
404 
405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406 /// the sign of `signOperand`.
407 ///
408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
411 /// if either operand can be negative. Emulate it via spirv.UMod.
412 template <typename SignedAbsOp>
413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414  Value signOperand, OpBuilder &builder) {
415  assert(lhs.getType() == rhs.getType());
416  assert(lhs == signOperand || rhs == signOperand);
417 
418  Type type = lhs.getType();
419 
420  // Calculate the remainder with spirv.UMod.
421  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424 
425  // Fix the sign.
426  Value isPositive;
427  if (lhs == signOperand)
428  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429  else
430  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433 }
434 
435 /// Converts arith.remsi to GLSL SPIR-V ops.
436 ///
437 /// This cannot be merged into the template unary/binary pattern due to Vulkan
438 /// restrictions over spirv.SRem and spirv.SMod.
439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
441 
442  LogicalResult
443  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447  adaptor.getOperands()[0], rewriter);
448  rewriter.replaceOp(op, result);
449 
450  return success();
451  }
452 };
453 
454 /// Converts arith.remsi to OpenCL SPIR-V ops.
455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
457 
458  LogicalResult
459  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460  ConversionPatternRewriter &rewriter) const override {
461  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463  adaptor.getOperands()[0], rewriter);
464  rewriter.replaceOp(op, result);
465 
466  return success();
467  }
468 };
469 
470 //===----------------------------------------------------------------------===//
471 // BitwiseOp
472 //===----------------------------------------------------------------------===//
473 
474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
475 /// other than the BinaryOpPatternPattern because if the operands are boolean
476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
481 
482  LogicalResult
483  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485  assert(adaptor.getOperands().size() == 2);
486  Type dstType = this->getTypeConverter()->convertType(op.getType());
487  if (!dstType)
488  return getTypeConversionFailure(rewriter, op);
489 
490  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492  op, dstType, adaptor.getOperands());
493  } else {
494  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495  op, dstType, adaptor.getOperands());
496  }
497  return success();
498  }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // XOrIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.xori to SPIR-V operations.
506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
508 
509  LogicalResult
510  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511  ConversionPatternRewriter &rewriter) const override {
512  assert(adaptor.getOperands().size() == 2);
513 
514  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515  return failure();
516 
517  Type dstType = getTypeConverter()->convertType(op.getType());
518  if (!dstType)
519  return getTypeConversionFailure(rewriter, op);
520 
521  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522  adaptor.getOperands());
523 
524  return success();
525  }
526 };
527 
528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529 /// vector of i1.
530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
532 
533  LogicalResult
534  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  assert(adaptor.getOperands().size() == 2);
537 
538  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539  return failure();
540 
541  Type dstType = getTypeConverter()->convertType(op.getType());
542  if (!dstType)
543  return getTypeConversionFailure(rewriter, op);
544 
545  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546  op, dstType, adaptor.getOperands());
547  return success();
548  }
549 };
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556 /// of i1.
557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
559 
560  LogicalResult
561  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  Type srcType = adaptor.getOperands().front().getType();
564  if (!isBoolScalarOrVector(srcType))
565  return failure();
566 
567  Type dstType = getTypeConverter()->convertType(op.getType());
568  if (!dstType)
569  return getTypeConversionFailure(rewriter, op);
570 
571  Location loc = op.getLoc();
572  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575  op, dstType, adaptor.getOperands().front(), one, zero);
576  return success();
577  }
578 };
579 
580 //===----------------------------------------------------------------------===//
581 // ExtSIOp
582 //===----------------------------------------------------------------------===//
583 
584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585 /// of i1.
586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
588 
589  LogicalResult
590  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  Value operand = adaptor.getIn();
593  if (!isBoolScalarOrVector(operand.getType()))
594  return failure();
595 
596  Location loc = op.getLoc();
597  Type dstType = getTypeConverter()->convertType(op.getType());
598  if (!dstType)
599  return getTypeConversionFailure(rewriter, op);
600 
601  Value allOnes;
602  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603  unsigned componentBitwidth = intTy.getWidth();
604  allOnes = rewriter.create<spirv::ConstantOp>(
605  loc, intTy,
606  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609  allOnes = rewriter.create<spirv::ConstantOp>(
610  loc, vectorTy,
611  SplatElementsAttr::get(vectorTy,
612  APInt::getAllOnes(componentBitwidth)));
613  } else {
614  return rewriter.notifyMatchFailure(
615  loc, llvm::formatv("unhandled type: {0}", dstType));
616  }
617 
618  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620  zero);
621  return success();
622  }
623 };
624 
625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626 /// vector of i1.
627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
629 
630  LogicalResult
631  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  Type srcType = adaptor.getIn().getType();
634  if (isBoolScalarOrVector(srcType))
635  return failure();
636 
637  Type dstType = getTypeConverter()->convertType(op.getType());
638  if (!dstType)
639  return getTypeConversionFailure(rewriter, op);
640 
641  if (dstType == srcType) {
642  // We can have the same source and destination type due to type emulation.
643  // Perform bit shifting to make sure we have the proper leading set bits.
644 
645  unsigned srcBW =
646  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647  unsigned dstBW =
649  assert(srcBW < dstBW);
650  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651  rewriter, op.getLoc());
652 
653  // First shift left to sequeeze out all leading bits beyond the original
654  // bitwidth. Here we need to use the original source and result type's
655  // bitwidth.
656  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658 
659  // Then we perform arithmetic right shift to make sure we have the right
660  // sign bits for negative values.
661  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662  op, dstType, shiftLOp, shiftSize);
663  } else {
664  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665  adaptor.getOperands());
666  }
667 
668  return success();
669  }
670 };
671 
672 //===----------------------------------------------------------------------===//
673 // ExtUIOp
674 //===----------------------------------------------------------------------===//
675 
676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677 /// of i1.
678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
680 
681  LogicalResult
682  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683  ConversionPatternRewriter &rewriter) const override {
684  Type srcType = adaptor.getOperands().front().getType();
685  if (!isBoolScalarOrVector(srcType))
686  return failure();
687 
688  Type dstType = getTypeConverter()->convertType(op.getType());
689  if (!dstType)
690  return getTypeConversionFailure(rewriter, op);
691 
692  Location loc = op.getLoc();
693  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696  op, dstType, adaptor.getOperands().front(), one, zero);
697  return success();
698  }
699 };
700 
701 /// Converts arith.extui for cases where the type of source is neither i1 nor
702 /// vector of i1.
703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
705 
706  LogicalResult
707  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708  ConversionPatternRewriter &rewriter) const override {
709  Type srcType = adaptor.getIn().getType();
710  if (isBoolScalarOrVector(srcType))
711  return failure();
712 
713  Type dstType = getTypeConverter()->convertType(op.getType());
714  if (!dstType)
715  return getTypeConversionFailure(rewriter, op);
716 
717  if (dstType == srcType) {
718  // We can have the same source and destination type due to type emulation.
719  // Perform bit masking to make sure we don't pollute downstream consumers
720  // with unwanted bits. Here we need to use the original source type's
721  // bitwidth.
722  unsigned bitwidth =
723  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
725  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
726  op.getLoc());
727  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728  adaptor.getIn(), mask);
729  } else {
730  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731  adaptor.getOperands());
732  }
733  return success();
734  }
735 };
736 
737 //===----------------------------------------------------------------------===//
738 // TruncIOp
739 //===----------------------------------------------------------------------===//
740 
741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742 /// of i1.
743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
745 
746  LogicalResult
747  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748  ConversionPatternRewriter &rewriter) const override {
749  Type dstType = getTypeConverter()->convertType(op.getType());
750  if (!dstType)
751  return getTypeConversionFailure(rewriter, op);
752 
753  if (!isBoolScalarOrVector(dstType))
754  return failure();
755 
756  Location loc = op.getLoc();
757  auto srcType = adaptor.getOperands().front().getType();
758  // Check if (x & 1) == 1.
759  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761  loc, srcType, adaptor.getOperands()[0], mask);
762  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763 
764  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767  return success();
768  }
769 };
770 
771 /// Converts arith.trunci for cases where the type of result is neither i1
772 /// nor vector of i1.
773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
775 
776  LogicalResult
777  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778  ConversionPatternRewriter &rewriter) const override {
779  Type srcType = adaptor.getIn().getType();
780  Type dstType = getTypeConverter()->convertType(op.getType());
781  if (!dstType)
782  return getTypeConversionFailure(rewriter, op);
783 
784  if (isBoolScalarOrVector(dstType))
785  return failure();
786 
787  if (dstType == srcType) {
788  // We can have the same source and destination type due to type emulation.
789  // Perform bit masking to make sure we don't pollute downstream consumers
790  // with unwanted bits. Here we need to use the original result type's
791  // bitwidth.
792  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
794  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
795  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796  adaptor.getIn(), mask);
797  } else {
798  // Given this is truncation, either SConvertOp or UConvertOp works.
799  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800  adaptor.getOperands());
801  }
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // TypeCastingOp
808 //===----------------------------------------------------------------------===//
809 
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812  switch (roundingMode) {
813  case arith::RoundingMode::downward:
814  return spirv::FPRoundingMode::RTN;
815  case arith::RoundingMode::to_nearest_even:
816  return spirv::FPRoundingMode::RTE;
817  case arith::RoundingMode::toward_zero:
818  return spirv::FPRoundingMode::RTZ;
819  case arith::RoundingMode::upward:
820  return spirv::FPRoundingMode::RTP;
821  case arith::RoundingMode::to_nearest_away:
822  // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823  // (as of SPIR-V 1.6)
824  return std::nullopt;
825  }
826  llvm_unreachable("Unhandled rounding mode");
827 }
828 
829 /// Converts type-casting standard operations to SPIR-V operations.
830 template <typename Op, typename SPIRVOp>
831 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
833 
834  LogicalResult
835  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836  ConversionPatternRewriter &rewriter) const override {
837  Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
838  Type dstType = this->getTypeConverter()->convertType(op.getType());
839  if (!dstType)
840  return getTypeConversionFailure(rewriter, op);
841 
842  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
843  return failure();
844 
845  if (dstType == srcType) {
846  // Due to type conversion, we are seeing the same source and target type.
847  // Then we can just erase this operation by forwarding its operand.
848  rewriter.replaceOp(op, adaptor.getOperands().front());
849  } else {
850  // Compute new rounding mode (if any).
851  std::optional<spirv::FPRoundingMode> rm = std::nullopt;
852  if (auto roundingModeOp =
853  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854  if (arith::RoundingModeAttr roundingMode =
855  roundingModeOp.getRoundingModeAttr()) {
856  if (!(rm =
857  convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
858  return rewriter.notifyMatchFailure(
859  op->getLoc(),
860  llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
861  }
862  }
863  }
864  // Create replacement op and attach rounding mode attribute (if any).
865  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
866  op, dstType, adaptor.getOperands());
867  if (rm) {
868  newOp->setAttr(
869  getDecorationString(spirv::Decoration::FPRoundingMode),
870  spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
871  }
872  }
873  return success();
874  }
875 };
876 
877 //===----------------------------------------------------------------------===//
878 // CmpIOp
879 //===----------------------------------------------------------------------===//
880 
881 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
882 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
883 public:
885 
886  LogicalResult
887  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
888  ConversionPatternRewriter &rewriter) const override {
889  Type srcType = op.getLhs().getType();
890  if (!isBoolScalarOrVector(srcType))
891  return failure();
892  Type dstType = getTypeConverter()->convertType(srcType);
893  if (!dstType)
894  return getTypeConversionFailure(rewriter, op, srcType);
895 
896  switch (op.getPredicate()) {
897  case arith::CmpIPredicate::eq: {
898  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
899  adaptor.getRhs());
900  return success();
901  }
902  case arith::CmpIPredicate::ne: {
903  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
904  op, adaptor.getLhs(), adaptor.getRhs());
905  return success();
906  }
907  case arith::CmpIPredicate::uge:
908  case arith::CmpIPredicate::ugt:
909  case arith::CmpIPredicate::ule:
910  case arith::CmpIPredicate::ult: {
911  // There are no direct corresponding instructions in SPIR-V for such
912  // cases. Extend them to 32-bit and do comparision then.
913  Type type = rewriter.getI32Type();
914  if (auto vectorType = dyn_cast<VectorType>(dstType))
915  type = VectorType::get(vectorType.getShape(), type);
916  Value extLhs =
917  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
918  Value extRhs =
919  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
920 
921  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
922  extRhs);
923  return success();
924  }
925  default:
926  break;
927  }
928  return failure();
929  }
930 };
931 
932 /// Converts integer compare operation to SPIR-V ops.
933 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
934 public:
936 
937  LogicalResult
938  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
939  ConversionPatternRewriter &rewriter) const override {
940  Type srcType = op.getLhs().getType();
941  if (isBoolScalarOrVector(srcType))
942  return failure();
943  Type dstType = getTypeConverter()->convertType(srcType);
944  if (!dstType)
945  return getTypeConversionFailure(rewriter, op, srcType);
946 
947  switch (op.getPredicate()) {
948 #define DISPATCH(cmpPredicate, spirvOp) \
949  case cmpPredicate: \
950  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
951  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
952  !hasSameBitwidth(srcType, dstType)) { \
953  return op.emitError( \
954  "bitwidth emulation is not implemented yet on unsigned op"); \
955  } \
956  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
957  adaptor.getRhs()); \
958  return success();
959 
960  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
961  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
962  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
963  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
964  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
965  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
966  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
967  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
968  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
969  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
970 
971 #undef DISPATCH
972  }
973  return failure();
974  }
975 };
976 
977 //===----------------------------------------------------------------------===//
978 // CmpFOpPattern
979 //===----------------------------------------------------------------------===//
980 
981 /// Converts floating-point comparison operations to SPIR-V ops.
982 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
983 public:
985 
986  LogicalResult
987  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
988  ConversionPatternRewriter &rewriter) const override {
989  switch (op.getPredicate()) {
990 #define DISPATCH(cmpPredicate, spirvOp) \
991  case cmpPredicate: \
992  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
993  adaptor.getRhs()); \
994  return success();
995 
996  // Ordered.
997  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
998  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
999  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1000  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1001  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1002  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1003  // Unordered.
1004  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1005  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1006  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1007  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1008  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1009  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1010 
1011 #undef DISPATCH
1012 
1013  default:
1014  break;
1015  }
1016  return failure();
1017  }
1018 };
1019 
1020 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
1021 /// Kernel capability.
1022 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1023 public:
1025 
1026  LogicalResult
1027  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1028  ConversionPatternRewriter &rewriter) const override {
1029  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1030  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1031  adaptor.getRhs());
1032  return success();
1033  }
1034 
1035  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1036  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1037  adaptor.getRhs());
1038  return success();
1039  }
1040 
1041  return failure();
1042  }
1043 };
1044 
1045 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1046 /// require additional capability.
1047 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1048 public:
1050 
1051  LogicalResult
1052  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1053  ConversionPatternRewriter &rewriter) const override {
1054  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1055  op.getPredicate() != arith::CmpFPredicate::UNO)
1056  return failure();
1057 
1058  Location loc = op.getLoc();
1059 
1060  Value replace;
1061  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1062  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1063  // Ordered comparsion checks if neither operand is NaN.
1064  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1065  } else {
1066  // Unordered comparsion checks if either operand is NaN.
1067  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1068  }
1069  } else {
1070  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1071  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1072 
1073  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1074  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1075  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1076  }
1077 
1078  rewriter.replaceOp(op, replace);
1079  return success();
1080  }
1081 };
1082 
1083 //===----------------------------------------------------------------------===//
1084 // AddUIExtendedOp
1085 //===----------------------------------------------------------------------===//
1086 
1087 /// Converts arith.addui_extended to spirv.IAddCarry.
1088 class AddUIExtendedOpPattern final
1089  : public OpConversionPattern<arith::AddUIExtendedOp> {
1090 public:
1092  LogicalResult
1093  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1094  ConversionPatternRewriter &rewriter) const override {
1095  Type dstElemTy = adaptor.getLhs().getType();
1096  Location loc = op->getLoc();
1097  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1098  adaptor.getRhs());
1099 
1100  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1101  loc, result, llvm::ArrayRef(0));
1102  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1103  loc, result, llvm::ArrayRef(1));
1104 
1105  // Convert the carry value to boolean.
1106  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1107  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1108 
1109  rewriter.replaceOp(op, {sumResult, carryResult});
1110  return success();
1111  }
1112 };
1113 
1114 //===----------------------------------------------------------------------===//
1115 // MulIExtendedOp
1116 //===----------------------------------------------------------------------===//
1117 
1118 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1119 template <typename ArithMulOp, typename SPIRVMulOp>
1120 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1121 public:
1123  LogicalResult
1124  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1125  ConversionPatternRewriter &rewriter) const override {
1126  Location loc = op->getLoc();
1127  Value result =
1128  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1129 
1130  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1131  llvm::ArrayRef(0));
1132  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1133  llvm::ArrayRef(1));
1134 
1135  rewriter.replaceOp(op, {low, high});
1136  return success();
1137  }
1138 };
1139 
1140 //===----------------------------------------------------------------------===//
1141 // SelectOp
1142 //===----------------------------------------------------------------------===//
1143 
1144 /// Converts arith.select to spirv.Select.
1145 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1146 public:
1148  LogicalResult
1149  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1150  ConversionPatternRewriter &rewriter) const override {
1151  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1152  adaptor.getTrueValue(),
1153  adaptor.getFalseValue());
1154  return success();
1155  }
1156 };
1157 
1158 //===----------------------------------------------------------------------===//
1159 // MinimumFOp, MaximumFOp
1160 //===----------------------------------------------------------------------===//
1161 
1162 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1163 /// spirv.CL.fmax/fmin.
1164 template <typename Op, typename SPIRVOp>
1165 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1166 public:
1168  LogicalResult
1169  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1170  ConversionPatternRewriter &rewriter) const override {
1171  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1172  Type dstType = converter->convertType(op.getType());
1173  if (!dstType)
1174  return getTypeConversionFailure(rewriter, op);
1175 
1176  // arith.maximumf/minimumf:
1177  // "if one of the arguments is NaN, then the result is also NaN."
1178  // spirv.GL.FMax/FMin
1179  // "which operand is the result is undefined if one of the operands
1180  // is a NaN."
1181  // spirv.CL.fmax/fmin:
1182  // "If one argument is a NaN, Fmin returns the other argument."
1183 
1184  Location loc = op.getLoc();
1185  Value spirvOp =
1186  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1187 
1188  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1189  rewriter.replaceOp(op, spirvOp);
1190  return success();
1191  }
1192 
1193  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1195 
1196  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197  adaptor.getLhs(), spirvOp);
1198  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199  adaptor.getRhs(), select1);
1200 
1201  rewriter.replaceOp(op, select2);
1202  return success();
1203  }
1204 };
1205 
1206 //===----------------------------------------------------------------------===//
1207 // MinNumFOp, MaxNumFOp
1208 //===----------------------------------------------------------------------===//
1209 
1210 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1211 /// spirv.CL.fmax/fmin.
1212 template <typename Op, typename SPIRVOp>
1213 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1214  template <typename TargetOp>
1215  constexpr bool shouldInsertNanGuards() const {
1216  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1217  }
1218 
1219 public:
1221  LogicalResult
1222  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1223  ConversionPatternRewriter &rewriter) const override {
1224  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1225  Type dstType = converter->convertType(op.getType());
1226  if (!dstType)
1227  return getTypeConversionFailure(rewriter, op);
1228 
1229  // arith.maxnumf/minnumf:
1230  // "If one of the arguments is NaN, then the result is the other
1231  // argument."
1232  // spirv.GL.FMax/FMin
1233  // "which operand is the result is undefined if one of the operands
1234  // is a NaN."
1235  // spirv.CL.fmax/fmin:
1236  // "If one argument is a NaN, Fmin returns the other argument."
1237 
1238  Location loc = op.getLoc();
1239  Value spirvOp =
1240  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1241 
1242  if (!shouldInsertNanGuards<SPIRVOp>() ||
1243  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1244  rewriter.replaceOp(op, spirvOp);
1245  return success();
1246  }
1247 
1248  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1249  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1250 
1251  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1252  adaptor.getRhs(), spirvOp);
1253  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1254  adaptor.getLhs(), select1);
1255 
1256  rewriter.replaceOp(op, select2);
1257  return success();
1258  }
1259 };
1260 
1261 } // namespace
1262 
1263 //===----------------------------------------------------------------------===//
1264 // Pattern Population
1265 //===----------------------------------------------------------------------===//
1266 
1268  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1269  // clang-format off
1270  patterns.add<
1271  ConstantCompositeOpPattern,
1272  ConstantScalarOpPattern,
1273  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1274  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1275  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1279  RemSIOpGLPattern, RemSIOpCLPattern,
1280  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1281  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1282  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1283  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1292  ExtUIPattern, ExtUII1Pattern,
1293  ExtSIPattern, ExtSII1Pattern,
1294  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1295  TruncIPattern, TruncII1Pattern,
1296  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1297  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1298  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1299  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1300  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1301  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1302  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1303  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1304  CmpIOpBooleanPattern, CmpIOpPattern,
1305  CmpFOpNanNonePattern, CmpFOpPattern,
1306  AddUIExtendedOpPattern,
1307  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1308  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1309  SelectOpPattern,
1310 
1311  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1312  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1313  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1314  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1319 
1320  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1321  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1322  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1323  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1328  >(typeConverter, patterns.getContext());
1329  // clang-format on
1330 
1331  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1332  // capability is available.
1333  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1334  /*benefit=*/2);
1335 }
1336 
1337 //===----------------------------------------------------------------------===//
1338 // Pass Definition
1339 //===----------------------------------------------------------------------===//
1340 
1341 namespace {
1342 struct ConvertArithToSPIRVPass
1343  : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1344  using Base::Base;
1345 
1346  void runOnOperation() override {
1347  Operation *op = getOperation();
1349  std::unique_ptr<SPIRVConversionTarget> target =
1350  SPIRVConversionTarget::get(targetAttr);
1351 
1353  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1354  SPIRVTypeConverter typeConverter(targetAttr, options);
1355 
1356  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1357  // in patterns for other dialects.
1358  target->addLegalOp<UnrealizedConversionCastOp>();
1359 
1360  // Fail hard when there are any remaining 'arith' ops.
1361  target->addIllegalDialect<arith::ArithDialect>();
1362 
1365 
1366  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1367  signalPassFailure();
1368  }
1369 };
1370 } // namespace
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:94
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getI32Type()
Definition: Builders.cpp:63
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:242
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:828
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:128
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23