MLIR  18.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../SPIRVCommon/Pattern.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "llvm/ADT/APInt.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <memory>
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTARITHTOSPIRV
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 #define DEBUG_TYPE "arith-to-spirv-pattern"
34 
35 using namespace mlir;
36 
37 //===----------------------------------------------------------------------===//
38 // Conversion Helpers
39 //===----------------------------------------------------------------------===//
40 
41 /// Converts the given `srcAttr` into a boolean attribute if it holds an
42 /// integral value. Returns null attribute if conversion fails.
43 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
44  if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
45  return boolAttr;
46  if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
47  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
48  return {};
49 }
50 
51 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
52 /// Returns null attribute if conversion fails.
53 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
54  Builder builder) {
55  // If the source number uses less active bits than the target bitwidth, then
56  // it should be safe to convert.
57  if (srcAttr.getValue().isIntN(dstType.getWidth()))
58  return builder.getIntegerAttr(dstType, srcAttr.getInt());
59 
60  // XXX: Try again by interpreting the source number as a signed value.
61  // Although integers in the standard dialect are signless, they can represent
62  // a signed number. It's the operation decides how to interpret. This is
63  // dangerous, but it seems there is no good way of handling this if we still
64  // want to change the bitwidth. Emit a message at least.
65  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
66  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
67  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
68  << dstAttr << "' for type '" << dstType << "'\n");
69  return dstAttr;
70  }
71 
72  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
73  << "' illegal: cannot fit into target type '"
74  << dstType << "'\n");
75  return {};
76 }
77 
78 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
79 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
80 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
81  Builder builder) {
82  // Only support converting to float for now.
83  if (!dstType.isF32())
84  return FloatAttr();
85 
86  // Try to convert the source floating-point number to single precision.
87  APFloat dstVal = srcAttr.getValue();
88  bool losesInfo = false;
89  APFloat::opStatus status =
90  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
91  if (status != APFloat::opOK || losesInfo) {
92  LLVM_DEBUG(llvm::dbgs()
93  << srcAttr << " illegal: cannot fit into converted type '"
94  << dstType << "'\n");
95  return FloatAttr();
96  }
97 
98  return builder.getF32FloatAttr(dstVal.convertToFloat());
99 }
100 
101 /// Returns true if the given `type` is a boolean scalar or vector type.
102 static bool isBoolScalarOrVector(Type type) {
103  assert(type && "Not a valid type");
104  if (type.isInteger(1))
105  return true;
106 
107  if (auto vecType = dyn_cast<VectorType>(type))
108  return vecType.getElementType().isInteger(1);
109 
110  return false;
111 }
112 
113 /// Creates a scalar/vector integer constant.
114 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
115  OpBuilder &builder, Location loc) {
116  if (auto vectorType = dyn_cast<VectorType>(type)) {
117  Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
118  auto attr = SplatElementsAttr::get(vectorType, element);
119  return builder.create<spirv::ConstantOp>(loc, vectorType, attr);
120  }
121 
122  if (auto intType = dyn_cast<IntegerType>(type))
123  return builder.create<spirv::ConstantOp>(
124  loc, type, builder.getIntegerAttr(type, value));
125 
126  return nullptr;
127 }
128 
129 /// Returns true if scalar/vector type `a` and `b` have the same number of
130 /// bitwidth.
131 static bool hasSameBitwidth(Type a, Type b) {
132  auto getNumBitwidth = [](Type type) {
133  unsigned bw = 0;
134  if (type.isIntOrFloat())
135  bw = type.getIntOrFloatBitWidth();
136  else if (auto vecType = dyn_cast<VectorType>(type))
137  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
138  return bw;
139  };
140  unsigned aBW = getNumBitwidth(a);
141  unsigned bBW = getNumBitwidth(b);
142  return aBW != 0 && bBW != 0 && aBW == bBW;
143 }
144 
145 /// Returns a source type conversion failure for `srcType` and operation `op`.
146 static LogicalResult
148  Type srcType) {
149  return rewriter.notifyMatchFailure(
150  op->getLoc(),
151  llvm::formatv("failed to convert source type '{0}'", srcType));
152 }
153 
154 /// Returns a source type conversion failure for the result type of `op`.
155 static LogicalResult
157  assert(op->getNumResults() == 1);
158  return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
159 }
160 
161 namespace {
162 
163 //===----------------------------------------------------------------------===//
164 // ConstantOp
165 //===----------------------------------------------------------------------===//
166 
167 /// Converts composite arith.constant operation to spirv.Constant.
168 struct ConstantCompositeOpPattern final
169  : public OpConversionPattern<arith::ConstantOp> {
171 
173  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
174  ConversionPatternRewriter &rewriter) const override {
175  auto srcType = dyn_cast<ShapedType>(constOp.getType());
176  if (!srcType || srcType.getNumElements() == 1)
177  return failure();
178 
179  // arith.constant should only have vector or tenor types.
180  assert((isa<VectorType, RankedTensorType>(srcType)));
181 
182  Type dstType = getTypeConverter()->convertType(srcType);
183  if (!dstType)
184  return failure();
185 
186  auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
187  if (!dstElementsAttr)
188  return failure();
189 
190  ShapedType dstAttrType = dstElementsAttr.getType();
191 
192  // If the composite type has more than one dimensions, perform
193  // linearization.
194  if (srcType.getRank() > 1) {
195  if (isa<RankedTensorType>(srcType)) {
196  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
197  srcType.getElementType());
198  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
199  } else {
200  // TODO: add support for large vectors.
201  return failure();
202  }
203  }
204 
205  Type srcElemType = srcType.getElementType();
206  Type dstElemType;
207  // Tensor types are converted to SPIR-V array types; vector types are
208  // converted to SPIR-V vector/array types.
209  if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
210  dstElemType = arrayType.getElementType();
211  else
212  dstElemType = cast<VectorType>(dstType).getElementType();
213 
214  // If the source and destination element types are different, perform
215  // attribute conversion.
216  if (srcElemType != dstElemType) {
217  SmallVector<Attribute, 8> elements;
218  if (isa<FloatType>(srcElemType)) {
219  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
220  FloatAttr dstAttr =
221  convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
222  if (!dstAttr)
223  return failure();
224  elements.push_back(dstAttr);
225  }
226  } else if (srcElemType.isInteger(1)) {
227  return failure();
228  } else {
229  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
230  IntegerAttr dstAttr = convertIntegerAttr(
231  srcAttr, cast<IntegerType>(dstElemType), rewriter);
232  if (!dstAttr)
233  return failure();
234  elements.push_back(dstAttr);
235  }
236  }
237 
238  // Unfortunately, we cannot use dialect-specific types for element
239  // attributes; element attributes only works with builtin types. So we
240  // need to prepare another converted builtin types for the destination
241  // elements attribute.
242  if (isa<RankedTensorType>(dstAttrType))
243  dstAttrType =
244  RankedTensorType::get(dstAttrType.getShape(), dstElemType);
245  else
246  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
247 
248  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
249  }
250 
251  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
252  dstElementsAttr);
253  return success();
254  }
255 };
256 
257 /// Converts scalar arith.constant operation to spirv.Constant.
258 struct ConstantScalarOpPattern final
259  : public OpConversionPattern<arith::ConstantOp> {
261 
263  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
264  ConversionPatternRewriter &rewriter) const override {
265  Type srcType = constOp.getType();
266  if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
267  if (shapedType.getNumElements() != 1)
268  return failure();
269  srcType = shapedType.getElementType();
270  }
271  if (!srcType.isIntOrIndexOrFloat())
272  return failure();
273 
274  Attribute cstAttr = constOp.getValue();
275  if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
276  cstAttr = elementsAttr.getSplatValue<Attribute>();
277 
278  Type dstType = getTypeConverter()->convertType(srcType);
279  if (!dstType)
280  return failure();
281 
282  // Floating-point types.
283  if (isa<FloatType>(srcType)) {
284  auto srcAttr = cast<FloatAttr>(cstAttr);
285  auto dstAttr = srcAttr;
286 
287  // Floating-point types not supported in the target environment are all
288  // converted to float type.
289  if (srcType != dstType) {
290  dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
291  if (!dstAttr)
292  return failure();
293  }
294 
295  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
296  return success();
297  }
298 
299  // Bool type.
300  if (srcType.isInteger(1)) {
301  // arith.constant can use 0/1 instead of true/false for i1 values. We need
302  // to handle that here.
303  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
304  if (!dstAttr)
305  return failure();
306  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
307  return success();
308  }
309 
310  // IndexType or IntegerType. Index values are converted to 32-bit integer
311  // values when converting to SPIR-V.
312  auto srcAttr = cast<IntegerAttr>(cstAttr);
313  IntegerAttr dstAttr =
314  convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
315  if (!dstAttr)
316  return failure();
317  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
318  return success();
319  }
320 };
321 
322 //===----------------------------------------------------------------------===//
323 // RemSIOp
324 //===----------------------------------------------------------------------===//
325 
326 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
327 /// the sign of `signOperand`.
328 ///
329 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
330 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
331 /// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
332 /// if either operand can be negative. Emulate it via spirv.UMod.
333 template <typename SignedAbsOp>
334 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
335  Value signOperand, OpBuilder &builder) {
336  assert(lhs.getType() == rhs.getType());
337  assert(lhs == signOperand || rhs == signOperand);
338 
339  Type type = lhs.getType();
340 
341  // Calculate the remainder with spirv.UMod.
342  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
343  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
344  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
345 
346  // Fix the sign.
347  Value isPositive;
348  if (lhs == signOperand)
349  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
350  else
351  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
352  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
353  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
354 }
355 
356 /// Converts arith.remsi to GLSL SPIR-V ops.
357 ///
358 /// This cannot be merged into the template unary/binary pattern due to Vulkan
359 /// restrictions over spirv.SRem and spirv.SMod.
360 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
362 
364  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
365  ConversionPatternRewriter &rewriter) const override {
366  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
367  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
368  adaptor.getOperands()[0], rewriter);
369  rewriter.replaceOp(op, result);
370 
371  return success();
372  }
373 };
374 
375 /// Converts arith.remsi to OpenCL SPIR-V ops.
376 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
378 
380  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
383  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
384  adaptor.getOperands()[0], rewriter);
385  rewriter.replaceOp(op, result);
386 
387  return success();
388  }
389 };
390 
391 //===----------------------------------------------------------------------===//
392 // BitwiseOp
393 //===----------------------------------------------------------------------===//
394 
395 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
396 /// other than the BinaryOpPatternPattern because if the operands are boolean
397 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
398 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
399 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
400 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
402 
404  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
405  ConversionPatternRewriter &rewriter) const override {
406  assert(adaptor.getOperands().size() == 2);
407  Type dstType = this->getTypeConverter()->convertType(op.getType());
408  if (!dstType)
409  return getTypeConversionFailure(rewriter, op);
410 
411  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
412  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
413  op, dstType, adaptor.getOperands());
414  } else {
415  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
416  op, dstType, adaptor.getOperands());
417  }
418  return success();
419  }
420 };
421 
422 //===----------------------------------------------------------------------===//
423 // XOrIOp
424 //===----------------------------------------------------------------------===//
425 
426 /// Converts arith.xori to SPIR-V operations.
427 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
429 
431  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const override {
433  assert(adaptor.getOperands().size() == 2);
434 
435  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
436  return failure();
437 
438  Type dstType = getTypeConverter()->convertType(op.getType());
439  if (!dstType)
440  return getTypeConversionFailure(rewriter, op);
441 
442  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
443  adaptor.getOperands());
444 
445  return success();
446  }
447 };
448 
449 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
450 /// vector of i1.
451 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
453 
455  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
456  ConversionPatternRewriter &rewriter) const override {
457  assert(adaptor.getOperands().size() == 2);
458 
459  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
460  return failure();
461 
462  Type dstType = getTypeConverter()->convertType(op.getType());
463  if (!dstType)
464  return getTypeConversionFailure(rewriter, op);
465 
466  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
467  op, dstType, adaptor.getOperands());
468  return success();
469  }
470 };
471 
472 //===----------------------------------------------------------------------===//
473 // UIToFPOp
474 //===----------------------------------------------------------------------===//
475 
476 /// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
477 /// of i1.
478 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
480 
482  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
483  ConversionPatternRewriter &rewriter) const override {
484  Type srcType = adaptor.getOperands().front().getType();
485  if (!isBoolScalarOrVector(srcType))
486  return failure();
487 
488  Type dstType = getTypeConverter()->convertType(op.getType());
489  if (!dstType)
490  return getTypeConversionFailure(rewriter, op);
491 
492  Location loc = op.getLoc();
493  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
494  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
495  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
496  op, dstType, adaptor.getOperands().front(), one, zero);
497  return success();
498  }
499 };
500 
501 //===----------------------------------------------------------------------===//
502 // ExtSIOp
503 //===----------------------------------------------------------------------===//
504 
505 /// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
506 /// of i1.
507 struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
509 
511  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
512  ConversionPatternRewriter &rewriter) const override {
513  Value operand = adaptor.getIn();
514  if (!isBoolScalarOrVector(operand.getType()))
515  return failure();
516 
517  Location loc = op.getLoc();
518  Type dstType = getTypeConverter()->convertType(op.getType());
519  if (!dstType)
520  return getTypeConversionFailure(rewriter, op);
521 
522  Value allOnes;
523  if (auto intTy = dyn_cast<IntegerType>(dstType)) {
524  unsigned componentBitwidth = intTy.getWidth();
525  allOnes = rewriter.create<spirv::ConstantOp>(
526  loc, intTy,
527  rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
528  } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
529  unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
530  allOnes = rewriter.create<spirv::ConstantOp>(
531  loc, vectorTy,
532  SplatElementsAttr::get(vectorTy,
533  APInt::getAllOnes(componentBitwidth)));
534  } else {
535  return rewriter.notifyMatchFailure(
536  loc, llvm::formatv("unhandled type: {0}", dstType));
537  }
538 
539  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
540  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
541  zero);
542  return success();
543  }
544 };
545 
546 /// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
547 /// vector of i1.
548 struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
550 
552  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  Type srcType = adaptor.getIn().getType();
555  if (isBoolScalarOrVector(srcType))
556  return failure();
557 
558  Type dstType = getTypeConverter()->convertType(op.getType());
559  if (!dstType)
560  return getTypeConversionFailure(rewriter, op);
561 
562  if (dstType == srcType) {
563  // We can have the same source and destination type due to type emulation.
564  // Perform bit shifting to make sure we have the proper leading set bits.
565 
566  unsigned srcBW =
567  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
568  unsigned dstBW =
570  assert(srcBW < dstBW);
571  Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
572  rewriter, op.getLoc());
573 
574  // First shift left to sequeeze out all leading bits beyond the original
575  // bitwidth. Here we need to use the original source and result type's
576  // bitwidth.
577  auto shiftLOp = rewriter.create<spirv::ShiftLeftLogicalOp>(
578  op.getLoc(), dstType, adaptor.getIn(), shiftSize);
579 
580  // Then we perform arithmetic right shift to make sure we have the right
581  // sign bits for negative values.
582  rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
583  op, dstType, shiftLOp, shiftSize);
584  } else {
585  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
586  adaptor.getOperands());
587  }
588 
589  return success();
590  }
591 };
592 
593 //===----------------------------------------------------------------------===//
594 // ExtUIOp
595 //===----------------------------------------------------------------------===//
596 
597 /// Converts arith.extui to spirv.Select if the type of source is i1 or vector
598 /// of i1.
599 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
601 
603  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
604  ConversionPatternRewriter &rewriter) const override {
605  Type srcType = adaptor.getOperands().front().getType();
606  if (!isBoolScalarOrVector(srcType))
607  return failure();
608 
609  Type dstType = getTypeConverter()->convertType(op.getType());
610  if (!dstType)
611  return getTypeConversionFailure(rewriter, op);
612 
613  Location loc = op.getLoc();
614  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
615  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
616  rewriter.replaceOpWithNewOp<spirv::SelectOp>(
617  op, dstType, adaptor.getOperands().front(), one, zero);
618  return success();
619  }
620 };
621 
622 /// Converts arith.extui for cases where the type of source is neither i1 nor
623 /// vector of i1.
624 struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
626 
628  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
629  ConversionPatternRewriter &rewriter) const override {
630  Type srcType = adaptor.getIn().getType();
631  if (isBoolScalarOrVector(srcType))
632  return failure();
633 
634  Type dstType = getTypeConverter()->convertType(op.getType());
635  if (!dstType)
636  return getTypeConversionFailure(rewriter, op);
637 
638  if (dstType == srcType) {
639  // We can have the same source and destination type due to type emulation.
640  // Perform bit masking to make sure we don't pollute downstream consumers
641  // with unwanted bits. Here we need to use the original source type's
642  // bitwidth.
643  unsigned bitwidth =
644  getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
646  dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
647  op.getLoc());
648  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
649  adaptor.getIn(), mask);
650  } else {
651  rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
652  adaptor.getOperands());
653  }
654  return success();
655  }
656 };
657 
658 //===----------------------------------------------------------------------===//
659 // TruncIOp
660 //===----------------------------------------------------------------------===//
661 
662 /// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
663 /// of i1.
664 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
666 
668  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
669  ConversionPatternRewriter &rewriter) const override {
670  Type dstType = getTypeConverter()->convertType(op.getType());
671  if (!dstType)
672  return getTypeConversionFailure(rewriter, op);
673 
674  if (!isBoolScalarOrVector(dstType))
675  return failure();
676 
677  Location loc = op.getLoc();
678  auto srcType = adaptor.getOperands().front().getType();
679  // Check if (x & 1) == 1.
680  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
681  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
682  loc, srcType, adaptor.getOperands()[0], mask);
683  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
684 
685  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
686  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
687  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
688  return success();
689  }
690 };
691 
692 /// Converts arith.trunci for cases where the type of result is neither i1
693 /// nor vector of i1.
694 struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
696 
698  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
699  ConversionPatternRewriter &rewriter) const override {
700  Type srcType = adaptor.getIn().getType();
701  Type dstType = getTypeConverter()->convertType(op.getType());
702  if (!dstType)
703  return getTypeConversionFailure(rewriter, op);
704 
705  if (isBoolScalarOrVector(dstType))
706  return failure();
707 
708  if (dstType == srcType) {
709  // We can have the same source and destination type due to type emulation.
710  // Perform bit masking to make sure we don't pollute downstream consumers
711  // with unwanted bits. Here we need to use the original result type's
712  // bitwidth.
713  unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
715  dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
716  rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
717  adaptor.getIn(), mask);
718  } else {
719  // Given this is truncation, either SConvertOp or UConvertOp works.
720  rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
721  adaptor.getOperands());
722  }
723  return success();
724  }
725 };
726 
727 //===----------------------------------------------------------------------===//
728 // TypeCastingOp
729 //===----------------------------------------------------------------------===//
730 
731 /// Converts type-casting standard operations to SPIR-V operations.
732 template <typename Op, typename SPIRVOp>
733 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
735 
737  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
738  ConversionPatternRewriter &rewriter) const override {
739  assert(adaptor.getOperands().size() == 1);
740  Type srcType = adaptor.getOperands().front().getType();
741  Type dstType = this->getTypeConverter()->convertType(op.getType());
742  if (!dstType)
743  return getTypeConversionFailure(rewriter, op);
744 
745  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
746  return failure();
747 
748  if (dstType == srcType) {
749  // Due to type conversion, we are seeing the same source and target type.
750  // Then we can just erase this operation by forwarding its operand.
751  rewriter.replaceOp(op, adaptor.getOperands().front());
752  } else {
753  rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
754  adaptor.getOperands());
755  }
756  return success();
757  }
758 };
759 
760 //===----------------------------------------------------------------------===//
761 // CmpIOp
762 //===----------------------------------------------------------------------===//
763 
764 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
765 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
766 public:
768 
770  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
771  ConversionPatternRewriter &rewriter) const override {
772  Type srcType = op.getLhs().getType();
773  if (!isBoolScalarOrVector(srcType))
774  return failure();
775  Type dstType = getTypeConverter()->convertType(srcType);
776  if (!dstType)
777  return getTypeConversionFailure(rewriter, op, srcType);
778 
779  switch (op.getPredicate()) {
780  case arith::CmpIPredicate::eq: {
781  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
782  adaptor.getRhs());
783  return success();
784  }
785  case arith::CmpIPredicate::ne: {
786  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
787  op, adaptor.getLhs(), adaptor.getRhs());
788  return success();
789  }
790  case arith::CmpIPredicate::uge:
791  case arith::CmpIPredicate::ugt:
792  case arith::CmpIPredicate::ule:
793  case arith::CmpIPredicate::ult: {
794  // There are no direct corresponding instructions in SPIR-V for such
795  // cases. Extend them to 32-bit and do comparision then.
796  Type type = rewriter.getI32Type();
797  if (auto vectorType = dyn_cast<VectorType>(dstType))
798  type = VectorType::get(vectorType.getShape(), type);
799  Value extLhs =
800  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
801  Value extRhs =
802  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
803 
804  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
805  extRhs);
806  return success();
807  }
808  default:
809  break;
810  }
811  return failure();
812  }
813 };
814 
815 /// Converts integer compare operation to SPIR-V ops.
816 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
817 public:
819 
821  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
822  ConversionPatternRewriter &rewriter) const override {
823  Type srcType = op.getLhs().getType();
824  if (isBoolScalarOrVector(srcType))
825  return failure();
826  Type dstType = getTypeConverter()->convertType(srcType);
827  if (!dstType)
828  return getTypeConversionFailure(rewriter, op, srcType);
829 
830  switch (op.getPredicate()) {
831 #define DISPATCH(cmpPredicate, spirvOp) \
832  case cmpPredicate: \
833  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
834  !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
835  !hasSameBitwidth(srcType, dstType)) { \
836  return op.emitError( \
837  "bitwidth emulation is not implemented yet on unsigned op"); \
838  } \
839  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
840  adaptor.getRhs()); \
841  return success();
842 
843  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
844  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
845  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
846  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
847  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
848  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
849  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
850  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
851  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
852  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
853 
854 #undef DISPATCH
855  }
856  return failure();
857  }
858 };
859 
860 //===----------------------------------------------------------------------===//
861 // CmpFOpPattern
862 //===----------------------------------------------------------------------===//
863 
864 /// Converts floating-point comparison operations to SPIR-V ops.
865 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
866 public:
868 
870  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
871  ConversionPatternRewriter &rewriter) const override {
872  switch (op.getPredicate()) {
873 #define DISPATCH(cmpPredicate, spirvOp) \
874  case cmpPredicate: \
875  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
876  adaptor.getRhs()); \
877  return success();
878 
879  // Ordered.
880  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
881  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
882  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
883  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
884  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
885  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
886  // Unordered.
887  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
888  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
889  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
890  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
891  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
892  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
893 
894 #undef DISPATCH
895 
896  default:
897  break;
898  }
899  return failure();
900  }
901 };
902 
903 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
904 /// Kernel capability.
905 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
906 public:
908 
910  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
911  ConversionPatternRewriter &rewriter) const override {
912  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
913  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
914  adaptor.getRhs());
915  return success();
916  }
917 
918  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
919  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
920  adaptor.getRhs());
921  return success();
922  }
923 
924  return failure();
925  }
926 };
927 
928 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
929 /// require additional capability.
930 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
931 public:
933 
935  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
936  ConversionPatternRewriter &rewriter) const override {
937  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
938  op.getPredicate() != arith::CmpFPredicate::UNO)
939  return failure();
940 
941  Location loc = op.getLoc();
942  auto *converter = getTypeConverter<SPIRVTypeConverter>();
943 
944  Value replace;
945  if (converter->getOptions().enableFastMathMode) {
946  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
947  // Ordered comparsion checks if neither operand is NaN.
948  replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
949  } else {
950  // Unordered comparsion checks if either operand is NaN.
951  replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
952  }
953  } else {
954  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
955  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
956 
957  replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
958  if (op.getPredicate() == arith::CmpFPredicate::ORD)
959  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
960  }
961 
962  rewriter.replaceOp(op, replace);
963  return success();
964  }
965 };
966 
967 //===----------------------------------------------------------------------===//
968 // AddUIExtendedOp
969 //===----------------------------------------------------------------------===//
970 
971 /// Converts arith.addui_extended to spirv.IAddCarry.
972 class AddUIExtendedOpPattern final
973  : public OpConversionPattern<arith::AddUIExtendedOp> {
974 public:
977  matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
978  ConversionPatternRewriter &rewriter) const override {
979  Type dstElemTy = adaptor.getLhs().getType();
980  Location loc = op->getLoc();
981  Value result = rewriter.create<spirv::IAddCarryOp>(loc, adaptor.getLhs(),
982  adaptor.getRhs());
983 
984  Value sumResult = rewriter.create<spirv::CompositeExtractOp>(
985  loc, result, llvm::ArrayRef(0));
986  Value carryValue = rewriter.create<spirv::CompositeExtractOp>(
987  loc, result, llvm::ArrayRef(1));
988 
989  // Convert the carry value to boolean.
990  Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
991  Value carryResult = rewriter.create<spirv::IEqualOp>(loc, carryValue, one);
992 
993  rewriter.replaceOp(op, {sumResult, carryResult});
994  return success();
995  }
996 };
997 
998 //===----------------------------------------------------------------------===//
999 // MulIExtendedOp
1000 //===----------------------------------------------------------------------===//
1001 
1002 /// Converts arith.mul*i_extended to spirv.*MulExtended.
1003 template <typename ArithMulOp, typename SPIRVMulOp>
1004 class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1005 public:
1008  matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1009  ConversionPatternRewriter &rewriter) const override {
1010  Location loc = op->getLoc();
1011  Value result =
1012  rewriter.create<SPIRVMulOp>(loc, adaptor.getLhs(), adaptor.getRhs());
1013 
1014  Value low = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1015  llvm::ArrayRef(0));
1016  Value high = rewriter.create<spirv::CompositeExtractOp>(loc, result,
1017  llvm::ArrayRef(1));
1018 
1019  rewriter.replaceOp(op, {low, high});
1020  return success();
1021  }
1022 };
1023 
1024 //===----------------------------------------------------------------------===//
1025 // SelectOp
1026 //===----------------------------------------------------------------------===//
1027 
1028 /// Converts arith.select to spirv.Select.
1029 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1030 public:
1033  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1034  ConversionPatternRewriter &rewriter) const override {
1035  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1036  adaptor.getTrueValue(),
1037  adaptor.getFalseValue());
1038  return success();
1039  }
1040 };
1041 
1042 //===----------------------------------------------------------------------===//
1043 // MinimumFOp, MaximumFOp
1044 //===----------------------------------------------------------------------===//
1045 
1046 /// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1047 /// spirv.CL.fmax/fmin.
1048 template <typename Op, typename SPIRVOp>
1049 class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1050 public:
1053  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1054  ConversionPatternRewriter &rewriter) const override {
1055  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1056  Type dstType = converter->convertType(op.getType());
1057  if (!dstType)
1058  return getTypeConversionFailure(rewriter, op);
1059 
1060  // arith.maximumf/minimumf:
1061  // "if one of the arguments is NaN, then the result is also NaN."
1062  // spirv.GL.FMax/FMin
1063  // "which operand is the result is undefined if one of the operands
1064  // is a NaN."
1065  // spirv.CL.fmax/fmin:
1066  // "If one argument is a NaN, Fmin returns the other argument."
1067 
1068  Location loc = op.getLoc();
1069  Value spirvOp =
1070  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1071 
1072  if (converter->getOptions().enableFastMathMode) {
1073  rewriter.replaceOp(op, spirvOp);
1074  return success();
1075  }
1076 
1077  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1078  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1079 
1080  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1081  adaptor.getLhs(), spirvOp);
1082  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1083  adaptor.getRhs(), select1);
1084 
1085  rewriter.replaceOp(op, select2);
1086  return success();
1087  }
1088 };
1089 
1090 //===----------------------------------------------------------------------===//
1091 // MinNumFOp, MaxNumFOp
1092 //===----------------------------------------------------------------------===//
1093 
1094 /// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1095 /// spirv.CL.fmax/fmin.
1096 template <typename Op, typename SPIRVOp>
1097 class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1098  template <typename TargetOp>
1099  constexpr bool shouldInsertNanGuards() const {
1100  return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1101  }
1102 
1103 public:
1106  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1107  ConversionPatternRewriter &rewriter) const override {
1108  auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1109  Type dstType = converter->convertType(op.getType());
1110  if (!dstType)
1111  return getTypeConversionFailure(rewriter, op);
1112 
1113  // arith.maxnumf/minnumf:
1114  // "If one of the arguments is NaN, then the result is the other
1115  // argument."
1116  // spirv.GL.FMax/FMin
1117  // "which operand is the result is undefined if one of the operands
1118  // is a NaN."
1119  // spirv.CL.fmax/fmin:
1120  // "If one argument is a NaN, Fmin returns the other argument."
1121 
1122  Location loc = op.getLoc();
1123  Value spirvOp =
1124  rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
1125 
1126  if (!shouldInsertNanGuards<SPIRVOp>() ||
1127  converter->getOptions().enableFastMathMode) {
1128  rewriter.replaceOp(op, spirvOp);
1129  return success();
1130  }
1131 
1132  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
1133  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
1134 
1135  Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
1136  adaptor.getRhs(), spirvOp);
1137  Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
1138  adaptor.getLhs(), select1);
1139 
1140  rewriter.replaceOp(op, select2);
1141  return success();
1142  }
1143 };
1144 
1145 } // namespace
1146 
1147 //===----------------------------------------------------------------------===//
1148 // Pattern Population
1149 //===----------------------------------------------------------------------===//
1150 
1152  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1153  // clang-format off
1154  patterns.add<
1155  ConstantCompositeOpPattern,
1156  ConstantScalarOpPattern,
1163  RemSIOpGLPattern, RemSIOpCLPattern,
1164  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1165  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1166  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1176  ExtUIPattern, ExtUII1Pattern,
1177  ExtSIPattern, ExtSII1Pattern,
1178  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1179  TruncIPattern, TruncII1Pattern,
1180  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1181  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1182  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1183  TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1184  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1185  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1186  TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1187  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1188  CmpIOpBooleanPattern, CmpIOpPattern,
1189  CmpFOpNanNonePattern, CmpFOpPattern,
1190  AddUIExtendedOpPattern,
1191  MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1192  MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1193  SelectOpPattern,
1194 
1195  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1196  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1197  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1198  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1203 
1204  MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1205  MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1206  MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1207  MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1212  >(typeConverter, patterns.getContext());
1213  // clang-format on
1214 
1215  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1216  // capability is available.
1217  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1218  /*benefit=*/2);
1219 }
1220 
1221 //===----------------------------------------------------------------------===//
1222 // Pass Definition
1223 //===----------------------------------------------------------------------===//
1224 
1225 namespace {
1226 struct ConvertArithToSPIRVPass
1227  : public impl::ConvertArithToSPIRVBase<ConvertArithToSPIRVPass> {
1228  void runOnOperation() override {
1229  Operation *op = getOperation();
1231  std::unique_ptr<SPIRVConversionTarget> target =
1232  SPIRVConversionTarget::get(targetAttr);
1233 
1235  options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1236  options.enableFastMathMode = this->enableFastMath;
1237  SPIRVTypeConverter typeConverter(targetAttr, options);
1238 
1239  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1240  // in patterns for other dialects.
1241  target->addLegalOp<UnrealizedConversionCastOp>();
1242 
1243  // Fail hard when there are any remaining 'arith' ops.
1244  target->addIllegalDialect<arith::ArithDialect>();
1245 
1246  RewritePatternSet patterns(&getContext());
1247  arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1248 
1249  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1250  signalPassFailure();
1251  }
1252 };
1253 } // namespace
1254 
1255 std::unique_ptr<OperationPass<>> mlir::arith::createConvertArithToSPIRVPass() {
1256  return std::make_unique<ConvertArithToSPIRVPass>();
1257 }
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
IntegerType getI32Type()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:121
bool isF32() const
Definition: Types.cpp:51
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
An attribute that specifies the target version, allowed extensions and capabilities,...
void populateArithToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass<> > createConvertArithToSPIRVPass()
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:23