MLIR  14.0.0git
ArithmeticToSPIRV.cpp
Go to the documentation of this file.
1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "arith-to-spirv-pattern"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Operation Conversion
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 
28 /// Converts composite arith.constant operation to spv.Constant.
29 struct ConstantCompositeOpPattern final
30  : public OpConversionPattern<arith::ConstantOp> {
32 
34  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
35  ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 /// Converts scalar arith.constant operation to spv.Constant.
39 struct ConstantScalarOpPattern final
40  : public OpConversionPattern<arith::ConstantOp> {
42 
44  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override;
46 };
47 
48 /// Converts arith.remsi to GLSL SPIR-V ops.
49 ///
50 /// This cannot be merged into the template unary/binary pattern due to Vulkan
51 /// restrictions over spv.SRem and spv.SMod.
52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
54 
56  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
57  ConversionPatternRewriter &rewriter) const override;
58 };
59 
60 /// Converts arith.remsi to OpenCL SPIR-V ops.
61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
63 
65  matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
66  ConversionPatternRewriter &rewriter) const override;
67 };
68 
69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
70 /// other than the BinaryOpPatternPattern because if the operands are boolean
71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
74 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
76 
78  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override;
80 };
81 
82 /// Converts arith.xori to SPIR-V operations.
83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
85 
87  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
92 /// vector of i1.
93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
95 
97  matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
98  ConversionPatternRewriter &rewriter) const override;
99 };
100 
101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
102 /// i1.
103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
105 
107  matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
112 /// i1.
113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
115 
117  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
118  ConversionPatternRewriter &rewriter) const override;
119 };
120 
121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
122 /// i1.
123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
125 
127  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
128  ConversionPatternRewriter &rewriter) const override;
129 };
130 
131 /// Converts type-casting standard operations to SPIR-V operations.
132 template <typename Op, typename SPIRVOp>
133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
135 
137  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
138  ConversionPatternRewriter &rewriter) const override;
139 };
140 
141 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
143 public:
145 
147  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
148  ConversionPatternRewriter &rewriter) const override;
149 };
150 
151 /// Converts integer compare operation to SPIR-V ops.
152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
153 public:
155 
157  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
158  ConversionPatternRewriter &rewriter) const override;
159 };
160 
161 /// Converts floating-point comparison operations to SPIR-V ops.
162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
163 public:
165 
167  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const override;
169 };
170 
171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
172 /// Kernel capability.
173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
174 public:
176 
178  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override;
180 };
181 
182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
183 /// require additional capability.
184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
185 public:
187 
189  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
190  ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // Conversion Helpers
197 //===----------------------------------------------------------------------===//
198 
199 /// Converts the given `srcAttr` into a boolean attribute if it holds an
200 /// integral value. Returns null attribute if conversion fails.
201 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
202  if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
203  return boolAttr;
204  if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
205  return builder.getBoolAttr(intAttr.getValue().getBoolValue());
206  return BoolAttr();
207 }
208 
209 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
210 /// Returns null attribute if conversion fails.
211 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
212  Builder builder) {
213  // If the source number uses less active bits than the target bitwidth, then
214  // it should be safe to convert.
215  if (srcAttr.getValue().isIntN(dstType.getWidth()))
216  return builder.getIntegerAttr(dstType, srcAttr.getInt());
217 
218  // XXX: Try again by interpreting the source number as a signed value.
219  // Although integers in the standard dialect are signless, they can represent
220  // a signed number. It's the operation decides how to interpret. This is
221  // dangerous, but it seems there is no good way of handling this if we still
222  // want to change the bitwidth. Emit a message at least.
223  if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
224  auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
225  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
226  << dstAttr << "' for type '" << dstType << "'\n");
227  return dstAttr;
228  }
229 
230  LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
231  << "' illegal: cannot fit into target type '"
232  << dstType << "'\n");
233  return IntegerAttr();
234 }
235 
236 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
237 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
238 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
239  Builder builder) {
240  // Only support converting to float for now.
241  if (!dstType.isF32())
242  return FloatAttr();
243 
244  // Try to convert the source floating-point number to single precision.
245  APFloat dstVal = srcAttr.getValue();
246  bool losesInfo = false;
247  APFloat::opStatus status =
248  dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
249  if (status != APFloat::opOK || losesInfo) {
250  LLVM_DEBUG(llvm::dbgs()
251  << srcAttr << " illegal: cannot fit into converted type '"
252  << dstType << "'\n");
253  return FloatAttr();
254  }
255 
256  return builder.getF32FloatAttr(dstVal.convertToFloat());
257 }
258 
259 /// Returns true if the given `type` is a boolean scalar or vector type.
260 static bool isBoolScalarOrVector(Type type) {
261  if (type.isInteger(1))
262  return true;
263  if (auto vecType = type.dyn_cast<VectorType>())
264  return vecType.getElementType().isInteger(1);
265  return false;
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // ConstantOp with composite type
270 //===----------------------------------------------------------------------===//
271 
272 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273  arith::ConstantOp constOp, OpAdaptor adaptor,
274  ConversionPatternRewriter &rewriter) const {
275  auto srcType = constOp.getType().dyn_cast<ShapedType>();
276  if (!srcType || srcType.getNumElements() == 1)
277  return failure();
278 
279  // arith.constant should only have vector or tenor types.
280  assert((srcType.isa<VectorType, RankedTensorType>()));
281 
282  auto dstType = getTypeConverter()->convertType(srcType);
283  if (!dstType)
284  return failure();
285 
286  auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
287  ShapedType dstAttrType = dstElementsAttr.getType();
288  if (!dstElementsAttr)
289  return failure();
290 
291  // If the composite type has more than one dimensions, perform linearization.
292  if (srcType.getRank() > 1) {
293  if (srcType.isa<RankedTensorType>()) {
294  dstAttrType = RankedTensorType::get(srcType.getNumElements(),
295  srcType.getElementType());
296  dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
297  } else {
298  // TODO: add support for large vectors.
299  return failure();
300  }
301  }
302 
303  Type srcElemType = srcType.getElementType();
304  Type dstElemType;
305  // Tensor types are converted to SPIR-V array types; vector types are
306  // converted to SPIR-V vector/array types.
307  if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
308  dstElemType = arrayType.getElementType();
309  else
310  dstElemType = dstType.cast<VectorType>().getElementType();
311 
312  // If the source and destination element types are different, perform
313  // attribute conversion.
314  if (srcElemType != dstElemType) {
315  SmallVector<Attribute, 8> elements;
316  if (srcElemType.isa<FloatType>()) {
317  for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
318  FloatAttr dstAttr =
319  convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
320  if (!dstAttr)
321  return failure();
322  elements.push_back(dstAttr);
323  }
324  } else if (srcElemType.isInteger(1)) {
325  return failure();
326  } else {
327  for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
328  IntegerAttr dstAttr = convertIntegerAttr(
329  srcAttr, dstElemType.cast<IntegerType>(), rewriter);
330  if (!dstAttr)
331  return failure();
332  elements.push_back(dstAttr);
333  }
334  }
335 
336  // Unfortunately, we cannot use dialect-specific types for element
337  // attributes; element attributes only works with builtin types. So we need
338  // to prepare another converted builtin types for the destination elements
339  // attribute.
340  if (dstAttrType.isa<RankedTensorType>())
341  dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
342  else
343  dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
344 
345  dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
346  }
347 
348  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
349  dstElementsAttr);
350  return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // ConstantOp with scalar type
355 //===----------------------------------------------------------------------===//
356 
357 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358  arith::ConstantOp constOp, OpAdaptor adaptor,
359  ConversionPatternRewriter &rewriter) const {
360  Type srcType = constOp.getType();
361  if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
362  if (shapedType.getNumElements() != 1)
363  return failure();
364  srcType = shapedType.getElementType();
365  }
366  if (!srcType.isIntOrIndexOrFloat())
367  return failure();
368 
369  Attribute cstAttr = constOp.getValue();
370  if (cstAttr.getType().isa<ShapedType>())
371  cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
372 
373  Type dstType = getTypeConverter()->convertType(srcType);
374  if (!dstType)
375  return failure();
376 
377  // Floating-point types.
378  if (srcType.isa<FloatType>()) {
379  auto srcAttr = cstAttr.cast<FloatAttr>();
380  auto dstAttr = srcAttr;
381 
382  // Floating-point types not supported in the target environment are all
383  // converted to float type.
384  if (srcType != dstType) {
385  dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
386  if (!dstAttr)
387  return failure();
388  }
389 
390  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
391  return success();
392  }
393 
394  // Bool type.
395  if (srcType.isInteger(1)) {
396  // arith.constant can use 0/1 instead of true/false for i1 values. We need
397  // to handle that here.
398  auto dstAttr = convertBoolAttr(cstAttr, rewriter);
399  if (!dstAttr)
400  return failure();
401  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
402  return success();
403  }
404 
405  // IndexType or IntegerType. Index values are converted to 32-bit integer
406  // values when converting to SPIR-V.
407  auto srcAttr = cstAttr.cast<IntegerAttr>();
408  auto dstAttr =
409  convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
410  if (!dstAttr)
411  return failure();
412  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413  return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // RemSIOpGLSLPattern
418 //===----------------------------------------------------------------------===//
419 
420 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
421 /// the sign of `signOperand`.
422 ///
423 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
424 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
425 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
426 /// if either operand can be negative. Emulate it via spv.UMod.
427 template <typename SignedAbsOp>
429  Value signOperand, OpBuilder &builder) {
430  assert(lhs.getType() == rhs.getType());
431  assert(lhs == signOperand || rhs == signOperand);
432 
433  Type type = lhs.getType();
434 
435  // Calculate the remainder with spv.UMod.
436  Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
437  Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
438  Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
439 
440  // Fix the sign.
441  Value isPositive;
442  if (lhs == signOperand)
443  isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
444  else
445  isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
446  Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
447  return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
448 }
449 
451 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
452  ConversionPatternRewriter &rewriter) const {
453  Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
454  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
455  adaptor.getOperands()[0], rewriter);
456  rewriter.replaceOp(op, result);
457 
458  return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // RemSIOpOCLPattern
463 //===----------------------------------------------------------------------===//
464 
466 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
467  ConversionPatternRewriter &rewriter) const {
468  Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
469  op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
470  adaptor.getOperands()[0], rewriter);
471  rewriter.replaceOp(op, result);
472 
473  return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // BitwiseOpPattern
478 //===----------------------------------------------------------------------===//
479 
480 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
482 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
483  Op op, typename Op::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const {
485  assert(adaptor.getOperands().size() == 2);
486  auto dstType =
487  this->getTypeConverter()->convertType(op.getResult().getType());
488  if (!dstType)
489  return failure();
490  if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491  rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
492  adaptor.getOperands());
493  } else {
494  rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
495  adaptor.getOperands());
496  }
497  return success();
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // XOrIOpLogicalPattern
502 //===----------------------------------------------------------------------===//
503 
504 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
505  arith::XOrIOp op, OpAdaptor adaptor,
506  ConversionPatternRewriter &rewriter) const {
507  assert(adaptor.getOperands().size() == 2);
508 
509  if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
510  return failure();
511 
512  auto dstType = getTypeConverter()->convertType(op.getType());
513  if (!dstType)
514  return failure();
515  rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
516  adaptor.getOperands());
517 
518  return success();
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // XOrIOpBooleanPattern
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
526  arith::XOrIOp op, OpAdaptor adaptor,
527  ConversionPatternRewriter &rewriter) const {
528  assert(adaptor.getOperands().size() == 2);
529 
530  if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
531  return failure();
532 
533  auto dstType = getTypeConverter()->convertType(op.getType());
534  if (!dstType)
535  return failure();
536  rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
537  adaptor.getOperands());
538  return success();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // UIToFPI1Pattern
543 //===----------------------------------------------------------------------===//
544 
546 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
547  ConversionPatternRewriter &rewriter) const {
548  auto srcType = adaptor.getOperands().front().getType();
549  if (!isBoolScalarOrVector(srcType))
550  return failure();
551 
552  auto dstType =
553  this->getTypeConverter()->convertType(op.getResult().getType());
554  Location loc = op.getLoc();
555  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
556  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
557  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
558  op, dstType, adaptor.getOperands().front(), one, zero);
559  return success();
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // ExtUII1Pattern
564 //===----------------------------------------------------------------------===//
565 
567 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
568  ConversionPatternRewriter &rewriter) const {
569  auto srcType = adaptor.getOperands().front().getType();
570  if (!isBoolScalarOrVector(srcType))
571  return failure();
572 
573  auto dstType =
574  this->getTypeConverter()->convertType(op.getResult().getType());
575  Location loc = op.getLoc();
576  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
577  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
578  rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
579  op, dstType, adaptor.getOperands().front(), one, zero);
580  return success();
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // TruncII1Pattern
585 //===----------------------------------------------------------------------===//
586 
588 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
589  ConversionPatternRewriter &rewriter) const {
590  auto dstType =
591  this->getTypeConverter()->convertType(op.getResult().getType());
592  if (!isBoolScalarOrVector(dstType))
593  return failure();
594 
595  Location loc = op.getLoc();
596  auto srcType = adaptor.getOperands().front().getType();
597  // Check if (x & 1) == 1.
598  Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
599  Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
600  loc, srcType, adaptor.getOperands()[0], mask);
601  Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
602 
603  Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
604  Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
605  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
606  return success();
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // TypeCastingOpPattern
611 //===----------------------------------------------------------------------===//
612 
613 template <typename Op, typename SPIRVOp>
614 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
615  Op op, typename Op::Adaptor adaptor,
616  ConversionPatternRewriter &rewriter) const {
617  assert(adaptor.getOperands().size() == 1);
618  auto srcType = adaptor.getOperands().front().getType();
619  auto dstType =
620  this->getTypeConverter()->convertType(op.getResult().getType());
621  if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
622  return failure();
623  if (dstType == srcType) {
624  // Due to type conversion, we are seeing the same source and target type.
625  // Then we can just erase this operation by forwarding its operand.
626  rewriter.replaceOp(op, adaptor.getOperands().front());
627  } else {
628  rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
629  adaptor.getOperands());
630  }
631  return success();
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // CmpIOpBooleanPattern
636 //===----------------------------------------------------------------------===//
637 
638 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
639  arith::CmpIOp op, OpAdaptor adaptor,
640  ConversionPatternRewriter &rewriter) const {
641  Type operandType = op.getLhs().getType();
642  if (!isBoolScalarOrVector(operandType))
643  return failure();
644 
645  switch (op.getPredicate()) {
646 #define DISPATCH(cmpPredicate, spirvOp) \
647  case cmpPredicate: \
648  rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
649  adaptor.getLhs(), adaptor.getRhs()); \
650  return success();
651 
652  DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
653  DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
654 
655 #undef DISPATCH
656  default:;
657  }
658  return failure();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // CmpIOpPattern
663 //===----------------------------------------------------------------------===//
664 
666 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
667  ConversionPatternRewriter &rewriter) const {
668  Type operandType = op.getLhs().getType();
669  if (isBoolScalarOrVector(operandType))
670  return failure();
671 
672  switch (op.getPredicate()) {
673 #define DISPATCH(cmpPredicate, spirvOp) \
674  case cmpPredicate: \
675  if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
676  operandType != this->getTypeConverter()->convertType(operandType)) { \
677  return op.emitError( \
678  "bitwidth emulation is not implemented yet on unsigned op"); \
679  } \
680  rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
681  adaptor.getLhs(), adaptor.getRhs()); \
682  return success();
683 
684  DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
685  DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
686  DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
687  DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
688  DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
689  DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
690  DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
691  DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
692  DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
693  DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
694 
695 #undef DISPATCH
696  }
697  return failure();
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // CmpFOpPattern
702 //===----------------------------------------------------------------------===//
703 
705 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
706  ConversionPatternRewriter &rewriter) const {
707  switch (op.getPredicate()) {
708 #define DISPATCH(cmpPredicate, spirvOp) \
709  case cmpPredicate: \
710  rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(), \
711  adaptor.getLhs(), adaptor.getRhs()); \
712  return success();
713 
714  // Ordered.
715  DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
716  DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
717  DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
718  DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
719  DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
720  DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
721  // Unordered.
722  DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
723  DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
724  DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
725  DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
726  DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
727  DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
728 
729 #undef DISPATCH
730 
731  default:
732  break;
733  }
734  return failure();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // CmpFOpNanKernelPattern
739 //===----------------------------------------------------------------------===//
740 
741 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
742  arith::CmpFOp op, OpAdaptor adaptor,
743  ConversionPatternRewriter &rewriter) const {
744  if (op.getPredicate() == arith::CmpFPredicate::ORD) {
745  rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
746  adaptor.getRhs());
747  return success();
748  }
749 
750  if (op.getPredicate() == arith::CmpFPredicate::UNO) {
751  rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
752  adaptor.getRhs());
753  return success();
754  }
755 
756  return failure();
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // CmpFOpNanNonePattern
761 //===----------------------------------------------------------------------===//
762 
763 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
764  arith::CmpFOp op, OpAdaptor adaptor,
765  ConversionPatternRewriter &rewriter) const {
766  if (op.getPredicate() != arith::CmpFPredicate::ORD &&
767  op.getPredicate() != arith::CmpFPredicate::UNO)
768  return failure();
769 
770  Location loc = op.getLoc();
771 
772  Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
773  Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
774 
775  Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
776  if (op.getPredicate() == arith::CmpFPredicate::ORD)
777  replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
778 
779  rewriter.replaceOp(op, replace);
780  return success();
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Pattern Population
785 //===----------------------------------------------------------------------===//
786 
788  SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
789  // clang-format off
790  patterns.add<
791  ConstantCompositeOpPattern,
792  ConstantScalarOpPattern,
799  RemSIOpGLSLPattern, RemSIOpOCLPattern,
800  BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
801  BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
802  XOrIOpLogicalPattern, XOrIOpBooleanPattern,
812  TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
813  TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
814  TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
815  TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
816  TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
817  TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
818  TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
819  TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
820  TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
821  TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
822  CmpIOpBooleanPattern, CmpIOpPattern,
823  CmpFOpNanNonePattern, CmpFOpPattern
824  >(typeConverter, patterns.getContext());
825  // clang-format on
826 
827  // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
828  // capability is available.
829  patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
830  /*benefit=*/2);
831 }
832 
833 //===----------------------------------------------------------------------===//
834 // Pass Definition
835 //===----------------------------------------------------------------------===//
836 
837 namespace {
838 struct ConvertArithmeticToSPIRVPass
839  : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
840  void runOnOperation() override {
841  auto module = getOperation()->getParentOfType<ModuleOp>();
842  auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
843  auto target = SPIRVConversionTarget::get(targetAttr);
844 
846  options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
847  SPIRVTypeConverter typeConverter(targetAttr, options);
848 
849  RewritePatternSet patterns(&getContext());
850  mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
851 
852  if (failed(applyPartialConversion(getOperation(), *target,
853  std::move(patterns))))
854  signalPassFailure();
855  }
856 };
857 } // namespace
858 
860  return std::make_unique<ConvertArithmeticToSPIRVPass>();
861 }
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
bool isF32() const
Definition: Types.cpp:23
static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder)
Returns signed remainder for lhs and rhs and lets the result follow the sign of signOperand.
U cast() const
Definition: Attributes.h:123
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition: Types.cpp:89
void populateArithmeticToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition: Pattern.h:21
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
std::unique_ptr< Pass > createConvertArithmeticToSPIRVPass()
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
An attribute that represents a reference to a dense vector or tensor object.
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
#define DISPATCH(cmpPredicate, spirvOp)
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
static llvm::ManagedStatic< PassManagerOptions > options
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
Type getType() const
Return the type of this value.
Definition: Value.h:117
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:124
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
U dyn_cast() const
Definition: Attributes.h:117
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
This provides public APIs that all operations should have.
bool isa() const
Definition: Types.h:234
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:185
This class helps build Operations.
Definition: Builders.h:177
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
MLIRContext * getContext() const
Definition: PatternMatch.h:906
bool emulateNon32BitScalarTypes
Whether to emulate non-32-bit scalar types with 32-bit scalar types if no native support.
Type conversion from builtin types to SPIR-V types for shader interface.
U cast() const
Definition: Types.h:250