MLIR  19.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <memory>
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "arith-to-spirv-pattern"
34 
35 using namespace mlir;
36 
37 //===----------------------------------------------------------------------===//
38 // Conversion Helpers
39 //===----------------------------------------------------------------------===//
40 
41 /// Converts the given `srcAttr` into a boolean attribute if it holds an
42 /// integral value. Returns null attribute if conversion fails.
43 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
44  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
45  return boolAttr;
46  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
47  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
48  return {};
49 }
50 
51 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
52 /// Returns null attribute if conversion fails.
53 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
54  Builder builder) {
55  // If the source number uses less active bits than the target bitwidth, then
56  // it should be safe to convert.
57  if (srcAttr.getValue().isIntN(dstType.getWidth()))
58  return builder.getIntegerAttr(dstType, srcAttr.getInt());
59 
60  // XXX: Try again by interpreting the source number as a signed value.
61  // Although integers in the standard dialect are signless, they can represent
62  // a signed number. It's the operation decides how to interpret. This is
63  // dangerous, but it seems there is no good way of handling this if we still
64  // want to change the bitwidth. Emit a message at least.
65  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
66  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
67  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
68  << dstAttr << "' for type '" << dstType << "'\n");
69  return dstAttr;
70  }
71 
72  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
73  << "' illegal: cannot fit into target type '"
74  << dstType << "'\n");
75  return {};
76 }
77 
78 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
79 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
80 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
81  Builder builder) {
82  // Only support converting to float for now.
83  if (!dstType.isF32())
84  return FloatAttr();
85 
86  // Try to convert the source floating-point number to single precision.
87  APFloat dstVal = srcAttr.getValue();
88  bool losesInfo = false;
89  APFloat::opStatus status =
90  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
91  if (status != APFloat::opOK || losesInfo) {
92  LLVM_DEBUG(llvm::dbgs()
93  << srcAttr << " illegal: cannot fit into converted type '"
94  << dstType << "'\n");
95  return FloatAttr();
96  }
97 
98  return builder.getF32FloatAttr(dstVal.convertToFloat());
99 }
100 
101 /// Returns true if the given `type` is a boolean scalar or vector type.
102 static bool isBoolScalarOrVector(Type type) {
103  assert(type && "Not a valid type");
104  if (type.isInteger(1))
105  return true;
106 
107  if (auto vecType = dyn_cast<VectorType>(type))
108  return vecType.getElementType().isInteger(1);
109 
110  return false;
111 }
112 
113 /// Creates a scalar/vector integer constant.
114 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
115  OpBuilder &builder, Location loc) {
116  if (auto vectorType = dyn_cast<VectorType>(type)) {
117  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
118  auto attr = SplatElementsAttr::get(vectorType, element);
119  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
120  }
121 
122  if (auto intType = dyn_cast<IntegerType>(type))
123  return builder.create<spirv::ConstantOp>(
124  loc, type, builder.getIntegerAttr(type, value));
125 
126  return nullptr;
127 }
128 
129 /// Returns true if scalar/vector type `a` and `b` have the same number of
130 /// bitwidth.
131 static bool hasSameBitwidth(Type a, Type b) {
132  auto getNumBitwidth = [](Type type) {
133  unsigned bw = 0;
134  if (type.isIntOrFloat())
135  bw = type.getIntOrFloatBitWidth();
136  else if (auto vecType = dyn_cast<VectorType>(type))
137  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
138  return bw;
139  };
140  unsigned aBW = getNumBitwidth(a);
141  unsigned bBW = getNumBitwidth(b);
142  return aBW != 0 && bBW != 0 && aBW == bBW;
143 }
144 
145 /// Returns a source type conversion failure for `srcType` and operation `op`.
146 static LogicalResult
148  Type srcType) {
149  return rewriter.notifyMatchFailure(
150  op->getLoc(),
151  llvm::formatv("failed to convert source type '{0}'", srcType));
152 }
153 
154 /// Returns a source type conversion failure for the result type of `op`.
155 static LogicalResult
157  assert(op->getNumResults() == 1);
158  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
159 }
160 
161 // TODO: Move to some common place?
162 static std::string getDecorationString(spirv::Decoration decor) {
163  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
164 }
165 
166 namespace {
167 
168 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
169 /// operations. Op can potentially support overflow flags.
170 template <typename Op, typename SPIRVOp>
171 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
173 
175  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
176  ConversionPatternRewriter &rewriter) const override {
177  assert(adaptor.getOperands().size() <= 3);
178  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
179  Type dstType = converter->convertType(op.getType());
180  if (!dstType) {
181  return rewriter.notifyMatchFailure(
182  op->getLoc(),
183  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
184  }
185 
186  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
187  !getElementTypeOrSelf(op.getType()).isIndex() &&
188  dstType != op.getType()) {
189  return op.emitError("bitwidth emulation is not implemented yet on "
190  "unsigned op pattern version");
191  }
192 
193  auto overflowFlags = arith::IntegerOverflowFlags::none;
194  if (auto overflowIface =
195  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
196  if (converter->getTargetEnv().allows(
197  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
198  overflowFlags = overflowIface.getOverflowAttr().getValue();
199  }
200 
201  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
202  op, dstType, adaptor.getOperands());
203 
204  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
205  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
206  rewriter.getUnitAttr());
207 
208  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
209  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
210  rewriter.getUnitAttr());
211 
212  return success();
213  }
214 };
215 
216 //===----------------------------------------------------------------------===//
217 // ConstantOp
218 //===----------------------------------------------------------------------===//
219 
220 /// Converts composite arith.constant operation to spirv.Constant.
221 struct ConstantCompositeOpPattern final
224 
226  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
227  ConversionPatternRewriter &rewriter) const override {
228  auto srcType = dyn_cast<ShapedType>(constOp.getType());
229  if (!srcType || srcType.getNumElements() == 1)
230  return failure();
231 
232  // arith.constant should only have vector or tenor types.
233  assert((isa<VectorType, RankedTensorType>(srcType)));
234 
235  Type dstType = getTypeConverter()->convertType(srcType);
236  if (!dstType)
237  return failure();
238 
239  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
240  if (!dstElementsAttr)
241  return failure();
242 
243  ShapedType dstAttrType = dstElementsAttr.getType();
244 
245  // If the composite type has more than one dimensions, perform
246  // linearization.
247  if (srcType.getRank() > 1) {
248  if (isa<RankedTensorType>(srcType)) {
249  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
250  srcType.getElementType());
251  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
252  } else {
253  // TODO: add support for large vectors.
254  return failure();
255  }
256  }
257 
258  Type srcElemType = srcType.getElementType();
259  Type dstElemType;
260  // Tensor types are converted to SPIR-V array types; vector types are
261  // converted to SPIR-V vector/array types.
262  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
263  dstElemType = arrayType.getElementType();
264  else
265  dstElemType = cast<VectorType>(dstType).getElementType();
266 
267  // If the source and destination element types are different, perform
268  // attribute conversion.
269  if (srcElemType != dstElemType) {
270  SmallVector<Attribute, 8> elements;
271  if (isa<FloatType>(srcElemType)) {
272  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
273  FloatAttr dstAttr =
274  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
275  if (!dstAttr)
276  return failure();
277  elements.push_back(dstAttr);
278  }
279  } else if (srcElemType.isInteger(1)) {
280  return failure();
281  } else {
282  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
283  IntegerAttr dstAttr = convertIntegerAttr(
284  srcAttr, cast<IntegerType>(dstElemType), rewriter);
285  if (!dstAttr)
286  return failure();
287  elements.push_back(dstAttr);
288  }
289  }
290 
291  // Unfortunately, we cannot use dialect-specific types for element
292  // attributes; element attributes only works with builtin types. So we
293  // need to prepare another converted builtin types for the destination
294  // elements attribute.
295  if (isa<RankedTensorType>(dstAttrType))
296  dstAttrType =
297  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
298  else
299  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
300 
301  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
302  }
303 
304  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
305  dstElementsAttr);
306  return success();
307  }
308 };
309 
310 /// Converts scalar arith.constant operation to spirv.Constant.
311 struct ConstantScalarOpPattern final
312  : public OpConversionPattern<arith::ConstantOp> {
314 
316  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
317  ConversionPatternRewriter &rewriter) const override {
318  Type srcType = constOp.getType();
319  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
320  if (shapedType.getNumElements() != 1)
321  return failure();
322  srcType = shapedType.getElementType();
323  }
324  if (!srcType.isIntOrIndexOrFloat())
325  return failure();
326 
327  Attribute cstAttr = constOp.getValue();
328  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
329  cstAttr = elementsAttr.getSplatValue<Attribute>();
330 
331  Type dstType = getTypeConverter()->convertType(srcType);
332  if (!dstType)
333  return failure();
334 
335  // Floating-point types.
336  if (isa<FloatType>(srcType)) {
337  auto srcAttr = cast<FloatAttr>(cstAttr);
338  auto dstAttr = srcAttr;
339 
340  // Floating-point types not supported in the target environment are all
341  // converted to float type.
342  if (srcType != dstType) {
343  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
344  if (!dstAttr)
345  return failure();
346  }
347 
348  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
349  return success();
350  }
351 
352  // Bool type.
353  if (srcType.isInteger(1)) {
354  // arith.constant can use 0/1 instead of true/false for i1 values. We need
355  // to handle that here.
356  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
357  if (!dstAttr)
358  return failure();
359  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
360  return success();
361  }
362 
363  // IndexType or IntegerType. Index values are converted to 32-bit integer
364  // values when converting to SPIR-V.
365  auto srcAttr = cast<IntegerAttr>(cstAttr);
366  IntegerAttr dstAttr =
367  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
368  if (!dstAttr)
369  return failure();
370  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
371  return success();
372  }
373 };
374 
375 //===----------------------------------------------------------------------===//
376 // RemSIOp
377 //===----------------------------------------------------------------------===//
378 
379 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
380 /// the sign of `signOperand`.
381 ///
382 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
383 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
384 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
385 /// if either operand can be negative. Emulate it via spirv.UMod.
386 template <typename SignedAbsOp>
387 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
388  Value signOperand, OpBuilder &builder) {
389  assert(lhs.getType() == rhs.getType());
390  assert(lhs == signOperand || rhs == signOperand);
391 
392  Type type = lhs.getType();
393 
394  // Calculate the remainder with spirv.UMod.
395  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
396  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
397  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
398 
399  // Fix the sign.
400  Value isPositive;
401  if (lhs == signOperand)
402  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
403  else
404  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
405  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
406  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
407 }
408 
409 /// Converts arith.remsi to GLSL SPIR-V ops.
410 ///
411 /// This cannot be merged into the template unary/binary pattern due to Vulkan
412 /// restrictions over spirv.SRem and spirv.SMod.
413 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
415 
417  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
418  ConversionPatternRewriter &rewriter) const override {
419  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
420  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
421  adaptor.getOperands()[0], rewriter);
422  rewriter.replaceOp(op, result);
423 
424  return success();
425  }
426 };
427 
428 /// Converts arith.remsi to OpenCL SPIR-V ops.
429 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
431 
433  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
434  ConversionPatternRewriter &rewriter) const override {
435  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
436  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
437  adaptor.getOperands()[0], rewriter);
438  rewriter.replaceOp(op, result);
439 
440  return success();
441  }
442 };
443 
444 //===----------------------------------------------------------------------===//
445 // BitwiseOp
446 //===----------------------------------------------------------------------===//
447 
448 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
449 /// other than the BinaryOpPatternPattern because if the operands are boolean
450 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
451 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
452 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
453 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
455 
457  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
458  ConversionPatternRewriter &rewriter) const override {
459  assert(adaptor.getOperands().size() == 2);
460  Type dstType = this->getTypeConverter()->convertType(op.getType());
461  if (!dstType)
462  return getTypeConversionFailure(rewriter, op);
463 
464  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
465  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
466  op, dstType, adaptor.getOperands());
467  } else {
468  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
469  op, dstType, adaptor.getOperands());
470  }
471  return success();
472  }
473 };
474 
475 //===----------------------------------------------------------------------===//
476 // XOrIOp
477 //===----------------------------------------------------------------------===//
478 
479 /// Converts arith.xori to SPIR-V operations.
480 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
482 
484  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486  assert(adaptor.getOperands().size() == 2);
487 
488  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
489  return failure();
490 
491  Type dstType = getTypeConverter()->convertType(op.getType());
492  if (!dstType)
493  return getTypeConversionFailure(rewriter, op);
494 
495  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
496  adaptor.getOperands());
497 
498  return success();
499  }
500 };
501 
502 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
503 /// vector of i1.
504 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
506 
508  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
509  ConversionPatternRewriter &rewriter) const override {
510  assert(adaptor.getOperands().size() == 2);
511 
512  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
513  return failure();
514 
515  Type dstType = getTypeConverter()->convertType(op.getType());
516  if (!dstType)
517  return getTypeConversionFailure(rewriter, op);
518 
519  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
520  op, dstType, adaptor.getOperands());
521  return success();
522  }
523 };
524 
525 //===----------------------------------------------------------------------===//
526 // UIToFPOp
527 //===----------------------------------------------------------------------===//
528 
529 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
530 /// of i1.
531 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
533 
535  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
536  ConversionPatternRewriter &rewriter) const override {
537  Type srcType = adaptor.getOperands().front().getType();
538  if (!isBoolScalarOrVector(srcType))
539  return failure();
540 
541  Type dstType = getTypeConverter()->convertType(op.getType());
542  if (!dstType)
543  return getTypeConversionFailure(rewriter, op);
544 
545  Location loc = op.getLoc();
546  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
547  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
548  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
549  op, dstType, adaptor.getOperands().front(), one, zero);
550  return success();
551  }
552 };
553 
554 //===----------------------------------------------------------------------===//
555 // ExtSIOp
556 //===----------------------------------------------------------------------===//
557 
558 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
559 /// of i1.
560 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
562 
564  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
565  ConversionPatternRewriter &rewriter) const override {
566  Value operand = adaptor.getIn();
567  if (!isBoolScalarOrVector(operand.getType()))
568  return failure();
569 
570  Location loc = op.getLoc();
571  Type dstType = getTypeConverter()->convertType(op.getType());
572  if (!dstType)
573  return getTypeConversionFailure(rewriter, op);
574 
575  Value allOnes;
576  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
577  unsigned componentBitwidth = intTy.getWidth();
578  allOnes = rewriter.create<spirv::ConstantOp>(
579  loc, intTy,
580  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
581  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
582  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
583  allOnes = rewriter.create<spirv::ConstantOp>(
584  loc, vectorTy,
585  SplatElementsAttr::get(vectorTy,
586  APInt::getAllOnes(componentBitwidth)));
587  } else {
588  return rewriter.notifyMatchFailure(
589  loc, llvm::formatv("unhandled type: {0}", dstType));
590  }
591 
592  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
593  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
594  zero);
595  return success();
596  }
597 };
598 
599 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
600 /// vector of i1.
601 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
603 
605  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
606  ConversionPatternRewriter &rewriter) const override {
607  Type srcType = adaptor.getIn().getType();
608  if (isBoolScalarOrVector(srcType))
609  return failure();
610 
611  Type dstType = getTypeConverter()->convertType(op.getType());
612  if (!dstType)
613  return getTypeConversionFailure(rewriter, op);
614 
615  if (dstType == srcType) {
616  // We can have the same source and destination type due to type emulation.
617  // Perform bit shifting to make sure we have the proper leading set bits.
618 
619  unsigned srcBW =
620  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
621  unsigned dstBW =
623  assert(srcBW < dstBW);
624  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
625  rewriter, op.getLoc());
626 
627  // First shift left to sequeeze out all leading bits beyond the original
628  // bitwidth. Here we need to use the original source and result type's
629  // bitwidth.
630  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
631  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
632 
633  // Then we perform arithmetic right shift to make sure we have the right
634  // sign bits for negative values.
635  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
636  op, dstType, shiftLOp, shiftSize);
637  } else {
638  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
639  adaptor.getOperands());
640  }
641 
642  return success();
643  }
644 };
645 
646 //===----------------------------------------------------------------------===//
647 // ExtUIOp
648 //===----------------------------------------------------------------------===//
649 
650 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
651 /// of i1.
652 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
654 
656  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
657  ConversionPatternRewriter &rewriter) const override {
658  Type srcType = adaptor.getOperands().front().getType();
659  if (!isBoolScalarOrVector(srcType))
660  return failure();
661 
662  Type dstType = getTypeConverter()->convertType(op.getType());
663  if (!dstType)
664  return getTypeConversionFailure(rewriter, op);
665 
666  Location loc = op.getLoc();
667  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
668  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
669  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
670  op, dstType, adaptor.getOperands().front(), one, zero);
671  return success();
672  }
673 };
674 
675 /// Converts arith.extui for cases where the type of source is neither i1 nor
676 /// vector of i1.
677 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
679 
681  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
682  ConversionPatternRewriter &rewriter) const override {
683  Type srcType = adaptor.getIn().getType();
684  if (isBoolScalarOrVector(srcType))
685  return failure();
686 
687  Type dstType = getTypeConverter()->convertType(op.getType());
688  if (!dstType)
689  return getTypeConversionFailure(rewriter, op);
690 
691  if (dstType == srcType) {
692  // We can have the same source and destination type due to type emulation.
693  // Perform bit masking to make sure we don't pollute downstream consumers
694  // with unwanted bits. Here we need to use the original source type's
695  // bitwidth.
696  unsigned bitwidth =
697  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
699  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
700  op.getLoc());
701  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
702  adaptor.getIn(), mask);
703  } else {
704  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
705  adaptor.getOperands());
706  }
707  return success();
708  }
709 };
710 
711 //===----------------------------------------------------------------------===//
712 // TruncIOp
713 //===----------------------------------------------------------------------===//
714 
715 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
716 /// of i1.
717 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
719 
721  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
722  ConversionPatternRewriter &rewriter) const override {
723  Type dstType = getTypeConverter()->convertType(op.getType());
724  if (!dstType)
725  return getTypeConversionFailure(rewriter, op);
726 
727  if (!isBoolScalarOrVector(dstType))
728  return failure();
729 
730  Location loc = op.getLoc();
731  auto srcType = adaptor.getOperands().front().getType();
732  // Check if (x & 1) == 1.
733  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
734  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
735  loc, srcType, adaptor.getOperands()[0], mask);
736  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
737 
738  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
739  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
740  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
741  return success();
742  }
743 };
744 
745 /// Converts arith.trunci for cases where the type of result is neither i1
746 /// nor vector of i1.
747 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
749 
751  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
752  ConversionPatternRewriter &rewriter) const override {
753  Type srcType = adaptor.getIn().getType();
754  Type dstType = getTypeConverter()->convertType(op.getType());
755  if (!dstType)
756  return getTypeConversionFailure(rewriter, op);
757 
758  if (isBoolScalarOrVector(dstType))
759  return failure();
760 
761  if (dstType == srcType) {
762  // We can have the same source and destination type due to type emulation.
763  // Perform bit masking to make sure we don't pollute downstream consumers
764  // with unwanted bits. Here we need to use the original result type's
765  // bitwidth.
766  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
768  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
769  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
770  adaptor.getIn(), mask);
771  } else {
772  // Given this is truncation, either SConvertOp or UConvertOp works.
773  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
774  adaptor.getOperands());
775  }
776  return success();
777  }
778 };
779 
780 //===----------------------------------------------------------------------===//
781 // TypeCastingOp
782 //===----------------------------------------------------------------------===//
783 
784 /// Converts type-casting standard operations to SPIR-V operations.
785 template <typename Op, typename SPIRVOp>
786 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
788 
790  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
791  ConversionPatternRewriter &rewriter) const override {
792  assert(adaptor.getOperands().size() == 1);
793  Type srcType = adaptor.getOperands().front().getType();
794  Type dstType = this->getTypeConverter()->convertType(op.getType());
795  if (!dstType)
796  return getTypeConversionFailure(rewriter, op);
797 
798  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
799  return failure();
800 
801  if (dstType == srcType) {
802  // Due to type conversion, we are seeing the same source and target type.
803  // Then we can just erase this operation by forwarding its operand.
804  rewriter.replaceOp(op, adaptor.getOperands().front());
805  } else {
806  rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
807  adaptor.getOperands());
808  if (auto roundingModeOp =
809  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
810  if (arith::RoundingModeAttr roundingMode =
811  roundingModeOp.getRoundingModeAttr()) {
812  // TODO: Perform rounding mode attribute conversion and attach to new
813  // operation when defined in the dialect.
814  return failure();
815  }
816  }
817  }
818  return success();
819  }
820 };
821 
822 //===----------------------------------------------------------------------===//
823 // CmpIOp
824 //===----------------------------------------------------------------------===//
825 
826 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
827 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
828 public:
830 
832  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
833  ConversionPatternRewriter &rewriter) const override {
834  Type srcType = op.getLhs().getType();
835  if (!isBoolScalarOrVector(srcType))
836  return failure();
837  Type dstType = getTypeConverter()->convertType(srcType);
838  if (!dstType)
839  return getTypeConversionFailure(rewriter, op, srcType);
840 
841  switch (op.getPredicate()) {
842  case arith::CmpIPredicate::eq: {
843  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
844  adaptor.getRhs());
845  return success();
846  }
847  case arith::CmpIPredicate::ne: {
848  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
849  op, adaptor.getLhs(), adaptor.getRhs());
850  return success();
851  }
852  case arith::CmpIPredicate::uge:
853  case arith::CmpIPredicate::ugt:
854  case arith::CmpIPredicate::ule:
855  case arith::CmpIPredicate::ult: {
856  // There are no direct corresponding instructions in SPIR-V for such
857  // cases. Extend them to 32-bit and do comparision then.
858  Type type = rewriter.getI32Type();
859  if (auto vectorType = dyn_cast<VectorType>(dstType))
860  type = VectorType::get(vectorType.getShape(), type);
861  Value extLhs =
862  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
863  Value extRhs =
864  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
865 
866  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
867  extRhs);
868  return success();
869  }
870  default:
871  break;
872  }
873  return failure();
874  }
875 };
876 
877 /// Converts integer compare operation to SPIR-V ops.
878 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
879 public:
881 
883  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
884  ConversionPatternRewriter &rewriter) const override {
885  Type srcType = op.getLhs().getType();
886  if (isBoolScalarOrVector(srcType))
887  return failure();
888  Type dstType = getTypeConverter()->convertType(srcType);
889  if (!dstType)
890  return getTypeConversionFailure(rewriter, op, srcType);
891 
892  switch (op.getPredicate()) {
893 #define DISPATCH(cmpPredicate, spirvOp) \
894  case cmpPredicate: \
895  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
896  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
897  !hasSameBitwidth(srcType, dstType)) { \
898  return op.emitError( \
899  "bitwidth emulation is not implemented yet on unsigned op"); \
900  } \
901  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
902  adaptor.getRhs()); \
903  return success();
904 
905  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
906  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
907  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
908  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
909  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
910  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
911  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
912  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
913  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
914  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
915 
916 #undef DISPATCH
917  }
918  return failure();
919  }
920 };
921 
922 //===----------------------------------------------------------------------===//
923 // CmpFOpPattern
924 //===----------------------------------------------------------------------===//
925 
926 /// Converts floating-point comparison operations to SPIR-V ops.
927 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
928 public:
930 
932  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
933  ConversionPatternRewriter &rewriter) const override {
934  switch (op.getPredicate()) {
935 #define DISPATCH(cmpPredicate, spirvOp) \
936  case cmpPredicate: \
937  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
938  adaptor.getRhs()); \
939  return success();
940 
941  // Ordered.
942  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
943  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
944  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
945  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
946  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
947  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
948  // Unordered.
949  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
950  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
951  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
952  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
953  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
954  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
955 
956 #undef DISPATCH
957 
958  default:
959  break;
960  }
961  return failure();
962  }
963 };
964 
965 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
966 /// Kernel capability.
967 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
968 public:
970 
972  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
973  ConversionPatternRewriter &rewriter) const override {
974  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
975  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
976  adaptor.getRhs());
977  return success();
978  }
979 
980  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
981  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
982  adaptor.getRhs());
983  return success();
984  }
985 
986  return failure();
987  }
988 };
989 
990 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
991 /// require additional capability.
992 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
993 public:
995 
997  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
998  ConversionPatternRewriter &rewriter) const override {
999  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1000  op.getPredicate() != arith::CmpFPredicate::UNO)
1001  return failure();
1002 
1003  Location loc = op.getLoc();
1004 
1005  Value replace;
1006  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1007  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1008  // Ordered comparsion checks if neither operand is NaN.
1009  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1010  } else {
1011  // Unordered comparsion checks if either operand is NaN.
1012  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1013  }
1014  } else {
1015  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1016  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1017 
1018  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1019  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1020  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1021  }
1022 
1023  rewriter.replaceOp(op, replace);
1024  return success();
1025  }
1026 };
1027 
1028 //===----------------------------------------------------------------------===//
1029 // AddUIExtendedOp
1030 //===----------------------------------------------------------------------===//
1031 
1032 /// Converts arith.addui_extended to spirv.IAddCarry.
1033 class AddUIExtendedOpPattern final
1034  : public OpConversionPattern<arith::AddUIExtendedOp> {
1035 public:
1038  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1039  ConversionPatternRewriter &rewriter) const override {
1040  Type dstElemTy = adaptor.getLhs().getType();
1041  Location loc = op->getLoc();
1042  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1043  adaptor.getRhs());
1044 
1045  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1046  loc, result, llvm::ArrayRef(0));
1047  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1048  loc, result, llvm::ArrayRef(1));
1049 
1050  // Convert the carry value to boolean.
1051  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1052  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1053 
1054  rewriter.replaceOp(op, {sumResult, carryResult});
1055  return success();
1056  }
1057 };
1058 
1059 //===----------------------------------------------------------------------===//
1060 // MulIExtendedOp
1061 //===----------------------------------------------------------------------===//
1062 
1063 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1064 template <typename ArithMulOp, typename SPIRVMulOp>
1065 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1066 public:
1069  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1070  ConversionPatternRewriter &rewriter) const override {
1071  Location loc = op->getLoc();
1072  Value result =
1073  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1074 
1075  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1076  llvm::ArrayRef(0));
1077  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1078  llvm::ArrayRef(1));
1079 
1080  rewriter.replaceOp(op, {low, high});
1081  return success();
1082  }
1083 };
1084 
1085 //===----------------------------------------------------------------------===//
1086 // SelectOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 /// Converts arith.select to spirv.Select.
1090 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1091 public:
1094  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1095  ConversionPatternRewriter &rewriter) const override {
1096  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1097  adaptor.getTrueValue(),
1098  adaptor.getFalseValue());
1099  return success();
1100  }
1101 };
1102 
1103 //===----------------------------------------------------------------------===//
1104 // MinimumFOp, MaximumFOp
1105 //===----------------------------------------------------------------------===//
1106 
1107 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1108 /// spirv.CL.fmax/fmin.
1109 template <typename Op, typename SPIRVOp>
1110 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1111 public:
1114  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1115  ConversionPatternRewriter &rewriter) const override {
1116  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1117  Type dstType = converter->convertType(op.getType());
1118  if (!dstType)
1119  return getTypeConversionFailure(rewriter, op);
1120 
1121  // arith.maximumf/minimumf:
1122  // "if one of the arguments is NaN, then the result is also NaN."
1123  // spirv.GL.FMax/FMin
1124  // "which operand is the result is undefined if one of the operands
1125  // is a NaN."
1126  // spirv.CL.fmax/fmin:
1127  // "If one argument is a NaN, Fmin returns the other argument."
1128 
1129  Location loc = op.getLoc();
1130  Value spirvOp =
1131  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1132 
1133  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1134  rewriter.replaceOp(op, spirvOp);
1135  return success();
1136  }
1137 
1138  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1139  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1140 
1141  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1142  adaptor.getLhs(), spirvOp);
1143  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1144  adaptor.getRhs(), select1);
1145 
1146  rewriter.replaceOp(op, select2);
1147  return success();
1148  }
1149 };
1150 
1151 //===----------------------------------------------------------------------===//
1152 // MinNumFOp, MaxNumFOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1156 /// spirv.CL.fmax/fmin.
1157 template <typename Op, typename SPIRVOp>
1158 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1159  template <typename TargetOp>
1160  constexpr bool shouldInsertNanGuards() const {
1161  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1162  }
1163 
1164 public:
1167  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1168  ConversionPatternRewriter &rewriter) const override {
1169  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1170  Type dstType = converter->convertType(op.getType());
1171  if (!dstType)
1172  return getTypeConversionFailure(rewriter, op);
1173 
1174  // arith.maxnumf/minnumf:
1175  // "If one of the arguments is NaN, then the result is the other
1176  // argument."
1177  // spirv.GL.FMax/FMin
1178  // "which operand is the result is undefined if one of the operands
1179  // is a NaN."
1180  // spirv.CL.fmax/fmin:
1181  // "If one argument is a NaN, Fmin returns the other argument."
1182 
1183  Location loc = op.getLoc();
1184  Value spirvOp =
1185  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1186 
1187  if (!shouldInsertNanGuards<SPIRVOp>() ||
1188  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1189  rewriter.replaceOp(op, spirvOp);
1190  return success();
1191  }
1192 
1193  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1194  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1195 
1196  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1197  adaptor.getRhs(), spirvOp);
1198  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1199  adaptor.getLhs(), select1);
1200 
1201  rewriter.replaceOp(op, select2);
1202  return success();
1203  }
1204 };
1205 
1206 } // namespace
1207 
1208 //===----------------------------------------------------------------------===//
1209 // Pattern Population
1210 //===----------------------------------------------------------------------===//
1211 
1213  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1214  // clang-format off
1215  patterns.add<
1216  ConstantCompositeOpPattern,
1217  ConstantScalarOpPattern,
1218  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1219  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1220  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1224  RemSIOpGLPattern, RemSIOpCLPattern,
1225  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1226  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1227  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1228  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1237  ExtUIPattern, ExtUII1Pattern,
1238  ExtSIPattern, ExtSII1Pattern,
1239  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1240  TruncIPattern, TruncII1Pattern,
1241  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1242  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1243  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1244  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1245  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1246  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1247  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1248  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1249  CmpIOpBooleanPattern, CmpIOpPattern,
1250  CmpFOpNanNonePattern, CmpFOpPattern,
1251  AddUIExtendedOpPattern,
1252  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1253  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1254  SelectOpPattern,
1255 
1256  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1257  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1258  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1259  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1264 
1265  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1266  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1267  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1268  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1273  >(typeConverter, patterns.getContext());
1274  // clang-format on
1275 
1276  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1277  // capability is available.
1278  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1279  /*benefit=*/2);
1280 }
1281 
1282 //===----------------------------------------------------------------------===//
1283 // Pass Definition
1284 //===----------------------------------------------------------------------===//
1285 
1286 namespace {
1287 struct ConvertArithToSPIRVPass
1288  : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1289  void runOnOperation() override {
1290  Operation *op = getOperation();
1292  std::unique_ptr<SPIRVConversionTarget> target =
1293  SPIRVConversionTarget::get(targetAttr);
1294 
1296  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1297  SPIRVTypeConverter typeConverter(targetAttr, options);
1298 
1299  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1300  // in patterns for other dialects.
1301  target->addLegalOp<UnrealizedConversionCastOp>();
1302 
1303  // Fail hard when there are any remaining 'arith' ops.
1304  target->addIllegalDialect<arith::ArithDialect>();
1305 
1306  RewritePatternSet patterns(&getContext());
1307  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1308 
1309  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1310  signalPassFailure();
1311  }
1312 };
1313 } // namespace
1314 
1315 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1316  return std::make_unique<ConvertArithToSPIRVPass>();
1317 }
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI32Type()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23