MLIR  22.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
21 #include "llvm/ADT/APInt.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/MathExtras.h"
26 #include <cassert>
27 #include <memory>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 #define DEBUG_TYPE "arith-to-spirv-pattern"
35 
36 using namespace mlir;
37 
38 //===----------------------------------------------------------------------===//
39 // Conversion Helpers
40 //===----------------------------------------------------------------------===//
41 
42 /// Converts the given `srcAttr` into a boolean attribute if it holds an
43 /// integral value. Returns null attribute if conversion fails.
44 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46  return boolAttr;
47  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49  return {};
50 }
51 
52 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53 /// Returns null attribute if conversion fails.
54 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55  Builder builder) {
56  // If the source number uses less active bits than the target bitwidth, then
57  // it should be safe to convert.
58  if (srcAttr.getValue().isIntN(dstType.getWidth()))
59  return builder.getIntegerAttr(dstType, srcAttr.getInt());
60 
61  // XXX: Try again by interpreting the source number as a signed value.
62  // Although integers in the standard dialect are signless, they can represent
63  // a signed number. It's the operation decides how to interpret. This is
64  // dangerous, but it seems there is no good way of handling this if we still
65  // want to change the bitwidth. Emit a message at least.
66  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69  << dstAttr << "' for type '" << dstType << "'\n");
70  return dstAttr;
71  }
72 
73  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74  << "' illegal: cannot fit into target type '"
75  << dstType << "'\n");
76  return {};
77 }
78 
79 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82  Builder builder) {
83  // Only support converting to float for now.
84  if (!dstType.isF32())
85  return FloatAttr();
86 
87  // Try to convert the source floating-point number to single precision.
88  APFloat dstVal = srcAttr.getValue();
89  bool losesInfo = false;
90  APFloat::opStatus status =
91  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92  if (status != APFloat::opOK || losesInfo) {
93  LLVM_DEBUG(llvm::dbgs()
94  << srcAttr << " illegal: cannot fit into converted type '"
95  << dstType << "'\n");
96  return FloatAttr();
97  }
98 
99  return builder.getF32FloatAttr(dstVal.convertToFloat());
100 }
101 
102 // Get in IntegerAttr from FloatAttr while preserving the bits.
103 // Useful for converting float constants to integer constants while preserving
104 // the bits.
105 static IntegerAttr
106 getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107  ConversionPatternRewriter &rewriter) {
108  APFloat floatVal = floatAttr.getValue();
109  APInt intVal = floatVal.bitcastToAPInt();
110  return rewriter.getIntegerAttr(dstType, intVal);
111 }
112 
113 /// Returns true if the given `type` is a boolean scalar or vector type.
114 static bool isBoolScalarOrVector(Type type) {
115  assert(type && "Not a valid type");
116  if (type.isInteger(1))
117  return true;
118 
119  if (auto vecType = dyn_cast<VectorType>(type))
120  return vecType.getElementType().isInteger(1);
121 
122  return false;
123 }
124 
125 /// Creates a scalar/vector integer constant.
126 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127  OpBuilder &builder, Location loc) {
128  if (auto vectorType = dyn_cast<VectorType>(type)) {
129  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130  auto attr = SplatElementsAttr::get(vectorType, element);
131  return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132  }
133 
134  if (auto intType = dyn_cast<IntegerType>(type))
135  return spirv::ConstantOp::create(builder, loc, type,
136  builder.getIntegerAttr(type, value));
137 
138  return nullptr;
139 }
140 
141 /// Returns true if scalar/vector type `a` and `b` have the same number of
142 /// bitwidth.
143 static bool hasSameBitwidth(Type a, Type b) {
144  auto getNumBitwidth = [](Type type) {
145  unsigned bw = 0;
146  if (type.isIntOrFloat())
147  bw = type.getIntOrFloatBitWidth();
148  else if (auto vecType = dyn_cast<VectorType>(type))
149  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150  return bw;
151  };
152  unsigned aBW = getNumBitwidth(a);
153  unsigned bBW = getNumBitwidth(b);
154  return aBW != 0 && bBW != 0 && aBW == bBW;
155 }
156 
157 /// Returns a source type conversion failure for `srcType` and operation `op`.
158 static LogicalResult
160  Type srcType) {
161  return rewriter.notifyMatchFailure(
162  op->getLoc(),
163  llvm::formatv("failed to convert source type '{0}'", srcType));
164 }
165 
166 /// Returns a source type conversion failure for the result type of `op`.
167 static LogicalResult
169  assert(op->getNumResults() == 1);
170  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171 }
172 
173 // TODO: Move to some common place?
174 static std::string getDecorationString(spirv::Decoration decor) {
175  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
176 }
177 
178 namespace {
179 
180 /// Converts elementwise unary, binary and ternary arith operations to SPIR-V
181 /// operations. Op can potentially support overflow flags.
182 template <typename Op, typename SPIRVOp>
183 struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
185 
186  LogicalResult
187  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
188  ConversionPatternRewriter &rewriter) const override {
189  assert(adaptor.getOperands().size() <= 3);
190  auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
191  Type dstType = converter->convertType(op.getType());
192  if (!dstType) {
193  return rewriter.notifyMatchFailure(
194  op->getLoc(),
195  llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
196  }
197 
198  if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
199  !getElementTypeOrSelf(op.getType()).isIndex() &&
200  dstType != op.getType()) {
201  return op.emitError("bitwidth emulation is not implemented yet on "
202  "unsigned op pattern version");
203  }
204 
205  auto overflowFlags = arith::IntegerOverflowFlags::none;
206  if (auto overflowIface =
207  dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208  if (converter->getTargetEnv().allows(
209  spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210  overflowFlags = overflowIface.getOverflowAttr().getValue();
211  }
212 
213  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214  op, dstType, adaptor.getOperands());
215 
216  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
217  newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
218  rewriter.getUnitAttr());
219 
220  if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
221  newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
222  rewriter.getUnitAttr());
223 
224  return success();
225  }
226 };
227 
228 //===----------------------------------------------------------------------===//
229 // ConstantOp
230 //===----------------------------------------------------------------------===//
231 
232 /// Converts composite arith.constant operation to spirv.Constant.
233 struct ConstantCompositeOpPattern final
235  using Base::Base;
236 
237  LogicalResult
238  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const override {
240  auto srcType = dyn_cast<ShapedType>(constOp.getType());
241  if (!srcType || srcType.getNumElements() == 1)
242  return failure();
243 
244  // arith.constant should only have vector or tensor types. This is a MLIR
245  // wide problem at the moment.
246  if (!isa<VectorType, RankedTensorType>(srcType))
247  return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
248 
249  Type dstType = getTypeConverter()->convertType(srcType);
250  if (!dstType)
251  return failure();
252 
253  // Import the resource into the IR to make use of the special handling of
254  // element types later on.
255  mlir::DenseElementsAttr dstElementsAttr;
256  if (auto denseElementsAttr =
257  dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258  dstElementsAttr = denseElementsAttr;
259  } else if (auto resourceAttr =
260  dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
261 
262  AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
263  if (!blob)
264  return constOp->emitError("could not find resource blob");
265 
266  ArrayRef<char> ptr = blob->getData();
267 
268  // Check that the buffer meets the requirements to get converted to a
269  // DenseElementsAttr
270  bool detectedSplat = false;
271  if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
272  return constOp->emitError("resource is not a valid buffer");
273 
274  dstElementsAttr =
275  DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
276  } else {
277  return constOp->emitError("unsupported elements attribute");
278  }
279 
280  ShapedType dstAttrType = dstElementsAttr.getType();
281 
282  // If the composite type has more than one dimensions, perform
283  // linearization.
284  if (srcType.getRank() > 1) {
285  if (isa<RankedTensorType>(srcType)) {
286  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
287  srcType.getElementType());
288  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
289  } else {
290  // TODO: add support for large vectors.
291  return failure();
292  }
293  }
294 
295  Type srcElemType = srcType.getElementType();
296  Type dstElemType;
297  // Tensor types are converted to SPIR-V array types; vector types are
298  // converted to SPIR-V vector/array types.
299  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
300  dstElemType = arrayType.getElementType();
301  else
302  dstElemType = cast<VectorType>(dstType).getElementType();
303 
304  // If the source and destination element types are different, perform
305  // attribute conversion.
306  if (srcElemType != dstElemType) {
307  SmallVector<Attribute, 8> elements;
308  if (isa<FloatType>(srcElemType)) {
309  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
310  Attribute dstAttr = nullptr;
311  // Handle 8-bit float conversion to 8-bit integer.
312  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
313  if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314  srcElemType.getIntOrFloatBitWidth() == 8 &&
315  isa<IntegerType>(dstElemType)) {
316  dstAttr =
317  getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
318  } else {
319  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
320  rewriter);
321  }
322  if (!dstAttr)
323  return failure();
324  elements.push_back(dstAttr);
325  }
326  } else if (srcElemType.isInteger(1)) {
327  return failure();
328  } else {
329  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
330  IntegerAttr dstAttr = convertIntegerAttr(
331  srcAttr, cast<IntegerType>(dstElemType), rewriter);
332  if (!dstAttr)
333  return failure();
334  elements.push_back(dstAttr);
335  }
336  }
337 
338  // Unfortunately, we cannot use dialect-specific types for element
339  // attributes; element attributes only works with builtin types. So we
340  // need to prepare another converted builtin types for the destination
341  // elements attribute.
342  if (isa<RankedTensorType>(dstAttrType))
343  dstAttrType =
344  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
345  else
346  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
347 
348  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
349  }
350 
351  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
352  dstElementsAttr);
353  return success();
354  }
355 };
356 
357 /// Converts scalar arith.constant operation to spirv.Constant.
358 struct ConstantScalarOpPattern final
359  : public OpConversionPattern<arith::ConstantOp> {
360  using Base::Base;
361 
362  LogicalResult
363  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
364  ConversionPatternRewriter &rewriter) const override {
365  Type srcType = constOp.getType();
366  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
367  if (shapedType.getNumElements() != 1)
368  return failure();
369  srcType = shapedType.getElementType();
370  }
371  if (!srcType.isIntOrIndexOrFloat())
372  return failure();
373 
374  Attribute cstAttr = constOp.getValue();
375  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
376  cstAttr = elementsAttr.getSplatValue<Attribute>();
377 
378  Type dstType = getTypeConverter()->convertType(srcType);
379  if (!dstType)
380  return failure();
381 
382  // Floating-point types.
383  if (isa<FloatType>(srcType)) {
384  auto srcAttr = cast<FloatAttr>(cstAttr);
385  Attribute dstAttr = srcAttr;
386 
387  // Floating-point types not supported in the target environment are all
388  // converted to float type.
389  auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
390  if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391  srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
392  dstType.getIntOrFloatBitWidth() == 8) {
393  // If the source is an 8-bit float, convert it to a 8-bit integer.
394  dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
395  if (!dstAttr)
396  return failure();
397  } else if (srcType != dstType) {
398  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
399  if (!dstAttr)
400  return failure();
401  }
402 
403  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
404  return success();
405  }
406 
407  // Bool type.
408  if (srcType.isInteger(1)) {
409  // arith.constant can use 0/1 instead of true/false for i1 values. We need
410  // to handle that here.
411  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
412  if (!dstAttr)
413  return failure();
414  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
415  return success();
416  }
417 
418  // IndexType or IntegerType. Index values are converted to 32-bit integer
419  // values when converting to SPIR-V.
420  auto srcAttr = cast<IntegerAttr>(cstAttr);
421  IntegerAttr dstAttr =
422  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
423  if (!dstAttr)
424  return failure();
425  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
426  return success();
427  }
428 };
429 
430 //===----------------------------------------------------------------------===//
431 // RemSIOp
432 //===----------------------------------------------------------------------===//
433 
434 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
435 /// the sign of `signOperand`.
436 ///
437 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
438 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
439 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
440 /// if either operand can be negative. Emulate it via spirv.UMod.
441 template <typename SignedAbsOp>
442 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
443  Value signOperand, OpBuilder &builder) {
444  assert(lhs.getType() == rhs.getType());
445  assert(lhs == signOperand || rhs == signOperand);
446 
447  Type type = lhs.getType();
448 
449  // Calculate the remainder with spirv.UMod.
450  Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
451  Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
452  Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
453 
454  // Fix the sign.
455  Value isPositive;
456  if (lhs == signOperand)
457  isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
458  else
459  isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
460  Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
461  return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
462  absNegate);
463 }
464 
465 /// Converts arith.remsi to GLSL SPIR-V ops.
466 ///
467 /// This cannot be merged into the template unary/binary pattern due to Vulkan
468 /// restrictions over spirv.SRem and spirv.SMod.
469 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
470  using Base::Base;
471 
472  LogicalResult
473  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
474  ConversionPatternRewriter &rewriter) const override {
475  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
476  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
477  adaptor.getOperands()[0], rewriter);
478  rewriter.replaceOp(op, result);
479 
480  return success();
481  }
482 };
483 
484 /// Converts arith.remsi to OpenCL SPIR-V ops.
485 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
486  using Base::Base;
487 
488  LogicalResult
489  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
490  ConversionPatternRewriter &rewriter) const override {
491  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
492  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
493  adaptor.getOperands()[0], rewriter);
494  rewriter.replaceOp(op, result);
495 
496  return success();
497  }
498 };
499 
500 //===----------------------------------------------------------------------===//
501 // BitwiseOp
502 //===----------------------------------------------------------------------===//
503 
504 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
505 /// other than the BinaryOpPatternPattern because if the operands are boolean
506 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
507 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
508 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
509 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
511 
512  LogicalResult
513  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
514  ConversionPatternRewriter &rewriter) const override {
515  assert(adaptor.getOperands().size() == 2);
516  Type dstType = this->getTypeConverter()->convertType(op.getType());
517  if (!dstType)
518  return getTypeConversionFailure(rewriter, op);
519 
520  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
521  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
522  op, dstType, adaptor.getOperands());
523  } else {
524  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
525  op, dstType, adaptor.getOperands());
526  }
527  return success();
528  }
529 };
530 
531 //===----------------------------------------------------------------------===//
532 // XOrIOp
533 //===----------------------------------------------------------------------===//
534 
535 /// Converts arith.xori to SPIR-V operations.
536 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
537  using Base::Base;
538 
539  LogicalResult
540  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
541  ConversionPatternRewriter &rewriter) const override {
542  assert(adaptor.getOperands().size() == 2);
543 
544  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
545  return failure();
546 
547  Type dstType = getTypeConverter()->convertType(op.getType());
548  if (!dstType)
549  return getTypeConversionFailure(rewriter, op);
550 
551  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
552  adaptor.getOperands());
553 
554  return success();
555  }
556 };
557 
558 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
559 /// vector of i1.
560 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
561  using Base::Base;
562 
563  LogicalResult
564  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
565  ConversionPatternRewriter &rewriter) const override {
566  assert(adaptor.getOperands().size() == 2);
567 
568  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
569  return failure();
570 
571  Type dstType = getTypeConverter()->convertType(op.getType());
572  if (!dstType)
573  return getTypeConversionFailure(rewriter, op);
574 
575  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
576  op, dstType, adaptor.getOperands());
577  return success();
578  }
579 };
580 
581 //===----------------------------------------------------------------------===//
582 // UIToFPOp
583 //===----------------------------------------------------------------------===//
584 
585 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
586 /// of i1.
587 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
588  using Base::Base;
589 
590  LogicalResult
591  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
592  ConversionPatternRewriter &rewriter) const override {
593  Type srcType = adaptor.getOperands().front().getType();
594  if (!isBoolScalarOrVector(srcType))
595  return failure();
596 
597  Type dstType = getTypeConverter()->convertType(op.getType());
598  if (!dstType)
599  return getTypeConversionFailure(rewriter, op);
600 
601  Location loc = op.getLoc();
602  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
603  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
604  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
605  op, dstType, adaptor.getOperands().front(), one, zero);
606  return success();
607  }
608 };
609 
610 //===----------------------------------------------------------------------===//
611 // IndexCastOp
612 //===----------------------------------------------------------------------===//
613 
614 /// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615 struct IndexCastIndexI1Pattern final
616  : public OpConversionPattern<arith::IndexCastOp> {
617  using Base::Base;
618 
619  LogicalResult
620  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  if (!isBoolScalarOrVector(op.getType()))
623  return failure();
624 
625  Type dstType = getTypeConverter()->convertType(op.getType());
626  if (!dstType)
627  return getTypeConversionFailure(rewriter, op);
628 
629  Location loc = op.getLoc();
630  Value zeroIdx =
631  spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
632  rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
633  adaptor.getIn());
634  return success();
635  }
636 };
637 
638 /// Converts arith.index_cast to spirv.Select if the source type is i1.
639 struct IndexCastI1IndexPattern final
640  : public OpConversionPattern<arith::IndexCastOp> {
641  using Base::Base;
642 
643  LogicalResult
644  matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
645  ConversionPatternRewriter &rewriter) const override {
646  if (!isBoolScalarOrVector(adaptor.getIn().getType()))
647  return failure();
648 
649  Type dstType = getTypeConverter()->convertType(op.getType());
650  if (!dstType)
651  return getTypeConversionFailure(rewriter, op);
652 
653  Location loc = op.getLoc();
654  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
655  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
656  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
657  one, zero);
658  return success();
659  }
660 };
661 
662 //===----------------------------------------------------------------------===//
663 // ExtSIOp
664 //===----------------------------------------------------------------------===//
665 
666 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
667 /// of i1.
668 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
669  using Base::Base;
670 
671  LogicalResult
672  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
673  ConversionPatternRewriter &rewriter) const override {
674  Value operand = adaptor.getIn();
675  if (!isBoolScalarOrVector(operand.getType()))
676  return failure();
677 
678  Location loc = op.getLoc();
679  Type dstType = getTypeConverter()->convertType(op.getType());
680  if (!dstType)
681  return getTypeConversionFailure(rewriter, op);
682 
683  Value allOnes;
684  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
685  unsigned componentBitwidth = intTy.getWidth();
686  allOnes = spirv::ConstantOp::create(
687  rewriter, loc, intTy,
688  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
689  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
690  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
691  allOnes = spirv::ConstantOp::create(
692  rewriter, loc, vectorTy,
693  SplatElementsAttr::get(vectorTy,
694  APInt::getAllOnes(componentBitwidth)));
695  } else {
696  return rewriter.notifyMatchFailure(
697  loc, llvm::formatv("unhandled type: {0}", dstType));
698  }
699 
700  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
701  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
702  zero);
703  return success();
704  }
705 };
706 
707 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
708 /// vector of i1.
709 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
710  using Base::Base;
711 
712  LogicalResult
713  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
714  ConversionPatternRewriter &rewriter) const override {
715  Type srcType = adaptor.getIn().getType();
716  if (isBoolScalarOrVector(srcType))
717  return failure();
718 
719  Type dstType = getTypeConverter()->convertType(op.getType());
720  if (!dstType)
721  return getTypeConversionFailure(rewriter, op);
722 
723  if (dstType == srcType) {
724  // We can have the same source and destination type due to type emulation.
725  // Perform bit shifting to make sure we have the proper leading set bits.
726 
727  unsigned srcBW =
728  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
729  unsigned dstBW =
731  assert(srcBW < dstBW);
732  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
733  rewriter, op.getLoc());
734 
735  // First shift left to sequeeze out all leading bits beyond the original
736  // bitwidth. Here we need to use the original source and result type's
737  // bitwidth.
738  auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
739  rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
740 
741  // Then we perform arithmetic right shift to make sure we have the right
742  // sign bits for negative values.
743  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
744  op, dstType, shiftLOp, shiftSize);
745  } else {
746  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
747  adaptor.getOperands());
748  }
749 
750  return success();
751  }
752 };
753 
754 //===----------------------------------------------------------------------===//
755 // ExtUIOp
756 //===----------------------------------------------------------------------===//
757 
758 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
759 /// of i1.
760 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
761  using Base::Base;
762 
763  LogicalResult
764  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
765  ConversionPatternRewriter &rewriter) const override {
766  Type srcType = adaptor.getOperands().front().getType();
767  if (!isBoolScalarOrVector(srcType))
768  return failure();
769 
770  Type dstType = getTypeConverter()->convertType(op.getType());
771  if (!dstType)
772  return getTypeConversionFailure(rewriter, op);
773 
774  Location loc = op.getLoc();
775  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
776  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
778  op, dstType, adaptor.getOperands().front(), one, zero);
779  return success();
780  }
781 };
782 
783 /// Converts arith.extui for cases where the type of source is neither i1 nor
784 /// vector of i1.
785 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
786  using Base::Base;
787 
788  LogicalResult
789  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
790  ConversionPatternRewriter &rewriter) const override {
791  Type srcType = adaptor.getIn().getType();
792  if (isBoolScalarOrVector(srcType))
793  return failure();
794 
795  Type dstType = getTypeConverter()->convertType(op.getType());
796  if (!dstType)
797  return getTypeConversionFailure(rewriter, op);
798 
799  if (dstType == srcType) {
800  // We can have the same source and destination type due to type emulation.
801  // Perform bit masking to make sure we don't pollute downstream consumers
802  // with unwanted bits. Here we need to use the original source type's
803  // bitwidth.
804  unsigned bitwidth =
805  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
807  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
808  op.getLoc());
809  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
810  adaptor.getIn(), mask);
811  } else {
812  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
813  adaptor.getOperands());
814  }
815  return success();
816  }
817 };
818 
819 //===----------------------------------------------------------------------===//
820 // TruncIOp
821 //===----------------------------------------------------------------------===//
822 
823 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
824 /// of i1.
825 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
826  using Base::Base;
827 
828  LogicalResult
829  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
830  ConversionPatternRewriter &rewriter) const override {
831  Type dstType = getTypeConverter()->convertType(op.getType());
832  if (!dstType)
833  return getTypeConversionFailure(rewriter, op);
834 
835  if (!isBoolScalarOrVector(dstType))
836  return failure();
837 
838  Location loc = op.getLoc();
839  auto srcType = adaptor.getOperands().front().getType();
840  // Check if (x & 1) == 1.
841  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
842  Value maskedSrc = spirv::BitwiseAndOp::create(
843  rewriter, loc, srcType, adaptor.getOperands()[0], mask);
844  Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
845 
846  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
847  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
848  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
849  return success();
850  }
851 };
852 
853 /// Converts arith.trunci for cases where the type of result is neither i1
854 /// nor vector of i1.
855 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
856  using Base::Base;
857 
858  LogicalResult
859  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
860  ConversionPatternRewriter &rewriter) const override {
861  Type srcType = adaptor.getIn().getType();
862  Type dstType = getTypeConverter()->convertType(op.getType());
863  if (!dstType)
864  return getTypeConversionFailure(rewriter, op);
865 
866  if (isBoolScalarOrVector(dstType))
867  return failure();
868 
869  if (dstType == srcType) {
870  // We can have the same source and destination type due to type emulation.
871  // Perform bit masking to make sure we don't pollute downstream consumers
872  // with unwanted bits. Here we need to use the original result type's
873  // bitwidth.
874  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
876  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
877  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
878  adaptor.getIn(), mask);
879  } else {
880  // Given this is truncation, either SConvertOp or UConvertOp works.
881  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
882  adaptor.getOperands());
883  }
884  return success();
885  }
886 };
887 
888 //===----------------------------------------------------------------------===//
889 // TypeCastingOp
890 //===----------------------------------------------------------------------===//
891 
892 static std::optional<spirv::FPRoundingMode>
893 convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
894  switch (roundingMode) {
895  case arith::RoundingMode::downward:
896  return spirv::FPRoundingMode::RTN;
897  case arith::RoundingMode::to_nearest_even:
898  return spirv::FPRoundingMode::RTE;
899  case arith::RoundingMode::toward_zero:
900  return spirv::FPRoundingMode::RTZ;
901  case arith::RoundingMode::upward:
902  return spirv::FPRoundingMode::RTP;
903  case arith::RoundingMode::to_nearest_away:
904  // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
905  // (as of SPIR-V 1.6)
906  return std::nullopt;
907  }
908  llvm_unreachable("Unhandled rounding mode");
909 }
910 
911 /// Converts type-casting standard operations to SPIR-V operations.
912 template <typename Op, typename SPIRVOp>
913 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
915 
916  LogicalResult
917  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
918  ConversionPatternRewriter &rewriter) const override {
919  Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
920  Type dstType = this->getTypeConverter()->convertType(op.getType());
921  if (!dstType)
922  return getTypeConversionFailure(rewriter, op);
923 
924  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
925  return failure();
926 
927  if (dstType == srcType) {
928  // Due to type conversion, we are seeing the same source and target type.
929  // Then we can just erase this operation by forwarding its operand.
930  rewriter.replaceOp(op, adaptor.getOperands().front());
931  } else {
932  // Compute new rounding mode (if any).
933  std::optional<spirv::FPRoundingMode> rm = std::nullopt;
934  if (auto roundingModeOp =
935  dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
936  if (arith::RoundingModeAttr roundingMode =
937  roundingModeOp.getRoundingModeAttr()) {
938  if (!(rm =
939  convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
940  return rewriter.notifyMatchFailure(
941  op->getLoc(),
942  llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
943  }
944  }
945  }
946  // Create replacement op and attach rounding mode attribute (if any).
947  auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
948  op, dstType, adaptor.getOperands());
949  if (rm) {
950  newOp->setAttr(
951  getDecorationString(spirv::Decoration::FPRoundingMode),
952  spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
953  }
954  }
955  return success();
956  }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // CmpIOp
961 //===----------------------------------------------------------------------===//
962 
963 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
964 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
965 public:
966  using Base::Base;
967 
968  LogicalResult
969  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
970  ConversionPatternRewriter &rewriter) const override {
971  Type srcType = op.getLhs().getType();
972  if (!isBoolScalarOrVector(srcType))
973  return failure();
974  Type dstType = getTypeConverter()->convertType(srcType);
975  if (!dstType)
976  return getTypeConversionFailure(rewriter, op, srcType);
977 
978  switch (op.getPredicate()) {
979  case arith::CmpIPredicate::eq: {
980  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
981  adaptor.getRhs());
982  return success();
983  }
984  case arith::CmpIPredicate::ne: {
985  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
986  op, adaptor.getLhs(), adaptor.getRhs());
987  return success();
988  }
989  case arith::CmpIPredicate::uge:
990  case arith::CmpIPredicate::ugt:
991  case arith::CmpIPredicate::ule:
992  case arith::CmpIPredicate::ult: {
993  // There are no direct corresponding instructions in SPIR-V for such
994  // cases. Extend them to 32-bit and do comparision then.
995  Type type = rewriter.getI32Type();
996  if (auto vectorType = dyn_cast<VectorType>(dstType))
997  type = VectorType::get(vectorType.getShape(), type);
998  Value extLhs =
999  arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1000  Value extRhs =
1001  arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1002 
1003  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1004  extRhs);
1005  return success();
1006  }
1007  default:
1008  break;
1009  }
1010  return failure();
1011  }
1012 };
1013 
1014 /// Converts integer compare operation to SPIR-V ops.
1015 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1016 public:
1017  using Base::Base;
1018 
1019  LogicalResult
1020  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1021  ConversionPatternRewriter &rewriter) const override {
1022  Type srcType = op.getLhs().getType();
1023  if (isBoolScalarOrVector(srcType))
1024  return failure();
1025  Type dstType = getTypeConverter()->convertType(srcType);
1026  if (!dstType)
1027  return getTypeConversionFailure(rewriter, op, srcType);
1028 
1029  switch (op.getPredicate()) {
1030 #define DISPATCH(cmpPredicate, spirvOp) \
1031  case cmpPredicate: \
1032  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1033  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1034  !hasSameBitwidth(srcType, dstType)) { \
1035  return op.emitError( \
1036  "bitwidth emulation is not implemented yet on unsigned op"); \
1037  } \
1038  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1039  adaptor.getRhs()); \
1040  return success();
1041 
1042  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1043  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1044  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1045  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1046  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1047  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1048  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1049  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1050  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1051  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1052 
1053 #undef DISPATCH
1054  }
1055  return failure();
1056  }
1057 };
1058 
1059 //===----------------------------------------------------------------------===//
1060 // CmpFOpPattern
1061 //===----------------------------------------------------------------------===//
1062 
1063 /// Converts floating-point comparison operations to SPIR-V ops.
1064 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1065 public:
1066  using Base::Base;
1067 
1068  LogicalResult
1069  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1070  ConversionPatternRewriter &rewriter) const override {
1071  switch (op.getPredicate()) {
1072 #define DISPATCH(cmpPredicate, spirvOp) \
1073  case cmpPredicate: \
1074  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1075  adaptor.getRhs()); \
1076  return success();
1077 
1078  // Ordered.
1079  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1080  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1081  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1082  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1083  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1084  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1085  // Unordered.
1086  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1087  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1088  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1089  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1090  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1091  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1092 
1093 #undef DISPATCH
1094 
1095  default:
1096  break;
1097  }
1098  return failure();
1099  }
1100 };
1101 
1102 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
1103 /// Kernel capability.
1104 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1105 public:
1106  using Base::Base;
1107 
1108  LogicalResult
1109  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1110  ConversionPatternRewriter &rewriter) const override {
1111  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1112  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1113  adaptor.getRhs());
1114  return success();
1115  }
1116 
1117  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1118  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1119  adaptor.getRhs());
1120  return success();
1121  }
1122 
1123  return failure();
1124  }
1125 };
1126 
1127 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
1128 /// require additional capability.
1129 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1130 public:
1131  using Base::Base;
1132 
1133  LogicalResult
1134  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1135  ConversionPatternRewriter &rewriter) const override {
1136  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1137  op.getPredicate() != arith::CmpFPredicate::UNO)
1138  return failure();
1139 
1140  Location loc = op.getLoc();
1141 
1142  Value replace;
1143  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1144  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1145  // Ordered comparsion checks if neither operand is NaN.
1146  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1147  } else {
1148  // Unordered comparsion checks if either operand is NaN.
1149  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1150  }
1151  } else {
1152  Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1153  Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1154 
1155  replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1156  if (op.getPredicate() == arith::CmpFPredicate::ORD)
1157  replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1158  }
1159 
1160  rewriter.replaceOp(op, replace);
1161  return success();
1162  }
1163 };
1164 
1165 //===----------------------------------------------------------------------===//
1166 // AddUIExtendedOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 /// Converts arith.addui_extended to spirv.IAddCarry.
1170 class AddUIExtendedOpPattern final
1171  : public OpConversionPattern<arith::AddUIExtendedOp> {
1172 public:
1173  using Base::Base;
1174  LogicalResult
1175  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1176  ConversionPatternRewriter &rewriter) const override {
1177  Type dstElemTy = adaptor.getLhs().getType();
1178  Location loc = op->getLoc();
1179  Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1180  adaptor.getRhs());
1181 
1182  Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1183  llvm::ArrayRef(0));
1184  Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1185  llvm::ArrayRef(1));
1186 
1187  // Convert the carry value to boolean.
1188  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1189  Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1190 
1191  rewriter.replaceOp(op, {sumResult, carryResult});
1192  return success();
1193  }
1194 };
1195 
1196 //===----------------------------------------------------------------------===//
1197 // MulIExtendedOp
1198 //===----------------------------------------------------------------------===//
1199 
1200 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1201 template <typename ArithMulOp, typename SPIRVMulOp>
1202 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1203 public:
1205  LogicalResult
1206  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1207  ConversionPatternRewriter &rewriter) const override {
1208  Location loc = op->getLoc();
1209  Value result =
1210  SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1211 
1212  Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1213  llvm::ArrayRef(0));
1214  Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1215  llvm::ArrayRef(1));
1216 
1217  rewriter.replaceOp(op, {low, high});
1218  return success();
1219  }
1220 };
1221 
1222 //===----------------------------------------------------------------------===//
1223 // SelectOp
1224 //===----------------------------------------------------------------------===//
1225 
1226 /// Converts arith.select to spirv.Select.
1227 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1228 public:
1229  using Base::Base;
1230  LogicalResult
1231  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1232  ConversionPatternRewriter &rewriter) const override {
1233  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1234  adaptor.getTrueValue(),
1235  adaptor.getFalseValue());
1236  return success();
1237  }
1238 };
1239 
1240 //===----------------------------------------------------------------------===//
1241 // MinimumFOp, MaximumFOp
1242 //===----------------------------------------------------------------------===//
1243 
1244 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1245 /// spirv.CL.fmax/fmin.
1246 template <typename Op, typename SPIRVOp>
1247 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1248 public:
1250  LogicalResult
1251  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1252  ConversionPatternRewriter &rewriter) const override {
1253  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1254  Type dstType = converter->convertType(op.getType());
1255  if (!dstType)
1256  return getTypeConversionFailure(rewriter, op);
1257 
1258  // arith.maximumf/minimumf:
1259  // "if one of the arguments is NaN, then the result is also NaN."
1260  // spirv.GL.FMax/FMin
1261  // "which operand is the result is undefined if one of the operands
1262  // is a NaN."
1263  // spirv.CL.fmax/fmin:
1264  // "If one argument is a NaN, Fmin returns the other argument."
1265 
1266  Location loc = op.getLoc();
1267  Value spirvOp =
1268  SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1269 
1270  if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1271  rewriter.replaceOp(op, spirvOp);
1272  return success();
1273  }
1274 
1275  Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1276  Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1277 
1278  Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1279  adaptor.getLhs(), spirvOp);
1280  Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1281  adaptor.getRhs(), select1);
1282 
1283  rewriter.replaceOp(op, select2);
1284  return success();
1285  }
1286 };
1287 
1288 //===----------------------------------------------------------------------===//
1289 // MinNumFOp, MaxNumFOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1293 /// spirv.CL.fmax/fmin.
1294 template <typename Op, typename SPIRVOp>
1295 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1296  template <typename TargetOp>
1297  constexpr bool shouldInsertNanGuards() const {
1298  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1299  }
1300 
1301 public:
1303  LogicalResult
1304  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1305  ConversionPatternRewriter &rewriter) const override {
1306  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1307  Type dstType = converter->convertType(op.getType());
1308  if (!dstType)
1309  return getTypeConversionFailure(rewriter, op);
1310 
1311  // arith.maxnumf/minnumf:
1312  // "If one of the arguments is NaN, then the result is the other
1313  // argument."
1314  // spirv.GL.FMax/FMin
1315  // "which operand is the result is undefined if one of the operands
1316  // is a NaN."
1317  // spirv.CL.fmax/fmin:
1318  // "If one argument is a NaN, Fmin returns the other argument."
1319 
1320  Location loc = op.getLoc();
1321  Value spirvOp =
1322  SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1323 
1324  if (!shouldInsertNanGuards<SPIRVOp>() ||
1325  bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1326  rewriter.replaceOp(op, spirvOp);
1327  return success();
1328  }
1329 
1330  Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1331  Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1332 
1333  Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1334  adaptor.getRhs(), spirvOp);
1335  Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1336  adaptor.getLhs(), select1);
1337 
1338  rewriter.replaceOp(op, select2);
1339  return success();
1340  }
1341 };
1342 
1343 } // namespace
1344 
1345 //===----------------------------------------------------------------------===//
1346 // Pattern Population
1347 //===----------------------------------------------------------------------===//
1348 
1350  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1351  // clang-format off
1352  patterns.add<
1353  ConstantCompositeOpPattern,
1354  ConstantScalarOpPattern,
1355  ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1356  ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1357  ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1361  RemSIOpGLPattern, RemSIOpCLPattern,
1362  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1363  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1364  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1365  ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1374  ExtUIPattern, ExtUII1Pattern,
1375  ExtSIPattern, ExtSII1Pattern,
1376  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1377  TruncIPattern, TruncII1Pattern,
1378  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1379  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1380  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1381  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1382  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1383  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384  IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1385  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1386  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1387  CmpIOpBooleanPattern, CmpIOpPattern,
1388  CmpFOpNanNonePattern, CmpFOpPattern,
1389  AddUIExtendedOpPattern,
1390  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1391  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1392  SelectOpPattern,
1393 
1394  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1395  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1396  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1397  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1402 
1403  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1404  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1405  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1406  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1411  >(typeConverter, patterns.getContext());
1412  // clang-format on
1413 
1414  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1415  // capability is available.
1416  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1417  /*benefit=*/2);
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Pass Definition
1422 //===----------------------------------------------------------------------===//
1423 
1424 namespace {
1425 struct ConvertArithToSPIRVPass
1426  : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1427  using Base::Base;
1428 
1429  void runOnOperation() override {
1430  Operation *op = getOperation();
1432  std::unique_ptr<SPIRVConversionTarget> target =
1433  SPIRVConversionTarget::get(targetAttr);
1434 
1436  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1437  options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1438  SPIRVTypeConverter typeConverter(targetAttr, options);
1439 
1440  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1441  // in patterns for other dialects.
1442  target->addLegalOp<UnrealizedConversionCastOp>();
1443 
1444  // Fail hard when there are any remaining 'arith' ops.
1445  target->addIllegalDialect<arith::ArithDialect>();
1446 
1449 
1450  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1451  signalPassFailure();
1452  }
1453 };
1454 } // namespace
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition: AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition: AsmState.h:145
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getI32Type()
Definition: Builders.cpp:62
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
MLIRContext * getContext() const
Definition: Builders.h:56
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:245
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:830
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23