MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
187 if (!dstType) {
188 return rewriter.notifyMatchFailure(
189 op->getLoc(),
190 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
191 }
192
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
194 !getElementTypeOrSelf(op.getType()).isIndex() &&
195 dstType != op.getType()) {
196 return op.emitError("bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
198 }
199
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
206 }
207
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
210
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
212 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
213 rewriter.getUnitAttr());
214
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
216 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
217 rewriter.getUnitAttr());
218
219 return success();
220 }
221};
222
223//===----------------------------------------------------------------------===//
224// ConstantOp
225//===----------------------------------------------------------------------===//
226
227/// Converts composite arith.constant operation to spirv.Constant.
228struct ConstantCompositeOpPattern final
229 : public OpConversionPattern<arith::ConstantOp> {
230 using Base::Base;
231
232 LogicalResult
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
237 return failure();
238
239 // arith.constant should only have vector or tensor types. This is a MLIR
240 // wide problem at the moment.
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
243
244 Type dstType = getTypeConverter()->convertType(srcType);
245 if (!dstType)
246 return failure();
247
248 // Import the resource into the IR to make use of the special handling of
249 // element types later on.
250 mlir::DenseElementsAttr dstElementsAttr;
251 if (auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 } else if (auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
256
257 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
258 if (!blob)
259 return constOp->emitError("could not find resource blob");
260
261 ArrayRef<char> ptr = blob->getData();
262
263 // Check that the buffer meets the requirements to get converted to a
264 // DenseElementsAttr
266 return constOp->emitError("resource is not a valid buffer");
267
268 dstElementsAttr =
269 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
270 } else {
271 return constOp->emitError("unsupported elements attribute");
272 }
273
274 ShapedType dstAttrType = dstElementsAttr.getType();
275
276 // If the composite type has more than one dimensions, perform
277 // linearization.
278 if (srcType.getRank() > 1) {
279 if (isa<RankedTensorType>(srcType)) {
280 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
281 srcType.getElementType());
282 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
283 } else {
284 // TODO: add support for large vectors.
285 return failure();
286 }
287 }
288
289 Type srcElemType = srcType.getElementType();
290 Type dstElemType;
291 // Tensor types are converted to SPIR-V array types; vector types are
292 // converted to SPIR-V vector/array types.
293 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
294 dstElemType = arrayType.getElementType();
295 else
296 dstElemType = cast<VectorType>(dstType).getElementType();
297
298 // If the source and destination element types are different, perform
299 // attribute conversion.
300 if (srcElemType != dstElemType) {
302 if (isa<FloatType>(srcElemType)) {
303 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
304 Attribute dstAttr = nullptr;
305 // Handle 8-bit float conversion to 8-bit integer.
306 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
307 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
308 srcElemType.getIntOrFloatBitWidth() == 8 &&
309 isa<IntegerType>(dstElemType)) {
310 dstAttr =
311 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
312 } else {
313 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
314 rewriter);
315 }
316 if (!dstAttr)
317 return failure();
318 elements.push_back(dstAttr);
319 }
320 } else if (srcElemType.isInteger(1)) {
321 return failure();
322 } else {
323 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
324 IntegerAttr dstAttr = convertIntegerAttr(
325 srcAttr, cast<IntegerType>(dstElemType), rewriter);
326 if (!dstAttr)
327 return failure();
328 elements.push_back(dstAttr);
329 }
330 }
331
332 // Unfortunately, we cannot use dialect-specific types for element
333 // attributes; element attributes only works with builtin types. So we
334 // need to prepare another converted builtin types for the destination
335 // elements attribute.
336 if (isa<RankedTensorType>(dstAttrType))
337 dstAttrType =
338 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
339 else
340 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
341
342 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
343 }
344
345 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
346 dstElementsAttr);
347 return success();
348 }
349};
350
351/// Converts scalar arith.constant operation to spirv.Constant.
352struct ConstantScalarOpPattern final
353 : public OpConversionPattern<arith::ConstantOp> {
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Type srcType = constOp.getType();
360 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
361 if (shapedType.getNumElements() != 1)
362 return failure();
363 srcType = shapedType.getElementType();
364 }
365 if (!srcType.isIntOrIndexOrFloat())
366 return failure();
367
368 Attribute cstAttr = constOp.getValue();
369 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
370 cstAttr = elementsAttr.getSplatValue<Attribute>();
371
372 Type dstType = getTypeConverter()->convertType(srcType);
373 if (!dstType)
374 return failure();
375
376 // Floating-point types.
377 if (isa<FloatType>(srcType)) {
378 auto srcAttr = cast<FloatAttr>(cstAttr);
379 Attribute dstAttr = srcAttr;
380
381 // Floating-point types not supported in the target environment are all
382 // converted to float type.
383 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
384 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
385 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
386 dstType.getIntOrFloatBitWidth() == 8) {
387 // If the source is an 8-bit float, convert it to a 8-bit integer.
388 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
389 if (!dstAttr)
390 return failure();
391 } else if (srcType != dstType) {
392 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
393 if (!dstAttr)
394 return failure();
395 }
396
397 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
398 return success();
399 }
400
401 // Bool type.
402 if (srcType.isInteger(1)) {
403 // arith.constant can use 0/1 instead of true/false for i1 values. We need
404 // to handle that here.
405 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
406 if (!dstAttr)
407 return failure();
408 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
409 return success();
410 }
411
412 // IndexType or IntegerType. Index values are converted to 32-bit integer
413 // values when converting to SPIR-V.
414 auto srcAttr = cast<IntegerAttr>(cstAttr);
415 IntegerAttr dstAttr =
416 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
417 if (!dstAttr)
418 return failure();
419 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
420 return success();
421 }
422};
423
424//===----------------------------------------------------------------------===//
425// RemSIOp
426//===----------------------------------------------------------------------===//
427
428/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
429/// the sign of `signOperand`.
430///
431/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
432/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
433/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
434/// if either operand can be negative. Emulate it via spirv.UMod.
435template <typename SignedAbsOp>
436static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
437 Value signOperand, OpBuilder &builder) {
438 assert(lhs.getType() == rhs.getType());
439 assert(lhs == signOperand || rhs == signOperand);
440
441 Type type = lhs.getType();
442
443 // Calculate the remainder with spirv.UMod.
444 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
445 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
446 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
447
448 // Fix the sign.
449 Value isPositive;
450 if (lhs == signOperand)
451 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
452 else
453 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
454 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
455 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
456 absNegate);
457}
458
459/// Converts arith.remsi to GLSL SPIR-V ops.
460///
461/// This cannot be merged into the template unary/binary pattern due to Vulkan
462/// restrictions over spirv.SRem and spirv.SMod.
463struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
464 using Base::Base;
465
466 LogicalResult
467 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override {
469 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
470 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
471 adaptor.getOperands()[0], rewriter);
472 rewriter.replaceOp(op, result);
473
474 return success();
475 }
476};
477
478/// Converts arith.remsi to OpenCL SPIR-V ops.
479struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
480 using Base::Base;
481
482 LogicalResult
483 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
486 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
487 adaptor.getOperands()[0], rewriter);
488 rewriter.replaceOp(op, result);
489
490 return success();
491 }
492};
493
494//===----------------------------------------------------------------------===//
495// BitwiseOp
496//===----------------------------------------------------------------------===//
497
498/// Converts bitwise operations to SPIR-V operations. This is a special pattern
499/// other than the BinaryOpPatternPattern because if the operands are boolean
500/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
501/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
502template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
503struct BitwiseOpPattern final : public OpConversionPattern<Op> {
504 using OpConversionPattern<Op>::OpConversionPattern;
505
506 LogicalResult
507 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 assert(adaptor.getOperands().size() == 2);
510 Type dstType = this->getTypeConverter()->convertType(op.getType());
511 if (!dstType)
512 return getTypeConversionFailure(rewriter, op);
513
514 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
515 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
516 op, dstType, adaptor.getOperands());
517 } else {
518 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
519 op, dstType, adaptor.getOperands());
520 }
521 return success();
522 }
523};
524
525//===----------------------------------------------------------------------===//
526// XOrIOp
527//===----------------------------------------------------------------------===//
528
529/// Converts arith.xori to SPIR-V operations.
530struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
531 using Base::Base;
532
533 LogicalResult
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter) const override {
536 assert(adaptor.getOperands().size() == 2);
537
538 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539 return failure();
540
541 Type dstType = getTypeConverter()->convertType(op.getType());
542 if (!dstType)
543 return getTypeConversionFailure(rewriter, op);
544
545 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
546 adaptor.getOperands());
547
548 return success();
549 }
550};
551
552/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
553/// vector of i1.
554struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
555 using Base::Base;
556
557 LogicalResult
558 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560 assert(adaptor.getOperands().size() == 2);
561
562 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
563 return failure();
564
565 Type dstType = getTypeConverter()->convertType(op.getType());
566 if (!dstType)
567 return getTypeConversionFailure(rewriter, op);
568
569 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
570 op, dstType, adaptor.getOperands());
571 return success();
572 }
573};
574
575//===----------------------------------------------------------------------===//
576// UIToFPOp
577//===----------------------------------------------------------------------===//
578
579/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
580/// of i1.
581struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
582 using Base::Base;
583
584 LogicalResult
585 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 Type srcType = adaptor.getOperands().front().getType();
588 if (!isBoolScalarOrVector(srcType))
589 return failure();
590
591 Type dstType = getTypeConverter()->convertType(op.getType());
592 if (!dstType)
593 return getTypeConversionFailure(rewriter, op);
594
595 Location loc = op.getLoc();
596 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
597 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
598 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
599 op, dstType, adaptor.getOperands().front(), one, zero);
600 return success();
601 }
602};
603
604//===----------------------------------------------------------------------===//
605// IndexCastOp
606//===----------------------------------------------------------------------===//
607
608/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
609struct IndexCastIndexI1Pattern final
610 : public OpConversionPattern<arith::IndexCastOp> {
611 using Base::Base;
612
613 LogicalResult
614 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const override {
616 if (!isBoolScalarOrVector(op.getType()))
617 return failure();
618
619 Type dstType = getTypeConverter()->convertType(op.getType());
620 if (!dstType)
621 return getTypeConversionFailure(rewriter, op);
622
623 Location loc = op.getLoc();
624 Value zeroIdx =
625 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
626 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
627 adaptor.getIn());
628 return success();
629 }
630};
631
632/// Converts arith.index_cast to spirv.Select if the source type is i1.
633struct IndexCastI1IndexPattern final
634 : public OpConversionPattern<arith::IndexCastOp> {
635 using Base::Base;
636
637 LogicalResult
638 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
641 return failure();
642
643 Type dstType = getTypeConverter()->convertType(op.getType());
644 if (!dstType)
645 return getTypeConversionFailure(rewriter, op);
646
647 Location loc = op.getLoc();
648 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
649 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
650 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
651 one, zero);
652 return success();
653 }
654};
655
656//===----------------------------------------------------------------------===//
657// ExtSIOp
658//===----------------------------------------------------------------------===//
659
660/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
661/// of i1.
662struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
663 using Base::Base;
664
665 LogicalResult
666 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668 Value operand = adaptor.getIn();
669 if (!isBoolScalarOrVector(operand.getType()))
670 return failure();
671
672 Location loc = op.getLoc();
673 Type dstType = getTypeConverter()->convertType(op.getType());
674 if (!dstType)
675 return getTypeConversionFailure(rewriter, op);
676
677 Value allOnes;
678 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
679 unsigned componentBitwidth = intTy.getWidth();
680 allOnes = spirv::ConstantOp::create(
681 rewriter, loc, intTy,
682 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
683 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
684 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
685 allOnes = spirv::ConstantOp::create(
686 rewriter, loc, vectorTy,
687 SplatElementsAttr::get(vectorTy,
688 APInt::getAllOnes(componentBitwidth)));
689 } else {
690 return rewriter.notifyMatchFailure(
691 loc, llvm::formatv("unhandled type: {0}", dstType));
692 }
693
694 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
696 zero);
697 return success();
698 }
699};
700
701/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
702/// vector of i1.
703struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
704 using Base::Base;
705
706 LogicalResult
707 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const override {
709 Type srcType = adaptor.getIn().getType();
710 if (isBoolScalarOrVector(srcType))
711 return failure();
712
713 Type dstType = getTypeConverter()->convertType(op.getType());
714 if (!dstType)
715 return getTypeConversionFailure(rewriter, op);
716
717 if (dstType == srcType) {
718 // We can have the same source and destination type due to type emulation.
719 // Perform bit shifting to make sure we have the proper leading set bits.
720
721 unsigned srcBW =
722 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
723 unsigned dstBW =
725 assert(srcBW < dstBW);
726 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
727 rewriter, op.getLoc());
728 if (!shiftSize)
729 return rewriter.notifyMatchFailure(op, "unsupported type for shift");
730
731 // First shift left to sequeeze out all leading bits beyond the original
732 // bitwidth. Here we need to use the original source and result type's
733 // bitwidth.
734 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
735 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
736
737 // Then we perform arithmetic right shift to make sure we have the right
738 // sign bits for negative values.
739 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
740 op, dstType, shiftLOp, shiftSize);
741 } else {
742 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
743 adaptor.getOperands());
744 }
745
746 return success();
747 }
748};
749
750//===----------------------------------------------------------------------===//
751// ExtUIOp
752//===----------------------------------------------------------------------===//
753
754/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
755/// of i1.
756struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
757 using Base::Base;
758
759 LogicalResult
760 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
761 ConversionPatternRewriter &rewriter) const override {
762 Type srcType = adaptor.getOperands().front().getType();
763 if (!isBoolScalarOrVector(srcType))
764 return failure();
765
766 Type dstType = getTypeConverter()->convertType(op.getType());
767 if (!dstType)
768 return getTypeConversionFailure(rewriter, op);
769
770 Location loc = op.getLoc();
771 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
772 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
773 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
774 op, dstType, adaptor.getOperands().front(), one, zero);
775 return success();
776 }
777};
778
779/// Converts arith.extui for cases where the type of source is neither i1 nor
780/// vector of i1.
781struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
782 using Base::Base;
783
784 LogicalResult
785 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
786 ConversionPatternRewriter &rewriter) const override {
787 Type srcType = adaptor.getIn().getType();
788 if (isBoolScalarOrVector(srcType))
789 return failure();
790
791 Type dstType = getTypeConverter()->convertType(op.getType());
792 if (!dstType)
793 return getTypeConversionFailure(rewriter, op);
794
795 if (dstType == srcType) {
796 // We can have the same source and destination type due to type emulation.
797 // Perform bit masking to make sure we don't pollute downstream consumers
798 // with unwanted bits. Here we need to use the original source type's
799 // bitwidth.
800 unsigned bitwidth =
801 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
802 Value mask = getScalarOrVectorConstInt(
803 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
804 op.getLoc());
805 if (!mask)
806 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
807 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
808 adaptor.getIn(), mask);
809 } else {
810 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
811 adaptor.getOperands());
812 }
813 return success();
814 }
815};
816
817//===----------------------------------------------------------------------===//
818// TruncIOp
819//===----------------------------------------------------------------------===//
820
821/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
822/// of i1.
823struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
824 using Base::Base;
825
826 LogicalResult
827 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter) const override {
829 Type dstType = getTypeConverter()->convertType(op.getType());
830 if (!dstType)
831 return getTypeConversionFailure(rewriter, op);
832
833 if (!isBoolScalarOrVector(dstType))
834 return failure();
835
836 Location loc = op.getLoc();
837 auto srcType = adaptor.getOperands().front().getType();
838 // Check if (x & 1) == 1.
839 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
840 Value maskedSrc = spirv::BitwiseAndOp::create(
841 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
842 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
843
844 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
845 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
846 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
847 return success();
848 }
849};
850
851/// Converts arith.trunci for cases where the type of result is neither i1
852/// nor vector of i1.
853struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
854 using Base::Base;
855
856 LogicalResult
857 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
858 ConversionPatternRewriter &rewriter) const override {
859 Type srcType = adaptor.getIn().getType();
860 Type dstType = getTypeConverter()->convertType(op.getType());
861 if (!dstType)
862 return getTypeConversionFailure(rewriter, op);
863
864 if (isBoolScalarOrVector(dstType))
865 return failure();
866
867 if (dstType == srcType) {
868 // We can have the same source and destination type due to type emulation.
869 // Perform bit masking to make sure we don't pollute downstream consumers
870 // with unwanted bits. Here we need to use the original result type's
871 // bitwidth.
872 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
873 Value mask = getScalarOrVectorConstInt(
874 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
875 if (!mask)
876 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
877 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
878 adaptor.getIn(), mask);
879 } else {
880 // Given this is truncation, either SConvertOp or UConvertOp works.
881 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
882 adaptor.getOperands());
883 }
884 return success();
885 }
886};
887
888//===----------------------------------------------------------------------===//
889// TypeCastingOp
890//===----------------------------------------------------------------------===//
891
892static std::optional<spirv::FPRoundingMode>
893convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
894 switch (roundingMode) {
895 case arith::RoundingMode::downward:
896 return spirv::FPRoundingMode::RTN;
897 case arith::RoundingMode::to_nearest_even:
898 return spirv::FPRoundingMode::RTE;
899 case arith::RoundingMode::toward_zero:
900 return spirv::FPRoundingMode::RTZ;
901 case arith::RoundingMode::upward:
902 return spirv::FPRoundingMode::RTP;
903 case arith::RoundingMode::to_nearest_away:
904 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
905 // (as of SPIR-V 1.6)
906 return std::nullopt;
907 }
908 llvm_unreachable("Unhandled rounding mode");
909}
910
911/// Converts type-casting standard operations to SPIR-V operations.
912template <typename Op, typename SPIRVOp>
913struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
914 using OpConversionPattern<Op>::OpConversionPattern;
915
916 LogicalResult
917 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
918 ConversionPatternRewriter &rewriter) const override {
919 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
920 Type dstType = this->getTypeConverter()->convertType(op.getType());
921 if (!dstType)
922 return getTypeConversionFailure(rewriter, op);
923
924 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
925 return failure();
926
927 if (dstType == srcType) {
928 // Due to type conversion, we are seeing the same source and target type.
929 // Then we can just erase this operation by forwarding its operand.
930 rewriter.replaceOp(op, adaptor.getOperands().front());
931 } else {
932 // Compute new rounding mode (if any).
933 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
934 if (auto roundingModeOp =
935 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
936 if (arith::RoundingModeAttr roundingMode =
937 roundingModeOp.getRoundingModeAttr()) {
938 if (!(rm =
939 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
940 return rewriter.notifyMatchFailure(
941 op->getLoc(),
942 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
943 }
944 }
945 }
946 // Create replacement op and attach rounding mode attribute (if any).
947 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
948 op, dstType, adaptor.getOperands());
949 if (rm) {
950 newOp->setAttr(
951 getDecorationString(spirv::Decoration::FPRoundingMode),
952 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
953 }
954 }
955 return success();
956 }
957};
958
959//===----------------------------------------------------------------------===//
960// CmpIOp
961//===----------------------------------------------------------------------===//
962
963/// Converts integer compare operation on i1 type operands to SPIR-V ops.
964class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
965public:
966 using Base::Base;
967
968 LogicalResult
969 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter) const override {
971 Type srcType = op.getLhs().getType();
972 if (!isBoolScalarOrVector(srcType))
973 return failure();
974 Type dstType = getTypeConverter()->convertType(srcType);
975 if (!dstType)
976 return getTypeConversionFailure(rewriter, op, srcType);
977
978 switch (op.getPredicate()) {
979 case arith::CmpIPredicate::eq: {
980 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
981 adaptor.getRhs());
982 return success();
983 }
984 case arith::CmpIPredicate::ne: {
985 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
986 op, adaptor.getLhs(), adaptor.getRhs());
987 return success();
988 }
989 case arith::CmpIPredicate::uge:
990 case arith::CmpIPredicate::ugt:
991 case arith::CmpIPredicate::ule:
992 case arith::CmpIPredicate::ult: {
993 // There are no direct corresponding instructions in SPIR-V for such
994 // cases. Extend them to 32-bit and do comparision then.
995 Type type = rewriter.getI32Type();
996 if (auto vectorType = dyn_cast<VectorType>(dstType))
997 type = VectorType::get(vectorType.getShape(), type);
998 Value extLhs =
999 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1000 Value extRhs =
1001 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1002
1003 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1004 extRhs);
1005 return success();
1006 }
1007 default:
1008 break;
1009 }
1010 return failure();
1011 }
1012};
1013
1014/// Converts integer compare operation to SPIR-V ops.
1015class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1016public:
1017 using Base::Base;
1018
1019 LogicalResult
1020 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 Type srcType = op.getLhs().getType();
1023 if (isBoolScalarOrVector(srcType))
1024 return failure();
1025 Type dstType = getTypeConverter()->convertType(srcType);
1026 if (!dstType)
1027 return getTypeConversionFailure(rewriter, op, srcType);
1028
1029 switch (op.getPredicate()) {
1030#define DISPATCH(cmpPredicate, spirvOp) \
1031 case cmpPredicate: \
1032 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1033 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1034 !hasSameBitwidth(srcType, dstType)) { \
1035 return op.emitError( \
1036 "bitwidth emulation is not implemented yet on unsigned op"); \
1037 } \
1038 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1039 adaptor.getRhs()); \
1040 return success();
1041
1042 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1044 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1045 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1046 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1047 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1048 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1049 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1050 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1051 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1052
1053#undef DISPATCH
1054 }
1055 return failure();
1056 }
1057};
1058
1059//===----------------------------------------------------------------------===//
1060// CmpFOpPattern
1061//===----------------------------------------------------------------------===//
1062
1063/// Converts floating-point comparison operations to SPIR-V ops.
1064class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1065public:
1066 using Base::Base;
1067
1068 LogicalResult
1069 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 switch (op.getPredicate()) {
1072#define DISPATCH(cmpPredicate, spirvOp) \
1073 case cmpPredicate: \
1074 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1075 adaptor.getRhs()); \
1076 return success();
1077
1078 // Ordered.
1079 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1080 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1081 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1082 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1083 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1085 // Unordered.
1086 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1087 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1088 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1089 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1090 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1091 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1092
1093#undef DISPATCH
1094
1095 default:
1096 break;
1097 }
1098 return failure();
1099 }
1100};
1101
1102/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1103/// Kernel capability.
1104class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1105public:
1106 using Base::Base;
1107
1108 LogicalResult
1109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1110 ConversionPatternRewriter &rewriter) const override {
1111 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1112 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1113 adaptor.getRhs());
1114 return success();
1115 }
1116
1117 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1118 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1119 adaptor.getRhs());
1120 return success();
1121 }
1122
1123 return failure();
1124 }
1125};
1126
1127/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1128/// require additional capability.
1129class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1130public:
1131 using Base::Base;
1132
1133 LogicalResult
1134 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1135 ConversionPatternRewriter &rewriter) const override {
1136 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1137 op.getPredicate() != arith::CmpFPredicate::UNO)
1138 return failure();
1139
1140 Location loc = op.getLoc();
1141
1142 Value replace;
1143 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1144 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1145 // Ordered comparsion checks if neither operand is NaN.
1146 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1147 } else {
1148 // Unordered comparsion checks if either operand is NaN.
1149 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1150 }
1151 } else {
1152 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1153 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1154
1155 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1156 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1157 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1158 }
1159
1160 rewriter.replaceOp(op, replace);
1161 return success();
1162 }
1163};
1164
1165//===----------------------------------------------------------------------===//
1166// AddUIExtendedOp
1167//===----------------------------------------------------------------------===//
1168
1169/// Converts arith.addui_extended to spirv.IAddCarry.
1170class AddUIExtendedOpPattern final
1171 : public OpConversionPattern<arith::AddUIExtendedOp> {
1172public:
1173 using Base::Base;
1174 LogicalResult
1175 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const override {
1177 Type dstElemTy = adaptor.getLhs().getType();
1178 Location loc = op->getLoc();
1179 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1180 adaptor.getRhs());
1181
1182 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1183 llvm::ArrayRef(0));
1184 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1185 llvm::ArrayRef(1));
1186
1187 // Convert the carry value to boolean.
1188 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1189 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1190
1191 rewriter.replaceOp(op, {sumResult, carryResult});
1192 return success();
1193 }
1194};
1195
1196//===----------------------------------------------------------------------===//
1197// MulIExtendedOp
1198//===----------------------------------------------------------------------===//
1199
1200/// Converts arith.mul*i_extended to spirv.*MulExtended.
1201template <typename ArithMulOp, typename SPIRVMulOp>
1202class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1203public:
1204 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1205 LogicalResult
1206 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1207 ConversionPatternRewriter &rewriter) const override {
1208 Location loc = op->getLoc();
1209 Value result =
1210 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1211
1212 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1213 llvm::ArrayRef(0));
1214 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1215 llvm::ArrayRef(1));
1216
1217 rewriter.replaceOp(op, {low, high});
1218 return success();
1219 }
1220};
1221
1222//===----------------------------------------------------------------------===//
1223// SelectOp
1224//===----------------------------------------------------------------------===//
1225
1226/// Converts arith.select to spirv.Select.
1227class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1228public:
1229 using Base::Base;
1230 LogicalResult
1231 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1232 ConversionPatternRewriter &rewriter) const override {
1233 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1234 adaptor.getTrueValue(),
1235 adaptor.getFalseValue());
1236 return success();
1237 }
1238};
1239
1240//===----------------------------------------------------------------------===//
1241// MinimumFOp, MaximumFOp
1242//===----------------------------------------------------------------------===//
1243
1244/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1245/// spirv.CL.fmax/fmin.
1246template <typename Op, typename SPIRVOp>
1247class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1248public:
1249 using OpConversionPattern<Op>::OpConversionPattern;
1250 LogicalResult
1251 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1252 ConversionPatternRewriter &rewriter) const override {
1253 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1254 Type dstType = converter->convertType(op.getType());
1255 if (!dstType)
1256 return getTypeConversionFailure(rewriter, op);
1257
1258 // arith.maximumf/minimumf:
1259 // "if one of the arguments is NaN, then the result is also NaN."
1260 // spirv.GL.FMax/FMin
1261 // "which operand is the result is undefined if one of the operands
1262 // is a NaN."
1263 // spirv.CL.fmax/fmin:
1264 // "If one argument is a NaN, Fmin returns the other argument."
1265
1266 Location loc = op.getLoc();
1267 Value spirvOp =
1268 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1269
1270 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1271 rewriter.replaceOp(op, spirvOp);
1272 return success();
1273 }
1274
1275 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1276 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1277
1278 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1279 adaptor.getLhs(), spirvOp);
1280 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1281 adaptor.getRhs(), select1);
1282
1283 rewriter.replaceOp(op, select2);
1284 return success();
1285 }
1286};
1287
1288//===----------------------------------------------------------------------===//
1289// MinNumFOp, MaxNumFOp
1290//===----------------------------------------------------------------------===//
1291
1292/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1293/// spirv.CL.fmax/fmin.
1294template <typename Op, typename SPIRVOp>
1295class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1296 template <typename TargetOp>
1297 constexpr bool shouldInsertNanGuards() const {
1298 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1299 }
1300
1301public:
1302 using OpConversionPattern<Op>::OpConversionPattern;
1303 LogicalResult
1304 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1305 ConversionPatternRewriter &rewriter) const override {
1306 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1307 Type dstType = converter->convertType(op.getType());
1308 if (!dstType)
1309 return getTypeConversionFailure(rewriter, op);
1310
1311 // arith.maxnumf/minnumf:
1312 // "If one of the arguments is NaN, then the result is the other
1313 // argument."
1314 // spirv.GL.FMax/FMin
1315 // "which operand is the result is undefined if one of the operands
1316 // is a NaN."
1317 // spirv.CL.fmax/fmin:
1318 // "If one argument is a NaN, Fmin returns the other argument."
1319
1320 Location loc = op.getLoc();
1321 Value spirvOp =
1322 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1323
1324 if (!shouldInsertNanGuards<SPIRVOp>() ||
1325 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1326 rewriter.replaceOp(op, spirvOp);
1327 return success();
1328 }
1329
1330 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1331 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1332
1333 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1334 adaptor.getRhs(), spirvOp);
1335 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1336 adaptor.getLhs(), select1);
1337
1338 rewriter.replaceOp(op, select2);
1339 return success();
1340 }
1341};
1342
1343} // namespace
1344
1345//===----------------------------------------------------------------------===//
1346// Pattern Population
1347//===----------------------------------------------------------------------===//
1348
1350 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1351 // clang-format off
1352 patterns.add<
1353 ConstantCompositeOpPattern,
1354 ConstantScalarOpPattern,
1355 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1356 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1357 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1361 RemSIOpGLPattern, RemSIOpCLPattern,
1362 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1363 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1364 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1365 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1374 ExtUIPattern, ExtUII1Pattern,
1375 ExtSIPattern, ExtSII1Pattern,
1376 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1377 TruncIPattern, TruncII1Pattern,
1378 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1379 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1380 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1381 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1382 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1383 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1385 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1386 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1387 CmpIOpBooleanPattern, CmpIOpPattern,
1388 CmpFOpNanNonePattern, CmpFOpPattern,
1389 AddUIExtendedOpPattern,
1390 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1391 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1392 SelectOpPattern,
1393
1394 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1395 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1396 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1397 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1402
1403 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1404 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1405 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1406 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1411 >(typeConverter, patterns.getContext());
1412 // clang-format on
1413
1414 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1415 // capability is available.
1416 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1417 /*benefit=*/2);
1418}
1419
1420//===----------------------------------------------------------------------===//
1421// Pass Definition
1422//===----------------------------------------------------------------------===//
1423
1424namespace {
1425struct ConvertArithToSPIRVPass
1426 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1427 using Base::Base;
1428
1429 void runOnOperation() override {
1430 Operation *op = getOperation();
1432 std::unique_ptr<SPIRVConversionTarget> target =
1433 SPIRVConversionTarget::get(targetAttr);
1434
1436 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1437 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1438 SPIRVTypeConverter typeConverter(targetAttr, options);
1439
1440 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1441 // in patterns for other dialects.
1442 target->addLegalOp<UnrealizedConversionCastOp>();
1443
1444 // Fail hard when there are any remaining 'arith' ops.
1445 target->addIllegalDialect<arith::ArithDialect>();
1446
1447 RewritePatternSet patterns(&getContext());
1448 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1449
1450 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1451 signalPassFailure();
1452 }
1453};
1454} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:436
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:412
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24