MLIR 22.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173// TODO: Move to some common place?
174static std::string getDecorationString(spirv::Decoration decor) {
175 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
176}
177
178namespace {
179
180/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
181/// operations. Op can potentially support overflow flags.
182template <typename Op, typename SPIRVOp>
183struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
184 using OpConversionPattern<Op>::OpConversionPattern;
185
186 LogicalResult
187 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter) const override {
189 assert(adaptor.getOperands().size() <= 3);
190 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
192 if (!dstType) {
193 return rewriter.notifyMatchFailure(
194 op->getLoc(),
195 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
196 }
197
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
199 !getElementTypeOrSelf(op.getType()).isIndex() &&
200 dstType != op.getType()) {
201 return op.emitError("bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
203 }
204
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
211 }
212
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
215
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
217 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
218 rewriter.getUnitAttr());
219
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
221 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
222 rewriter.getUnitAttr());
223
224 return success();
225 }
226};
227
228//===----------------------------------------------------------------------===//
229// ConstantOp
230//===----------------------------------------------------------------------===//
231
232/// Converts composite arith.constant operation to spirv.Constant.
233struct ConstantCompositeOpPattern final
234 : public OpConversionPattern<arith::ConstantOp> {
235 using Base::Base;
236
237 LogicalResult
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
242 return failure();
243
244 // arith.constant should only have vector or tensor types. This is a MLIR
245 // wide problem at the moment.
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
248
249 Type dstType = getTypeConverter()->convertType(srcType);
250 if (!dstType)
251 return failure();
252
253 // Import the resource into the IR to make use of the special handling of
254 // element types later on.
255 mlir::DenseElementsAttr dstElementsAttr;
256 if (auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 } else if (auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
261
262 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
263 if (!blob)
264 return constOp->emitError("could not find resource blob");
265
266 ArrayRef<char> ptr = blob->getData();
267
268 // Check that the buffer meets the requirements to get converted to a
269 // DenseElementsAttr
270 bool detectedSplat = false;
271 if (!DenseElementsAttr::isValidRawBuffer(srcType, ptr, detectedSplat))
272 return constOp->emitError("resource is not a valid buffer");
273
274 dstElementsAttr =
275 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
276 } else {
277 return constOp->emitError("unsupported elements attribute");
278 }
279
280 ShapedType dstAttrType = dstElementsAttr.getType();
281
282 // If the composite type has more than one dimensions, perform
283 // linearization.
284 if (srcType.getRank() > 1) {
285 if (isa<RankedTensorType>(srcType)) {
286 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
287 srcType.getElementType());
288 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
289 } else {
290 // TODO: add support for large vectors.
291 return failure();
292 }
293 }
294
295 Type srcElemType = srcType.getElementType();
296 Type dstElemType;
297 // Tensor types are converted to SPIR-V array types; vector types are
298 // converted to SPIR-V vector/array types.
299 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
300 dstElemType = arrayType.getElementType();
301 else
302 dstElemType = cast<VectorType>(dstType).getElementType();
303
304 // If the source and destination element types are different, perform
305 // attribute conversion.
306 if (srcElemType != dstElemType) {
308 if (isa<FloatType>(srcElemType)) {
309 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
310 Attribute dstAttr = nullptr;
311 // Handle 8-bit float conversion to 8-bit integer.
312 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
313 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314 srcElemType.getIntOrFloatBitWidth() == 8 &&
315 isa<IntegerType>(dstElemType)) {
316 dstAttr =
317 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
318 } else {
319 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
320 rewriter);
321 }
322 if (!dstAttr)
323 return failure();
324 elements.push_back(dstAttr);
325 }
326 } else if (srcElemType.isInteger(1)) {
327 return failure();
328 } else {
329 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
330 IntegerAttr dstAttr = convertIntegerAttr(
331 srcAttr, cast<IntegerType>(dstElemType), rewriter);
332 if (!dstAttr)
333 return failure();
334 elements.push_back(dstAttr);
335 }
336 }
337
338 // Unfortunately, we cannot use dialect-specific types for element
339 // attributes; element attributes only works with builtin types. So we
340 // need to prepare another converted builtin types for the destination
341 // elements attribute.
342 if (isa<RankedTensorType>(dstAttrType))
343 dstAttrType =
344 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
345 else
346 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
347
348 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
349 }
350
351 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
352 dstElementsAttr);
353 return success();
354 }
355};
356
357/// Converts scalar arith.constant operation to spirv.Constant.
358struct ConstantScalarOpPattern final
359 : public OpConversionPattern<arith::ConstantOp> {
360 using Base::Base;
361
362 LogicalResult
363 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
364 ConversionPatternRewriter &rewriter) const override {
365 Type srcType = constOp.getType();
366 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
367 if (shapedType.getNumElements() != 1)
368 return failure();
369 srcType = shapedType.getElementType();
370 }
371 if (!srcType.isIntOrIndexOrFloat())
372 return failure();
373
374 Attribute cstAttr = constOp.getValue();
375 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
376 cstAttr = elementsAttr.getSplatValue<Attribute>();
377
378 Type dstType = getTypeConverter()->convertType(srcType);
379 if (!dstType)
380 return failure();
381
382 // Floating-point types.
383 if (isa<FloatType>(srcType)) {
384 auto srcAttr = cast<FloatAttr>(cstAttr);
385 Attribute dstAttr = srcAttr;
386
387 // Floating-point types not supported in the target environment are all
388 // converted to float type.
389 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
390 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
392 dstType.getIntOrFloatBitWidth() == 8) {
393 // If the source is an 8-bit float, convert it to a 8-bit integer.
394 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
395 if (!dstAttr)
396 return failure();
397 } else if (srcType != dstType) {
398 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
399 if (!dstAttr)
400 return failure();
401 }
402
403 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
404 return success();
405 }
406
407 // Bool type.
408 if (srcType.isInteger(1)) {
409 // arith.constant can use 0/1 instead of true/false for i1 values. We need
410 // to handle that here.
411 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
412 if (!dstAttr)
413 return failure();
414 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
415 return success();
416 }
417
418 // IndexType or IntegerType. Index values are converted to 32-bit integer
419 // values when converting to SPIR-V.
420 auto srcAttr = cast<IntegerAttr>(cstAttr);
421 IntegerAttr dstAttr =
422 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
423 if (!dstAttr)
424 return failure();
425 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
426 return success();
427 }
428};
429
430//===----------------------------------------------------------------------===//
431// RemSIOp
432//===----------------------------------------------------------------------===//
433
434/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
435/// the sign of `signOperand`.
436///
437/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
438/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
439/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
440/// if either operand can be negative. Emulate it via spirv.UMod.
441template <typename SignedAbsOp>
442static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
443 Value signOperand, OpBuilder &builder) {
444 assert(lhs.getType() == rhs.getType());
445 assert(lhs == signOperand || rhs == signOperand);
446
447 Type type = lhs.getType();
448
449 // Calculate the remainder with spirv.UMod.
450 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
451 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
452 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
453
454 // Fix the sign.
455 Value isPositive;
456 if (lhs == signOperand)
457 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
458 else
459 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
460 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
461 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
462 absNegate);
463}
464
465/// Converts arith.remsi to GLSL SPIR-V ops.
466///
467/// This cannot be merged into the template unary/binary pattern due to Vulkan
468/// restrictions over spirv.SRem and spirv.SMod.
469struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
470 using Base::Base;
471
472 LogicalResult
473 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
474 ConversionPatternRewriter &rewriter) const override {
475 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
476 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
477 adaptor.getOperands()[0], rewriter);
478 rewriter.replaceOp(op, result);
479
480 return success();
481 }
482};
483
484/// Converts arith.remsi to OpenCL SPIR-V ops.
485struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
486 using Base::Base;
487
488 LogicalResult
489 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter) const override {
491 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
492 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
493 adaptor.getOperands()[0], rewriter);
494 rewriter.replaceOp(op, result);
495
496 return success();
497 }
498};
499
500//===----------------------------------------------------------------------===//
501// BitwiseOp
502//===----------------------------------------------------------------------===//
503
504/// Converts bitwise operations to SPIR-V operations. This is a special pattern
505/// other than the BinaryOpPatternPattern because if the operands are boolean
506/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
507/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
508template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
509struct BitwiseOpPattern final : public OpConversionPattern<Op> {
510 using OpConversionPattern<Op>::OpConversionPattern;
511
512 LogicalResult
513 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
514 ConversionPatternRewriter &rewriter) const override {
515 assert(adaptor.getOperands().size() == 2);
516 Type dstType = this->getTypeConverter()->convertType(op.getType());
517 if (!dstType)
518 return getTypeConversionFailure(rewriter, op);
519
520 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
521 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
522 op, dstType, adaptor.getOperands());
523 } else {
524 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
525 op, dstType, adaptor.getOperands());
526 }
527 return success();
528 }
529};
530
531//===----------------------------------------------------------------------===//
532// XOrIOp
533//===----------------------------------------------------------------------===//
534
535/// Converts arith.xori to SPIR-V operations.
536struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
537 using Base::Base;
538
539 LogicalResult
540 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
541 ConversionPatternRewriter &rewriter) const override {
542 assert(adaptor.getOperands().size() == 2);
543
544 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
545 return failure();
546
547 Type dstType = getTypeConverter()->convertType(op.getType());
548 if (!dstType)
549 return getTypeConversionFailure(rewriter, op);
550
551 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
552 adaptor.getOperands());
553
554 return success();
555 }
556};
557
558/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
559/// vector of i1.
560struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
561 using Base::Base;
562
563 LogicalResult
564 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 assert(adaptor.getOperands().size() == 2);
567
568 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
569 return failure();
570
571 Type dstType = getTypeConverter()->convertType(op.getType());
572 if (!dstType)
573 return getTypeConversionFailure(rewriter, op);
574
575 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
576 op, dstType, adaptor.getOperands());
577 return success();
578 }
579};
580
581//===----------------------------------------------------------------------===//
582// UIToFPOp
583//===----------------------------------------------------------------------===//
584
585/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
586/// of i1.
587struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
588 using Base::Base;
589
590 LogicalResult
591 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
592 ConversionPatternRewriter &rewriter) const override {
593 Type srcType = adaptor.getOperands().front().getType();
594 if (!isBoolScalarOrVector(srcType))
595 return failure();
596
597 Type dstType = getTypeConverter()->convertType(op.getType());
598 if (!dstType)
599 return getTypeConversionFailure(rewriter, op);
600
601 Location loc = op.getLoc();
602 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
603 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
604 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
605 op, dstType, adaptor.getOperands().front(), one, zero);
606 return success();
607 }
608};
609
610//===----------------------------------------------------------------------===//
611// IndexCastOp
612//===----------------------------------------------------------------------===//
613
614/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
615struct IndexCastIndexI1Pattern final
616 : public OpConversionPattern<arith::IndexCastOp> {
617 using Base::Base;
618
619 LogicalResult
620 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
621 ConversionPatternRewriter &rewriter) const override {
622 if (!isBoolScalarOrVector(op.getType()))
623 return failure();
624
625 Type dstType = getTypeConverter()->convertType(op.getType());
626 if (!dstType)
627 return getTypeConversionFailure(rewriter, op);
628
629 Location loc = op.getLoc();
630 Value zeroIdx =
631 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
632 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
633 adaptor.getIn());
634 return success();
635 }
636};
637
638/// Converts arith.index_cast to spirv.Select if the source type is i1.
639struct IndexCastI1IndexPattern final
640 : public OpConversionPattern<arith::IndexCastOp> {
641 using Base::Base;
642
643 LogicalResult
644 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
645 ConversionPatternRewriter &rewriter) const override {
646 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
647 return failure();
648
649 Type dstType = getTypeConverter()->convertType(op.getType());
650 if (!dstType)
651 return getTypeConversionFailure(rewriter, op);
652
653 Location loc = op.getLoc();
654 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
655 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
656 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
657 one, zero);
658 return success();
659 }
660};
661
662//===----------------------------------------------------------------------===//
663// ExtSIOp
664//===----------------------------------------------------------------------===//
665
666/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
667/// of i1.
668struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
669 using Base::Base;
670
671 LogicalResult
672 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter) const override {
674 Value operand = adaptor.getIn();
675 if (!isBoolScalarOrVector(operand.getType()))
676 return failure();
677
678 Location loc = op.getLoc();
679 Type dstType = getTypeConverter()->convertType(op.getType());
680 if (!dstType)
681 return getTypeConversionFailure(rewriter, op);
682
683 Value allOnes;
684 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
685 unsigned componentBitwidth = intTy.getWidth();
686 allOnes = spirv::ConstantOp::create(
687 rewriter, loc, intTy,
688 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
689 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
690 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
691 allOnes = spirv::ConstantOp::create(
692 rewriter, loc, vectorTy,
693 SplatElementsAttr::get(vectorTy,
694 APInt::getAllOnes(componentBitwidth)));
695 } else {
696 return rewriter.notifyMatchFailure(
697 loc, llvm::formatv("unhandled type: {0}", dstType));
698 }
699
700 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
701 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
702 zero);
703 return success();
704 }
705};
706
707/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
708/// vector of i1.
709struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
710 using Base::Base;
711
712 LogicalResult
713 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
714 ConversionPatternRewriter &rewriter) const override {
715 Type srcType = adaptor.getIn().getType();
716 if (isBoolScalarOrVector(srcType))
717 return failure();
718
719 Type dstType = getTypeConverter()->convertType(op.getType());
720 if (!dstType)
721 return getTypeConversionFailure(rewriter, op);
722
723 if (dstType == srcType) {
724 // We can have the same source and destination type due to type emulation.
725 // Perform bit shifting to make sure we have the proper leading set bits.
726
727 unsigned srcBW =
728 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
729 unsigned dstBW =
731 assert(srcBW < dstBW);
732 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
733 rewriter, op.getLoc());
734
735 // First shift left to sequeeze out all leading bits beyond the original
736 // bitwidth. Here we need to use the original source and result type's
737 // bitwidth.
738 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
739 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
741 // Then we perform arithmetic right shift to make sure we have the right
742 // sign bits for negative values.
743 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
744 op, dstType, shiftLOp, shiftSize);
745 } else {
746 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
747 adaptor.getOperands());
749
750 return success();
752};
753
754//===----------------------------------------------------------------------===//
755// ExtUIOp
756//===----------------------------------------------------------------------===//
758/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
759/// of i1.
760struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
761 using Base::Base;
763 LogicalResult
764 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
765 ConversionPatternRewriter &rewriter) const override {
766 Type srcType = adaptor.getOperands().front().getType();
767 if (!isBoolScalarOrVector(srcType))
768 return failure();
769
770 Type dstType = getTypeConverter()->convertType(op.getType());
771 if (!dstType)
772 return getTypeConversionFailure(rewriter, op);
773
774 Location loc = op.getLoc();
775 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
776 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
777 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
778 op, dstType, adaptor.getOperands().front(), one, zero);
779 return success();
780 }
783/// Converts arith.extui for cases where the type of source is neither i1 nor
784/// vector of i1.
785struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
786 using Base::Base;
787
788 LogicalResult
789 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
790 ConversionPatternRewriter &rewriter) const override {
791 Type srcType = adaptor.getIn().getType();
792 if (isBoolScalarOrVector(srcType))
793 return failure();
794
795 Type dstType = getTypeConverter()->convertType(op.getType());
796 if (!dstType)
797 return getTypeConversionFailure(rewriter, op);
798
799 if (dstType == srcType) {
800 // We can have the same source and destination type due to type emulation.
801 // Perform bit masking to make sure we don't pollute downstream consumers
802 // with unwanted bits. Here we need to use the original source type's
803 // bitwidth.
804 unsigned bitwidth =
805 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
807 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
808 op.getLoc());
809 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
810 adaptor.getIn(), mask);
811 } else {
812 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
813 adaptor.getOperands());
814 }
815 return success();
816 }
817};
818
819//===----------------------------------------------------------------------===//
820// TruncIOp
821//===----------------------------------------------------------------------===//
822
823/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
824/// of i1.
825struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
826 using Base::Base;
827
828 LogicalResult
829 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
830 ConversionPatternRewriter &rewriter) const override {
831 Type dstType = getTypeConverter()->convertType(op.getType());
832 if (!dstType)
833 return getTypeConversionFailure(rewriter, op);
834
835 if (!isBoolScalarOrVector(dstType))
836 return failure();
837
838 Location loc = op.getLoc();
839 auto srcType = adaptor.getOperands().front().getType();
840 // Check if (x & 1) == 1.
841 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
842 Value maskedSrc = spirv::BitwiseAndOp::create(
843 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
844 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
845
846 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
847 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
848 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
849 return success();
850 }
851};
852
853/// Converts arith.trunci for cases where the type of result is neither i1
854/// nor vector of i1.
855struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
856 using Base::Base;
857
858 LogicalResult
859 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter) const override {
861 Type srcType = adaptor.getIn().getType();
862 Type dstType = getTypeConverter()->convertType(op.getType());
863 if (!dstType)
864 return getTypeConversionFailure(rewriter, op);
865
866 if (isBoolScalarOrVector(dstType))
867 return failure();
868
869 if (dstType == srcType) {
870 // We can have the same source and destination type due to type emulation.
871 // Perform bit masking to make sure we don't pollute downstream consumers
872 // with unwanted bits. Here we need to use the original result type's
873 // bitwidth.
874 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
875 Value mask = getScalarOrVectorConstInt(
876 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
877 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
878 adaptor.getIn(), mask);
879 } else {
880 // Given this is truncation, either SConvertOp or UConvertOp works.
881 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
882 adaptor.getOperands());
883 }
884 return success();
885 }
886};
887
888//===----------------------------------------------------------------------===//
889// TypeCastingOp
890//===----------------------------------------------------------------------===//
891
892static std::optional<spirv::FPRoundingMode>
893convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
894 switch (roundingMode) {
895 case arith::RoundingMode::downward:
896 return spirv::FPRoundingMode::RTN;
897 case arith::RoundingMode::to_nearest_even:
898 return spirv::FPRoundingMode::RTE;
899 case arith::RoundingMode::toward_zero:
900 return spirv::FPRoundingMode::RTZ;
901 case arith::RoundingMode::upward:
902 return spirv::FPRoundingMode::RTP;
903 case arith::RoundingMode::to_nearest_away:
904 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
905 // (as of SPIR-V 1.6)
906 return std::nullopt;
907 }
908 llvm_unreachable("Unhandled rounding mode");
909}
910
911/// Converts type-casting standard operations to SPIR-V operations.
912template <typename Op, typename SPIRVOp>
913struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
914 using OpConversionPattern<Op>::OpConversionPattern;
915
916 LogicalResult
917 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
918 ConversionPatternRewriter &rewriter) const override {
919 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
920 Type dstType = this->getTypeConverter()->convertType(op.getType());
921 if (!dstType)
922 return getTypeConversionFailure(rewriter, op);
923
924 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
925 return failure();
926
927 if (dstType == srcType) {
928 // Due to type conversion, we are seeing the same source and target type.
929 // Then we can just erase this operation by forwarding its operand.
930 rewriter.replaceOp(op, adaptor.getOperands().front());
931 } else {
932 // Compute new rounding mode (if any).
933 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
934 if (auto roundingModeOp =
935 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
936 if (arith::RoundingModeAttr roundingMode =
937 roundingModeOp.getRoundingModeAttr()) {
938 if (!(rm =
939 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
940 return rewriter.notifyMatchFailure(
941 op->getLoc(),
942 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
943 }
944 }
945 }
946 // Create replacement op and attach rounding mode attribute (if any).
947 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
948 op, dstType, adaptor.getOperands());
949 if (rm) {
950 newOp->setAttr(
951 getDecorationString(spirv::Decoration::FPRoundingMode),
952 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
953 }
954 }
955 return success();
956 }
957};
958
959//===----------------------------------------------------------------------===//
960// CmpIOp
961//===----------------------------------------------------------------------===//
962
963/// Converts integer compare operation on i1 type operands to SPIR-V ops.
964class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
965public:
966 using Base::Base;
967
968 LogicalResult
969 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter) const override {
971 Type srcType = op.getLhs().getType();
972 if (!isBoolScalarOrVector(srcType))
973 return failure();
974 Type dstType = getTypeConverter()->convertType(srcType);
975 if (!dstType)
976 return getTypeConversionFailure(rewriter, op, srcType);
977
978 switch (op.getPredicate()) {
979 case arith::CmpIPredicate::eq: {
980 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
981 adaptor.getRhs());
982 return success();
983 }
984 case arith::CmpIPredicate::ne: {
985 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
986 op, adaptor.getLhs(), adaptor.getRhs());
987 return success();
988 }
989 case arith::CmpIPredicate::uge:
990 case arith::CmpIPredicate::ugt:
991 case arith::CmpIPredicate::ule:
992 case arith::CmpIPredicate::ult: {
993 // There are no direct corresponding instructions in SPIR-V for such
994 // cases. Extend them to 32-bit and do comparision then.
995 Type type = rewriter.getI32Type();
996 if (auto vectorType = dyn_cast<VectorType>(dstType))
997 type = VectorType::get(vectorType.getShape(), type);
998 Value extLhs =
999 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1000 Value extRhs =
1001 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1002
1003 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1004 extRhs);
1005 return success();
1006 }
1007 default:
1008 break;
1009 }
1010 return failure();
1011 }
1012};
1013
1014/// Converts integer compare operation to SPIR-V ops.
1015class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1016public:
1017 using Base::Base;
1018
1019 LogicalResult
1020 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 Type srcType = op.getLhs().getType();
1023 if (isBoolScalarOrVector(srcType))
1024 return failure();
1025 Type dstType = getTypeConverter()->convertType(srcType);
1026 if (!dstType)
1027 return getTypeConversionFailure(rewriter, op, srcType);
1028
1029 switch (op.getPredicate()) {
1030#define DISPATCH(cmpPredicate, spirvOp) \
1031 case cmpPredicate: \
1032 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1033 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1034 !hasSameBitwidth(srcType, dstType)) { \
1035 return op.emitError( \
1036 "bitwidth emulation is not implemented yet on unsigned op"); \
1037 } \
1038 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1039 adaptor.getRhs()); \
1040 return success();
1041
1042 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1043 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1044 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1045 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1046 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1047 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1048 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1049 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1050 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1051 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1052
1053#undef DISPATCH
1054 }
1055 return failure();
1056 }
1057};
1058
1059//===----------------------------------------------------------------------===//
1060// CmpFOpPattern
1061//===----------------------------------------------------------------------===//
1062
1063/// Converts floating-point comparison operations to SPIR-V ops.
1064class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1065public:
1066 using Base::Base;
1067
1068 LogicalResult
1069 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 switch (op.getPredicate()) {
1072#define DISPATCH(cmpPredicate, spirvOp) \
1073 case cmpPredicate: \
1074 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1075 adaptor.getRhs()); \
1076 return success();
1077
1078 // Ordered.
1079 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1080 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1081 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1082 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1083 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1084 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1085 // Unordered.
1086 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1087 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1088 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1089 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1090 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1091 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1092
1093#undef DISPATCH
1094
1095 default:
1096 break;
1097 }
1098 return failure();
1099 }
1100};
1101
1102/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1103/// Kernel capability.
1104class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1105public:
1106 using Base::Base;
1107
1108 LogicalResult
1109 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1110 ConversionPatternRewriter &rewriter) const override {
1111 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1112 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1113 adaptor.getRhs());
1114 return success();
1115 }
1116
1117 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1118 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1119 adaptor.getRhs());
1120 return success();
1121 }
1122
1123 return failure();
1124 }
1125};
1126
1127/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1128/// require additional capability.
1129class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1130public:
1131 using Base::Base;
1132
1133 LogicalResult
1134 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1135 ConversionPatternRewriter &rewriter) const override {
1136 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1137 op.getPredicate() != arith::CmpFPredicate::UNO)
1138 return failure();
1139
1140 Location loc = op.getLoc();
1141
1142 Value replace;
1143 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1144 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1145 // Ordered comparsion checks if neither operand is NaN.
1146 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1147 } else {
1148 // Unordered comparsion checks if either operand is NaN.
1149 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1150 }
1151 } else {
1152 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1153 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1154
1155 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1156 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1157 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1158 }
1159
1160 rewriter.replaceOp(op, replace);
1161 return success();
1162 }
1163};
1164
1165//===----------------------------------------------------------------------===//
1166// AddUIExtendedOp
1167//===----------------------------------------------------------------------===//
1168
1169/// Converts arith.addui_extended to spirv.IAddCarry.
1170class AddUIExtendedOpPattern final
1171 : public OpConversionPattern<arith::AddUIExtendedOp> {
1172public:
1173 using Base::Base;
1174 LogicalResult
1175 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const override {
1177 Type dstElemTy = adaptor.getLhs().getType();
1178 Location loc = op->getLoc();
1179 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1180 adaptor.getRhs());
1181
1182 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1183 llvm::ArrayRef(0));
1184 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1185 llvm::ArrayRef(1));
1186
1187 // Convert the carry value to boolean.
1188 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1189 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1190
1191 rewriter.replaceOp(op, {sumResult, carryResult});
1192 return success();
1193 }
1194};
1195
1196//===----------------------------------------------------------------------===//
1197// MulIExtendedOp
1198//===----------------------------------------------------------------------===//
1199
1200/// Converts arith.mul*i_extended to spirv.*MulExtended.
1201template <typename ArithMulOp, typename SPIRVMulOp>
1202class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1203public:
1204 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1205 LogicalResult
1206 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1207 ConversionPatternRewriter &rewriter) const override {
1208 Location loc = op->getLoc();
1209 Value result =
1210 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1211
1212 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1213 llvm::ArrayRef(0));
1214 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1215 llvm::ArrayRef(1));
1216
1217 rewriter.replaceOp(op, {low, high});
1218 return success();
1219 }
1220};
1221
1222//===----------------------------------------------------------------------===//
1223// SelectOp
1224//===----------------------------------------------------------------------===//
1225
1226/// Converts arith.select to spirv.Select.
1227class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1228public:
1229 using Base::Base;
1230 LogicalResult
1231 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1232 ConversionPatternRewriter &rewriter) const override {
1233 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1234 adaptor.getTrueValue(),
1235 adaptor.getFalseValue());
1236 return success();
1237 }
1238};
1239
1240//===----------------------------------------------------------------------===//
1241// MinimumFOp, MaximumFOp
1242//===----------------------------------------------------------------------===//
1243
1244/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1245/// spirv.CL.fmax/fmin.
1246template <typename Op, typename SPIRVOp>
1247class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1248public:
1249 using OpConversionPattern<Op>::OpConversionPattern;
1250 LogicalResult
1251 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1252 ConversionPatternRewriter &rewriter) const override {
1253 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1254 Type dstType = converter->convertType(op.getType());
1255 if (!dstType)
1256 return getTypeConversionFailure(rewriter, op);
1257
1258 // arith.maximumf/minimumf:
1259 // "if one of the arguments is NaN, then the result is also NaN."
1260 // spirv.GL.FMax/FMin
1261 // "which operand is the result is undefined if one of the operands
1262 // is a NaN."
1263 // spirv.CL.fmax/fmin:
1264 // "If one argument is a NaN, Fmin returns the other argument."
1265
1266 Location loc = op.getLoc();
1267 Value spirvOp =
1268 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1269
1270 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1271 rewriter.replaceOp(op, spirvOp);
1272 return success();
1273 }
1274
1275 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1276 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1277
1278 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1279 adaptor.getLhs(), spirvOp);
1280 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1281 adaptor.getRhs(), select1);
1282
1283 rewriter.replaceOp(op, select2);
1284 return success();
1285 }
1286};
1287
1288//===----------------------------------------------------------------------===//
1289// MinNumFOp, MaxNumFOp
1290//===----------------------------------------------------------------------===//
1291
1292/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1293/// spirv.CL.fmax/fmin.
1294template <typename Op, typename SPIRVOp>
1295class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1296 template <typename TargetOp>
1297 constexpr bool shouldInsertNanGuards() const {
1298 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1299 }
1300
1301public:
1302 using OpConversionPattern<Op>::OpConversionPattern;
1303 LogicalResult
1304 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1305 ConversionPatternRewriter &rewriter) const override {
1306 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1307 Type dstType = converter->convertType(op.getType());
1308 if (!dstType)
1309 return getTypeConversionFailure(rewriter, op);
1310
1311 // arith.maxnumf/minnumf:
1312 // "If one of the arguments is NaN, then the result is the other
1313 // argument."
1314 // spirv.GL.FMax/FMin
1315 // "which operand is the result is undefined if one of the operands
1316 // is a NaN."
1317 // spirv.CL.fmax/fmin:
1318 // "If one argument is a NaN, Fmin returns the other argument."
1319
1320 Location loc = op.getLoc();
1321 Value spirvOp =
1322 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1323
1324 if (!shouldInsertNanGuards<SPIRVOp>() ||
1325 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1326 rewriter.replaceOp(op, spirvOp);
1327 return success();
1328 }
1329
1330 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1331 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1332
1333 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1334 adaptor.getRhs(), spirvOp);
1335 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1336 adaptor.getLhs(), select1);
1337
1338 rewriter.replaceOp(op, select2);
1339 return success();
1340 }
1341};
1342
1343} // namespace
1344
1345//===----------------------------------------------------------------------===//
1346// Pattern Population
1347//===----------------------------------------------------------------------===//
1348
1350 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1351 // clang-format off
1352 patterns.add<
1353 ConstantCompositeOpPattern,
1354 ConstantScalarOpPattern,
1355 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1356 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1357 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1361 RemSIOpGLPattern, RemSIOpCLPattern,
1362 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1363 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1364 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1365 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1374 ExtUIPattern, ExtUII1Pattern,
1375 ExtSIPattern, ExtSII1Pattern,
1376 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1377 TruncIPattern, TruncII1Pattern,
1378 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1379 TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
1380 TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
1381 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1382 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1383 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1384 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1385 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1386 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1387 CmpIOpBooleanPattern, CmpIOpPattern,
1388 CmpFOpNanNonePattern, CmpFOpPattern,
1389 AddUIExtendedOpPattern,
1390 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1391 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1392 SelectOpPattern,
1393
1394 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1395 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1396 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1397 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1402
1403 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1404 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1405 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1406 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1411 >(typeConverter, patterns.getContext());
1412 // clang-format on
1413
1414 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1415 // capability is available.
1416 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1417 /*benefit=*/2);
1418}
1419
1420//===----------------------------------------------------------------------===//
1421// Pass Definition
1422//===----------------------------------------------------------------------===//
1423
1424namespace {
1425struct ConvertArithToSPIRVPass
1426 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1427 using Base::Base;
1428
1429 void runOnOperation() override {
1430 Operation *op = getOperation();
1432 std::unique_ptr<SPIRVConversionTarget> target =
1433 SPIRVConversionTarget::get(targetAttr);
1434
1436 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1437 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1438 SPIRVTypeConverter typeConverter(targetAttr, options);
1439
1440 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1441 // in patterns for other dialects.
1442 target->addLegalOp<UnrealizedConversionCastOp>();
1443
1444 // Fail hard when there are any remaining 'arith' ops.
1445 target->addIllegalDialect<arith::ArithDialect>();
1446
1449
1450 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1451 signalPassFailure();
1452 }
1453};
1454} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static std::string getDecorationString(spirv::Decoration decor)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:246
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer, bool &detectedSplat)
Returns true if the given buffer is a valid raw buffer for the given type.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:23