MLIR 23.0.0git
ArithToSPIRV.cpp
Go to the documentation of this file.
1//===- ArithToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
21#include "llvm/ADT/APInt.h"
22#include "llvm/ADT/ArrayRef.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/MathExtras.h"
26#include <cassert>
27#include <memory>
28
29namespace mlir {
30#define GEN_PASS_DEF_CONVERTARITHTOSPIRVPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34#define DEBUG_TYPE "arith-to-spirv-pattern"
35
36using namespace mlir;
37
38//===----------------------------------------------------------------------===//
39// Conversion Helpers
40//===----------------------------------------------------------------------===//
41
42/// Converts the given `srcAttr` into a boolean attribute if it holds an
43/// integral value. Returns null attribute if conversion fails.
44static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
45 if (auto boolAttr = dyn_cast<BoolAttr>(srcAttr))
46 return boolAttr;
47 if (auto intAttr = dyn_cast<IntegerAttr>(srcAttr))
48 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
49 return {};
50}
51
52/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
53/// Returns null attribute if conversion fails.
54static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
55 Builder builder) {
56 // If the source number uses less active bits than the target bitwidth, then
57 // it should be safe to convert.
58 if (srcAttr.getValue().isIntN(dstType.getWidth()))
59 return builder.getIntegerAttr(dstType, srcAttr.getInt());
60
61 // XXX: Try again by interpreting the source number as a signed value.
62 // Although integers in the standard dialect are signless, they can represent
63 // a signed number. It's the operation decides how to interpret. This is
64 // dangerous, but it seems there is no good way of handling this if we still
65 // want to change the bitwidth. Emit a message at least.
66 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
67 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
68 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
69 << dstAttr << "' for type '" << dstType << "'\n");
70 return dstAttr;
71 }
72
73 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
74 << "' illegal: cannot fit into target type '"
75 << dstType << "'\n");
76 return {};
77}
78
79/// Converts the given `srcAttr` to a new attribute of the given `dstType`.
80/// Returns null attribute if `dstType` is not 32-bit or conversion fails.
81static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
82 Builder builder) {
83 // Only support converting to float for now.
84 if (!dstType.isF32())
85 return FloatAttr();
86
87 // Try to convert the source floating-point number to single precision.
88 APFloat dstVal = srcAttr.getValue();
89 bool losesInfo = false;
90 APFloat::opStatus status =
91 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
92 if (status != APFloat::opOK || losesInfo) {
93 LLVM_DEBUG(llvm::dbgs()
94 << srcAttr << " illegal: cannot fit into converted type '"
95 << dstType << "'\n");
96 return FloatAttr();
97 }
98
99 return builder.getF32FloatAttr(dstVal.convertToFloat());
100}
101
102// Get in IntegerAttr from FloatAttr while preserving the bits.
103// Useful for converting float constants to integer constants while preserving
104// the bits.
105static IntegerAttr
106getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107 ConversionPatternRewriter &rewriter) {
108 APFloat floatVal = floatAttr.getValue();
109 APInt intVal = floatVal.bitcastToAPInt();
110 return rewriter.getIntegerAttr(dstType, intVal);
111}
112
113/// Returns true if the given `type` is a boolean scalar or vector type.
114static bool isBoolScalarOrVector(Type type) {
115 assert(type && "Not a valid type");
116 if (type.isInteger(1))
117 return true;
118
119 if (auto vecType = dyn_cast<VectorType>(type))
120 return vecType.getElementType().isInteger(1);
121
122 return false;
123}
124
125/// Creates a scalar/vector integer constant.
126static Value getScalarOrVectorConstInt(Type type, uint64_t value,
127 OpBuilder &builder, Location loc) {
128 if (auto vectorType = dyn_cast<VectorType>(type)) {
129 Attribute element = IntegerAttr::get(vectorType.getElementType(), value);
130 auto attr = SplatElementsAttr::get(vectorType, element);
131 return spirv::ConstantOp::create(builder, loc, vectorType, attr);
132 }
133
134 if (auto intType = dyn_cast<IntegerType>(type))
135 return spirv::ConstantOp::create(builder, loc, type,
136 builder.getIntegerAttr(type, value));
137
138 return nullptr;
139}
140
141/// Returns true if scalar/vector type `a` and `b` have the same number of
142/// bitwidth.
143static bool hasSameBitwidth(Type a, Type b) {
144 auto getNumBitwidth = [](Type type) {
145 unsigned bw = 0;
146 if (type.isIntOrFloat())
147 bw = type.getIntOrFloatBitWidth();
148 else if (auto vecType = dyn_cast<VectorType>(type))
149 bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
150 return bw;
151 };
152 unsigned aBW = getNumBitwidth(a);
153 unsigned bBW = getNumBitwidth(b);
154 return aBW != 0 && bBW != 0 && aBW == bBW;
155}
156
157/// Returns a source type conversion failure for `srcType` and operation `op`.
158static LogicalResult
159getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op,
160 Type srcType) {
161 return rewriter.notifyMatchFailure(
162 op->getLoc(),
163 llvm::formatv("failed to convert source type '{0}'", srcType));
164}
165
166/// Returns a source type conversion failure for the result type of `op`.
167static LogicalResult
168getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
169 assert(op->getNumResults() == 1);
170 return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
171}
172
173namespace {
174
175/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
176/// operations. Op can potentially support overflow flags.
177template <typename Op, typename SPIRVOp>
178struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
179 using OpConversionPattern<Op>::OpConversionPattern;
180
181 LogicalResult
182 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
183 ConversionPatternRewriter &rewriter) const override {
184 assert(adaptor.getOperands().size() <= 3);
185 // Reject boolean types to allow specialized boolean patterns to handle
186 // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
187 if (!adaptor.getOperands().empty() &&
188 isBoolScalarOrVector(adaptor.getOperands().front().getType()))
189 return failure();
190 auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
191 Type dstType = converter->convertType(op.getType());
192 if (!dstType) {
193 return rewriter.notifyMatchFailure(
194 op->getLoc(),
195 llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
196 }
197
198 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
199 !getElementTypeOrSelf(op.getType()).isIndex() &&
200 dstType != op.getType()) {
201 return op.emitError("bitwidth emulation is not implemented yet on "
202 "unsigned op pattern version");
203 }
204
205 auto overflowFlags = arith::IntegerOverflowFlags::none;
206 if (auto overflowIface =
207 dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
208 if (converter->getTargetEnv().allows(
209 spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
210 overflowFlags = overflowIface.getOverflowAttr().getValue();
211 }
212
213 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
214 op, dstType, adaptor.getOperands());
215
216 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
217 newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
218 rewriter.getUnitAttr());
219
220 if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
221 newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
222 rewriter.getUnitAttr());
223
224 return success();
225 }
226};
227
228//===----------------------------------------------------------------------===//
229// ConstantOp
230//===----------------------------------------------------------------------===//
231
232/// Converts composite arith.constant operation to spirv.Constant.
233struct ConstantCompositeOpPattern final
234 : public OpConversionPattern<arith::ConstantOp> {
235 using Base::Base;
236
237 LogicalResult
238 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
239 ConversionPatternRewriter &rewriter) const override {
240 auto srcType = dyn_cast<ShapedType>(constOp.getType());
241 if (!srcType || srcType.getNumElements() == 1)
242 return failure();
243
244 // arith.constant should only have vector or tensor types. This is a MLIR
245 // wide problem at the moment.
246 if (!isa<VectorType, RankedTensorType>(srcType))
247 return rewriter.notifyMatchFailure(constOp, "unsupported ShapedType");
248
249 Type dstType = getTypeConverter()->convertType(srcType);
250 if (!dstType)
251 return failure();
252
253 // Import the resource into the IR to make use of the special handling of
254 // element types later on.
255 mlir::DenseElementsAttr dstElementsAttr;
256 if (auto denseElementsAttr =
257 dyn_cast<DenseElementsAttr>(constOp.getValue())) {
258 dstElementsAttr = denseElementsAttr;
259 } else if (auto resourceAttr =
260 dyn_cast<DenseResourceElementsAttr>(constOp.getValue())) {
261
262 AsmResourceBlob *blob = resourceAttr.getRawHandle().getBlob();
263 if (!blob)
264 return constOp->emitError("could not find resource blob");
265
266 ArrayRef<char> ptr = blob->getData();
267
268 // Check that the buffer meets the requirements to get converted to a
269 // DenseElementsAttr
271 return constOp->emitError("resource is not a valid buffer");
272
273 dstElementsAttr =
274 DenseElementsAttr::getFromRawBuffer(resourceAttr.getType(), ptr);
275 } else {
276 return constOp->emitError("unsupported elements attribute");
277 }
278
279 ShapedType dstAttrType = dstElementsAttr.getType();
280
281 // If the composite type has more than one dimensions, perform
282 // linearization.
283 if (srcType.getRank() > 1) {
284 if (isa<RankedTensorType>(srcType)) {
285 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286 srcType.getElementType());
287 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
288 } else {
289 // TODO: add support for large vectors.
290 return failure();
291 }
292 }
293
294 Type srcElemType = srcType.getElementType();
295 Type dstElemType;
296 // Tensor types are converted to SPIR-V array types; vector types are
297 // converted to SPIR-V vector/array types.
298 if (auto arrayType = dyn_cast<spirv::ArrayType>(dstType))
299 dstElemType = arrayType.getElementType();
300 else
301 dstElemType = cast<VectorType>(dstType).getElementType();
302
303 // If the source and destination element types are different, perform
304 // attribute conversion.
305 if (srcElemType != dstElemType) {
307 if (isa<FloatType>(srcElemType)) {
308 for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
309 Attribute dstAttr = nullptr;
310 // Handle 8-bit float conversion to 8-bit integer.
311 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
312 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
313 srcElemType.getIntOrFloatBitWidth() == 8 &&
314 isa<IntegerType>(dstElemType)) {
315 dstAttr =
316 getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
317 } else {
318 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
319 rewriter);
320 }
321 if (!dstAttr)
322 return failure();
323 elements.push_back(dstAttr);
324 }
325 } else if (srcElemType.isInteger(1)) {
326 return failure();
327 } else {
328 for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
329 IntegerAttr dstAttr = convertIntegerAttr(
330 srcAttr, cast<IntegerType>(dstElemType), rewriter);
331 if (!dstAttr)
332 return failure();
333 elements.push_back(dstAttr);
334 }
335 }
336
337 // Unfortunately, we cannot use dialect-specific types for element
338 // attributes; element attributes only works with builtin types. So we
339 // need to prepare another converted builtin types for the destination
340 // elements attribute.
341 if (isa<RankedTensorType>(dstAttrType))
342 dstAttrType =
343 RankedTensorType::get(dstAttrType.getShape(), dstElemType);
344 else
345 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
346
347 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
348 }
349
350 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
351 dstElementsAttr);
352 return success();
353 }
354};
355
356/// Converts scalar arith.constant operation to spirv.Constant.
357struct ConstantScalarOpPattern final
358 : public OpConversionPattern<arith::ConstantOp> {
359 using Base::Base;
360
361 LogicalResult
362 matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override {
364 Type srcType = constOp.getType();
365 if (auto shapedType = dyn_cast<ShapedType>(srcType)) {
366 if (shapedType.getNumElements() != 1)
367 return failure();
368 srcType = shapedType.getElementType();
369 }
370 if (!srcType.isIntOrIndexOrFloat())
371 return failure();
372
373 Attribute cstAttr = constOp.getValue();
374 if (auto elementsAttr = dyn_cast<DenseElementsAttr>(cstAttr))
375 cstAttr = elementsAttr.getSplatValue<Attribute>();
376
377 Type dstType = getTypeConverter()->convertType(srcType);
378 if (!dstType)
379 return failure();
380
381 // Floating-point types.
382 if (isa<FloatType>(srcType)) {
383 auto srcAttr = cast<FloatAttr>(cstAttr);
384 Attribute dstAttr = srcAttr;
385
386 // Floating-point types not supported in the target environment are all
387 // converted to float type.
388 auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
389 if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
390 srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
391 dstType.getIntOrFloatBitWidth() == 8) {
392 // If the source is an 8-bit float, convert it to a 8-bit integer.
393 dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
394 if (!dstAttr)
395 return failure();
396 } else if (srcType != dstType) {
397 dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
398 if (!dstAttr)
399 return failure();
400 }
401
402 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
403 return success();
404 }
405
406 // Bool type.
407 if (srcType.isInteger(1)) {
408 // arith.constant can use 0/1 instead of true/false for i1 values. We need
409 // to handle that here.
410 auto dstAttr = convertBoolAttr(cstAttr, rewriter);
411 if (!dstAttr)
412 return failure();
413 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
414 return success();
415 }
416
417 // IndexType or IntegerType. Index values are converted to 32-bit integer
418 // values when converting to SPIR-V.
419 auto srcAttr = cast<IntegerAttr>(cstAttr);
420 IntegerAttr dstAttr =
421 convertIntegerAttr(srcAttr, cast<IntegerType>(dstType), rewriter);
422 if (!dstAttr)
423 return failure();
424 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
425 return success();
426 }
427};
428
429//===----------------------------------------------------------------------===//
430// RemSIOp
431//===----------------------------------------------------------------------===//
432
433/// Returns signed remainder for `lhs` and `rhs` and lets the result follow
434/// the sign of `signOperand`.
435///
436/// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
437/// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
438/// the result is undefined." So we cannot directly use spirv.SRem/spirv.SMod
439/// if either operand can be negative. Emulate it via spirv.UMod.
440template <typename SignedAbsOp>
441static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
442 Value signOperand, OpBuilder &builder) {
443 assert(lhs.getType() == rhs.getType());
444 assert(lhs == signOperand || rhs == signOperand);
445
446 Type type = lhs.getType();
447
448 // Calculate the remainder with spirv.UMod.
449 Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs);
450 Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs);
451 Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs);
452
453 // Fix the sign.
454 Value isPositive;
455 if (lhs == signOperand)
456 isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs);
457 else
458 isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs);
459 Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs);
460 return spirv::SelectOp::create(builder, loc, type, isPositive, abs,
461 absNegate);
462}
463
464/// Converts arith.remsi to GLSL SPIR-V ops.
465///
466/// This cannot be merged into the template unary/binary pattern due to Vulkan
467/// restrictions over spirv.SRem and spirv.SMod.
468struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
469 using Base::Base;
470
471 LogicalResult
472 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
473 ConversionPatternRewriter &rewriter) const override {
474 Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
475 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
476 adaptor.getOperands()[0], rewriter);
477 rewriter.replaceOp(op, result);
478
479 return success();
480 }
481};
482
483/// Converts arith.remsi to OpenCL SPIR-V ops.
484struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
485 using Base::Base;
486
487 LogicalResult
488 matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
491 op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
492 adaptor.getOperands()[0], rewriter);
493 rewriter.replaceOp(op, result);
494
495 return success();
496 }
497};
498
499//===----------------------------------------------------------------------===//
500// BitwiseOp
501//===----------------------------------------------------------------------===//
502
503/// Converts bitwise operations to SPIR-V operations. This is a special pattern
504/// other than the BinaryOpPatternPattern because if the operands are boolean
505/// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
506/// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
507template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
508struct BitwiseOpPattern final : public OpConversionPattern<Op> {
509 using OpConversionPattern<Op>::OpConversionPattern;
510
511 LogicalResult
512 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
513 ConversionPatternRewriter &rewriter) const override {
514 assert(adaptor.getOperands().size() == 2);
515 Type dstType = this->getTypeConverter()->convertType(op.getType());
516 if (!dstType)
517 return getTypeConversionFailure(rewriter, op);
518
519 if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
520 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(
521 op, dstType, adaptor.getOperands());
522 } else {
523 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(
524 op, dstType, adaptor.getOperands());
525 }
526 return success();
527 }
528};
529
530//===----------------------------------------------------------------------===//
531// XOrIOp
532//===----------------------------------------------------------------------===//
533
534/// Converts arith.xori to SPIR-V operations.
535struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
536 using Base::Base;
537
538 LogicalResult
539 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override {
541 assert(adaptor.getOperands().size() == 2);
542
543 if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
544 return failure();
545
546 Type dstType = getTypeConverter()->convertType(op.getType());
547 if (!dstType)
548 return getTypeConversionFailure(rewriter, op);
549
550 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
551 adaptor.getOperands());
552
553 return success();
554 }
555};
556
557/// Converts arith.xori to SPIR-V operations if the type of source is i1 or
558/// vector of i1.
559struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
560 using Base::Base;
561
562 LogicalResult
563 matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter) const override {
565 assert(adaptor.getOperands().size() == 2);
566
567 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
568 return failure();
569
570 Type dstType = getTypeConverter()->convertType(op.getType());
571 if (!dstType)
572 return getTypeConversionFailure(rewriter, op);
573
574 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
575 op, dstType, adaptor.getOperands());
576 return success();
577 }
578};
579
580/// Converts an arith integer op to the given SPIR-V boolean op if the type is
581/// i1 or vector of i1. Each mapping follows from the boolean truth table of
582/// the operation:
583/// addi(a, b) = a ^ b (add mod 2 = XOR = LogicalNotEqual)
584/// subi(a, b) = a ^ b (sub mod 2 = XOR = LogicalNotEqual)
585/// muli(a, b) = a & b (1*1=1, else 0 = LogicalAnd)
586/// divui(a, b) = a & b (a/1=a, a/0=UB; truth table matches AND)
587/// divsi(a, b) = a & b (same as divui on i1)
588/// maxsi(a, b) = a & b (signed i1: 1 represents -1, so max is 0 unless both
589/// are 1)
590/// maxui(a, b) = a | b (unsigned max on i1: 1 when either operand is 1)
591/// minsi(a, b) = a | b (signed i1: -1 < 0, so min is 1 when either operand
592/// is 1)
593/// minui(a, b) = a & b (unsigned min on i1: 1 only when both operands are
594/// 1)
595template <typename ArithOp, typename SPIRVOp>
596struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
597 BoolIOpPattern(const TypeConverter &converter, MLIRContext *context)
598 // benefit=2: takes priority over the generic ElementwiseArithOpPattern
599 // (benefit=1) when the operand type is i1.
600 : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
601
602 LogicalResult
603 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
604 ConversionPatternRewriter &rewriter) const override {
605 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
606 return failure();
607
608 Type dstType = this->getTypeConverter()->convertType(op.getType());
609 if (!dstType)
610 return getTypeConversionFailure(rewriter, op);
611
612 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
613 return success();
614 }
615};
616
617/// Converts an arith binary op on i1 to spirv.LogicalAnd(lhs,
618/// spirv.LogicalNot(rhs)). This covers shift-left, shift-right-unsigned, and
619/// unsigned remainder on i1:
620/// shli(a, b) = a & ~b (shift left clears the bit when b=1)
621/// shrui(a, b) = a & ~b (shift right unsigned clears the bit when b=1)
622/// remui(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
623/// gives 0)
624/// remsi(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
625/// gives 0)
626template <typename ArithOp>
627struct BoolIOpAndNotPattern final : public OpConversionPattern<ArithOp> {
628 BoolIOpAndNotPattern(const TypeConverter &converter, MLIRContext *context)
629 // benefit=2: takes priority over the generic ElementwiseArithOpPattern
630 // (benefit=1) when the operand type is i1.
631 : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
632
633 LogicalResult
634 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
635 ConversionPatternRewriter &rewriter) const override {
636 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
637 return failure();
638
639 Type dstType = this->getTypeConverter()->convertType(op.getType());
640 if (!dstType)
641 return getTypeConversionFailure(rewriter, op);
642
643 Location loc = op.getLoc();
644 Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
645 adaptor.getOperands()[1]);
646 rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
647 op, dstType, adaptor.getOperands()[0], notRhs);
648 return success();
649 }
650};
651
652/// Converts arith.shrsi on i1 to identity: arithmetic right shift of a 1-bit
653/// signed value always yields the original value (0 >> n = 0, -1 >> n = -1).
654struct ShRSIBoolPattern final : public OpConversionPattern<arith::ShRSIOp> {
655 ShRSIBoolPattern(const TypeConverter &converter, MLIRContext *context)
656 // benefit=2: takes priority over the generic spirv::ElementwiseOpPattern
657 // (benefit=1) when the operand type is i1.
658 : OpConversionPattern<arith::ShRSIOp>(converter, context,
659 /*benefit=*/2) {}
660
661 LogicalResult
662 matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter) const override {
664 if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
665 return failure();
666
667 rewriter.replaceOp(op, adaptor.getOperands().front());
668 return success();
669 }
670};
671
672//===----------------------------------------------------------------------===//
673// i1 source to value
674//===----------------------------------------------------------------------===//
675
676/// Converts an op whose i1 (or vector of i1) source selects between one and
677/// zero of the destination type, i.e. spirv.Select(src, one, zero). Shared by
678/// arith.uitofp, arith.extui, and arith.index_cast on boolean sources.
679template <typename ArithOp>
680struct BoolToValuePattern final : public OpConversionPattern<ArithOp> {
681 using OpConversionPattern<ArithOp>::OpConversionPattern;
682
683 LogicalResult
684 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
685 ConversionPatternRewriter &rewriter) const override {
686 Type srcType = adaptor.getOperands().front().getType();
687 if (!isBoolScalarOrVector(srcType))
688 return failure();
689
690 Type dstType = this->getTypeConverter()->convertType(op.getType());
691 if (!dstType)
692 return getTypeConversionFailure(rewriter, op);
693
694 Location loc = op.getLoc();
695 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
696 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
697 rewriter.replaceOpWithNewOp<spirv::SelectOp>(
698 op, dstType, adaptor.getOperands().front(), one, zero);
699 return success();
700 }
701};
702
703//===----------------------------------------------------------------------===//
704// UIToFPOp
705//===----------------------------------------------------------------------===//
706
707/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF.
708/// When the source integer type was widened during type conversion (e.g., i8
709/// emulated as i32), the upper bits of the widened value may contain garbage.
710/// This pattern cleans the upper bits before the conversion:
711/// - For unsigned (IsSigned=false): mask with BitwiseAnd.
712/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical +
713/// ShiftRightArithmetic.
714template <typename ArithOp, typename SPIRVOp, bool IsSigned>
715struct IntToFPPattern final : public OpConversionPattern<ArithOp> {
716 using OpConversionPattern<ArithOp>::OpConversionPattern;
717
718 LogicalResult
719 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
720 ConversionPatternRewriter &rewriter) const override {
721 Type srcType = adaptor.getOperands().front().getType();
722 if (isBoolScalarOrVector(srcType))
723 return failure();
724
725 Type dstType = this->getTypeConverter()->convertType(op.getType());
726 if (!dstType)
727 return getTypeConversionFailure(rewriter, op);
728
729 // Check if the source integer type was widened during type conversion.
730 unsigned originalBitwidth =
731 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
732 unsigned convertedBitwidth =
734
735 if (originalBitwidth >= convertedBitwidth) {
736 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
737 return success();
738 }
739
740 // The source was widened. Clean the upper bits before converting.
741 Location loc = op.getLoc();
742 Value cleaned;
743 if constexpr (IsSigned) {
744 // Sign-extend by shifting left then arithmetic right.
745 unsigned shiftAmount = convertedBitwidth - originalBitwidth;
746 Value shiftSize =
747 getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc);
748 Value shifted = spirv::ShiftLeftLogicalOp::create(
749 rewriter, loc, srcType, adaptor.getIn(), shiftSize);
750 cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType,
751 shifted, shiftSize);
752 } else {
753 // Zero-extend by masking off the upper bits.
754 Value mask = getScalarOrVectorConstInt(
755 srcType, llvm::maskTrailingOnes<uint64_t>(originalBitwidth), rewriter,
756 loc);
757 cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType,
758 adaptor.getIn(), mask);
759 }
760 rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, cleaned);
761 return success();
762 }
763};
764
765//===----------------------------------------------------------------------===//
766// IndexCastOp
767//===----------------------------------------------------------------------===//
768
769/// Converts arith.index_cast to spirv.INotEqual if the target type is i1.
770struct IndexCastIndexI1Pattern final
771 : public OpConversionPattern<arith::IndexCastOp> {
772 using Base::Base;
773
774 LogicalResult
775 matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter) const override {
777 if (!isBoolScalarOrVector(op.getType()))
778 return failure();
779
780 Type dstType = getTypeConverter()->convertType(op.getType());
781 if (!dstType)
782 return getTypeConversionFailure(rewriter, op);
783
784 Location loc = op.getLoc();
785 Value zeroIdx =
786 spirv::ConstantOp::getZero(adaptor.getIn().getType(), loc, rewriter);
787 rewriter.replaceOpWithNewOp<spirv::INotEqualOp>(op, dstType, zeroIdx,
788 adaptor.getIn());
789 return success();
790 }
791};
792
793//===----------------------------------------------------------------------===//
794// ExtSIOp
795//===----------------------------------------------------------------------===//
796
797/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector
798/// of i1.
799struct ExtSII1Pattern final : public OpConversionPattern<arith::ExtSIOp> {
800 using Base::Base;
801
802 LogicalResult
803 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
804 ConversionPatternRewriter &rewriter) const override {
805 Value operand = adaptor.getIn();
806 if (!isBoolScalarOrVector(operand.getType()))
807 return failure();
808
809 Location loc = op.getLoc();
810 Type dstType = getTypeConverter()->convertType(op.getType());
811 if (!dstType)
812 return getTypeConversionFailure(rewriter, op);
813
814 Value allOnes;
815 if (auto intTy = dyn_cast<IntegerType>(dstType)) {
816 unsigned componentBitwidth = intTy.getWidth();
817 allOnes = spirv::ConstantOp::create(
818 rewriter, loc, intTy,
819 rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth)));
820 } else if (auto vectorTy = dyn_cast<VectorType>(dstType)) {
821 unsigned componentBitwidth = vectorTy.getElementTypeBitWidth();
822 allOnes = spirv::ConstantOp::create(
823 rewriter, loc, vectorTy,
824 SplatElementsAttr::get(vectorTy,
825 APInt::getAllOnes(componentBitwidth)));
826 } else {
827 return rewriter.notifyMatchFailure(
828 loc, llvm::formatv("unhandled type: {0}", dstType));
829 }
830
831 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
832 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, operand, allOnes,
833 zero);
834 return success();
835 }
836};
837
838/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor
839/// vector of i1.
840struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
841 using Base::Base;
842
843 LogicalResult
844 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
845 ConversionPatternRewriter &rewriter) const override {
846 Type srcType = adaptor.getIn().getType();
847 if (isBoolScalarOrVector(srcType))
848 return failure();
849
850 Type dstType = getTypeConverter()->convertType(op.getType());
851 if (!dstType)
852 return getTypeConversionFailure(rewriter, op);
853
854 if (dstType == srcType) {
855 // We can have the same source and destination type due to type emulation.
856 // Perform bit shifting to make sure we have the proper leading set bits.
857
858 unsigned srcBW =
859 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
860 unsigned dstBW =
862 assert(srcBW < dstBW);
863 Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
864 rewriter, op.getLoc());
865 if (!shiftSize)
866 return rewriter.notifyMatchFailure(op, "unsupported type for shift");
867
868 // First shift left to sequeeze out all leading bits beyond the original
869 // bitwidth. Here we need to use the original source and result type's
870 // bitwidth.
871 auto shiftLOp = spirv::ShiftLeftLogicalOp::create(
872 rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize);
873
874 // Then we perform arithmetic right shift to make sure we have the right
875 // sign bits for negative values.
876 rewriter.replaceOpWithNewOp<spirv::ShiftRightArithmeticOp>(
877 op, dstType, shiftLOp, shiftSize);
878 } else {
879 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
880 adaptor.getOperands());
881 }
882
883 return success();
884 }
885};
886
887//===----------------------------------------------------------------------===//
888// ExtUIOp
889//===----------------------------------------------------------------------===//
890
891/// Converts arith.extui for cases where the type of source is neither i1 nor
892/// vector of i1.
893struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
894 using Base::Base;
895
896 LogicalResult
897 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
898 ConversionPatternRewriter &rewriter) const override {
899 Type srcType = adaptor.getIn().getType();
900 if (isBoolScalarOrVector(srcType))
901 return failure();
902
903 Type dstType = getTypeConverter()->convertType(op.getType());
904 if (!dstType)
905 return getTypeConversionFailure(rewriter, op);
906
907 if (dstType == srcType) {
908 // We can have the same source and destination type due to type emulation.
909 // Perform bit masking to make sure we don't pollute downstream consumers
910 // with unwanted bits. Here we need to use the original source type's
911 // bitwidth.
912 unsigned bitwidth =
913 getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth();
914 Value mask = getScalarOrVectorConstInt(
915 dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
916 op.getLoc());
917 if (!mask)
918 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
919 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
920 adaptor.getIn(), mask);
921 } else {
922 rewriter.replaceOpWithNewOp<spirv::UConvertOp>(op, dstType,
923 adaptor.getOperands());
924 }
925 return success();
926 }
927};
928
929//===----------------------------------------------------------------------===//
930// TruncIOp
931//===----------------------------------------------------------------------===//
932
933/// Converts arith.trunci to spirv.Select if the type of result is i1 or vector
934/// of i1.
935struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
936 using Base::Base;
937
938 LogicalResult
939 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
940 ConversionPatternRewriter &rewriter) const override {
941 Type dstType = getTypeConverter()->convertType(op.getType());
942 if (!dstType)
943 return getTypeConversionFailure(rewriter, op);
944
945 if (!isBoolScalarOrVector(dstType))
946 return failure();
947
948 Location loc = op.getLoc();
949 auto srcType = adaptor.getOperands().front().getType();
950 // Check if (x & 1) == 1.
951 Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
952 Value maskedSrc = spirv::BitwiseAndOp::create(
953 rewriter, loc, srcType, adaptor.getOperands()[0], mask);
954 Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask);
955
956 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
957 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
958 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
959 return success();
960 }
961};
962
963/// Converts arith.trunci for cases where the type of result is neither i1
964/// nor vector of i1.
965struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
966 using Base::Base;
967
968 LogicalResult
969 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
970 ConversionPatternRewriter &rewriter) const override {
971 Type srcType = adaptor.getIn().getType();
972 Type dstType = getTypeConverter()->convertType(op.getType());
973 if (!dstType)
974 return getTypeConversionFailure(rewriter, op);
975
976 if (isBoolScalarOrVector(dstType))
977 return failure();
978
979 if (dstType == srcType) {
980 // We can have the same source and destination type due to type emulation.
981 // Perform bit masking to make sure we don't pollute downstream consumers
982 // with unwanted bits. Here we need to use the original result type's
983 // bitwidth.
984 unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
985 Value mask = getScalarOrVectorConstInt(
986 dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
987 if (!mask)
988 return rewriter.notifyMatchFailure(op, "unsupported type for mask");
989 rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
990 adaptor.getIn(), mask);
991 } else {
992 // Given this is truncation, either SConvertOp or UConvertOp works.
993 rewriter.replaceOpWithNewOp<spirv::SConvertOp>(op, dstType,
994 adaptor.getOperands());
995 }
996 return success();
997 }
998};
999
1000//===----------------------------------------------------------------------===//
1001// TypeCastingOp
1002//===----------------------------------------------------------------------===//
1003
1004static std::optional<spirv::FPRoundingMode>
1005convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
1006 switch (roundingMode) {
1007 case arith::RoundingMode::downward:
1008 return spirv::FPRoundingMode::RTN;
1009 case arith::RoundingMode::to_nearest_even:
1010 return spirv::FPRoundingMode::RTE;
1011 case arith::RoundingMode::toward_zero:
1012 return spirv::FPRoundingMode::RTZ;
1013 case arith::RoundingMode::upward:
1014 return spirv::FPRoundingMode::RTP;
1015 case arith::RoundingMode::to_nearest_away:
1016 // SPIR-V FPRoundingMode decoration has no ties-away-from-zero mode
1017 // (as of SPIR-V 1.6)
1018 return std::nullopt;
1019 }
1020 llvm_unreachable("Unhandled rounding mode");
1021}
1022
1023/// Converts type-casting standard operations to SPIR-V operations.
1024template <typename Op, typename SPIRVOp>
1025struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
1026 using OpConversionPattern<Op>::OpConversionPattern;
1027
1028 LogicalResult
1029 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1030 ConversionPatternRewriter &rewriter) const override {
1031 Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
1032 Type dstType = this->getTypeConverter()->convertType(op.getType());
1033 if (!dstType)
1034 return getTypeConversionFailure(rewriter, op);
1035
1036 if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
1037 return failure();
1038
1039 if (dstType == srcType) {
1040 // Due to type conversion, we are seeing the same source and target type.
1041 // Then we can just erase this operation by forwarding its operand.
1042 rewriter.replaceOp(op, adaptor.getOperands().front());
1043 } else {
1044 // Compute new rounding mode (if any).
1045 std::optional<spirv::FPRoundingMode> rm = std::nullopt;
1046 if (auto roundingModeOp =
1047 dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
1048 if (arith::RoundingModeAttr roundingMode =
1049 roundingModeOp.getRoundingModeAttr()) {
1050 if (!(rm =
1051 convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
1052 return rewriter.notifyMatchFailure(
1053 op->getLoc(),
1054 llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
1055 }
1056 }
1057 }
1058 // Create replacement op and attach rounding mode attribute (if any).
1059 auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
1060 op, dstType, adaptor.getOperands());
1061 if (rm) {
1062 newOp->setAttr(
1063 getDecorationString(spirv::Decoration::FPRoundingMode),
1064 spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
1065 }
1066 }
1067 return success();
1068 }
1069};
1070
1071//===----------------------------------------------------------------------===//
1072// CmpIOp
1073//===----------------------------------------------------------------------===//
1074
1075/// Converts integer compare operation on i1 type operands to SPIR-V ops.
1076class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
1077public:
1078 using Base::Base;
1079
1080 LogicalResult
1081 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1082 ConversionPatternRewriter &rewriter) const override {
1083 Type srcType = op.getLhs().getType();
1084 if (!isBoolScalarOrVector(srcType))
1085 return failure();
1086 Type dstType = getTypeConverter()->convertType(srcType);
1087 if (!dstType)
1088 return getTypeConversionFailure(rewriter, op, srcType);
1089
1090 switch (op.getPredicate()) {
1091 case arith::CmpIPredicate::eq: {
1092 rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
1093 adaptor.getRhs());
1094 return success();
1095 }
1096 case arith::CmpIPredicate::ne: {
1097 rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
1098 op, adaptor.getLhs(), adaptor.getRhs());
1099 return success();
1100 }
1101 case arith::CmpIPredicate::uge:
1102 case arith::CmpIPredicate::ugt:
1103 case arith::CmpIPredicate::ule:
1104 case arith::CmpIPredicate::ult: {
1105 // There are no direct corresponding instructions in SPIR-V for such
1106 // cases. Extend them to 32-bit and do comparision then.
1107 Type type = rewriter.getI32Type();
1108 if (auto vectorType = dyn_cast<VectorType>(dstType))
1109 type = VectorType::get(vectorType.getShape(), type);
1110 Value extLhs =
1111 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs());
1112 Value extRhs =
1113 arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs());
1114
1115 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
1116 extRhs);
1117 return success();
1118 }
1119 default:
1120 break;
1121 }
1122 return failure();
1123 }
1124};
1125
1126/// Converts integer compare operation to SPIR-V ops.
1127class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
1128public:
1129 using Base::Base;
1130
1131 LogicalResult
1132 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter) const override {
1134 Type srcType = op.getLhs().getType();
1135 if (isBoolScalarOrVector(srcType))
1136 return failure();
1137 Type dstType = getTypeConverter()->convertType(srcType);
1138 if (!dstType)
1139 return getTypeConversionFailure(rewriter, op, srcType);
1140
1141 switch (op.getPredicate()) {
1142#define DISPATCH(cmpPredicate, spirvOp) \
1143 case cmpPredicate: \
1144 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
1145 !getElementTypeOrSelf(srcType).isIndex() && srcType != dstType && \
1146 !hasSameBitwidth(srcType, dstType)) { \
1147 return op.emitError( \
1148 "bitwidth emulation is not implemented yet on unsigned op"); \
1149 } \
1150 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1151 adaptor.getRhs()); \
1152 return success();
1153
1154 DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
1155 DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
1156 DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
1157 DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
1158 DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
1159 DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
1160 DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
1161 DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
1162 DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
1163 DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
1164
1165#undef DISPATCH
1166 }
1167 return failure();
1168 }
1169};
1170
1171//===----------------------------------------------------------------------===//
1172// CmpFOpPattern
1173//===----------------------------------------------------------------------===//
1174
1175/// Converts floating-point comparison operations to SPIR-V ops.
1176class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
1177public:
1178 using Base::Base;
1179
1180 LogicalResult
1181 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1182 ConversionPatternRewriter &rewriter) const override {
1183 switch (op.getPredicate()) {
1184#define DISPATCH(cmpPredicate, spirvOp) \
1185 case cmpPredicate: \
1186 rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
1187 adaptor.getRhs()); \
1188 return success();
1189
1190 // Ordered.
1191 DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
1192 DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
1193 DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
1194 DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
1195 DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
1196 DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
1197 // Unordered.
1198 DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
1199 DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
1200 DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
1201 DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
1202 DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
1203 DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
1204
1205#undef DISPATCH
1206
1207 default:
1208 break;
1209 }
1210 return failure();
1211 }
1212};
1213
1214/// Converts floating point NaN check to SPIR-V ops. This pattern requires
1215/// Kernel capability.
1216class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
1217public:
1218 using Base::Base;
1219
1220 LogicalResult
1221 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1222 ConversionPatternRewriter &rewriter) const override {
1223 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1224 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
1225 adaptor.getRhs());
1226 return success();
1227 }
1228
1229 if (op.getPredicate() == arith::CmpFPredicate::UNO) {
1230 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
1231 adaptor.getRhs());
1232 return success();
1233 }
1234
1235 return failure();
1236 }
1237};
1238
1239/// Converts floating point NaN check to SPIR-V ops. This pattern does not
1240/// require additional capability.
1241class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
1242public:
1243 using Base::Base;
1244
1245 LogicalResult
1246 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
1247 ConversionPatternRewriter &rewriter) const override {
1248 if (op.getPredicate() != arith::CmpFPredicate::ORD &&
1249 op.getPredicate() != arith::CmpFPredicate::UNO)
1250 return failure();
1251
1252 Location loc = op.getLoc();
1253
1254 Value replace;
1255 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1256 if (op.getPredicate() == arith::CmpFPredicate::ORD) {
1257 // Ordered comparsion checks if neither operand is NaN.
1258 replace = spirv::ConstantOp::getOne(op.getType(), loc, rewriter);
1259 } else {
1260 // Unordered comparsion checks if either operand is NaN.
1261 replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter);
1262 }
1263 } else {
1264 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1265 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1266
1267 replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan);
1268 if (op.getPredicate() == arith::CmpFPredicate::ORD)
1269 replace = spirv::LogicalNotOp::create(rewriter, loc, replace);
1270 }
1271
1272 rewriter.replaceOp(op, replace);
1273 return success();
1274 }
1275};
1276
1277//===----------------------------------------------------------------------===//
1278// AddUIExtendedOp/SubUIExtendedOp
1279//===----------------------------------------------------------------------===//
1280
1281/// Converts arith.addui_extended/arith.subui_extended to spirv.IAddCarry/
1282/// spirv.ISubBorrow.
1283template <typename ArithExtendedOp, typename SPIRVExtendedOp>
1284class BinaryExtendedOpPattern final
1285 : public OpConversionPattern<ArithExtendedOp> {
1286public:
1287 using OpConversionPattern<ArithExtendedOp>::OpConversionPattern;
1288 LogicalResult
1289 matchAndRewrite(ArithExtendedOp op, typename ArithExtendedOp::Adaptor adaptor,
1290 ConversionPatternRewriter &rewriter) const override {
1291 Type dstElemTy = adaptor.getLhs().getType();
1292 Location loc = op->getLoc();
1293 Value result = SPIRVExtendedOp::create(rewriter, loc, adaptor.getLhs(),
1294 adaptor.getRhs());
1295
1296 Value valueResult = spirv::CompositeExtractOp::create(rewriter, loc, result,
1297 llvm::ArrayRef(0));
1298 Value flagValue = spirv::CompositeExtractOp::create(rewriter, loc, result,
1299 llvm::ArrayRef(1));
1300
1301 // Convert the carry/borrow value to boolean.
1302 Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter);
1303 Value flagResult = spirv::IEqualOp::create(rewriter, loc, flagValue, one);
1304
1305 rewriter.replaceOp(op, {valueResult, flagResult});
1306 return success();
1307 }
1308};
1309
1310//===----------------------------------------------------------------------===//
1311// MulIExtendedOp
1312//===----------------------------------------------------------------------===//
1313
1314/// Converts arith.mul*i_extended to spirv.*MulExtended.
1315template <typename ArithMulOp, typename SPIRVMulOp>
1316class MulIExtendedOpPattern final : public OpConversionPattern<ArithMulOp> {
1317public:
1318 using OpConversionPattern<ArithMulOp>::OpConversionPattern;
1319 LogicalResult
1320 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
1321 ConversionPatternRewriter &rewriter) const override {
1322 Location loc = op->getLoc();
1323 Value result =
1324 SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs());
1325
1326 Value low = spirv::CompositeExtractOp::create(rewriter, loc, result,
1327 llvm::ArrayRef(0));
1328 Value high = spirv::CompositeExtractOp::create(rewriter, loc, result,
1329 llvm::ArrayRef(1));
1330
1331 rewriter.replaceOp(op, {low, high});
1332 return success();
1333 }
1334};
1335
1336//===----------------------------------------------------------------------===//
1337// SelectOp
1338//===----------------------------------------------------------------------===//
1339
1340/// Converts arith.select to spirv.Select.
1341class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
1342public:
1343 using Base::Base;
1344 LogicalResult
1345 matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
1346 ConversionPatternRewriter &rewriter) const override {
1347 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
1348 adaptor.getTrueValue(),
1349 adaptor.getFalseValue());
1350 return success();
1351 }
1352};
1353
1354//===----------------------------------------------------------------------===//
1355// MinimumFOp, MaximumFOp
1356//===----------------------------------------------------------------------===//
1357
1358/// Converts arith.maximumf/minimumf to spirv.GL.FMax/FMin or
1359/// spirv.CL.fmax/fmin.
1360template <typename Op, typename SPIRVOp>
1361class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
1362public:
1363 using OpConversionPattern<Op>::OpConversionPattern;
1364 LogicalResult
1365 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1366 ConversionPatternRewriter &rewriter) const override {
1367 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1368 Type dstType = converter->convertType(op.getType());
1369 if (!dstType)
1370 return getTypeConversionFailure(rewriter, op);
1371
1372 // arith.maximumf/minimumf:
1373 // "if one of the arguments is NaN, then the result is also NaN."
1374 // spirv.GL.FMax/FMin
1375 // "which operand is the result is undefined if one of the operands
1376 // is a NaN."
1377 // spirv.CL.fmax/fmin:
1378 // "If one argument is a NaN, Fmin returns the other argument."
1379
1380 Location loc = op.getLoc();
1381 Value spirvOp =
1382 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1383
1384 if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1385 rewriter.replaceOp(op, spirvOp);
1386 return success();
1387 }
1388
1389 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1390 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1391
1392 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1393 adaptor.getLhs(), spirvOp);
1394 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1395 adaptor.getRhs(), select1);
1396
1397 rewriter.replaceOp(op, select2);
1398 return success();
1399 }
1400};
1401
1402//===----------------------------------------------------------------------===//
1403// MinNumFOp, MaxNumFOp
1404//===----------------------------------------------------------------------===//
1405
1406/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
1407/// spirv.CL.fmax/fmin.
1408template <typename Op, typename SPIRVOp>
1409class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
1410 template <typename TargetOp>
1411 constexpr bool shouldInsertNanGuards() const {
1412 return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
1413 }
1414
1415public:
1416 using OpConversionPattern<Op>::OpConversionPattern;
1417 LogicalResult
1418 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1419 ConversionPatternRewriter &rewriter) const override {
1420 auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
1421 Type dstType = converter->convertType(op.getType());
1422 if (!dstType)
1423 return getTypeConversionFailure(rewriter, op);
1424
1425 // arith.maxnumf/minnumf:
1426 // "If one of the arguments is NaN, then the result is the other
1427 // argument."
1428 // spirv.GL.FMax/FMin
1429 // "which operand is the result is undefined if one of the operands
1430 // is a NaN."
1431 // spirv.CL.fmax/fmin:
1432 // "If one argument is a NaN, Fmin returns the other argument."
1433
1434 Location loc = op.getLoc();
1435 Value spirvOp =
1436 SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands());
1437
1438 if (!shouldInsertNanGuards<SPIRVOp>() ||
1439 bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) {
1440 rewriter.replaceOp(op, spirvOp);
1441 return success();
1442 }
1443
1444 Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs());
1445 Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs());
1446
1447 Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan,
1448 adaptor.getRhs(), spirvOp);
1449 Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan,
1450 adaptor.getLhs(), select1);
1451
1452 rewriter.replaceOp(op, select2);
1453 return success();
1454 }
1455};
1456
1457} // namespace
1458
1459//===----------------------------------------------------------------------===//
1460// Pattern Population
1461//===----------------------------------------------------------------------===//
1462
1464 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1465 // clang-format off
1466 patterns.add<
1467 ConstantCompositeOpPattern,
1468 ConstantScalarOpPattern,
1469 BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>, // add mod 2 = XOR = not-equal
1470 ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
1471 BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>, // sub mod 2 = XOR = not-equal
1472 ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
1473 BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>, // 1*1=1, else 0 = AND
1474 ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
1475 BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>, // a/1=a, a/0=UB; truth table = AND
1477 BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>, // same as divui on i1
1479 BoolIOpAndNotPattern<arith::RemUIOp>, // remui(a,b) = a & ~b (see pattern comment)
1481 BoolIOpAndNotPattern<arith::RemSIOp>, // remsi(a,b) = a & ~b (see pattern comment)
1482 RemSIOpGLPattern, RemSIOpCLPattern,
1483 BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1484 BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1485 XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1486 BoolIOpAndNotPattern<arith::ShLIOp>, // shli(a,b) = a & ~b (see pattern comment)
1487 ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1488 BoolIOpAndNotPattern<arith::ShRUIOp>, // shrui(a,b) = a & ~b (see pattern comment)
1490 ShRSIBoolPattern, // shrsi(a,b) = a (identity; see pattern comment)
1498 ExtUIPattern, BoolToValuePattern<arith::ExtUIOp>,
1499 ExtSIPattern, ExtSII1Pattern,
1500 TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
1501 TruncIPattern, TruncII1Pattern,
1502 TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
1503 IntToFPPattern<arith::UIToFPOp, spirv::ConvertUToFOp, false>,
1504 BoolToValuePattern<arith::UIToFPOp>,
1505 IntToFPPattern<arith::SIToFPOp, spirv::ConvertSToFOp, true>,
1506 TypeCastingOpPattern<arith::FPToUIOp, spirv::ConvertFToUOp>,
1507 TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
1508 TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
1509 IndexCastIndexI1Pattern, BoolToValuePattern<arith::IndexCastOp>,
1510 TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
1511 TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
1512 CmpIOpBooleanPattern, CmpIOpPattern,
1513 CmpFOpNanNonePattern, CmpFOpPattern,
1514 BinaryExtendedOpPattern<arith::AddUIExtendedOp, spirv::IAddCarryOp>,
1515 BinaryExtendedOpPattern<arith::SubUIExtendedOp, spirv::ISubBorrowOp>,
1516 MulIExtendedOpPattern<arith::MulSIExtendedOp, spirv::SMulExtendedOp>,
1517 MulIExtendedOpPattern<arith::MulUIExtendedOp, spirv::UMulExtendedOp>,
1518 SelectOpPattern,
1519
1520 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
1521 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
1522 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
1523 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
1524 BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>, // signed i1: 1=-1, so max=0 unless both are 1
1525 BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>, // unsigned max on i1: 1 when either is 1
1526 BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>, // signed i1: -1<0, so min=1 when either is 1
1527 BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>, // unsigned min on i1: 1 only when both are 1
1532
1533 MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
1534 MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
1535 MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
1536 MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
1541 >(typeConverter, patterns.getContext());
1542 // clang-format on
1543
1544 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1545 // capability is available.
1546 patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
1547 /*benefit=*/2);
1548}
1549
1550//===----------------------------------------------------------------------===//
1551// Pass Definition
1552//===----------------------------------------------------------------------===//
1553
1554namespace {
1555struct ConvertArithToSPIRVPass
1556 : public impl::ConvertArithToSPIRVPassBase<ConvertArithToSPIRVPass> {
1557 using Base::Base;
1558
1559 void runOnOperation() override {
1560 Operation *op = getOperation();
1562 std::unique_ptr<SPIRVConversionTarget> target =
1563 SPIRVConversionTarget::get(targetAttr);
1564
1566 options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1567 options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
1568 SPIRVTypeConverter typeConverter(targetAttr, options);
1569
1570 // Use UnrealizedConversionCast as the bridge so that we don't need to pull
1571 // in patterns for other dialects.
1572 target->addLegalOp<UnrealizedConversionCastOp>();
1573
1574 // Fail hard when there are any remaining 'arith' ops.
1575 target->addIllegalDialect<arith::ArithDialect>();
1576
1577 RewritePatternSet patterns(&getContext());
1578 arith::populateArithToSPIRVPatterns(typeConverter, patterns);
1579
1580 if (failed(applyPartialConversion(op, *target, std::move(patterns))))
1581 signalPassFailure();
1582 }
1583};
1584} // namespace
return success()
static bool hasSameBitwidth(Type a, Type b)
Returns true if scalar/vector type a and b have the same number of bitwidth.
static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc)
Creates a scalar/vector integer constant.
static LogicalResult getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op, Type srcType)
Returns a source type conversion failure for srcType and operation op.
static IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, ConversionPatternRewriter &rewriter)
static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder)
Converts the given srcAttr into a boolean attribute if it holds an integral value.
static bool isBoolScalarOrVector(Type type)
Returns true if the given type is a boolean scalar or vector type.
#define DISPATCH(cmpPredicate, spirvOp)
static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType, Builder builder)
Converts the given srcAttr to a new attribute of the given dstType.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a processed binary blob of data.
Definition AsmState.h:91
ArrayRef< char > getData() const
Return the raw underlying data of this blob.
Definition AsmState.h:145
Attributes are known-constant values of operations.
Definition Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:251
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
static bool isValidRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Returns true if the given buffer is a valid raw buffer for the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
result_type_range getResultTypes()
Definition Operation.h:453
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
Type front()
Return first type in the range.
Definition TypeRange.h:164
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
An attribute that specifies the target version, allowed extensions and capabilities,...
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void populateArithToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::string getDecorationString(Decoration decoration)
Converts a SPIR-V Decoration enum value to its snake_case string representation for use in MLIR attri...
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.
Definition Pattern.h:24