MLIR  20.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46  return boolAttr;
47  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49  return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55  Builder builder) {
56  // If the source number uses less active bits than the target bitwidth, then
57  // it should be safe to convert.
58  if (srcAttr.getValue().isIntN(dstType.getWidth()))
59  return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61  // XXX: Try again by interpreting the source number as a signed value.
62  // Although integers in the standard dialect are signless, they can represent
63  // a signed number. It's the operation decides how to interpret. This is
64  // dangerous, but it seems there is no good way of handling this if we still
65  // want to change the bitwidth. Emit a message at least.
66  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69  << dstAttr << "' for type '" << dstType << "'\n");
70  return dstAttr;
71  }
72 
73  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74  << "' illegal: cannot fit into target type '"
75  << dstType << "'\n");
76  return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82  Builder builder) {
83  // Only support converting to float for now.
84  if (!dstType.isF32())
85  return FloatAttr();
86 
87  // Try to convert the source floating-point number to single precision.
88  APFloat dstVal = srcAttr.getValue();
89  bool losesInfo = false;
90  APFloat::opStatus status =
91  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92  if (status != APFloat::opOK || losesInfo) {
93  LLVM_DEBUG(llvm::dbgs()
94  << srcAttr << " illegal: cannot fit into converted type '"
95  << dstType << "'\n");
96  return FloatAttr();
97  }
98 
99  return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 /// Returns true if the given `type` is a boolean scalar or vector type.
103 static bool isBoolScalarOrVector(Type type) {
104  assert(type && "Not a valid type");
105  if (type.isInteger(1))
106  return true;
107 
108  if (auto vecType = dyn_cast<VectorType>(type))
109  return vecType.getElementType().isInteger(1);
110 
111  return false;
112 }
113 
114 /// Creates a scalar/vector integer constant.
115 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116  OpBuilder &builder, Location loc) {
117  if (auto vectorType = dyn_cast<VectorType>(type)) {
118  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119  auto attr = SplatElementsAttr::get(vectorType, element);
120  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121  }
122 
123  if (auto intType = dyn_cast<IntegerType>(type))
124  return builder.create<spirv::ConstantOp>(
125  loc, type, builder.getIntegerAttr(type, value));
126 
127  return nullptr;
128 }
129 
130 /// Returns true if scalar/vector type `a` and `b` have the same number of
131 /// bitwidth.
132 static bool hasSameBitwidth(Type a, Type b) {
133  auto getNumBitwidth = [](Type type) {
134  unsigned bw = 0;
135  if (type.isIntOrFloat())
136  bw = type.getIntOrFloatBitWidth();
137  else if (auto vecType = dyn_cast<VectorType>(type))
138  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139  return bw;
140  };
141  unsigned aBW = getNumBitwidth(a);
142  unsigned bBW = getNumBitwidth(b);
143  return aBW != 0 && bBW != 0 && aBW == bBW;
144 }
145 
146 /// Returns a source type conversion failure for `srcType` and operation `op`.
147 static LogicalResult
149  Type srcType) {
150  return rewriter.notifyMatchFailure(
151  op->getLoc(),
152  llvm::formatv("failed to convert source type '{0}'", srcType));
153 }
154 
155 /// Returns a source type conversion failure for the result type of `op`.
156 static LogicalResult
158  assert(op->getNumResults() == 1);
159  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
160 }
161 
162 // TODO: Move to some common place?
163 static std::string getDecorationString(spirv::Decoration decor) {
164  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
165 }
166 
167 namespace {
168 
169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170 /// operations. Op can potentially support overflow flags.
171 template <typename Op, typename SPIRVOp>
172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
174 
175  LogicalResult
176  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  assert(adaptor.getOperands().size() <= 3);
179  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180  Type dstType = converter->convertType(op.getType());
181  if (!dstType) {
182  return rewriter.notifyMatchFailure(
183  op->getLoc(),
184  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185  }
186 
187  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188  !getElementTypeOrSelf(op.getType()).isIndex() &&
189  dstType != op.getType()) {
190  return op.emitError("bitwidth emulation is not implemented yet on "
191  "unsigned op pattern version");
192  }
193 
194  auto overflowFlags = arith::IntegerOverflowFlags::none;
195  if (auto overflowIface =
196  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197  if (converter->getTargetEnv().allows(
198  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199  overflowFlags = overflowIface.getOverflowAttr().getValue();
200  }
201 
202  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203  op, dstType, adaptor.getOperands());
204 
205  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207  rewriter.getUnitAttr());
208 
209  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211  rewriter.getUnitAttr());
212 
213  return success();
214  }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConstantOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Converts composite arith.constant operation to spirv.Constant.
222 struct ConstantCompositeOpPattern final
225 
226  LogicalResult
227  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  auto srcType = dyn_cast<ShapedType>(constOp.getType());
230  if (!srcType || srcType.getNumElements() == 1)
231  return failure();
232 
233  // arith.constant should only have vector or tensor types. This is a MLIR
234  // wide problem at the moment.
235  if (!isa<VectorType, RankedTensorType>(srcType))
236  return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237 
238  Type dstType = getTypeConverter()->convertType(srcType);
239  if (!dstType)
240  return failure();
241 
242  // Import the resource into the IR to make use of the special handling of
243  // element types later on.
244  mlir::DenseElementsAttr dstElementsAttr;
245  if (auto denseElementsAttr =
246  dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247  dstElementsAttr = denseElementsAttr;
248  } else if (auto resourceAttr =
249  dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250 
251  AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252  if (!blob)
253  return constOp->emitError("could not find resource blob");
254 
255  ArrayRef<char> ptr = blob->getData();
256 
257  // Check that the buffer meets the requirements to get converted to a
258  // DenseElementsAttr
259  bool detectedSplat = false;
260  if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261  return constOp->emitError("resource is not a valid buffer");
262 
263  dstElementsAttr =
264  DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265  } else {
266  return constOp->emitError("unsupported elements attribute");
267  }
268 
269  ShapedType dstAttrType = dstElementsAttr.getType();
270 
271  // If the composite type has more than one dimensions, perform
272  // linearization.
273  if (srcType.getRank() > 1) {
274  if (isa<RankedTensorType>(srcType)) {
275  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276  srcType.getElementType());
277  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278  } else {
279  // TODO: add support for large vectors.
280  return failure();
281  }
282  }
283 
284  Type srcElemType = srcType.getElementType();
285  Type dstElemType;
286  // Tensor types are converted to SPIR-V array types; vector types are
287  // converted to SPIR-V vector/array types.
288  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289  dstElemType = arrayType.getElementType();
290  else
291  dstElemType = cast<VectorType>(dstType).getElementType();
292 
293  // If the source and destination element types are different, perform
294  // attribute conversion.
295  if (srcElemType != dstElemType) {
296  SmallVector<Attribute, 8> elements;
297  if (isa<FloatType>(srcElemType)) {
298  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299  FloatAttr dstAttr =
300  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301  if (!dstAttr)
302  return failure();
303  elements.push_back(dstAttr);
304  }
305  } else if (srcElemType.isInteger(1)) {
306  return failure();
307  } else {
308  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309  IntegerAttr dstAttr = convertIntegerAttr(
310  srcAttr, cast<IntegerType>(dstElemType), rewriter);
311  if (!dstAttr)
312  return failure();
313  elements.push_back(dstAttr);
314  }
315  }
316 
317  // Unfortunately, we cannot use dialect-specific types for element
318  // attributes; element attributes only works with builtin types. So we
319  // need to prepare another converted builtin types for the destination
320  // elements attribute.
321  if (isa<RankedTensorType>(dstAttrType))
322  dstAttrType =
323  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324  else
325  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326 
327  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328  }
329 
330  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331  dstElementsAttr);
332  return success();
333  }
334 };
335 
336 /// Converts scalar arith.constant operation to spirv.Constant.
337 struct ConstantScalarOpPattern final
338  : public OpConversionPattern<arith::ConstantOp> {
340 
341  LogicalResult
342  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  Type srcType = constOp.getType();
345  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346  if (shapedType.getNumElements() != 1)
347  return failure();
348  srcType = shapedType.getElementType();
349  }
350  if (!srcType.isIntOrIndexOrFloat())
351  return failure();
352 
353  Attribute cstAttr = constOp.getValue();
354  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355  cstAttr = elementsAttr.getSplatValue<Attribute>();
356 
357  Type dstType = getTypeConverter()->convertType(srcType);
358  if (!dstType)
359  return failure();
360 
361  // Floating-point types.
362  if (isa<FloatType>(srcType)) {
363  auto srcAttr = cast<FloatAttr>(cstAttr);
364  auto dstAttr = srcAttr;
365 
366  // Floating-point types not supported in the target environment are all
367  // converted to float type.
368  if (srcType != dstType) {
369  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370  if (!dstAttr)
371  return failure();
372  }
373 
374  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375  return success();
376  }
377 
378  // Bool type.
379  if (srcType.isInteger(1)) {
380  // arith.constant can use 0/1 instead of true/false for i1 values. We need
381  // to handle that here.
382  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
383  if (!dstAttr)
384  return failure();
385  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386  return success();
387  }
388 
389  // IndexType or IntegerType. Index values are converted to 32-bit integer
390  // values when converting to SPIR-V.
391  auto srcAttr = cast<IntegerAttr>(cstAttr);
392  IntegerAttr dstAttr =
393  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394  if (!dstAttr)
395  return failure();
396  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397  return success();
398  }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // RemSIOp
403 //===----------------------------------------------------------------------===//
404 
405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406 /// the sign of `signOperand`.
407 ///
408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
411 /// if either operand can be negative. Emulate it via spirv.UMod.
412 template <typename SignedAbsOp>
413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414  Value signOperand, OpBuilder &builder) {
415  assert(lhs.getType() == rhs.getType());
416  assert(lhs == signOperand || rhs == signOperand);
417 
418  Type type = lhs.getType();
419 
420  // Calculate the remainder with spirv.UMod.
421  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424 
425  // Fix the sign.
426  Value isPositive;
427  if (lhs == signOperand)
428  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429  else
430  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433 }
434 
435 /// Converts arith.remsi to GLSL SPIR-V ops.
436 ///
437 /// This cannot be merged into the template unary/binary pattern due to Vulkan
438 /// restrictions over spirv.SRem and spirv.SMod.
439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
441 
442  LogicalResult
443  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447  adaptor.getOperands()[0], rewriter);
448  rewriter.replaceOp(op, result);
449 
450  return success();
451  }
452 };
453 
454 /// Converts arith.remsi to OpenCL SPIR-V ops.
455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
457 
458  LogicalResult
459  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460  ConversionPatternRewriter &rewriter) const override {
461  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463  adaptor.getOperands()[0], rewriter);
464  rewriter.replaceOp(op, result);
465 
466  return success();
467  }
468 };
469 
470 //===----------------------------------------------------------------------===//
471 // BitwiseOp
472 //===----------------------------------------------------------------------===//
473 
474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
475 /// other than the BinaryOpPatternPattern because if the operands are boolean
476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
481 
482  LogicalResult
483  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485  assert(adaptor.getOperands().size() == 2);
486  Type dstType = this->getTypeConverter()->convertType(op.getType());
487  if (!dstType)
488  return getTypeConversionFailure(rewriter, op);
489 
490  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492  op, dstType, adaptor.getOperands());
493  } else {
494  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495  op, dstType, adaptor.getOperands());
496  }
497  return success();
498  }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // XOrIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.xori to SPIR-V operations.
506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
508 
509  LogicalResult
510  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511  ConversionPatternRewriter &rewriter) const override {
512  assert(adaptor.getOperands().size() == 2);
513 
514  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515  return failure();
516 
517  Type dstType = getTypeConverter()->convertType(op.getType());
518  if (!dstType)
519  return getTypeConversionFailure(rewriter, op);
520 
521  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522  adaptor.getOperands());
523 
524  return success();
525  }
526 };
527 
528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529 /// vector of i1.
530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
532 
533  LogicalResult
534  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  assert(adaptor.getOperands().size() == 2);
537 
538  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539  return failure();
540 
541  Type dstType = getTypeConverter()->convertType(op.getType());
542  if (!dstType)
543  return getTypeConversionFailure(rewriter, op);
544 
545  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546  op, dstType, adaptor.getOperands());
547  return success();
548  }
549 };
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556 /// of i1.
557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
559 
560  LogicalResult
561  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  Type srcType = adaptor.getOperands().front().getType();
564  if (!isBoolScalarOrVector(srcType))
565  return failure();
566 
567  Type dstType = getTypeConverter()->convertType(op.getType());
568  if (!dstType)
569  return getTypeConversionFailure(rewriter, op);
570 
571  Location loc = op.getLoc();
572  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575  op, dstType, adaptor.getOperands().front(), one, zero);
576  return success();
577  }
578 };
579 
580 //===----------------------------------------------------------------------===//
581 // ExtSIOp
582 //===----------------------------------------------------------------------===//
583 
584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585 /// of i1.
586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
588 
589  LogicalResult
590  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  Value operand = adaptor.getIn();
593  if (!isBoolScalarOrVector(operand.getType()))
594  return failure();
595 
596  Location loc = op.getLoc();
597  Type dstType = getTypeConverter()->convertType(op.getType());
598  if (!dstType)
599  return getTypeConversionFailure(rewriter, op);
600 
601  Value allOnes;
602  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603  unsigned componentBitwidth = intTy.getWidth();
604  allOnes = rewriter.create<spirv::ConstantOp>(
605  loc, intTy,
606  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609  allOnes = rewriter.create<spirv::ConstantOp>(
610  loc, vectorTy,
611  SplatElementsAttr::get(vectorTy,
612  APInt::getAllOnes(componentBitwidth)));
613  } else {
614  return rewriter.notifyMatchFailure(
615  loc, llvm::formatv("unhandled type: {0}", dstType));
616  }
617 
618  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620  zero);
621  return success();
622  }
623 };
624 
625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626 /// vector of i1.
627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
629 
630  LogicalResult
631  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  Type srcType = adaptor.getIn().getType();
634  if (isBoolScalarOrVector(srcType))
635  return failure();
636 
637  Type dstType = getTypeConverter()->convertType(op.getType());
638  if (!dstType)
639  return getTypeConversionFailure(rewriter, op);
640 
641  if (dstType == srcType) {
642  // We can have the same source and destination type due to type emulation.
643  // Perform bit shifting to make sure we have the proper leading set bits.
644 
645  unsigned srcBW =
646  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647  unsigned dstBW =
649  assert(srcBW < dstBW);
650  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651  rewriter, op.getLoc());
652 
653  // First shift left to sequeeze out all leading bits beyond the original
654  // bitwidth. Here we need to use the original source and result type's
655  // bitwidth.
656  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658 
659  // Then we perform arithmetic right shift to make sure we have the right
660  // sign bits for negative values.
661  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662  op, dstType, shiftLOp, shiftSize);
663  } else {
664  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665  adaptor.getOperands());
666  }
667 
668  return success();
669  }
670 };
671 
672 //===----------------------------------------------------------------------===//
673 // ExtUIOp
674 //===----------------------------------------------------------------------===//
675 
676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677 /// of i1.
678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
680 
681  LogicalResult
682  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683  ConversionPatternRewriter &rewriter) const override {
684  Type srcType = adaptor.getOperands().front().getType();
685  if (!isBoolScalarOrVector(srcType))
686  return failure();
687 
688  Type dstType = getTypeConverter()->convertType(op.getType());
689  if (!dstType)
690  return getTypeConversionFailure(rewriter, op);
691 
692  Location loc = op.getLoc();
693  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696  op, dstType, adaptor.getOperands().front(), one, zero);
697  return success();
698  }
699 };
700 
701 /// Converts arith.extui for cases where the type of source is neither i1 nor
702 /// vector of i1.
703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
705 
706  LogicalResult
707  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708  ConversionPatternRewriter &rewriter) const override {
709  Type srcType = adaptor.getIn().getType();
710  if (isBoolScalarOrVector(srcType))
711  return failure();
712 
713  Type dstType = getTypeConverter()->convertType(op.getType());
714  if (!dstType)
715  return getTypeConversionFailure(rewriter, op);
716 
717  if (dstType == srcType) {
718  // We can have the same source and destination type due to type emulation.
719  // Perform bit masking to make sure we don't pollute downstream consumers
720  // with unwanted bits. Here we need to use the original source type's
721  // bitwidth.
722  unsigned bitwidth =
723  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
725  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
726  op.getLoc());
727  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728  adaptor.getIn(), mask);
729  } else {
730  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731  adaptor.getOperands());
732  }
733  return success();
734  }
735 };
736 
737 //===----------------------------------------------------------------------===//
738 // TruncIOp
739 //===----------------------------------------------------------------------===//
740 
741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742 /// of i1.
743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
745 
746  LogicalResult
747  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748  ConversionPatternRewriter &rewriter) const override {
749  Type dstType = getTypeConverter()->convertType(op.getType());
750  if (!dstType)
751  return getTypeConversionFailure(rewriter, op);
752 
753  if (!isBoolScalarOrVector(dstType))
754  return failure();
755 
756  Location loc = op.getLoc();
757  auto srcType = adaptor.getOperands().front().getType();
758  // Check if (x & 1) == 1.
759  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761  loc, srcType, adaptor.getOperands()[0], mask);
762  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763 
764  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767  return success();
768  }
769 };
770 
771 /// Converts arith.trunci for cases where the type of result is neither i1
772 /// nor vector of i1.
773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
775 
776  LogicalResult
777  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778  ConversionPatternRewriter &rewriter) const override {
779  Type srcType = adaptor.getIn().getType();
780  Type dstType = getTypeConverter()->convertType(op.getType());
781  if (!dstType)
782  return getTypeConversionFailure(rewriter, op);
783 
784  if (isBoolScalarOrVector(dstType))
785  return failure();
786 
787  if (dstType == srcType) {
788  // We can have the same source and destination type due to type emulation.
789  // Perform bit masking to make sure we don't pollute downstream consumers
790  // with unwanted bits. Here we need to use the original result type's
791  // bitwidth.
792  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
794  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
795  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796  adaptor.getIn(), mask);
797  } else {
798  // Given this is truncation, either SConvertOp or UConvertOp works.
799  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800  adaptor.getOperands());
801  }
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // TypeCastingOp
808 //===----------------------------------------------------------------------===//
809 
810 static std::optional<spirv::FPRoundingMode>
811 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
812  switch (roundingMode) {
813  case arith::RoundingMode::downward:
814  return spirv::FPRoundingMode::RTN;
815  case arith::RoundingMode::to_nearest_even:
816  return spirv::FPRoundingMode::RTE;
817  case arith::RoundingMode::toward_zero:
818  return spirv::FPRoundingMode::RTZ;
819  case arith::RoundingMode::upward:
820  return spirv::FPRoundingMode::RTP;
821  case arith::RoundingMode::to_nearest_away:
822  // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
823  // (as of SPIR-V 1.6)
824  return std::nullopt;
825  }
826  llvm_unreachable("Unhandled rounding mode");
827 }
828 
829 /// Converts type-casting standard operations to SPIR-V operations.
830 template <typename Op, typename SPIRVOp>
831 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
833 
834  LogicalResult
835  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
836  ConversionPatternRewriter &rewriter) const override {
837  assert(adaptor.getOperands().size() == 1);
838  Type srcType = adaptor.getOperands().front().getType();
839  Type dstType = this->getTypeConverter()->convertType(op.getType());
840  if (!dstType)
841  return getTypeConversionFailure(rewriter, op);
842 
843  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
844  return failure();
845 
846  if (dstType == srcType) {
847  // Due to type conversion, we are seeing the same source and target type.
848  // Then we can just erase this operation by forwarding its operand.
849  rewriter.replaceOp(op, adaptor.getOperands().front());
850  } else {
851  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
852  op, dstType, adaptor.getOperands());
853  if (auto roundingModeOp =
854  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
855  if (arith::RoundingModeAttr roundingMode =
856  roundingModeOp.getRoundingModeAttr()) {
857  if (auto rm =
858  convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
859  newOp->setAttr(
860  getDecorationString(spirv::Decoration::FPRoundingMode),
861  spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
862  } else {
863  return rewriter.notifyMatchFailure(
864  op->getLoc(),
865  llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
866  }
867  }
868  }
869  }
870  return success();
871  }
872 };
873 
874 //===----------------------------------------------------------------------===//
875 // CmpIOp
876 //===----------------------------------------------------------------------===//
877 
878 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
879 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
880 public:
882 
883  LogicalResult
884  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
885  ConversionPatternRewriter &rewriter) const override {
886  Type srcType = op.getLhs().getType();
887  if (!isBoolScalarOrVector(srcType))
888  return failure();
889  Type dstType = getTypeConverter()->convertType(srcType);
890  if (!dstType)
891  return getTypeConversionFailure(rewriter, op, srcType);
892 
893  switch (op.getPredicate()) {
894  case arith::CmpIPredicate::eq: {
895  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
896  adaptor.getRhs());
897  return success();
898  }
899  case arith::CmpIPredicate::ne: {
900  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
901  op, adaptor.getLhs(), adaptor.getRhs());
902  return success();
903  }
904  case arith::CmpIPredicate::uge:
905  case arith::CmpIPredicate::ugt:
906  case arith::CmpIPredicate::ule:
907  case arith::CmpIPredicate::ult: {
908  // There are no direct corresponding instructions in SPIR-V for such
909  // cases. Extend them to 32-bit and do comparision then.
910  Type type = rewriter.getI32Type();
911  if (auto vectorType = dyn_cast<VectorType>(dstType))
912  type = VectorType::get(vectorType.getShape(), type);
913  Value extLhs =
914  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
915  Value extRhs =
916  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
917 
918  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
919  extRhs);
920  return success();
921  }
922  default:
923  break;
924  }
925  return failure();
926  }
927 };
928 
929 /// Converts integer compare operation to SPIR-V ops.
930 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
931 public:
933 
934  LogicalResult
935  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
936  ConversionPatternRewriter &rewriter) const override {
937  Type srcType = op.getLhs().getType();
938  if (isBoolScalarOrVector(srcType))
939  return failure();
940  Type dstType = getTypeConverter()->convertType(srcType);
941  if (!dstType)
942  return getTypeConversionFailure(rewriter, op, srcType);
943 
944  switch (op.getPredicate()) {
945 #define DISPATCH(cmpPredicate, spirvOp) \
946  case cmpPredicate: \
947  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
948  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
949  !hasSameBitwidth(srcType, dstType)) { \
950  return op.emitError( \
951  "bitwidth emulation is not implemented yet on unsigned op"); \
952  } \
953  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
954  adaptor.getRhs()); \
955  return success();
956 
957  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
958  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
959  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
960  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
961  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
962  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
963  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
964  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
965  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
966  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
967 
968 #undef DISPATCH
969  }
970  return failure();
971  }
972 };
973 
974 //===----------------------------------------------------------------------===//
975 // CmpFOpPattern
976 //===----------------------------------------------------------------------===//
977 
978 /// Converts floating-point comparison operations to SPIR-V ops.
979 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
980 public:
982 
983  LogicalResult
984  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
985  ConversionPatternRewriter &rewriter) const override {
986  switch (op.getPredicate()) {
987 #define DISPATCH(cmpPredicate, spirvOp) \
988  case cmpPredicate: \
989  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
990  adaptor.getRhs()); \
991  return success();
992 
993  // Ordered.
994  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
995  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
996  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
997  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
998  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
999  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1000  // Unordered.
1001  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1002  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1003  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1004  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1005  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1006  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1007 
1008 #undef DISPATCH
1009 
1010  default:
1011  break;
1012  }
1013  return failure();
1014  }
1015 };
1016 
1017 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
1018 /// Kernel capability.
1019 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1020 public:
1022 
1023  LogicalResult
1024  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1025  ConversionPatternRewriter &rewriter) const override {
1026  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1027  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1028  adaptor.getRhs());
1029  return success();
1030  }
1031 
1032  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1033  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1034  adaptor.getRhs());
1035  return success();
1036  }
1037 
1038  return failure();
1039  }
1040 };
1041 
1042 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1043 /// require additional capability.
1044 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1045 public:
1047 
1048  LogicalResult
1049  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1050  ConversionPatternRewriter &rewriter) const override {
1051  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1052  op.getPredicate() != arith::CmpFPredicate::UNO)
1053  return failure();
1054 
1055  Location loc = op.getLoc();
1056 
1057  Value replace;
1058  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1059  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1060  // Ordered comparsion checks if neither operand is NaN.
1061  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1062  } else {
1063  // Unordered comparsion checks if either operand is NaN.
1064  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1065  }
1066  } else {
1067  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1068  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1069 
1070  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1071  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1072  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1073  }
1074 
1075  rewriter.replaceOp(op, replace);
1076  return success();
1077  }
1078 };
1079 
1080 //===----------------------------------------------------------------------===//
1081 // AddUIExtendedOp
1082 //===----------------------------------------------------------------------===//
1083 
1084 /// Converts arith.addui_extended to spirv.IAddCarry.
1085 class AddUIExtendedOpPattern final
1086  : public OpConversionPattern<arith::AddUIExtendedOp> {
1087 public:
1089  LogicalResult
1090  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1091  ConversionPatternRewriter &rewriter) const override {
1092  Type dstElemTy = adaptor.getLhs().getType();
1093  Location loc = op->getLoc();
1094  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1095  adaptor.getRhs());
1096 
1097  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1098  loc, result, llvm::ArrayRef(0));
1099  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1100  loc, result, llvm::ArrayRef(1));
1101 
1102  // Convert the carry value to boolean.
1103  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1104  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1105 
1106  rewriter.replaceOp(op, {sumResult, carryResult});
1107  return success();
1108  }
1109 };
1110 
1111 //===----------------------------------------------------------------------===//
1112 // MulIExtendedOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1116 template <typename ArithMulOp, typename SPIRVMulOp>
1117 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1118 public:
1120  LogicalResult
1121  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1122  ConversionPatternRewriter &rewriter) const override {
1123  Location loc = op->getLoc();
1124  Value result =
1125  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1126 
1127  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1128  llvm::ArrayRef(0));
1129  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1130  llvm::ArrayRef(1));
1131 
1132  rewriter.replaceOp(op, {low, high});
1133  return success();
1134  }
1135 };
1136 
1137 //===----------------------------------------------------------------------===//
1138 // SelectOp
1139 //===----------------------------------------------------------------------===//
1140 
1141 /// Converts arith.select to spirv.Select.
1142 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1143 public:
1145  LogicalResult
1146  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1147  ConversionPatternRewriter &rewriter) const override {
1148  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1149  adaptor.getTrueValue(),
1150  adaptor.getFalseValue());
1151  return success();
1152  }
1153 };
1154 
1155 //===----------------------------------------------------------------------===//
1156 // MinimumFOp, MaximumFOp
1157 //===----------------------------------------------------------------------===//
1158 
1159 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1160 /// spirv.CL.fmax/fmin.
1161 template <typename Op, typename SPIRVOp>
1162 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1163 public:
1165  LogicalResult
1166  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1167  ConversionPatternRewriter &rewriter) const override {
1168  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1169  Type dstType = converter->convertType(op.getType());
1170  if (!dstType)
1171  return getTypeConversionFailure(rewriter, op);
1172 
1173  // arith.maximumf/minimumf:
1174  // "if one of the arguments is NaN, then the result is also NaN."
1175  // spirv.GL.FMax/FMin
1176  // "which operand is the result is undefined if one of the operands
1177  // is a NaN."
1178  // spirv.CL.fmax/fmin:
1179  // "If one argument is a NaN, Fmin returns the other argument."
1180 
1181  Location loc = op.getLoc();
1182  Value spirvOp =
1183  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1184 
1185  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1186  rewriter.replaceOp(op, spirvOp);
1187  return success();
1188  }
1189 
1190  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1191  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1192 
1193  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1194  adaptor.getLhs(), spirvOp);
1195  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1196  adaptor.getRhs(), select1);
1197 
1198  rewriter.replaceOp(op, select2);
1199  return success();
1200  }
1201 };
1202 
1203 //===----------------------------------------------------------------------===//
1204 // MinNumFOp, MaxNumFOp
1205 //===----------------------------------------------------------------------===//
1206 
1207 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1208 /// spirv.CL.fmax/fmin.
1209 template <typename Op, typename SPIRVOp>
1210 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1211  template <typename TargetOp>
1212  constexpr bool shouldInsertNanGuards() const {
1213  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1214  }
1215 
1216 public:
1218  LogicalResult
1219  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1220  ConversionPatternRewriter &rewriter) const override {
1221  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1222  Type dstType = converter->convertType(op.getType());
1223  if (!dstType)
1224  return getTypeConversionFailure(rewriter, op);
1225 
1226  // arith.maxnumf/minnumf:
1227  // "If one of the arguments is NaN, then the result is the other
1228  // argument."
1229  // spirv.GL.FMax/FMin
1230  // "which operand is the result is undefined if one of the operands
1231  // is a NaN."
1232  // spirv.CL.fmax/fmin:
1233  // "If one argument is a NaN, Fmin returns the other argument."
1234 
1235  Location loc = op.getLoc();
1236  Value spirvOp =
1237  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1238 
1239  if (!shouldInsertNanGuards<SPIRVOp>() ||
1240  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1241  rewriter.replaceOp(op, spirvOp);
1242  return success();
1243  }
1244 
1245  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1246  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1247 
1248  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1249  adaptor.getRhs(), spirvOp);
1250  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1251  adaptor.getLhs(), select1);
1252 
1253  rewriter.replaceOp(op, select2);
1254  return success();
1255  }
1256 };
1257 
1258 } // namespace
1259 
1260 //===----------------------------------------------------------------------===//
1261 // Pattern Population
1262 //===----------------------------------------------------------------------===//
1263 
1265  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1266  // clang-format off
1267  patterns.add<
1268  ConstantCompositeOpPattern,
1269  ConstantScalarOpPattern,
1270  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1271  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1272  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1276  RemSIOpGLPattern, RemSIOpCLPattern,
1277  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1278  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1279  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1280  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1289  ExtUIPattern, ExtUII1Pattern,
1290  ExtSIPattern, ExtSII1Pattern,
1291  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1292  TruncIPattern, TruncII1Pattern,
1293  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1294  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1295  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1296  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1297  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1298  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1299  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1300  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1301  CmpIOpBooleanPattern, CmpIOpPattern,
1302  CmpFOpNanNonePattern, CmpFOpPattern,
1303  AddUIExtendedOpPattern,
1304  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1305  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1306  SelectOpPattern,
1307 
1308  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1309  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1310  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1311  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1316 
1317  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1318  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1319  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1320  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1325  >(typeConverter, patterns.getContext());
1326  // clang-format on
1327 
1328  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1329  // capability is available.
1330  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1331  /*benefit=*/2);
1332 }
1333 
1334 //===----------------------------------------------------------------------===//
1335 // Pass Definition
1336 //===----------------------------------------------------------------------===//
1337 
1338 namespace {
1339 struct ConvertArithToSPIRVPass
1340  : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1341  void runOnOperation() override {
1342  Operation *op = getOperation();
1344  std::unique_ptr<SPIRVConversionTarget> target =
1345  SPIRVConversionTarget::get(targetAttr);
1346 
1348  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1349  SPIRVTypeConverter typeConverter(targetAttr, options);
1350 
1351  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1352  // in patterns for other dialects.
1353  target->addLegalOp<UnrealizedConversionCastOp>();
1354 
1355  // Fail hard when there are any remaining 'arith' ops.
1356  target->addIllegalDialect<arith::ArithDialect>();
1357 
1358  RewritePatternSet patterns(&getContext());
1359  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1360 
1361  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1362  signalPassFailure();
1363  }
1364 };
1365 } // namespace
1366 
1367 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1368  return std::make_unique<ConvertArithToSPIRVPass>();
1369 }
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition: AsmState.h:90
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:144
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getI32Type()
Definition: Builders.cpp:107
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:286
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:131
bool isF32() const
Definition: Types.cpp:59
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23