MLIR  17.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38  if (type.isSignedInteger())
39  return true;
40  if (auto vecType = type.dyn_cast<VectorType>())
41  return vecType.getElementType().isSignedInteger();
42  return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47  if (type.isUnsignedInteger())
48  return true;
49  if (auto vecType = type.dyn_cast<VectorType>())
50  return vecType.getElementType().isUnsignedInteger();
51  return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56  assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57  "bitwidth is not supported for this type");
58  if (type.isIntOrFloat())
59  return type.getIntOrFloatBitWidth();
60  auto vecType = type.dyn_cast<VectorType>();
61  auto elementType = vecType.getElementType();
62  assert(elementType.isIntOrFloat() &&
63  "only integers and floats have a bitwidth");
64  return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(Type type) {
70  : type)
71  .cast<IntegerType>()
72  .getWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77  if (auto vecType = type.dyn_cast<VectorType>()) {
78  auto integerType = vecType.getElementType().cast<IntegerType>();
79  return builder.getIntegerAttr(integerType, -1);
80  }
81  auto integerType = type.cast<IntegerType>();
82  return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87  PatternRewriter &rewriter) {
88  if (srcType.isa<VectorType>()) {
89  return rewriter.create<LLVM::ConstantOp>(
90  loc, dstType,
91  SplatElementsAttr::get(srcType.cast<ShapedType>(),
92  minusOneIntegerAttribute(srcType, rewriter)));
93  }
94  return rewriter.create<LLVM::ConstantOp>(
95  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100  PatternRewriter &rewriter, double value) {
101  if (auto vecType = srcType.dyn_cast<VectorType>()) {
102  auto floatType = vecType.getElementType().cast<FloatType>();
103  return rewriter.create<LLVM::ConstantOp>(
104  loc, dstType,
105  SplatElementsAttr::get(vecType,
106  rewriter.getFloatAttr(floatType, value)));
107  }
108  auto floatType = srcType.cast<FloatType>();
109  return rewriter.create<LLVM::ConstantOp>(
110  loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 /// - `BitFieldInsert`
115 /// - `BitFieldSExtract`
116 /// - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
120  Type llvmType,
121  PatternRewriter &rewriter) {
122  auto srcType = value.getType();
123  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125  ? getLLVMTypeBitWidth(srcType)
126  : getBitWidth(srcType);
127 
128  if (valueBitWidth < targetBitWidth)
129  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130  // If the bit widths of `Count` and `Offset` are greater than the bit width
131  // of the target type, they are truncated. Truncation is safe since `Count`
132  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133  // both values can be expressed in 8 bits.
134  if (valueBitWidth > targetBitWidth)
135  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136  return value;
137 }
138 
139 /// Broadcasts the value to vector with `numElements` number of elements.
140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141  LLVMTypeConverter &typeConverter,
142  ConversionPatternRewriter &rewriter) {
143  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144  auto llvmVectorType = typeConverter.convertType(vectorType);
145  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147  for (unsigned i = 0; i < numElements; ++i) {
148  auto index = rewriter.create<LLVM::ConstantOp>(
149  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150  broadcasted = rewriter.create<LLVM::InsertElementOp>(
151  loc, llvmVectorType, broadcasted, toBroadcast, index);
152  }
153  return broadcasted;
154 }
155 
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
157 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
158  LLVMTypeConverter &typeConverter,
159  ConversionPatternRewriter &rewriter) {
160  if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161  unsigned numElements = vectorType.getNumElements();
162  return broadcast(loc, value, numElements, typeConverter, rewriter);
163  }
164  return value;
165 }
166 
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 /// - same number of elements as `Base`
172 /// - each element has the type that is the same as the type of `Offset` or
173 /// `Count`
174 /// - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
177 static Value processCountOrOffset(Location loc, Value value, Type srcType,
178  Type dstType, LLVMTypeConverter &converter,
179  ConversionPatternRewriter &rewriter) {
180  Value broadcasted =
181  optionallyBroadcast(loc, value, srcType, converter, rewriter);
182  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184 
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static std::optional<Type>
189  LLVMTypeConverter &converter) {
190  if (type != VulkanLayoutUtils::decorateType(type))
191  return std::nullopt;
192 
193  auto elementsVector = llvm::to_vector<8>(
194  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195  return converter.convertType(elementType);
196  }));
197  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198  /*isPacked=*/false);
199 }
200 
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
203  LLVMTypeConverter &converter) {
204  auto elementsVector = llvm::to_vector<8>(
205  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206  return converter.convertType(elementType);
207  }));
208  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209  /*isPacked=*/true);
210 }
211 
212 /// Creates LLVM dialect constant with the given value.
214  unsigned value) {
215  return rewriter.create<LLVM::ConstantOp>(
216  loc, IntegerType::get(rewriter.getContext(), 32),
217  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219 
220 /// Utility for `spirv.Load` and `spirv.Store` conversion.
222  ConversionPatternRewriter &rewriter,
223  LLVMTypeConverter &typeConverter,
224  unsigned alignment, bool isVolatile,
225  bool isNonTemporal) {
226  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227  auto dstType = typeConverter.convertType(loadOp.getType());
228  if (!dstType)
229  return failure();
230  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
232  isVolatile, isNonTemporal);
233  return success();
234  }
235  auto storeOp = cast<spirv::StoreOp>(op);
236  spirv::StoreOpAdaptor adaptor(operands);
237  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
238  adaptor.getPtr(), alignment,
239  isVolatile, isNonTemporal);
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Type conversion
245 //===----------------------------------------------------------------------===//
246 
247 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
248 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
249 /// when converting ops that manipulate array types.
250 static std::optional<Type> convertArrayType(spirv::ArrayType type,
251  TypeConverter &converter) {
252  unsigned stride = type.getArrayStride();
253  Type elementType = type.getElementType();
254  auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
255  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
256  return std::nullopt;
257 
258  auto llvmElementType = converter.convertType(elementType);
259  unsigned numElements = type.getNumElements();
260  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
261 }
262 
263 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
264 /// modelled at the moment.
266  LLVMTypeConverter &converter) {
267  auto pointeeType = converter.convertType(type.getPointeeType());
268  return converter.getPointerType(pointeeType);
269 }
270 
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
274 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
275  TypeConverter &converter) {
276  if (type.getArrayStride() != 0)
277  return std::nullopt;
278  auto elementType = converter.convertType(type.getElementType());
279  return LLVM::LLVMArrayType::get(elementType, 0);
280 }
281 
282 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
283 /// member decorations. Also, only natural offset is supported.
284 static std::optional<Type> convertStructType(spirv::StructType type,
285  LLVMTypeConverter &converter) {
287  type.getMemberDecorations(memberDecorations);
288  if (!memberDecorations.empty())
289  return std::nullopt;
290  if (type.hasOffset())
291  return convertStructTypeWithOffset(type, converter);
292  return convertStructTypePacked(type, converter);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Operation conversion
297 //===----------------------------------------------------------------------===//
298 
299 namespace {
300 
301 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
302 public:
304 
306  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
307  ConversionPatternRewriter &rewriter) const override {
308  auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
309  if (!dstType)
310  return failure();
311  // To use GEP we need to add a first 0 index to go through the pointer.
312  auto indices = llvm::to_vector<4>(adaptor.getIndices());
313  Type indexType = op.getIndices().front().getType();
314  auto llvmIndexType = typeConverter.convertType(indexType);
315  if (!llvmIndexType)
316  return failure();
317  Value zero = rewriter.create<LLVM::ConstantOp>(
318  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
319  indices.insert(indices.begin(), zero);
320  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(
321  op, dstType,
322  typeConverter.convertType(op.getBasePtr()
323  .getType()
324  .cast<spirv::PointerType>()
325  .getPointeeType()),
326  adaptor.getBasePtr(), indices);
327  return success();
328  }
329 };
330 
331 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
332 public:
334 
336  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
337  ConversionPatternRewriter &rewriter) const override {
338  auto dstType = typeConverter.convertType(op.getPointer().getType());
339  if (!dstType)
340  return failure();
341  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
342  return success();
343  }
344 };
345 
346 class BitFieldInsertPattern
347  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
348 public:
350 
352  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override {
354  auto srcType = op.getType();
355  auto dstType = typeConverter.convertType(srcType);
356  if (!dstType)
357  return failure();
358  Location loc = op.getLoc();
359 
360  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
361  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
362  typeConverter, rewriter);
363  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
364  typeConverter, rewriter);
365 
366  // Create a mask with bits set outside [Offset, Offset + Count - 1].
367  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
368  Value maskShiftedByCount =
369  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
370  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
371  maskShiftedByCount, minusOne);
372  Value maskShiftedByCountAndOffset =
373  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
374  Value mask = rewriter.create<LLVM::XOrOp>(
375  loc, dstType, maskShiftedByCountAndOffset, minusOne);
376 
377  // Extract unchanged bits from the `Base` that are outside of
378  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
379  Value baseAndMask =
380  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
381  Value insertShiftedByOffset =
382  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
383  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
384  insertShiftedByOffset);
385  return success();
386  }
387 };
388 
389 /// Converts SPIR-V ConstantOp with scalar or vector type.
390 class ConstantScalarAndVectorPattern
391  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
392 public:
394 
396  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
397  ConversionPatternRewriter &rewriter) const override {
398  auto srcType = constOp.getType();
399  if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
400  return failure();
401 
402  auto dstType = typeConverter.convertType(srcType);
403  if (!dstType)
404  return failure();
405 
406  // SPIR-V constant can be a signed/unsigned integer, which has to be
407  // casted to signless integer when converting to LLVM dialect. Removing the
408  // sign bit may have unexpected behaviour. However, it is better to handle
409  // it case-by-case, given that the purpose of the conversion is not to
410  // cover all possible corner cases.
411  if (isSignedIntegerOrVector(srcType) ||
412  isUnsignedIntegerOrVector(srcType)) {
413  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
414 
415  if (srcType.isa<VectorType>()) {
416  auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
417  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
418  constOp, dstType,
419  dstElementsAttr.mapValues(
420  signlessType, [&](const APInt &value) { return value; }));
421  return success();
422  }
423  auto srcAttr = constOp.getValue().cast<IntegerAttr>();
424  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
425  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
426  return success();
427  }
428  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
429  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
430  return success();
431  }
432 };
433 
434 class BitFieldSExtractPattern
435  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
436 public:
438 
440  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
441  ConversionPatternRewriter &rewriter) const override {
442  auto srcType = op.getType();
443  auto dstType = typeConverter.convertType(srcType);
444  if (!dstType)
445  return failure();
446  Location loc = op.getLoc();
447 
448  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
449  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
450  typeConverter, rewriter);
451  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
452  typeConverter, rewriter);
453 
454  // Create a constant that holds the size of the `Base`.
455  IntegerType integerType;
456  if (auto vecType = srcType.dyn_cast<VectorType>())
457  integerType = vecType.getElementType().cast<IntegerType>();
458  else
459  integerType = srcType.cast<IntegerType>();
460 
461  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
462  Value size =
463  srcType.isa<VectorType>()
464  ? rewriter.create<LLVM::ConstantOp>(
465  loc, dstType,
466  SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
467  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
468 
469  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
470  // at Offset + Count - 1 is the most significant bit now.
471  Value countPlusOffset =
472  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
473  Value amountToShiftLeft =
474  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
475  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
476  loc, dstType, op.getBase(), amountToShiftLeft);
477 
478  // Shift the result right, filling the bits with the sign bit.
479  Value amountToShiftRight =
480  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
481  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
482  amountToShiftRight);
483  return success();
484  }
485 };
486 
487 class BitFieldUExtractPattern
488  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
489 public:
491 
493  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
494  ConversionPatternRewriter &rewriter) const override {
495  auto srcType = op.getType();
496  auto dstType = typeConverter.convertType(srcType);
497  if (!dstType)
498  return failure();
499  Location loc = op.getLoc();
500 
501  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
502  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
503  typeConverter, rewriter);
504  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
505  typeConverter, rewriter);
506 
507  // Create a mask with bits set at [0, Count - 1].
508  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
509  Value maskShiftedByCount =
510  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
511  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
512  minusOne);
513 
514  // Shift `Base` by `Offset` and apply the mask on it.
515  Value shiftedBase =
516  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
517  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
518  return success();
519  }
520 };
521 
522 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
523 public:
525 
527  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
528  ConversionPatternRewriter &rewriter) const override {
529  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
530  branchOp.getTarget());
531  return success();
532  }
533 };
534 
535 class BranchConditionalConversionPattern
536  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
537 public:
538  using SPIRVToLLVMConversion<
539  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
540 
542  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
543  ConversionPatternRewriter &rewriter) const override {
544  // If branch weights exist, map them to 32-bit integer vector.
545  ElementsAttr branchWeights = nullptr;
546  if (auto weights = op.getBranchWeights()) {
547  VectorType weightType = VectorType::get(2, rewriter.getI32Type());
548  branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
549  }
550 
551  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
552  op, op.getCondition(), op.getTrueBlockArguments(),
553  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
554  op.getFalseBlock());
555  return success();
556  }
557 };
558 
559 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
560 /// type is an aggregate type (struct or array). Otherwise, converts to
561 /// `llvm.extractelement` that operates on vectors.
562 class CompositeExtractPattern
563  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
564 public:
566 
568  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
569  ConversionPatternRewriter &rewriter) const override {
570  auto dstType = this->typeConverter.convertType(op.getType());
571  if (!dstType)
572  return failure();
573 
574  Type containerType = op.getComposite().getType();
575  if (containerType.isa<VectorType>()) {
576  Location loc = op.getLoc();
577  IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
578  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
579  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
580  op, dstType, adaptor.getComposite(), index);
581  return success();
582  }
583 
584  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
585  op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
586  return success();
587  }
588 };
589 
590 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
591 /// type is an aggregate type (struct or array). Otherwise, converts to
592 /// `llvm.insertelement` that operates on vectors.
593 class CompositeInsertPattern
594  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
595 public:
597 
599  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
600  ConversionPatternRewriter &rewriter) const override {
601  auto dstType = this->typeConverter.convertType(op.getType());
602  if (!dstType)
603  return failure();
604 
605  Type containerType = op.getComposite().getType();
606  if (containerType.isa<VectorType>()) {
607  Location loc = op.getLoc();
608  IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
609  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
610  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
611  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
612  return success();
613  }
614 
615  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
616  op, adaptor.getComposite(), adaptor.getObject(),
617  LLVM::convertArrayToIndices(op.getIndices()));
618  return success();
619  }
620 };
621 
622 /// Converts SPIR-V operations that have straightforward LLVM equivalent
623 /// into LLVM dialect operations.
624 template <typename SPIRVOp, typename LLVMOp>
625 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
626 public:
628 
630  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
631  ConversionPatternRewriter &rewriter) const override {
632  auto dstType = this->typeConverter.convertType(operation.getType());
633  if (!dstType)
634  return failure();
635  rewriter.template replaceOpWithNewOp<LLVMOp>(
636  operation, dstType, adaptor.getOperands(), operation->getAttrs());
637  return success();
638  }
639 };
640 
641 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
642 /// execution mode information.
643 class ExecutionModePattern
644  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
645 public:
647 
649  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
650  ConversionPatternRewriter &rewriter) const override {
651  // First, create the global struct's name that would be associated with
652  // this entry point's execution mode. We set it to be:
653  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
654  ModuleOp module = op->getParentOfType<ModuleOp>();
655  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
656  std::string moduleName;
657  if (module.getName().has_value())
658  moduleName = "_" + module.getName()->str();
659  else
660  moduleName = "";
661  std::string executionModeInfoName = llvm::formatv(
662  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
663  static_cast<uint32_t>(executionModeAttr.getValue()));
664 
665  MLIRContext *context = rewriter.getContext();
666  OpBuilder::InsertionGuard guard(rewriter);
667  rewriter.setInsertionPointToStart(module.getBody());
668 
669  // Create a struct type, corresponding to the C struct below.
670  // struct {
671  // int32_t executionMode;
672  // int32_t values[]; // optional values
673  // };
674  auto llvmI32Type = IntegerType::get(context, 32);
675  SmallVector<Type, 2> fields;
676  fields.push_back(llvmI32Type);
677  ArrayAttr values = op.getValues();
678  if (!values.empty()) {
679  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
680  fields.push_back(arrayType);
681  }
682  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
683 
684  // Create `llvm.mlir.global` with initializer region containing one block.
685  auto global = rewriter.create<LLVM::GlobalOp>(
686  UnknownLoc::get(context), structType, /*isConstant=*/true,
687  LLVM::Linkage::External, executionModeInfoName, Attribute(),
688  /*alignment=*/0);
689  Location loc = global.getLoc();
690  Region &region = global.getInitializerRegion();
691  Block *block = rewriter.createBlock(&region);
692 
693  // Initialize the struct and set the execution mode value.
694  rewriter.setInsertionPoint(block, block->begin());
695  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
696  Value executionMode = rewriter.create<LLVM::ConstantOp>(
697  loc, llvmI32Type,
698  rewriter.getI32IntegerAttr(
699  static_cast<uint32_t>(executionModeAttr.getValue())));
700  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
701  executionMode, 0);
702 
703  // Insert extra operands if they exist into execution mode info struct.
704  for (unsigned i = 0, e = values.size(); i < e; ++i) {
705  auto attr = values.getValue()[i];
706  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
707  structValue = rewriter.create<LLVM::InsertValueOp>(
708  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
709  }
710  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
711  rewriter.eraseOp(op);
712  return success();
713  }
714 };
715 
716 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
717 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
718 /// value. This difference is handled by `spirv.mlir.addressof` and
719 /// `llvm.mlir.addressof`ops that both return a pointer.
720 class GlobalVariablePattern
721  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
722 public:
724 
726  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
727  ConversionPatternRewriter &rewriter) const override {
728  // Currently, there is no support of initialization with a constant value in
729  // SPIR-V dialect. Specialization constants are not considered as well.
730  if (op.getInitializer())
731  return failure();
732 
733  auto srcType = op.getType().cast<spirv::PointerType>();
734  auto dstType = typeConverter.convertType(srcType.getPointeeType());
735  if (!dstType)
736  return failure();
737 
738  // Limit conversion to the current invocation only or `StorageBuffer`
739  // required by SPIR-V runner.
740  // This is okay because multiple invocations are not supported yet.
741  auto storageClass = srcType.getStorageClass();
742  switch (storageClass) {
743  case spirv::StorageClass::Input:
744  case spirv::StorageClass::Private:
745  case spirv::StorageClass::Output:
746  case spirv::StorageClass::StorageBuffer:
747  case spirv::StorageClass::UniformConstant:
748  break;
749  default:
750  return failure();
751  }
752 
753  // LLVM dialect spec: "If the global value is a constant, storing into it is
754  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
755  // storage class that is read-only.
756  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
757  (storageClass == spirv::StorageClass::UniformConstant);
758  // SPIR-V spec: "By default, functions and global variables are private to a
759  // module and cannot be accessed by other modules. However, a module may be
760  // written to export or import functions and global (module scope)
761  // variables.". Therefore, map 'Private' storage class to private linkage,
762  // 'Input' and 'Output' to external linkage.
763  auto linkage = storageClass == spirv::StorageClass::Private
764  ? LLVM::Linkage::Private
765  : LLVM::Linkage::External;
766  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
767  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
768  /*alignment=*/0);
769 
770  // Attach location attribute if applicable
771  if (op.getLocationAttr())
772  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
773 
774  return success();
775  }
776 };
777 
778 /// Converts SPIR-V cast ops that do not have straightforward LLVM
779 /// equivalent in LLVM dialect.
780 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
781 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
782 public:
784 
786  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
787  ConversionPatternRewriter &rewriter) const override {
788 
789  Type fromType = operation.getOperand().getType();
790  Type toType = operation.getType();
791 
792  auto dstType = this->typeConverter.convertType(toType);
793  if (!dstType)
794  return failure();
795 
796  if (getBitWidth(fromType) < getBitWidth(toType)) {
797  rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
798  adaptor.getOperands());
799  return success();
800  }
801  if (getBitWidth(fromType) > getBitWidth(toType)) {
802  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
803  adaptor.getOperands());
804  return success();
805  }
806  return failure();
807  }
808 };
809 
810 class FunctionCallPattern
811  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
812 public:
814 
816  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
817  ConversionPatternRewriter &rewriter) const override {
818  if (callOp.getNumResults() == 0) {
819  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
820  callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs());
821  return success();
822  }
823 
824  // Function returns a single result.
825  auto dstType = typeConverter.convertType(callOp.getType(0));
826  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
827  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
828  return success();
829  }
830 };
831 
832 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
833 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
834 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
835 public:
837 
839  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
840  ConversionPatternRewriter &rewriter) const override {
841 
842  auto dstType = this->typeConverter.convertType(operation.getType());
843  if (!dstType)
844  return failure();
845 
846  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
847  operation, dstType, predicate, operation.getOperand1(),
848  operation.getOperand2());
849  return success();
850  }
851 };
852 
853 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
854 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
855 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
856 public:
858 
860  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
861  ConversionPatternRewriter &rewriter) const override {
862 
863  auto dstType = this->typeConverter.convertType(operation.getType());
864  if (!dstType)
865  return failure();
866 
867  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
868  operation, dstType, predicate, operation.getOperand1(),
869  operation.getOperand2());
870  return success();
871  }
872 };
873 
874 class InverseSqrtPattern
875  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
876 public:
878 
880  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
881  ConversionPatternRewriter &rewriter) const override {
882  auto srcType = op.getType();
883  auto dstType = typeConverter.convertType(srcType);
884  if (!dstType)
885  return failure();
886 
887  Location loc = op.getLoc();
888  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
889  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
890  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
891  return success();
892  }
893 };
894 
895 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
896 template <typename SPIRVOp>
897 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
898 public:
900 
902  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
903  ConversionPatternRewriter &rewriter) const override {
904  if (!op.getMemoryAccess()) {
905  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
906  this->typeConverter, /*alignment=*/0,
907  /*isVolatile=*/false,
908  /*isNonTemporal=*/false);
909  }
910  auto memoryAccess = *op.getMemoryAccess();
911  switch (memoryAccess) {
912  case spirv::MemoryAccess::Aligned:
914  case spirv::MemoryAccess::Nontemporal:
915  case spirv::MemoryAccess::Volatile: {
916  unsigned alignment =
917  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
918  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
919  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
920  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
921  this->typeConverter, alignment, isVolatile,
922  isNonTemporal);
923  }
924  default:
925  // There is no support of other memory access attributes.
926  return failure();
927  }
928  }
929 };
930 
931 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
932 template <typename SPIRVOp>
933 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
934 public:
936 
938  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
939  ConversionPatternRewriter &rewriter) const override {
940  auto srcType = notOp.getType();
941  auto dstType = this->typeConverter.convertType(srcType);
942  if (!dstType)
943  return failure();
944 
945  Location loc = notOp.getLoc();
946  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
947  auto mask = srcType.template isa<VectorType>()
948  ? rewriter.create<LLVM::ConstantOp>(
949  loc, dstType,
951  srcType.template cast<VectorType>(), minusOne))
952  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
953  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
954  notOp.getOperand(), mask);
955  return success();
956  }
957 };
958 
959 /// A template pattern that erases the given `SPIRVOp`.
960 template <typename SPIRVOp>
961 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
962 public:
964 
966  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
967  ConversionPatternRewriter &rewriter) const override {
968  rewriter.eraseOp(op);
969  return success();
970  }
971 };
972 
973 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
974 public:
976 
978  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
979  ConversionPatternRewriter &rewriter) const override {
980  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
981  ArrayRef<Value>());
982  return success();
983  }
984 };
985 
986 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
987 public:
989 
991  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
992  ConversionPatternRewriter &rewriter) const override {
993  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
994  adaptor.getOperands());
995  return success();
996  }
997 };
998 
999 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1000 /// should be reachable for conversion to succeed. The structure of the loop in
1001 /// LLVM dialect will be the following:
1002 ///
1003 /// +------------------------------------+
1004 /// | <code before spirv.mlir.loop> |
1005 /// | llvm.br ^header |
1006 /// +------------------------------------+
1007 /// |
1008 /// +----------------+ |
1009 /// | | |
1010 /// | V V
1011 /// | +------------------------------------+
1012 /// | | ^header: |
1013 /// | | <header code> |
1014 /// | | llvm.cond_br %cond, ^body, ^exit |
1015 /// | +------------------------------------+
1016 /// | |
1017 /// | |----------------------+
1018 /// | | |
1019 /// | V |
1020 /// | +------------------------------------+ |
1021 /// | | ^body: | |
1022 /// | | <body code> | |
1023 /// | | llvm.br ^continue | |
1024 /// | +------------------------------------+ |
1025 /// | | |
1026 /// | V |
1027 /// | +------------------------------------+ |
1028 /// | | ^continue: | |
1029 /// | | <continue code> | |
1030 /// | | llvm.br ^header | |
1031 /// | +------------------------------------+ |
1032 /// | | |
1033 /// +---------------+ +----------------------+
1034 /// |
1035 /// V
1036 /// +------------------------------------+
1037 /// | ^exit: |
1038 /// | llvm.br ^remaining |
1039 /// +------------------------------------+
1040 /// |
1041 /// V
1042 /// +------------------------------------+
1043 /// | ^remaining: |
1044 /// | <code after spirv.mlir.loop> |
1045 /// +------------------------------------+
1046 ///
1047 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1048 public:
1050 
1052  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1053  ConversionPatternRewriter &rewriter) const override {
1054  // There is no support of loop control at the moment.
1055  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1056  return failure();
1057 
1058  Location loc = loopOp.getLoc();
1059 
1060  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1061  // be used in `endBlock`.
1062  Block *currentBlock = rewriter.getBlock();
1063  auto position = Block::iterator(loopOp);
1064  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1065 
1066  // Remove entry block and create a branch in the current block going to the
1067  // header block.
1068  Block *entryBlock = loopOp.getEntryBlock();
1069  assert(entryBlock->getOperations().size() == 1);
1070  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1071  if (!brOp)
1072  return failure();
1073  Block *headerBlock = loopOp.getHeaderBlock();
1074  rewriter.setInsertionPointToEnd(currentBlock);
1075  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1076  rewriter.eraseBlock(entryBlock);
1077 
1078  // Branch from merge block to end block.
1079  Block *mergeBlock = loopOp.getMergeBlock();
1080  Operation *terminator = mergeBlock->getTerminator();
1081  ValueRange terminatorOperands = terminator->getOperands();
1082  rewriter.setInsertionPointToEnd(mergeBlock);
1083  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1084 
1085  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1086  rewriter.replaceOp(loopOp, endBlock->getArguments());
1087  return success();
1088  }
1089 };
1090 
1091 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1092 /// block. All blocks within selection should be reachable for conversion to
1093 /// succeed.
1094 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1095 public:
1097 
1099  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1100  ConversionPatternRewriter &rewriter) const override {
1101  // There is no support for `Flatten` or `DontFlatten` selection control at
1102  // the moment. This are just compiler hints and can be performed during the
1103  // optimization passes.
1104  if (op.getSelectionControl() != spirv::SelectionControl::None)
1105  return failure();
1106 
1107  // `spirv.mlir.selection` should have at least two blocks: one selection
1108  // header block and one merge block. If no blocks are present, or control
1109  // flow branches straight to merge block (two blocks are present), the op is
1110  // redundant and it is erased.
1111  if (op.getBody().getBlocks().size() <= 2) {
1112  rewriter.eraseOp(op);
1113  return success();
1114  }
1115 
1116  Location loc = op.getLoc();
1117 
1118  // Split the current block after `spirv.mlir.selection`. The remaining ops
1119  // will be used in `continueBlock`.
1120  auto *currentBlock = rewriter.getInsertionBlock();
1121  rewriter.setInsertionPointAfter(op);
1122  auto position = rewriter.getInsertionPoint();
1123  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1124 
1125  // Extract conditional branch information from the header block. By SPIR-V
1126  // dialect spec, it should contain `spirv.BranchConditional` or
1127  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1128  // moment in the SPIR-V dialect. Remove this block when finished.
1129  auto *headerBlock = op.getHeaderBlock();
1130  assert(headerBlock->getOperations().size() == 1);
1131  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1132  headerBlock->getOperations().front());
1133  if (!condBrOp)
1134  return failure();
1135  rewriter.eraseBlock(headerBlock);
1136 
1137  // Branch from merge block to continue block.
1138  auto *mergeBlock = op.getMergeBlock();
1139  Operation *terminator = mergeBlock->getTerminator();
1140  ValueRange terminatorOperands = terminator->getOperands();
1141  rewriter.setInsertionPointToEnd(mergeBlock);
1142  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1143 
1144  // Link current block to `true` and `false` blocks within the selection.
1145  Block *trueBlock = condBrOp.getTrueBlock();
1146  Block *falseBlock = condBrOp.getFalseBlock();
1147  rewriter.setInsertionPointToEnd(currentBlock);
1148  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1149  condBrOp.getTrueTargetOperands(), falseBlock,
1150  condBrOp.getFalseTargetOperands());
1151 
1152  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1153  rewriter.replaceOp(op, continueBlock->getArguments());
1154  return success();
1155  }
1156 };
1157 
1158 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1159 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1160 /// `Shift` is zero or sign extended to match this specification. Cases when
1161 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1162 template <typename SPIRVOp, typename LLVMOp>
1163 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1164 public:
1166 
1168  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1169  ConversionPatternRewriter &rewriter) const override {
1170 
1171  auto dstType = this->typeConverter.convertType(operation.getType());
1172  if (!dstType)
1173  return failure();
1174 
1175  Type op1Type = operation.getOperand1().getType();
1176  Type op2Type = operation.getOperand2().getType();
1177 
1178  if (op1Type == op2Type) {
1179  rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1180  adaptor.getOperands());
1181  return success();
1182  }
1183 
1184  Location loc = operation.getLoc();
1185  Value extended;
1186  if (isUnsignedIntegerOrVector(op2Type)) {
1187  extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1188  adaptor.getOperand2());
1189  } else {
1190  extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1191  adaptor.getOperand2());
1192  }
1193  Value result = rewriter.template create<LLVMOp>(
1194  loc, dstType, adaptor.getOperand1(), extended);
1195  rewriter.replaceOp(operation, result);
1196  return success();
1197  }
1198 };
1199 
1200 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1201 public:
1203 
1205  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1206  ConversionPatternRewriter &rewriter) const override {
1207  auto dstType = typeConverter.convertType(tanOp.getType());
1208  if (!dstType)
1209  return failure();
1210 
1211  Location loc = tanOp.getLoc();
1212  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1213  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1214  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1215  return success();
1216  }
1217 };
1218 
1219 /// Convert `spirv.Tanh` to
1220 ///
1221 /// exp(2x) - 1
1222 /// -----------
1223 /// exp(2x) + 1
1224 ///
1225 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1226 public:
1228 
1230  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1231  ConversionPatternRewriter &rewriter) const override {
1232  auto srcType = tanhOp.getType();
1233  auto dstType = typeConverter.convertType(srcType);
1234  if (!dstType)
1235  return failure();
1236 
1237  Location loc = tanhOp.getLoc();
1238  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1239  Value multiplied =
1240  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1241  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1242  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1243  Value numerator =
1244  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1245  Value denominator =
1246  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1247  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1248  denominator);
1249  return success();
1250  }
1251 };
1252 
1253 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1254 public:
1256 
1258  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1259  ConversionPatternRewriter &rewriter) const override {
1260  auto srcType = varOp.getType();
1261  // Initialization is supported for scalars and vectors only.
1262  auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1263  auto init = varOp.getInitializer();
1264  if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1265  return failure();
1266 
1267  auto dstType = typeConverter.convertType(srcType);
1268  if (!dstType)
1269  return failure();
1270 
1271  Location loc = varOp.getLoc();
1272  Value size = createI32ConstantOf(loc, rewriter, 1);
1273  if (!init) {
1274  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(
1275  varOp, dstType, typeConverter.convertType(pointerTo), size);
1276  return success();
1277  }
1278  Value allocated = rewriter.create<LLVM::AllocaOp>(
1279  loc, dstType, typeConverter.convertType(pointerTo), size);
1280  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1281  rewriter.replaceOp(varOp, allocated);
1282  return success();
1283  }
1284 };
1285 
1286 //===----------------------------------------------------------------------===//
1287 // BitcastOp conversion
1288 //===----------------------------------------------------------------------===//
1289 
1290 class BitcastConversionPattern
1291  : public SPIRVToLLVMConversion<spirv::BitcastOp> {
1292 public:
1294 
1296  matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor,
1297  ConversionPatternRewriter &rewriter) const override {
1298  auto dstType = typeConverter.convertType(bitcastOp.getType());
1299  if (!dstType)
1300  return failure();
1301 
1302  if (typeConverter.useOpaquePointers() &&
1303  dstType.isa<LLVM::LLVMPointerType>()) {
1304  rewriter.replaceOp(bitcastOp, adaptor.getOperand());
1305  return success();
1306  }
1307 
1308  rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1309  bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs());
1310  return success();
1311  }
1312 };
1313 
1314 //===----------------------------------------------------------------------===//
1315 // FuncOp conversion
1316 //===----------------------------------------------------------------------===//
1317 
1318 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1319 public:
1321 
1323  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1324  ConversionPatternRewriter &rewriter) const override {
1325 
1326  // Convert function signature. At the moment LLVMType converter is enough
1327  // for currently supported types.
1328  auto funcType = funcOp.getFunctionType();
1329  TypeConverter::SignatureConversion signatureConverter(
1330  funcType.getNumInputs());
1331  auto llvmType = typeConverter.convertFunctionSignature(
1332  funcType, /*isVariadic=*/false, signatureConverter);
1333  if (!llvmType)
1334  return failure();
1335 
1336  // Create a new `LLVMFuncOp`
1337  Location loc = funcOp.getLoc();
1338  StringRef name = funcOp.getName();
1339  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1340 
1341  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1342  MLIRContext *context = funcOp.getContext();
1343  switch (funcOp.getFunctionControl()) {
1344 #define DISPATCH(functionControl, llvmAttr) \
1345  case functionControl: \
1346  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1347  break;
1348 
1349  DISPATCH(spirv::FunctionControl::Inline,
1350  StringAttr::get(context, "alwaysinline"));
1351  DISPATCH(spirv::FunctionControl::DontInline,
1352  StringAttr::get(context, "noinline"));
1353  DISPATCH(spirv::FunctionControl::Pure,
1354  StringAttr::get(context, "readonly"));
1355  DISPATCH(spirv::FunctionControl::Const,
1356  StringAttr::get(context, "readnone"));
1357 
1358 #undef DISPATCH
1359 
1360  // Default: if `spirv::FunctionControl::None`, then no attributes are
1361  // needed.
1362  default:
1363  break;
1364  }
1365 
1366  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1367  newFuncOp.end());
1368  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1369  &signatureConverter))) {
1370  return failure();
1371  }
1372  rewriter.eraseOp(funcOp);
1373  return success();
1374  }
1375 };
1376 
1377 //===----------------------------------------------------------------------===//
1378 // ModuleOp conversion
1379 //===----------------------------------------------------------------------===//
1380 
1381 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1382 public:
1384 
1386  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1387  ConversionPatternRewriter &rewriter) const override {
1388 
1389  auto newModuleOp =
1390  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1391  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1392 
1393  // Remove the terminator block that was automatically added by builder
1394  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1395  rewriter.eraseOp(spvModuleOp);
1396  return success();
1397  }
1398 };
1399 
1400 //===----------------------------------------------------------------------===//
1401 // VectorShuffleOp conversion
1402 //===----------------------------------------------------------------------===//
1403 
1404 class VectorShufflePattern
1405  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1406 public:
1409  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1410  ConversionPatternRewriter &rewriter) const override {
1411  Location loc = op.getLoc();
1412  auto components = adaptor.getComponents();
1413  auto vector1 = adaptor.getVector1();
1414  auto vector2 = adaptor.getVector2();
1415  int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1416  int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1417  if (vector1Size == vector2Size) {
1418  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1419  op, vector1, vector2,
1420  LLVM::convertArrayToIndices<int32_t>(components));
1421  return success();
1422  }
1423 
1424  auto dstType = typeConverter.convertType(op.getType());
1425  auto scalarType = dstType.cast<VectorType>().getElementType();
1426  auto componentsArray = components.getValue();
1427  auto *context = rewriter.getContext();
1428  auto llvmI32Type = IntegerType::get(context, 32);
1429  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1430  for (unsigned i = 0; i < componentsArray.size(); i++) {
1431  if (!componentsArray[i].isa<IntegerAttr>())
1432  return op.emitError("unable to support non-constant component");
1433 
1434  int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1435  if (indexVal == -1)
1436  continue;
1437 
1438  int offsetVal = 0;
1439  Value baseVector = vector1;
1440  if (indexVal >= vector1Size) {
1441  offsetVal = vector1Size;
1442  baseVector = vector2;
1443  }
1444 
1445  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1446  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1447  Value index = rewriter.create<LLVM::ConstantOp>(
1448  loc, llvmI32Type,
1449  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1450 
1451  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1452  loc, scalarType, baseVector, index);
1453  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1454  extractOp, dstIndex);
1455  }
1456  rewriter.replaceOp(op, targetOp);
1457  return success();
1458  }
1459 };
1460 } // namespace
1461 
1462 //===----------------------------------------------------------------------===//
1463 // Pattern population
1464 //===----------------------------------------------------------------------===//
1465 
1467  typeConverter.addConversion([&](spirv::ArrayType type) {
1468  return convertArrayType(type, typeConverter);
1469  });
1470  typeConverter.addConversion([&](spirv::PointerType type) {
1471  return convertPointerType(type, typeConverter);
1472  });
1473  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1474  return convertRuntimeArrayType(type, typeConverter);
1475  });
1476  typeConverter.addConversion([&](spirv::StructType type) {
1477  return convertStructType(type, typeConverter);
1478  });
1479 }
1480 
1482  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1483  patterns.add<
1484  // Arithmetic ops
1485  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1486  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1487  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1488  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1489  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1490  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1491  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1492  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1493  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1494  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1495  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1496  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1497  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1498 
1499  // Bitwise ops
1500  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1501  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1502  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1503  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1504  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1505  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1506  NotPattern<spirv::NotOp>,
1507 
1508  // Cast ops
1509  BitcastConversionPattern,
1510  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1511  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1512  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1513  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1514  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1515  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1516  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1517 
1518  // Comparison ops
1519  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1520  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1521  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1522  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1523  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1524  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1525  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1526  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1527  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1528  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1529  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1530  LLVM::FCmpPredicate::uge>,
1531  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1532  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1533  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1534  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1535  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1536  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1537  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1538  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1539  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1540  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1541  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1542 
1543  // Constant op
1544  ConstantScalarAndVectorPattern,
1545 
1546  // Control Flow ops
1547  BranchConversionPattern, BranchConditionalConversionPattern,
1548  FunctionCallPattern, LoopPattern, SelectionPattern,
1549  ErasePattern<spirv::MergeOp>,
1550 
1551  // Entry points and execution mode are handled separately.
1552  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1553 
1554  // GLSL extended instruction set ops
1555  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1556  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1557  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1558  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1559  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1560  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1561  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1562  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1563  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1564  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1565  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1566  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1567  InverseSqrtPattern, TanPattern, TanhPattern,
1568 
1569  // Logical ops
1570  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1571  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1572  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1573  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1574  NotPattern<spirv::LogicalNotOp>,
1575 
1576  // Memory ops
1577  AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1578  LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1579  VariablePattern,
1580 
1581  // Miscellaneous ops
1582  CompositeExtractPattern, CompositeInsertPattern,
1583  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1584  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1585  VectorShufflePattern,
1586 
1587  // Shift ops
1588  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1589  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1590  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1591 
1592  // Return ops
1593  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1594 }
1595 
1597  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1598  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1599 }
1600 
1602  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1603  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1604 }
1605 
1606 //===----------------------------------------------------------------------===//
1607 // Pre-conversion hooks
1608 //===----------------------------------------------------------------------===//
1609 
1610 /// Hook for descriptor set and binding number encoding.
1611 static constexpr StringRef kBinding = "binding";
1612 static constexpr StringRef kDescriptorSet = "descriptor_set";
1613 void mlir::encodeBindAttribute(ModuleOp module) {
1614  auto spvModules = module.getOps<spirv::ModuleOp>();
1615  for (auto spvModule : spvModules) {
1616  spvModule.walk([&](spirv::GlobalVariableOp op) {
1617  IntegerAttr descriptorSet =
1618  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1619  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1620  // For every global variable in the module, get the ones with descriptor
1621  // set and binding numbers.
1622  if (descriptorSet && binding) {
1623  // Encode these numbers into the variable's symbolic name. If the
1624  // SPIR-V module has a name, add it at the beginning.
1625  auto moduleAndName =
1626  spvModule.getName().has_value()
1627  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1628  : op.getSymName().str();
1629  std::string name =
1630  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1631  std::to_string(descriptorSet.getInt()),
1632  std::to_string(binding.getInt()));
1633  auto nameAttr = StringAttr::get(op->getContext(), name);
1634 
1635  // Replace all symbol uses and set the new symbol name. Finally, remove
1636  // descriptor set and binding attributes.
1637  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1638  op.emitError("unable to replace all symbol uses for ") << name;
1639  SymbolTable::setSymbolName(op, nameAttr);
1640  op->removeAttr(kDescriptorSet);
1641  op->removeAttr(kBinding);
1642  }
1643  });
1644  }
1645 }
@ None
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
Definition: SPIRVToLLVM.cpp:99
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:55
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static Type convertPointerType(spirv::PointerType type, LLVMTypeConverter &converter)
Converts SPIR-V pointer type to LLVM pointer.
static std::optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:37
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:46
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static std::optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static std::optional< Type > convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static std::optional< Type > convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:76
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:86
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:68
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1223
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
iterator begin()
Definition: Block.h:132
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:202
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:247
IntegerType getI32Type()
Definition: Builders.cpp:80
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:84
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
LLVM::LLVMPointerType getPointerType(Type elementType, unsigned addressSpace=0)
Creates an LLVM pointer type with the given element type and address space.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:447
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
U cast() const
Definition: Location.h:89
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:426
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:417
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:405
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:393
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:423
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:482
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:321
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:70
U dyn_cast() const
Definition: Types.h:311
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:82
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:108
bool isa() const
Definition: Types.h:301
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:112
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
bool isa() const
Definition: Value.h:98
Type getType() const
Return the type of this value.
Definition: Value.h:122
U cast() const
Definition: Value.h:113
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
Type getElementType() const
Definition: SPIRVTypes.cpp:65
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:67
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:63
Type getPointeeType() const
Definition: SPIRVTypes.cpp:482
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:545
SPIR-V struct type.
Definition: SPIRVTypes.h:282
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
ElementTypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:842
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:824
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:220
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:858
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns that convert from SPIR-V to LLVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26