MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
186 Type dstType = converter->convertType(op.getType());
187 if (!dstType) {
188 return rewriter.notifyMatchFailure(
189 op->getLoc(),
190 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
191 }
192
193 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
194 !getElementTypeOrSelf(op.getType()).isIndex() &&
195 dstType != op.getType()) {
196 return op.emitError("bitwidth emulation is not implemented yet on "
197 "unsigned op pattern version");
198 }
199
200 auto overflowFlags = arith::IntegerOverflowFlags::none;
201 if (auto overflowIface =
202 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
203 if (converter->getTargetEnv().allows(
204 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
205 overflowFlags = overflowIface.getOverflowAttr().getValue();
206 }
207
208 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
209 op, dstType, adaptor.getOperands());
210
211 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
212 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
213 rewriter.getUnitAttr());
214
215 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
216 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
217 rewriter.getUnitAttr());
218
219 return success();
220 }
221};
222
223//===----------------------------------------------------------------------===//
224// ConstantOp
225//===----------------------------------------------------------------------===//
226
227/// Converts composite arith.constant operation to spirv.Constant.
228struct ConstantCompositeOpPattern final
229 : public OpConversionPattern<arith::ConstantOp> {
230 using Base::Base;
231
232 LogicalResult
233 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override {
235 auto srcType = dyn_cast<ShapedType>(constOp.getType());
236 if (!srcType || srcType.getNumElements() == 1)
237 return failure();
238
239 // arith.constant should only have vector or tensor types. This is a MLIR
240 // wide problem at the moment.
241 if (!isa<VectorType, RankedTensorType>(srcType))
242 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
243
244 Type dstType = getTypeConverter()->convertType(srcType);
245 if (!dstType)
246 return failure();
247
248 // Import the resource into the IR to make use of the special handling of
249 // element types later on.
250 mlir::DenseElementsAttr dstElementsAttr;
251 if (auto denseElementsAttr =
252 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
253 dstElementsAttr = denseElementsAttr;
254 } else if (auto resourceAttr =
255 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
256
257 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
258 if (!blob)
259 return constOp->emitError("could not find resource blob");
260
261 ArrayRef<char> ptr = blob->getData();
262
263 // Check that the buffer meets the requirements to get converted to a
264 // DenseElementsAttr
266 return constOp->emitError("resource is not a valid buffer");
267
268 dstElementsAttr =
269 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
270 } else {
271 return constOp->emitError("unsupported elements attribute");
272 }
273
274 ShapedType dstAttrType = dstElementsAttr.getType();
275
276 // If the composite type has more than one dimensions, perform
277 // linearization.
278 if (srcType.getRank() > 1) {
279 if (isa<RankedTensorType>(srcType)) {
280 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
281 srcType.getElementType());
282 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
283 } else {
284 // TODO: add support for large vectors.
285 return failure();
286 }
287 }
288
289 Type srcElemType = srcType.getElementType();
290 Type dstElemType;
291 // Tensor types are converted to SPIR-V array types; vector types are
292 // converted to SPIR-V vector/array types.
293 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
294 dstElemType = arrayType.getElementType();
295 else
296 dstElemType = cast<VectorType>(dstType).getElementType();
297
298 // If the source and destination element types are different, perform
299 // attribute conversion.
300 if (srcElemType != dstElemType) {
302 if (isa<FloatType>(srcElemType)) {
303 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
304 Attribute dstAttr = nullptr;
305 // Handle 8-bit float conversion to 8-bit integer.
306 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
307 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
308 srcElemType.getIntOrFloatBitWidth() == 8 &&
309 isa<IntegerType>(dstElemType)) {
310 dstAttr =
311 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
312 } else {
313 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
314 rewriter);
315 }
316 if (!dstAttr)
317 return failure();
318 elements.push_back(dstAttr);
319 }
320 } else if (srcElemType.isInteger(1)) {
321 return failure();
322 } else {
323 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
324 IntegerAttr dstAttr = convertIntegerAttr(
325 srcAttr, cast<IntegerType>(dstElemType), rewriter);
326 if (!dstAttr)
327 return failure();
328 elements.push_back(dstAttr);
329 }
330 }
331
332 // Unfortunately, we cannot use dialect-specific types for element
333 // attributes; element attributes only works with builtin types. So we
334 // need to prepare another converted builtin types for the destination
335 // elements attribute.
336 if (isa<RankedTensorType>(dstAttrType))
337 dstAttrType =
338 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
339 else
340 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
341
342 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
343 }
344
345 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
346 dstElementsAttr);
347 return success();
348 }
349};
350
351/// Converts scalar arith.constant operation to spirv.Constant.
352struct ConstantScalarOpPattern final
353 : public OpConversionPattern<arith::ConstantOp> {
354 using Base::Base;
355
356 LogicalResult
357 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter) const override {
359 Type srcType = constOp.getType();
360 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
361 if (shapedType.getNumElements() != 1)
362 return failure();
363 srcType = shapedType.getElementType();
364 }
365 if (!srcType.isIntOrIndexOrFloat())
366 return failure();
367
368 Attribute cstAttr = constOp.getValue();
369 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
370 cstAttr = elementsAttr.getSplatValue<Attribute>();
371
372 Type dstType = getTypeConverter()->convertType(srcType);
373 if (!dstType)
374 return failure();
375
376 // Floating-point types.
377 if (isa<FloatType>(srcType)) {
378 auto srcAttr = cast<FloatAttr>(cstAttr);
379 Attribute dstAttr = srcAttr;
380
381 // Floating-point types not supported in the target environment are all
382 // converted to float type.
383 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
384 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
385 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
386 dstType.getIntOrFloatBitWidth() == 8) {
387 // If the source is an 8-bit float, convert it to a 8-bit integer.
388 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
389 if (!dstAttr)
390 return failure();
391 } else if (srcType != dstType) {
392 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
393 if (!dstAttr)
394 return failure();
395 }
396
397 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
398 return success();
399 }
400
401 // Bool type.
402 if (srcType.isInteger(1)) {
403 // arith.constant can use 0/1 instead of true/false for i1 values. We need
404 // to handle that here.
405 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
406 if (!dstAttr)
407 return failure();
408 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
409 return success();
410 }
411
412 // IndexType or IntegerType. Index values are converted to 32-bit integer
413 // values when converting to SPIR-V.
414 auto srcAttr = cast<IntegerAttr>(cstAttr);
415 IntegerAttr dstAttr =
416 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
417 if (!dstAttr)
418 return failure();
419 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
420 return success();
421 }
422};
423
424//===----------------------------------------------------------------------===//
425// RemSIOp
426//===----------------------------------------------------------------------===//
427
428/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
429/// the sign of `signOperand`.
430///
431/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
432/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
433/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
434/// if either operand can be negative. Emulate it via spirv.UMod.
435template <typename SignedAbsOp>
436static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
437 Value signOperand, OpBuilder &builder) {
438 assert(lhs.getType() == rhs.getType());
439 assert(lhs == signOperand || rhs == signOperand);
440
441 Type type = lhs.getType();
442
443 // Calculate the remainder with spirv.UMod.
444 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
445 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
446 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
447
448 // Fix the sign.
449 Value isPositive;
450 if (lhs == signOperand)
451 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
452 else
453 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
454 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
455 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
456 absNegate);
457}
458
459/// Converts arith.remsi to GLSL SPIR-V ops.
460///
461/// This cannot be merged into the template unary/binary pattern due to Vulkan
462/// restrictions over spirv.SRem and spirv.SMod.
463struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
464 using Base::Base;
465
466 LogicalResult
467 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override {
469 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
470 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
471 adaptor.getOperands()[0], rewriter);
472 rewriter.replaceOp(op, result);
473
474 return success();
475 }
476};
477
478/// Converts arith.remsi to OpenCL SPIR-V ops.
479struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
480 using Base::Base;
481
482 LogicalResult
483 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
486 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
487 adaptor.getOperands()[0], rewriter);
488 rewriter.replaceOp(op, result);
489
490 return success();
491 }
492};
493
494//===----------------------------------------------------------------------===//
495// BitwiseOp
496//===----------------------------------------------------------------------===//
497
498/// Converts bitwise operations to SPIR-V operations. This is a special pattern
499/// other than the BinaryOpPatternPattern because if the operands are boolean
500/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
501/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
502template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
503struct BitwiseOpPattern final : public OpConversionPattern<Op> {
504 using OpConversionPattern<Op>::OpConversionPattern;
505
506 LogicalResult
507 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
508 ConversionPatternRewriter &rewriter) const override {
509 assert(adaptor.getOperands().size() == 2);
510 Type dstType = this->getTypeConverter()->convertType(op.getType());
511 if (!dstType)
512 return getTypeConversionFailure(rewriter, op);
513
514 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
515 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
516 op, dstType, adaptor.getOperands());
517 } else {
518 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
519 op, dstType, adaptor.getOperands());
520 }
521 return success();
522 }
523};
524
525//===----------------------------------------------------------------------===//
526// XOrIOp
527//===----------------------------------------------------------------------===//
528
529/// Converts arith.xori to SPIR-V operations.
530struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
531 using Base::Base;
532
533 LogicalResult
534 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
535 ConversionPatternRewriter &rewriter) const override {
536 assert(adaptor.getOperands().size() == 2);
537
538 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
539 return failure();
540
541 Type dstType = getTypeConverter()->convertType(op.getType());
542 if (!dstType)
543 return getTypeConversionFailure(rewriter, op);
544
545 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
546 adaptor.getOperands());
547
548 return success();
549 }
550};
551
552/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
553/// vector of i1.
554struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
555 using Base::Base;
556
557 LogicalResult
558 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560 assert(adaptor.getOperands().size() == 2);
561
562 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
563 return failure();
564
565 Type dstType = getTypeConverter()->convertType(op.getType());
566 if (!dstType)
567 return getTypeConversionFailure(rewriter, op);
568
569 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
570 op, dstType, adaptor.getOperands());
571 return success();
572 }
573};
574
575//===----------------------------------------------------------------------===//
576// UIToFPOp
577//===----------------------------------------------------------------------===//
578
579/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
580/// of i1.
581struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
582 using Base::Base;
583
584 LogicalResult
585 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
586 ConversionPatternRewriter &rewriter) const override {
587 Type srcType = adaptor.getOperands().front().getType();
588 if (!isBoolScalarOrVector(srcType))
589 return failure();
590
591 Type dstType = getTypeConverter()->convertType(op.getType());
592 if (!dstType)
593 return getTypeConversionFailure(rewriter, op);
594
595 Location loc = op.getLoc();
596 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
597 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
598 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
599 op, dstType, adaptor.getOperands().front(), one, zero);
600 return success();
601 }
602};
603
604//===----------------------------------------------------------------------===//
605// IndexCastOp
606//===----------------------------------------------------------------------===//
607
608/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
609struct IndexCastIndexI1Pattern final
610 : public OpConversionPattern<arith::IndexCastOp> {
611 using Base::Base;
612
613 LogicalResult
614 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const override {
616 if (!isBoolScalarOrVector(op.getType()))
617 return failure();
618
619 Type dstType = getTypeConverter()->convertType(op.getType());
620 if (!dstType)
621 return getTypeConversionFailure(rewriter, op);
622
623 Location loc = op.getLoc();
624 Value zeroIdx =
625 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
626 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
627 adaptor.getIn());
628 return success();
629 }
630};
631
632/// Converts arith.index_cast to spirv.Select if the source type is i1.
633struct IndexCastI1IndexPattern final
634 : public OpConversionPattern<arith::IndexCastOp> {
635 using Base::Base;
636
637 LogicalResult
638 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
641 return failure();
642
643 Type dstType = getTypeConverter()->convertType(op.getType());
644 if (!dstType)
645 return getTypeConversionFailure(rewriter, op);
646
647 Location loc = op.getLoc();
648 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
649 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
650 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
651 one, zero);
652 return success();
653 }
654};
655
656//===----------------------------------------------------------------------===//
657// ExtSIOp
658//===----------------------------------------------------------------------===//
659
660/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
661/// of i1.
662struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
663 using Base::Base;
664
665 LogicalResult
666 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter) const override {
668 Value operand = adaptor.getIn();
669 if (!isBoolScalarOrVector(operand.getType()))
670 return failure();
671
672 Location loc = op.getLoc();
673 Type dstType = getTypeConverter()->convertType(op.getType());
674 if (!dstType)
675 return getTypeConversionFailure(rewriter, op);
676
677 Value allOnes;
678 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
679 unsigned componentBitwidth = intTy.getWidth();
680 allOnes = spirv::ConstantOp::create(
681 rewriter, loc, intTy,
682 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
683 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
684 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
685 allOnes = spirv::ConstantOp::create(
686 rewriter, loc, vectorTy,
687 SplatElementsAttr::get(vectorTy,
688 APInt::getAllOnes(componentBitwidth)));
689 } else {
690 return rewriter.notifyMatchFailure(
691 loc, llvm::formatv("unhandled type: {0}", dstType));
692 }
693
694 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
696 zero);
697 return success();
698 }
699};
700
701/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
702/// vector of i1.
703struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
704 using Base::Base;
705
706 LogicalResult
707 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
708 ConversionPatternRewriter &rewriter) const override {
709 Type srcType = adaptor.getIn().getType();
710 if (isBoolScalarOrVector(srcType))
711 return failure();
712
713 Type dstType = getTypeConverter()->convertType(op.getType());
714 if (!dstType)
715 return getTypeConversionFailure(rewriter, op);
716
717 if (dstType == srcType) {
718 // We can have the same source and destination type due to type emulation.
719 // Perform bit shifting to make sure we have the proper leading set bits.
720
721 unsigned srcBW =
722 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
723 unsigned dstBW =
725 assert(srcBW < dstBW);
726 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
727 rewriter, op.getLoc());
728
729 // First shift left to sequeeze out all leading bits beyond the original
730 // bitwidth. Here we need to use the original source and result type's
731 // bitwidth.
732 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
733 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
735 // Then we perform arithmetic right shift to make sure we have the right
736 // sign bits for negative values.
737 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
738 op, dstType, shiftLOp, shiftSize);
739 } else {
740 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
741 adaptor.getOperands());
744 return success();
745 }
747
748//===----------------------------------------------------------------------===//
749// ExtUIOp
750//===----------------------------------------------------------------------===//
752/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
753/// of i1.
754struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
755 using Base::Base;
756
757 LogicalResult
758 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
759 ConversionPatternRewriter &rewriter) const override {
760 Type srcType = adaptor.getOperands().front().getType();
761 if (!isBoolScalarOrVector(srcType))
762 return failure();
763
764 Type dstType = getTypeConverter()->convertType(op.getType());
765 if (!dstType)
766 return getTypeConversionFailure(rewriter, op);
767
768 Location loc = op.getLoc();
769 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
770 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
771 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
772 op, dstType, adaptor.getOperands().front(), one, zero);
773 return success();
774 }
775};
776
777/// Converts arith.extui for cases where the type of source is neither i1 nor
778/// vector of i1.
779struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
780 using Base::Base;
781
782 LogicalResult
783 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
784 ConversionPatternRewriter &rewriter) const override {
785 Type srcType = adaptor.getIn().getType();
786 if (isBoolScalarOrVector(srcType))
787 return failure();
789 Type dstType = getTypeConverter()->convertType(op.getType());
790 if (!dstType)
791 return getTypeConversionFailure(rewriter, op);
793 if (dstType == srcType) {
794 // We can have the same source and destination type due to type emulation.
795 // Perform bit masking to make sure we don't pollute downstream consumers
796 // with unwanted bits. Here we need to use the original source type's
797 // bitwidth.
798 unsigned bitwidth =
799 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
801 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
802 op.getLoc());
803 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
804 adaptor.getIn(), mask);
805 } else {
806 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
807 adaptor.getOperands());
808 }
809 return success();
810 }
811};
812
813//===----------------------------------------------------------------------===//
814// TruncIOp
815//===----------------------------------------------------------------------===//
816
817/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
818/// of i1.
819struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
820 using Base::Base;
821
822 LogicalResult
823 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
824 ConversionPatternRewriter &rewriter) const override {
825 Type dstType = getTypeConverter()->convertType(op.getType());
826 if (!dstType)
827 return getTypeConversionFailure(rewriter, op);
828
829 if (!isBoolScalarOrVector(dstType))
830 return failure();
831
832 Location loc = op.getLoc();
833 auto srcType = adaptor.getOperands().front().getType();
834 // Check if (x & 1) == 1.
835 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
836 Value maskedSrc = spirv::BitwiseAndOp::create(
837 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
838 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
839
840 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
841 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
842 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
843 return success();
844 }
845};
846
847/// Converts arith.trunci for cases where the type of result is neither i1
848/// nor vector of i1.
849struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
850 using Base::Base;
851
852 LogicalResult
853 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 Type srcType = adaptor.getIn().getType();
856 Type dstType = getTypeConverter()->convertType(op.getType());
857 if (!dstType)
858 return getTypeConversionFailure(rewriter, op);
859
860 if (isBoolScalarOrVector(dstType))
861 return failure();
862
863 if (dstType == srcType) {
864 // We can have the same source and destination type due to type emulation.
865 // Perform bit masking to make sure we don't pollute downstream consumers
866 // with unwanted bits. Here we need to use the original result type's
867 // bitwidth.
868 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
869 Value mask = getScalarOrVectorConstInt(
870 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
871 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
872 adaptor.getIn(), mask);
873 } else {
874 // Given this is truncation, either SConvertOp or UConvertOp works.
875 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
876 adaptor.getOperands());
877 }
878 return success();
879 }
880};
881
882//===----------------------------------------------------------------------===//
883// TypeCastingOp
884//===----------------------------------------------------------------------===//
885
886static std::optional<spirv::FPRoundingMode>
887convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
888 switch (roundingMode) {
889 case arith::RoundingMode::downward:
890 return spirv::FPRoundingMode::RTN;
891 case arith::RoundingMode::to_nearest_even:
892 return spirv::FPRoundingMode::RTE;
893 case arith::RoundingMode::toward_zero:
894 return spirv::FPRoundingMode::RTZ;
895 case arith::RoundingMode::upward:
896 return spirv::FPRoundingMode::RTP;
897 case arith::RoundingMode::to_nearest_away:
898 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
899 // (as of SPIR-V 1.6)
900 return std::nullopt;
901 }
902 llvm_unreachable("Unhandled rounding mode");
903}
904
905/// Converts type-casting standard operations to SPIR-V operations.
906template <typename Op, typename SPIRVOp>
907struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
908 using OpConversionPattern<Op>::OpConversionPattern;
909
910 LogicalResult
911 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
912 ConversionPatternRewriter &rewriter) const override {
913 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
914 Type dstType = this->getTypeConverter()->convertType(op.getType());
915 if (!dstType)
916 return getTypeConversionFailure(rewriter, op);
917
918 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
919 return failure();
920
921 if (dstType == srcType) {
922 // Due to type conversion, we are seeing the same source and target type.
923 // Then we can just erase this operation by forwarding its operand.
924 rewriter.replaceOp(op, adaptor.getOperands().front());
925 } else {
926 // Compute new rounding mode (if any).
927 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
928 if (auto roundingModeOp =
929 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
930 if (arith::RoundingModeAttr roundingMode =
931 roundingModeOp.getRoundingModeAttr()) {
932 if (!(rm =
933 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
934 return rewriter.notifyMatchFailure(
935 op->getLoc(),
936 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
937 }
938 }
939 }
940 // Create replacement op and attach rounding mode attribute (if any).
941 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
942 op, dstType, adaptor.getOperands());
943 if (rm) {
944 newOp->setAttr(
945 getDecorationString(spirv::Decoration::FPRoundingMode),
946 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
947 }
948 }
949 return success();
950 }
951};
952
953//===----------------------------------------------------------------------===//
954// CmpIOp
955//===----------------------------------------------------------------------===//
956
957/// Converts integer compare operation on i1 type operands to SPIR-V ops.
958class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
959public:
960 using Base::Base;
961
962 LogicalResult
963 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter) const override {
965 Type srcType = op.getLhs().getType();
966 if (!isBoolScalarOrVector(srcType))
967 return failure();
968 Type dstType = getTypeConverter()->convertType(srcType);
969 if (!dstType)
970 return getTypeConversionFailure(rewriter, op, srcType);
971
972 switch (op.getPredicate()) {
973 case arith::CmpIPredicate::eq: {
974 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
975 adaptor.getRhs());
976 return success();
977 }
978 case arith::CmpIPredicate::ne: {
979 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
980 op, adaptor.getLhs(), adaptor.getRhs());
981 return success();
982 }
983 case arith::CmpIPredicate::uge:
984 case arith::CmpIPredicate::ugt:
985 case arith::CmpIPredicate::ule:
986 case arith::CmpIPredicate::ult: {
987 // There are no direct corresponding instructions in SPIR-V for such
988 // cases. Extend them to 32-bit and do comparision then.
989 Type type = rewriter.getI32Type();
990 if (auto vectorType = dyn_cast<VectorType>(dstType))
991 type = VectorType::get(vectorType.getShape(), type);
992 Value extLhs =
993 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
994 Value extRhs =
995 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
996
997 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
998 extRhs);
999 return success();
1000 }
1001 default:
1002 break;
1003 }
1004 return failure();
1005 }
1006};
1007
1008/// Converts integer compare operation to SPIR-V ops.
1009class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1010public:
1011 using Base::Base;
1012
1013 LogicalResult
1014 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1015 ConversionPatternRewriter &rewriter) const override {
1016 Type srcType = op.getLhs().getType();
1017 if (isBoolScalarOrVector(srcType))
1018 return failure();
1019 Type dstType = getTypeConverter()->convertType(srcType);
1020 if (!dstType)
1021 return getTypeConversionFailure(rewriter, op, srcType);
1022
1023 switch (op.getPredicate()) {
1024#define DISPATCH(cmpPredicate, spirvOp) \
1025 case cmpPredicate: \
1026 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1027 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1028 !hasSameBitwidth(srcType, dstType)) { \
1029 return op.emitError( \
1030 "bitwidth emulation is not implemented yet on unsigned op"); \
1031 } \
1032 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1033 adaptor.getRhs()); \
1034 return success();
1035
1036 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1037 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1038 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1039 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1040 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1041 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1042 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1043 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1044 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1045 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1046
1047#undef DISPATCH
1048 }
1049 return failure();
1050 }
1051};
1052
1053//===----------------------------------------------------------------------===//
1054// CmpFOpPattern
1055//===----------------------------------------------------------------------===//
1056
1057/// Converts floating-point comparison operations to SPIR-V ops.
1058class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1059public:
1060 using Base::Base;
1061
1062 LogicalResult
1063 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter) const override {
1065 switch (op.getPredicate()) {
1066#define DISPATCH(cmpPredicate, spirvOp) \
1067 case cmpPredicate: \
1068 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1069 adaptor.getRhs()); \
1070 return success();
1071
1072 // Ordered.
1073 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1074 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1075 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1076 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1077 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1078 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1079 // Unordered.
1080 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1081 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1082 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1083 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1084 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1085 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1086
1087#undef DISPATCH
1088
1089 default:
1090 break;
1091 }
1092 return failure();
1093 }
1094};
1095
1096/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1097/// Kernel capability.
1098class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1099public:
1100 using Base::Base;
1101
1102 LogicalResult
1103 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter) const override {
1105 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1106 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1107 adaptor.getRhs());
1108 return success();
1109 }
1110
1111 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1112 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1113 adaptor.getRhs());
1114 return success();
1115 }
1116
1117 return failure();
1118 }
1119};
1120
1121/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1122/// require additional capability.
1123class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1124public:
1125 using Base::Base;
1126
1127 LogicalResult
1128 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1129 ConversionPatternRewriter &rewriter) const override {
1130 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1131 op.getPredicate() != arith::CmpFPredicate::UNO)
1132 return failure();
1133
1134 Location loc = op.getLoc();
1135
1136 Value replace;
1137 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1138 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1139 // Ordered comparsion checks if neither operand is NaN.
1140 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1141 } else {
1142 // Unordered comparsion checks if either operand is NaN.
1143 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1144 }
1145 } else {
1146 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1147 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1148
1149 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1150 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1151 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1152 }
1153
1154 rewriter.replaceOp(op, replace);
1155 return success();
1156 }
1157};
1158
1159//===----------------------------------------------------------------------===//
1160// AddUIExtendedOp
1161//===----------------------------------------------------------------------===//
1162
1163/// Converts arith.addui_extended to spirv.IAddCarry.
1164class AddUIExtendedOpPattern final
1165 : public OpConversionPattern<arith::AddUIExtendedOp> {
1166public:
1167 using Base::Base;
1168 LogicalResult
1169 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1170 ConversionPatternRewriter &rewriter) const override {
1171 Type dstElemTy = adaptor.getLhs().getType();
1172 Location loc = op->getLoc();
1173 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1174 adaptor.getRhs());
1175
1176 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1177 llvm::ArrayRef(0));
1178 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1179 llvm::ArrayRef(1));
1180
1181 // Convert the carry value to boolean.
1182 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1183 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1184
1185 rewriter.replaceOp(op, {sumResult, carryResult});
1186 return success();
1187 }
1188};
1189
1190//===----------------------------------------------------------------------===//
1191// MulIExtendedOp
1192//===----------------------------------------------------------------------===//
1193
1194/// Converts arith.mul*i_extended to spirv.*MulExtended.
1195template <typename ArithMulOp, typename SPIRVMulOp>
1196class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1197public:
1198 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1199 LogicalResult
1200 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1201 ConversionPatternRewriter &rewriter) const override {
1202 Location loc = op->getLoc();
1203 Value result =
1204 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1205
1206 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1207 llvm::ArrayRef(0));
1208 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1209 llvm::ArrayRef(1));
1210
1211 rewriter.replaceOp(op, {low, high});
1212 return success();
1213 }
1214};
1215
1216//===----------------------------------------------------------------------===//
1217// SelectOp
1218//===----------------------------------------------------------------------===//
1219
1220/// Converts arith.select to spirv.Select.
1221class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1222public:
1223 using Base::Base;
1224 LogicalResult
1225 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1226 ConversionPatternRewriter &rewriter) const override {
1227 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1228 adaptor.getTrueValue(),
1229 adaptor.getFalseValue());
1230 return success();
1231 }
1232};
1233
1234//===----------------------------------------------------------------------===//
1235// MinimumFOp, MaximumFOp
1236//===----------------------------------------------------------------------===//
1237
1238/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1239/// spirv.CL.fmax/fmin.
1240template <typename Op, typename SPIRVOp>
1241class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1242public:
1243 using OpConversionPattern<Op>::OpConversionPattern;
1244 LogicalResult
1245 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1246 ConversionPatternRewriter &rewriter) const override {
1247 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1248 Type dstType = converter->convertType(op.getType());
1249 if (!dstType)
1250 return getTypeConversionFailure(rewriter, op);
1251
1252 // arith.maximumf/minimumf:
1253 // "if one of the arguments is NaN, then the result is also NaN."
1254 // spirv.GL.FMax/FMin
1255 // "which operand is the result is undefined if one of the operands
1256 // is a NaN."
1257 // spirv.CL.fmax/fmin:
1258 // "If one argument is a NaN, Fmin returns the other argument."
1259
1260 Location loc = op.getLoc();
1261 Value spirvOp =
1262 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1263
1264 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1265 rewriter.replaceOp(op, spirvOp);
1266 return success();
1267 }
1268
1269 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1270 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1271
1272 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1273 adaptor.getLhs(), spirvOp);
1274 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1275 adaptor.getRhs(), select1);
1276
1277 rewriter.replaceOp(op, select2);
1278 return success();
1279 }
1280};
1281
1282//===----------------------------------------------------------------------===//
1283// MinNumFOp, MaxNumFOp
1284//===----------------------------------------------------------------------===//
1285
1286/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1287/// spirv.CL.fmax/fmin.
1288template <typename Op, typename SPIRVOp>
1289class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1290 template <typename TargetOp>
1291 constexpr bool shouldInsertNanGuards() const {
1292 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1293 }
1294
1295public:
1296 using OpConversionPattern<Op>::OpConversionPattern;
1297 LogicalResult
1298 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1299 ConversionPatternRewriter &rewriter) const override {
1300 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1301 Type dstType = converter->convertType(op.getType());
1302 if (!dstType)
1303 return getTypeConversionFailure(rewriter, op);
1304
1305 // arith.maxnumf/minnumf:
1306 // "If one of the arguments is NaN, then the result is the other
1307 // argument."
1308 // spirv.GL.FMax/FMin
1309 // "which operand is the result is undefined if one of the operands
1310 // is a NaN."
1311 // spirv.CL.fmax/fmin:
1312 // "If one argument is a NaN, Fmin returns the other argument."
1313
1314 Location loc = op.getLoc();
1315 Value spirvOp =
1316 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1317
1318 if (!shouldInsertNanGuards<SPIRVOp>() ||
1319 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1320 rewriter.replaceOp(op, spirvOp);
1321 return success();
1322 }
1323
1324 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1325 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1326
1327 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1328 adaptor.getRhs(), spirvOp);
1329 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1330 adaptor.getLhs(), select1);
1331
1332 rewriter.replaceOp(op, select2);
1333 return success();
1334 }
1335};
1336
1337} // namespace
1338
1339//===----------------------------------------------------------------------===//
1340// Pattern Population
1341//===----------------------------------------------------------------------===//
1342
1344 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1345 // clang-format off
1346 patterns.add<
1347 ConstantCompositeOpPattern,
1348 ConstantScalarOpPattern,
1349 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1350 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1351 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1355 RemSIOpGLPattern, RemSIOpCLPattern,
1356 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1357 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1358 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1359 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1368 ExtUIPattern, ExtUII1Pattern,
1369 ExtSIPattern, ExtSII1Pattern,
1370 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1371 TruncIPattern, TruncII1Pattern,
1372 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1373 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1374 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1375 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1376 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1377 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1378 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1379 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1380 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1381 CmpIOpBooleanPattern, CmpIOpPattern,
1382 CmpFOpNanNonePattern, CmpFOpPattern,
1383 AddUIExtendedOpPattern,
1384 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1385 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1386 SelectOpPattern,
1387
1388 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1389 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1390 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1391 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1396
1397 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1398 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1399 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1400 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1405 >(typeConverter, patterns.getContext());
1406 // clang-format on
1407
1408 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1409 // capability is available.
1410 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1411 /*benefit=*/2);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Pass Definition
1416//===----------------------------------------------------------------------===//
1417
1418namespace {
1419struct ConvertArithToSPIRVPass
1420 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1421 using Base::Base;
1422
1423 void runOnOperation() override {
1424 Operation *op = getOperation();
1426 std::unique_ptr<SPIRVConversionTarget> target =
1427 SPIRVConversionTarget::get(targetAttr);
1428
1430 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1431 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1432 SPIRVTypeConverter typeConverter(targetAttr, options);
1433
1434 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1435 // in patterns for other dialects.
1436 target->addLegalOp<UnrealizedConversionCastOp>();
1437
1438 // Fail hard when there are any remaining 'arith' ops.
1439 target->addIllegalDialect<arith::ArithDialect>();
1440
1443
1444 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1445 signalPassFailure();
1446 }
1447};
1448} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24