MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
187 if (!dstType) {
188 return rewriter.notifyMatchFailure(
189 op->getLoc(),
190 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
191 }
192
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
194 !getElementTypeOrSelf(op.getType()).isIndex() &&
195 dstType != op.getType()) {
196 return op.emitError("bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
198 }
199
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
206 }
207
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
210
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
212 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
213 rewriter.getUnitAttr());
214
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
216 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
217 rewriter.getUnitAttr());
218
219 return success();
220 }
221};
222
223//===----------------------------------------------------------------------===//
224// ConstantOp
225//===----------------------------------------------------------------------===//
226
227/// Converts composite arith.constant operation to spirv.Constant.
228struct ConstantCompositeOpPattern final
229 : public OpConversionPattern<arith::ConstantOp> {
230 using Base::Base;
231
232 LogicalResult
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
237 return failure();
238
239 // arith.constant should only have vector or tensor types. This is a MLIR
240 // wide problem at the moment.
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
243
244 Type dstType = getTypeConverter()->convertType(srcType);
245 if (!dstType)
246 return failure();
247
248 // Import the resource into the IR to make use of the special handling of
249 // element types later on.
250 mlir::DenseElementsAttr dstElementsAttr;
251 if (auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 } else if (auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
256
257 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
258 if (!blob)
259 return constOp->emitError("could not find resource blob");
260
261 ArrayRef<char> ptr = blob->getData();
262
263 // Check that the buffer meets the requirements to get converted to a
264 // DenseElementsAttr
265 bool detectedSplat = false;
266 if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
267 return constOp->emitError("resource is not a valid buffer");
268
269 dstElementsAttr =
270 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
271 } else {
272 return constOp->emitError("unsupported elements attribute");
273 }
274
275 ShapedType dstAttrType = dstElementsAttr.getType();
276
277 // If the composite type has more than one dimensions, perform
278 // linearization.
279 if (srcType.getRank() > 1) {
280 if (isa<RankedTensorType>(srcType)) {
281 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
282 srcType.getElementType());
283 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
284 } else {
285 // TODO: add support for large vectors.
286 return failure();
287 }
288 }
289
290 Type srcElemType = srcType.getElementType();
291 Type dstElemType;
292 // Tensor types are converted to SPIR-V array types; vector types are
293 // converted to SPIR-V vector/array types.
294 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
295 dstElemType = arrayType.getElementType();
296 else
297 dstElemType = cast<VectorType>(dstType).getElementType();
298
299 // If the source and destination element types are different, perform
300 // attribute conversion.
301 if (srcElemType != dstElemType) {
303 if (isa<FloatType>(srcElemType)) {
304 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
305 Attribute dstAttr = nullptr;
306 // Handle 8-bit float conversion to 8-bit integer.
307 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
308 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
309 srcElemType.getIntOrFloatBitWidth() == 8 &&
310 isa<IntegerType>(dstElemType)) {
311 dstAttr =
312 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
313 } else {
314 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
315 rewriter);
316 }
317 if (!dstAttr)
318 return failure();
319 elements.push_back(dstAttr);
320 }
321 } else if (srcElemType.isInteger(1)) {
322 return failure();
323 } else {
324 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
325 IntegerAttr dstAttr = convertIntegerAttr(
326 srcAttr, cast<IntegerType>(dstElemType), rewriter);
327 if (!dstAttr)
328 return failure();
329 elements.push_back(dstAttr);
330 }
331 }
332
333 // Unfortunately, we cannot use dialect-specific types for element
334 // attributes; element attributes only works with builtin types. So we
335 // need to prepare another converted builtin types for the destination
336 // elements attribute.
337 if (isa<RankedTensorType>(dstAttrType))
338 dstAttrType =
339 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
340 else
341 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
342
343 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
344 }
345
346 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
347 dstElementsAttr);
348 return success();
349 }
350};
351
352/// Converts scalar arith.constant operation to spirv.Constant.
353struct ConstantScalarOpPattern final
354 : public OpConversionPattern<arith::ConstantOp> {
355 using Base::Base;
356
357 LogicalResult
358 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
359 ConversionPatternRewriter &rewriter) const override {
360 Type srcType = constOp.getType();
361 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
362 if (shapedType.getNumElements() != 1)
363 return failure();
364 srcType = shapedType.getElementType();
365 }
366 if (!srcType.isIntOrIndexOrFloat())
367 return failure();
368
369 Attribute cstAttr = constOp.getValue();
370 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
371 cstAttr = elementsAttr.getSplatValue<Attribute>();
372
373 Type dstType = getTypeConverter()->convertType(srcType);
374 if (!dstType)
375 return failure();
376
377 // Floating-point types.
378 if (isa<FloatType>(srcType)) {
379 auto srcAttr = cast<FloatAttr>(cstAttr);
380 Attribute dstAttr = srcAttr;
381
382 // Floating-point types not supported in the target environment are all
383 // converted to float type.
384 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
385 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
386 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
387 dstType.getIntOrFloatBitWidth() == 8) {
388 // If the source is an 8-bit float, convert it to a 8-bit integer.
389 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
390 if (!dstAttr)
391 return failure();
392 } else if (srcType != dstType) {
393 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
394 if (!dstAttr)
395 return failure();
396 }
397
398 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
399 return success();
400 }
401
402 // Bool type.
403 if (srcType.isInteger(1)) {
404 // arith.constant can use 0/1 instead of true/false for i1 values. We need
405 // to handle that here.
406 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
407 if (!dstAttr)
408 return failure();
409 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
410 return success();
411 }
412
413 // IndexType or IntegerType. Index values are converted to 32-bit integer
414 // values when converting to SPIR-V.
415 auto srcAttr = cast<IntegerAttr>(cstAttr);
416 IntegerAttr dstAttr =
417 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
418 if (!dstAttr)
419 return failure();
420 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
421 return success();
422 }
423};
424
425//===----------------------------------------------------------------------===//
426// RemSIOp
427//===----------------------------------------------------------------------===//
428
429/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
430/// the sign of `signOperand`.
431///
432/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
433/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
434/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
435/// if either operand can be negative. Emulate it via spirv.UMod.
436template <typename SignedAbsOp>
437static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
438 Value signOperand, OpBuilder &builder) {
439 assert(lhs.getType() == rhs.getType());
440 assert(lhs == signOperand || rhs == signOperand);
441
442 Type type = lhs.getType();
443
444 // Calculate the remainder with spirv.UMod.
445 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
446 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
447 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
448
449 // Fix the sign.
450 Value isPositive;
451 if (lhs == signOperand)
452 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
453 else
454 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
455 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
456 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
457 absNegate);
458}
459
460/// Converts arith.remsi to GLSL SPIR-V ops.
461///
462/// This cannot be merged into the template unary/binary pattern due to Vulkan
463/// restrictions over spirv.SRem and spirv.SMod.
464struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
465 using Base::Base;
466
467 LogicalResult
468 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
469 ConversionPatternRewriter &rewriter) const override {
470 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
471 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
472 adaptor.getOperands()[0], rewriter);
473 rewriter.replaceOp(op, result);
474
475 return success();
476 }
477};
478
479/// Converts arith.remsi to OpenCL SPIR-V ops.
480struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
481 using Base::Base;
482
483 LogicalResult
484 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
487 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
488 adaptor.getOperands()[0], rewriter);
489 rewriter.replaceOp(op, result);
490
491 return success();
492 }
493};
494
495//===----------------------------------------------------------------------===//
496// BitwiseOp
497//===----------------------------------------------------------------------===//
498
499/// Converts bitwise operations to SPIR-V operations. This is a special pattern
500/// other than the BinaryOpPatternPattern because if the operands are boolean
501/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
502/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
503template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
504struct BitwiseOpPattern final : public OpConversionPattern<Op> {
505 using OpConversionPattern<Op>::OpConversionPattern;
506
507 LogicalResult
508 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
509 ConversionPatternRewriter &rewriter) const override {
510 assert(adaptor.getOperands().size() == 2);
511 Type dstType = this->getTypeConverter()->convertType(op.getType());
512 if (!dstType)
513 return getTypeConversionFailure(rewriter, op);
514
515 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
516 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
517 op, dstType, adaptor.getOperands());
518 } else {
519 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
520 op, dstType, adaptor.getOperands());
521 }
522 return success();
523 }
524};
525
526//===----------------------------------------------------------------------===//
527// XOrIOp
528//===----------------------------------------------------------------------===//
529
530/// Converts arith.xori to SPIR-V operations.
531struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
532 using Base::Base;
533
534 LogicalResult
535 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter) const override {
537 assert(adaptor.getOperands().size() == 2);
538
539 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
540 return failure();
541
542 Type dstType = getTypeConverter()->convertType(op.getType());
543 if (!dstType)
544 return getTypeConversionFailure(rewriter, op);
545
546 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
547 adaptor.getOperands());
548
549 return success();
550 }
551};
552
553/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
554/// vector of i1.
555struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
556 using Base::Base;
557
558 LogicalResult
559 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override {
561 assert(adaptor.getOperands().size() == 2);
562
563 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
564 return failure();
565
566 Type dstType = getTypeConverter()->convertType(op.getType());
567 if (!dstType)
568 return getTypeConversionFailure(rewriter, op);
569
570 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
571 op, dstType, adaptor.getOperands());
572 return success();
573 }
574};
575
576//===----------------------------------------------------------------------===//
577// UIToFPOp
578//===----------------------------------------------------------------------===//
579
580/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
581/// of i1.
582struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
583 using Base::Base;
584
585 LogicalResult
586 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
587 ConversionPatternRewriter &rewriter) const override {
588 Type srcType = adaptor.getOperands().front().getType();
589 if (!isBoolScalarOrVector(srcType))
590 return failure();
591
592 Type dstType = getTypeConverter()->convertType(op.getType());
593 if (!dstType)
594 return getTypeConversionFailure(rewriter, op);
595
596 Location loc = op.getLoc();
597 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
598 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
599 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
600 op, dstType, adaptor.getOperands().front(), one, zero);
601 return success();
602 }
603};
604
605//===----------------------------------------------------------------------===//
606// IndexCastOp
607//===----------------------------------------------------------------------===//
608
609/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
610struct IndexCastIndexI1Pattern final
611 : public OpConversionPattern<arith::IndexCastOp> {
612 using Base::Base;
613
614 LogicalResult
615 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
616 ConversionPatternRewriter &rewriter) const override {
617 if (!isBoolScalarOrVector(op.getType()))
618 return failure();
619
620 Type dstType = getTypeConverter()->convertType(op.getType());
621 if (!dstType)
622 return getTypeConversionFailure(rewriter, op);
623
624 Location loc = op.getLoc();
625 Value zeroIdx =
626 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
627 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
628 adaptor.getIn());
629 return success();
630 }
631};
632
633/// Converts arith.index_cast to spirv.Select if the source type is i1.
634struct IndexCastI1IndexPattern final
635 : public OpConversionPattern<arith::IndexCastOp> {
636 using Base::Base;
637
638 LogicalResult
639 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
640 ConversionPatternRewriter &rewriter) const override {
641 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
642 return failure();
643
644 Type dstType = getTypeConverter()->convertType(op.getType());
645 if (!dstType)
646 return getTypeConversionFailure(rewriter, op);
647
648 Location loc = op.getLoc();
649 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
650 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
651 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
652 one, zero);
653 return success();
654 }
655};
656
657//===----------------------------------------------------------------------===//
658// ExtSIOp
659//===----------------------------------------------------------------------===//
660
661/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
662/// of i1.
663struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
664 using Base::Base;
665
666 LogicalResult
667 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
668 ConversionPatternRewriter &rewriter) const override {
669 Value operand = adaptor.getIn();
670 if (!isBoolScalarOrVector(operand.getType()))
671 return failure();
672
673 Location loc = op.getLoc();
674 Type dstType = getTypeConverter()->convertType(op.getType());
675 if (!dstType)
676 return getTypeConversionFailure(rewriter, op);
677
678 Value allOnes;
679 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
680 unsigned componentBitwidth = intTy.getWidth();
681 allOnes = spirv::ConstantOp::create(
682 rewriter, loc, intTy,
683 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
684 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
685 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
686 allOnes = spirv::ConstantOp::create(
687 rewriter, loc, vectorTy,
688 SplatElementsAttr::get(vectorTy,
689 APInt::getAllOnes(componentBitwidth)));
690 } else {
691 return rewriter.notifyMatchFailure(
692 loc, llvm::formatv("unhandled type: {0}", dstType));
693 }
694
695 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
696 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
697 zero);
698 return success();
699 }
700};
701
702/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
703/// vector of i1.
704struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
705 using Base::Base;
706
707 LogicalResult
708 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
709 ConversionPatternRewriter &rewriter) const override {
710 Type srcType = adaptor.getIn().getType();
711 if (isBoolScalarOrVector(srcType))
712 return failure();
713
714 Type dstType = getTypeConverter()->convertType(op.getType());
715 if (!dstType)
716 return getTypeConversionFailure(rewriter, op);
717
718 if (dstType == srcType) {
719 // We can have the same source and destination type due to type emulation.
720 // Perform bit shifting to make sure we have the proper leading set bits.
721
722 unsigned srcBW =
723 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
724 unsigned dstBW =
726 assert(srcBW < dstBW);
727 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
728 rewriter, op.getLoc());
729
730 // First shift left to sequeeze out all leading bits beyond the original
731 // bitwidth. Here we need to use the original source and result type's
732 // bitwidth.
733 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
734 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
735
736 // Then we perform arithmetic right shift to make sure we have the right
737 // sign bits for negative values.
738 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
739 op, dstType, shiftLOp, shiftSize);
740 } else {
741 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
742 adaptor.getOperands());
743 }
744
745 return success();
746 }
747};
748
749//===----------------------------------------------------------------------===//
750// ExtUIOp
751//===----------------------------------------------------------------------===//
752
753/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
754/// of i1.
755struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
756 using Base::Base;
757
758 LogicalResult
759 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
760 ConversionPatternRewriter &rewriter) const override {
761 Type srcType = adaptor.getOperands().front().getType();
762 if (!isBoolScalarOrVector(srcType))
763 return failure();
764
765 Type dstType = getTypeConverter()->convertType(op.getType());
766 if (!dstType)
767 return getTypeConversionFailure(rewriter, op);
768
769 Location loc = op.getLoc();
770 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
771 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
772 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
773 op, dstType, adaptor.getOperands().front(), one, zero);
774 return success();
775 }
776};
777
778/// Converts arith.extui for cases where the type of source is neither i1 nor
779/// vector of i1.
780struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
781 using Base::Base;
782
783 LogicalResult
784 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
785 ConversionPatternRewriter &rewriter) const override {
786 Type srcType = adaptor.getIn().getType();
787 if (isBoolScalarOrVector(srcType))
788 return failure();
789
790 Type dstType = getTypeConverter()->convertType(op.getType());
791 if (!dstType)
792 return getTypeConversionFailure(rewriter, op);
793
794 if (dstType == srcType) {
795 // We can have the same source and destination type due to type emulation.
796 // Perform bit masking to make sure we don't pollute downstream consumers
797 // with unwanted bits. Here we need to use the original source type's
798 // bitwidth.
799 unsigned bitwidth =
800 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
801 Value mask = getScalarOrVectorConstInt(
802 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
803 op.getLoc());
804 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
805 adaptor.getIn(), mask);
806 } else {
807 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
808 adaptor.getOperands());
809 }
810 return success();
811 }
812};
813
814//===----------------------------------------------------------------------===//
815// TruncIOp
816//===----------------------------------------------------------------------===//
817
818/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
819/// of i1.
820struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
821 using Base::Base;
822
823 LogicalResult
824 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
825 ConversionPatternRewriter &rewriter) const override {
826 Type dstType = getTypeConverter()->convertType(op.getType());
827 if (!dstType)
828 return getTypeConversionFailure(rewriter, op);
829
830 if (!isBoolScalarOrVector(dstType))
831 return failure();
832
833 Location loc = op.getLoc();
834 auto srcType = adaptor.getOperands().front().getType();
835 // Check if (x & 1) == 1.
836 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
837 Value maskedSrc = spirv::BitwiseAndOp::create(
838 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
839 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
840
841 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
842 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
843 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
844 return success();
845 }
846};
847
848/// Converts arith.trunci for cases where the type of result is neither i1
849/// nor vector of i1.
850struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
851 using Base::Base;
852
853 LogicalResult
854 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter) const override {
856 Type srcType = adaptor.getIn().getType();
857 Type dstType = getTypeConverter()->convertType(op.getType());
858 if (!dstType)
859 return getTypeConversionFailure(rewriter, op);
860
861 if (isBoolScalarOrVector(dstType))
862 return failure();
863
864 if (dstType == srcType) {
865 // We can have the same source and destination type due to type emulation.
866 // Perform bit masking to make sure we don't pollute downstream consumers
867 // with unwanted bits. Here we need to use the original result type's
868 // bitwidth.
869 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
870 Value mask = getScalarOrVectorConstInt(
871 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
872 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
873 adaptor.getIn(), mask);
874 } else {
875 // Given this is truncation, either SConvertOp or UConvertOp works.
876 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
877 adaptor.getOperands());
878 }
879 return success();
880 }
881};
882
883//===----------------------------------------------------------------------===//
884// TypeCastingOp
885//===----------------------------------------------------------------------===//
886
887static std::optional<spirv::FPRoundingMode>
888convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
889 switch (roundingMode) {
890 case arith::RoundingMode::downward:
891 return spirv::FPRoundingMode::RTN;
892 case arith::RoundingMode::to_nearest_even:
893 return spirv::FPRoundingMode::RTE;
894 case arith::RoundingMode::toward_zero:
895 return spirv::FPRoundingMode::RTZ;
896 case arith::RoundingMode::upward:
897 return spirv::FPRoundingMode::RTP;
898 case arith::RoundingMode::to_nearest_away:
899 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
900 // (as of SPIR-V 1.6)
901 return std::nullopt;
902 }
903 llvm_unreachable("Unhandled rounding mode");
904}
905
906/// Converts type-casting standard operations to SPIR-V operations.
907template <typename Op, typename SPIRVOp>
908struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
909 using OpConversionPattern<Op>::OpConversionPattern;
910
911 LogicalResult
912 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
913 ConversionPatternRewriter &rewriter) const override {
914 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
915 Type dstType = this->getTypeConverter()->convertType(op.getType());
916 if (!dstType)
917 return getTypeConversionFailure(rewriter, op);
918
919 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
920 return failure();
921
922 if (dstType == srcType) {
923 // Due to type conversion, we are seeing the same source and target type.
924 // Then we can just erase this operation by forwarding its operand.
925 rewriter.replaceOp(op, adaptor.getOperands().front());
926 } else {
927 // Compute new rounding mode (if any).
928 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
929 if (auto roundingModeOp =
930 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
931 if (arith::RoundingModeAttr roundingMode =
932 roundingModeOp.getRoundingModeAttr()) {
933 if (!(rm =
934 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
935 return rewriter.notifyMatchFailure(
936 op->getLoc(),
937 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
938 }
939 }
940 }
941 // Create replacement op and attach rounding mode attribute (if any).
942 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
943 op, dstType, adaptor.getOperands());
944 if (rm) {
945 newOp->setAttr(
946 getDecorationString(spirv::Decoration::FPRoundingMode),
947 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
948 }
949 }
950 return success();
951 }
952};
953
954//===----------------------------------------------------------------------===//
955// CmpIOp
956//===----------------------------------------------------------------------===//
957
958/// Converts integer compare operation on i1 type operands to SPIR-V ops.
959class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
960public:
961 using Base::Base;
962
963 LogicalResult
964 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
965 ConversionPatternRewriter &rewriter) const override {
966 Type srcType = op.getLhs().getType();
967 if (!isBoolScalarOrVector(srcType))
968 return failure();
969 Type dstType = getTypeConverter()->convertType(srcType);
970 if (!dstType)
971 return getTypeConversionFailure(rewriter, op, srcType);
972
973 switch (op.getPredicate()) {
974 case arith::CmpIPredicate::eq: {
975 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
976 adaptor.getRhs());
977 return success();
978 }
979 case arith::CmpIPredicate::ne: {
980 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
981 op, adaptor.getLhs(), adaptor.getRhs());
982 return success();
983 }
984 case arith::CmpIPredicate::uge:
985 case arith::CmpIPredicate::ugt:
986 case arith::CmpIPredicate::ule:
987 case arith::CmpIPredicate::ult: {
988 // There are no direct corresponding instructions in SPIR-V for such
989 // cases. Extend them to 32-bit and do comparision then.
990 Type type = rewriter.getI32Type();
991 if (auto vectorType = dyn_cast<VectorType>(dstType))
992 type = VectorType::get(vectorType.getShape(), type);
993 Value extLhs =
994 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
995 Value extRhs =
996 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
997
998 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
999 extRhs);
1000 return success();
1001 }
1002 default:
1003 break;
1004 }
1005 return failure();
1006 }
1007};
1008
1009/// Converts integer compare operation to SPIR-V ops.
1010class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1011public:
1012 using Base::Base;
1013
1014 LogicalResult
1015 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1016 ConversionPatternRewriter &rewriter) const override {
1017 Type srcType = op.getLhs().getType();
1018 if (isBoolScalarOrVector(srcType))
1019 return failure();
1020 Type dstType = getTypeConverter()->convertType(srcType);
1021 if (!dstType)
1022 return getTypeConversionFailure(rewriter, op, srcType);
1023
1024 switch (op.getPredicate()) {
1025#define DISPATCH(cmpPredicate, spirvOp) \
1026 case cmpPredicate: \
1027 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1028 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1029 !hasSameBitwidth(srcType, dstType)) { \
1030 return op.emitError( \
1031 "bitwidth emulation is not implemented yet on unsigned op"); \
1032 } \
1033 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1034 adaptor.getRhs()); \
1035 return success();
1036
1037 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1038 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1039 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1040 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1041 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1042 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1044 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1045 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1046 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1047
1048#undef DISPATCH
1049 }
1050 return failure();
1051 }
1052};
1053
1054//===----------------------------------------------------------------------===//
1055// CmpFOpPattern
1056//===----------------------------------------------------------------------===//
1057
1058/// Converts floating-point comparison operations to SPIR-V ops.
1059class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1060public:
1061 using Base::Base;
1062
1063 LogicalResult
1064 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1065 ConversionPatternRewriter &rewriter) const override {
1066 switch (op.getPredicate()) {
1067#define DISPATCH(cmpPredicate, spirvOp) \
1068 case cmpPredicate: \
1069 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1070 adaptor.getRhs()); \
1071 return success();
1072
1073 // Ordered.
1074 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1075 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1076 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1077 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1078 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1079 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1080 // Unordered.
1081 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1082 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1083 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1085 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1086 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1087
1088#undef DISPATCH
1089
1090 default:
1091 break;
1092 }
1093 return failure();
1094 }
1095};
1096
1097/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1098/// Kernel capability.
1099class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1100public:
1101 using Base::Base;
1102
1103 LogicalResult
1104 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1105 ConversionPatternRewriter &rewriter) const override {
1106 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1107 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1108 adaptor.getRhs());
1109 return success();
1110 }
1111
1112 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1113 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1114 adaptor.getRhs());
1115 return success();
1116 }
1117
1118 return failure();
1119 }
1120};
1121
1122/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1123/// require additional capability.
1124class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1125public:
1126 using Base::Base;
1127
1128 LogicalResult
1129 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1130 ConversionPatternRewriter &rewriter) const override {
1131 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1132 op.getPredicate() != arith::CmpFPredicate::UNO)
1133 return failure();
1134
1135 Location loc = op.getLoc();
1136
1137 Value replace;
1138 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1139 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1140 // Ordered comparsion checks if neither operand is NaN.
1141 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1142 } else {
1143 // Unordered comparsion checks if either operand is NaN.
1144 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1145 }
1146 } else {
1147 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1148 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1149
1150 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1151 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1152 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1153 }
1154
1155 rewriter.replaceOp(op, replace);
1156 return success();
1157 }
1158};
1159
1160//===----------------------------------------------------------------------===//
1161// AddUIExtendedOp
1162//===----------------------------------------------------------------------===//
1163
1164/// Converts arith.addui_extended to spirv.IAddCarry.
1165class AddUIExtendedOpPattern final
1166 : public OpConversionPattern<arith::AddUIExtendedOp> {
1167public:
1168 using Base::Base;
1169 LogicalResult
1170 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1171 ConversionPatternRewriter &rewriter) const override {
1172 Type dstElemTy = adaptor.getLhs().getType();
1173 Location loc = op->getLoc();
1174 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1175 adaptor.getRhs());
1176
1177 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1178 llvm::ArrayRef(0));
1179 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1180 llvm::ArrayRef(1));
1181
1182 // Convert the carry value to boolean.
1183 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1184 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1185
1186 rewriter.replaceOp(op, {sumResult, carryResult});
1187 return success();
1188 }
1189};
1190
1191//===----------------------------------------------------------------------===//
1192// MulIExtendedOp
1193//===----------------------------------------------------------------------===//
1194
1195/// Converts arith.mul*i_extended to spirv.*MulExtended.
1196template <typename ArithMulOp, typename SPIRVMulOp>
1197class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1198public:
1199 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1200 LogicalResult
1201 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1202 ConversionPatternRewriter &rewriter) const override {
1203 Location loc = op->getLoc();
1204 Value result =
1205 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1206
1207 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1208 llvm::ArrayRef(0));
1209 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1210 llvm::ArrayRef(1));
1211
1212 rewriter.replaceOp(op, {low, high});
1213 return success();
1214 }
1215};
1216
1217//===----------------------------------------------------------------------===//
1218// SelectOp
1219//===----------------------------------------------------------------------===//
1220
1221/// Converts arith.select to spirv.Select.
1222class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1223public:
1224 using Base::Base;
1225 LogicalResult
1226 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1227 ConversionPatternRewriter &rewriter) const override {
1228 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1229 adaptor.getTrueValue(),
1230 adaptor.getFalseValue());
1231 return success();
1232 }
1233};
1234
1235//===----------------------------------------------------------------------===//
1236// MinimumFOp, MaximumFOp
1237//===----------------------------------------------------------------------===//
1238
1239/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1240/// spirv.CL.fmax/fmin.
1241template <typename Op, typename SPIRVOp>
1242class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1243public:
1244 using OpConversionPattern<Op>::OpConversionPattern;
1245 LogicalResult
1246 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1247 ConversionPatternRewriter &rewriter) const override {
1248 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1249 Type dstType = converter->convertType(op.getType());
1250 if (!dstType)
1251 return getTypeConversionFailure(rewriter, op);
1252
1253 // arith.maximumf/minimumf:
1254 // "if one of the arguments is NaN, then the result is also NaN."
1255 // spirv.GL.FMax/FMin
1256 // "which operand is the result is undefined if one of the operands
1257 // is a NaN."
1258 // spirv.CL.fmax/fmin:
1259 // "If one argument is a NaN, Fmin returns the other argument."
1260
1261 Location loc = op.getLoc();
1262 Value spirvOp =
1263 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1264
1265 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1266 rewriter.replaceOp(op, spirvOp);
1267 return success();
1268 }
1269
1270 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1271 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1272
1273 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1274 adaptor.getLhs(), spirvOp);
1275 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1276 adaptor.getRhs(), select1);
1277
1278 rewriter.replaceOp(op, select2);
1279 return success();
1280 }
1281};
1282
1283//===----------------------------------------------------------------------===//
1284// MinNumFOp, MaxNumFOp
1285//===----------------------------------------------------------------------===//
1286
1287/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1288/// spirv.CL.fmax/fmin.
1289template <typename Op, typename SPIRVOp>
1290class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1291 template <typename TargetOp>
1292 constexpr bool shouldInsertNanGuards() const {
1293 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1294 }
1295
1296public:
1297 using OpConversionPattern<Op>::OpConversionPattern;
1298 LogicalResult
1299 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1300 ConversionPatternRewriter &rewriter) const override {
1301 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1302 Type dstType = converter->convertType(op.getType());
1303 if (!dstType)
1304 return getTypeConversionFailure(rewriter, op);
1305
1306 // arith.maxnumf/minnumf:
1307 // "If one of the arguments is NaN, then the result is the other
1308 // argument."
1309 // spirv.GL.FMax/FMin
1310 // "which operand is the result is undefined if one of the operands
1311 // is a NaN."
1312 // spirv.CL.fmax/fmin:
1313 // "If one argument is a NaN, Fmin returns the other argument."
1314
1315 Location loc = op.getLoc();
1316 Value spirvOp =
1317 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1318
1319 if (!shouldInsertNanGuards<SPIRVOp>() ||
1320 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1321 rewriter.replaceOp(op, spirvOp);
1322 return success();
1323 }
1324
1325 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1326 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1327
1328 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1329 adaptor.getRhs(), spirvOp);
1330 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1331 adaptor.getLhs(), select1);
1332
1333 rewriter.replaceOp(op, select2);
1334 return success();
1335 }
1336};
1337
1338} // namespace
1339
1340//===----------------------------------------------------------------------===//
1341// Pattern Population
1342//===----------------------------------------------------------------------===//
1343
1345 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1346 // clang-format off
1347 patterns.add<
1348 ConstantCompositeOpPattern,
1349 ConstantScalarOpPattern,
1350 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1351 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1352 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1356 RemSIOpGLPattern, RemSIOpCLPattern,
1357 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1358 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1359 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1360 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1369 ExtUIPattern, ExtUII1Pattern,
1370 ExtSIPattern, ExtSII1Pattern,
1371 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1372 TruncIPattern, TruncII1Pattern,
1373 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1374 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1375 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1376 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1377 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1378 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1379 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1380 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1381 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1382 CmpIOpBooleanPattern, CmpIOpPattern,
1383 CmpFOpNanNonePattern, CmpFOpPattern,
1384 AddUIExtendedOpPattern,
1385 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1386 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1387 SelectOpPattern,
1388
1389 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1390 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1391 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1392 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1397
1398 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1399 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1400 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1401 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1406 >(typeConverter, patterns.getContext());
1407 // clang-format on
1408
1409 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1410 // capability is available.
1411 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1412 /*benefit=*/2);
1413}
1414
1415//===----------------------------------------------------------------------===//
1416// Pass Definition
1417//===----------------------------------------------------------------------===//
1418
1419namespace {
1420struct ConvertArithToSPIRVPass
1421 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1422 using Base::Base;
1423
1424 void runOnOperation() override {
1425 Operation *op = getOperation();
1427 std::unique_ptr<SPIRVConversionTarget> target =
1428 SPIRVConversionTarget::get(targetAttr);
1429
1431 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1432 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1433 SPIRVTypeConverter typeConverter(targetAttr, options);
1434
1435 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1436 // in patterns for other dialects.
1437 target->addLegalOp<UnrealizedConversionCastOp>();
1438
1439 // Fail hard when there are any remaining 'arith' ops.
1440 target->addIllegalDialect<arith::ArithDialect>();
1441
1444
1445 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1446 signalPassFailure();
1447 }
1448};
1449} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:246
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24