MLIR 22.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1//===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/IR/BuiltinOps.h"
23#include "llvm/ADT/TypeSwitch.h"
24#include "llvm/Support/FormatVariadic.h"
25
26#define DEBUG_TYPE "spirv-to-llvm-pattern"
27
28using namespace mlir;
29
30//===----------------------------------------------------------------------===//
31// Utility functions
32//===----------------------------------------------------------------------===//
33
34/// Returns true if the given type is a signed integer or vector type.
35static bool isSignedIntegerOrVector(Type type) {
36 if (type.isSignedInteger())
37 return true;
38 if (auto vecType = dyn_cast<VectorType>(type))
39 return vecType.getElementType().isSignedInteger();
40 return false;
41}
42
43/// Returns true if the given type is an unsigned integer or vector type
45 if (type.isUnsignedInteger())
46 return true;
47 if (auto vecType = dyn_cast<VectorType>(type))
48 return vecType.getElementType().isUnsignedInteger();
49 return false;
50}
51
52/// Returns the width of an integer or of the element type of an integer vector,
53/// if applicable.
54static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) {
55 if (auto intType = dyn_cast<IntegerType>(type))
56 return intType.getWidth();
57 if (auto vecType = dyn_cast<VectorType>(type))
58 if (auto intType = dyn_cast<IntegerType>(vecType.getElementType()))
59 return intType.getWidth();
60 return std::nullopt;
61}
62
63/// Returns the bit width of integer, float or vector of float or integer values
64static unsigned getBitWidth(Type type) {
65 assert((type.isIntOrFloat() || isa<VectorType>(type)) &&
66 "bitwidth is not supported for this type");
67 if (type.isIntOrFloat())
68 return type.getIntOrFloatBitWidth();
69 auto vecType = dyn_cast<VectorType>(type);
70 auto elementType = vecType.getElementType();
71 assert(elementType.isIntOrFloat() &&
72 "only integers and floats have a bitwidth");
73 return elementType.getIntOrFloatBitWidth();
74}
75
76/// Returns the bit width of LLVMType integer or vector.
77static unsigned getLLVMTypeBitWidth(Type type) {
78 if (auto vecTy = dyn_cast<VectorType>(type))
79 type = vecTy.getElementType();
80 return cast<IntegerType>(type).getWidth();
81}
82
83/// Creates `IntegerAttribute` with all bits set for given type
84static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
85 if (auto vecType = dyn_cast<VectorType>(type)) {
86 auto integerType = cast<IntegerType>(vecType.getElementType());
87 return builder.getIntegerAttr(integerType, -1);
88 }
89 auto integerType = cast<IntegerType>(type);
90 return builder.getIntegerAttr(integerType, -1);
91}
92
93/// Creates `llvm.mlir.constant` with all bits set for the given type.
94static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
95 PatternRewriter &rewriter) {
96 if (isa<VectorType>(srcType)) {
97 return LLVM::ConstantOp::create(
98 rewriter, loc, dstType,
99 SplatElementsAttr::get(cast<ShapedType>(srcType),
100 minusOneIntegerAttribute(srcType, rewriter)));
101 }
102 return LLVM::ConstantOp::create(rewriter, loc, dstType,
103 minusOneIntegerAttribute(srcType, rewriter));
104}
105
106/// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
107static Value createFPConstant(Location loc, Type srcType, Type dstType,
108 PatternRewriter &rewriter, double value) {
109 if (auto vecType = dyn_cast<VectorType>(srcType)) {
110 auto floatType = cast<FloatType>(vecType.getElementType());
111 return LLVM::ConstantOp::create(
112 rewriter, loc, dstType,
114 rewriter.getFloatAttr(floatType, value)));
115 }
116 auto floatType = cast<FloatType>(srcType);
117 return LLVM::ConstantOp::create(rewriter, loc, dstType,
118 rewriter.getFloatAttr(floatType, value));
119}
120
121/// Utility function for bitfield ops:
122/// - `BitFieldInsert`
123/// - `BitFieldSExtract`
124/// - `BitFieldUExtract`
125/// Truncates or extends the value. If the bitwidth of the value is the same as
126/// `llvmType` bitwidth, the value remains unchanged.
128 Type llvmType,
129 PatternRewriter &rewriter) {
130 auto srcType = value.getType();
131 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
132 unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
133 ? getLLVMTypeBitWidth(srcType)
134 : getBitWidth(srcType);
135
136 if (valueBitWidth < targetBitWidth)
137 return LLVM::ZExtOp::create(rewriter, loc, llvmType, value);
138 // If the bit widths of `Count` and `Offset` are greater than the bit width
139 // of the target type, they are truncated. Truncation is safe since `Count`
140 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
141 // both values can be expressed in 8 bits.
142 if (valueBitWidth > targetBitWidth)
143 return LLVM::TruncOp::create(rewriter, loc, llvmType, value);
144 return value;
145}
146
147/// Broadcasts the value to vector with `numElements` number of elements.
148static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
149 const TypeConverter &typeConverter,
150 ConversionPatternRewriter &rewriter) {
151 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
152 auto llvmVectorType = typeConverter.convertType(vectorType);
153 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
154 Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType);
155 for (unsigned i = 0; i < numElements; ++i) {
156 auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type,
157 rewriter.getI32IntegerAttr(i));
158 broadcasted = LLVM::InsertElementOp::create(
159 rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index);
160 }
161 return broadcasted;
162}
163
164/// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
165static Value optionallyBroadcast(Location loc, Value value, Type srcType,
166 const TypeConverter &typeConverter,
167 ConversionPatternRewriter &rewriter) {
168 if (auto vectorType = dyn_cast<VectorType>(srcType)) {
169 unsigned numElements = vectorType.getNumElements();
170 return broadcast(loc, value, numElements, typeConverter, rewriter);
171 }
172 return value;
173}
174
175/// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
176/// `BitFieldUExtract`.
177/// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
178/// a vector type, construct a vector that has:
179/// - same number of elements as `Base`
180/// - each element has the type that is the same as the type of `Offset` or
181/// `Count`
182/// - each element has the same value as `Offset` or `Count`
183/// Then cast `Offset` and `Count` if their bit width is different
184/// from `Base` bit width.
185static Value processCountOrOffset(Location loc, Value value, Type srcType,
186 Type dstType, const TypeConverter &converter,
187 ConversionPatternRewriter &rewriter) {
188 Value broadcasted =
189 optionallyBroadcast(loc, value, srcType, converter, rewriter);
190 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
191}
192
193/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
194/// offset to LLVM struct. Otherwise, the conversion is not supported.
196 const TypeConverter &converter) {
197 if (type != VulkanLayoutUtils::decorateType(type))
198 return nullptr;
199
200 SmallVector<Type> elementsVector;
201 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
202 return nullptr;
203 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
204 /*isPacked=*/false);
205}
206
207/// Converts SPIR-V struct with no offset to packed LLVM struct.
209 const TypeConverter &converter) {
210 SmallVector<Type> elementsVector;
211 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector)))
212 return nullptr;
213 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
214 /*isPacked=*/true);
215}
216
217/// Creates LLVM dialect constant with the given value.
219 unsigned value) {
220 return LLVM::ConstantOp::create(
221 rewriter, loc, IntegerType::get(rewriter.getContext(), 32),
222 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
223}
224
225/// Utility for `spirv.Load` and `spirv.Store` conversion.
226static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands,
227 ConversionPatternRewriter &rewriter,
228 const TypeConverter &typeConverter,
229 unsigned alignment, bool isVolatile,
230 bool isNonTemporal) {
231 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
232 auto dstType = typeConverter.convertType(loadOp.getType());
233 if (!dstType)
234 return rewriter.notifyMatchFailure(op, "type conversion failed");
235 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
236 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
237 isVolatile, isNonTemporal);
238 return success();
239 }
240 auto storeOp = cast<spirv::StoreOp>(op);
241 spirv::StoreOpAdaptor adaptor(operands);
242 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
243 adaptor.getPtr(), alignment,
244 isVolatile, isNonTemporal);
245 return success();
246}
247
248//===----------------------------------------------------------------------===//
249// Type conversion
250//===----------------------------------------------------------------------===//
251
252/// Converts SPIR-V array type to LLVM array. Natural stride (according to
253/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
254/// when converting ops that manipulate array types.
255static std::optional<Type> convertArrayType(spirv::ArrayType type,
256 TypeConverter &converter) {
257 unsigned stride = type.getArrayStride();
258 Type elementType = type.getElementType();
259 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes();
260 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
261 return std::nullopt;
262
263 auto llvmElementType = converter.convertType(elementType);
264 unsigned numElements = type.getNumElements();
265 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
266}
267
268/// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
269/// modelled at the moment.
271 const TypeConverter &converter,
272 spirv::ClientAPI clientAPI) {
273 unsigned addressSpace =
275 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace);
276}
277
278/// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
279/// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
280/// no modelling of array stride at the moment.
281static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
282 TypeConverter &converter) {
283 if (type.getArrayStride() != 0)
284 return std::nullopt;
285 auto elementType = converter.convertType(type.getElementType());
286 return LLVM::LLVMArrayType::get(elementType, 0);
287}
288
289/// Converts SPIR-V struct to LLVM struct. There is no support of structs with
290/// member decorations. Also, only natural offset is supported.
292 const TypeConverter &converter) {
294 type.getMemberDecorations(memberDecorations);
295 if (!memberDecorations.empty())
296 return nullptr;
297 if (type.hasOffset())
298 return convertStructTypeWithOffset(type, converter);
299 return convertStructTypePacked(type, converter);
300}
301
302//===----------------------------------------------------------------------===//
303// Operation conversion
304//===----------------------------------------------------------------------===//
305
306namespace {
307
308class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
309public:
310 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
311
312 LogicalResult
313 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 auto dstType =
316 getTypeConverter()->convertType(op.getComponentPtr().getType());
317 if (!dstType)
318 return rewriter.notifyMatchFailure(op, "type conversion failed");
319 // To use GEP we need to add a first 0 index to go through the pointer.
320 auto indices = llvm::to_vector<4>(adaptor.getIndices());
321 Type indexType = op.getIndices().front().getType();
322 auto llvmIndexType = getTypeConverter()->convertType(indexType);
323 if (!llvmIndexType)
324 return rewriter.notifyMatchFailure(op, "type conversion failed");
325 Value zero =
326 LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType,
327 rewriter.getIntegerAttr(indexType, 0));
328 indices.insert(indices.begin(), zero);
329
330 auto elementType = getTypeConverter()->convertType(
331 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType());
332 if (!elementType)
333 return rewriter.notifyMatchFailure(op, "type conversion failed");
334 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType,
335 adaptor.getBasePtr(), indices);
336 return success();
337 }
338};
339
340class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
341public:
342 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
343
344 LogicalResult
345 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
346 ConversionPatternRewriter &rewriter) const override {
347 auto dstType = getTypeConverter()->convertType(op.getPointer().getType());
348 if (!dstType)
349 return rewriter.notifyMatchFailure(op, "type conversion failed");
350 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType,
351 op.getVariable());
352 return success();
353 }
354};
355
356class BitFieldInsertPattern
357 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
358public:
359 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
360
361 LogicalResult
362 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
363 ConversionPatternRewriter &rewriter) const override {
364 auto srcType = op.getType();
365 auto dstType = getTypeConverter()->convertType(srcType);
366 if (!dstType)
367 return rewriter.notifyMatchFailure(op, "type conversion failed");
368 Location loc = op.getLoc();
369
370 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
371 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
372 *getTypeConverter(), rewriter);
373 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
374 *getTypeConverter(), rewriter);
375
376 // Create a mask with bits set outside [Offset, Offset + Count - 1].
377 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
378 Value maskShiftedByCount =
379 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
380 Value negated = LLVM::XOrOp::create(rewriter, loc, dstType,
381 maskShiftedByCount, minusOne);
382 Value maskShiftedByCountAndOffset =
383 LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset);
384 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType,
385 maskShiftedByCountAndOffset, minusOne);
386
387 // Extract unchanged bits from the `Base` that are outside of
388 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
389 Value baseAndMask =
390 LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask);
391 Value insertShiftedByOffset =
392 LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset);
393 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
394 insertShiftedByOffset);
395 return success();
396 }
397};
398
399/// Converts SPIR-V ConstantOp with scalar or vector type.
400class ConstantScalarAndVectorPattern
401 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
402public:
403 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
404
405 LogicalResult
406 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 auto srcType = constOp.getType();
409 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat())
410 return failure();
411
412 auto dstType = getTypeConverter()->convertType(srcType);
413 if (!dstType)
414 return rewriter.notifyMatchFailure(constOp, "type conversion failed");
415
416 // SPIR-V constant can be a signed/unsigned integer, which has to be
417 // casted to signless integer when converting to LLVM dialect. Removing the
418 // sign bit may have unexpected behaviour. However, it is better to handle
419 // it case-by-case, given that the purpose of the conversion is not to
420 // cover all possible corner cases.
421 if (isSignedIntegerOrVector(srcType) ||
422 isUnsignedIntegerOrVector(srcType)) {
423 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
424
425 if (isa<VectorType>(srcType)) {
426 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue());
427 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
428 constOp, dstType,
429 dstElementsAttr.mapValues(
430 signlessType, [&](const APInt &value) { return value; }));
431 return success();
432 }
433 auto srcAttr = cast<IntegerAttr>(constOp.getValue());
434 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
435 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
436 return success();
437 }
438 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
439 constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
440 return success();
441 }
442};
443
444class BitFieldSExtractPattern
445 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
446public:
447 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
448
449 LogicalResult
450 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
451 ConversionPatternRewriter &rewriter) const override {
452 auto srcType = op.getType();
453 auto dstType = getTypeConverter()->convertType(srcType);
454 if (!dstType)
455 return rewriter.notifyMatchFailure(op, "type conversion failed");
456 Location loc = op.getLoc();
457
458 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
459 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
460 *getTypeConverter(), rewriter);
461 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
462 *getTypeConverter(), rewriter);
463
464 // Create a constant that holds the size of the `Base`.
465 IntegerType integerType;
466 if (auto vecType = dyn_cast<VectorType>(srcType))
467 integerType = cast<IntegerType>(vecType.getElementType());
468 else
469 integerType = cast<IntegerType>(srcType);
470
471 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
472 Value size =
473 isa<VectorType>(srcType)
474 ? LLVM::ConstantOp::create(
475 rewriter, loc, dstType,
476 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize))
477 : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize);
478
479 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
480 // at Offset + Count - 1 is the most significant bit now.
481 Value countPlusOffset =
482 LLVM::AddOp::create(rewriter, loc, dstType, count, offset);
483 Value amountToShiftLeft =
484 LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset);
485 Value baseShiftedLeft = LLVM::ShlOp::create(
486 rewriter, loc, dstType, op.getBase(), amountToShiftLeft);
487
488 // Shift the result right, filling the bits with the sign bit.
489 Value amountToShiftRight =
490 LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft);
491 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
492 amountToShiftRight);
493 return success();
494 }
495};
496
497class BitFieldUExtractPattern
498 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
499public:
500 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
501
502 LogicalResult
503 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 auto srcType = op.getType();
506 auto dstType = getTypeConverter()->convertType(srcType);
507 if (!dstType)
508 return rewriter.notifyMatchFailure(op, "type conversion failed");
509 Location loc = op.getLoc();
510
511 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
512 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
513 *getTypeConverter(), rewriter);
514 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
515 *getTypeConverter(), rewriter);
516
517 // Create a mask with bits set at [0, Count - 1].
518 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
519 Value maskShiftedByCount =
520 LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count);
521 Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount,
522 minusOne);
523
524 // Shift `Base` by `Offset` and apply the mask on it.
525 Value shiftedBase =
526 LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset);
527 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
528 return success();
529 }
530};
531
532class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
533public:
534 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
535
536 LogicalResult
537 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
538 ConversionPatternRewriter &rewriter) const override {
539 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
540 branchOp.getTarget());
541 return success();
542 }
543};
544
545class BranchConditionalConversionPattern
546 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
547public:
548 using SPIRVToLLVMConversion<
549 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
550
551 LogicalResult
552 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter) const override {
554 // If branch weights exist, map them to 32-bit integer vector.
555 DenseI32ArrayAttr branchWeights = nullptr;
556 if (auto weights = op.getBranchWeights()) {
557 SmallVector<int32_t> weightValues;
558 for (auto weight : weights->getAsRange<IntegerAttr>())
559 weightValues.push_back(weight.getInt());
560 branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
561 }
562
563 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
564 op, op.getCondition(), op.getTrueBlockArguments(),
565 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
566 op.getFalseBlock());
567 return success();
568 }
569};
570
571/// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
572/// type is an aggregate type (struct or array). Otherwise, converts to
573/// `llvm.extractelement` that operates on vectors.
574class CompositeExtractPattern
575 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
576public:
577 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
578
579 LogicalResult
580 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter) const override {
582 auto dstType = this->getTypeConverter()->convertType(op.getType());
583 if (!dstType)
584 return rewriter.notifyMatchFailure(op, "type conversion failed");
585
586 Type containerType = op.getComposite().getType();
587 if (isa<VectorType>(containerType)) {
588 Location loc = op.getLoc();
589 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
590 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
591 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
592 op, dstType, adaptor.getComposite(), index);
593 return success();
594 }
595
596 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
597 op, adaptor.getComposite(),
598 LLVM::convertArrayToIndices(op.getIndices()));
599 return success();
600 }
601};
602
603/// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
604/// type is an aggregate type (struct or array). Otherwise, converts to
605/// `llvm.insertelement` that operates on vectors.
606class CompositeInsertPattern
607 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
608public:
609 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
610
611 LogicalResult
612 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
613 ConversionPatternRewriter &rewriter) const override {
614 auto dstType = this->getTypeConverter()->convertType(op.getType());
615 if (!dstType)
616 return rewriter.notifyMatchFailure(op, "type conversion failed");
617
618 Type containerType = op.getComposite().getType();
619 if (isa<VectorType>(containerType)) {
620 Location loc = op.getLoc();
621 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]);
622 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
623 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
624 op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
625 return success();
626 }
627
628 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
629 op, adaptor.getComposite(), adaptor.getObject(),
630 LLVM::convertArrayToIndices(op.getIndices()));
631 return success();
632 }
633};
634
635/// Converts SPIR-V operations that have straightforward LLVM equivalent
636/// into LLVM dialect operations.
637template <typename SPIRVOp, typename LLVMOp>
638class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
639public:
640 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
641
642 LogicalResult
643 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
644 ConversionPatternRewriter &rewriter) const override {
645 auto dstType = this->getTypeConverter()->convertType(op.getType());
646 if (!dstType)
647 return rewriter.notifyMatchFailure(op, "type conversion failed");
648 rewriter.template replaceOpWithNewOp<LLVMOp>(
649 op, dstType, adaptor.getOperands(), op->getAttrs());
650 return success();
651 }
652};
653
654/// Converts `spirv.ExecutionMode` into a global struct constant that holds
655/// execution mode information.
656class ExecutionModePattern
657 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
658public:
659 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
660
661 LogicalResult
662 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter) const override {
664 // First, create the global struct's name that would be associated with
665 // this entry point's execution mode. We set it to be:
666 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
667 ModuleOp module = op->getParentOfType<ModuleOp>();
668 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
669 std::string moduleName;
670 if (module.getName().has_value())
671 moduleName = "_" + module.getName()->str();
672 else
673 moduleName = "";
674 std::string executionModeInfoName = llvm::formatv(
675 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
676 static_cast<uint32_t>(executionModeAttr.getValue()));
677
678 MLIRContext *context = rewriter.getContext();
679 OpBuilder::InsertionGuard guard(rewriter);
680 rewriter.setInsertionPointToStart(module.getBody());
681
682 // Create a struct type, corresponding to the C struct below.
683 // struct {
684 // int32_t executionMode;
685 // int32_t values[]; // optional values
686 // };
687 auto llvmI32Type = IntegerType::get(context, 32);
688 SmallVector<Type, 2> fields;
689 fields.push_back(llvmI32Type);
690 ArrayAttr values = op.getValues();
691 if (!values.empty()) {
692 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
693 fields.push_back(arrayType);
694 }
695 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
696
697 // Create `llvm.mlir.global` with initializer region containing one block.
698 auto global = LLVM::GlobalOp::create(
699 rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true,
700 LLVM::Linkage::External, executionModeInfoName, Attribute(),
701 /*alignment=*/0);
702 Location loc = global.getLoc();
703 Region &region = global.getInitializerRegion();
704 Block *block = rewriter.createBlock(&region);
705
706 // Initialize the struct and set the execution mode value.
707 rewriter.setInsertionPointToStart(block);
708 Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType);
709 Value executionMode = LLVM::ConstantOp::create(
710 rewriter, loc, llvmI32Type,
711 rewriter.getI32IntegerAttr(
712 static_cast<uint32_t>(executionModeAttr.getValue())));
713 SmallVector<int64_t> position{0};
714 structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue,
715 executionMode, position);
716
717 // Insert extra operands if they exist into execution mode info struct.
718 for (unsigned i = 0, e = values.size(); i < e; ++i) {
719 auto attr = values.getValue()[i];
720 Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr);
721 structValue = LLVM::InsertValueOp::create(
722 rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i}));
723 }
724 LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue}));
725 rewriter.eraseOp(op);
726 return success();
727 }
728};
729
730/// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
731/// global returns a pointer, whereas in LLVM dialect the global holds an actual
732/// value. This difference is handled by `spirv.mlir.addressof` and
733/// `llvm.mlir.addressof`ops that both return a pointer.
734class GlobalVariablePattern
735 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
736public:
737 template <typename... Args>
738 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args)
739 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>(
740 std::forward<Args>(args)...),
741 clientAPI(clientAPI) {}
742
743 LogicalResult
744 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter) const override {
746 // Currently, there is no support of initialization with a constant value in
747 // SPIR-V dialect. Specialization constants are not considered as well.
748 if (op.getInitializer())
749 return failure();
750
751 auto srcType = cast<spirv::PointerType>(op.getType());
752 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType());
753 if (!dstType)
754 return rewriter.notifyMatchFailure(op, "type conversion failed");
755
756 // Limit conversion to the current invocation only or `StorageBuffer`
757 // required by SPIR-V runner.
758 // This is okay because multiple invocations are not supported yet.
759 auto storageClass = srcType.getStorageClass();
760 switch (storageClass) {
761 case spirv::StorageClass::Input:
762 case spirv::StorageClass::Private:
763 case spirv::StorageClass::Output:
764 case spirv::StorageClass::StorageBuffer:
765 case spirv::StorageClass::UniformConstant:
766 break;
767 default:
768 return failure();
769 }
770
771 // LLVM dialect spec: "If the global value is a constant, storing into it is
772 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
773 // storage class that is read-only.
774 bool isConstant = (storageClass == spirv::StorageClass::Input) ||
775 (storageClass == spirv::StorageClass::UniformConstant);
776 // SPIR-V spec: "By default, functions and global variables are private to a
777 // module and cannot be accessed by other modules. However, a module may be
778 // written to export or import functions and global (module scope)
779 // variables.". Therefore, map 'Private' storage class to private linkage,
780 // 'Input' and 'Output' to external linkage.
781 auto linkage = storageClass == spirv::StorageClass::Private
782 ? LLVM::Linkage::Private
783 : LLVM::Linkage::External;
784 StringAttr locationAttrName = op.getLocationAttrName();
785 IntegerAttr locationAttr = op.getLocationAttr();
786 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
787 op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
788 /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass));
789
790 // Attach location attribute if applicable
791 if (locationAttr)
792 newGlobalOp->setAttr(locationAttrName, locationAttr);
793
794 return success();
795 }
796
797private:
798 spirv::ClientAPI clientAPI;
799};
800
801/// Converts SPIR-V cast ops that do not have straightforward LLVM
802/// equivalent in LLVM dialect.
803template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
804class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
805public:
806 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
807
808 LogicalResult
809 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
810 ConversionPatternRewriter &rewriter) const override {
811
812 Type fromType = op.getOperand().getType();
813 Type toType = op.getType();
814
815 auto dstType = this->getTypeConverter()->convertType(toType);
816 if (!dstType)
817 return rewriter.notifyMatchFailure(op, "type conversion failed");
818
819 if (getBitWidth(fromType) < getBitWidth(toType)) {
820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType,
821 adaptor.getOperands());
822 return success();
823 }
824 if (getBitWidth(fromType) > getBitWidth(toType)) {
825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType,
826 adaptor.getOperands());
827 return success();
828 }
829 return failure();
830 }
831};
832
833class FunctionCallPattern
834 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
835public:
836 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
837
838 LogicalResult
839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
840 ConversionPatternRewriter &rewriter) const override {
841 if (callOp.getNumResults() == 0) {
842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
843 callOp, TypeRange(), adaptor.getOperands(), callOp->getAttrs());
844 newOp.getProperties().operandSegmentSizes = {
845 static_cast<int32_t>(adaptor.getOperands().size()), 0};
846 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
847 return success();
848 }
849
850 // Function returns a single result.
851 auto dstType = getTypeConverter()->convertType(callOp.getType(0));
852 if (!dstType)
853 return rewriter.notifyMatchFailure(callOp, "type conversion failed");
854 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
856 newOp.getProperties().operandSegmentSizes = {
857 static_cast<int32_t>(adaptor.getOperands().size()), 0};
858 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
859 return success();
860 }
861};
862
863/// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
864template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
865class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
866public:
867 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
868
869 LogicalResult
870 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
871 ConversionPatternRewriter &rewriter) const override {
872
873 auto dstType = this->getTypeConverter()->convertType(op.getType());
874 if (!dstType)
875 return rewriter.notifyMatchFailure(op, "type conversion failed");
876
877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
878 op, dstType, predicate, op.getOperand1(), op.getOperand2());
879 return success();
880 }
881};
882
883/// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
884template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
885class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
886public:
887 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
888
889 LogicalResult
890 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
891 ConversionPatternRewriter &rewriter) const override {
892
893 auto dstType = this->getTypeConverter()->convertType(op.getType());
894 if (!dstType)
895 return rewriter.notifyMatchFailure(op, "type conversion failed");
896
897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
898 op, dstType, predicate, op.getOperand1(), op.getOperand2());
899 return success();
900 }
901};
902
903class InverseSqrtPattern
904 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
905public:
906 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion;
907
908 LogicalResult
909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
910 ConversionPatternRewriter &rewriter) const override {
911 auto srcType = op.getType();
912 auto dstType = getTypeConverter()->convertType(srcType);
913 if (!dstType)
914 return rewriter.notifyMatchFailure(op, "type conversion failed");
915
916 Location loc = op.getLoc();
917 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
918 Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand());
919 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
920 return success();
921 }
922};
923
924/// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
925template <typename SPIRVOp>
926class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
927public:
928 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
929
930 LogicalResult
931 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
932 ConversionPatternRewriter &rewriter) const override {
933 if (!op.getMemoryAccess()) {
934 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
935 *this->getTypeConverter(), /*alignment=*/0,
936 /*isVolatile=*/false,
937 /*isNonTemporal=*/false);
938 }
939 auto memoryAccess = *op.getMemoryAccess();
940 switch (memoryAccess) {
941 case spirv::MemoryAccess::Aligned:
942 case spirv::MemoryAccess::None:
943 case spirv::MemoryAccess::Nontemporal:
944 case spirv::MemoryAccess::Volatile: {
945 unsigned alignment =
946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
949 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
950 *this->getTypeConverter(), alignment,
951 isVolatile, isNonTemporal);
952 }
953 default:
954 // There is no support of other memory access attributes.
955 return failure();
956 }
957 }
958};
959
960/// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
961template <typename SPIRVOp>
962class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
963public:
964 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
965
966 LogicalResult
967 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
968 ConversionPatternRewriter &rewriter) const override {
969 auto srcType = notOp.getType();
970 auto dstType = this->getTypeConverter()->convertType(srcType);
971 if (!dstType)
972 return rewriter.notifyMatchFailure(notOp, "type conversion failed");
973
974 Location loc = notOp.getLoc();
975 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
976 auto mask =
977 isa<VectorType>(srcType)
978 ? LLVM::ConstantOp::create(
979 rewriter, loc, dstType,
980 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne))
981 : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne);
982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
983 notOp.getOperand(), mask);
984 return success();
985 }
986};
987
988/// A template pattern that erases the given `SPIRVOp`.
989template <typename SPIRVOp>
990class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
991public:
992 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
993
994 LogicalResult
995 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
996 ConversionPatternRewriter &rewriter) const override {
997 rewriter.eraseOp(op);
998 return success();
999 }
1000};
1001
1002class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
1003public:
1004 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
1005
1006 LogicalResult
1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
1008 ConversionPatternRewriter &rewriter) const override {
1009 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
1010 ArrayRef<Value>());
1011 return success();
1012 }
1013};
1014
1015class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1016public:
1017 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
1018
1019 LogicalResult
1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
1023 adaptor.getOperands());
1024 return success();
1025 }
1026};
1027
1028static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1029 StringRef name,
1030 ArrayRef<Type> paramTypes,
1031 Type resultType,
1032 bool convergent = true) {
1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1034 SymbolTable::lookupSymbolIn(symbolTable, name));
1035 if (func)
1036 return func;
1037
1038 OpBuilder b(symbolTable->getRegion(0));
1039 func = LLVM::LLVMFuncOp::create(
1040 b, symbolTable->getLoc(), name,
1041 LLVM::LLVMFunctionType::get(resultType, paramTypes));
1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1043 func.setConvergent(convergent);
1044 func.setNoUnwind(true);
1045 func.setWillReturn(true);
1046 return func;
1047}
1048
1049static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
1050 LLVM::LLVMFuncOp func,
1051 ValueRange args) {
1052 auto call = LLVM::CallOp::create(builder, loc, func, args);
1053 call.setCConv(func.getCConv());
1054 call.setConvergentAttr(func.getConvergentAttr());
1055 call.setNoUnwindAttr(func.getNoUnwindAttr());
1056 call.setWillReturnAttr(func.getWillReturnAttr());
1057 return call;
1058}
1059
1060template <typename BarrierOpTy>
1061class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> {
1062public:
1063 using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064
1065 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1066
1067 static constexpr StringRef getFuncName();
1068
1069 LogicalResult
1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
1071 ConversionPatternRewriter &rewriter) const override {
1072 constexpr StringRef funcName = getFuncName();
1073 Operation *symbolTable =
1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>();
1075
1076 Type i32 = rewriter.getI32Type();
1077
1078 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
1079 LLVM::LLVMFuncOp func =
1080 lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
1081
1082 Location loc = controlBarrierOp->getLoc();
1083 Value execution = LLVM::ConstantOp::create(
1084 rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
1085 Value memory = LLVM::ConstantOp::create(
1086 rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
1087 Value semantics = LLVM::ConstantOp::create(
1088 rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
1089
1090 auto call = createSPIRVBuiltinCall(loc, rewriter, func,
1091 {execution, memory, semantics});
1092
1093 rewriter.replaceOp(controlBarrierOp, call);
1094 return success();
1095 }
1096};
1097
1098namespace {
1099
1100StringRef getTypeMangling(Type type, bool isSigned) {
1102 .Case<Float16Type>([](auto) { return "Dh"; })
1103 .Case<Float32Type>([](auto) { return "f"; })
1104 .Case<Float64Type>([](auto) { return "d"; })
1105 .Case<IntegerType>([isSigned](IntegerType intTy) {
1106 switch (intTy.getWidth()) {
1107 case 1:
1108 return "b";
1109 case 8:
1110 return (isSigned) ? "a" : "c";
1111 case 16:
1112 return (isSigned) ? "s" : "t";
1113 case 32:
1114 return (isSigned) ? "i" : "j";
1115 case 64:
1116 return (isSigned) ? "l" : "m";
1117 default:
1118 llvm_unreachable("Unsupported integer width");
1119 }
1120 })
1121 .DefaultUnreachable("No mangling defined");
1122}
1123
1124template <typename ReduceOp>
1125constexpr StringLiteral getGroupFuncName();
1126
1127template <>
1128constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1129 return "_Z17__spirv_GroupIAddii";
1130}
1131template <>
1132constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1133 return "_Z17__spirv_GroupFAddii";
1134}
1135template <>
1136constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1137 return "_Z17__spirv_GroupSMinii";
1138}
1139template <>
1140constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1141 return "_Z17__spirv_GroupUMinii";
1142}
1143template <>
1144constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1145 return "_Z17__spirv_GroupFMinii";
1146}
1147template <>
1148constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1149 return "_Z17__spirv_GroupSMaxii";
1150}
1151template <>
1152constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1153 return "_Z17__spirv_GroupUMaxii";
1154}
1155template <>
1156constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1157 return "_Z17__spirv_GroupFMaxii";
1158}
1159template <>
1160constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1161 return "_Z27__spirv_GroupNonUniformIAddii";
1162}
1163template <>
1164constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1165 return "_Z27__spirv_GroupNonUniformFAddii";
1166}
1167template <>
1168constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1169 return "_Z27__spirv_GroupNonUniformIMulii";
1170}
1171template <>
1172constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1173 return "_Z27__spirv_GroupNonUniformFMulii";
1174}
1175template <>
1176constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1177 return "_Z27__spirv_GroupNonUniformSMinii";
1178}
1179template <>
1180constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1181 return "_Z27__spirv_GroupNonUniformUMinii";
1182}
1183template <>
1184constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1185 return "_Z27__spirv_GroupNonUniformFMinii";
1186}
1187template <>
1188constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1189 return "_Z27__spirv_GroupNonUniformSMaxii";
1190}
1191template <>
1192constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1193 return "_Z27__spirv_GroupNonUniformUMaxii";
1194}
1195template <>
1196constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1197 return "_Z27__spirv_GroupNonUniformFMaxii";
1198}
1199template <>
1200constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1201 return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1202}
1203template <>
1204constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1205 return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1206}
1207template <>
1208constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1209 return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1210}
1211template <>
1212constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1213 return "_Z33__spirv_GroupNonUniformLogicalAndii";
1214}
1215template <>
1216constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1217 return "_Z32__spirv_GroupNonUniformLogicalOrii";
1218}
1219template <>
1220constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1221 return "_Z33__spirv_GroupNonUniformLogicalXorii";
1222}
1223} // namespace
1224
1225template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1226class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1227public:
1228 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1229
1230 LogicalResult
1231 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1232 ConversionPatternRewriter &rewriter) const override {
1233
1234 Type retTy = op.getResult().getType();
1235 if (!retTy.isIntOrFloat()) {
1236 return failure();
1237 }
1238 SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1239 funcName += getTypeMangling(retTy, false);
1240
1241 Type i32Ty = rewriter.getI32Type();
1242 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1243 if constexpr (NonUniform) {
1244 if (adaptor.getClusterSize()) {
1245 funcName += "j";
1246 paramTypes.push_back(i32Ty);
1247 }
1248 }
1249
1250 Operation *symbolTable =
1251 op->template getParentWithTrait<OpTrait::SymbolTable>();
1252
1253 LLVM::LLVMFuncOp func =
1254 lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy);
1255
1256 Location loc = op.getLoc();
1257 Value scope = LLVM::ConstantOp::create(
1258 rewriter, loc, i32Ty,
1259 static_cast<int32_t>(adaptor.getExecutionScope()));
1260 Value groupOp = LLVM::ConstantOp::create(
1261 rewriter, loc, i32Ty,
1262 static_cast<int32_t>(adaptor.getGroupOperation()));
1263 SmallVector<Value> operands{scope, groupOp};
1264 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1265
1266 auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1267 rewriter.replaceOp(op, call);
1268 return success();
1269 }
1270};
1271
1272template <>
1273constexpr StringRef
1274ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1275 return "_Z22__spirv_ControlBarrieriii";
1276}
1277
1278template <>
1279constexpr StringRef
1280ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1281 return "_Z33__spirv_ControlBarrierArriveINTELiii";
1282}
1283
1284template <>
1285constexpr StringRef
1286ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1287 return "_Z31__spirv_ControlBarrierWaitINTELiii";
1288}
1289
1290/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1291/// should be reachable for conversion to succeed. The structure of the loop in
1292/// LLVM dialect will be the following:
1293///
1294/// +------------------------------------+
1295/// | <code before spirv.mlir.loop> |
1296/// | llvm.br ^header |
1297/// +------------------------------------+
1298/// |
1299/// +----------------+ |
1300/// | | |
1301/// | V V
1302/// | +------------------------------------+
1303/// | | ^header: |
1304/// | | <header code> |
1305/// | | llvm.cond_br %cond, ^body, ^exit |
1306/// | +------------------------------------+
1307/// | |
1308/// | |----------------------+
1309/// | | |
1310/// | V |
1311/// | +------------------------------------+ |
1312/// | | ^body: | |
1313/// | | <body code> | |
1314/// | | llvm.br ^continue | |
1315/// | +------------------------------------+ |
1316/// | | |
1317/// | V |
1318/// | +------------------------------------+ |
1319/// | | ^continue: | |
1320/// | | <continue code> | |
1321/// | | llvm.br ^header | |
1322/// | +------------------------------------+ |
1323/// | | |
1324/// +---------------+ +----------------------+
1325/// |
1326/// V
1327/// +------------------------------------+
1328/// | ^exit: |
1329/// | llvm.br ^remaining |
1330/// +------------------------------------+
1331/// |
1332/// V
1333/// +------------------------------------+
1334/// | ^remaining: |
1335/// | <code after spirv.mlir.loop> |
1336/// +------------------------------------+
1337///
1338class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1339public:
1340 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1341
1342 LogicalResult
1343 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter) const override {
1345 // There is no support of loop control at the moment.
1346 if (loopOp.getLoopControl() != spirv::LoopControl::None)
1347 return failure();
1348
1349 // `spirv.mlir.loop` with empty region is redundant and should be erased.
1350 if (loopOp.getBody().empty()) {
1351 rewriter.eraseOp(loopOp);
1352 return success();
1353 }
1354
1355 Location loc = loopOp.getLoc();
1356
1357 // Split the current block after `spirv.mlir.loop`. The remaining ops will
1358 // be used in `endBlock`.
1359 Block *currentBlock = rewriter.getBlock();
1360 auto position = Block::iterator(loopOp);
1361 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1362
1363 // Remove entry block and create a branch in the current block going to the
1364 // header block.
1365 Block *entryBlock = loopOp.getEntryBlock();
1366 assert(entryBlock->getOperations().size() == 1);
1367 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1368 if (!brOp)
1369 return failure();
1370 Block *headerBlock = loopOp.getHeaderBlock();
1371 rewriter.setInsertionPointToEnd(currentBlock);
1372 LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock);
1373 rewriter.eraseBlock(entryBlock);
1374
1375 // Branch from merge block to end block.
1376 Block *mergeBlock = loopOp.getMergeBlock();
1377 Operation *terminator = mergeBlock->getTerminator();
1378 ValueRange terminatorOperands = terminator->getOperands();
1379 rewriter.setInsertionPointToEnd(mergeBlock);
1380 LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock);
1381
1382 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1383 rewriter.replaceOp(loopOp, endBlock->getArguments());
1384 return success();
1385 }
1386};
1387
1388/// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1389/// block. All blocks within selection should be reachable for conversion to
1390/// succeed.
1391class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1392public:
1393 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1394
1395 LogicalResult
1396 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1397 ConversionPatternRewriter &rewriter) const override {
1398 // There is no support for `Flatten` or `DontFlatten` selection control at
1399 // the moment. This are just compiler hints and can be performed during the
1400 // optimization passes.
1401 if (op.getSelectionControl() != spirv::SelectionControl::None)
1402 return failure();
1403
1404 // `spirv.mlir.selection` should have at least two blocks: one selection
1405 // header block and one merge block. If no blocks are present, or control
1406 // flow branches straight to merge block (two blocks are present), the op is
1407 // redundant and it is erased.
1408 if (op.getBody().getBlocks().size() <= 2) {
1409 rewriter.eraseOp(op);
1410 return success();
1411 }
1412
1413 Location loc = op.getLoc();
1414
1415 // Split the current block after `spirv.mlir.selection`. The remaining ops
1416 // will be used in `continueBlock`.
1417 auto *currentBlock = rewriter.getInsertionBlock();
1418 rewriter.setInsertionPointAfter(op);
1419 auto position = rewriter.getInsertionPoint();
1420 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1421
1422 // Extract conditional branch information from the header block. By SPIR-V
1423 // dialect spec, it should contain `spirv.BranchConditional` or
1424 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1425 // moment in the SPIR-V dialect. Remove this block when finished.
1426 auto *headerBlock = op.getHeaderBlock();
1427 assert(headerBlock->getOperations().size() == 1);
1428 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1429 headerBlock->getOperations().front());
1430 if (!condBrOp)
1431 return failure();
1432
1433 // Branch from merge block to continue block.
1434 auto *mergeBlock = op.getMergeBlock();
1435 Operation *terminator = mergeBlock->getTerminator();
1436 ValueRange terminatorOperands = terminator->getOperands();
1437 rewriter.setInsertionPointToEnd(mergeBlock);
1438 LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock);
1439
1440 // Link current block to `true` and `false` blocks within the selection.
1441 Block *trueBlock = condBrOp.getTrueBlock();
1442 Block *falseBlock = condBrOp.getFalseBlock();
1443 rewriter.setInsertionPointToEnd(currentBlock);
1444 LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock,
1445 condBrOp.getTrueTargetOperands(), falseBlock,
1446 condBrOp.getFalseTargetOperands());
1447
1448 rewriter.eraseBlock(headerBlock);
1449 rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1450 rewriter.replaceOp(op, continueBlock->getArguments());
1451 return success();
1452 }
1453};
1454
1455/// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1456/// puts a restriction on `Shift` and `Base` to have the same bit width,
1457/// `Shift` is zero or sign extended to match this specification. Cases when
1458/// `Shift` bit width > `Base` bit width are considered to be illegal.
1459template <typename SPIRVOp, typename LLVMOp>
1460class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1461public:
1462 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1463
1464 LogicalResult
1465 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
1466 ConversionPatternRewriter &rewriter) const override {
1467
1468 auto dstType = this->getTypeConverter()->convertType(op.getType());
1469 if (!dstType)
1470 return rewriter.notifyMatchFailure(op, "type conversion failed");
1471
1472 Type op1Type = op.getOperand1().getType();
1473 Type op2Type = op.getOperand2().getType();
1474
1475 if (op1Type == op2Type) {
1476 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType,
1477 adaptor.getOperands());
1478 return success();
1479 }
1480
1481 std::optional<uint64_t> dstTypeWidth =
1483 std::optional<uint64_t> op2TypeWidth =
1485
1486 if (!dstTypeWidth || !op2TypeWidth)
1487 return failure();
1488
1489 Location loc = op.getLoc();
1490 Value extended;
1491 if (op2TypeWidth < dstTypeWidth) {
1492 if (isUnsignedIntegerOrVector(op2Type)) {
1493 extended =
1494 LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1495 } else {
1496 extended =
1497 LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2());
1498 }
1499 } else if (op2TypeWidth == dstTypeWidth) {
1500 extended = adaptor.getOperand2();
1501 } else {
1502 return failure();
1503 }
1504
1505 Value result =
1506 LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended);
1507 rewriter.replaceOp(op, result);
1508 return success();
1509 }
1510};
1511
1512class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1513public:
1514 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion;
1515
1516 LogicalResult
1517 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1518 ConversionPatternRewriter &rewriter) const override {
1519 auto dstType = getTypeConverter()->convertType(tanOp.getType());
1520 if (!dstType)
1521 return rewriter.notifyMatchFailure(tanOp, "type conversion failed");
1522
1523 rewriter.replaceOpWithNewOp<LLVM::TanOp>(tanOp, dstType,
1524 adaptor.getOperands());
1525 return success();
1526 }
1527};
1528
1529class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1530public:
1531 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion;
1532
1533 LogicalResult
1534 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1535 ConversionPatternRewriter &rewriter) const override {
1536 auto srcType = tanhOp.getType();
1537 auto dstType = getTypeConverter()->convertType(srcType);
1538 if (!dstType)
1539 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed");
1540
1541 rewriter.replaceOpWithNewOp<LLVM::TanhOp>(tanhOp, dstType,
1542 adaptor.getOperands());
1543 return success();
1544 }
1545};
1546
1547class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1548public:
1549 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1550
1551 LogicalResult
1552 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1553 ConversionPatternRewriter &rewriter) const override {
1554 auto srcType = varOp.getType();
1555 // Initialization is supported for scalars and vectors only.
1556 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType();
1557 auto init = varOp.getInitializer();
1558 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo))
1559 return failure();
1560
1561 auto dstType = getTypeConverter()->convertType(srcType);
1562 if (!dstType)
1563 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1564
1565 Location loc = varOp.getLoc();
1566 Value size = createI32ConstantOf(loc, rewriter, 1);
1567 if (!init) {
1568 auto elementType = getTypeConverter()->convertType(pointerTo);
1569 if (!elementType)
1570 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1571 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType,
1572 size);
1573 return success();
1574 }
1575 auto elementType = getTypeConverter()->convertType(pointerTo);
1576 if (!elementType)
1577 return rewriter.notifyMatchFailure(varOp, "type conversion failed");
1578 Value allocated =
1579 LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size);
1580 LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated);
1581 rewriter.replaceOp(varOp, allocated);
1582 return success();
1583 }
1584};
1585
1586//===----------------------------------------------------------------------===//
1587// BitcastOp conversion
1588//===----------------------------------------------------------------------===//
1589
1590class BitcastConversionPattern
1591 : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1592public:
1593 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion;
1594
1595 LogicalResult
1596 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1597 ConversionPatternRewriter &rewriter) const override {
1598 auto dstType = getTypeConverter()->convertType(bitcastOp.getType());
1599 if (!dstType)
1600 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed");
1601
1602 // LLVM's opaque pointers do not require bitcasts.
1603 if (isa<LLVM::LLVMPointerType>(dstType)) {
1604 rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1605 return success();
1606 }
1607
1608 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1609 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1610 return success();
1611 }
1612};
1613
1614//===----------------------------------------------------------------------===//
1615// FuncOp conversion
1616//===----------------------------------------------------------------------===//
1617
1618class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1619public:
1620 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1621
1622 LogicalResult
1623 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1624 ConversionPatternRewriter &rewriter) const override {
1625
1626 // Convert function signature. At the moment LLVMType converter is enough
1627 // for currently supported types.
1628 auto funcType = funcOp.getFunctionType();
1629 TypeConverter::SignatureConversion signatureConverter(
1630 funcType.getNumInputs());
1631 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter())
1632 ->convertFunctionSignature(
1633 funcType, /*isVariadic=*/false,
1634 /*useBarePtrCallConv=*/false, signatureConverter);
1635 if (!llvmType)
1636 return failure();
1637
1638 // Create a new `LLVMFuncOp`
1639 Location loc = funcOp.getLoc();
1640 StringRef name = funcOp.getName();
1641 auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType);
1642
1643 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1644 MLIRContext *context = funcOp.getContext();
1645 switch (funcOp.getFunctionControl()) {
1646 case spirv::FunctionControl::Inline:
1647 newFuncOp.setAlwaysInline(true);
1648 break;
1649 case spirv::FunctionControl::DontInline:
1650 newFuncOp.setNoInline(true);
1651 break;
1652
1653#define DISPATCH(functionControl, llvmAttr) \
1654 case functionControl: \
1655 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1656 break;
1657
1658 DISPATCH(spirv::FunctionControl::Pure,
1659 StringAttr::get(context, "readonly"));
1660 DISPATCH(spirv::FunctionControl::Const,
1661 StringAttr::get(context, "readnone"));
1662
1663#undef DISPATCH
1664
1665 // Default: if `spirv::FunctionControl::None`, then no attributes are
1666 // needed.
1667 default:
1668 break;
1669 }
1670
1671 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1672 newFuncOp.end());
1673 if (failed(rewriter.convertRegionTypes(
1674 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) {
1675 return failure();
1676 }
1677 rewriter.eraseOp(funcOp);
1678 return success();
1679 }
1680};
1681
1682//===----------------------------------------------------------------------===//
1683// ModuleOp conversion
1684//===----------------------------------------------------------------------===//
1685
1686class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1687public:
1688 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1689
1690 LogicalResult
1691 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1692 ConversionPatternRewriter &rewriter) const override {
1693
1694 auto newModuleOp =
1695 ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName());
1696 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1697
1698 // Remove the terminator block that was automatically added by builder
1699 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1700 rewriter.eraseOp(spvModuleOp);
1701 return success();
1702 }
1703};
1704
1705//===----------------------------------------------------------------------===//
1706// VectorShuffleOp conversion
1707//===----------------------------------------------------------------------===//
1708
1709class VectorShufflePattern
1710 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1711public:
1712 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion;
1713 LogicalResult
1714 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1715 ConversionPatternRewriter &rewriter) const override {
1716 Location loc = op.getLoc();
1717 auto components = adaptor.getComponents();
1718 auto vector1 = adaptor.getVector1();
1719 auto vector2 = adaptor.getVector2();
1720 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements();
1721 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements();
1722 if (vector1Size == vector2Size) {
1723 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1724 op, vector1, vector2,
1726 return success();
1727 }
1728
1729 auto dstType = getTypeConverter()->convertType(op.getType());
1730 if (!dstType)
1731 return rewriter.notifyMatchFailure(op, "type conversion failed");
1732 auto scalarType = cast<VectorType>(dstType).getElementType();
1733 auto componentsArray = components.getValue();
1734 auto *context = rewriter.getContext();
1735 auto llvmI32Type = IntegerType::get(context, 32);
1736 Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType);
1737 for (unsigned i = 0; i < componentsArray.size(); i++) {
1738 if (!isa<IntegerAttr>(componentsArray[i]))
1739 return op.emitError("unable to support non-constant component");
1740
1741 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt();
1742 if (indexVal == -1)
1743 continue;
1744
1745 int offsetVal = 0;
1746 Value baseVector = vector1;
1747 if (indexVal >= vector1Size) {
1748 offsetVal = vector1Size;
1749 baseVector = vector2;
1750 }
1751
1752 Value dstIndex = LLVM::ConstantOp::create(
1753 rewriter, loc, llvmI32Type,
1754 rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1755 Value index = LLVM::ConstantOp::create(
1756 rewriter, loc, llvmI32Type,
1757 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1758
1759 auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType,
1760 baseVector, index);
1761 targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp,
1762 extractOp, dstIndex);
1763 }
1764 rewriter.replaceOp(op, targetOp);
1765 return success();
1766 }
1767};
1768} // namespace
1769
1770//===----------------------------------------------------------------------===//
1771// Pattern population
1772//===----------------------------------------------------------------------===//
1773
1775 spirv::ClientAPI clientAPI) {
1776 typeConverter.addConversion([&](spirv::ArrayType type) {
1777 return convertArrayType(type, typeConverter);
1778 });
1779 typeConverter.addConversion([&, clientAPI](spirv::PointerType type) {
1780 return convertPointerType(type, typeConverter, clientAPI);
1781 });
1782 typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1783 return convertRuntimeArrayType(type, typeConverter);
1784 });
1785 typeConverter.addConversion([&](spirv::StructType type) {
1786 return convertStructType(type, typeConverter);
1787 });
1788}
1789
1791 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
1792 spirv::ClientAPI clientAPI) {
1793 patterns.add<
1794 // Arithmetic ops
1795 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1796 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1797 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1798 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1799 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1800 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1801 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1802 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1803 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1804 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1805 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1806 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1807 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1808
1809 // Bitwise ops
1810 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1811 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1812 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1813 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1814 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1815 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1816 NotPattern<spirv::NotOp>,
1817
1818 // Cast ops
1819 BitcastConversionPattern,
1820 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1821 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1822 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1823 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1824 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1825 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1826 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1827
1828 // Comparison ops
1829 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1830 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1831 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1832 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1833 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1834 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1835 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1836 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1837 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1838 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1839 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1840 LLVM::FCmpPredicate::uge>,
1841 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1842 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1843 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1844 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1845 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1846 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1847 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1848 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1849 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1850 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1851 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1852
1853 // Constant op
1854 ConstantScalarAndVectorPattern,
1855
1856 // Control Flow ops
1857 BranchConversionPattern, BranchConditionalConversionPattern,
1858 FunctionCallPattern, LoopPattern, SelectionPattern,
1859 ErasePattern<spirv::MergeOp>,
1860
1861 // Entry points and execution mode are handled separately.
1862 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1863
1864 // GLSL extended instruction set ops
1865 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1866 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1867 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1868 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1869 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1870 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1871 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1872 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1873 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1874 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1875 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1876 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1877 InverseSqrtPattern, TanPattern, TanhPattern,
1878
1879 // Logical ops
1880 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1881 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1882 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1883 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1884 NotPattern<spirv::LogicalNotOp>,
1885
1886 // Memory ops
1887 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>,
1888 LoadStorePattern<spirv::StoreOp>, VariablePattern,
1889
1890 // Miscellaneous ops
1891 CompositeExtractPattern, CompositeInsertPattern,
1892 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1893 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1894 VectorShufflePattern,
1895
1896 // Shift ops
1897 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1898 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1899 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1900
1901 // Return ops
1902 ReturnPattern, ReturnValuePattern,
1903
1904 // Barrier ops
1905 ControlBarrierPattern<spirv::ControlBarrierOp>,
1906 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1907 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
1908
1909 // Group reduction operations
1910 GroupReducePattern<spirv::GroupIAddOp>,
1911 GroupReducePattern<spirv::GroupFAddOp>,
1912 GroupReducePattern<spirv::GroupFMinOp>,
1913 GroupReducePattern<spirv::GroupUMinOp>,
1914 GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1915 GroupReducePattern<spirv::GroupFMaxOp>,
1916 GroupReducePattern<spirv::GroupUMaxOp>,
1917 GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1918 GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1919 /*NonUniform=*/true>,
1920 GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1921 /*NonUniform=*/true>,
1922 GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1923 /*NonUniform=*/true>,
1924 GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1925 /*NonUniform=*/true>,
1926 GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1927 /*NonUniform=*/true>,
1928 GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1929 /*NonUniform=*/true>,
1930 GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1931 /*NonUniform=*/true>,
1932 GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1933 /*NonUniform=*/true>,
1934 GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1935 /*NonUniform=*/true>,
1936 GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1937 /*NonUniform=*/true>,
1938 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1939 /*NonUniform=*/true>,
1940 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1941 /*NonUniform=*/true>,
1942 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1943 /*NonUniform=*/true>,
1944 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1945 /*NonUniform=*/true>,
1946 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1947 /*NonUniform=*/true>,
1948 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1949 /*NonUniform=*/true>>(patterns.getContext(),
1950 typeConverter);
1951
1952 patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
1953 typeConverter);
1954}
1955
1957 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1958 patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1959}
1960
1962 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1963 patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1964}
1965
1966//===----------------------------------------------------------------------===//
1967// Pre-conversion hooks
1968//===----------------------------------------------------------------------===//
1969
1970/// Hook for descriptor set and binding number encoding.
1971static constexpr StringRef kBinding = "binding";
1972static constexpr StringRef kDescriptorSet = "descriptor_set";
1973void mlir::encodeBindAttribute(ModuleOp module) {
1974 auto spvModules = module.getOps<spirv::ModuleOp>();
1975 for (auto spvModule : spvModules) {
1976 spvModule.walk([&](spirv::GlobalVariableOp op) {
1977 IntegerAttr descriptorSet =
1978 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1979 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1980 // For every global variable in the module, get the ones with descriptor
1981 // set and binding numbers.
1982 if (descriptorSet && binding) {
1983 // Encode these numbers into the variable's symbolic name. If the
1984 // SPIR-V module has a name, add it at the beginning.
1985 auto moduleAndName =
1986 spvModule.getName().has_value()
1987 ? spvModule.getName()->str() + "_" + op.getSymName().str()
1988 : op.getSymName().str();
1989 std::string name =
1990 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1991 std::to_string(descriptorSet.getInt()),
1992 std::to_string(binding.getInt()));
1993 auto nameAttr = StringAttr::get(op->getContext(), name);
1994
1995 // Replace all symbol uses and set the new symbol name. Finally, remove
1996 // descriptor set and binding attributes.
1997 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1998 op.emitError("unable to replace all symbol uses for ") << name;
1999 SymbolTable::setSymbolName(op, nameAttr);
2000 op->removeAttr(kDescriptorSet);
2001 op->removeAttr(kBinding);
2002 }
2003 });
2004 }
2005}
return success()
static LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args)
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, StringRef name, ArrayRef< Type > paramTypes, Type resultType, bool isMemNone, bool isConvergent)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Type convertPointerType(spirv::PointerType type, const TypeConverter &converter, spirv::ClientAPI clientAPI)
Converts SPIR-V pointer type to LLVM pointer.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Type convertStructTypePacked(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
static Value optionallyBroadcast(Location loc, Value value, Type srcType, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
static std::optional< uint64_t > getIntegerOrVectorElementWidth(Type type)
Returns the width of an integer or of the element type of an integer vector, if applicable.
#define DISPATCH(functionControl, llvmAttr)
static Type convertStructTypeWithOffset(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Type convertStructType(spirv::StructType type, const TypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
OpListType::iterator iterator
Definition Block.h:140
OpListType & getOperations()
Definition Block.h:137
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI32Type()
Definition Builders.cpp:63
MLIRContext * getContext() const
Definition Builders.h:56
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Type getElementType() const
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getNumElements() const
StorageClass getStorageClass() const
unsigned getArrayStride() const
Returns the array stride in bytes.
SPIR-V struct type.
Definition SPIRVTypes.h:251
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
TypeRange getElementTypes() const
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
unsigned storageClassToAddressSpace(spirv::ClientAPI clientAPI, spirv::StorageClass storageClass)
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMFunctionConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
void populateSPIRVToLLVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, spirv::ClientAPI clientAPIForAddressSpaceMapping=spirv::ClientAPI::Unknown)
Populates the given list with patterns that convert from SPIR-V to LLVM.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMModuleConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.