MLIR  16.0.0git
ArithmeticToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "arith-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation Conversion
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 
30 /// Converts composite arith.constant operation to spv.Constant.
31 struct ConstantCompositeOpPattern final
32  : public OpConversionPattern<arith::ConstantOp> {
34 
36  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
37  ConversionPatternRewriter &rewriter) const override;
38 };
39 
40 /// Converts scalar arith.constant operation to spv.Constant.
41 struct ConstantScalarOpPattern final
42  : public OpConversionPattern<arith::ConstantOp> {
44 
46  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
47  ConversionPatternRewriter &rewriter) const override;
48 };
49 
50 /// Converts arith.remsi to GLSL SPIR-V ops.
51 ///
52 /// This cannot be merged into the template unary/binary pattern due to Vulkan
53 /// restrictions over spv.SRem and spv.SMod.
54 struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
56 
58  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
59  ConversionPatternRewriter &rewriter) const override;
60 };
61 
62 /// Converts arith.remsi to OpenCL SPIR-V ops.
63 struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
65 
67  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override;
69 };
70 
71 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
72 /// other than the BinaryOpPatternPattern because if the operands are boolean
73 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
74 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
75 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
76 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
78 
80  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
81  ConversionPatternRewriter &rewriter) const override;
82 };
83 
84 /// Converts arith.xori to SPIR-V operations.
85 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
87 
89  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override;
91 };
92 
93 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
94 /// vector of i1.
95 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
97 
99  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override;
101 };
102 
103 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
104 /// i1.
105 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
107 
109  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override;
111 };
112 
113 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
114 /// i1.
115 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
117 
119  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override;
121 };
122 
123 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
124 /// i1.
125 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
127 
129  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override;
131 };
132 
133 /// Converts type-casting standard operations to SPIR-V operations.
134 template <typename Op, typename SPIRVOp>
135 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
137 
139  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
140  ConversionPatternRewriter &rewriter) const override;
141 };
142 
143 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
144 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
145 public:
147 
149  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
150  ConversionPatternRewriter &rewriter) const override;
151 };
152 
153 /// Converts integer compare operation to SPIR-V ops.
154 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
155 public:
157 
159  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
160  ConversionPatternRewriter &rewriter) const override;
161 };
162 
163 /// Converts floating-point comparison operations to SPIR-V ops.
164 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
165 public:
167 
169  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
170  ConversionPatternRewriter &rewriter) const override;
171 };
172 
173 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
174 /// Kernel capability.
175 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
176 public:
178 
180  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
185 /// require additional capability.
186 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
187 public:
189 
191  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
192  ConversionPatternRewriter &rewriter) const override;
193 };
194 
195 /// Converts arith.select to spv.Select.
196 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
197 public:
200  matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
201  ConversionPatternRewriter &rewriter) const override;
202 };
203 
204 } // namespace
205 
206 //===----------------------------------------------------------------------===//
207 // Conversion Helpers
208 //===----------------------------------------------------------------------===//
209 
210 /// Converts the given `srcAttr` into a boolean attribute if it holds an
211 /// integral value. Returns null attribute if conversion fails.
212 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
213  if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
214  return boolAttr;
215  if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
216  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
217  return BoolAttr();
218 }
219 
220 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
221 /// Returns null attribute if conversion fails.
222 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
223  Builder builder) {
224  // If the source number uses less active bits than the target bitwidth, then
225  // it should be safe to convert.
226  if (srcAttr.getValue().isIntN(dstType.getWidth()))
227  return builder.getIntegerAttr(dstType, srcAttr.getInt());
228 
229  // XXX: Try again by interpreting the source number as a signed value.
230  // Although integers in the standard dialect are signless, they can represent
231  // a signed number. It's the operation decides how to interpret. This is
232  // dangerous, but it seems there is no good way of handling this if we still
233  // want to change the bitwidth. Emit a message at least.
234  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
235  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
236  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
237  << dstAttr << "' for type '" << dstType << "'\n");
238  return dstAttr;
239  }
240 
241  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
242  << "' illegal: cannot fit into target type '"
243  << dstType << "'\n");
244  return IntegerAttr();
245 }
246 
247 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
248 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
249 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
250  Builder builder) {
251  // Only support converting to float for now.
252  if (!dstType.isF32())
253  return FloatAttr();
254 
255  // Try to convert the source floating-point number to single precision.
256  APFloat dstVal = srcAttr.getValue();
257  bool losesInfo = false;
258  APFloat::opStatus status =
259  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
260  if (status != APFloat::opOK || losesInfo) {
261  LLVM_DEBUG(llvm::dbgs()
262  << srcAttr << " illegal: cannot fit into converted type '"
263  << dstType << "'\n");
264  return FloatAttr();
265  }
266 
267  return builder.getF32FloatAttr(dstVal.convertToFloat());
268 }
269 
270 /// Returns true if the given `type` is a boolean scalar or vector type.
271 static bool isBoolScalarOrVector(Type type) {
272  if (type.isInteger(1))
273  return true;
274  if (auto vecType = type.dyn_cast<VectorType>())
275  return vecType.getElementType().isInteger(1);
276  return false;
277 }
278 
279 /// Returns true if scalar/vector type `a` and `b` have the same number of
280 /// bitwidth.
281 static bool hasSameBitwidth(Type a, Type b) {
282  auto getNumBitwidth = [](Type type) {
283  unsigned bw = 0;
284  if (type.isIntOrFloat())
285  bw = type.getIntOrFloatBitWidth();
286  else if (auto vecType = type.dyn_cast<VectorType>())
287  bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
288  return bw;
289  };
290  unsigned aBW = getNumBitwidth(a);
291  unsigned bBW = getNumBitwidth(b);
292  return aBW != 0 && bBW != 0 && aBW == bBW;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // ConstantOp with composite type
297 //===----------------------------------------------------------------------===//
298 
299 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
300  arith::ConstantOp constOp, OpAdaptor adaptor,
301  ConversionPatternRewriter &rewriter) const {
302  auto srcType = constOp.getType().dyn_cast<ShapedType>();
303  if (!srcType || srcType.getNumElements() == 1)
304  return failure();
305 
306  // arith.constant should only have vector or tenor types.
307  assert((srcType.isa<VectorType, RankedTensorType>()));
308 
309  auto dstType = getTypeConverter()->convertType(srcType);
310  if (!dstType)
311  return failure();
312 
313  auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
314  if (!dstElementsAttr)
315  return failure();
316 
317  ShapedType dstAttrType = dstElementsAttr.getType();
318 
319  // If the composite type has more than one dimensions, perform linearization.
320  if (srcType.getRank() > 1) {
321  if (srcType.isa<RankedTensorType>()) {
322  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
323  srcType.getElementType());
324  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
325  } else {
326  // TODO: add support for large vectors.
327  return failure();
328  }
329  }
330 
331  Type srcElemType = srcType.getElementType();
332  Type dstElemType;
333  // Tensor types are converted to SPIR-V array types; vector types are
334  // converted to SPIR-V vector/array types.
335  if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
336  dstElemType = arrayType.getElementType();
337  else
338  dstElemType = dstType.cast<VectorType>().getElementType();
339 
340  // If the source and destination element types are different, perform
341  // attribute conversion.
342  if (srcElemType != dstElemType) {
343  SmallVector<Attribute, 8> elements;
344  if (srcElemType.isa<FloatType>()) {
345  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
346  FloatAttr dstAttr =
347  convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
348  if (!dstAttr)
349  return failure();
350  elements.push_back(dstAttr);
351  }
352  } else if (srcElemType.isInteger(1)) {
353  return failure();
354  } else {
355  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
356  IntegerAttr dstAttr = convertIntegerAttr(
357  srcAttr, dstElemType.cast<IntegerType>(), rewriter);
358  if (!dstAttr)
359  return failure();
360  elements.push_back(dstAttr);
361  }
362  }
363 
364  // Unfortunately, we cannot use dialect-specific types for element
365  // attributes; element attributes only works with builtin types. So we need
366  // to prepare another converted builtin types for the destination elements
367  // attribute.
368  if (dstAttrType.isa<RankedTensorType>())
369  dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
370  else
371  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
372 
373  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
374  }
375 
376  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
377  dstElementsAttr);
378  return success();
379 }
380 
381 //===----------------------------------------------------------------------===//
382 // ConstantOp with scalar type
383 //===----------------------------------------------------------------------===//
384 
385 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386  arith::ConstantOp constOp, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const {
388  Type srcType = constOp.getType();
389  if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
390  if (shapedType.getNumElements() != 1)
391  return failure();
392  srcType = shapedType.getElementType();
393  }
394  if (!srcType.isIntOrIndexOrFloat())
395  return failure();
396 
397  Attribute cstAttr = constOp.getValue();
398  if (auto elementsAttr = cstAttr.dyn_cast<DenseElementsAttr>())
399  cstAttr = elementsAttr.getSplatValue<Attribute>();
400 
401  Type dstType = getTypeConverter()->convertType(srcType);
402  if (!dstType)
403  return failure();
404 
405  // Floating-point types.
406  if (srcType.isa<FloatType>()) {
407  auto srcAttr = cstAttr.cast<FloatAttr>();
408  auto dstAttr = srcAttr;
409 
410  // Floating-point types not supported in the target environment are all
411  // converted to float type.
412  if (srcType != dstType) {
413  dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
414  if (!dstAttr)
415  return failure();
416  }
417 
418  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419  return success();
420  }
421 
422  // Bool type.
423  if (srcType.isInteger(1)) {
424  // arith.constant can use 0/1 instead of true/false for i1 values. We need
425  // to handle that here.
426  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
427  if (!dstAttr)
428  return failure();
429  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
430  return success();
431  }
432 
433  // IndexType or IntegerType. Index values are converted to 32-bit integer
434  // values when converting to SPIR-V.
435  auto srcAttr = cstAttr.cast<IntegerAttr>();
436  auto dstAttr =
437  convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
438  if (!dstAttr)
439  return failure();
440  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
441  return success();
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // RemSIOpGLPattern
446 //===----------------------------------------------------------------------===//
447 
448 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
449 /// the sign of `signOperand`.
450 ///
451 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
452 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
453 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
454 /// if either operand can be negative. Emulate it via spv.UMod.
455 template <typename SignedAbsOp>
457  Value signOperand, OpBuilder &builder) {
458  assert(lhs.getType() == rhs.getType());
459  assert(lhs == signOperand || rhs == signOperand);
460 
461  Type type = lhs.getType();
462 
463  // Calculate the remainder with spv.UMod.
464  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
465  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
466  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
467 
468  // Fix the sign.
469  Value isPositive;
470  if (lhs == signOperand)
471  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
472  else
473  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
474  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
475  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
476 }
477 
479 RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
480  ConversionPatternRewriter &rewriter) const {
481  Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
482  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
483  adaptor.getOperands()[0], rewriter);
484  rewriter.replaceOp(op, result);
485 
486  return success();
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // RemSIOpCLPattern
491 //===----------------------------------------------------------------------===//
492 
494 RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
495  ConversionPatternRewriter &rewriter) const {
496  Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
497  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
498  adaptor.getOperands()[0], rewriter);
499  rewriter.replaceOp(op, result);
500 
501  return success();
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // BitwiseOpPattern
506 //===----------------------------------------------------------------------===//
507 
508 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
510 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
511  Op op, typename Op::Adaptor adaptor,
512  ConversionPatternRewriter &rewriter) const {
513  assert(adaptor.getOperands().size() == 2);
514  auto dstType =
515  this->getTypeConverter()->convertType(op.getResult().getType());
516  if (!dstType)
517  return failure();
518  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
519  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
520  adaptor.getOperands());
521  } else {
522  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
523  adaptor.getOperands());
524  }
525  return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // XOrIOpLogicalPattern
530 //===----------------------------------------------------------------------===//
531 
532 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
533  arith::XOrIOp op, OpAdaptor adaptor,
534  ConversionPatternRewriter &rewriter) const {
535  assert(adaptor.getOperands().size() == 2);
536 
537  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
538  return failure();
539 
540  auto dstType = getTypeConverter()->convertType(op.getType());
541  if (!dstType)
542  return failure();
543  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
544  adaptor.getOperands());
545 
546  return success();
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // XOrIOpBooleanPattern
551 //===----------------------------------------------------------------------===//
552 
553 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
554  arith::XOrIOp op, OpAdaptor adaptor,
555  ConversionPatternRewriter &rewriter) const {
556  assert(adaptor.getOperands().size() == 2);
557 
558  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
559  return failure();
560 
561  auto dstType = getTypeConverter()->convertType(op.getType());
562  if (!dstType)
563  return failure();
564  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
565  adaptor.getOperands());
566  return success();
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // UIToFPI1Pattern
571 //===----------------------------------------------------------------------===//
572 
574 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
575  ConversionPatternRewriter &rewriter) const {
576  auto srcType = adaptor.getOperands().front().getType();
577  if (!isBoolScalarOrVector(srcType))
578  return failure();
579 
580  auto dstType =
581  this->getTypeConverter()->convertType(op.getResult().getType());
582  Location loc = op.getLoc();
583  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
584  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
585  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
586  op, dstType, adaptor.getOperands().front(), one, zero);
587  return success();
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // ExtUII1Pattern
592 //===----------------------------------------------------------------------===//
593 
595 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
596  ConversionPatternRewriter &rewriter) const {
597  auto srcType = adaptor.getOperands().front().getType();
598  if (!isBoolScalarOrVector(srcType))
599  return failure();
600 
601  auto dstType =
602  this->getTypeConverter()->convertType(op.getResult().getType());
603  Location loc = op.getLoc();
604  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
605  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
606  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
607  op, dstType, adaptor.getOperands().front(), one, zero);
608  return success();
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // TruncII1Pattern
613 //===----------------------------------------------------------------------===//
614 
616 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
617  ConversionPatternRewriter &rewriter) const {
618  auto dstType =
619  this->getTypeConverter()->convertType(op.getResult().getType());
620  if (!isBoolScalarOrVector(dstType))
621  return failure();
622 
623  Location loc = op.getLoc();
624  auto srcType = adaptor.getOperands().front().getType();
625  // Check if (x & 1) == 1.
626  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
627  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
628  loc, srcType, adaptor.getOperands()[0], mask);
629  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
630 
631  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
632  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
633  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
634  return success();
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // TypeCastingOpPattern
639 //===----------------------------------------------------------------------===//
640 
641 template <typename Op, typename SPIRVOp>
642 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
643  Op op, typename Op::Adaptor adaptor,
644  ConversionPatternRewriter &rewriter) const {
645  assert(adaptor.getOperands().size() == 1);
646  auto srcType = adaptor.getOperands().front().getType();
647  auto dstType =
648  this->getTypeConverter()->convertType(op.getResult().getType());
649  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
650  return failure();
651  if (dstType == srcType) {
652  // Due to type conversion, we are seeing the same source and target type.
653  // Then we can just erase this operation by forwarding its operand.
654  rewriter.replaceOp(op, adaptor.getOperands().front());
655  } else {
656  rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
657  adaptor.getOperands());
658  }
659  return success();
660 }
661 
662 //===----------------------------------------------------------------------===//
663 // CmpIOpBooleanPattern
664 //===----------------------------------------------------------------------===//
665 
666 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
667  arith::CmpIOp op, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const {
669  Type srcType = op.getLhs().getType();
670  if (!isBoolScalarOrVector(srcType))
671  return failure();
672  Type dstType = getTypeConverter()->convertType(srcType);
673  if (!dstType)
674  return failure();
675 
676  switch (op.getPredicate()) {
677  case arith::CmpIPredicate::eq: {
678  rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
679  adaptor.getRhs());
680  return success();
681  }
682  case arith::CmpIPredicate::ne: {
683  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
684  adaptor.getRhs());
685  return success();
686  }
687  case arith::CmpIPredicate::uge:
688  case arith::CmpIPredicate::ugt:
689  case arith::CmpIPredicate::ule:
690  case arith::CmpIPredicate::ult: {
691  // There are no direct corresponding instructions in SPIR-V for such cases.
692  // Extend them to 32-bit and do comparision then.
693  Type type = rewriter.getI32Type();
694  if (auto vectorType = dstType.dyn_cast<VectorType>())
695  type = VectorType::get(vectorType.getShape(), type);
696  auto extLhs =
697  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
698  auto extRhs =
699  rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
700 
701  rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
702  extRhs);
703  return success();
704  }
705  default:
706  break;
707  }
708  return failure();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // CmpIOpPattern
713 //===----------------------------------------------------------------------===//
714 
716 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
717  ConversionPatternRewriter &rewriter) const {
718  Type srcType = op.getLhs().getType();
719  if (isBoolScalarOrVector(srcType))
720  return failure();
721  Type dstType = getTypeConverter()->convertType(srcType);
722  if (!dstType)
723  return failure();
724 
725  switch (op.getPredicate()) {
726 #define DISPATCH(cmpPredicate, spirvOp) \
727  case cmpPredicate: \
728  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
729  srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \
730  return op.emitError( \
731  "bitwidth emulation is not implemented yet on unsigned op"); \
732  } \
733  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
734  adaptor.getRhs()); \
735  return success();
736 
737  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
738  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
739  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
740  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
741  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
742  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
743  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
744  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
745  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
746  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
747 
748 #undef DISPATCH
749  }
750  return failure();
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // CmpFOpPattern
755 //===----------------------------------------------------------------------===//
756 
758 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
759  ConversionPatternRewriter &rewriter) const {
760  switch (op.getPredicate()) {
761 #define DISPATCH(cmpPredicate, spirvOp) \
762  case cmpPredicate: \
763  rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
764  adaptor.getRhs()); \
765  return success();
766 
767  // Ordered.
768  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
769  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
770  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
771  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
772  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
773  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
774  // Unordered.
775  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
776  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
777  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
778  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
779  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
780  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
781 
782 #undef DISPATCH
783 
784  default:
785  break;
786  }
787  return failure();
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // CmpFOpNanKernelPattern
792 //===----------------------------------------------------------------------===//
793 
794 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
795  arith::CmpFOp op, OpAdaptor adaptor,
796  ConversionPatternRewriter &rewriter) const {
797  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
798  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
799  adaptor.getRhs());
800  return success();
801  }
802 
803  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
804  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
805  adaptor.getRhs());
806  return success();
807  }
808 
809  return failure();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // CmpFOpNanNonePattern
814 //===----------------------------------------------------------------------===//
815 
816 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
817  arith::CmpFOp op, OpAdaptor adaptor,
818  ConversionPatternRewriter &rewriter) const {
819  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
820  op.getPredicate() != arith::CmpFPredicate::UNO)
821  return failure();
822 
823  Location loc = op.getLoc();
824 
825  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
826  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
827 
828  Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
829  if (op.getPredicate() == arith::CmpFPredicate::ORD)
830  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
831 
832  rewriter.replaceOp(op, replace);
833  return success();
834 }
835 
836 //===----------------------------------------------------------------------===//
837 // SelectOpPattern
838 //===----------------------------------------------------------------------===//
839 
841 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
842  ConversionPatternRewriter &rewriter) const {
843  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
844  adaptor.getTrueValue(),
845  adaptor.getFalseValue());
846  return success();
847 }
848 
849 //===----------------------------------------------------------------------===//
850 // Pattern Population
851 //===----------------------------------------------------------------------===//
852 
854  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
855  // clang-format off
856  patterns.add<
857  ConstantCompositeOpPattern,
858  ConstantScalarOpPattern,
865  RemSIOpGLPattern, RemSIOpCLPattern,
866  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
867  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
868  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
878  TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
879  TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
880  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
881  TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
882  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
883  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
884  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
885  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
886  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
887  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
888  CmpIOpBooleanPattern, CmpIOpPattern,
889  CmpFOpNanNonePattern, CmpFOpPattern,
890  SelectOpPattern,
891 
898  >(typeConverter, patterns.getContext());
899  // clang-format on
900 
901  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
902  // capability is available.
903  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
904  /*benefit=*/2);
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // Pass Definition
909 //===----------------------------------------------------------------------===//
910 
911 namespace {
912 struct ConvertArithmeticToSPIRVPass
913  : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
914  void runOnOperation() override {
915  Operation *op = getOperation();
916  auto targetAttr = spirv::lookupTargetEnvOrDefault(op);
917  auto target = SPIRVConversionTarget::get(targetAttr);
918 
920  options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
921  SPIRVTypeConverter typeConverter(targetAttr, options);
922 
923  // Use UnrealizedConversionCast as the bridge so that we don't need to pull
924  // in patterns for other dialects.
925  auto addUnrealizedCast = [](OpBuilder &builder, Type type,
926  ValueRange inputs, Location loc) {
927  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
928  return Optional<Value>(cast.getResult(0));
929  };
930  typeConverter.addSourceMaterialization(addUnrealizedCast);
931  typeConverter.addTargetMaterialization(addUnrealizedCast);
932  target->addLegalOp<UnrealizedConversionCastOp>();
933 
934  RewritePatternSet patterns(&getContext());
935  arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
936 
937  if (failed(applyPartialConversion(op, *target, std::move(patterns))))
938  signalPassFailure();
939  }
940 };
941 } // namespace
942 
943 std::unique_ptr<OperationPass<>>
945  return std::make_unique<ConvertArithmeticToSPIRVPass>();
946 }
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
bool isF32() const
Definition: Types.cpp:23
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder)
Returns signed remainder for lhs and rhs and lets the result follow the sign of signOperand.
U cast() const
Definition: Attributes.h:135
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:369
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:21
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal...
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
An attribute that represents a reference to a dense vector or tensor object.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:194
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
#define DISPATCH(cmpPredicate, spirvOp)
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
static llvm::ManagedStatic< PassManagerOptions > options
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:118
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:114
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
U dyn_cast() const
Definition: Attributes.h:127
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
This provides public APIs that all operations should have.
std::unique_ptr< OperationPass<> > createConvertArithmeticToSPIRVPass()
bool isa() const
Definition: Types.h:254
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:209
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
MLIRContext * getContext() const
IntegerType getI32Type()
Definition: Builders.cpp:54
bool emulateNon32BitScalarTypes
Whether to emulate non-32-bit scalar types with 32-bit scalar types if no native support.
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:278