MLIR  19.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46  return boolAttr;
47  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49  return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55  Builder builder) {
56  // If the source number uses less active bits than the target bitwidth, then
57  // it should be safe to convert.
58  if (srcAttr.getValue().isIntN(dstType.getWidth()))
59  return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61  // XXX: Try again by interpreting the source number as a signed value.
62  // Although integers in the standard dialect are signless, they can represent
63  // a signed number. It's the operation decides how to interpret. This is
64  // dangerous, but it seems there is no good way of handling this if we still
65  // want to change the bitwidth. Emit a message at least.
66  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69  << dstAttr << "' for type '" << dstType << "'\n");
70  return dstAttr;
71  }
72 
73  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74  << "' illegal: cannot fit into target type '"
75  << dstType << "'\n");
76  return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82  Builder builder) {
83  // Only support converting to float for now.
84  if (!dstType.isF32())
85  return FloatAttr();
86 
87  // Try to convert the source floating-point number to single precision.
88  APFloat dstVal = srcAttr.getValue();
89  bool losesInfo = false;
90  APFloat::opStatus status =
91  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92  if (status != APFloat::opOK || losesInfo) {
93  LLVM_DEBUG(llvm::dbgs()
94  << srcAttr << " illegal: cannot fit into converted type '"
95  << dstType << "'\n");
96  return FloatAttr();
97  }
98 
99  return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 /// Returns true if the given `type` is a boolean scalar or vector type.
103 static bool isBoolScalarOrVector(Type type) {
104  assert(type && "Not a valid type");
105  if (type.isInteger(1))
106  return true;
107 
108  if (auto vecType = dyn_cast<VectorType>(type))
109  return vecType.getElementType().isInteger(1);
110 
111  return false;
112 }
113 
114 /// Creates a scalar/vector integer constant.
115 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
116  OpBuilder &builder, Location loc) {
117  if (auto vectorType = dyn_cast<VectorType>(type)) {
118  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
119  auto attr = SplatElementsAttr::get(vectorType, element);
120  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
121  }
122 
123  if (auto intType = dyn_cast<IntegerType>(type))
124  return builder.create<spirv::ConstantOp>(
125  loc, type, builder.getIntegerAttr(type, value));
126 
127  return nullptr;
128 }
129 
130 /// Returns true if scalar/vector type `a` and `b` have the same number of
131 /// bitwidth.
132 static bool hasSameBitwidth(Type a, Type b) {
133  auto getNumBitwidth = [](Type type) {
134  unsigned bw = 0;
135  if (type.isIntOrFloat())
136  bw = type.getIntOrFloatBitWidth();
137  else if (auto vecType = dyn_cast<VectorType>(type))
138  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
139  return bw;
140  };
141  unsigned aBW = getNumBitwidth(a);
142  unsigned bBW = getNumBitwidth(b);
143  return aBW != 0 && bBW != 0 && aBW == bBW;
144 }
145 
146 /// Returns a source type conversion failure for `srcType` and operation `op`.
147 static LogicalResult
149  Type srcType) {
150  return rewriter.notifyMatchFailure(
151  op->getLoc(),
152  llvm::formatv("failed to convert source type '{0}'", srcType));
153 }
154 
155 /// Returns a source type conversion failure for the result type of `op`.
156 static LogicalResult
158  assert(op->getNumResults() == 1);
159  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
160 }
161 
162 // TODO: Move to some common place?
163 static std::string getDecorationString(spirv::Decoration decor) {
164  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
165 }
166 
167 namespace {
168 
169 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
170 /// operations. Op can potentially support overflow flags.
171 template <typename Op, typename SPIRVOp>
172 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
174 
175  LogicalResult
176  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  assert(adaptor.getOperands().size() <= 3);
179  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
180  Type dstType = converter->convertType(op.getType());
181  if (!dstType) {
182  return rewriter.notifyMatchFailure(
183  op->getLoc(),
184  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
185  }
186 
187  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
188  !getElementTypeOrSelf(op.getType()).isIndex() &&
189  dstType != op.getType()) {
190  return op.emitError("bitwidth emulation is not implemented yet on "
191  "unsigned op pattern version");
192  }
193 
194  auto overflowFlags = arith::IntegerOverflowFlags::none;
195  if (auto overflowIface =
196  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
197  if (converter->getTargetEnv().allows(
198  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
199  overflowFlags = overflowIface.getOverflowAttr().getValue();
200  }
201 
202  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
203  op, dstType, adaptor.getOperands());
204 
205  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
206  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
207  rewriter.getUnitAttr());
208 
209  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
210  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
211  rewriter.getUnitAttr());
212 
213  return success();
214  }
215 };
216 
217 //===----------------------------------------------------------------------===//
218 // ConstantOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Converts composite arith.constant operation to spirv.Constant.
222 struct ConstantCompositeOpPattern final
225 
226  LogicalResult
227  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
228  ConversionPatternRewriter &rewriter) const override {
229  auto srcType = dyn_cast<ShapedType>(constOp.getType());
230  if (!srcType || srcType.getNumElements() == 1)
231  return failure();
232 
233  // arith.constant should only have vector or tensor types. This is a MLIR
234  // wide problem at the moment.
235  if (!isa<VectorType, RankedTensorType>(srcType))
236  return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
237 
238  Type dstType = getTypeConverter()->convertType(srcType);
239  if (!dstType)
240  return failure();
241 
242  // Import the resource into the IR to make use of the special handling of
243  // element types later on.
244  mlir::DenseElementsAttr dstElementsAttr;
245  if (auto denseElementsAttr =
246  dyn_cast<DenseElementsAttr>(constOp.getValue())) {
247  dstElementsAttr = denseElementsAttr;
248  } else if (auto resourceAttr =
249  dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
250 
251  AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
252  if (!blob)
253  return constOp->emitError("could not find resource blob");
254 
255  ArrayRef<char> ptr = blob->getData();
256 
257  // Check that the buffer meets the requirements to get converted to a
258  // DenseElementsAttr
259  bool detectedSplat = false;
260  if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
261  return constOp->emitError("resource is not a valid buffer");
262 
263  dstElementsAttr =
264  DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
265  } else {
266  return constOp->emitError("unsupported elements attribute");
267  }
268 
269  ShapedType dstAttrType = dstElementsAttr.getType();
270 
271  // If the composite type has more than one dimensions, perform
272  // linearization.
273  if (srcType.getRank() > 1) {
274  if (isa<RankedTensorType>(srcType)) {
275  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
276  srcType.getElementType());
277  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
278  } else {
279  // TODO: add support for large vectors.
280  return failure();
281  }
282  }
283 
284  Type srcElemType = srcType.getElementType();
285  Type dstElemType;
286  // Tensor types are converted to SPIR-V array types; vector types are
287  // converted to SPIR-V vector/array types.
288  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
289  dstElemType = arrayType.getElementType();
290  else
291  dstElemType = cast<VectorType>(dstType).getElementType();
292 
293  // If the source and destination element types are different, perform
294  // attribute conversion.
295  if (srcElemType != dstElemType) {
296  SmallVector<Attribute, 8> elements;
297  if (isa<FloatType>(srcElemType)) {
298  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299  FloatAttr dstAttr =
300  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
301  if (!dstAttr)
302  return failure();
303  elements.push_back(dstAttr);
304  }
305  } else if (srcElemType.isInteger(1)) {
306  return failure();
307  } else {
308  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
309  IntegerAttr dstAttr = convertIntegerAttr(
310  srcAttr, cast<IntegerType>(dstElemType), rewriter);
311  if (!dstAttr)
312  return failure();
313  elements.push_back(dstAttr);
314  }
315  }
316 
317  // Unfortunately, we cannot use dialect-specific types for element
318  // attributes; element attributes only works with builtin types. So we
319  // need to prepare another converted builtin types for the destination
320  // elements attribute.
321  if (isa<RankedTensorType>(dstAttrType))
322  dstAttrType =
323  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
324  else
325  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
326 
327  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
328  }
329 
330  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
331  dstElementsAttr);
332  return success();
333  }
334 };
335 
336 /// Converts scalar arith.constant operation to spirv.Constant.
337 struct ConstantScalarOpPattern final
338  : public OpConversionPattern<arith::ConstantOp> {
340 
341  LogicalResult
342  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
343  ConversionPatternRewriter &rewriter) const override {
344  Type srcType = constOp.getType();
345  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
346  if (shapedType.getNumElements() != 1)
347  return failure();
348  srcType = shapedType.getElementType();
349  }
350  if (!srcType.isIntOrIndexOrFloat())
351  return failure();
352 
353  Attribute cstAttr = constOp.getValue();
354  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
355  cstAttr = elementsAttr.getSplatValue<Attribute>();
356 
357  Type dstType = getTypeConverter()->convertType(srcType);
358  if (!dstType)
359  return failure();
360 
361  // Floating-point types.
362  if (isa<FloatType>(srcType)) {
363  auto srcAttr = cast<FloatAttr>(cstAttr);
364  auto dstAttr = srcAttr;
365 
366  // Floating-point types not supported in the target environment are all
367  // converted to float type.
368  if (srcType != dstType) {
369  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370  if (!dstAttr)
371  return failure();
372  }
373 
374  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
375  return success();
376  }
377 
378  // Bool type.
379  if (srcType.isInteger(1)) {
380  // arith.constant can use 0/1 instead of true/false for i1 values. We need
381  // to handle that here.
382  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
383  if (!dstAttr)
384  return failure();
385  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
386  return success();
387  }
388 
389  // IndexType or IntegerType. Index values are converted to 32-bit integer
390  // values when converting to SPIR-V.
391  auto srcAttr = cast<IntegerAttr>(cstAttr);
392  IntegerAttr dstAttr =
393  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
394  if (!dstAttr)
395  return failure();
396  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
397  return success();
398  }
399 };
400 
401 //===----------------------------------------------------------------------===//
402 // RemSIOp
403 //===----------------------------------------------------------------------===//
404 
405 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
406 /// the sign of `signOperand`.
407 ///
408 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
409 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
410 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
411 /// if either operand can be negative. Emulate it via spirv.UMod.
412 template <typename SignedAbsOp>
413 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
414  Value signOperand, OpBuilder &builder) {
415  assert(lhs.getType() == rhs.getType());
416  assert(lhs == signOperand || rhs == signOperand);
417 
418  Type type = lhs.getType();
419 
420  // Calculate the remainder with spirv.UMod.
421  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
422  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
423  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
424 
425  // Fix the sign.
426  Value isPositive;
427  if (lhs == signOperand)
428  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
429  else
430  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
431  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
432  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
433 }
434 
435 /// Converts arith.remsi to GLSL SPIR-V ops.
436 ///
437 /// This cannot be merged into the template unary/binary pattern due to Vulkan
438 /// restrictions over spirv.SRem and spirv.SMod.
439 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
441 
442  LogicalResult
443  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
444  ConversionPatternRewriter &rewriter) const override {
445  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
446  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
447  adaptor.getOperands()[0], rewriter);
448  rewriter.replaceOp(op, result);
449 
450  return success();
451  }
452 };
453 
454 /// Converts arith.remsi to OpenCL SPIR-V ops.
455 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
457 
458  LogicalResult
459  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
460  ConversionPatternRewriter &rewriter) const override {
461  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
462  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
463  adaptor.getOperands()[0], rewriter);
464  rewriter.replaceOp(op, result);
465 
466  return success();
467  }
468 };
469 
470 //===----------------------------------------------------------------------===//
471 // BitwiseOp
472 //===----------------------------------------------------------------------===//
473 
474 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
475 /// other than the BinaryOpPatternPattern because if the operands are boolean
476 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
477 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
478 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
479 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
481 
482  LogicalResult
483  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485  assert(adaptor.getOperands().size() == 2);
486  Type dstType = this->getTypeConverter()->convertType(op.getType());
487  if (!dstType)
488  return getTypeConversionFailure(rewriter, op);
489 
490  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
492  op, dstType, adaptor.getOperands());
493  } else {
494  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
495  op, dstType, adaptor.getOperands());
496  }
497  return success();
498  }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // XOrIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.xori to SPIR-V operations.
506 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
508 
509  LogicalResult
510  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
511  ConversionPatternRewriter &rewriter) const override {
512  assert(adaptor.getOperands().size() == 2);
513 
514  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
515  return failure();
516 
517  Type dstType = getTypeConverter()->convertType(op.getType());
518  if (!dstType)
519  return getTypeConversionFailure(rewriter, op);
520 
521  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
522  adaptor.getOperands());
523 
524  return success();
525  }
526 };
527 
528 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
529 /// vector of i1.
530 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
532 
533  LogicalResult
534  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535  ConversionPatternRewriter &rewriter) const override {
536  assert(adaptor.getOperands().size() == 2);
537 
538  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539  return failure();
540 
541  Type dstType = getTypeConverter()->convertType(op.getType());
542  if (!dstType)
543  return getTypeConversionFailure(rewriter, op);
544 
545  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
546  op, dstType, adaptor.getOperands());
547  return success();
548  }
549 };
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPOp
553 //===----------------------------------------------------------------------===//
554 
555 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
556 /// of i1.
557 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
559 
560  LogicalResult
561  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
562  ConversionPatternRewriter &rewriter) const override {
563  Type srcType = adaptor.getOperands().front().getType();
564  if (!isBoolScalarOrVector(srcType))
565  return failure();
566 
567  Type dstType = getTypeConverter()->convertType(op.getType());
568  if (!dstType)
569  return getTypeConversionFailure(rewriter, op);
570 
571  Location loc = op.getLoc();
572  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
573  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
574  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
575  op, dstType, adaptor.getOperands().front(), one, zero);
576  return success();
577  }
578 };
579 
580 //===----------------------------------------------------------------------===//
581 // ExtSIOp
582 //===----------------------------------------------------------------------===//
583 
584 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
585 /// of i1.
586 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
588 
589  LogicalResult
590  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
591  ConversionPatternRewriter &rewriter) const override {
592  Value operand = adaptor.getIn();
593  if (!isBoolScalarOrVector(operand.getType()))
594  return failure();
595 
596  Location loc = op.getLoc();
597  Type dstType = getTypeConverter()->convertType(op.getType());
598  if (!dstType)
599  return getTypeConversionFailure(rewriter, op);
600 
601  Value allOnes;
602  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
603  unsigned componentBitwidth = intTy.getWidth();
604  allOnes = rewriter.create<spirv::ConstantOp>(
605  loc, intTy,
606  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
607  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
608  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
609  allOnes = rewriter.create<spirv::ConstantOp>(
610  loc, vectorTy,
611  SplatElementsAttr::get(vectorTy,
612  APInt::getAllOnes(componentBitwidth)));
613  } else {
614  return rewriter.notifyMatchFailure(
615  loc, llvm::formatv("unhandled type: {0}", dstType));
616  }
617 
618  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
619  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
620  zero);
621  return success();
622  }
623 };
624 
625 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
626 /// vector of i1.
627 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
629 
630  LogicalResult
631  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
632  ConversionPatternRewriter &rewriter) const override {
633  Type srcType = adaptor.getIn().getType();
634  if (isBoolScalarOrVector(srcType))
635  return failure();
636 
637  Type dstType = getTypeConverter()->convertType(op.getType());
638  if (!dstType)
639  return getTypeConversionFailure(rewriter, op);
640 
641  if (dstType == srcType) {
642  // We can have the same source and destination type due to type emulation.
643  // Perform bit shifting to make sure we have the proper leading set bits.
644 
645  unsigned srcBW =
646  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
647  unsigned dstBW =
649  assert(srcBW < dstBW);
650  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
651  rewriter, op.getLoc());
652 
653  // First shift left to sequeeze out all leading bits beyond the original
654  // bitwidth. Here we need to use the original source and result type's
655  // bitwidth.
656  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
657  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
658 
659  // Then we perform arithmetic right shift to make sure we have the right
660  // sign bits for negative values.
661  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
662  op, dstType, shiftLOp, shiftSize);
663  } else {
664  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
665  adaptor.getOperands());
666  }
667 
668  return success();
669  }
670 };
671 
672 //===----------------------------------------------------------------------===//
673 // ExtUIOp
674 //===----------------------------------------------------------------------===//
675 
676 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
677 /// of i1.
678 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
680 
681  LogicalResult
682  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
683  ConversionPatternRewriter &rewriter) const override {
684  Type srcType = adaptor.getOperands().front().getType();
685  if (!isBoolScalarOrVector(srcType))
686  return failure();
687 
688  Type dstType = getTypeConverter()->convertType(op.getType());
689  if (!dstType)
690  return getTypeConversionFailure(rewriter, op);
691 
692  Location loc = op.getLoc();
693  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696  op, dstType, adaptor.getOperands().front(), one, zero);
697  return success();
698  }
699 };
700 
701 /// Converts arith.extui for cases where the type of source is neither i1 nor
702 /// vector of i1.
703 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
705 
706  LogicalResult
707  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
708  ConversionPatternRewriter &rewriter) const override {
709  Type srcType = adaptor.getIn().getType();
710  if (isBoolScalarOrVector(srcType))
711  return failure();
712 
713  Type dstType = getTypeConverter()->convertType(op.getType());
714  if (!dstType)
715  return getTypeConversionFailure(rewriter, op);
716 
717  if (dstType == srcType) {
718  // We can have the same source and destination type due to type emulation.
719  // Perform bit masking to make sure we don't pollute downstream consumers
720  // with unwanted bits. Here we need to use the original source type's
721  // bitwidth.
722  unsigned bitwidth =
723  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
725  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
726  op.getLoc());
727  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
728  adaptor.getIn(), mask);
729  } else {
730  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
731  adaptor.getOperands());
732  }
733  return success();
734  }
735 };
736 
737 //===----------------------------------------------------------------------===//
738 // TruncIOp
739 //===----------------------------------------------------------------------===//
740 
741 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
742 /// of i1.
743 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
745 
746  LogicalResult
747  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
748  ConversionPatternRewriter &rewriter) const override {
749  Type dstType = getTypeConverter()->convertType(op.getType());
750  if (!dstType)
751  return getTypeConversionFailure(rewriter, op);
752 
753  if (!isBoolScalarOrVector(dstType))
754  return failure();
755 
756  Location loc = op.getLoc();
757  auto srcType = adaptor.getOperands().front().getType();
758  // Check if (x & 1) == 1.
759  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
760  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
761  loc, srcType, adaptor.getOperands()[0], mask);
762  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
763 
764  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
765  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
766  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
767  return success();
768  }
769 };
770 
771 /// Converts arith.trunci for cases where the type of result is neither i1
772 /// nor vector of i1.
773 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
775 
776  LogicalResult
777  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
778  ConversionPatternRewriter &rewriter) const override {
779  Type srcType = adaptor.getIn().getType();
780  Type dstType = getTypeConverter()->convertType(op.getType());
781  if (!dstType)
782  return getTypeConversionFailure(rewriter, op);
783 
784  if (isBoolScalarOrVector(dstType))
785  return failure();
786 
787  if (dstType == srcType) {
788  // We can have the same source and destination type due to type emulation.
789  // Perform bit masking to make sure we don't pollute downstream consumers
790  // with unwanted bits. Here we need to use the original result type's
791  // bitwidth.
792  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
794  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
795  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
796  adaptor.getIn(), mask);
797  } else {
798  // Given this is truncation, either SConvertOp or UConvertOp works.
799  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
800  adaptor.getOperands());
801  }
802  return success();
803  }
804 };
805 
806 //===----------------------------------------------------------------------===//
807 // TypeCastingOp
808 //===----------------------------------------------------------------------===//
809 
810 /// Converts type-casting standard operations to SPIR-V operations.
811 template <typename Op, typename SPIRVOp>
812 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
814 
815  LogicalResult
816  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
817  ConversionPatternRewriter &rewriter) const override {
818  assert(adaptor.getOperands().size() == 1);
819  Type srcType = adaptor.getOperands().front().getType();
820  Type dstType = this->getTypeConverter()->convertType(op.getType());
821  if (!dstType)
822  return getTypeConversionFailure(rewriter, op);
823 
824  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
825  return failure();
826 
827  if (dstType == srcType) {
828  // Due to type conversion, we are seeing the same source and target type.
829  // Then we can just erase this operation by forwarding its operand.
830  rewriter.replaceOp(op, adaptor.getOperands().front());
831  } else {
832  rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
833  adaptor.getOperands());
834  if (auto roundingModeOp =
835  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
836  if (arith::RoundingModeAttr roundingMode =
837  roundingModeOp.getRoundingModeAttr()) {
838  // TODO: Perform rounding mode attribute conversion and attach to new
839  // operation when defined in the dialect.
840  return failure();
841  }
842  }
843  }
844  return success();
845  }
846 };
847 
848 //===----------------------------------------------------------------------===//
849 // CmpIOp
850 //===----------------------------------------------------------------------===//
851 
852 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
853 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
854 public:
856 
857  LogicalResult
858  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
859  ConversionPatternRewriter &rewriter) const override {
860  Type srcType = op.getLhs().getType();
861  if (!isBoolScalarOrVector(srcType))
862  return failure();
863  Type dstType = getTypeConverter()->convertType(srcType);
864  if (!dstType)
865  return getTypeConversionFailure(rewriter, op, srcType);
866 
867  switch (op.getPredicate()) {
868  case arith::CmpIPredicate::eq: {
869  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
870  adaptor.getRhs());
871  return success();
872  }
873  case arith::CmpIPredicate::ne: {
874  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
875  op, adaptor.getLhs(), adaptor.getRhs());
876  return success();
877  }
878  case arith::CmpIPredicate::uge:
879  case arith::CmpIPredicate::ugt:
880  case arith::CmpIPredicate::ule:
881  case arith::CmpIPredicate::ult: {
882  // There are no direct corresponding instructions in SPIR-V for such
883  // cases. Extend them to 32-bit and do comparision then.
884  Type type = rewriter.getI32Type();
885  if (auto vectorType = dyn_cast<VectorType>(dstType))
886  type = VectorType::get(vectorType.getShape(), type);
887  Value extLhs =
888  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
889  Value extRhs =
890  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
891 
892  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
893  extRhs);
894  return success();
895  }
896  default:
897  break;
898  }
899  return failure();
900  }
901 };
902 
903 /// Converts integer compare operation to SPIR-V ops.
904 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
905 public:
907 
908  LogicalResult
909  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
910  ConversionPatternRewriter &rewriter) const override {
911  Type srcType = op.getLhs().getType();
912  if (isBoolScalarOrVector(srcType))
913  return failure();
914  Type dstType = getTypeConverter()->convertType(srcType);
915  if (!dstType)
916  return getTypeConversionFailure(rewriter, op, srcType);
917 
918  switch (op.getPredicate()) {
919 #define DISPATCH(cmpPredicate, spirvOp) \
920  case cmpPredicate: \
921  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
922  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
923  !hasSameBitwidth(srcType, dstType)) { \
924  return op.emitError( \
925  "bitwidth emulation is not implemented yet on unsigned op"); \
926  } \
927  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
928  adaptor.getRhs()); \
929  return success();
930 
931  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
932  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
933  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
934  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
935  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
936  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
937  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
938  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
939  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
940  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
941 
942 #undef DISPATCH
943  }
944  return failure();
945  }
946 };
947 
948 //===----------------------------------------------------------------------===//
949 // CmpFOpPattern
950 //===----------------------------------------------------------------------===//
951 
952 /// Converts floating-point comparison operations to SPIR-V ops.
953 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
954 public:
956 
957  LogicalResult
958  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
959  ConversionPatternRewriter &rewriter) const override {
960  switch (op.getPredicate()) {
961 #define DISPATCH(cmpPredicate, spirvOp) \
962  case cmpPredicate: \
963  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
964  adaptor.getRhs()); \
965  return success();
966 
967  // Ordered.
968  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
969  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
970  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
971  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
972  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
973  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
974  // Unordered.
975  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
976  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
977  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
978  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
979  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
980  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
981 
982 #undef DISPATCH
983 
984  default:
985  break;
986  }
987  return failure();
988  }
989 };
990 
991 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
992 /// Kernel capability.
993 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
994 public:
996 
997  LogicalResult
998  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
999  ConversionPatternRewriter &rewriter) const override {
1000  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1001  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1002  adaptor.getRhs());
1003  return success();
1004  }
1005 
1006  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1007  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1008  adaptor.getRhs());
1009  return success();
1010  }
1011 
1012  return failure();
1013  }
1014 };
1015 
1016 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1017 /// require additional capability.
1018 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1019 public:
1021 
1022  LogicalResult
1023  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1024  ConversionPatternRewriter &rewriter) const override {
1025  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1026  op.getPredicate() != arith::CmpFPredicate::UNO)
1027  return failure();
1028 
1029  Location loc = op.getLoc();
1030 
1031  Value replace;
1032  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1033  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1034  // Ordered comparsion checks if neither operand is NaN.
1035  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1036  } else {
1037  // Unordered comparsion checks if either operand is NaN.
1038  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1039  }
1040  } else {
1041  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1042  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1043 
1044  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
1045  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1046  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
1047  }
1048 
1049  rewriter.replaceOp(op, replace);
1050  return success();
1051  }
1052 };
1053 
1054 //===----------------------------------------------------------------------===//
1055 // AddUIExtendedOp
1056 //===----------------------------------------------------------------------===//
1057 
1058 /// Converts arith.addui_extended to spirv.IAddCarry.
1059 class AddUIExtendedOpPattern final
1060  : public OpConversionPattern<arith::AddUIExtendedOp> {
1061 public:
1063  LogicalResult
1064  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1065  ConversionPatternRewriter &rewriter) const override {
1066  Type dstElemTy = adaptor.getLhs().getType();
1067  Location loc = op->getLoc();
1068  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
1069  adaptor.getRhs());
1070 
1071  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
1072  loc, result, llvm::ArrayRef(0));
1073  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
1074  loc, result, llvm::ArrayRef(1));
1075 
1076  // Convert the carry value to boolean.
1077  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1078  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
1079 
1080  rewriter.replaceOp(op, {sumResult, carryResult});
1081  return success();
1082  }
1083 };
1084 
1085 //===----------------------------------------------------------------------===//
1086 // MulIExtendedOp
1087 //===----------------------------------------------------------------------===//
1088 
1089 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1090 template <typename ArithMulOp, typename SPIRVMulOp>
1091 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1092 public:
1094  LogicalResult
1095  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1096  ConversionPatternRewriter &rewriter) const override {
1097  Location loc = op->getLoc();
1098  Value result =
1099  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1100 
1101  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1102  llvm::ArrayRef(0));
1103  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1104  llvm::ArrayRef(1));
1105 
1106  rewriter.replaceOp(op, {low, high});
1107  return success();
1108  }
1109 };
1110 
1111 //===----------------------------------------------------------------------===//
1112 // SelectOp
1113 //===----------------------------------------------------------------------===//
1114 
1115 /// Converts arith.select to spirv.Select.
1116 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1117 public:
1119  LogicalResult
1120  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1121  ConversionPatternRewriter &rewriter) const override {
1122  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1123  adaptor.getTrueValue(),
1124  adaptor.getFalseValue());
1125  return success();
1126  }
1127 };
1128 
1129 //===----------------------------------------------------------------------===//
1130 // MinimumFOp, MaximumFOp
1131 //===----------------------------------------------------------------------===//
1132 
1133 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1134 /// spirv.CL.fmax/fmin.
1135 template <typename Op, typename SPIRVOp>
1136 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1137 public:
1139  LogicalResult
1140  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1141  ConversionPatternRewriter &rewriter) const override {
1142  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1143  Type dstType = converter->convertType(op.getType());
1144  if (!dstType)
1145  return getTypeConversionFailure(rewriter, op);
1146 
1147  // arith.maximumf/minimumf:
1148  // "if one of the arguments is NaN, then the result is also NaN."
1149  // spirv.GL.FMax/FMin
1150  // "which operand is the result is undefined if one of the operands
1151  // is a NaN."
1152  // spirv.CL.fmax/fmin:
1153  // "If one argument is a NaN, Fmin returns the other argument."
1154 
1155  Location loc = op.getLoc();
1156  Value spirvOp =
1157  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1158 
1159  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1160  rewriter.replaceOp(op, spirvOp);
1161  return success();
1162  }
1163 
1164  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1165  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1166 
1167  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1168  adaptor.getLhs(), spirvOp);
1169  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1170  adaptor.getRhs(), select1);
1171 
1172  rewriter.replaceOp(op, select2);
1173  return success();
1174  }
1175 };
1176 
1177 //===----------------------------------------------------------------------===//
1178 // MinNumFOp, MaxNumFOp
1179 //===----------------------------------------------------------------------===//
1180 
1181 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1182 /// spirv.CL.fmax/fmin.
1183 template <typename Op, typename SPIRVOp>
1184 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1185  template <typename TargetOp>
1186  constexpr bool shouldInsertNanGuards() const {
1187  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1188  }
1189 
1190 public:
1192  LogicalResult
1193  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1194  ConversionPatternRewriter &rewriter) const override {
1195  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1196  Type dstType = converter->convertType(op.getType());
1197  if (!dstType)
1198  return getTypeConversionFailure(rewriter, op);
1199 
1200  // arith.maxnumf/minnumf:
1201  // "If one of the arguments is NaN, then the result is the other
1202  // argument."
1203  // spirv.GL.FMax/FMin
1204  // "which operand is the result is undefined if one of the operands
1205  // is a NaN."
1206  // spirv.CL.fmax/fmin:
1207  // "If one argument is a NaN, Fmin returns the other argument."
1208 
1209  Location loc = op.getLoc();
1210  Value spirvOp =
1211  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1212 
1213  if (!shouldInsertNanGuards<SPIRVOp>() ||
1214  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1215  rewriter.replaceOp(op, spirvOp);
1216  return success();
1217  }
1218 
1219  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1220  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1221 
1222  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1223  adaptor.getRhs(), spirvOp);
1224  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1225  adaptor.getLhs(), select1);
1226 
1227  rewriter.replaceOp(op, select2);
1228  return success();
1229  }
1230 };
1231 
1232 } // namespace
1233 
1234 //===----------------------------------------------------------------------===//
1235 // Pattern Population
1236 //===----------------------------------------------------------------------===//
1237 
1239  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1240  // clang-format off
1241  patterns.add<
1242  ConstantCompositeOpPattern,
1243  ConstantScalarOpPattern,
1244  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1245  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1246  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1250  RemSIOpGLPattern, RemSIOpCLPattern,
1251  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1252  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1253  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1254  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1263  ExtUIPattern, ExtUII1Pattern,
1264  ExtSIPattern, ExtSII1Pattern,
1265  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1266  TruncIPattern, TruncII1Pattern,
1267  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1268  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1269  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1270  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1271  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1272  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1273  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1274  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1275  CmpIOpBooleanPattern, CmpIOpPattern,
1276  CmpFOpNanNonePattern, CmpFOpPattern,
1277  AddUIExtendedOpPattern,
1278  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1279  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1280  SelectOpPattern,
1281 
1282  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1283  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1284  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1285  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1290 
1291  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1292  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1293  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1294  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1299  >(typeConverter, patterns.getContext());
1300  // clang-format on
1301 
1302  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1303  // capability is available.
1304  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1305  /*benefit=*/2);
1306 }
1307 
1308 //===----------------------------------------------------------------------===//
1309 // Pass Definition
1310 //===----------------------------------------------------------------------===//
1311 
1312 namespace {
1313 struct ConvertArithToSPIRVPass
1314  : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1315  void runOnOperation() override {
1316  Operation *op = getOperation();
1318  std::unique_ptr<SPIRVConversionTarget> target =
1319  SPIRVConversionTarget::get(targetAttr);
1320 
1322  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1323  SPIRVTypeConverter typeConverter(targetAttr, options);
1324 
1325  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1326  // in patterns for other dialects.
1327  target->addLegalOp<UnrealizedConversionCastOp>();
1328 
1329  // Fail hard when there are any remaining 'arith' ops.
1330  target->addIllegalDialect<arith::ArithDialect>();
1331 
1332  RewritePatternSet patterns(&getContext());
1333  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1334 
1335  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1336  signalPassFailure();
1337  }
1338 };
1339 } // namespace
1340 
1341 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1342  return std::make_unique<ConvertArithToSPIRVPass>();
1343 }
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
The following classes enable support for parsing and printing resources within MLIR assembly formats.
Definition: AsmState.h:88
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:142
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI32Type()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:123
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23