MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
187 if (!dstType) {
188 return rewriter.notifyMatchFailure(
189 op->getLoc(),
190 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
191 }
192
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
194 !getElementTypeOrSelf(op.getType()).isIndex() &&
195 dstType != op.getType()) {
196 return op.emitError("bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
198 }
199
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
206 }
207
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
210
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
212 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
213 rewriter.getUnitAttr());
214
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
216 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
217 rewriter.getUnitAttr());
218
219 return success();
220 }
221};
222
223//===----------------------------------------------------------------------===//
224// ConstantOp
225//===----------------------------------------------------------------------===//
226
227/// Converts composite arith.constant operation to spirv.Constant.
228struct ConstantCompositeOpPattern final
229 : public OpConversionPattern<arith::ConstantOp> {
230 using Base::Base;
231
232 LogicalResult
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
237 return failure();
238
239 // arith.constant should only have vector or tensor types. This is a MLIR
240 // wide problem at the moment.
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
243
244 Type dstType = getTypeConverter()->convertType(srcType);
245 if (!dstType)
246 return failure();
247
248 // Import the resource into the IR to make use of the special handling of
249 // element types later on.
250 mlir::DenseElementsAttr dstElementsAttr;
251 if (auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 } else if (auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
256
257 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
258 if (!blob)
259 return constOp->emitError("could not find resource blob");
260
261 ArrayRef<char> ptr = blob->getData();
262
263 // Check that the buffer meets the requirements to get converted to a
264 // DenseElementsAttr
266 return constOp->emitError("resource is not a valid buffer");
267
268 dstElementsAttr =
269 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
270 } else {
271 return constOp->emitError("unsupported elements attribute");
272 }
273
274 ShapedType dstAttrType = dstElementsAttr.getType();
275
276 // If the composite type has more than one dimensions, perform
277 // linearization.
278 if (srcType.getRank() > 1) {
279 if (isa<RankedTensorType>(srcType)) {
280 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
281 srcType.getElementType());
282 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
283 } else {
284 // TODO: add support for large vectors.
285 return failure();
286 }
287 }
288
289 Type srcElemType = srcType.getElementType();
290 Type dstElemType;
291 // Tensor types are converted to SPIR-V array types; vector types are
292 // converted to SPIR-V vector/array types.
293 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
294 dstElemType = arrayType.getElementType();
295 else
296 dstElemType = cast<VectorType>(dstType).getElementType();
297
298 // If the source and destination element types are different, perform
299 // attribute conversion.
300 if (srcElemType != dstElemType) {
302 if (isa<FloatType>(srcElemType)) {
303 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
304 Attribute dstAttr = nullptr;
305 // Handle 8-bit float conversion to 8-bit integer.
306 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
307 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
308 srcElemType.getIntOrFloatBitWidth() == 8 &&
309 isa<IntegerType>(dstElemType)) {
310 dstAttr =
311 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
312 } else {
313 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
314 rewriter);
315 }
316 if (!dstAttr)
317 return failure();
318 elements.push_back(dstAttr);
319 }
320 } else if (srcElemType.isInteger(1)) {
321 return failure();
322 } else {
323 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
324 IntegerAttr dstAttr = convertIntegerAttr(
325 srcAttr, cast<IntegerType>(dstElemType), rewriter);
326 if (!dstAttr)
327 return failure();
328 elements.push_back(dstAttr);
329 }
330 }
331
332 // Unfortunately, we cannot use dialect-specific types for element
333 // attributes; element attributes only works with builtin types. So we
334 // need to prepare another converted builtin types for the destination
335 // elements attribute.
336 if (isa<RankedTensorType>(dstAttrType))
337 dstAttrType =
338 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
339 else
340 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
341
342 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
343 }
344
345 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
346 dstElementsAttr);
347 return success();
348 }
349};
350
351/// Converts scalar arith.constant operation to spirv.Constant.
352struct ConstantScalarOpPattern final
353 : public OpConversionPattern<arith::ConstantOp> {
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Type srcType = constOp.getType();
360 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
361 if (shapedType.getNumElements() != 1)
362 return failure();
363 srcType = shapedType.getElementType();
364 }
365 if (!srcType.isIntOrIndexOrFloat())
366 return failure();
367
368 Attribute cstAttr = constOp.getValue();
369 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
370 cstAttr = elementsAttr.getSplatValue<Attribute>();
371
372 Type dstType = getTypeConverter()->convertType(srcType);
373 if (!dstType)
374 return failure();
375
376 // Floating-point types.
377 if (isa<FloatType>(srcType)) {
378 auto srcAttr = cast<FloatAttr>(cstAttr);
379 Attribute dstAttr = srcAttr;
380
381 // Floating-point types not supported in the target environment are all
382 // converted to float type.
383 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
384 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
385 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
386 dstType.getIntOrFloatBitWidth() == 8) {
387 // If the source is an 8-bit float, convert it to a 8-bit integer.
388 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
389 if (!dstAttr)
390 return failure();
391 } else if (srcType != dstType) {
392 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
393 if (!dstAttr)
394 return failure();
395 }
396
397 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
398 return success();
399 }
400
401 // Bool type.
402 if (srcType.isInteger(1)) {
403 // arith.constant can use 0/1 instead of true/false for i1 values. We need
404 // to handle that here.
405 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
406 if (!dstAttr)
407 return failure();
408 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
409 return success();
410 }
411
412 // IndexType or IntegerType. Index values are converted to 32-bit integer
413 // values when converting to SPIR-V.
414 auto srcAttr = cast<IntegerAttr>(cstAttr);
415 IntegerAttr dstAttr =
416 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
417 if (!dstAttr)
418 return failure();
419 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
420 return success();
421 }
422};
423
424//===----------------------------------------------------------------------===//
425// RemSIOp
426//===----------------------------------------------------------------------===//
427
428/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
429/// the sign of `signOperand`.
430///
431/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
432/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
433/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
434/// if either operand can be negative. Emulate it via spirv.UMod.
435template <typename SignedAbsOp>
436static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
437 Value signOperand, OpBuilder &builder) {
438 assert(lhs.getType() == rhs.getType());
439 assert(lhs == signOperand || rhs == signOperand);
440
441 Type type = lhs.getType();
442
443 // Calculate the remainder with spirv.UMod.
444 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
445 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
446 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
447
448 // Fix the sign.
449 Value isPositive;
450 if (lhs == signOperand)
451 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
452 else
453 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
454 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
455 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
456 absNegate);
457}
458
459/// Converts arith.remsi to GLSL SPIR-V ops.
460///
461/// This cannot be merged into the template unary/binary pattern due to Vulkan
462/// restrictions over spirv.SRem and spirv.SMod.
463struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
464 using Base::Base;
465
466 LogicalResult
467 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override {
469 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
470 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
471 adaptor.getOperands()[0], rewriter);
472 rewriter.replaceOp(op, result);
473
474 return success();
475 }
476};
477
478/// Converts arith.remsi to OpenCL SPIR-V ops.
479struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
480 using Base::Base;
481
482 LogicalResult
483 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
486 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
487 adaptor.getOperands()[0], rewriter);
488 rewriter.replaceOp(op, result);
489
490 return success();
491 }
492};
493
494//===----------------------------------------------------------------------===//
495// BitwiseOp
496//===----------------------------------------------------------------------===//
497
498/// Converts bitwise operations to SPIR-V operations. This is a special pattern
499/// other than the BinaryOpPatternPattern because if the operands are boolean
500/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
501/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
502template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
503struct BitwiseOpPattern final : public OpConversionPattern<Op> {
504 using OpConversionPattern<Op>::OpConversionPattern;
505
506 LogicalResult
507 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 assert(adaptor.getOperands().size() == 2);
510 Type dstType = this->getTypeConverter()->convertType(op.getType());
511 if (!dstType)
512 return getTypeConversionFailure(rewriter, op);
513
514 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
515 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
516 op, dstType, adaptor.getOperands());
517 } else {
518 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
519 op, dstType, adaptor.getOperands());
520 }
521 return success();
522 }
523};
524
525//===----------------------------------------------------------------------===//
526// XOrIOp
527//===----------------------------------------------------------------------===//
528
529/// Converts arith.xori to SPIR-V operations.
530struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
531 using Base::Base;
532
533 LogicalResult
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter) const override {
536 assert(adaptor.getOperands().size() == 2);
537
538 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539 return failure();
540
541 Type dstType = getTypeConverter()->convertType(op.getType());
542 if (!dstType)
543 return getTypeConversionFailure(rewriter, op);
544
545 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
546 adaptor.getOperands());
547
548 return success();
549 }
550};
551
552/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
553/// vector of i1.
554struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
555 using Base::Base;
556
557 LogicalResult
558 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560 assert(adaptor.getOperands().size() == 2);
561
562 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
563 return failure();
564
565 Type dstType = getTypeConverter()->convertType(op.getType());
566 if (!dstType)
567 return getTypeConversionFailure(rewriter, op);
568
569 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
570 op, dstType, adaptor.getOperands());
571 return success();
572 }
573};
574
575//===----------------------------------------------------------------------===//
576// UIToFPOp
577//===----------------------------------------------------------------------===//
578
579/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
580/// of i1.
581struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
582 using Base::Base;
583
584 LogicalResult
585 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 Type srcType = adaptor.getOperands().front().getType();
588 if (!isBoolScalarOrVector(srcType))
589 return failure();
590
591 Type dstType = getTypeConverter()->convertType(op.getType());
592 if (!dstType)
593 return getTypeConversionFailure(rewriter, op);
594
595 Location loc = op.getLoc();
596 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
597 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
598 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
599 op, dstType, adaptor.getOperands().front(), one, zero);
600 return success();
601 }
602};
603
604/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF.
605/// When the source integer type was widened during type conversion (e.g., i8
606/// emulated as i32), the upper bits of the widened value may contain garbage.
607/// This pattern cleans the upper bits before the conversion:
608/// - For unsigned (IsSigned=false): mask with BitwiseAnd.
609/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical +
610/// ShiftRightArithmetic.
611template <typename ArithOp, typename SPIRVOp, bool IsSigned>
612struct IntToFPPattern final : public OpConversionPattern<ArithOp> {
613 using OpConversionPattern<ArithOp>::OpConversionPattern;
614
615 LogicalResult
616 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
617 ConversionPatternRewriter &rewriter) const override {
618 Type srcType = adaptor.getOperands().front().getType();
619 if (isBoolScalarOrVector(srcType))
620 return failure();
621
622 Type dstType = this->getTypeConverter()->convertType(op.getType());
623 if (!dstType)
624 return getTypeConversionFailure(rewriter, op);
625
626 // Check if the source integer type was widened during type conversion.
627 unsigned originalBitwidth =
628 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
629 unsigned convertedBitwidth =
631
632 if (originalBitwidth >= convertedBitwidth) {
633 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
634 return success();
635 }
636
637 // The source was widened. Clean the upper bits before converting.
638 Location loc = op.getLoc();
639 Value cleaned;
640 if constexpr (IsSigned) {
641 // Sign-extend by shifting left then arithmetic right.
642 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
643 Value shiftSize =
644 getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc);
645 Value shifted = spirv::ShiftLeftLogicalOp::create(
646 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
647 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
648 shifted, shiftSize);
649 } else {
650 // Zero-extend by masking off the upper bits.
651 Value mask = getScalarOrVectorConstInt(
652 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
653 loc);
654 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
655 adaptor.getIn(), mask);
656 }
657 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
658 return success();
659 }
660};
661
662//===----------------------------------------------------------------------===//
663// IndexCastOp
664//===----------------------------------------------------------------------===//
665
666/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
667struct IndexCastIndexI1Pattern final
668 : public OpConversionPattern<arith::IndexCastOp> {
669 using Base::Base;
670
671 LogicalResult
672 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter) const override {
674 if (!isBoolScalarOrVector(op.getType()))
675 return failure();
676
677 Type dstType = getTypeConverter()->convertType(op.getType());
678 if (!dstType)
679 return getTypeConversionFailure(rewriter, op);
680
681 Location loc = op.getLoc();
682 Value zeroIdx =
683 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
684 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
685 adaptor.getIn());
686 return success();
687 }
688};
689
690/// Converts arith.index_cast to spirv.Select if the source type is i1.
691struct IndexCastI1IndexPattern final
692 : public OpConversionPattern<arith::IndexCastOp> {
693 using Base::Base;
694
695 LogicalResult
696 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
697 ConversionPatternRewriter &rewriter) const override {
698 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
699 return failure();
700
701 Type dstType = getTypeConverter()->convertType(op.getType());
702 if (!dstType)
703 return getTypeConversionFailure(rewriter, op);
704
705 Location loc = op.getLoc();
706 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
707 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
708 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
709 one, zero);
710 return success();
711 }
712};
713
714//===----------------------------------------------------------------------===//
715// ExtSIOp
716//===----------------------------------------------------------------------===//
717
718/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
719/// of i1.
720struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
721 using Base::Base;
722
723 LogicalResult
724 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
725 ConversionPatternRewriter &rewriter) const override {
726 Value operand = adaptor.getIn();
727 if (!isBoolScalarOrVector(operand.getType()))
728 return failure();
729
730 Location loc = op.getLoc();
731 Type dstType = getTypeConverter()->convertType(op.getType());
732 if (!dstType)
733 return getTypeConversionFailure(rewriter, op);
734
735 Value allOnes;
736 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
737 unsigned componentBitwidth = intTy.getWidth();
738 allOnes = spirv::ConstantOp::create(
739 rewriter, loc, intTy,
740 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
741 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
742 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
743 allOnes = spirv::ConstantOp::create(
744 rewriter, loc, vectorTy,
745 SplatElementsAttr::get(vectorTy,
746 APInt::getAllOnes(componentBitwidth)));
747 } else {
748 return rewriter.notifyMatchFailure(
749 loc, llvm::formatv("unhandled type: {0}", dstType));
751
752 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
753 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
754 zero);
755 return success();
756 }
757};
759/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
760/// vector of i1.
761struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
762 using Base::Base;
763
764 LogicalResult
765 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter) const override {
767 Type srcType = adaptor.getIn().getType();
768 if (isBoolScalarOrVector(srcType))
769 return failure();
770
771 Type dstType = getTypeConverter()->convertType(op.getType());
772 if (!dstType)
773 return getTypeConversionFailure(rewriter, op);
774
775 if (dstType == srcType) {
776 // We can have the same source and destination type due to type emulation.
777 // Perform bit shifting to make sure we have the proper leading set bits.
778
779 unsigned srcBW =
781 unsigned dstBW =
783 assert(srcBW < dstBW);
784 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
785 rewriter, op.getLoc());
786 if (!shiftSize)
787 return rewriter.notifyMatchFailure(op, "unsupported type for shift");
788
789 // First shift left to sequeeze out all leading bits beyond the original
790 // bitwidth. Here we need to use the original source and result type's
791 // bitwidth.
792 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
793 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
794
795 // Then we perform arithmetic right shift to make sure we have the right
796 // sign bits for negative values.
797 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
798 op, dstType, shiftLOp, shiftSize);
799 } else {
800 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
801 adaptor.getOperands());
802 }
804 return success();
805 }
806};
807
808//===----------------------------------------------------------------------===//
809// ExtUIOp
810//===----------------------------------------------------------------------===//
811
812/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
813/// of i1.
814struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
815 using Base::Base;
816
817 LogicalResult
818 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
819 ConversionPatternRewriter &rewriter) const override {
820 Type srcType = adaptor.getOperands().front().getType();
821 if (!isBoolScalarOrVector(srcType))
822 return failure();
823
824 Type dstType = getTypeConverter()->convertType(op.getType());
825 if (!dstType)
826 return getTypeConversionFailure(rewriter, op);
827
828 Location loc = op.getLoc();
829 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
830 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
831 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
832 op, dstType, adaptor.getOperands().front(), one, zero);
833 return success();
834 }
835};
836
837/// Converts arith.extui for cases where the type of source is neither i1 nor
838/// vector of i1.
839struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
840 using Base::Base;
841
842 LogicalResult
843 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
844 ConversionPatternRewriter &rewriter) const override {
845 Type srcType = adaptor.getIn().getType();
846 if (isBoolScalarOrVector(srcType))
847 return failure();
848
849 Type dstType = getTypeConverter()->convertType(op.getType());
850 if (!dstType)
851 return getTypeConversionFailure(rewriter, op);
852
853 if (dstType == srcType) {
854 // We can have the same source and destination type due to type emulation.
855 // Perform bit masking to make sure we don't pollute downstream consumers
856 // with unwanted bits. Here we need to use the original source type's
857 // bitwidth.
858 unsigned bitwidth =
859 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
860 Value mask = getScalarOrVectorConstInt(
861 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
862 op.getLoc());
863 if (!mask)
864 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
865 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
866 adaptor.getIn(), mask);
867 } else {
868 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
869 adaptor.getOperands());
870 }
871 return success();
872 }
873};
874
875//===----------------------------------------------------------------------===//
876// TruncIOp
877//===----------------------------------------------------------------------===//
878
879/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
880/// of i1.
881struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
882 using Base::Base;
883
884 LogicalResult
885 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
886 ConversionPatternRewriter &rewriter) const override {
887 Type dstType = getTypeConverter()->convertType(op.getType());
888 if (!dstType)
889 return getTypeConversionFailure(rewriter, op);
890
891 if (!isBoolScalarOrVector(dstType))
892 return failure();
893
894 Location loc = op.getLoc();
895 auto srcType = adaptor.getOperands().front().getType();
896 // Check if (x & 1) == 1.
897 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
898 Value maskedSrc = spirv::BitwiseAndOp::create(
899 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
900 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
901
902 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
903 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
904 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
905 return success();
906 }
907};
908
909/// Converts arith.trunci for cases where the type of result is neither i1
910/// nor vector of i1.
911struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
912 using Base::Base;
913
914 LogicalResult
915 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
916 ConversionPatternRewriter &rewriter) const override {
917 Type srcType = adaptor.getIn().getType();
918 Type dstType = getTypeConverter()->convertType(op.getType());
919 if (!dstType)
920 return getTypeConversionFailure(rewriter, op);
921
922 if (isBoolScalarOrVector(dstType))
923 return failure();
924
925 if (dstType == srcType) {
926 // We can have the same source and destination type due to type emulation.
927 // Perform bit masking to make sure we don't pollute downstream consumers
928 // with unwanted bits. Here we need to use the original result type's
929 // bitwidth.
930 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
931 Value mask = getScalarOrVectorConstInt(
932 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
933 if (!mask)
934 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
935 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
936 adaptor.getIn(), mask);
937 } else {
938 // Given this is truncation, either SConvertOp or UConvertOp works.
939 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
940 adaptor.getOperands());
941 }
942 return success();
943 }
944};
945
946//===----------------------------------------------------------------------===//
947// TypeCastingOp
948//===----------------------------------------------------------------------===//
949
950static std::optional<spirv::FPRoundingMode>
951convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
952 switch (roundingMode) {
953 case arith::RoundingMode::downward:
954 return spirv::FPRoundingMode::RTN;
955 case arith::RoundingMode::to_nearest_even:
956 return spirv::FPRoundingMode::RTE;
957 case arith::RoundingMode::toward_zero:
958 return spirv::FPRoundingMode::RTZ;
959 case arith::RoundingMode::upward:
960 return spirv::FPRoundingMode::RTP;
961 case arith::RoundingMode::to_nearest_away:
962 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
963 // (as of SPIR-V 1.6)
964 return std::nullopt;
965 }
966 llvm_unreachable("Unhandled rounding mode");
967}
968
969/// Converts type-casting standard operations to SPIR-V operations.
970template <typename Op, typename SPIRVOp>
971struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
972 using OpConversionPattern<Op>::OpConversionPattern;
973
974 LogicalResult
975 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
976 ConversionPatternRewriter &rewriter) const override {
977 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
978 Type dstType = this->getTypeConverter()->convertType(op.getType());
979 if (!dstType)
980 return getTypeConversionFailure(rewriter, op);
981
982 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
983 return failure();
984
985 if (dstType == srcType) {
986 // Due to type conversion, we are seeing the same source and target type.
987 // Then we can just erase this operation by forwarding its operand.
988 rewriter.replaceOp(op, adaptor.getOperands().front());
989 } else {
990 // Compute new rounding mode (if any).
991 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
992 if (auto roundingModeOp =
993 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
994 if (arith::RoundingModeAttr roundingMode =
995 roundingModeOp.getRoundingModeAttr()) {
996 if (!(rm =
997 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
998 return rewriter.notifyMatchFailure(
999 op->getLoc(),
1000 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
1001 }
1002 }
1003 }
1004 // Create replacement op and attach rounding mode attribute (if any).
1005 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1006 op, dstType, adaptor.getOperands());
1007 if (rm) {
1008 newOp->setAttr(
1009 getDecorationString(spirv::Decoration::FPRoundingMode),
1010 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1011 }
1012 }
1013 return success();
1014 }
1015};
1016
1017//===----------------------------------------------------------------------===//
1018// CmpIOp
1019//===----------------------------------------------------------------------===//
1020
1021/// Converts integer compare operation on i1 type operands to SPIR-V ops.
1022class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
1023public:
1024 using Base::Base;
1025
1026 LogicalResult
1027 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1028 ConversionPatternRewriter &rewriter) const override {
1029 Type srcType = op.getLhs().getType();
1030 if (!isBoolScalarOrVector(srcType))
1031 return failure();
1032 Type dstType = getTypeConverter()->convertType(srcType);
1033 if (!dstType)
1034 return getTypeConversionFailure(rewriter, op, srcType);
1035
1036 switch (op.getPredicate()) {
1037 case arith::CmpIPredicate::eq: {
1038 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1039 adaptor.getRhs());
1040 return success();
1041 }
1042 case arith::CmpIPredicate::ne: {
1043 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1044 op, adaptor.getLhs(), adaptor.getRhs());
1045 return success();
1046 }
1047 case arith::CmpIPredicate::uge:
1048 case arith::CmpIPredicate::ugt:
1049 case arith::CmpIPredicate::ule:
1050 case arith::CmpIPredicate::ult: {
1051 // There are no direct corresponding instructions in SPIR-V for such
1052 // cases. Extend them to 32-bit and do comparision then.
1053 Type type = rewriter.getI32Type();
1054 if (auto vectorType = dyn_cast<VectorType>(dstType))
1055 type = VectorType::get(vectorType.getShape(), type);
1056 Value extLhs =
1057 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1058 Value extRhs =
1059 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1060
1061 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1062 extRhs);
1063 return success();
1064 }
1065 default:
1066 break;
1067 }
1068 return failure();
1069 }
1070};
1071
1072/// Converts integer compare operation to SPIR-V ops.
1073class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1074public:
1075 using Base::Base;
1076
1077 LogicalResult
1078 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1079 ConversionPatternRewriter &rewriter) const override {
1080 Type srcType = op.getLhs().getType();
1081 if (isBoolScalarOrVector(srcType))
1082 return failure();
1083 Type dstType = getTypeConverter()->convertType(srcType);
1084 if (!dstType)
1085 return getTypeConversionFailure(rewriter, op, srcType);
1086
1087 switch (op.getPredicate()) {
1088#define DISPATCH(cmpPredicate, spirvOp) \
1089 case cmpPredicate: \
1090 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1091 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1092 !hasSameBitwidth(srcType, dstType)) { \
1093 return op.emitError( \
1094 "bitwidth emulation is not implemented yet on unsigned op"); \
1095 } \
1096 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1097 adaptor.getRhs()); \
1098 return success();
1099
1100 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1101 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1102 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1103 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1104 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1105 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1106 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1107 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1108 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1109 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1110
1111#undef DISPATCH
1112 }
1113 return failure();
1114 }
1115};
1116
1117//===----------------------------------------------------------------------===//
1118// CmpFOpPattern
1119//===----------------------------------------------------------------------===//
1120
1121/// Converts floating-point comparison operations to SPIR-V ops.
1122class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1123public:
1124 using Base::Base;
1125
1126 LogicalResult
1127 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1128 ConversionPatternRewriter &rewriter) const override {
1129 switch (op.getPredicate()) {
1130#define DISPATCH(cmpPredicate, spirvOp) \
1131 case cmpPredicate: \
1132 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1133 adaptor.getRhs()); \
1134 return success();
1135
1136 // Ordered.
1137 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1138 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1139 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1140 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1141 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1142 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1143 // Unordered.
1144 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1145 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1146 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1147 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1148 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1149 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1150
1151#undef DISPATCH
1152
1153 default:
1154 break;
1155 }
1156 return failure();
1157 }
1158};
1159
1160/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1161/// Kernel capability.
1162class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1163public:
1164 using Base::Base;
1165
1166 LogicalResult
1167 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1168 ConversionPatternRewriter &rewriter) const override {
1169 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1170 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1171 adaptor.getRhs());
1172 return success();
1173 }
1174
1175 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1176 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1177 adaptor.getRhs());
1178 return success();
1179 }
1180
1181 return failure();
1182 }
1183};
1184
1185/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1186/// require additional capability.
1187class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1188public:
1189 using Base::Base;
1190
1191 LogicalResult
1192 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1193 ConversionPatternRewriter &rewriter) const override {
1194 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1195 op.getPredicate() != arith::CmpFPredicate::UNO)
1196 return failure();
1197
1198 Location loc = op.getLoc();
1199
1200 Value replace;
1201 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1202 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1203 // Ordered comparsion checks if neither operand is NaN.
1204 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1205 } else {
1206 // Unordered comparsion checks if either operand is NaN.
1207 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1208 }
1209 } else {
1210 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1211 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1212
1213 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1214 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1215 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1216 }
1217
1218 rewriter.replaceOp(op, replace);
1219 return success();
1220 }
1221};
1222
1223//===----------------------------------------------------------------------===//
1224// AddUIExtendedOp
1225//===----------------------------------------------------------------------===//
1226
1227/// Converts arith.addui_extended to spirv.IAddCarry.
1228class AddUIExtendedOpPattern final
1229 : public OpConversionPattern<arith::AddUIExtendedOp> {
1230public:
1231 using Base::Base;
1232 LogicalResult
1233 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1234 ConversionPatternRewriter &rewriter) const override {
1235 Type dstElemTy = adaptor.getLhs().getType();
1236 Location loc = op->getLoc();
1237 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1238 adaptor.getRhs());
1239
1240 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1241 llvm::ArrayRef(0));
1242 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1243 llvm::ArrayRef(1));
1244
1245 // Convert the carry value to boolean.
1246 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1247 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1248
1249 rewriter.replaceOp(op, {sumResult, carryResult});
1250 return success();
1251 }
1252};
1253
1254//===----------------------------------------------------------------------===//
1255// MulIExtendedOp
1256//===----------------------------------------------------------------------===//
1257
1258/// Converts arith.mul*i_extended to spirv.*MulExtended.
1259template <typename ArithMulOp, typename SPIRVMulOp>
1260class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1261public:
1262 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1263 LogicalResult
1264 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const override {
1266 Location loc = op->getLoc();
1267 Value result =
1268 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1269
1270 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1271 llvm::ArrayRef(0));
1272 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1273 llvm::ArrayRef(1));
1274
1275 rewriter.replaceOp(op, {low, high});
1276 return success();
1277 }
1278};
1279
1280//===----------------------------------------------------------------------===//
1281// SelectOp
1282//===----------------------------------------------------------------------===//
1283
1284/// Converts arith.select to spirv.Select.
1285class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1286public:
1287 using Base::Base;
1288 LogicalResult
1289 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1290 ConversionPatternRewriter &rewriter) const override {
1291 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1292 adaptor.getTrueValue(),
1293 adaptor.getFalseValue());
1294 return success();
1295 }
1296};
1297
1298//===----------------------------------------------------------------------===//
1299// MinimumFOp, MaximumFOp
1300//===----------------------------------------------------------------------===//
1301
1302/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1303/// spirv.CL.fmax/fmin.
1304template <typename Op, typename SPIRVOp>
1305class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1306public:
1307 using OpConversionPattern<Op>::OpConversionPattern;
1308 LogicalResult
1309 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1310 ConversionPatternRewriter &rewriter) const override {
1311 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1312 Type dstType = converter->convertType(op.getType());
1313 if (!dstType)
1314 return getTypeConversionFailure(rewriter, op);
1315
1316 // arith.maximumf/minimumf:
1317 // "if one of the arguments is NaN, then the result is also NaN."
1318 // spirv.GL.FMax/FMin
1319 // "which operand is the result is undefined if one of the operands
1320 // is a NaN."
1321 // spirv.CL.fmax/fmin:
1322 // "If one argument is a NaN, Fmin returns the other argument."
1323
1324 Location loc = op.getLoc();
1325 Value spirvOp =
1326 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1327
1328 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1329 rewriter.replaceOp(op, spirvOp);
1330 return success();
1331 }
1332
1333 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1334 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1335
1336 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1337 adaptor.getLhs(), spirvOp);
1338 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1339 adaptor.getRhs(), select1);
1340
1341 rewriter.replaceOp(op, select2);
1342 return success();
1343 }
1344};
1345
1346//===----------------------------------------------------------------------===//
1347// MinNumFOp, MaxNumFOp
1348//===----------------------------------------------------------------------===//
1349
1350/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1351/// spirv.CL.fmax/fmin.
1352template <typename Op, typename SPIRVOp>
1353class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1354 template <typename TargetOp>
1355 constexpr bool shouldInsertNanGuards() const {
1356 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1357 }
1358
1359public:
1360 using OpConversionPattern<Op>::OpConversionPattern;
1361 LogicalResult
1362 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1363 ConversionPatternRewriter &rewriter) const override {
1364 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1365 Type dstType = converter->convertType(op.getType());
1366 if (!dstType)
1367 return getTypeConversionFailure(rewriter, op);
1368
1369 // arith.maxnumf/minnumf:
1370 // "If one of the arguments is NaN, then the result is the other
1371 // argument."
1372 // spirv.GL.FMax/FMin
1373 // "which operand is the result is undefined if one of the operands
1374 // is a NaN."
1375 // spirv.CL.fmax/fmin:
1376 // "If one argument is a NaN, Fmin returns the other argument."
1377
1378 Location loc = op.getLoc();
1379 Value spirvOp =
1380 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1381
1382 if (!shouldInsertNanGuards<SPIRVOp>() ||
1383 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1384 rewriter.replaceOp(op, spirvOp);
1385 return success();
1386 }
1387
1388 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1389 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1390
1391 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1392 adaptor.getRhs(), spirvOp);
1393 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1394 adaptor.getLhs(), select1);
1395
1396 rewriter.replaceOp(op, select2);
1397 return success();
1398 }
1399};
1400
1401} // namespace
1402
1403//===----------------------------------------------------------------------===//
1404// Pattern Population
1405//===----------------------------------------------------------------------===//
1406
1408 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1409 // clang-format off
1410 patterns.add<
1411 ConstantCompositeOpPattern,
1412 ConstantScalarOpPattern,
1413 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1414 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1415 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1419 RemSIOpGLPattern, RemSIOpCLPattern,
1420 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1421 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1422 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1423 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1432 ExtUIPattern, ExtUII1Pattern,
1433 ExtSIPattern, ExtSII1Pattern,
1434 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1435 TruncIPattern, TruncII1Pattern,
1436 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1437 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1438 UIToFPI1Pattern,
1439 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1440 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1441 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1442 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1443 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1444 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1445 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1446 CmpIOpBooleanPattern, CmpIOpPattern,
1447 CmpFOpNanNonePattern, CmpFOpPattern,
1448 AddUIExtendedOpPattern,
1449 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1450 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1451 SelectOpPattern,
1452
1453 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1454 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1455 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1456 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1461
1462 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1463 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1464 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1465 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1470 >(typeConverter, patterns.getContext());
1471 // clang-format on
1472
1473 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1474 // capability is available.
1475 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1476 /*benefit=*/2);
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// Pass Definition
1481//===----------------------------------------------------------------------===//
1482
1483namespace {
1484struct ConvertArithToSPIRVPass
1485 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1486 using Base::Base;
1487
1488 void runOnOperation() override {
1489 Operation *op = getOperation();
1491 std::unique_ptr<SPIRVConversionTarget> target =
1492 SPIRVConversionTarget::get(targetAttr);
1493
1495 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1496 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1497 SPIRVTypeConverter typeConverter(targetAttr, options);
1498
1499 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1500 // in patterns for other dialects.
1501 target->addLegalOp<UnrealizedConversionCastOp>();
1502
1503 // Fail hard when there are any remaining 'arith' ops.
1504 target->addIllegalDialect<arith::ArithDialect>();
1505
1506 RewritePatternSet patterns(&getContext());
1507 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1508
1509 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1510 signalPassFailure();
1511 }
1512};
1513} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
result_type_range getResultTypes()
Definition Operation.h:457
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24