MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 // Reject boolean types to allow specialized boolean patterns to handle
186 // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
187 if (!adaptor.getOperands().empty() &&
188 isBoolScalarOrVector(adaptor.getOperands().front().getType()))
189 return failure();
190 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
192 if (!dstType) {
193 return rewriter.notifyMatchFailure(
194 op->getLoc(),
195 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
196 }
197
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
199 !getElementTypeOrSelf(op.getType()).isIndex() &&
200 dstType != op.getType()) {
201 return op.emitError("bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
203 }
204
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
211 }
212
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
215
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
217 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
218 rewriter.getUnitAttr());
219
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
221 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
222 rewriter.getUnitAttr());
223
224 return success();
225 }
226};
227
228//===----------------------------------------------------------------------===//
229// ConstantOp
230//===----------------------------------------------------------------------===//
231
232/// Converts composite arith.constant operation to spirv.Constant.
233struct ConstantCompositeOpPattern final
234 : public OpConversionPattern<arith::ConstantOp> {
235 using Base::Base;
236
237 LogicalResult
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
242 return failure();
243
244 // arith.constant should only have vector or tensor types. This is a MLIR
245 // wide problem at the moment.
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
248
249 Type dstType = getTypeConverter()->convertType(srcType);
250 if (!dstType)
251 return failure();
252
253 // Import the resource into the IR to make use of the special handling of
254 // element types later on.
255 mlir::DenseElementsAttr dstElementsAttr;
256 if (auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 } else if (auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
261
262 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
263 if (!blob)
264 return constOp->emitError("could not find resource blob");
265
266 ArrayRef<char> ptr = blob->getData();
267
268 // Check that the buffer meets the requirements to get converted to a
269 // DenseElementsAttr
271 return constOp->emitError("resource is not a valid buffer");
272
273 dstElementsAttr =
274 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
275 } else {
276 return constOp->emitError("unsupported elements attribute");
277 }
278
279 ShapedType dstAttrType = dstElementsAttr.getType();
280
281 // If the composite type has more than one dimensions, perform
282 // linearization.
283 if (srcType.getRank() > 1) {
284 if (isa<RankedTensorType>(srcType)) {
285 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286 srcType.getElementType());
287 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
288 } else {
289 // TODO: add support for large vectors.
290 return failure();
291 }
292 }
293
294 Type srcElemType = srcType.getElementType();
295 Type dstElemType;
296 // Tensor types are converted to SPIR-V array types; vector types are
297 // converted to SPIR-V vector/array types.
298 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
299 dstElemType = arrayType.getElementType();
300 else
301 dstElemType = cast<VectorType>(dstType).getElementType();
302
303 // If the source and destination element types are different, perform
304 // attribute conversion.
305 if (srcElemType != dstElemType) {
307 if (isa<FloatType>(srcElemType)) {
308 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
309 Attribute dstAttr = nullptr;
310 // Handle 8-bit float conversion to 8-bit integer.
311 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
313 srcElemType.getIntOrFloatBitWidth() == 8 &&
314 isa<IntegerType>(dstElemType)) {
315 dstAttr =
316 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
317 } else {
318 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
319 rewriter);
320 }
321 if (!dstAttr)
322 return failure();
323 elements.push_back(dstAttr);
324 }
325 } else if (srcElemType.isInteger(1)) {
326 return failure();
327 } else {
328 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
329 IntegerAttr dstAttr = convertIntegerAttr(
330 srcAttr, cast<IntegerType>(dstElemType), rewriter);
331 if (!dstAttr)
332 return failure();
333 elements.push_back(dstAttr);
334 }
335 }
336
337 // Unfortunately, we cannot use dialect-specific types for element
338 // attributes; element attributes only works with builtin types. So we
339 // need to prepare another converted builtin types for the destination
340 // elements attribute.
341 if (isa<RankedTensorType>(dstAttrType))
342 dstAttrType =
343 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
344 else
345 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
346
347 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
348 }
349
350 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
351 dstElementsAttr);
352 return success();
353 }
354};
355
356/// Converts scalar arith.constant operation to spirv.Constant.
357struct ConstantScalarOpPattern final
358 : public OpConversionPattern<arith::ConstantOp> {
359 using Base::Base;
360
361 LogicalResult
362 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override {
364 Type srcType = constOp.getType();
365 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
366 if (shapedType.getNumElements() != 1)
367 return failure();
368 srcType = shapedType.getElementType();
369 }
370 if (!srcType.isIntOrIndexOrFloat())
371 return failure();
372
373 Attribute cstAttr = constOp.getValue();
374 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
375 cstAttr = elementsAttr.getSplatValue<Attribute>();
376
377 Type dstType = getTypeConverter()->convertType(srcType);
378 if (!dstType)
379 return failure();
380
381 // Floating-point types.
382 if (isa<FloatType>(srcType)) {
383 auto srcAttr = cast<FloatAttr>(cstAttr);
384 Attribute dstAttr = srcAttr;
385
386 // Floating-point types not supported in the target environment are all
387 // converted to float type.
388 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
389 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
390 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
391 dstType.getIntOrFloatBitWidth() == 8) {
392 // If the source is an 8-bit float, convert it to a 8-bit integer.
393 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
394 if (!dstAttr)
395 return failure();
396 } else if (srcType != dstType) {
397 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
398 if (!dstAttr)
399 return failure();
400 }
401
402 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
403 return success();
404 }
405
406 // Bool type.
407 if (srcType.isInteger(1)) {
408 // arith.constant can use 0/1 instead of true/false for i1 values. We need
409 // to handle that here.
410 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
411 if (!dstAttr)
412 return failure();
413 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
414 return success();
415 }
416
417 // IndexType or IntegerType. Index values are converted to 32-bit integer
418 // values when converting to SPIR-V.
419 auto srcAttr = cast<IntegerAttr>(cstAttr);
420 IntegerAttr dstAttr =
421 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
422 if (!dstAttr)
423 return failure();
424 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
425 return success();
426 }
427};
428
429//===----------------------------------------------------------------------===//
430// RemSIOp
431//===----------------------------------------------------------------------===//
432
433/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
434/// the sign of `signOperand`.
435///
436/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
437/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
438/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
439/// if either operand can be negative. Emulate it via spirv.UMod.
440template <typename SignedAbsOp>
441static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
442 Value signOperand, OpBuilder &builder) {
443 assert(lhs.getType() == rhs.getType());
444 assert(lhs == signOperand || rhs == signOperand);
445
446 Type type = lhs.getType();
447
448 // Calculate the remainder with spirv.UMod.
449 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
450 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
451 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
452
453 // Fix the sign.
454 Value isPositive;
455 if (lhs == signOperand)
456 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
457 else
458 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
459 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
460 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
461 absNegate);
462}
463
464/// Converts arith.remsi to GLSL SPIR-V ops.
465///
466/// This cannot be merged into the template unary/binary pattern due to Vulkan
467/// restrictions over spirv.SRem and spirv.SMod.
468struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
469 using Base::Base;
470
471 LogicalResult
472 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter) const override {
474 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
475 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
476 adaptor.getOperands()[0], rewriter);
477 rewriter.replaceOp(op, result);
478
479 return success();
480 }
481};
482
483/// Converts arith.remsi to OpenCL SPIR-V ops.
484struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
485 using Base::Base;
486
487 LogicalResult
488 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
491 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
492 adaptor.getOperands()[0], rewriter);
493 rewriter.replaceOp(op, result);
494
495 return success();
496 }
497};
498
499//===----------------------------------------------------------------------===//
500// BitwiseOp
501//===----------------------------------------------------------------------===//
502
503/// Converts bitwise operations to SPIR-V operations. This is a special pattern
504/// other than the BinaryOpPatternPattern because if the operands are boolean
505/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
506/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
507template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
508struct BitwiseOpPattern final : public OpConversionPattern<Op> {
509 using OpConversionPattern<Op>::OpConversionPattern;
510
511 LogicalResult
512 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
513 ConversionPatternRewriter &rewriter) const override {
514 assert(adaptor.getOperands().size() == 2);
515 Type dstType = this->getTypeConverter()->convertType(op.getType());
516 if (!dstType)
517 return getTypeConversionFailure(rewriter, op);
518
519 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
520 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
521 op, dstType, adaptor.getOperands());
522 } else {
523 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
524 op, dstType, adaptor.getOperands());
525 }
526 return success();
527 }
528};
529
530//===----------------------------------------------------------------------===//
531// XOrIOp
532//===----------------------------------------------------------------------===//
533
534/// Converts arith.xori to SPIR-V operations.
535struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
536 using Base::Base;
537
538 LogicalResult
539 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 assert(adaptor.getOperands().size() == 2);
542
543 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
544 return failure();
545
546 Type dstType = getTypeConverter()->convertType(op.getType());
547 if (!dstType)
548 return getTypeConversionFailure(rewriter, op);
549
550 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
551 adaptor.getOperands());
552
553 return success();
554 }
555};
556
557/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
558/// vector of i1.
559struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
560 using Base::Base;
561
562 LogicalResult
563 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter) const override {
565 assert(adaptor.getOperands().size() == 2);
566
567 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
568 return failure();
569
570 Type dstType = getTypeConverter()->convertType(op.getType());
571 if (!dstType)
572 return getTypeConversionFailure(rewriter, op);
573
574 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
575 op, dstType, adaptor.getOperands());
576 return success();
577 }
578};
579
580/// Converts an arith integer op to the given SPIR-V boolean op if the type is
581/// i1 or vector of i1. Each mapping follows from the boolean truth table of
582/// the operation:
583/// addi(a, b) = a ^ b (add mod 2 = XOR = LogicalNotEqual)
584/// subi(a, b) = a ^ b (sub mod 2 = XOR = LogicalNotEqual)
585/// muli(a, b) = a & b (1*1=1, else 0 = LogicalAnd)
586/// divui(a, b) = a & b (a/1=a, a/0=UB; truth table matches AND)
587/// divsi(a, b) = a & b (same as divui on i1)
588/// maxsi(a, b) = a & b (signed i1: 1 represents -1, so max is 0 unless both
589/// are 1)
590/// maxui(a, b) = a | b (unsigned max on i1: 1 when either operand is 1)
591/// minsi(a, b) = a | b (signed i1: -1 < 0, so min is 1 when either operand
592/// is 1)
593/// minui(a, b) = a & b (unsigned min on i1: 1 only when both operands are
594/// 1)
595template <typename ArithOp, typename SPIRVOp>
596struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
597 BoolIOpPattern(const TypeConverter &converter, MLIRContext *context)
598 // benefit=2: takes priority over the generic ElementwiseArithOpPattern
599 // (benefit=1) when the operand type is i1.
600 : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
601
602 LogicalResult
603 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
604 ConversionPatternRewriter &rewriter) const override {
605 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
606 return failure();
607
608 Type dstType = this->getTypeConverter()->convertType(op.getType());
609 if (!dstType)
610 return getTypeConversionFailure(rewriter, op);
611
612 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
613 return success();
614 }
615};
616
617/// Converts an arith binary op on i1 to spirv.LogicalAnd(lhs,
618/// spirv.LogicalNot(rhs)). This covers shift-left, shift-right-unsigned, and
619/// unsigned remainder on i1:
620/// shli(a, b) = a & ~b (shift left clears the bit when b=1)
621/// shrui(a, b) = a & ~b (shift right unsigned clears the bit when b=1)
622/// remui(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
623/// gives 0)
624/// remsi(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
625/// gives 0)
626template <typename ArithOp>
627struct BoolIOpAndNotPattern final : public OpConversionPattern<ArithOp> {
628 BoolIOpAndNotPattern(const TypeConverter &converter, MLIRContext *context)
629 // benefit=2: takes priority over the generic ElementwiseArithOpPattern
630 // (benefit=1) when the operand type is i1.
631 : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
632
633 LogicalResult
634 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
635 ConversionPatternRewriter &rewriter) const override {
636 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
637 return failure();
638
639 Type dstType = this->getTypeConverter()->convertType(op.getType());
640 if (!dstType)
641 return getTypeConversionFailure(rewriter, op);
642
643 Location loc = op.getLoc();
644 Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
645 adaptor.getOperands()[1]);
646 rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
647 op, dstType, adaptor.getOperands()[0], notRhs);
648 return success();
649 }
650};
651
652/// Converts arith.shrsi on i1 to identity: arithmetic right shift of a 1-bit
653/// signed value always yields the original value (0 >> n = 0, -1 >> n = -1).
654struct ShRSIBoolPattern final : public OpConversionPattern<arith::ShRSIOp> {
655 ShRSIBoolPattern(const TypeConverter &converter, MLIRContext *context)
656 // benefit=2: takes priority over the generic spirv::ElementwiseOpPattern
657 // (benefit=1) when the operand type is i1.
658 : OpConversionPattern<arith::ShRSIOp>(converter, context,
659 /*benefit=*/2) {}
660
661 LogicalResult
662 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter) const override {
664 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
665 return failure();
666
667 rewriter.replaceOp(op, adaptor.getOperands().front());
668 return success();
669 }
670};
671
672//===----------------------------------------------------------------------===//
673// UIToFPOp
674//===----------------------------------------------------------------------===//
675
676/// Converts arith.uitofp to spirv.Select if the type of source is i1 or vector
677/// of i1.
678struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
679 using Base::Base;
680
681 LogicalResult
682 matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
683 ConversionPatternRewriter &rewriter) const override {
684 Type srcType = adaptor.getOperands().front().getType();
685 if (!isBoolScalarOrVector(srcType))
686 return failure();
687
688 Type dstType = getTypeConverter()->convertType(op.getType());
689 if (!dstType)
690 return getTypeConversionFailure(rewriter, op);
691
692 Location loc = op.getLoc();
693 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
694 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
695 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
696 op, dstType, adaptor.getOperands().front(), one, zero);
697 return success();
698 }
699};
700
701/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF.
702/// When the source integer type was widened during type conversion (e.g., i8
703/// emulated as i32), the upper bits of the widened value may contain garbage.
704/// This pattern cleans the upper bits before the conversion:
705/// - For unsigned (IsSigned=false): mask with BitwiseAnd.
706/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical +
707/// ShiftRightArithmetic.
708template <typename ArithOp, typename SPIRVOp, bool IsSigned>
709struct IntToFPPattern final : public OpConversionPattern<ArithOp> {
710 using OpConversionPattern<ArithOp>::OpConversionPattern;
711
712 LogicalResult
713 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
714 ConversionPatternRewriter &rewriter) const override {
715 Type srcType = adaptor.getOperands().front().getType();
716 if (isBoolScalarOrVector(srcType))
717 return failure();
718
719 Type dstType = this->getTypeConverter()->convertType(op.getType());
720 if (!dstType)
721 return getTypeConversionFailure(rewriter, op);
722
723 // Check if the source integer type was widened during type conversion.
724 unsigned originalBitwidth =
725 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
726 unsigned convertedBitwidth =
728
729 if (originalBitwidth >= convertedBitwidth) {
730 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
731 return success();
732 }
733
734 // The source was widened. Clean the upper bits before converting.
735 Location loc = op.getLoc();
736 Value cleaned;
737 if constexpr (IsSigned) {
738 // Sign-extend by shifting left then arithmetic right.
739 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
740 Value shiftSize =
741 getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc);
742 Value shifted = spirv::ShiftLeftLogicalOp::create(
743 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
744 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
745 shifted, shiftSize);
746 } else {
747 // Zero-extend by masking off the upper bits.
749 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
750 loc);
751 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
752 adaptor.getIn(), mask);
753 }
754 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
755 return success();
756 }
757};
759//===----------------------------------------------------------------------===//
760// IndexCastOp
761//===----------------------------------------------------------------------===//
762
763/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
764struct IndexCastIndexI1Pattern final
765 : public OpConversionPattern<arith::IndexCastOp> {
767
768 LogicalResult
769 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
770 ConversionPatternRewriter &rewriter) const override {
771 if (!isBoolScalarOrVector(op.getType()))
772 return failure();
773
774 Type dstType = getTypeConverter()->convertType(op.getType());
775 if (!dstType)
776 return getTypeConversionFailure(rewriter, op);
777
778 Location loc = op.getLoc();
779 Value zeroIdx =
780 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
781 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
782 adaptor.getIn());
783 return success();
784 }
787/// Converts arith.index_cast to spirv.Select if the source type is i1.
788struct IndexCastI1IndexPattern final
789 : public OpConversionPattern<arith::IndexCastOp> {
790 using Base::Base;
791
792 LogicalResult
793 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter) const override {
795 if (!isBoolScalarOrVector(adaptor.getIn().getType()))
796 return failure();
797
798 Type dstType = getTypeConverter()->convertType(op.getType());
799 if (!dstType)
800 return getTypeConversionFailure(rewriter, op);
801
802 Location loc = op.getLoc();
803 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
804 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
805 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, adaptor.getIn(),
806 one, zero);
807 return success();
808 }
809};
810
811//===----------------------------------------------------------------------===//
812// ExtSIOp
813//===----------------------------------------------------------------------===//
814
815/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
816/// of i1.
817struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
818 using Base::Base;
819
820 LogicalResult
821 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
822 ConversionPatternRewriter &rewriter) const override {
823 Value operand = adaptor.getIn();
824 if (!isBoolScalarOrVector(operand.getType()))
825 return failure();
826
827 Location loc = op.getLoc();
828 Type dstType = getTypeConverter()->convertType(op.getType());
829 if (!dstType)
830 return getTypeConversionFailure(rewriter, op);
831
832 Value allOnes;
833 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
834 unsigned componentBitwidth = intTy.getWidth();
835 allOnes = spirv::ConstantOp::create(
836 rewriter, loc, intTy,
837 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
838 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
839 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
840 allOnes = spirv::ConstantOp::create(
841 rewriter, loc, vectorTy,
842 SplatElementsAttr::get(vectorTy,
843 APInt::getAllOnes(componentBitwidth)));
844 } else {
845 return rewriter.notifyMatchFailure(
846 loc, llvm::formatv("unhandled type: {0}", dstType));
847 }
848
849 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
850 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
851 zero);
852 return success();
853 }
854};
855
856/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
857/// vector of i1.
858struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
859 using Base::Base;
860
861 LogicalResult
862 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
863 ConversionPatternRewriter &rewriter) const override {
864 Type srcType = adaptor.getIn().getType();
865 if (isBoolScalarOrVector(srcType))
866 return failure();
867
868 Type dstType = getTypeConverter()->convertType(op.getType());
869 if (!dstType)
870 return getTypeConversionFailure(rewriter, op);
871
872 if (dstType == srcType) {
873 // We can have the same source and destination type due to type emulation.
874 // Perform bit shifting to make sure we have the proper leading set bits.
875
876 unsigned srcBW =
877 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
878 unsigned dstBW =
880 assert(srcBW < dstBW);
881 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
882 rewriter, op.getLoc());
883 if (!shiftSize)
884 return rewriter.notifyMatchFailure(op, "unsupported type for shift");
885
886 // First shift left to sequeeze out all leading bits beyond the original
887 // bitwidth. Here we need to use the original source and result type's
888 // bitwidth.
889 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
890 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
891
892 // Then we perform arithmetic right shift to make sure we have the right
893 // sign bits for negative values.
894 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
895 op, dstType, shiftLOp, shiftSize);
896 } else {
897 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
898 adaptor.getOperands());
899 }
900
901 return success();
902 }
903};
904
905//===----------------------------------------------------------------------===//
906// ExtUIOp
907//===----------------------------------------------------------------------===//
908
909/// Converts arith.extui to spirv.Select if the type of source is i1 or vector
910/// of i1.
911struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
912 using Base::Base;
913
914 LogicalResult
915 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
916 ConversionPatternRewriter &rewriter) const override {
917 Type srcType = adaptor.getOperands().front().getType();
918 if (!isBoolScalarOrVector(srcType))
919 return failure();
920
921 Type dstType = getTypeConverter()->convertType(op.getType());
922 if (!dstType)
923 return getTypeConversionFailure(rewriter, op);
924
925 Location loc = op.getLoc();
926 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
927 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
928 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
929 op, dstType, adaptor.getOperands().front(), one, zero);
930 return success();
931 }
932};
933
934/// Converts arith.extui for cases where the type of source is neither i1 nor
935/// vector of i1.
936struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
937 using Base::Base;
938
939 LogicalResult
940 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter) const override {
942 Type srcType = adaptor.getIn().getType();
943 if (isBoolScalarOrVector(srcType))
944 return failure();
945
946 Type dstType = getTypeConverter()->convertType(op.getType());
947 if (!dstType)
948 return getTypeConversionFailure(rewriter, op);
949
950 if (dstType == srcType) {
951 // We can have the same source and destination type due to type emulation.
952 // Perform bit masking to make sure we don't pollute downstream consumers
953 // with unwanted bits. Here we need to use the original source type's
954 // bitwidth.
955 unsigned bitwidth =
956 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
957 Value mask = getScalarOrVectorConstInt(
958 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
959 op.getLoc());
960 if (!mask)
961 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
962 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
963 adaptor.getIn(), mask);
964 } else {
965 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
966 adaptor.getOperands());
967 }
968 return success();
969 }
970};
971
972//===----------------------------------------------------------------------===//
973// TruncIOp
974//===----------------------------------------------------------------------===//
975
976/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
977/// of i1.
978struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
979 using Base::Base;
980
981 LogicalResult
982 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter) const override {
984 Type dstType = getTypeConverter()->convertType(op.getType());
985 if (!dstType)
986 return getTypeConversionFailure(rewriter, op);
987
988 if (!isBoolScalarOrVector(dstType))
989 return failure();
990
991 Location loc = op.getLoc();
992 auto srcType = adaptor.getOperands().front().getType();
993 // Check if (x & 1) == 1.
994 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
995 Value maskedSrc = spirv::BitwiseAndOp::create(
996 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
997 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
998
999 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
1000 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
1001 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
1002 return success();
1003 }
1004};
1005
1006/// Converts arith.trunci for cases where the type of result is neither i1
1007/// nor vector of i1.
1008struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
1009 using Base::Base;
1010
1011 LogicalResult
1012 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
1013 ConversionPatternRewriter &rewriter) const override {
1014 Type srcType = adaptor.getIn().getType();
1015 Type dstType = getTypeConverter()->convertType(op.getType());
1016 if (!dstType)
1017 return getTypeConversionFailure(rewriter, op);
1018
1019 if (isBoolScalarOrVector(dstType))
1020 return failure();
1021
1022 if (dstType == srcType) {
1023 // We can have the same source and destination type due to type emulation.
1024 // Perform bit masking to make sure we don't pollute downstream consumers
1025 // with unwanted bits. Here we need to use the original result type's
1026 // bitwidth.
1027 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
1028 Value mask = getScalarOrVectorConstInt(
1029 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
1030 if (!mask)
1031 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
1032 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
1033 adaptor.getIn(), mask);
1034 } else {
1035 // Given this is truncation, either SConvertOp or UConvertOp works.
1036 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
1037 adaptor.getOperands());
1038 }
1039 return success();
1040 }
1041};
1042
1043//===----------------------------------------------------------------------===//
1044// TypeCastingOp
1045//===----------------------------------------------------------------------===//
1046
1047static std::optional<spirv::FPRoundingMode>
1048convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
1049 switch (roundingMode) {
1050 case arith::RoundingMode::downward:
1051 return spirv::FPRoundingMode::RTN;
1052 case arith::RoundingMode::to_nearest_even:
1053 return spirv::FPRoundingMode::RTE;
1054 case arith::RoundingMode::toward_zero:
1055 return spirv::FPRoundingMode::RTZ;
1056 case arith::RoundingMode::upward:
1057 return spirv::FPRoundingMode::RTP;
1058 case arith::RoundingMode::to_nearest_away:
1059 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
1060 // (as of SPIR-V 1.6)
1061 return std::nullopt;
1062 }
1063 llvm_unreachable("Unhandled rounding mode");
1064}
1065
1066/// Converts type-casting standard operations to SPIR-V operations.
1067template <typename Op, typename SPIRVOp>
1068struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
1069 using OpConversionPattern<Op>::OpConversionPattern;
1070
1071 LogicalResult
1072 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1073 ConversionPatternRewriter &rewriter) const override {
1074 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
1075 Type dstType = this->getTypeConverter()->convertType(op.getType());
1076 if (!dstType)
1077 return getTypeConversionFailure(rewriter, op);
1078
1079 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
1080 return failure();
1081
1082 if (dstType == srcType) {
1083 // Due to type conversion, we are seeing the same source and target type.
1084 // Then we can just erase this operation by forwarding its operand.
1085 rewriter.replaceOp(op, adaptor.getOperands().front());
1086 } else {
1087 // Compute new rounding mode (if any).
1088 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
1089 if (auto roundingModeOp =
1090 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
1091 if (arith::RoundingModeAttr roundingMode =
1092 roundingModeOp.getRoundingModeAttr()) {
1093 if (!(rm =
1094 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
1095 return rewriter.notifyMatchFailure(
1096 op->getLoc(),
1097 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
1098 }
1099 }
1100 }
1101 // Create replacement op and attach rounding mode attribute (if any).
1102 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1103 op, dstType, adaptor.getOperands());
1104 if (rm) {
1105 newOp->setAttr(
1106 getDecorationString(spirv::Decoration::FPRoundingMode),
1107 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1108 }
1109 }
1110 return success();
1111 }
1112};
1113
1114//===----------------------------------------------------------------------===//
1115// CmpIOp
1116//===----------------------------------------------------------------------===//
1117
1118/// Converts integer compare operation on i1 type operands to SPIR-V ops.
1119class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
1120public:
1121 using Base::Base;
1122
1123 LogicalResult
1124 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter) const override {
1126 Type srcType = op.getLhs().getType();
1127 if (!isBoolScalarOrVector(srcType))
1128 return failure();
1129 Type dstType = getTypeConverter()->convertType(srcType);
1130 if (!dstType)
1131 return getTypeConversionFailure(rewriter, op, srcType);
1132
1133 switch (op.getPredicate()) {
1134 case arith::CmpIPredicate::eq: {
1135 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1136 adaptor.getRhs());
1137 return success();
1138 }
1139 case arith::CmpIPredicate::ne: {
1140 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1141 op, adaptor.getLhs(), adaptor.getRhs());
1142 return success();
1143 }
1144 case arith::CmpIPredicate::uge:
1145 case arith::CmpIPredicate::ugt:
1146 case arith::CmpIPredicate::ule:
1147 case arith::CmpIPredicate::ult: {
1148 // There are no direct corresponding instructions in SPIR-V for such
1149 // cases. Extend them to 32-bit and do comparision then.
1150 Type type = rewriter.getI32Type();
1151 if (auto vectorType = dyn_cast<VectorType>(dstType))
1152 type = VectorType::get(vectorType.getShape(), type);
1153 Value extLhs =
1154 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1155 Value extRhs =
1156 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1157
1158 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1159 extRhs);
1160 return success();
1161 }
1162 default:
1163 break;
1164 }
1165 return failure();
1166 }
1167};
1168
1169/// Converts integer compare operation to SPIR-V ops.
1170class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1171public:
1172 using Base::Base;
1173
1174 LogicalResult
1175 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter) const override {
1177 Type srcType = op.getLhs().getType();
1178 if (isBoolScalarOrVector(srcType))
1179 return failure();
1180 Type dstType = getTypeConverter()->convertType(srcType);
1181 if (!dstType)
1182 return getTypeConversionFailure(rewriter, op, srcType);
1183
1184 switch (op.getPredicate()) {
1185#define DISPATCH(cmpPredicate, spirvOp) \
1186 case cmpPredicate: \
1187 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1188 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1189 !hasSameBitwidth(srcType, dstType)) { \
1190 return op.emitError( \
1191 "bitwidth emulation is not implemented yet on unsigned op"); \
1192 } \
1193 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1194 adaptor.getRhs()); \
1195 return success();
1196
1197 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1198 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1199 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1200 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1201 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1202 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1203 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1204 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1205 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1206 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1207
1208#undef DISPATCH
1209 }
1210 return failure();
1211 }
1212};
1213
1214//===----------------------------------------------------------------------===//
1215// CmpFOpPattern
1216//===----------------------------------------------------------------------===//
1217
1218/// Converts floating-point comparison operations to SPIR-V ops.
1219class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1220public:
1221 using Base::Base;
1222
1223 LogicalResult
1224 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1225 ConversionPatternRewriter &rewriter) const override {
1226 switch (op.getPredicate()) {
1227#define DISPATCH(cmpPredicate, spirvOp) \
1228 case cmpPredicate: \
1229 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1230 adaptor.getRhs()); \
1231 return success();
1232
1233 // Ordered.
1234 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1235 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1236 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1237 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1238 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1239 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1240 // Unordered.
1241 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1242 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1243 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1244 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1245 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1246 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1247
1248#undef DISPATCH
1249
1250 default:
1251 break;
1252 }
1253 return failure();
1254 }
1255};
1256
1257/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1258/// Kernel capability.
1259class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1260public:
1261 using Base::Base;
1262
1263 LogicalResult
1264 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const override {
1266 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1267 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1268 adaptor.getRhs());
1269 return success();
1270 }
1271
1272 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1273 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1274 adaptor.getRhs());
1275 return success();
1276 }
1277
1278 return failure();
1279 }
1280};
1281
1282/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1283/// require additional capability.
1284class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1285public:
1286 using Base::Base;
1287
1288 LogicalResult
1289 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1290 ConversionPatternRewriter &rewriter) const override {
1291 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1292 op.getPredicate() != arith::CmpFPredicate::UNO)
1293 return failure();
1294
1295 Location loc = op.getLoc();
1296
1297 Value replace;
1298 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1299 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1300 // Ordered comparsion checks if neither operand is NaN.
1301 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1302 } else {
1303 // Unordered comparsion checks if either operand is NaN.
1304 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1305 }
1306 } else {
1307 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1308 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1309
1310 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1311 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1312 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1313 }
1314
1315 rewriter.replaceOp(op, replace);
1316 return success();
1317 }
1318};
1319
1320//===----------------------------------------------------------------------===//
1321// AddUIExtendedOp
1322//===----------------------------------------------------------------------===//
1323
1324/// Converts arith.addui_extended to spirv.IAddCarry.
1325class AddUIExtendedOpPattern final
1326 : public OpConversionPattern<arith::AddUIExtendedOp> {
1327public:
1328 using Base::Base;
1329 LogicalResult
1330 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
1331 ConversionPatternRewriter &rewriter) const override {
1332 Type dstElemTy = adaptor.getLhs().getType();
1333 Location loc = op->getLoc();
1334 Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(),
1335 adaptor.getRhs());
1336
1337 Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1338 llvm::ArrayRef(0));
1339 Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1340 llvm::ArrayRef(1));
1341
1342 // Convert the carry value to boolean.
1343 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1344 Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one);
1345
1346 rewriter.replaceOp(op, {sumResult, carryResult});
1347 return success();
1348 }
1349};
1350
1351//===----------------------------------------------------------------------===//
1352// MulIExtendedOp
1353//===----------------------------------------------------------------------===//
1354
1355/// Converts arith.mul*i_extended to spirv.*MulExtended.
1356template <typename ArithMulOp, typename SPIRVMulOp>
1357class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1358public:
1359 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1360 LogicalResult
1361 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1362 ConversionPatternRewriter &rewriter) const override {
1363 Location loc = op->getLoc();
1364 Value result =
1365 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1366
1367 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1368 llvm::ArrayRef(0));
1369 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1370 llvm::ArrayRef(1));
1371
1372 rewriter.replaceOp(op, {low, high});
1373 return success();
1374 }
1375};
1376
1377//===----------------------------------------------------------------------===//
1378// SelectOp
1379//===----------------------------------------------------------------------===//
1380
1381/// Converts arith.select to spirv.Select.
1382class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1383public:
1384 using Base::Base;
1385 LogicalResult
1386 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1387 ConversionPatternRewriter &rewriter) const override {
1388 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1389 adaptor.getTrueValue(),
1390 adaptor.getFalseValue());
1391 return success();
1392 }
1393};
1394
1395//===----------------------------------------------------------------------===//
1396// MinimumFOp, MaximumFOp
1397//===----------------------------------------------------------------------===//
1398
1399/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1400/// spirv.CL.fmax/fmin.
1401template <typename Op, typename SPIRVOp>
1402class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1403public:
1404 using OpConversionPattern<Op>::OpConversionPattern;
1405 LogicalResult
1406 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1407 ConversionPatternRewriter &rewriter) const override {
1408 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1409 Type dstType = converter->convertType(op.getType());
1410 if (!dstType)
1411 return getTypeConversionFailure(rewriter, op);
1412
1413 // arith.maximumf/minimumf:
1414 // "if one of the arguments is NaN, then the result is also NaN."
1415 // spirv.GL.FMax/FMin
1416 // "which operand is the result is undefined if one of the operands
1417 // is a NaN."
1418 // spirv.CL.fmax/fmin:
1419 // "If one argument is a NaN, Fmin returns the other argument."
1420
1421 Location loc = op.getLoc();
1422 Value spirvOp =
1423 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1424
1425 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1426 rewriter.replaceOp(op, spirvOp);
1427 return success();
1428 }
1429
1430 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1431 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1432
1433 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1434 adaptor.getLhs(), spirvOp);
1435 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1436 adaptor.getRhs(), select1);
1437
1438 rewriter.replaceOp(op, select2);
1439 return success();
1440 }
1441};
1442
1443//===----------------------------------------------------------------------===//
1444// MinNumFOp, MaxNumFOp
1445//===----------------------------------------------------------------------===//
1446
1447/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1448/// spirv.CL.fmax/fmin.
1449template <typename Op, typename SPIRVOp>
1450class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1451 template <typename TargetOp>
1452 constexpr bool shouldInsertNanGuards() const {
1453 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1454 }
1455
1456public:
1457 using OpConversionPattern<Op>::OpConversionPattern;
1458 LogicalResult
1459 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1460 ConversionPatternRewriter &rewriter) const override {
1461 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1462 Type dstType = converter->convertType(op.getType());
1463 if (!dstType)
1464 return getTypeConversionFailure(rewriter, op);
1465
1466 // arith.maxnumf/minnumf:
1467 // "If one of the arguments is NaN, then the result is the other
1468 // argument."
1469 // spirv.GL.FMax/FMin
1470 // "which operand is the result is undefined if one of the operands
1471 // is a NaN."
1472 // spirv.CL.fmax/fmin:
1473 // "If one argument is a NaN, Fmin returns the other argument."
1474
1475 Location loc = op.getLoc();
1476 Value spirvOp =
1477 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1478
1479 if (!shouldInsertNanGuards<SPIRVOp>() ||
1480 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1481 rewriter.replaceOp(op, spirvOp);
1482 return success();
1483 }
1484
1485 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1486 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1487
1488 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1489 adaptor.getRhs(), spirvOp);
1490 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1491 adaptor.getLhs(), select1);
1492
1493 rewriter.replaceOp(op, select2);
1494 return success();
1495 }
1496};
1497
1498} // namespace
1499
1500//===----------------------------------------------------------------------===//
1501// Pattern Population
1502//===----------------------------------------------------------------------===//
1503
1505 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1506 // clang-format off
1507 patterns.add<
1508 ConstantCompositeOpPattern,
1509 ConstantScalarOpPattern,
1510 BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>, // add mod 2 = XOR = not-equal
1511 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1512 BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>, // sub mod 2 = XOR = not-equal
1513 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1514 BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>, // 1*1=1, else 0 = AND
1515 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1516 BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>, // a/1=a, a/0=UB; truth table = AND
1518 BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>, // same as divui on i1
1520 BoolIOpAndNotPattern<arith::RemUIOp>, // remui(a,b) = a & ~b (see pattern comment)
1522 BoolIOpAndNotPattern<arith::RemSIOp>, // remsi(a,b) = a & ~b (see pattern comment)
1523 RemSIOpGLPattern, RemSIOpCLPattern,
1524 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1525 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1526 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1527 BoolIOpAndNotPattern<arith::ShLIOp>, // shli(a,b) = a & ~b (see pattern comment)
1528 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1529 BoolIOpAndNotPattern<arith::ShRUIOp>, // shrui(a,b) = a & ~b (see pattern comment)
1531 ShRSIBoolPattern, // shrsi(a,b) = a (identity; see pattern comment)
1539 ExtUIPattern, ExtUII1Pattern,
1540 ExtSIPattern, ExtSII1Pattern,
1541 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1542 TruncIPattern, TruncII1Pattern,
1543 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1544 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1545 UIToFPI1Pattern,
1546 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1547 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1548 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1549 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1550 IndexCastIndexI1Pattern, IndexCastI1IndexPattern,
1551 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1552 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1553 CmpIOpBooleanPattern, CmpIOpPattern,
1554 CmpFOpNanNonePattern, CmpFOpPattern,
1555 AddUIExtendedOpPattern,
1556 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1557 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1558 SelectOpPattern,
1559
1560 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1561 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1562 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1563 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1564 BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>, // signed i1: 1=-1, so max=0 unless both are 1
1565 BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>, // unsigned max on i1: 1 when either is 1
1566 BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>, // signed i1: -1<0, so min=1 when either is 1
1567 BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>, // unsigned min on i1: 1 only when both are 1
1572
1573 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1574 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1575 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1576 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1581 >(typeConverter, patterns.getContext());
1582 // clang-format on
1583
1584 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1585 // capability is available.
1586 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1587 /*benefit=*/2);
1588}
1589
1590//===----------------------------------------------------------------------===//
1591// Pass Definition
1592//===----------------------------------------------------------------------===//
1593
1594namespace {
1595struct ConvertArithToSPIRVPass
1596 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1597 using Base::Base;
1598
1599 void runOnOperation() override {
1600 Operation *op = getOperation();
1602 std::unique_ptr<SPIRVConversionTarget> target =
1603 SPIRVConversionTarget::get(targetAttr);
1604
1606 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1607 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1608 SPIRVTypeConverter typeConverter(targetAttr, options);
1609
1610 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1611 // in patterns for other dialects.
1612 target->addLegalOp<UnrealizedConversionCastOp>();
1613
1614 // Fail hard when there are any remaining 'arith' ops.
1615 target->addIllegalDialect<arith::ArithDialect>();
1616
1617 RewritePatternSet patterns(&getContext());
1618 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1619
1620 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1621 signalPassFailure();
1622 }
1623};
1624} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:250
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
result_type_range getResultTypes()
Definition Operation.h:454
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:430
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24