MLIR  16.0.0git
SPIRVToLLVM.cpp
Go to the documentation of this file.
1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns true if the given type is a signed integer or vector type.
37 static bool isSignedIntegerOrVector(Type type) {
38  if (type.isSignedInteger())
39  return true;
40  if (auto vecType = type.dyn_cast<VectorType>())
41  return vecType.getElementType().isSignedInteger();
42  return false;
43 }
44 
45 /// Returns true if the given type is an unsigned integer or vector type
46 static bool isUnsignedIntegerOrVector(Type type) {
47  if (type.isUnsignedInteger())
48  return true;
49  if (auto vecType = type.dyn_cast<VectorType>())
50  return vecType.getElementType().isUnsignedInteger();
51  return false;
52 }
53 
54 /// Returns the bit width of integer, float or vector of float or integer values
55 static unsigned getBitWidth(Type type) {
56  assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57  "bitwidth is not supported for this type");
58  if (type.isIntOrFloat())
59  return type.getIntOrFloatBitWidth();
60  auto vecType = type.dyn_cast<VectorType>();
61  auto elementType = vecType.getElementType();
62  assert(elementType.isIntOrFloat() &&
63  "only integers and floats have a bitwidth");
64  return elementType.getIntOrFloatBitWidth();
65 }
66 
67 /// Returns the bit width of LLVMType integer or vector.
68 static unsigned getLLVMTypeBitWidth(Type type) {
70  : type)
71  .cast<IntegerType>()
72  .getWidth();
73 }
74 
75 /// Creates `IntegerAttribute` with all bits set for given type
76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77  if (auto vecType = type.dyn_cast<VectorType>()) {
78  auto integerType = vecType.getElementType().cast<IntegerType>();
79  return builder.getIntegerAttr(integerType, -1);
80  }
81  auto integerType = type.cast<IntegerType>();
82  return builder.getIntegerAttr(integerType, -1);
83 }
84 
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87  PatternRewriter &rewriter) {
88  if (srcType.isa<VectorType>()) {
89  return rewriter.create<LLVM::ConstantOp>(
90  loc, dstType,
91  SplatElementsAttr::get(srcType.cast<ShapedType>(),
92  minusOneIntegerAttribute(srcType, rewriter)));
93  }
94  return rewriter.create<LLVM::ConstantOp>(
95  loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97 
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100  PatternRewriter &rewriter, double value) {
101  if (auto vecType = srcType.dyn_cast<VectorType>()) {
102  auto floatType = vecType.getElementType().cast<FloatType>();
103  return rewriter.create<LLVM::ConstantOp>(
104  loc, dstType,
105  SplatElementsAttr::get(vecType,
106  rewriter.getFloatAttr(floatType, value)));
107  }
108  auto floatType = srcType.cast<FloatType>();
109  return rewriter.create<LLVM::ConstantOp>(
110  loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112 
113 /// Utility function for bitfield ops:
114 /// - `BitFieldInsert`
115 /// - `BitFieldSExtract`
116 /// - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
120  Type llvmType,
121  PatternRewriter &rewriter) {
122  auto srcType = value.getType();
123  unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125  ? getLLVMTypeBitWidth(srcType)
126  : getBitWidth(srcType);
127 
128  if (valueBitWidth < targetBitWidth)
129  return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130  // If the bit widths of `Count` and `Offset` are greater than the bit width
131  // of the target type, they are truncated. Truncation is safe since `Count`
132  // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133  // both values can be expressed in 8 bits.
134  if (valueBitWidth > targetBitWidth)
135  return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136  return value;
137 }
138 
139 /// Broadcasts the value to vector with `numElements` number of elements.
140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141  LLVMTypeConverter &typeConverter,
142  ConversionPatternRewriter &rewriter) {
143  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144  auto llvmVectorType = typeConverter.convertType(vectorType);
145  auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146  Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147  for (unsigned i = 0; i < numElements; ++i) {
148  auto index = rewriter.create<LLVM::ConstantOp>(
149  loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150  broadcasted = rewriter.create<LLVM::InsertElementOp>(
151  loc, llvmVectorType, broadcasted, toBroadcast, index);
152  }
153  return broadcasted;
154 }
155 
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
158  LLVMTypeConverter &typeConverter,
159  ConversionPatternRewriter &rewriter) {
160  if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161  unsigned numElements = vectorType.getNumElements();
162  return broadcast(loc, value, numElements, typeConverter, rewriter);
163  }
164  return value;
165 }
166 
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 /// - same number of elements as `Base`
172 /// - each element has the type that is the same as the type of `Offset` or
173 /// `Count`
174 /// - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
178  Type dstType, LLVMTypeConverter &converter,
179  ConversionPatternRewriter &rewriter) {
180  Value broadcasted =
181  optionallyBroadcast(loc, value, srcType, converter, rewriter);
182  return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184 
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static Optional<Type>
189  LLVMTypeConverter &converter) {
190  if (type != VulkanLayoutUtils::decorateType(type))
191  return llvm::None;
192 
193  auto elementsVector = llvm::to_vector<8>(
194  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195  return converter.convertType(elementType);
196  }));
197  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198  /*isPacked=*/false);
199 }
200 
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
203  LLVMTypeConverter &converter) {
204  auto elementsVector = llvm::to_vector<8>(
205  llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206  return converter.convertType(elementType);
207  }));
208  return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209  /*isPacked=*/true);
210 }
211 
212 /// Creates LLVM dialect constant with the given value.
214  unsigned value) {
215  return rewriter.create<LLVM::ConstantOp>(
216  loc, IntegerType::get(rewriter.getContext(), 32),
217  rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219 
220 /// Utility for `spirv.Load` and `spirv.Store` conversion.
222  ConversionPatternRewriter &rewriter,
223  LLVMTypeConverter &typeConverter,
224  unsigned alignment, bool isVolatile,
225  bool isNonTemporal) {
226  if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227  auto dstType = typeConverter.convertType(loadOp.getType());
228  if (!dstType)
229  return failure();
230  rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231  loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment,
232  isVolatile, isNonTemporal);
233  return success();
234  }
235  auto storeOp = cast<spirv::StoreOp>(op);
236  spirv::StoreOpAdaptor adaptor(operands);
237  rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(),
238  adaptor.getPtr(), alignment,
239  isVolatile, isNonTemporal);
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Type conversion
245 //===----------------------------------------------------------------------===//
246 
247 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
248 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
249 /// when converting ops that manipulate array types.
251  TypeConverter &converter) {
252  unsigned stride = type.getArrayStride();
253  Type elementType = type.getElementType();
254  auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
255  if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride))
256  return llvm::None;
257 
258  auto llvmElementType = converter.convertType(elementType);
259  unsigned numElements = type.getNumElements();
260  return LLVM::LLVMArrayType::get(llvmElementType, numElements);
261 }
262 
263 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
264 /// modelled at the moment.
266  TypeConverter &converter) {
267  auto pointeeType = converter.convertType(type.getPointeeType());
268  return LLVM::LLVMPointerType::get(pointeeType);
269 }
270 
271 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
272 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
273 /// no modelling of array stride at the moment.
275  TypeConverter &converter) {
276  if (type.getArrayStride() != 0)
277  return llvm::None;
278  auto elementType = converter.convertType(type.getElementType());
279  return LLVM::LLVMArrayType::get(elementType, 0);
280 }
281 
282 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
283 /// member decorations. Also, only natural offset is supported.
285  LLVMTypeConverter &converter) {
287  type.getMemberDecorations(memberDecorations);
288  if (!memberDecorations.empty())
289  return llvm::None;
290  if (type.hasOffset())
291  return convertStructTypeWithOffset(type, converter);
292  return convertStructTypePacked(type, converter);
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Operation conversion
297 //===----------------------------------------------------------------------===//
298 
299 namespace {
300 
301 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
302 public:
304 
306  matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor,
307  ConversionPatternRewriter &rewriter) const override {
308  auto dstType = typeConverter.convertType(op.getComponentPtr().getType());
309  if (!dstType)
310  return failure();
311  // To use GEP we need to add a first 0 index to go through the pointer.
312  auto indices = llvm::to_vector<4>(adaptor.getIndices());
313  Type indexType = op.getIndices().front().getType();
314  auto llvmIndexType = typeConverter.convertType(indexType);
315  if (!llvmIndexType)
316  return failure();
317  Value zero = rewriter.create<LLVM::ConstantOp>(
318  op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
319  indices.insert(indices.begin(), zero);
320  rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, adaptor.getBasePtr(),
321  indices);
322  return success();
323  }
324 };
325 
326 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
327 public:
329 
331  matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor,
332  ConversionPatternRewriter &rewriter) const override {
333  auto dstType = typeConverter.convertType(op.getPointer().getType());
334  if (!dstType)
335  return failure();
336  rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.getVariable());
337  return success();
338  }
339 };
340 
341 class BitFieldInsertPattern
342  : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
343 public:
345 
347  matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor,
348  ConversionPatternRewriter &rewriter) const override {
349  auto srcType = op.getType();
350  auto dstType = typeConverter.convertType(srcType);
351  if (!dstType)
352  return failure();
353  Location loc = op.getLoc();
354 
355  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
356  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
357  typeConverter, rewriter);
358  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
359  typeConverter, rewriter);
360 
361  // Create a mask with bits set outside [Offset, Offset + Count - 1].
362  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
363  Value maskShiftedByCount =
364  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
365  Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
366  maskShiftedByCount, minusOne);
367  Value maskShiftedByCountAndOffset =
368  rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
369  Value mask = rewriter.create<LLVM::XOrOp>(
370  loc, dstType, maskShiftedByCountAndOffset, minusOne);
371 
372  // Extract unchanged bits from the `Base` that are outside of
373  // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
374  Value baseAndMask =
375  rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask);
376  Value insertShiftedByOffset =
377  rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset);
378  rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
379  insertShiftedByOffset);
380  return success();
381  }
382 };
383 
384 /// Converts SPIR-V ConstantOp with scalar or vector type.
385 class ConstantScalarAndVectorPattern
386  : public SPIRVToLLVMConversion<spirv::ConstantOp> {
387 public:
389 
391  matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor,
392  ConversionPatternRewriter &rewriter) const override {
393  auto srcType = constOp.getType();
394  if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
395  return failure();
396 
397  auto dstType = typeConverter.convertType(srcType);
398  if (!dstType)
399  return failure();
400 
401  // SPIR-V constant can be a signed/unsigned integer, which has to be
402  // casted to signless integer when converting to LLVM dialect. Removing the
403  // sign bit may have unexpected behaviour. However, it is better to handle
404  // it case-by-case, given that the purpose of the conversion is not to
405  // cover all possible corner cases.
406  if (isSignedIntegerOrVector(srcType) ||
407  isUnsignedIntegerOrVector(srcType)) {
408  auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
409 
410  if (srcType.isa<VectorType>()) {
411  auto dstElementsAttr = constOp.getValue().cast<DenseIntElementsAttr>();
412  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
413  constOp, dstType,
414  dstElementsAttr.mapValues(
415  signlessType, [&](const APInt &value) { return value; }));
416  return success();
417  }
418  auto srcAttr = constOp.getValue().cast<IntegerAttr>();
419  auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
420  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
421  return success();
422  }
423  rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
424  constOp, dstType, adaptor.getOperands(), constOp->getAttrs());
425  return success();
426  }
427 };
428 
429 class BitFieldSExtractPattern
430  : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
431 public:
433 
435  matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor,
436  ConversionPatternRewriter &rewriter) const override {
437  auto srcType = op.getType();
438  auto dstType = typeConverter.convertType(srcType);
439  if (!dstType)
440  return failure();
441  Location loc = op.getLoc();
442 
443  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
444  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
445  typeConverter, rewriter);
446  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
447  typeConverter, rewriter);
448 
449  // Create a constant that holds the size of the `Base`.
450  IntegerType integerType;
451  if (auto vecType = srcType.dyn_cast<VectorType>())
452  integerType = vecType.getElementType().cast<IntegerType>();
453  else
454  integerType = srcType.cast<IntegerType>();
455 
456  auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
457  Value size =
458  srcType.isa<VectorType>()
459  ? rewriter.create<LLVM::ConstantOp>(
460  loc, dstType,
461  SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
462  : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
463 
464  // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
465  // at Offset + Count - 1 is the most significant bit now.
466  Value countPlusOffset =
467  rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
468  Value amountToShiftLeft =
469  rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
470  Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
471  loc, dstType, op.getBase(), amountToShiftLeft);
472 
473  // Shift the result right, filling the bits with the sign bit.
474  Value amountToShiftRight =
475  rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
476  rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
477  amountToShiftRight);
478  return success();
479  }
480 };
481 
482 class BitFieldUExtractPattern
483  : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
484 public:
486 
488  matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor,
489  ConversionPatternRewriter &rewriter) const override {
490  auto srcType = op.getType();
491  auto dstType = typeConverter.convertType(srcType);
492  if (!dstType)
493  return failure();
494  Location loc = op.getLoc();
495 
496  // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
497  Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType,
498  typeConverter, rewriter);
499  Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType,
500  typeConverter, rewriter);
501 
502  // Create a mask with bits set at [0, Count - 1].
503  Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
504  Value maskShiftedByCount =
505  rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
506  Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
507  minusOne);
508 
509  // Shift `Base` by `Offset` and apply the mask on it.
510  Value shiftedBase =
511  rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset);
512  rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
513  return success();
514  }
515 };
516 
517 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
518 public:
520 
522  matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor,
523  ConversionPatternRewriter &rewriter) const override {
524  rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(),
525  branchOp.getTarget());
526  return success();
527  }
528 };
529 
530 class BranchConditionalConversionPattern
531  : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
532 public:
533  using SPIRVToLLVMConversion<
534  spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
535 
537  matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
538  ConversionPatternRewriter &rewriter) const override {
539  // If branch weights exist, map them to 32-bit integer vector.
540  ElementsAttr branchWeights = nullptr;
541  if (auto weights = op.getBranchWeights()) {
542  VectorType weightType = VectorType::get(2, rewriter.getI32Type());
543  branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
544  }
545 
546  rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
547  op, op.getCondition(), op.getTrueBlockArguments(),
548  op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
549  op.getFalseBlock());
550  return success();
551  }
552 };
553 
554 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container
555 /// type is an aggregate type (struct or array). Otherwise, converts to
556 /// `llvm.extractelement` that operates on vectors.
557 class CompositeExtractPattern
558  : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
559 public:
561 
563  matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor,
564  ConversionPatternRewriter &rewriter) const override {
565  auto dstType = this->typeConverter.convertType(op.getType());
566  if (!dstType)
567  return failure();
568 
569  Type containerType = op.getComposite().getType();
570  if (containerType.isa<VectorType>()) {
571  Location loc = op.getLoc();
572  IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
573  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
574  rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
575  op, dstType, adaptor.getComposite(), index);
576  return success();
577  }
578 
579  rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
580  op, adaptor.getComposite(), LLVM::convertArrayToIndices(op.getIndices()));
581  return success();
582  }
583 };
584 
585 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container
586 /// type is an aggregate type (struct or array). Otherwise, converts to
587 /// `llvm.insertelement` that operates on vectors.
588 class CompositeInsertPattern
589  : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
590 public:
592 
594  matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor,
595  ConversionPatternRewriter &rewriter) const override {
596  auto dstType = this->typeConverter.convertType(op.getType());
597  if (!dstType)
598  return failure();
599 
600  Type containerType = op.getComposite().getType();
601  if (containerType.isa<VectorType>()) {
602  Location loc = op.getLoc();
603  IntegerAttr value = op.getIndices()[0].cast<IntegerAttr>();
604  Value index = createI32ConstantOf(loc, rewriter, value.getInt());
605  rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
606  op, dstType, adaptor.getComposite(), adaptor.getObject(), index);
607  return success();
608  }
609 
610  rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
611  op, adaptor.getComposite(), adaptor.getObject(),
612  LLVM::convertArrayToIndices(op.getIndices()));
613  return success();
614  }
615 };
616 
617 /// Converts SPIR-V operations that have straightforward LLVM equivalent
618 /// into LLVM dialect operations.
619 template <typename SPIRVOp, typename LLVMOp>
620 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
621 public:
623 
625  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
626  ConversionPatternRewriter &rewriter) const override {
627  auto dstType = this->typeConverter.convertType(operation.getType());
628  if (!dstType)
629  return failure();
630  rewriter.template replaceOpWithNewOp<LLVMOp>(
631  operation, dstType, adaptor.getOperands(), operation->getAttrs());
632  return success();
633  }
634 };
635 
636 /// Converts `spirv.ExecutionMode` into a global struct constant that holds
637 /// execution mode information.
638 class ExecutionModePattern
639  : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
640 public:
642 
644  matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor,
645  ConversionPatternRewriter &rewriter) const override {
646  // First, create the global struct's name that would be associated with
647  // this entry point's execution mode. We set it to be:
648  // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
649  ModuleOp module = op->getParentOfType<ModuleOp>();
650  spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr();
651  std::string moduleName;
652  if (module.getName().has_value())
653  moduleName = "_" + module.getName()->str();
654  else
655  moduleName = "";
656  std::string executionModeInfoName = llvm::formatv(
657  "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(),
658  static_cast<uint32_t>(executionModeAttr.getValue()));
659 
660  MLIRContext *context = rewriter.getContext();
661  OpBuilder::InsertionGuard guard(rewriter);
662  rewriter.setInsertionPointToStart(module.getBody());
663 
664  // Create a struct type, corresponding to the C struct below.
665  // struct {
666  // int32_t executionMode;
667  // int32_t values[]; // optional values
668  // };
669  auto llvmI32Type = IntegerType::get(context, 32);
670  SmallVector<Type, 2> fields;
671  fields.push_back(llvmI32Type);
672  ArrayAttr values = op.getValues();
673  if (!values.empty()) {
674  auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
675  fields.push_back(arrayType);
676  }
677  auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
678 
679  // Create `llvm.mlir.global` with initializer region containing one block.
680  auto global = rewriter.create<LLVM::GlobalOp>(
681  UnknownLoc::get(context), structType, /*isConstant=*/true,
682  LLVM::Linkage::External, executionModeInfoName, Attribute(),
683  /*alignment=*/0);
684  Location loc = global.getLoc();
685  Region &region = global.getInitializerRegion();
686  Block *block = rewriter.createBlock(&region);
687 
688  // Initialize the struct and set the execution mode value.
689  rewriter.setInsertionPoint(block, block->begin());
690  Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
691  Value executionMode = rewriter.create<LLVM::ConstantOp>(
692  loc, llvmI32Type,
693  rewriter.getI32IntegerAttr(
694  static_cast<uint32_t>(executionModeAttr.getValue())));
695  structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue,
696  executionMode, 0);
697 
698  // Insert extra operands if they exist into execution mode info struct.
699  for (unsigned i = 0, e = values.size(); i < e; ++i) {
700  auto attr = values.getValue()[i];
701  Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
702  structValue = rewriter.create<LLVM::InsertValueOp>(
703  loc, structValue, entry, ArrayRef<int64_t>({1, i}));
704  }
705  rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
706  rewriter.eraseOp(op);
707  return success();
708  }
709 };
710 
711 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V
712 /// global returns a pointer, whereas in LLVM dialect the global holds an actual
713 /// value. This difference is handled by `spirv.mlir.addressof` and
714 /// `llvm.mlir.addressof`ops that both return a pointer.
715 class GlobalVariablePattern
716  : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
717 public:
719 
721  matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor,
722  ConversionPatternRewriter &rewriter) const override {
723  // Currently, there is no support of initialization with a constant value in
724  // SPIR-V dialect. Specialization constants are not considered as well.
725  if (op.getInitializer())
726  return failure();
727 
728  auto srcType = op.getType().cast<spirv::PointerType>();
729  auto dstType = typeConverter.convertType(srcType.getPointeeType());
730  if (!dstType)
731  return failure();
732 
733  // Limit conversion to the current invocation only or `StorageBuffer`
734  // required by SPIR-V runner.
735  // This is okay because multiple invocations are not supported yet.
736  auto storageClass = srcType.getStorageClass();
737  switch (storageClass) {
738  case spirv::StorageClass::Input:
739  case spirv::StorageClass::Private:
740  case spirv::StorageClass::Output:
741  case spirv::StorageClass::StorageBuffer:
742  case spirv::StorageClass::UniformConstant:
743  break;
744  default:
745  return failure();
746  }
747 
748  // LLVM dialect spec: "If the global value is a constant, storing into it is
749  // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant'
750  // storage class that is read-only.
751  bool isConstant = (storageClass == spirv::StorageClass::Input) ||
752  (storageClass == spirv::StorageClass::UniformConstant);
753  // SPIR-V spec: "By default, functions and global variables are private to a
754  // module and cannot be accessed by other modules. However, a module may be
755  // written to export or import functions and global (module scope)
756  // variables.". Therefore, map 'Private' storage class to private linkage,
757  // 'Input' and 'Output' to external linkage.
758  auto linkage = storageClass == spirv::StorageClass::Private
759  ? LLVM::Linkage::Private
760  : LLVM::Linkage::External;
761  auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
762  op, dstType, isConstant, linkage, op.getSymName(), Attribute(),
763  /*alignment=*/0);
764 
765  // Attach location attribute if applicable
766  if (op.getLocationAttr())
767  newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr());
768 
769  return success();
770  }
771 };
772 
773 /// Converts SPIR-V cast ops that do not have straightforward LLVM
774 /// equivalent in LLVM dialect.
775 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
776 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
777 public:
779 
781  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
782  ConversionPatternRewriter &rewriter) const override {
783 
784  Type fromType = operation.getOperand().getType();
785  Type toType = operation.getType();
786 
787  auto dstType = this->typeConverter.convertType(toType);
788  if (!dstType)
789  return failure();
790 
791  if (getBitWidth(fromType) < getBitWidth(toType)) {
792  rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
793  adaptor.getOperands());
794  return success();
795  }
796  if (getBitWidth(fromType) > getBitWidth(toType)) {
797  rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
798  adaptor.getOperands());
799  return success();
800  }
801  return failure();
802  }
803 };
804 
805 class FunctionCallPattern
806  : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
807 public:
809 
811  matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor,
812  ConversionPatternRewriter &rewriter) const override {
813  if (callOp.getNumResults() == 0) {
814  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
815  callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs());
816  return success();
817  }
818 
819  // Function returns a single result.
820  auto dstType = typeConverter.convertType(callOp.getType(0));
821  rewriter.replaceOpWithNewOp<LLVM::CallOp>(
822  callOp, dstType, adaptor.getOperands(), callOp->getAttrs());
823  return success();
824  }
825 };
826 
827 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
828 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
829 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
830 public:
832 
834  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
835  ConversionPatternRewriter &rewriter) const override {
836 
837  auto dstType = this->typeConverter.convertType(operation.getType());
838  if (!dstType)
839  return failure();
840 
841  rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
842  operation, dstType, predicate, operation.getOperand1(),
843  operation.getOperand2());
844  return success();
845  }
846 };
847 
848 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
849 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
850 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
851 public:
853 
855  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
856  ConversionPatternRewriter &rewriter) const override {
857 
858  auto dstType = this->typeConverter.convertType(operation.getType());
859  if (!dstType)
860  return failure();
861 
862  rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
863  operation, dstType, predicate, operation.getOperand1(),
864  operation.getOperand2());
865  return success();
866  }
867 };
868 
869 class InverseSqrtPattern
870  : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> {
871 public:
873 
875  matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor,
876  ConversionPatternRewriter &rewriter) const override {
877  auto srcType = op.getType();
878  auto dstType = typeConverter.convertType(srcType);
879  if (!dstType)
880  return failure();
881 
882  Location loc = op.getLoc();
883  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
884  Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand());
885  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
886  return success();
887  }
888 };
889 
890 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect.
891 template <typename SPIRVOp>
892 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> {
893 public:
895 
897  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
898  ConversionPatternRewriter &rewriter) const override {
899  if (!op.getMemoryAccess()) {
900  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
901  this->typeConverter, /*alignment=*/0,
902  /*isVolatile=*/false,
903  /*isNonTemporal=*/false);
904  }
905  auto memoryAccess = *op.getMemoryAccess();
906  switch (memoryAccess) {
907  case spirv::MemoryAccess::Aligned:
909  case spirv::MemoryAccess::Nontemporal:
910  case spirv::MemoryAccess::Volatile: {
911  unsigned alignment =
912  memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0;
913  bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
914  bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
915  return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter,
916  this->typeConverter, alignment, isVolatile,
917  isNonTemporal);
918  }
919  default:
920  // There is no support of other memory access attributes.
921  return failure();
922  }
923  }
924 };
925 
926 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect.
927 template <typename SPIRVOp>
928 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
929 public:
931 
933  matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor,
934  ConversionPatternRewriter &rewriter) const override {
935  auto srcType = notOp.getType();
936  auto dstType = this->typeConverter.convertType(srcType);
937  if (!dstType)
938  return failure();
939 
940  Location loc = notOp.getLoc();
941  IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
942  auto mask = srcType.template isa<VectorType>()
943  ? rewriter.create<LLVM::ConstantOp>(
944  loc, dstType,
946  srcType.template cast<VectorType>(), minusOne))
947  : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
948  rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
949  notOp.getOperand(), mask);
950  return success();
951  }
952 };
953 
954 /// A template pattern that erases the given `SPIRVOp`.
955 template <typename SPIRVOp>
956 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
957 public:
959 
961  matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor,
962  ConversionPatternRewriter &rewriter) const override {
963  rewriter.eraseOp(op);
964  return success();
965  }
966 };
967 
968 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
969 public:
971 
973  matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor,
974  ConversionPatternRewriter &rewriter) const override {
975  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
976  ArrayRef<Value>());
977  return success();
978  }
979 };
980 
981 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
982 public:
984 
986  matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor,
987  ConversionPatternRewriter &rewriter) const override {
988  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
989  adaptor.getOperands());
990  return success();
991  }
992 };
993 
994 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
995 /// should be reachable for conversion to succeed. The structure of the loop in
996 /// LLVM dialect will be the following:
997 ///
998 /// +------------------------------------+
999 /// | <code before spirv.mlir.loop> |
1000 /// | llvm.br ^header |
1001 /// +------------------------------------+
1002 /// |
1003 /// +----------------+ |
1004 /// | | |
1005 /// | V V
1006 /// | +------------------------------------+
1007 /// | | ^header: |
1008 /// | | <header code> |
1009 /// | | llvm.cond_br %cond, ^body, ^exit |
1010 /// | +------------------------------------+
1011 /// | |
1012 /// | |----------------------+
1013 /// | | |
1014 /// | V |
1015 /// | +------------------------------------+ |
1016 /// | | ^body: | |
1017 /// | | <body code> | |
1018 /// | | llvm.br ^continue | |
1019 /// | +------------------------------------+ |
1020 /// | | |
1021 /// | V |
1022 /// | +------------------------------------+ |
1023 /// | | ^continue: | |
1024 /// | | <continue code> | |
1025 /// | | llvm.br ^header | |
1026 /// | +------------------------------------+ |
1027 /// | | |
1028 /// +---------------+ +----------------------+
1029 /// |
1030 /// V
1031 /// +------------------------------------+
1032 /// | ^exit: |
1033 /// | llvm.br ^remaining |
1034 /// +------------------------------------+
1035 /// |
1036 /// V
1037 /// +------------------------------------+
1038 /// | ^remaining: |
1039 /// | <code after spirv.mlir.loop> |
1040 /// +------------------------------------+
1041 ///
1042 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1043 public:
1045 
1047  matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor,
1048  ConversionPatternRewriter &rewriter) const override {
1049  // There is no support of loop control at the moment.
1050  if (loopOp.getLoopControl() != spirv::LoopControl::None)
1051  return failure();
1052 
1053  Location loc = loopOp.getLoc();
1054 
1055  // Split the current block after `spirv.mlir.loop`. The remaining ops will
1056  // be used in `endBlock`.
1057  Block *currentBlock = rewriter.getBlock();
1058  auto position = Block::iterator(loopOp);
1059  Block *endBlock = rewriter.splitBlock(currentBlock, position);
1060 
1061  // Remove entry block and create a branch in the current block going to the
1062  // header block.
1063  Block *entryBlock = loopOp.getEntryBlock();
1064  assert(entryBlock->getOperations().size() == 1);
1065  auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1066  if (!brOp)
1067  return failure();
1068  Block *headerBlock = loopOp.getHeaderBlock();
1069  rewriter.setInsertionPointToEnd(currentBlock);
1070  rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1071  rewriter.eraseBlock(entryBlock);
1072 
1073  // Branch from merge block to end block.
1074  Block *mergeBlock = loopOp.getMergeBlock();
1075  Operation *terminator = mergeBlock->getTerminator();
1076  ValueRange terminatorOperands = terminator->getOperands();
1077  rewriter.setInsertionPointToEnd(mergeBlock);
1078  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1079 
1080  rewriter.inlineRegionBefore(loopOp.getBody(), endBlock);
1081  rewriter.replaceOp(loopOp, endBlock->getArguments());
1082  return success();
1083  }
1084 };
1085 
1086 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header
1087 /// block. All blocks within selection should be reachable for conversion to
1088 /// succeed.
1089 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1090 public:
1092 
1094  matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor,
1095  ConversionPatternRewriter &rewriter) const override {
1096  // There is no support for `Flatten` or `DontFlatten` selection control at
1097  // the moment. This are just compiler hints and can be performed during the
1098  // optimization passes.
1099  if (op.getSelectionControl() != spirv::SelectionControl::None)
1100  return failure();
1101 
1102  // `spirv.mlir.selection` should have at least two blocks: one selection
1103  // header block and one merge block. If no blocks are present, or control
1104  // flow branches straight to merge block (two blocks are present), the op is
1105  // redundant and it is erased.
1106  if (op.getBody().getBlocks().size() <= 2) {
1107  rewriter.eraseOp(op);
1108  return success();
1109  }
1110 
1111  Location loc = op.getLoc();
1112 
1113  // Split the current block after `spirv.mlir.selection`. The remaining ops
1114  // will be used in `continueBlock`.
1115  auto *currentBlock = rewriter.getInsertionBlock();
1116  rewriter.setInsertionPointAfter(op);
1117  auto position = rewriter.getInsertionPoint();
1118  auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1119 
1120  // Extract conditional branch information from the header block. By SPIR-V
1121  // dialect spec, it should contain `spirv.BranchConditional` or
1122  // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the
1123  // moment in the SPIR-V dialect. Remove this block when finished.
1124  auto *headerBlock = op.getHeaderBlock();
1125  assert(headerBlock->getOperations().size() == 1);
1126  auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1127  headerBlock->getOperations().front());
1128  if (!condBrOp)
1129  return failure();
1130  rewriter.eraseBlock(headerBlock);
1131 
1132  // Branch from merge block to continue block.
1133  auto *mergeBlock = op.getMergeBlock();
1134  Operation *terminator = mergeBlock->getTerminator();
1135  ValueRange terminatorOperands = terminator->getOperands();
1136  rewriter.setInsertionPointToEnd(mergeBlock);
1137  rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1138 
1139  // Link current block to `true` and `false` blocks within the selection.
1140  Block *trueBlock = condBrOp.getTrueBlock();
1141  Block *falseBlock = condBrOp.getFalseBlock();
1142  rewriter.setInsertionPointToEnd(currentBlock);
1143  rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock,
1144  condBrOp.getTrueTargetOperands(), falseBlock,
1145  condBrOp.getFalseTargetOperands());
1146 
1147  rewriter.inlineRegionBefore(op.getBody(), continueBlock);
1148  rewriter.replaceOp(op, continueBlock->getArguments());
1149  return success();
1150  }
1151 };
1152 
1153 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1154 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1155 /// `Shift` is zero or sign extended to match this specification. Cases when
1156 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1157 template <typename SPIRVOp, typename LLVMOp>
1158 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1159 public:
1161 
1163  matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor,
1164  ConversionPatternRewriter &rewriter) const override {
1165 
1166  auto dstType = this->typeConverter.convertType(operation.getType());
1167  if (!dstType)
1168  return failure();
1169 
1170  Type op1Type = operation.getOperand1().getType();
1171  Type op2Type = operation.getOperand2().getType();
1172 
1173  if (op1Type == op2Type) {
1174  rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1175  adaptor.getOperands());
1176  return success();
1177  }
1178 
1179  Location loc = operation.getLoc();
1180  Value extended;
1181  if (isUnsignedIntegerOrVector(op2Type)) {
1182  extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1183  adaptor.getOperand2());
1184  } else {
1185  extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1186  adaptor.getOperand2());
1187  }
1188  Value result = rewriter.template create<LLVMOp>(
1189  loc, dstType, adaptor.getOperand1(), extended);
1190  rewriter.replaceOp(operation, result);
1191  return success();
1192  }
1193 };
1194 
1195 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> {
1196 public:
1198 
1200  matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor,
1201  ConversionPatternRewriter &rewriter) const override {
1202  auto dstType = typeConverter.convertType(tanOp.getType());
1203  if (!dstType)
1204  return failure();
1205 
1206  Location loc = tanOp.getLoc();
1207  Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand());
1208  Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand());
1209  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1210  return success();
1211  }
1212 };
1213 
1214 /// Convert `spirv.Tanh` to
1215 ///
1216 /// exp(2x) - 1
1217 /// -----------
1218 /// exp(2x) + 1
1219 ///
1220 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> {
1221 public:
1223 
1225  matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor,
1226  ConversionPatternRewriter &rewriter) const override {
1227  auto srcType = tanhOp.getType();
1228  auto dstType = typeConverter.convertType(srcType);
1229  if (!dstType)
1230  return failure();
1231 
1232  Location loc = tanhOp.getLoc();
1233  Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1234  Value multiplied =
1235  rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand());
1236  Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1237  Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1238  Value numerator =
1239  rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1240  Value denominator =
1241  rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1242  rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1243  denominator);
1244  return success();
1245  }
1246 };
1247 
1248 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1249 public:
1251 
1253  matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor,
1254  ConversionPatternRewriter &rewriter) const override {
1255  auto srcType = varOp.getType();
1256  // Initialization is supported for scalars and vectors only.
1257  auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1258  auto init = varOp.getInitializer();
1259  if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1260  return failure();
1261 
1262  auto dstType = typeConverter.convertType(srcType);
1263  if (!dstType)
1264  return failure();
1265 
1266  Location loc = varOp.getLoc();
1267  Value size = createI32ConstantOf(loc, rewriter, 1);
1268  if (!init) {
1269  rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1270  return success();
1271  }
1272  Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1273  rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated);
1274  rewriter.replaceOp(varOp, allocated);
1275  return success();
1276  }
1277 };
1278 
1279 //===----------------------------------------------------------------------===//
1280 // FuncOp conversion
1281 //===----------------------------------------------------------------------===//
1282 
1283 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1284 public:
1286 
1288  matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor,
1289  ConversionPatternRewriter &rewriter) const override {
1290 
1291  // Convert function signature. At the moment LLVMType converter is enough
1292  // for currently supported types.
1293  auto funcType = funcOp.getFunctionType();
1294  TypeConverter::SignatureConversion signatureConverter(
1295  funcType.getNumInputs());
1296  auto llvmType = typeConverter.convertFunctionSignature(
1297  funcType, /*isVariadic=*/false, signatureConverter);
1298  if (!llvmType)
1299  return failure();
1300 
1301  // Create a new `LLVMFuncOp`
1302  Location loc = funcOp.getLoc();
1303  StringRef name = funcOp.getName();
1304  auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1305 
1306  // Convert SPIR-V Function Control to equivalent LLVM function attribute
1307  MLIRContext *context = funcOp.getContext();
1308  switch (funcOp.getFunctionControl()) {
1309 #define DISPATCH(functionControl, llvmAttr) \
1310  case functionControl: \
1311  newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
1312  break;
1313 
1314  DISPATCH(spirv::FunctionControl::Inline,
1315  StringAttr::get(context, "alwaysinline"));
1316  DISPATCH(spirv::FunctionControl::DontInline,
1317  StringAttr::get(context, "noinline"));
1318  DISPATCH(spirv::FunctionControl::Pure,
1319  StringAttr::get(context, "readonly"));
1320  DISPATCH(spirv::FunctionControl::Const,
1321  StringAttr::get(context, "readnone"));
1322 
1323 #undef DISPATCH
1324 
1325  // Default: if `spirv::FunctionControl::None`, then no attributes are
1326  // needed.
1327  default:
1328  break;
1329  }
1330 
1331  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1332  newFuncOp.end());
1333  if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1334  &signatureConverter))) {
1335  return failure();
1336  }
1337  rewriter.eraseOp(funcOp);
1338  return success();
1339  }
1340 };
1341 
1342 //===----------------------------------------------------------------------===//
1343 // ModuleOp conversion
1344 //===----------------------------------------------------------------------===//
1345 
1346 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1347 public:
1349 
1351  matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor,
1352  ConversionPatternRewriter &rewriter) const override {
1353 
1354  auto newModuleOp =
1355  rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1356  rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
1357 
1358  // Remove the terminator block that was automatically added by builder
1359  rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1360  rewriter.eraseOp(spvModuleOp);
1361  return success();
1362  }
1363 };
1364 
1365 //===----------------------------------------------------------------------===//
1366 // VectorShuffleOp conversion
1367 //===----------------------------------------------------------------------===//
1368 
1369 class VectorShufflePattern
1370  : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> {
1371 public:
1374  matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor,
1375  ConversionPatternRewriter &rewriter) const override {
1376  Location loc = op.getLoc();
1377  auto components = adaptor.getComponents();
1378  auto vector1 = adaptor.getVector1();
1379  auto vector2 = adaptor.getVector2();
1380  int vector1Size = vector1.getType().cast<VectorType>().getNumElements();
1381  int vector2Size = vector2.getType().cast<VectorType>().getNumElements();
1382  if (vector1Size == vector2Size) {
1383  rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(
1384  op, vector1, vector2,
1385  LLVM::convertArrayToIndices<int32_t>(components));
1386  return success();
1387  }
1388 
1389  auto dstType = typeConverter.convertType(op.getType());
1390  auto scalarType = dstType.cast<VectorType>().getElementType();
1391  auto componentsArray = components.getValue();
1392  auto *context = rewriter.getContext();
1393  auto llvmI32Type = IntegerType::get(context, 32);
1394  Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
1395  for (unsigned i = 0; i < componentsArray.size(); i++) {
1396  if (!componentsArray[i].isa<IntegerAttr>())
1397  return op.emitError("unable to support non-constant component");
1398 
1399  int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
1400  if (indexVal == -1)
1401  continue;
1402 
1403  int offsetVal = 0;
1404  Value baseVector = vector1;
1405  if (indexVal >= vector1Size) {
1406  offsetVal = vector1Size;
1407  baseVector = vector2;
1408  }
1409 
1410  Value dstIndex = rewriter.create<LLVM::ConstantOp>(
1411  loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i));
1412  Value index = rewriter.create<LLVM::ConstantOp>(
1413  loc, llvmI32Type,
1414  rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal));
1415 
1416  auto extractOp = rewriter.create<LLVM::ExtractElementOp>(
1417  loc, scalarType, baseVector, index);
1418  targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp,
1419  extractOp, dstIndex);
1420  }
1421  rewriter.replaceOp(op, targetOp);
1422  return success();
1423  }
1424 };
1425 } // namespace
1426 
1427 //===----------------------------------------------------------------------===//
1428 // Pattern population
1429 //===----------------------------------------------------------------------===//
1430 
1432  typeConverter.addConversion([&](spirv::ArrayType type) {
1433  return convertArrayType(type, typeConverter);
1434  });
1435  typeConverter.addConversion([&](spirv::PointerType type) {
1436  return convertPointerType(type, typeConverter);
1437  });
1438  typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1439  return convertRuntimeArrayType(type, typeConverter);
1440  });
1441  typeConverter.addConversion([&](spirv::StructType type) {
1442  return convertStructType(type, typeConverter);
1443  });
1444 }
1445 
1447  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1448  patterns.add<
1449  // Arithmetic ops
1450  DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1451  DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1452  DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1453  DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1454  DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1455  DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1456  DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1457  DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1458  DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1459  DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1460  DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1461  DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1462  DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1463 
1464  // Bitwise ops
1465  BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1466  DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1467  DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1468  DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1469  DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1470  DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1471  NotPattern<spirv::NotOp>,
1472 
1473  // Cast ops
1474  DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1475  DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1476  DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1477  DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1478  DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1479  IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1480  IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1481  IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1482 
1483  // Comparison ops
1484  IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1485  IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1486  FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1487  FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1488  FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1489  FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1490  FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1491  FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1492  FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1493  FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1494  FComparePattern<spirv::FUnordGreaterThanEqualOp,
1495  LLVM::FCmpPredicate::uge>,
1496  FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1497  FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1498  FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1499  IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1500  IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1501  IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1502  IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1503  IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1504  IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1505  IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1506  IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1507 
1508  // Constant op
1509  ConstantScalarAndVectorPattern,
1510 
1511  // Control Flow ops
1512  BranchConversionPattern, BranchConditionalConversionPattern,
1513  FunctionCallPattern, LoopPattern, SelectionPattern,
1514  ErasePattern<spirv::MergeOp>,
1515 
1516  // Entry points and execution mode are handled separately.
1517  ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1518 
1519  // GLSL extended instruction set ops
1520  DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>,
1521  DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>,
1522  DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>,
1523  DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>,
1524  DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>,
1525  DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>,
1526  DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>,
1527  DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>,
1528  DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>,
1529  DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>,
1530  DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>,
1531  DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>,
1532  InverseSqrtPattern, TanPattern, TanhPattern,
1533 
1534  // Logical ops
1535  DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1536  DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1537  IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1538  IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1539  NotPattern<spirv::LogicalNotOp>,
1540 
1541  // Memory ops
1542  AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1543  LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1544  VariablePattern,
1545 
1546  // Miscellaneous ops
1547  CompositeExtractPattern, CompositeInsertPattern,
1548  DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1549  DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1550  VectorShufflePattern,
1551 
1552  // Shift ops
1553  ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1554  ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1555  ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1556 
1557  // Return ops
1558  ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
1559 }
1560 
1562  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1563  patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter);
1564 }
1565 
1567  LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1568  patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
1569 }
1570 
1571 //===----------------------------------------------------------------------===//
1572 // Pre-conversion hooks
1573 //===----------------------------------------------------------------------===//
1574 
1575 /// Hook for descriptor set and binding number encoding.
1576 static constexpr StringRef kBinding = "binding";
1577 static constexpr StringRef kDescriptorSet = "descriptor_set";
1578 void mlir::encodeBindAttribute(ModuleOp module) {
1579  auto spvModules = module.getOps<spirv::ModuleOp>();
1580  for (auto spvModule : spvModules) {
1581  spvModule.walk([&](spirv::GlobalVariableOp op) {
1582  IntegerAttr descriptorSet =
1583  op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1584  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1585  // For every global variable in the module, get the ones with descriptor
1586  // set and binding numbers.
1587  if (descriptorSet && binding) {
1588  // Encode these numbers into the variable's symbolic name. If the
1589  // SPIR-V module has a name, add it at the beginning.
1590  auto moduleAndName =
1591  spvModule.getName().has_value()
1592  ? spvModule.getName()->str() + "_" + op.getSymName().str()
1593  : op.getSymName().str();
1594  std::string name =
1595  llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1596  std::to_string(descriptorSet.getInt()),
1597  std::to_string(binding.getInt()));
1598  auto nameAttr = StringAttr::get(op->getContext(), name);
1599 
1600  // Replace all symbol uses and set the new symbol name. Finally, remove
1601  // descriptor set and binding attributes.
1602  if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
1603  op.emitError("unable to replace all symbol uses for ") << name;
1604  SymbolTable::setSymbolName(op, nameAttr);
1605  op->removeAttr(kDescriptorSet);
1606  op->removeAttr(kBinding);
1607  }
1608  });
1609  }
1610 }
static constexpr const bool value
@ None
static llvm::Value * getSizeInBytes(llvm::IRBuilderBase &builder, llvm::Value *basePtr)
Computes the size of type in bytes.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:696
static Value optionallyTruncateOrExtend(Location loc, Value value, Type llvmType, PatternRewriter &rewriter)
Utility function for bitfield ops:
static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value)
Creates llvm.mlir.constant with a floating-point scalar or vector value.
Definition: SPIRVToLLVM.cpp:99
static constexpr StringRef kDescriptorSet
static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value)
Creates LLVM dialect constant with the given value.
static Optional< Type > convertStructTypeWithOffset(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with a regular (according to VulkanLayoutUtils) offset to LLVM struct.
static Value processCountOrOffset(Location loc, Value value, Type srcType, Type dstType, LLVMTypeConverter &converter, ConversionPatternRewriter &rewriter)
Utility function for bitfield ops: BitFieldInsert, BitFieldSExtract and BitFieldUExtract.
static unsigned getBitWidth(Type type)
Returns the bit width of integer, float or vector of float or integer values.
Definition: SPIRVToLLVM.cpp:55
static Type convertPointerType(spirv::PointerType type, TypeConverter &converter)
Converts SPIR-V pointer type to LLVM pointer.
static Optional< Type > convertArrayType(spirv::ArrayType type, TypeConverter &converter)
Converts SPIR-V array type to LLVM array.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static bool isSignedIntegerOrVector(Type type)
Returns true if the given type is a signed integer or vector type.
Definition: SPIRVToLLVM.cpp:37
static bool isUnsignedIntegerOrVector(Type type)
Returns true if the given type is an unsigned integer or vector type.
Definition: SPIRVToLLVM.cpp:46
static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct with no offset to packed LLVM struct.
static constexpr StringRef kBinding
Hook for descriptor set and binding number encoding.
static Optional< Type > convertRuntimeArrayType(spirv::RuntimeArrayType type, TypeConverter &converter)
Converts SPIR-V runtime array to LLVM array.
static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder)
Creates IntegerAttribute with all bits set for given type.
Definition: SPIRVToLLVM.cpp:76
static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, bool isNonTemporal)
Utility for spirv.Load and spirv.Store conversion.
static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter)
Creates llvm.mlir.constant with all bits set for the given type.
Definition: SPIRVToLLVM.cpp:86
static unsigned getLLVMTypeBitWidth(Type type)
Returns the bit width of LLVMType integer or vector.
Definition: SPIRVToLLVM.cpp:68
static Value optionallyBroadcast(Location loc, Value value, Type srcType, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value. If srcType is a scalar, the value remains unchanged.
#define DISPATCH(functionControl, llvmAttr)
static Optional< Type > convertStructType(spirv::StructType type, LLVMTypeConverter &converter)
Converts SPIR-V struct to LLVM struct.
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1252
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType::iterator iterator
Definition: Block.h:129
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
iterator begin()
Definition: Block.h:132
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:190
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:235
IntegerType getI32Type()
Definition: Builders.cpp:68
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:72
MLIRContext * getContext() const
Definition: Builders.h:54
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
Block * splitBlock(Block *block, Block::iterator before) override
PatternRewriter hook for splitting a block into two parts.
void eraseBlock(Block *block) override
PatternRewriter hook for erase all operations in a block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:449
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
U cast() const
Definition: Location.h:90
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:397
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * getBlock() const
Returns the current block of the builder.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:394
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, StringAttr newSymbol, Operation *from)
Attempt to replace all uses of the given symbol 'oldSymbol' with the provided symbol 'newSymbol' that...
static void setSymbolName(Operation *symbol, StringAttr name)
Sets the name of the given symbol operation.
This class provides all of the information necessary to convert a type signature.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:51
U dyn_cast() const
Definition: Types.h:270
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:63
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:89
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isa() const
Definition: Value.h:90
Type getType() const
Return the type of this value.
Definition: Value.h:114
U cast() const
Definition: Value.h:105
static spirv::StructType decorateType(spirv::StructType structType)
Returns a new StructType with layout decoration.
Definition: LayoutUtils.cpp:21
Type getElementType() const
Definition: SPIRVTypes.cpp:64
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
unsigned getNumElements() const
Definition: SPIRVTypes.cpp:62
Type getPointeeType() const
Definition: SPIRVTypes.cpp:480
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:543
SPIR-V struct type.
Definition: SPIRVTypes.h:281
void getMemberDecorations(SmallVectorImpl< StructType::MemberDecorationInfo > &memberDecorations) const
ElementTypeRange getElementTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:857
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:839
SmallVector< IntT > convertArrayToIndices(ArrayRef< Attribute > attrs)
Convert an array of integer attributes to a vector of integers that can be used as indices in LLVM op...
Definition: LLVMDialect.h:221
Type getVectorElementType(Type type)
Returns the element type of any vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:873
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populateSPIRVToLLVMFunctionConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns for function conversion from SPIR-V to LLVM.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateSPIRVToLLVMModuleConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter)
Populates type conversions with additional SPIR-V types.
void encodeBindAttribute(ModuleOp module)
Encodes global variable's descriptor set and binding into its name if they both exist.
void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Populates the given list with patterns that convert from SPIR-V to LLVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26